mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-12 07:09:43 +00:00
Compare commits
88 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f25ac02e6c | |||
| 23cb668cac | |||
| 2ea3043b1b | |||
| 0f61e2415f | |||
| 76a425c600 | |||
| df71f3ce24 | |||
| 326aca0a48 | |||
| be46bdea8f | |||
| 306429a85b | |||
| 12f2f35760 | |||
| a024d33750 | |||
| 63cd2111ad | |||
| abe9e79825 | |||
| 503fc4e9f4 | |||
| 92b479f9ac | |||
| b954337ac7 | |||
| 5f6f476f32 | |||
| 502fdc0630 | |||
| 9db6213895 | |||
| aa1d906802 | |||
| eff8a6fd12 | |||
| c54cd529a2 | |||
| a5ca206c49 | |||
| 88100943ef | |||
| a95b15ccc0 | |||
| a97d078d95 | |||
| 98662e5f24 | |||
| 4d8f242af9 | |||
| 1ff8986c77 | |||
| f0aeded142 | |||
| da5d2f3e91 | |||
| d6ea3bbce0 | |||
| 7aedbbf81a | |||
| 1ee8d824f5 | |||
| f7c4f99545 | |||
| 92b6254473 | |||
| 79137f58d1 | |||
| da9c2e66f4 | |||
| 45730cc71e | |||
| 5d4af4b0b1 | |||
| 0050d7c61c | |||
| cf2897f545 | |||
| 2c18210d02 | |||
| 44bf283701 | |||
| a51682b266 | |||
| ed49c9935a | |||
| 52455d03a7 | |||
| 4afb253825 | |||
| 96c664e09f | |||
| 8bd0aec618 | |||
| e82e7a02e9 | |||
| 845b359d39 | |||
| a6ff3cfebb | |||
| 271d92dcaa | |||
| 8e940bf361 | |||
| 6e8be57eb2 | |||
| 723013c71b | |||
| bf6ac5e110 | |||
| 3ce5bcf24d | |||
| 6f5bb4d4a4 | |||
| f29311ccb0 | |||
| 0c79cf8f4e | |||
| f2ff370459 | |||
| 25f60c301b | |||
| 0699b46d87 | |||
| b8f7e401d4 | |||
| 656fc0f059 | |||
| 829d2d1ad9 | |||
| 4ccf28437a | |||
| 9a49e57c72 | |||
| 6c28ef894a | |||
| bf3c8746b7 | |||
| 9f32e00f90 | |||
| fcaa0ea5f9 | |||
| 5ac9356135 | |||
| b74e2a6113 | |||
| a4bed41132 | |||
| 5c8dd883be | |||
| 38f6fc816b | |||
| abde7be3b3 | |||
| b6c528a438 | |||
| 6d331310ab | |||
| 5dfdec9288 | |||
| 50977a2c28 | |||
| a0d7627d81 | |||
| 1ad2da403d | |||
| 2d3a605b3c | |||
| f173265354 |
@@ -78,7 +78,7 @@ jobs:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --all-extras
|
||||
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
|
||||
@@ -119,6 +119,7 @@ jobs:
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
container:
|
||||
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
@@ -158,3 +159,36 @@ jobs:
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
run: make test-end-to-end
|
||||
|
||||
# This job runs multi-GPU training tests with 4 GPUs
|
||||
nightly-multi-gpu-tests:
|
||||
name: Nightly Multi-GPU Tests
|
||||
needs: [build-docker-gpu-nightly]
|
||||
runs-on:
|
||||
group: aws-g4dn-12xlarge # Instance with 4 GPUs
|
||||
env:
|
||||
HF_HOME: /home/user_lerobot/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
CUDA_VISIBLE_DEVICES: "0,1,2,3"
|
||||
container:
|
||||
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Verify GPU availability
|
||||
run: |
|
||||
nvidia-smi
|
||||
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
||||
|
||||
- name: Run multi-GPU training tests
|
||||
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
|
||||
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
|
||||
timeout-minutes: 10
|
||||
|
||||
@@ -82,6 +82,14 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Remove Tags with Git dependencies
|
||||
# TODO(Steven): Temporary patch to remove libero and pi from PyPi 0.4.0 release due to its reliance on git dependencies.
|
||||
run: |
|
||||
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
|
||||
grep -E '@ git\+https|lerobot\[pi\]|lerobot\[libero\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
|
||||
sed -E -i '/@ git\+https|lerobot\[pi\]|lerobot\[libero\]/d' pyproject.toml
|
||||
echo "::info:: Git dependencies removed. Proceeding with build."
|
||||
|
||||
- name: Install build dependencies
|
||||
run: python -m pip install build
|
||||
|
||||
@@ -103,7 +111,7 @@ jobs:
|
||||
- name: Publish to TestPyPI for pre-releases
|
||||
# True for tags like 'v0.2.0-rc1'
|
||||
if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-')
|
||||
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
verbose: true
|
||||
@@ -111,7 +119,7 @@ jobs:
|
||||
|
||||
- name: Publish to PyPI
|
||||
if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
|
||||
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
with:
|
||||
verbose: true
|
||||
print-hash: true
|
||||
@@ -138,7 +146,7 @@ jobs:
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
enable-cache: true
|
||||
enable-cache: true # zizmor: ignore[cache-poisoning]
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
- name: Create uv virtual environment
|
||||
|
||||
@@ -27,15 +27,17 @@ env:
|
||||
This issue was closed because it has been stalled for 14 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
CLOSE_PR_MESSAGE: >
|
||||
This PR was closed because it has been stalled for 14 days with no activity.
|
||||
This PR was closed because it has been stalled for 21 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
WARN_ISSUE_MESSAGE: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
Any change, comment or update to this issue will reset this count.
|
||||
Thank you for your contributions.
|
||||
WARN_PR_MESSAGE: >
|
||||
This PR has been automatically marked as stale because it has not had
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
recent activity (1 year). It will be closed if no further activity occurs.
|
||||
Any change, comment or update to this PR will reset this count.
|
||||
Thank you for your contributions.
|
||||
|
||||
jobs:
|
||||
@@ -56,10 +58,10 @@ jobs:
|
||||
stale-pr-label: stale
|
||||
exempt-issue-labels: never-stale
|
||||
exempt-pr-labels: never-stale
|
||||
days-before-issue-stale: 180 # TODO(Steven): Will modify this to 90 after initial cleanup
|
||||
days-before-issue-stale: 180
|
||||
days-before-issue-close: 14
|
||||
days-before-pr-stale: 180
|
||||
days-before-pr-close: 14
|
||||
days-before-pr-stale: 365
|
||||
days-before-pr-close: 21
|
||||
delete-branch: true
|
||||
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
|
||||
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This workflow handles full testing with unboud dependencies versions.
|
||||
name: Unbound Dependency Tests
|
||||
|
||||
on:
|
||||
# Allows running this workflow manually from the Actions tab
|
||||
workflow_dispatch:
|
||||
|
||||
# Run on the 1st and 15th of every month at 09:00 UTC
|
||||
schedule:
|
||||
- cron: '0 2 1,15 * *'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound
|
||||
|
||||
# Ensures that only the latest action is built, canceling older runs.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
|
||||
# This job runs the E2E tests + pytest with all unbound extras
|
||||
full-tests:
|
||||
name: Full Unbound Tests
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Unbound dependencies
|
||||
run: |
|
||||
sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml
|
||||
echo "Dependencies unbound:" && cat pyproject.toml
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --all-extras
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv
|
||||
|
||||
- name: Run end-to-end tests
|
||||
run: uv run make test-end-to-end
|
||||
|
||||
# This job builds a GPU enabled image for testing
|
||||
build-and-push-docker:
|
||||
name: Build and Push Docker
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
outputs:
|
||||
image_tag: ${{ env.DOCKER_IMAGE_NAME }}
|
||||
env:
|
||||
GITHUB_REF: ${{ github.ref }}
|
||||
steps:
|
||||
- name: Install Git LFS
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
cache-binary: false
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/Dockerfile.internal
|
||||
push: true
|
||||
tags: ${{ env.DOCKER_IMAGE_NAME }}
|
||||
build-args: |
|
||||
UNBOUND_DEPS=true
|
||||
|
||||
# This job runs pytest with all unbound extras in a GPU enabled host
|
||||
# It runs everytime a test image is created
|
||||
gpu-tests:
|
||||
name: GPU Unbound Tests
|
||||
needs: [build-and-push-docker]
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
env:
|
||||
HF_HOME: /home/user_lerobot/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
container:
|
||||
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv
|
||||
- name: Run end-to-end tests
|
||||
run: make test-end-to-end
|
||||
|
||||
# This job deletes the test image recently created
|
||||
# It runs everytime after the gpu-tests have finished
|
||||
delete-unbound-image:
|
||||
name: Delete Unbound Image
|
||||
needs: [gpu-tests, build-and-push-docker]
|
||||
if: always() && needs.build-and-push-docker.result == 'success'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Get Docker Hub Token and Delete Image
|
||||
# zizmor: ignore[template-injection]
|
||||
run: |
|
||||
IMAGE_NAME=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f1)
|
||||
IMAGE_TAG=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f2)
|
||||
|
||||
echo "Attempting to delete image: $IMAGE_NAME:$IMAGE_TAG"
|
||||
|
||||
TOKEN=$(curl -s -H "Content-Type: application/json" \
|
||||
-X POST \
|
||||
-d '{"username": "${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}", "password": "${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}"}' \
|
||||
https://hub.docker.com/v2/users/login/ | jq -r .token)
|
||||
|
||||
if [ "$TOKEN" == "null" ] || [ -z "$TOKEN" ]; then
|
||||
echo "::error::Failed to get Docker Hub token."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
HTTP_RESPONSE=$(curl -s -o /dev/null -w "%{http_code}" \
|
||||
-H "Authorization: JWT ${TOKEN}" \
|
||||
-X DELETE \
|
||||
https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/${IMAGE_TAG}/)
|
||||
|
||||
if [ "$HTTP_RESPONSE" -eq 204 ]; then
|
||||
echo "Successfully deleted Docker image tag: $IMAGE_NAME:$IMAGE_TAG"
|
||||
else
|
||||
echo "::error::Failed to delete Docker image. HTTP status: $HTTP_RESPONSE"
|
||||
exit 1
|
||||
fi
|
||||
+12
-11
@@ -26,7 +26,7 @@ repos:
|
||||
|
||||
##### General Code Quality & Formatting #####
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
rev: v6.0.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=1024']
|
||||
@@ -39,20 +39,20 @@ repos:
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.4
|
||||
rev: v0.14.1
|
||||
hooks:
|
||||
- id: ruff-format
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
|
||||
- repo: https://github.com/adhtruong/mirrors-typos
|
||||
rev: v1.34.0
|
||||
rev: v1.38.1
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--force-exclude]
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.20.0
|
||||
rev: v3.21.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py310-plus]
|
||||
@@ -68,12 +68,12 @@ repos:
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.27.2
|
||||
rev: v8.28.0
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||
rev: v1.11.0
|
||||
rev: v1.15.2
|
||||
hooks:
|
||||
- id: zizmor
|
||||
|
||||
@@ -86,11 +86,12 @@ repos:
|
||||
|
||||
# TODO(Steven): Uncomment when ready to use
|
||||
##### Static Analysis & Typing #####
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: v1.16.0
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
# args: [--python-version=3.10]
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.18.2
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: [--config-file=pyproject.toml]
|
||||
exclude: ^(examples|benchmarks|tests)/
|
||||
|
||||
##### Docstring Checks #####
|
||||
# - repo: https://github.com/akaihola/darglint2
|
||||
|
||||
+1
-2
@@ -72,7 +72,6 @@ post it.
|
||||
|
||||
Look at our implementations for [datasets](./src/lerobot/datasets/), [policies](./src/lerobot/policies/),
|
||||
environments ([aloha](https://github.com/huggingface/gym-aloha),
|
||||
[xarm](https://github.com/huggingface/gym-xarm),
|
||||
[pusht](https://github.com/huggingface/gym-pusht))
|
||||
and follow the same api design.
|
||||
|
||||
@@ -138,7 +137,7 @@ Follow these steps to start contributing:
|
||||
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
|
||||
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
|
||||
|
||||
Set up a development environment with conda or miniconda:
|
||||
Set up a development environment with conda:
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
|
||||
|
||||
@@ -119,10 +119,9 @@ test-tdmpc-ete-train:
|
||||
--policy.type=tdmpc \
|
||||
--policy.device=$(DEVICE) \
|
||||
--policy.push_to_hub=false \
|
||||
--env.type=xarm \
|
||||
--env.task=XarmLift-v0 \
|
||||
--env.type=pusht \
|
||||
--env.episode_length=5 \
|
||||
--dataset.repo_id=lerobot/xarm_lift_medium \
|
||||
--dataset.repo_id=lerobot/pusht_image \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
@@ -140,9 +139,10 @@ test-tdmpc-ete-eval:
|
||||
lerobot-eval \
|
||||
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=xarm \
|
||||
--env.type=pusht \
|
||||
--env.episode_length=5 \
|
||||
--env.task=XarmLift-v0 \
|
||||
--env.observation_height=96 \
|
||||
--env.observation_width=96 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
|
||||
@@ -104,14 +104,14 @@ LeRobot works with Python 3.10+ and PyTorch 2.2+.
|
||||
|
||||
### Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniforge`](https://conda-forge.org/download/):
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
When using `conda`, install `ffmpeg` in your environment:
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
@@ -197,7 +197,7 @@ wandb login
|
||||
|
||||
### Visualize datasets
|
||||
|
||||
Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
|
||||
Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
|
||||
|
||||
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
|
||||
|
||||
@@ -207,13 +207,13 @@ lerobot-dataset-viz \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
or from a dataset in a local folder with the `root` option and the `--mode local` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--root ./my_local_data_dir \
|
||||
--local-files-only 1 \
|
||||
--mode local \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
@@ -310,7 +310,7 @@ To upload these to the hub, run the following:
|
||||
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
|
||||
```
|
||||
|
||||
See [eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/eval.py) for an example of how other people may use your policy.
|
||||
See [lerobot_eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py) for an example of how other people may use your policy.
|
||||
|
||||
### Acknowledgment
|
||||
|
||||
|
||||
@@ -75,6 +75,14 @@ RUN uv venv --python python${PYTHON_VERSION}
|
||||
# Install Python dependencies for caching
|
||||
COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
|
||||
ARG UNBOUND_DEPS=false
|
||||
|
||||
RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml; \
|
||||
echo "Dependencies unbound:" && cat pyproject.toml; \
|
||||
fi
|
||||
|
||||
RUN uv pip install --no-cache ".[all]"
|
||||
|
||||
# Copy the rest of the application source code
|
||||
|
||||
@@ -61,6 +61,14 @@ RUN uv venv
|
||||
# Install Python dependencies for caching
|
||||
COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
|
||||
ARG UNBOUND_DEPS=false
|
||||
|
||||
RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml; \
|
||||
echo "Dependencies unbound:" && cat pyproject.toml; \
|
||||
fi
|
||||
|
||||
RUN uv pip install --no-cache ".[all]"
|
||||
|
||||
# Copy the rest of the application code
|
||||
|
||||
@@ -7,8 +7,6 @@
|
||||
- sections:
|
||||
- local: il_robots
|
||||
title: Imitation Learning for Robots
|
||||
- local: il_sim
|
||||
title: Imitation Learning in Sim
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
- local: integrate_hardware
|
||||
@@ -19,20 +17,37 @@
|
||||
title: Train RL in Simulation
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
- local: multi_gpu_training
|
||||
title: Multi GPU training
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
title: Using LeRobotDataset
|
||||
- local: porting_datasets_v3
|
||||
title: Porting Large Datasets
|
||||
- local: using_dataset_tools
|
||||
title: Using the Dataset Tools
|
||||
title: "Datasets"
|
||||
- sections:
|
||||
- local: act
|
||||
title: ACT
|
||||
- local: smolvla
|
||||
title: Finetune SmolVLA
|
||||
title: SmolVLA
|
||||
- local: pi0
|
||||
title: π₀ (Pi0)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
title: "Policies"
|
||||
- sections:
|
||||
- local: il_sim
|
||||
title: Imitation Learning in Sim
|
||||
- local: libero
|
||||
title: Using Libero
|
||||
title: "Policies"
|
||||
|
||||
- local: metaworld
|
||||
title: Using MetaWorld
|
||||
title: "Simulation"
|
||||
- sections:
|
||||
- local: introduction_processors
|
||||
title: Introduction to Robot Processors
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
# ACT (Action Chunking with Transformers)
|
||||
|
||||
ACT is a **lightweight and efficient policy for imitation learning**, especially well-suited for fine-grained manipulation tasks. It's the **first model we recommend when you're starting out** with LeRobot due to its fast training time, low computational requirements, and strong performance.
|
||||
|
||||
<div class="video-container">
|
||||
<iframe
|
||||
width="100%"
|
||||
height="415"
|
||||
src="https://www.youtube.com/embed/ft73x0LfGpM"
|
||||
title="LeRobot ACT Tutorial"
|
||||
frameborder="0"
|
||||
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
|
||||
allowfullscreen
|
||||
></iframe>
|
||||
</div>
|
||||
|
||||
_Watch this tutorial from the LeRobot team to learn how ACT works: [LeRobot ACT Tutorial](https://www.youtube.com/watch?v=ft73x0LfGpM)_
|
||||
|
||||
## Model Overview
|
||||
|
||||
Action Chunking with Transformers (ACT) was introduced in the paper [Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware](https://arxiv.org/abs/2304.13705) by Zhao et al. The policy was designed to enable precise, contact-rich manipulation tasks using affordable hardware and minimal demonstration data.
|
||||
|
||||
### Why ACT is Great for Beginners
|
||||
|
||||
ACT stands out as an excellent starting point for several reasons:
|
||||
|
||||
- **Fast Training**: Trains in a few hours on a single GPU
|
||||
- **Lightweight**: Only ~80M parameters, making it efficient and easy to work with
|
||||
- **Data Efficient**: Often achieves high success rates with just 50 demonstrations
|
||||
|
||||
### Architecture
|
||||
|
||||
ACT uses a transformer-based architecture with three main components:
|
||||
|
||||
1. **Vision Backbone**: ResNet-18 processes images from multiple camera viewpoints
|
||||
2. **Transformer Encoder**: Synthesizes information from camera features, joint positions, and a learned latent variable
|
||||
3. **Transformer Decoder**: Generates coherent action sequences using cross-attention
|
||||
|
||||
The policy takes as input:
|
||||
|
||||
- Multiple RGB images (e.g., from wrist cameras, front/top cameras)
|
||||
- Current robot joint positions
|
||||
- A latent style variable `z` (learned during training, set to zero during inference)
|
||||
|
||||
And outputs a chunk of `k` future action sequences.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. ACT is included in the base LeRobot installation, so no additional dependencies are needed!
|
||||
|
||||
## Training ACT
|
||||
|
||||
ACT works seamlessly with the standard LeRobot training pipeline. Here's a complete example for training ACT on your dataset:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/your_dataset \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_your_dataset \
|
||||
--job_name=act_your_dataset \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=${HF_USER}/act_policy
|
||||
```
|
||||
|
||||
### Training Tips
|
||||
|
||||
1. **Start with defaults**: ACT's default hyperparameters work well for most tasks
|
||||
2. **Training duration**: Expect a few hours for 100k training steps on a single GPU
|
||||
3. **Batch size**: Start with batch size 8 and adjust based on your GPU memory
|
||||
|
||||
### Train using Google Colab
|
||||
|
||||
If your local computer doesn't have a powerful GPU, you can utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||
|
||||
## Evaluating ACT
|
||||
|
||||
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_robot \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Your task description" \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
```
|
||||
@@ -31,15 +31,15 @@ Then, spin up a policy server (in one terminal, or in a separate machine) specif
|
||||
You can spin up a policy server running:
|
||||
|
||||
```shell
|
||||
python src/lerobot/async_inference/policy_server.py \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080 \
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080
|
||||
```
|
||||
|
||||
This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with:
|
||||
|
||||
```shell
|
||||
python src/lerobot/async_inference/robot_client.py \
|
||||
python -m lerobot.async_inference.robot_client \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
@@ -113,9 +113,9 @@ As such, spinning up a policy server is as easy as specifying the host address a
|
||||
<hfoptions id="start_policy_server">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.scripts.server.policy_server \
|
||||
--host="localhost" \
|
||||
--port=8080
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
@@ -148,7 +148,7 @@ The `RobotClient` streams observations to the `PolicyServer`, and receives actio
|
||||
<hfoptions id="start_robot_client">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python src/lerobot/async_inference/robot_client.py \
|
||||
python -m lerobot.async_inference.robot_client \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
# GR00T N1.5 Policy
|
||||
|
||||
GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
|
||||
|
||||
This document outlines the specifics of its integration and usage within the LeRobot framework.
|
||||
|
||||
## Model Overview
|
||||
|
||||
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
|
||||
|
||||
Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
|
||||
|
||||
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
|
||||
|
||||
Its strong performance comes from being trained on an expansive and diverse humanoid dataset, which includes:
|
||||
|
||||
- Real captured data from robots.
|
||||
- Synthetic data generated using NVIDIA Isaac GR00T Blueprint.
|
||||
- Internet-scale video data.
|
||||
|
||||
This approach allows the model to be highly adaptable through post-training for specific embodiments, tasks, and environments.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
As of today, GR00T N1.5 requires flash attention for it's internal working.
|
||||
|
||||
We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
|
||||
|
||||
1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
|
||||
2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
|
||||
|
||||
```bash
|
||||
# Check https://pytorch.org/get-started/locally/ for your system
|
||||
pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
|
||||
pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
|
||||
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
|
||||
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
|
||||
```
|
||||
|
||||
3. Install LeRobot by running:
|
||||
|
||||
```bash
|
||||
pip install lerobot[groot] # consider also installing libero,dev and test tags
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To use GR00T in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=groot
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Training Command Example
|
||||
|
||||
Here's a complete training command for finetuning the base GR00T model on your own dataset:
|
||||
|
||||
```bash
|
||||
# Using a multi-GPU setup
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=$NUM_GPUS \
|
||||
$(which lerobot-train) \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--save_checkpoint=true \
|
||||
--batch_size=$BATCH_SIZE \
|
||||
--steps=$NUM_STEPS \
|
||||
--save_freq=$SAVE_FREQ \
|
||||
--log_freq=$LOG_FREQ \
|
||||
--policy.push_to_hub=true \
|
||||
--policy.type=groot \
|
||||
--policy.repo_id=$REPO_ID \
|
||||
--policy.tune_diffusion_model=false \
|
||||
--dataset.repo_id=$DATASET_ID \
|
||||
--wandb.enable=true \
|
||||
--wandb.disable_artifact=true \
|
||||
--job_name=$JOB_NAME
|
||||
```
|
||||
|
||||
## Performance Results
|
||||
|
||||
### Libero Benchmark Results
|
||||
|
||||
GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
|
||||
|
||||
| Benchmark | LeRobot Implementation | GR00T Reference |
|
||||
| ------------------ | ---------------------- | --------------- |
|
||||
| **Libero Spatial** | 82.0% | 92.0% |
|
||||
| **Libero Object** | 99.0% | 92.0% |
|
||||
| **Libero Long** | 82.0% | 76.0% |
|
||||
| **Average** | 87.0% | 87.0% |
|
||||
|
||||
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
|
||||
|
||||
### Evaluate in your hardware setup
|
||||
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=bi_so100_follower \
|
||||
--robot.left_arm_port=/dev/ttyACM1 \
|
||||
--robot.right_arm_port=/dev/ttyACM0 \
|
||||
--robot.id=bimanual_follower \
|
||||
--robot.cameras='{ right: {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30},
|
||||
left: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30},
|
||||
top: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30},
|
||||
}' \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm"
|
||||
--policy.path=<user>/groot-bimanual # your trained model
|
||||
--dataset.episode_time_s=30
|
||||
--dataset.reset_time_s=10
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T).
|
||||
@@ -165,7 +165,7 @@ huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
HF_USER=$(hf auth whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
@@ -513,13 +513,14 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.policies.factory import make_processor
|
||||
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -562,7 +563,7 @@ init_rerun(session_name="recording")
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
preprocessor, postprocessor = make_processor(
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
# Installation
|
||||
|
||||
## Install [`miniforge`](https://conda-forge.org/download/)
|
||||
|
||||
```bash
|
||||
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
|
||||
bash Miniforge3-$(uname)-$(uname -m).sh
|
||||
```
|
||||
|
||||
## Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install)
|
||||
Create a virtual environment with Python 3.10, using conda:
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
@@ -14,7 +21,7 @@ Then activate your conda environment, you have to do this each time you open a s
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
When using `conda`, install `ffmpeg` in your environment:
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
@@ -74,6 +81,9 @@ _Replace `[...]` with your desired features._
|
||||
For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install libero or pi, you will have to do: `pip install "lerobot[pi,libero]@git+https://github.com/huggingface/lerobot.git"`
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||
@@ -91,7 +101,7 @@ LeRobot provides optional extras for specific functionalities. Multiple extras c
|
||||
|
||||
### Simulations
|
||||
|
||||
Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), `xarm` ([gym-xarm](https://github.com/huggingface/gym-xarm)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht))
|
||||
Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht))
|
||||
Example:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -8,7 +8,7 @@ To that end, we provide the [`Robot`](https://github.com/huggingface/lerobot/blo
|
||||
|
||||
- Your own robot which exposes a communication interface (e.g. serial, CAN, TCP)
|
||||
- A way to read sensor data and send motor commands programmatically, e.g. manufacturer's SDK or API, or your own protocol implementation.
|
||||
- LeRobot installed in your environment. Follow our [Installation Guide](./installation.mdx).
|
||||
- LeRobot installed in your environment. Follow our [Installation Guide](./installation).
|
||||
|
||||
## Choose your motors
|
||||
|
||||
@@ -65,7 +65,7 @@ class MyCoolRobotConfig(RobotConfig):
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
[Cameras tutorial](./cameras.mdx) to understand how to detect and add your camera.
|
||||
[Cameras tutorial](./cameras) to understand how to detect and add your camera.
|
||||
|
||||
Next, we'll create our actual robot class which inherits from `Robot`. This abstract class defines a contract you must follow for your robot to be usable with the rest of the LeRobot tools.
|
||||
|
||||
@@ -208,34 +208,36 @@ LeRobot supports saving and loading calibration data automatically. This is usef
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
> @property
|
||||
> def is_calibrated(self) -> bool:
|
||||
> return True
|
||||
>
|
||||
> def calibrate(self) -> None:
|
||||
> pass
|
||||
> ```
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
def calibrate(self) -> None:
|
||||
pass
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### `is_calibrated`
|
||||
|
||||
This should reflect whether your robot has the required calibration loaded.
|
||||
|
||||
```
|
||||
<!-- prettier-ignore-end -->python
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.bus.is_calibrated
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### `calibrate()`
|
||||
|
||||
The goal of the calibration is twofold:
|
||||
- Know the physical range of motion of each motors in order to only send commands within this range.
|
||||
- Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
|
||||
|
||||
- Know the physical range of motion of each motors in order to only send commands within this range.
|
||||
- Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
|
||||
|
||||
It should implement the logic for calibration (if relevant) and update the `self.calibration` dictionary. If you are using Feetech or Dynamixel motors, our bus interfaces already include methods to help with this.
|
||||
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
def calibrate(self) -> None:
|
||||
@@ -335,6 +337,134 @@ For implementing teleoperation devices, we also provide a [`Teleoperator`](https
|
||||
|
||||
The main differences are in the I/O functions: a teleoperator allows you to produce action via `get_action` and can receive feedback actions via `send_feedback`. Feedback could be anything controllable on the teleoperation device that could help the person controlling it understand the consequences of the actions sent. Think motion/force feedback on a leader arm, vibrations on a gamepad controller for example. To implement a teleoperator, you can follow this same tutorial and adapt it for these two methods.
|
||||
|
||||
## Using Your Own `LeRobot` Devices 🔌
|
||||
|
||||
You can easily extend `lerobot` with your own custom hardware—be it a camera, robot, or teleoperation device—by creating a separate, installable Python package. If you follow a few simple conventions, the `lerobot` command-line tools (like `lerobot-teleop` and `lerobot-record`) will **automatically discover and integrate your creations** without requiring any changes to the `lerobot` source code.
|
||||
|
||||
This guide outlines the conventions your plugin must follow.
|
||||
|
||||
### The 4 Core Conventions
|
||||
|
||||
To ensure your custom device is discoverable, you must adhere to the following four rules.
|
||||
|
||||
#### 1\. Create an Installable Package with a Specific Prefix
|
||||
|
||||
Your project must be a standard, installable Python package. Crucially, the name of your package (as defined in `pyproject.toml` or `setup.py`) must begin with one of these prefixes:
|
||||
|
||||
- `lerobot_robot_` for a robot.
|
||||
- `lerobot_camera_` for a camera.
|
||||
- `lerobot_teleoperator_` for a teleoperation device.
|
||||
|
||||
This prefix system is how `lerobot` automatically finds your plugin in the Python environment.
|
||||
|
||||
#### 2\. Follow the `SomethingConfig`/`Something` Naming Pattern
|
||||
|
||||
Your device's implementation class must be named after its configuration class, simply by removing the `Config` suffix.
|
||||
|
||||
- **Config Class:** `MyAwesomeTeleopConfig`
|
||||
- **Device Class:** `MyAwesomeTeleop`
|
||||
|
||||
#### 3\. Place Your Files in a Predictable Structure
|
||||
|
||||
The device class (`MyAwesomeTeleop`) must be located in a predictable module relative to its configuration class (`MyAwesomeTeleopConfig`). `lerobot` will automatically search in these locations:
|
||||
|
||||
- In the **same module** as the config class.
|
||||
- In a **submodule named after the device** (e.g., `my_awesome_teleop.py`).
|
||||
|
||||
The recommended and simplest structure is to place them in separate, clearly named files within the same directory.
|
||||
|
||||
#### 4\. Expose Classes in `__init__.py`
|
||||
|
||||
Your package's `__init__.py` file should import and expose both the configuration and the device classes, making them easily accessible.
|
||||
|
||||
### Putting It All Together: A Complete Example
|
||||
|
||||
Let's create a new teleoperator called `my_awesome_teleop`.
|
||||
|
||||
#### Directory Structure
|
||||
|
||||
Here is what the project folder should look like. The package name, `lerobot_teleoperator_my_awesome_teleop`, follows **Convention \#1**.
|
||||
|
||||
```
|
||||
lerobot_teleoperator_my_awesome_teleop/
|
||||
├── pyproject.toml # (or setup.py) lists lerobot as a dependency
|
||||
└── lerobot_teleoperator_my_awesome_teleop/
|
||||
├── __init__.py
|
||||
├── config_my_awesome_teleop.py
|
||||
└── my_awesome_teleop.py
|
||||
```
|
||||
|
||||
#### File Contents
|
||||
|
||||
- **`config_my_awesome_teleop.py`**: Defines the configuration class. Note the `Config` suffix (**Convention \#2**).
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.teleoperators.config import TeleoperatorConfig
|
||||
|
||||
@TeleoperatorConfig.register_subclass("my_awesome_teleop")
|
||||
@dataclass
|
||||
class MyAwesomeTeleopConfig(TeleoperatorConfig):
|
||||
# Your configuration fields go here
|
||||
port: str = "192.168.1.1"
|
||||
```
|
||||
|
||||
- **`my_awesome_teleop.py`**: Implements the device. The class name `MyAwesomeTeleop` matches its config class name (**Convention \#2**). This file structure adheres to **Convention \#3**.
|
||||
|
||||
```python
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
|
||||
from .config_my_awesome_teleop import MyAwesomeTeleopConfig
|
||||
|
||||
class MyAwesomeTeleop(Teleoperator):
|
||||
config_class = MyAwesomeTeleopConfig
|
||||
name = "my_awesome_teleop"
|
||||
|
||||
def __init__(self, config: MyAwesomeTeleopConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Your device logic (e.g., connect) goes here
|
||||
```
|
||||
|
||||
- **`__init__.py`**: Exposes the key classes (**Convention \#4**).
|
||||
|
||||
```python
|
||||
from .config_my_awesome_teleop import MyAwesomeTeleopConfig
|
||||
from .my_awesome_teleop import MyAwesomeTeleop
|
||||
```
|
||||
|
||||
### Installation and Usage
|
||||
|
||||
1. **Install your new plugin in your Python environment.** You can install your local plugin package using `pip`'s editable mode or from PyPi.
|
||||
|
||||
```bash
|
||||
# Locally
|
||||
# Navigate to your plugin's root directory and install it
|
||||
cd lerobot_teleoperator_my_awesome_teleop
|
||||
pip install -e .
|
||||
|
||||
# From PyPi
|
||||
pip install lerobot_teleoperator_my_awesome_teleop
|
||||
```
|
||||
|
||||
2. **Use it directly from the command line.** Now, you can use your custom device by referencing its type.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate --teleop.type=my_awesome_teleop \
|
||||
# other arguments
|
||||
```
|
||||
|
||||
And that's it\! Your custom device is now fully integrated.
|
||||
|
||||
### Looking for an example ?
|
||||
|
||||
Check out these two packages from the community:
|
||||
|
||||
- https://github.com/SpesRobotics/lerobot-robot-xarm
|
||||
- https://github.com/SpesRobotics/lerobot-teleoperator-teleop
|
||||
|
||||
## Wrapping Up
|
||||
|
||||
Once your robot class is complete, you can leverage the LeRobot ecosystem:
|
||||
|
||||
@@ -297,9 +297,9 @@ LeRobot provides many registered processor steps. Here are the most commonly use
|
||||
|
||||
### Next Steps
|
||||
|
||||
- **[Implement Your Own Processor](implement_your_own_processor.mdx)** - Create custom processor steps
|
||||
- **[Debug Your Pipeline](debug_processor_pipeline.mdx)** - Troubleshoot and optimize pipelines
|
||||
- **[Processors for Robots and Teleoperators](processors_robots_teleop.mdx)** - Real-world integration patterns
|
||||
- **[Implement Your Own Processor](./implement_your_own_processor)** - Create custom processor steps
|
||||
- **[Debug Your Pipeline](./debug_processor_pipeline)** - Troubleshoot and optimize pipelines
|
||||
- **[Processors for Robots and Teleoperators](./processors_robots_teleop)** - Real-world integration patterns
|
||||
|
||||
## Summary
|
||||
|
||||
|
||||
@@ -279,3 +279,36 @@ python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DAT
|
||||
- Aggregates parquet files: `episode-0000.parquet`, `episode-0001.parquet`, … → **`file-0000.parquet`**, …
|
||||
- Aggregates mp4 files: `episode-0000.mp4`, `episode-0001.mp4`, … → **`file-0000.mp4`**, …
|
||||
- Updates `meta/episodes/*` (chunked Parquet) with per‑episode lengths, tasks, and byte/frame offsets.
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Always call `finalize()` before pushing
|
||||
|
||||
When creating or recording datasets, you **must** call `dataset.finalize()` to properly close parquet writers. See the [PR #1903](https://github.com/huggingface/lerobot/pull/1903) for more details.
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Create dataset and record episodes
|
||||
dataset = LeRobotDataset.create(...)
|
||||
|
||||
for episode in range(num_episodes):
|
||||
# Record frames
|
||||
for frame in episode_data:
|
||||
dataset.add_frame(frame)
|
||||
dataset.save_episode()
|
||||
|
||||
# Call finalize() when done recording and before push_to_hub()
|
||||
dataset.finalize() # Closes parquet writers, writes metadata footers
|
||||
dataset.push_to_hub()
|
||||
```
|
||||
|
||||
**Why is this necessary?**
|
||||
|
||||
Dataset v3.0 uses incremental parquet writing with buffered metadata for efficiency. The `finalize()` method:
|
||||
|
||||
- Flushes any buffered episode metadata to disk
|
||||
- Closes parquet writers to write footer metadata, otherwise the parquet files will be corrupt
|
||||
- Ensures the dataset is valid for loading
|
||||
|
||||
Without calling `finalize()`, your parquet files will be incomplete and the dataset won't load properly.
|
||||
|
||||
@@ -125,3 +125,42 @@ lerobot-train \
|
||||
LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation:
|
||||
|
||||
- `export MUJOCO_GL=egl` → for headless servers (e.g. HPC, cloud)
|
||||
|
||||
## Reproducing π₀.₅ results
|
||||
|
||||
We reproduce the results of π₀.₅ on the LIBERO benchmark using the LeRobot implementation. We take the Physical Intelligence LIBERO base model (`pi05_libero`) and finetune for an additional 6k steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero).
|
||||
|
||||
The finetuned model can be found here:
|
||||
|
||||
- **π₀.₅ LIBERO**: [lerobot/pi05_libero_finetuned](https://huggingface.co/lerobot/pi05_libero_finetuned)
|
||||
|
||||
We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--output_dir=/logs/ \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10 \
|
||||
--policy.path=pi05_libero_finetuned \
|
||||
--policy.n_action_steps=10 \
|
||||
--output_dir=./eval_logs/ \
|
||||
--env.max_parallel_tasks=1
|
||||
```
|
||||
|
||||
**Note:** We set `n_action_steps=10`, similar to the original OpenPI implementation.
|
||||
|
||||
### Results
|
||||
|
||||
We obtain the following results on the LIBERO benchmark:
|
||||
|
||||
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
|
||||
| -------- | -------------- | ------------- | ----------- | --------- | -------- |
|
||||
| **π₀.₅** | 97.0 | 99.0 | 98.0 | 96.0 | **97.5** |
|
||||
|
||||
These results are consistent with the original [results](https://github.com/Physical-Intelligence/openpi/tree/main/examples/libero#results) reported by Physical Intelligence:
|
||||
|
||||
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
|
||||
| -------- | -------------- | ------------- | ----------- | --------- | --------- |
|
||||
| **π₀.₅** | 98.8 | 98.2 | 98.0 | 92.4 | **96.85** |
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
# Meta-World
|
||||
|
||||
Meta-World is a well-designed, open-source simulation benchmark for multi-task and meta reinforcement learning in continuous-control robotic manipulation. It gives researchers a shared, realistic playground to test whether algorithms can _learn many different tasks_ and _generalize quickly to new ones_ — two central challenges for real-world robotics.
|
||||
|
||||
- 📄 [MetaWorld paper](https://arxiv.org/pdf/1910.10897)
|
||||
- 💻 [Original MetaWorld repo](https://github.com/Farama-Foundation/Metaworld)
|
||||
|
||||

|
||||
|
||||
## Why Meta-World matters
|
||||
|
||||
- **Diverse, realistic tasks.** Meta-World bundles a large suite of simulated manipulation tasks (50 in the MT50 suite) using everyday objects and a common tabletop Sawyer arm. This diversity exposes algorithms to a wide variety of dynamics, contacts and goal specifications while keeping a consistent control and observation structure.
|
||||
- **Focus on generalization and multi-task learning.** By evaluating across task distributions that share structure but differ in goals and objects, Meta-World reveals whether an agent truly learns transferable skills rather than overfitting to a narrow task.
|
||||
- **Standardized evaluation protocol.** It provides clear evaluation modes and difficulty splits, so different methods can be compared fairly across easy, medium, hard and very-hard regimes.
|
||||
- **Empirical insight.** Past evaluations on Meta-World show impressive progress on some fronts, but also highlight that current multi-task and meta-RL methods still struggle with large, diverse task sets. That gap points to important research directions.
|
||||
|
||||
## What it enables in LeRobot
|
||||
|
||||
In LeRobot, you can evaluate any policy or vision-language-action (VLA) model on Meta-World tasks and get a clear success-rate measure. The integration is designed to be straightforward:
|
||||
|
||||
- We provide a LeRobot-ready dataset for Meta-World (MT50) on the HF Hub: `https://huggingface.co/datasets/lerobot/metaworld_mt50`.
|
||||
- This dataset is formatted for the MT50 evaluation that uses all 50 tasks (the most challenging multi-task setting).
|
||||
- MT50 gives the policy a one-hot task vector and uses fixed object/goal positions for consistency.
|
||||
|
||||
- Task descriptions and the exact keys required for evaluation are available in the repo/dataset — use these to ensure your policy outputs the right success signals.
|
||||
|
||||
## Quick start, train a SmolVLA policy on Meta-World
|
||||
|
||||
Example command to train a SmolVLA policy on a subset of tasks:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--policy.repo_id=${HF_USER}/metaworld-test \
|
||||
--policy.load_vlm_weights=true \
|
||||
--dataset.repo_id=lerobot/metaworld_mt50 \
|
||||
--env.type=metaworld \
|
||||
--env.task=assembly-v3,dial-turn-v3,handle-press-side-v3 \
|
||||
--output_dir=./outputs/ \
|
||||
--steps=100000 \
|
||||
--batch_size=4 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval_freq=1000
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- `--env.task` accepts explicit task lists (comma separated) or difficulty groups (e.g., `env.task="hard"`).
|
||||
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget.
|
||||
- **Gymnasium Assertion Error**: if you encounter an error like
|
||||
`AssertionError: ['human', 'rgb_array', 'depth_array']` when running MetaWorld environments, this comes from a mismatch between MetaWorld and your Gymnasium version.
|
||||
We recommend using:
|
||||
|
||||
```bash
|
||||
pip install "gymnasium==1.1.0"
|
||||
```
|
||||
|
||||
to ensure proper compatibility.
|
||||
|
||||
## Quick start — evaluate a trained policy
|
||||
|
||||
To evaluate a trained policy on the Meta-World medium difficulty split:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="your-policy-id" \
|
||||
--env.type=metaworld \
|
||||
--env.task=medium \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=2
|
||||
```
|
||||
|
||||
This will run episodes and return per-task success rates using the standard Meta-World evaluation keys.
|
||||
|
||||
## Practical tips
|
||||
|
||||
- If you care about generalization, run on the full MT50 suite — it’s intentionally challenging and reveals strengths/weaknesses better than a few narrow tasks.
|
||||
- Use the one-hot task conditioning for multi-task training (MT10 / MT50 conventions) so policies have explicit task context.
|
||||
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.
|
||||
@@ -0,0 +1,125 @@
|
||||
# Multi-GPU Training
|
||||
|
||||
This guide shows you how to train policies on multiple GPUs using [Hugging Face Accelerate](https://huggingface.co/docs/accelerate).
|
||||
|
||||
## Installation
|
||||
|
||||
First, ensure you have accelerate installed:
|
||||
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
## Training with Multiple GPUs
|
||||
|
||||
You can launch training in two ways:
|
||||
|
||||
### Option 1: Without config (specify parameters directly)
|
||||
|
||||
You can specify all parameters directly in the command without running `accelerate config`:
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=2 \
|
||||
$(which lerobot-train) \
|
||||
--dataset.repo_id=${HF_USER}/my_dataset \
|
||||
--policy.type=act \
|
||||
--policy.repo_id=${HF_USER}/my_trained_policy \
|
||||
--output_dir=outputs/train/act_multi_gpu \
|
||||
--job_name=act_multi_gpu \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
**Key accelerate parameters:**
|
||||
|
||||
- `--multi_gpu`: Enable multi-GPU training
|
||||
- `--num_processes=2`: Number of GPUs to use
|
||||
- `--mixed_precision=fp16`: Use fp16 mixed precision (or `bf16` if supported)
|
||||
|
||||
### Option 2: Using accelerate config
|
||||
|
||||
If you prefer to save your configuration, you can optionally configure accelerate for your hardware setup by running:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
This interactive setup will ask you questions about your training environment (number of GPUs, mixed precision settings, etc.) and saves the configuration for future use. For a simple multi-GPU setup on a single machine, you can use these recommended settings:
|
||||
|
||||
- Compute environment: This machine
|
||||
- Number of machines: 1
|
||||
- Number of processes: (number of GPUs you want to use)
|
||||
- GPU ids to use: (leave empty to use all)
|
||||
- Mixed precision: fp16 or bf16 (recommended for faster training)
|
||||
|
||||
Then launch training with:
|
||||
|
||||
```bash
|
||||
accelerate launch $(which lerobot-train) \
|
||||
--dataset.repo_id=${HF_USER}/my_dataset \
|
||||
--policy.type=act \
|
||||
--policy.repo_id=${HF_USER}/my_trained_policy \
|
||||
--output_dir=outputs/train/act_multi_gpu \
|
||||
--job_name=act_multi_gpu \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
When you launch training with accelerate:
|
||||
|
||||
1. **Automatic detection**: LeRobot automatically detects if it's running under accelerate
|
||||
2. **Data distribution**: Your batch is automatically split across GPUs
|
||||
3. **Gradient synchronization**: Gradients are synchronized across GPUs during backpropagation
|
||||
4. **Single process logging**: Only the main process logs to wandb and saves checkpoints
|
||||
|
||||
## Learning Rate and Training Steps Scaling
|
||||
|
||||
**Important:** LeRobot does **NOT** automatically scale learning rates or training steps based on the number of GPUs. This gives you full control over your training hyperparameters.
|
||||
|
||||
### Why No Automatic Scaling?
|
||||
|
||||
Many distributed training frameworks automatically scale the learning rate by the number of GPUs (e.g., `lr = base_lr × num_gpus`).
|
||||
However, LeRobot keeps the learning rate exactly as you specify it.
|
||||
|
||||
### When and How to Scale
|
||||
|
||||
If you want to scale your hyperparameters when using multiple GPUs, you should do it manually:
|
||||
|
||||
**Learning Rate Scaling:**
|
||||
|
||||
```bash
|
||||
# Example: 2 GPUs with linear LR scaling
|
||||
# Base LR: 1e-4, with 2 GPUs -> 2e-4
|
||||
accelerate launch --num_processes=2 $(which lerobot-train) \
|
||||
--optimizer.lr=2e-4 \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--policy=act
|
||||
```
|
||||
|
||||
**Training Steps Scaling:**
|
||||
|
||||
Since the effective batch size `bs` increases with multiple GPUs (batch_size × num_gpus), you may want to reduce the number of training steps proportionally:
|
||||
|
||||
```bash
|
||||
# Example: 2 GPUs with effective batch size 2x larger
|
||||
# Original: batch_size=8, steps=100000
|
||||
# With 2 GPUs: batch_size=8 (16 in total), steps=50000
|
||||
accelerate launch --num_processes=2 $(which lerobot-train) \
|
||||
--batch_size=8 \
|
||||
--steps=50000 \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--policy=act
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
|
||||
- Training logs, checkpoints, and hub uploads are only done by the main process to avoid conflicts. Non-main processes have console logging disabled to prevent duplicate output.
|
||||
- The effective batch size is `batch_size × num_gpus`. If you use 4 GPUs with `--batch_size=8`, your effective batch size is 32.
|
||||
- Learning rate scheduling is handled correctly across multiple processes—LeRobot sets `step_scheduler_with_optimizer=False` to prevent accelerate from adjusting scheduler steps based on the number of processes.
|
||||
- When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility.
|
||||
- WandB integration automatically initializes only on the main process, preventing multiple runs from being created.
|
||||
|
||||
For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate). If you want to learn more about how to train on a large number of GPUs, checkout this awesome guide: [Ultrascale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook).
|
||||
@@ -79,7 +79,7 @@ After running the example:
|
||||
- Android: after starting the script, open the printed local URL on your phone, tap Start, then press and hold Move.
|
||||
- iOS: open HEBI Mobile I/O first; B1 enables motion. A3 controls the gripper.
|
||||
|
||||
Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop.mdx) guide.
|
||||
Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop) guide.
|
||||
|
||||
- Run this example to record a dataset, which saves absolute end effector observations and actions:
|
||||
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
# π₀ (Pi0)
|
||||
|
||||
π₀ is a **Vision-Language-Action model for general robot control**, from Physical Intelligence. The LeRobot implementation is adapted from their open source [OpenPI](https://github.com/Physical-Intelligence/openpi) repository.
|
||||
|
||||
## Model Overview
|
||||
|
||||
π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi0). Unlike traditional robot programs that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks.
|
||||
|
||||
### The Vision for Physical Intelligence
|
||||
|
||||
As described by Physical Intelligence, while AI has achieved remarkable success in digital domains, from chess-playing to drug discovery, human intelligence still dramatically outpaces AI in the physical world. To paraphrase Moravec's paradox, winning a game of chess represents an "easy" problem for AI, but folding a shirt or cleaning up a table requires solving some of the most difficult engineering problems ever conceived. π₀ represents a first step toward developing artificial physical intelligence that enables users to simply ask robots to perform any task they want, just like they can with large language models.
|
||||
|
||||
### Architecture and Approach
|
||||
|
||||
π₀ combines several key innovations:
|
||||
|
||||
- **Flow Matching**: Uses a novel method to augment pre-trained VLMs with continuous action outputs via flow matching (a variant of diffusion models)
|
||||
- **Cross-Embodiment Training**: Trained on data from 8 distinct robot platforms including UR5e, Bimanual UR5e, Franka, Bimanual Trossen, Bimanual ARX, Mobile Trossen, and Mobile Fibocom
|
||||
- **Internet-Scale Pre-training**: Inherits semantic knowledge from a pre-trained 3B parameter Vision-Language Model
|
||||
- **High-Frequency Control**: Outputs motor commands at up to 50 Hz for real-time dexterous manipulation
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. Install Pi0 dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
## Training Data and Capabilities
|
||||
|
||||
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
|
||||
|
||||
1. **Internet-Scale Pre-training**: Vision-language data from the web for semantic understanding
|
||||
2. **Open X-Embodiment Dataset**: Open-source robot manipulation datasets
|
||||
3. **Physical Intelligence Dataset**: Large and diverse dataset of dexterous tasks across 8 distinct robots
|
||||
|
||||
## Usage
|
||||
|
||||
To use π₀ in LeRobot, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=pi0
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi0 \
|
||||
--output_dir=./outputs/pi0_training \
|
||||
--job_name=pi0_training \
|
||||
--policy.pretrained_path=lerobot/pi0_base \
|
||||
--policy.repo_id=your_repo_id \
|
||||
--policy.compile_model=true \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--batch_size=32
|
||||
```
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
- **`--policy.compile_model=true`**: Enables model compilation for faster training
|
||||
- **`--policy.gradient_checkpointing=true`**: Reduces memory usage significantly during training
|
||||
- **`--policy.dtype=bfloat16`**: Use mixed precision training for efficiency
|
||||
- **`--batch_size=32`**: Batch size for training, adapt this based on your GPU memory
|
||||
- **`--policy.pretrained_path=lerobot/pi0_base`**: The base π₀ model you want to finetune, options are:
|
||||
- [lerobot/pi0_base](https://huggingface.co/lerobot/pi0_base)
|
||||
- [lerobot/pi0_libero](https://huggingface.co/lerobot/pi0_libero) (specifically trained on the Libero dataset)
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
@@ -0,0 +1,107 @@
|
||||
# π₀.₅ (Pi05) Policy
|
||||
|
||||
π₀.₅ is a **Vision-Language-Action model with open-world generalization**, from Physical Intelligence. The LeRobot implementation is adapted from their open source [OpenPI](https://github.com/Physical-Intelligence/openpi) repository.
|
||||
|
||||
## Model Overview
|
||||
|
||||
π₀.₅ represents a significant evolution from π₀, developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi05) to address a big challenge in robotics: **open-world generalization**. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
|
||||
|
||||
### The Generalization Challenge
|
||||
|
||||
As Physical Intelligence explains, the fundamental challenge isn't performing tasks of agility or dexterity, but generalization, the ability to correctly perform tasks in new settings with new objects. Consider a robot cleaning different homes: each home has different objects in different places. Generalization must occur at multiple levels:
|
||||
|
||||
- **Physical Level**: Understanding how to pick up a spoon (by the handle) or plate (by the edge), even with unseen objects in cluttered environments
|
||||
- **Semantic Level**: Understanding task semantics, where to put clothes and shoes (laundry hamper, not on the bed), and what tools are appropriate for cleaning spills
|
||||
- **Environmental Level**: Adapting to "messy" real-world environments like homes, grocery stores, offices, and hospitals
|
||||
|
||||
### Co-Training on Heterogeneous Data
|
||||
|
||||
The breakthrough innovation in π₀.₅ is **co-training on heterogeneous data sources**. The model learns from:
|
||||
|
||||
1. **Multimodal Web Data**: Image captioning, visual question answering, object detection
|
||||
2. **Verbal Instructions**: Humans coaching robots through complex tasks step-by-step
|
||||
3. **Subtask Commands**: High-level semantic behavior labels (e.g., "pick up the pillow" for an unmade bed)
|
||||
4. **Cross-Embodiment Robot Data**: Data from various robot platforms with different capabilities
|
||||
5. **Multi-Environment Data**: Static robots deployed across many different homes
|
||||
6. **Mobile Manipulation Data**: ~400 hours of mobile robot demonstrations
|
||||
|
||||
This diverse training mixture creates a "curriculum" that enables generalization across physical, visual, and semantic levels simultaneously.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. Install Pi0.5 dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=pi05
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Training Command Example
|
||||
|
||||
Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py\
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi05 \
|
||||
--output_dir=./outputs/pi05_training \
|
||||
--job_name=pi05_training \
|
||||
--policy.repo_id=your_repo_id \
|
||||
--policy.pretrained_path=lerobot/pi05_base \
|
||||
--policy.compile_model=true \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--wandb.enable=true \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--batch_size=32
|
||||
```
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
- **`--policy.compile_model=true`**: Enables model compilation for faster training
|
||||
- **`--policy.gradient_checkpointing=true`**: Reduces memory usage significantly during training
|
||||
- **`--policy.dtype=bfloat16`**: Use mixed precision training for efficiency
|
||||
- **`--batch_size=32`**: Batch size for training, adapt this based on your GPU memory
|
||||
- **`--policy.pretrained_path=lerobot/pi05_base`**: The base π₀.₅ model you want to finetune, options are:
|
||||
- [lerobot/pi05_base](https://huggingface.co/lerobot/pi05_base)
|
||||
- [lerobot/pi05_libero](https://huggingface.co/lerobot/pi05_libero) (specifically trained on the Libero dataset)
|
||||
|
||||
If your dataset is not converted with `quantiles`, you can convert it with the following command:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
--repo-id=your_dataset \
|
||||
```
|
||||
|
||||
Or train pi05 with this normalization mapping: `--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'`
|
||||
|
||||
## Performance Results
|
||||
|
||||
### Libero Benchmark Results
|
||||
|
||||
π₀.₅ has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the libero base model for an additional 6k steps on the Libero dataset and compared the results to the OpenPI reference results.
|
||||
|
||||
| Benchmark | LeRobot Implementation | OpenPI Reference |
|
||||
| ------------------ | ---------------------- | ---------------- |
|
||||
| **Libero Spatial** | 97.0% | 98.8% |
|
||||
| **Libero Object** | 99.0% | 98.2% |
|
||||
| **Libero Goal** | 98.0% | 98.0% |
|
||||
| **Libero 10** | 96.0% | 92.4% |
|
||||
| **Average** | 97.5% | 96.85% |
|
||||
|
||||
These results demonstrate π₀.₅'s strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
@@ -0,0 +1,27 @@
|
||||
## Research Paper
|
||||
|
||||
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
|
||||
## Repository
|
||||
|
||||
Code: https://github.com/NVIDIA/Isaac-GR00T
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{gr00tn1_2025,
|
||||
archivePrefix = {arxiv},
|
||||
eprint = {2503.14734},
|
||||
title = {{GR00T} {N1}: An Open Foundation Model for Generalist Humanoid Robots},
|
||||
author = {NVIDIA and Johan Bjorck andFernando Castañeda, Nikita Cherniadev and Xingye Da and Runyu Ding and Linxi "Jim" Fan and Yu Fang and Dieter Fox and Fengyuan Hu and Spencer Huang and Joel Jang and Zhenyu Jiang and Jan Kautz and Kaushil Kundalia and Lawrence Lao and Zhiqi Li and Zongyu Lin and Kevin Lin and Guilin Liu and Edith Llontop and Loic Magne and Ajay Mandlekar and Avnish Narayan and Soroush Nasiriany and Scott Reed and You Liang Tan and Guanzhi Wang and Zu Wang and Jing Wang and Qi Wang and Jiannan Xiang and Yuqi Xie and Yinzhen Xu and Zhenjia Xu and Seonghyeon Ye and Zhiding Yu and Ao Zhang and Hao Zhang and Yizhou Zhao and Ruijie Zheng and Yuke Zhu},
|
||||
month = {March},
|
||||
year = {2025},
|
||||
booktitle = {ArXiv Preprint},
|
||||
}
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
Blog: https://developer.nvidia.com/isaac/gr00t
|
||||
|
||||
Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B
|
||||
@@ -1,4 +1,4 @@
|
||||
# Finetune SmolVLA
|
||||
# SmolVLA
|
||||
|
||||
SmolVLA is Hugging Face’s lightweight foundation model for robotics. Designed for easy fine-tuning on LeRobot datasets, it helps accelerate your development!
|
||||
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
# Using Dataset Tools
|
||||
|
||||
This guide covers the dataset tools utilities available in LeRobot for modifying and editing existing datasets.
|
||||
|
||||
## Overview
|
||||
|
||||
LeRobot provides several utilities for manipulating datasets:
|
||||
|
||||
1. **Delete Episodes** - Remove specific episodes from a dataset
|
||||
2. **Split Dataset** - Divide a dataset into multiple smaller datasets
|
||||
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
|
||||
## Command-Line Tool: lerobot-edit-dataset
|
||||
|
||||
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features.
|
||||
|
||||
Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
|
||||
|
||||
### Usage Examples
|
||||
|
||||
#### Delete Episodes
|
||||
|
||||
Remove specific episodes from a dataset. This is useful for filtering out undesired data.
|
||||
|
||||
```bash
|
||||
# Delete episodes 0, 2, and 5 (modifies original dataset)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
# Delete episodes and save to a new dataset (preserves original dataset)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_after_deletion \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
```
|
||||
|
||||
#### Split Dataset
|
||||
|
||||
Divide a dataset into multiple subsets.
|
||||
|
||||
```bash
|
||||
# Split by fractions (e.g. 80% train, 20% test, 20% val)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.8, "test": 0.2, "val": 0.2}'
|
||||
|
||||
# Split by specific episode indices
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"task1": [0, 1, 2, 3], "task2": [4, 5]}'
|
||||
```
|
||||
|
||||
There are no constraints on the split names, they can be determined by the user. Resulting datasets are saved under the repo id with the split name appended, e.g. `lerobot/pusht_train`, `lerobot/pusht_task1`, `lerobot/pusht_task2`.
|
||||
|
||||
#### Merge Datasets
|
||||
|
||||
Combine multiple datasets into a single dataset.
|
||||
|
||||
```bash
|
||||
# Merge train and validation splits back into one dataset
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_merged \
|
||||
--operation.type merge \
|
||||
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
||||
```
|
||||
|
||||
#### Remove Features
|
||||
|
||||
Remove features from a dataset.
|
||||
|
||||
```bash
|
||||
# Remove a camera feature
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
```
|
||||
|
||||
### Push to Hub
|
||||
|
||||
Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
|
||||
```bash
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_after_deletion \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.
|
||||
@@ -132,17 +132,15 @@ print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
||||
|
||||
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
|
||||
# PyTorch datasets.
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
break
|
||||
if __name__ == "__main__":
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
)
|
||||
for batch in dataloader:
|
||||
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
break
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example script demonstrating dataset tools utilities.
|
||||
|
||||
This script shows how to:
|
||||
1. Delete episodes from a dataset
|
||||
2. Split a dataset into train/val sets
|
||||
3. Add/remove features
|
||||
4. Merge datasets
|
||||
|
||||
Usage:
|
||||
python examples/dataset/use_dataset_tools.py
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
add_features,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
modify_features,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def main():
|
||||
dataset = LeRobotDataset("lerobot/pusht")
|
||||
|
||||
print(f"Original dataset: {dataset.meta.total_episodes} episodes, {dataset.meta.total_frames} frames")
|
||||
print(f"Features: {list(dataset.meta.features.keys())}")
|
||||
|
||||
print("\n1. Deleting episodes 0 and 2...")
|
||||
filtered_dataset = delete_episodes(dataset, episode_indices=[0, 2], repo_id="lerobot/pusht_filtered")
|
||||
print(f"Filtered dataset: {filtered_dataset.meta.total_episodes} episodes")
|
||||
|
||||
print("\n2. Splitting dataset into train/val...")
|
||||
splits = split_dataset(
|
||||
dataset,
|
||||
splits={"train": 0.8, "val": 0.2},
|
||||
)
|
||||
print(f"Train split: {splits['train'].meta.total_episodes} episodes")
|
||||
print(f"Val split: {splits['val'].meta.total_episodes} episodes")
|
||||
|
||||
print("\n3. Adding features...")
|
||||
|
||||
reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32)
|
||||
|
||||
def compute_success(row_dict, episode_index, frame_index):
|
||||
episode_length = 10
|
||||
return float(frame_index >= episode_length - 10)
|
||||
|
||||
dataset_with_features = add_features(
|
||||
dataset,
|
||||
features={
|
||||
"reward": (
|
||||
reward_values,
|
||||
{"dtype": "float32", "shape": (1,), "names": None},
|
||||
),
|
||||
"success": (
|
||||
compute_success,
|
||||
{"dtype": "float32", "shape": (1,), "names": None},
|
||||
),
|
||||
},
|
||||
repo_id="lerobot/pusht_with_features",
|
||||
)
|
||||
|
||||
print(f"New features: {list(dataset_with_features.meta.features.keys())}")
|
||||
|
||||
print("\n4. Removing the success feature...")
|
||||
dataset_cleaned = remove_feature(
|
||||
dataset_with_features, feature_names="success", repo_id="lerobot/pusht_cleaned"
|
||||
)
|
||||
print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}")
|
||||
|
||||
print("\n5. Using modify_features to add and remove features simultaneously...")
|
||||
dataset_modified = modify_features(
|
||||
dataset_with_features,
|
||||
add_features={
|
||||
"discount": (
|
||||
np.ones(dataset.meta.total_frames, dtype=np.float32) * 0.99,
|
||||
{"dtype": "float32", "shape": (1,), "names": None},
|
||||
),
|
||||
},
|
||||
remove_features="reward",
|
||||
repo_id="lerobot/pusht_modified",
|
||||
)
|
||||
print(f"Modified features: {list(dataset_modified.meta.features.keys())}")
|
||||
|
||||
print("\n6. Merging train and val splits back together...")
|
||||
merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged")
|
||||
print(f"Merged dataset: {merged.meta.total_episodes} episodes")
|
||||
|
||||
print("\n7. Complex workflow example...")
|
||||
|
||||
if len(dataset.meta.camera_keys) > 1:
|
||||
camera_to_remove = dataset.meta.camera_keys[0]
|
||||
print(f"Removing camera: {camera_to_remove}")
|
||||
dataset_no_cam = remove_feature(
|
||||
dataset, feature_names=camera_to_remove, repo_id="pusht_no_first_camera"
|
||||
)
|
||||
print(f"Remaining cameras: {dataset_no_cam.meta.camera_keys}")
|
||||
|
||||
print("\nDone! Check ~/.cache/huggingface/lerobot/ for the created datasets.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -133,4 +133,6 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
@@ -130,4 +130,6 @@ robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
@@ -194,4 +194,6 @@ for episode_idx in range(NUM_EPISODES):
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
@@ -200,4 +200,6 @@ log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
@@ -362,6 +362,8 @@ def port_droid(
|
||||
lerobot_dataset.save_episode()
|
||||
logging.info("Save_episode")
|
||||
|
||||
lerobot_dataset.finalize()
|
||||
|
||||
if push_to_hub:
|
||||
lerobot_dataset.push_to_hub(
|
||||
# Add openx tag, since it belongs to the openx collection of datasets
|
||||
|
||||
@@ -195,4 +195,6 @@ for episode_idx in range(NUM_EPISODES):
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
@@ -199,4 +199,6 @@ log_say("Stop recording")
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
"""This script demonstrates how to train ACT Policy on a real-world dataset."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
|
||||
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
|
||||
if delta_indices is None:
|
||||
return [0]
|
||||
|
||||
return [i / fps for i in delta_indices]
|
||||
|
||||
|
||||
output_directory = Path("outputs/robot_learning_tutorial/act")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
policy = ACTPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
|
||||
}
|
||||
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "fracapuano/robot_learning_tutorial_act"
|
||||
model = ACTPolicy.from_pretrained(model_id)
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
@@ -0,0 +1,11 @@
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
host = ... # something like "127.0.0.1" if you're exposing to localhost
|
||||
port = ... # something like 8080
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
serve(config)
|
||||
@@ -0,0 +1,55 @@
|
||||
import threading
|
||||
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.helpers import visualize_action_queue_size
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
|
||||
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
|
||||
# check the config.json on the Hub for the policy you are using to see the expected camera specs
|
||||
camera_cfg = {
|
||||
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
|
||||
|
||||
server_address = ... # something like "127.0.0.1:8080" if using localhost
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="fracapuano/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
# 5. Provide a textual description of the task
|
||||
task = ...
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
@@ -0,0 +1,99 @@
|
||||
"""This script demonstrates how to train Diffusion Policy on a real-world dataset."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
|
||||
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
|
||||
if delta_indices is None:
|
||||
return [0]
|
||||
|
||||
return [i / fps for i in delta_indices]
|
||||
|
||||
|
||||
output_directory = Path("outputs/robot_learning_tutorial/diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
||||
policy = DiffusionPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
|
||||
}
|
||||
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "fracapuano/robot_learning_tutorial_diffusion"
|
||||
|
||||
model = DiffusionPolicy.from_pretrained(model_id)
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config, model_id, dataset_stats=dataset_metadata.stats
|
||||
)
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
@@ -0,0 +1,67 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/pi0_base"
|
||||
|
||||
model = PI0Policy.from_pretrained(model_id)
|
||||
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"base_0_rgb": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"left_wrist_0_rgb": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
"right_wrist_0_rgb": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
@@ -0,0 +1,345 @@
|
||||
import multiprocessing as mp
|
||||
import signal
|
||||
from pathlib import Path
|
||||
from queue import Empty, Full
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
from lerobot.teleoperators.so100_leader import SO100LeaderConfig
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
LOG_EVERY = 10
|
||||
SEND_EVERY = 10
|
||||
|
||||
|
||||
def run_learner(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_learner: SACPolicy,
|
||||
online_buffer: ReplayBuffer,
|
||||
offline_buffer: ReplayBuffer,
|
||||
lr: float = 3e-4,
|
||||
batch_size: int = 32,
|
||||
device: torch.device = "mps",
|
||||
):
|
||||
"""The learner process - trains SAC policy on transitions streamed from the actor, updating parameters
|
||||
for the actor to adopt."""
|
||||
policy_learner.train()
|
||||
policy_learner.to(device)
|
||||
|
||||
# Create Adam optimizer from scratch - simple and clean
|
||||
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
|
||||
|
||||
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
|
||||
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
|
||||
|
||||
training_step = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
# retrieve incoming transitions from the actor process
|
||||
try:
|
||||
transitions = transitions_queue.get(timeout=0.1)
|
||||
for transition in transitions:
|
||||
# HIL-SERL: Add ALL transitions to online buffer
|
||||
online_buffer.add(**transition)
|
||||
|
||||
# HIL-SERL: Add ONLY human intervention transitions to offline buffer
|
||||
is_intervention = transition.get("complementary_info", {}).get("is_intervention", False)
|
||||
if is_intervention:
|
||||
offline_buffer.add(**transition)
|
||||
print(
|
||||
f"[LEARNER] Human intervention detected! Added to offline buffer (now {len(offline_buffer)} transitions)"
|
||||
)
|
||||
|
||||
except Empty:
|
||||
pass # No transitions available, continue
|
||||
|
||||
# Train if we have enough data
|
||||
if len(online_buffer) >= policy_learner.config.online_step_before_learning:
|
||||
# Sample from online buffer (autonomous + human data)
|
||||
online_batch = online_buffer.sample(batch_size // 2)
|
||||
|
||||
# Sample from offline buffer (human demonstrations only, either precollected or at runtime)
|
||||
offline_batch = offline_buffer.sample(batch_size // 2)
|
||||
|
||||
# Combine batches - this is the key HIL-SERL mechanism!
|
||||
batch = {}
|
||||
for key in online_batch:
|
||||
if key in offline_batch:
|
||||
batch[key] = torch.cat([online_batch[key], offline_batch[key]], dim=0)
|
||||
else:
|
||||
batch[key] = online_batch[key]
|
||||
|
||||
loss, _ = policy_learner.forward(batch)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
training_step += 1
|
||||
|
||||
if training_step % LOG_EVERY == 0:
|
||||
print(
|
||||
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
|
||||
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
|
||||
)
|
||||
|
||||
# Send updated parameters to actor every 10 training steps
|
||||
if training_step % SEND_EVERY == 0:
|
||||
try:
|
||||
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
|
||||
parameters_queue.put_nowait(state_dict)
|
||||
print("[LEARNER] Sent updated parameters to actor")
|
||||
except Full:
|
||||
# Missing write due to queue not being consumed (should happen rarely)
|
||||
pass
|
||||
|
||||
print("[LEARNER] Learner process finished")
|
||||
|
||||
|
||||
def run_actor(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_actor: SACPolicy,
|
||||
reward_classifier: Classifier,
|
||||
env_cfg: HILSerlRobotEnvConfig,
|
||||
device: torch.device = "mps",
|
||||
output_directory: Path | None = None,
|
||||
):
|
||||
"""The actor process - interacts with environment and collects data.
|
||||
The policy is frozen and only the parameters are updated, popping the most recent ones from a queue."""
|
||||
policy_actor.eval()
|
||||
policy_actor.to(device)
|
||||
|
||||
reward_classifier.eval()
|
||||
reward_classifier.to(device)
|
||||
|
||||
# Create robot environment inside the actor process
|
||||
env, teleop_device = make_robot_env(env_cfg)
|
||||
|
||||
try:
|
||||
for episode in range(MAX_EPISODES):
|
||||
if shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs, _info = env.reset()
|
||||
episode_reward = 0.0
|
||||
step = 0
|
||||
episode_transitions = []
|
||||
|
||||
print(f"[ACTOR] Starting episode {episode + 1}")
|
||||
|
||||
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
|
||||
try:
|
||||
new_params = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_params)
|
||||
print("[ACTOR] Updated policy parameters from learner")
|
||||
except Empty: # No new updated parameters available from learner, waiting
|
||||
pass
|
||||
|
||||
# Get action from policy
|
||||
policy_obs = make_policy_obs(obs, device=device)
|
||||
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
|
||||
action = action_tensor.squeeze(0).cpu().numpy()
|
||||
|
||||
# Step environment
|
||||
next_obs, _env_reward, terminated, truncated, _info = env.step(action)
|
||||
done = terminated or truncated
|
||||
|
||||
# Predict reward
|
||||
policy_next_obs = make_policy_obs(next_obs, device=device)
|
||||
reward = reward_classifier.predict_reward(policy_next_obs)
|
||||
|
||||
if reward >= 1.0 and not done: # success detected! halt episode
|
||||
terminated = True
|
||||
done = True
|
||||
|
||||
# In HIL-SERL, human interventions come from the teleop device
|
||||
is_intervention = False
|
||||
if hasattr(teleop_device, "get_teleop_events"):
|
||||
# Real intervention detection from teleop device
|
||||
teleop_events = teleop_device.get_teleop_events()
|
||||
is_intervention = teleop_events.get(TeleopEvents.IS_INTERVENTION, False)
|
||||
|
||||
# Store transition with intervention metadata
|
||||
transition = {
|
||||
"state": policy_obs,
|
||||
"action": action,
|
||||
"reward": float(reward) if hasattr(reward, "item") else reward,
|
||||
"next_state": policy_next_obs,
|
||||
"done": done,
|
||||
"truncated": truncated,
|
||||
"complementary_info": {
|
||||
"is_intervention": is_intervention,
|
||||
},
|
||||
}
|
||||
|
||||
episode_transitions.append(transition)
|
||||
|
||||
episode_reward += reward
|
||||
step += 1
|
||||
|
||||
obs = next_obs
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Send episode transitions to learner
|
||||
transitions_queue.put_nowait(episode_transitions)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("[ACTOR] Interrupted by user")
|
||||
finally:
|
||||
# Clean up
|
||||
if hasattr(env, "robot") and env.robot.is_connected:
|
||||
env.robot.disconnect()
|
||||
if teleop_device and hasattr(teleop_device, "disconnect"):
|
||||
teleop_device.disconnect()
|
||||
if output_directory is not None:
|
||||
policy_actor.save_pretrained(output_directory)
|
||||
print(f"[ACTOR] Latest actor policy saved at: {output_directory}")
|
||||
|
||||
print("[ACTOR] Actor process finished")
|
||||
|
||||
|
||||
def make_policy_obs(obs, device: torch.device = "cpu"):
|
||||
return {
|
||||
"observation.state": torch.from_numpy(obs["agent_pos"]).float().unsqueeze(0).to(device),
|
||||
**{
|
||||
f"observation.image.{k}": torch.from_numpy(obs["pixels"][k]).float().unsqueeze(0).to(device)
|
||||
for k in obs["pixels"]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
"""Main function - coordinates actor and learner processes."""
|
||||
|
||||
device = "mps" # or "cuda" or "cpu"
|
||||
output_directory = Path("outputs/robot_learning_tutorial/hil_serl")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ...
|
||||
leader_port = ...
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ...
|
||||
leader_id = ...
|
||||
|
||||
# A pretrained model (to be used in-distribution!)
|
||||
reward_classifier_id = "fracapuano/reward_classifier_hil_serl_example"
|
||||
reward_classifier = Classifier.from_pretrained(reward_classifier_id)
|
||||
|
||||
reward_classifier.to(device)
|
||||
reward_classifier.eval()
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
# Robot and environment configuration
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id)
|
||||
teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id)
|
||||
processor_cfg = HILSerlProcessorConfig(control_mode="leader")
|
||||
|
||||
env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg)
|
||||
|
||||
# Create robot environment
|
||||
env, teleop_device = make_robot_env(env_cfg)
|
||||
|
||||
obs_features = hw_to_dataset_features(env.robot.observation_features, "observation")
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = SACConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
|
||||
# Online buffer: initialized from scratch
|
||||
online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys()))
|
||||
# Offline buffer: Created from dataset (pre-populated it with demonstrations)
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys())
|
||||
)
|
||||
|
||||
# Create communication channels between learner and actor processes
|
||||
transitions_queue = mp.Queue(maxsize=10)
|
||||
parameters_queue = mp.Queue(maxsize=2)
|
||||
shutdown_event = mp.Event()
|
||||
|
||||
|
||||
# Signal handler for graceful shutdown
|
||||
def signal_handler(sig):
|
||||
print(f"\nSignal {sig} received, shutting down...")
|
||||
shutdown_event.set()
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Create processes
|
||||
learner_process = mp.Process(
|
||||
target=run_learner,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_learner,
|
||||
online_replay_buffer,
|
||||
offline_replay_buffer,
|
||||
),
|
||||
kwargs={"device": device}, # can run on accelerated hardware for training
|
||||
)
|
||||
|
||||
actor_process = mp.Process(
|
||||
target=run_actor,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_actor,
|
||||
reward_classifier,
|
||||
env_cfg,
|
||||
output_directory,
|
||||
),
|
||||
kwargs={"device": "cpu"}, # actor is frozen, can run on CPU or accelerate for inference
|
||||
)
|
||||
|
||||
learner_process.start()
|
||||
actor_process.start()
|
||||
|
||||
try:
|
||||
# Wait for actor to finish (it controls the episode loop)
|
||||
actor_process.join()
|
||||
shutdown_event.set()
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Main process interrupted")
|
||||
shutdown_event.set()
|
||||
actor_process.join(timeout=5)
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
finally:
|
||||
if learner_process.is_alive():
|
||||
learner_process.terminate()
|
||||
if actor_process.is_alive():
|
||||
actor_process.terminate()
|
||||
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
# Device to use for training
|
||||
device = "mps" # or "cuda", or "cpu"
|
||||
|
||||
# Load the dataset used for training
|
||||
repo_id = "lerobot/example_hil_serl_dataset"
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
# Configure the policy to extract features from the image frames
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
|
||||
config = RewardClassifierConfig(
|
||||
num_cameras=len(camera_keys),
|
||||
device=device,
|
||||
# backbone model to extract features from the image frames
|
||||
model_name="microsoft/resnet-18",
|
||||
)
|
||||
|
||||
# Make policy, preprocessor, and optimizer
|
||||
policy = make_policy(config, ds_meta=dataset.meta)
|
||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
||||
|
||||
|
||||
classifier_id = "fracapuano/reward_classifier_hil_serl_example"
|
||||
|
||||
# Instantiate a dataloader
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
|
||||
|
||||
# Training loop
|
||||
num_epochs = 5
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
total_accuracy = 0
|
||||
for batch in dataloader:
|
||||
# Preprocess the batch and move it to the correct device.
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Forward pass
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_accuracy += output_dict["accuracy"]
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
avg_accuracy = total_accuracy / len(dataloader)
|
||||
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
|
||||
|
||||
print("Training finished!")
|
||||
|
||||
# You can now save the trained policy.
|
||||
policy.push_to_hub(classifier_id)
|
||||
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/smolvla_base"
|
||||
|
||||
model = SmolVLAPolicy.from_pretrained(model_id)
|
||||
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
+98
-82
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.3.4"
|
||||
version = "0.4.0"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
readme = "README.md"
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -59,28 +59,30 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
||||
dependencies = [
|
||||
|
||||
# Hugging Face dependencies
|
||||
"datasets>=4.0.0",
|
||||
"diffusers>=0.27.2",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2",
|
||||
"datasets>=4.0.0,<4.2.0",
|
||||
"diffusers>=0.27.2,<0.36.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
|
||||
# Core dependencies
|
||||
"cmake>=3.29.0.1",
|
||||
"einops>=0.8.0",
|
||||
"opencv-python-headless>=4.9.0",
|
||||
"av>=14.2.0",
|
||||
"jsonlines>=4.0.0",
|
||||
"packaging>=24.2",
|
||||
"pynput>=1.7.7",
|
||||
"pyserial>=3.5",
|
||||
"wandb>=0.20.0",
|
||||
"setuptools>=71.0.0,<81.0.0",
|
||||
"cmake>=3.29.0.1,<4.2.0",
|
||||
"einops>=0.8.0,<0.9.0",
|
||||
"opencv-python-headless>=4.9.0,<4.13.0",
|
||||
"av>=15.0.0,<16.0.0",
|
||||
"jsonlines>=4.0.0,<5.0.0",
|
||||
"packaging>=24.2,<26.0",
|
||||
"pynput>=1.7.7,<1.9.0",
|
||||
"pyserial>=3.5,<4.0",
|
||||
"wandb>=0.20.0,<0.22.0", # TODO: Bumb dependency (compatible with protobuf)
|
||||
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
|
||||
"draccus==0.10.0", # TODO: Remove ==
|
||||
"gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency
|
||||
"rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
"gymnasium>=1.1.1,<2.0.0",
|
||||
"rerun-sdk>=0.24.0,<0.27.0",
|
||||
|
||||
# Support dependencies
|
||||
"deepdiff>=7.0.1,<9.0.0",
|
||||
@@ -92,51 +94,56 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1"]
|
||||
placo-dep = ["placo>=0.9.6"]
|
||||
transformers-dep = ["transformers>=4.52.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31"]
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
|
||||
# Robots
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1"]
|
||||
reachy2 = ["reachy2_sdk>=1.0.14"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
|
||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
|
||||
]
|
||||
phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"]
|
||||
# stretch = [
|
||||
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
|
||||
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
|
||||
# "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"
|
||||
# ] # TODO: Currently not supported
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
||||
|
||||
# Policies
|
||||
pi0 = ["lerobot[transformers-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
"peft>=0.13.0,<1.0.0",
|
||||
"dm-tree>=0.1.8,<1.0.0",
|
||||
"timm>=1.0.0,<1.1.0",
|
||||
"safetensors>=0.4.3,<1.0.0",
|
||||
"Pillow>=10.0.0,<13.0.0",
|
||||
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"]
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||
test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
aloha = ["gym-aloha>=0.1.1"]
|
||||
pusht = ["gym-pusht>=0.1.5", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
xarm = ["gym-xarm>=0.1.1"]
|
||||
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
|
||||
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"]
|
||||
|
||||
metaworld = ["metaworld==3.0.0"]
|
||||
|
||||
# All
|
||||
all = [
|
||||
@@ -147,8 +154,9 @@ all = [
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[pi0]",
|
||||
"lerobot[pi]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
@@ -156,9 +164,9 @@ all = [
|
||||
"lerobot[video_benchmark]",
|
||||
"lerobot[aloha]",
|
||||
"lerobot[pusht]",
|
||||
"lerobot[xarm]",
|
||||
"lerobot[phone]",
|
||||
"lerobot[libero]",
|
||||
"lerobot[metaworld]",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -175,6 +183,7 @@ lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
|
||||
lerobot-info="lerobot.scripts.lerobot_info:main"
|
||||
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.packages.find]
|
||||
@@ -232,9 +241,6 @@ exclude_dirs = [
|
||||
"tests",
|
||||
"benchmarks",
|
||||
"src/lerobot/datasets/push_dataset_to_hub",
|
||||
"src/lerobot/datasets/v2/convert_dataset_v1_to_v2",
|
||||
"src/lerobot/policies/pi0/conversion_scripts",
|
||||
"src/lerobot/scripts/push_dataset_to_hub.py",
|
||||
]
|
||||
skips = ["B101", "B311", "B404", "B603", "B615"]
|
||||
|
||||
@@ -249,6 +255,8 @@ default.extend-ignore-identifiers-re = [
|
||||
"pn",
|
||||
"ser",
|
||||
"ein",
|
||||
"thw",
|
||||
"inpt",
|
||||
]
|
||||
|
||||
# TODO: Uncomment when ready to use
|
||||
@@ -270,80 +278,88 @@ default.extend-ignore-identifiers-re = [
|
||||
# TODO: Enable mypy gradually module by module across multiple PRs
|
||||
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
|
||||
|
||||
# [tool.mypy]
|
||||
# python_version = "3.10"
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
follow_imports = "skip"
|
||||
# warn_return_any = true
|
||||
# warn_unused_configs = true
|
||||
# ignore_missing_imports = false
|
||||
# strict = true
|
||||
# disallow_untyped_defs = true
|
||||
# disallow_incomplete_defs = true
|
||||
# check_untyped_defs = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.*"
|
||||
ignore_errors = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.envs.*"
|
||||
ignore_errors = false
|
||||
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.utils.*"
|
||||
# # include = "src/lerobot/utils/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.configs.*"
|
||||
ignore_errors = false
|
||||
|
||||
# extra strictness for configs
|
||||
disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
check_untyped_defs = true
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.configs.*"
|
||||
# # include = "src/lerobot/configs/**/*.py"
|
||||
# module = "lerobot.optim.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.model.*"
|
||||
ignore_errors = false
|
||||
|
||||
# # Data processing modules
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.processor.*"
|
||||
# # include = "src/lerobot/processor/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.datasets.*"
|
||||
# # include = "src/lerobot/datasets/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# # Core machine learning modules
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.optim.*"
|
||||
# # include = "src/lerobot/optim/**/*.py"
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.model.*"
|
||||
# # include = "src/lerobot/model/**/*.py"
|
||||
|
||||
# # Hardware interfaces
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.cameras.*"
|
||||
# # include = "src/lerobot/cameras/**/*.py"
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.cameras.*"
|
||||
ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.motors.*"
|
||||
# # include = "src/lerobot/motors/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.robots.*"
|
||||
# # include = "src/lerobot/robots/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.teleoperators.*"
|
||||
# # include = "src/lerobot/teleoperators/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# # Complex modules (enable these last)
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.policies.*"
|
||||
# # include = "src/lerobot/policies/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.rl.*"
|
||||
# # include = "src/lerobot/rl/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.envs.*"
|
||||
# # include = "src/lerobot/envs/**/*.py"
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.async_inference.*"
|
||||
# # include = "src/lerobot/async_inference/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.transport.*"
|
||||
# # include = "src/lerobot/transport/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.scripts.*"
|
||||
# # include = "src/lerobot/scripts/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
+325
-120
@@ -1,3 +1,4 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with Python 3.10
|
||||
# by the following command:
|
||||
#
|
||||
@@ -12,47 +13,62 @@ absl-py==2.3.1
|
||||
# dm-tree
|
||||
# labmaze
|
||||
# mujoco
|
||||
accelerate==1.9.0
|
||||
# via lerobot
|
||||
# tensorboard
|
||||
accelerate==1.11.0
|
||||
# via
|
||||
# lerobot
|
||||
# peft
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.12.15
|
||||
aiohttp==3.13.1
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via
|
||||
# hydra-core
|
||||
# omegaconf
|
||||
anyio==4.11.0
|
||||
# via
|
||||
# starlette
|
||||
# watchfiles
|
||||
asttokens==3.0.0
|
||||
# via stack-data
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.3.0
|
||||
attrs==25.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
# dm-tree
|
||||
# jsonlines
|
||||
# jsonschema
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
av==15.0.0
|
||||
av==15.1.0
|
||||
# via lerobot
|
||||
blinker==1.9.0
|
||||
# via flask
|
||||
certifi==2025.7.14
|
||||
bddl==1.0.1
|
||||
# via libero
|
||||
certifi==2025.10.5
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==1.17.1
|
||||
cffi==2.0.0
|
||||
# via pymunk
|
||||
cfgv==3.4.0
|
||||
# via pre-commit
|
||||
charset-normalizer==3.4.2
|
||||
charset-normalizer==3.4.4
|
||||
# via requests
|
||||
click==8.2.1
|
||||
click==8.3.0
|
||||
# via
|
||||
# flask
|
||||
# uvicorn
|
||||
# wandb
|
||||
cloudpickle==3.1.1
|
||||
# via gymnasium
|
||||
cmake==4.0.3
|
||||
# via
|
||||
# gymnasium
|
||||
# libero
|
||||
cmake==4.1.0
|
||||
# via lerobot
|
||||
cmeel==0.57.3
|
||||
# via
|
||||
@@ -94,27 +110,27 @@ coal-library==3.0.1
|
||||
# via pin
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
coverage[toml]==7.10.1
|
||||
coverage[toml]==7.11.0
|
||||
# via pytest-cov
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
datasets==3.6.0
|
||||
datasets==4.1.1
|
||||
# via lerobot
|
||||
debugpy==1.8.15
|
||||
debugpy==1.8.17
|
||||
# via lerobot
|
||||
decorator==5.2.1
|
||||
# via ipython
|
||||
deepdiff==8.5.0
|
||||
deepdiff==8.6.1
|
||||
# via lerobot
|
||||
diffusers==0.34.0
|
||||
diffusers==0.35.2
|
||||
# via lerobot
|
||||
dill==0.3.8
|
||||
dill==0.4.0
|
||||
# via
|
||||
# datasets
|
||||
# multiprocess
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
dm-control==1.0.14
|
||||
dm-control==1.0.34
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
@@ -122,29 +138,45 @@ dm-tree==0.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# lerobot
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
draccus==0.10.0
|
||||
# via lerobot
|
||||
dynamixel-sdk==3.7.31
|
||||
dynamixel-sdk==3.8.4
|
||||
# via lerobot
|
||||
easydict==1.13
|
||||
# via libero
|
||||
egl-probe @ git+https://github.com/huggingface/egl_probe.git
|
||||
# via
|
||||
# libero
|
||||
# robomimic
|
||||
eigenpy==3.10.3
|
||||
# via coal-library
|
||||
einops==0.8.1
|
||||
# via lerobot
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
etils[epath,epy]==1.13.0
|
||||
# via mujoco
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# ipython
|
||||
# pytest
|
||||
executing==2.2.0
|
||||
executing==2.2.1
|
||||
# via stack-data
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fastapi==0.119.1
|
||||
# via teleop
|
||||
fastjsonschema==2.21.2
|
||||
# via nbformat
|
||||
feetech-servo-sdk==1.0.0
|
||||
# via lerobot
|
||||
filelock==3.18.0
|
||||
filelock==3.20.0
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
@@ -152,24 +184,25 @@ filelock==3.18.0
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
flask==3.1.1
|
||||
# via lerobot
|
||||
fonttools==4.59.0
|
||||
fonttools==4.60.1
|
||||
# via matplotlib
|
||||
frozenlist==1.7.0
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2025.3.0
|
||||
fsspec[http]==2025.9.0
|
||||
# via
|
||||
# datasets
|
||||
# etils
|
||||
# huggingface-hub
|
||||
# torch
|
||||
future==1.0.0
|
||||
# via libero
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.45
|
||||
# via wandb
|
||||
glfw==2.9.0
|
||||
glfw==2.10.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
@@ -177,61 +210,79 @@ grpcio==1.73.1
|
||||
# via
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
grpcio-tools==1.73.1
|
||||
# via
|
||||
# lerobot
|
||||
# reachy2-sdk-api
|
||||
gym-aloha==0.1.3
|
||||
# via lerobot
|
||||
gym-aloha==0.1.1
|
||||
gym-hil==0.1.13
|
||||
# via lerobot
|
||||
gym-hil==0.1.10
|
||||
gym-pusht==0.1.6
|
||||
# via lerobot
|
||||
gym-pusht==0.1.5
|
||||
# via lerobot
|
||||
gym-xarm==0.1.1
|
||||
# via lerobot
|
||||
gymnasium==0.29.1
|
||||
gymnasium==1.2.1
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# pettingzoo
|
||||
gymnasium-robotics==1.2.4
|
||||
# via gym-xarm
|
||||
# libero
|
||||
# metaworld
|
||||
h11==0.16.0
|
||||
# via uvicorn
|
||||
h5py==3.15.1
|
||||
# via robomimic
|
||||
hebi-py==2.11.0
|
||||
# via lerobot
|
||||
hf-transfer==0.1.9
|
||||
# via huggingface-hub
|
||||
hf-xet==1.1.5
|
||||
hf-xet==1.1.10
|
||||
# via huggingface-hub
|
||||
hidapi==0.14.0.post4
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
huggingface-hub[cli,hf-transfer]==0.34.3
|
||||
httptools==0.7.1
|
||||
# via uvicorn
|
||||
huggingface-hub[cli,hf-transfer]==0.35.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
identify==2.6.12
|
||||
hydra-core==1.3.2
|
||||
# via libero
|
||||
identify==2.6.15
|
||||
# via pre-commit
|
||||
idna==3.10
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# requests
|
||||
# yarl
|
||||
imageio[ffmpeg]==2.37.0
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robomimic
|
||||
# scikit-image
|
||||
imageio-ffmpeg==0.6.0
|
||||
# via imageio
|
||||
# via
|
||||
# imageio
|
||||
# robomimic
|
||||
importlib-metadata==8.7.0
|
||||
# via diffusers
|
||||
iniconfig==2.1.0
|
||||
importlib-resources==6.5.2
|
||||
# via etils
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
inquirerpy==0.3.4
|
||||
# via huggingface-hub
|
||||
@@ -239,50 +290,71 @@ ipython==8.37.0
|
||||
# via meshcat
|
||||
ischedule==1.2.7
|
||||
# via placo
|
||||
itsdangerous==2.2.0
|
||||
# via flask
|
||||
jedi==0.19.2
|
||||
# via ipython
|
||||
jinja2==3.1.6
|
||||
# via
|
||||
# flask
|
||||
# gymnasium-robotics
|
||||
# torch
|
||||
# via torch
|
||||
jsonlines==4.0.0
|
||||
# via lerobot
|
||||
kiwisolver==1.4.8
|
||||
jsonschema==4.25.1
|
||||
# via nbformat
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
jupyter-core==5.9.1
|
||||
# via nbformat
|
||||
jupytext==1.18.1
|
||||
# via bddl
|
||||
kiwisolver==1.4.9
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lazy-loader==0.4
|
||||
# via scikit-image
|
||||
lxml==6.0.0
|
||||
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
|
||||
# via lerobot
|
||||
llvmlite==0.45.1
|
||||
# via numba
|
||||
lxml==6.0.2
|
||||
# via dm-control
|
||||
markupsafe==3.0.2
|
||||
markdown==3.9
|
||||
# via tensorboard
|
||||
markdown-it-py==4.0.0
|
||||
# via
|
||||
# jupytext
|
||||
# mdit-py-plugins
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# flask
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.10.5
|
||||
# via lerobot
|
||||
matplotlib-inline==0.1.7
|
||||
matplotlib==3.10.7
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
matplotlib-inline==0.2.1
|
||||
# via ipython
|
||||
mdit-py-plugins==0.5.0
|
||||
# via jupytext
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mergedeep==1.3.4
|
||||
# via draccus
|
||||
meshcat==0.3.2
|
||||
# via placo
|
||||
metaworld==3.0.0
|
||||
# via lerobot
|
||||
mock-serial==0.0.1
|
||||
# via lerobot
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mujoco==2.3.7
|
||||
mujoco==3.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
multidict==6.6.3
|
||||
# libero
|
||||
# metaworld
|
||||
# robosuite
|
||||
multidict==6.7.0
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@@ -290,17 +362,25 @@ multiprocess==0.70.16
|
||||
# via datasets
|
||||
mypy-extensions==1.1.0
|
||||
# via typing-inspect
|
||||
nbformat==5.10.4
|
||||
# via jupytext
|
||||
networkx==3.4.2
|
||||
# via
|
||||
# bddl
|
||||
# scikit-image
|
||||
# torch
|
||||
ninja==1.13.0
|
||||
# via lerobot
|
||||
nodeenv==1.9.1
|
||||
# via pre-commit
|
||||
num2words==0.5.14
|
||||
# via lerobot
|
||||
numba==0.62.1
|
||||
# via robosuite
|
||||
numpy==2.2.6
|
||||
# via
|
||||
# accelerate
|
||||
# bddl
|
||||
# cmeel-boost
|
||||
# contourpy
|
||||
# datasets
|
||||
@@ -309,25 +389,43 @@ numpy==2.2.6
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# gymnasium
|
||||
# gymnasium-robotics
|
||||
# h5py
|
||||
# hebi-py
|
||||
# imageio
|
||||
# labmaze
|
||||
# libero
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# metaworld
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# opencv-python-headless
|
||||
# pandas
|
||||
# pettingzoo
|
||||
# peft
|
||||
# pyquaternion
|
||||
# reachy2-sdk
|
||||
# rerun-sdk
|
||||
# robomimic
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# scipy
|
||||
# shapely
|
||||
# teleop
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# tifffile
|
||||
# torchvision
|
||||
# transformers
|
||||
# transforms3d
|
||||
omegaconf==2.3.0
|
||||
# via hydra-core
|
||||
opencv-python==4.12.0.88
|
||||
# via gym-pusht
|
||||
# via
|
||||
# gym-pusht
|
||||
# libero
|
||||
# reachy2-sdk
|
||||
# robosuite
|
||||
opencv-python-headless==4.12.0.88
|
||||
# via lerobot
|
||||
orderly-set==5.5.0
|
||||
@@ -337,53 +435,63 @@ packaging==25.0
|
||||
# accelerate
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# hydra-core
|
||||
# jupytext
|
||||
# lazy-loader
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# peft
|
||||
# pytest
|
||||
# reachy2-sdk
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# transformers
|
||||
# wandb
|
||||
pandas==2.3.1
|
||||
pandas==2.3.3
|
||||
# via
|
||||
# datasets
|
||||
# lerobot
|
||||
parso==0.8.4
|
||||
parso==0.8.5
|
||||
# via jedi
|
||||
pettingzoo==1.24.3
|
||||
# via gymnasium-robotics
|
||||
peft==0.17.1
|
||||
# via lerobot
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pfzy==0.3.4
|
||||
# via inquirerpy
|
||||
pillow==11.3.0
|
||||
pillow==12.0.0
|
||||
# via
|
||||
# diffusers
|
||||
# imageio
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# rerun-sdk
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# torchvision
|
||||
pin==3.4.0
|
||||
# via placo
|
||||
placo==0.9.14
|
||||
# via lerobot
|
||||
platformdirs==4.3.8
|
||||
platformdirs==4.5.0
|
||||
# via
|
||||
# jupyter-core
|
||||
# virtualenv
|
||||
# wandb
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
pre-commit==4.2.0
|
||||
pre-commit==4.3.0
|
||||
# via lerobot
|
||||
prompt-toolkit==3.0.51
|
||||
prompt-toolkit==3.0.52
|
||||
# via
|
||||
# inquirerpy
|
||||
# ipython
|
||||
propcache==0.3.2
|
||||
propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@@ -392,11 +500,17 @@ protobuf==6.31.0
|
||||
# dm-control
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# wandb
|
||||
psutil==7.0.0
|
||||
psutil==7.1.1
|
||||
# via
|
||||
# accelerate
|
||||
# imageio
|
||||
# peft
|
||||
# robomimic
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
@@ -405,11 +519,13 @@ pyarrow==21.0.0
|
||||
# via
|
||||
# datasets
|
||||
# rerun-sdk
|
||||
pycparser==2.22
|
||||
pycparser==2.23
|
||||
# via cffi
|
||||
pydantic==2.11.7
|
||||
# via wandb
|
||||
pydantic-core==2.33.2
|
||||
pydantic==2.12.3
|
||||
# via
|
||||
# fastapi
|
||||
# wandb
|
||||
pydantic-core==2.41.4
|
||||
# via pydantic
|
||||
pygame==2.6.1
|
||||
# via
|
||||
@@ -424,40 +540,42 @@ pymunk==6.11.1
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pyngrok==7.2.12
|
||||
pyngrok==7.4.1
|
||||
# via meshcat
|
||||
pynput==1.8.1
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
pyobjc-core==11.1
|
||||
pyobjc-core==12.0
|
||||
# via
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-cocoa
|
||||
# pyobjc-framework-coretext
|
||||
# pyobjc-framework-quartz
|
||||
pyobjc-framework-applicationservices==11.1
|
||||
pyobjc-framework-applicationservices==12.0
|
||||
# via pynput
|
||||
pyobjc-framework-cocoa==11.1
|
||||
pyobjc-framework-cocoa==12.0
|
||||
# via
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-coretext
|
||||
# pyobjc-framework-quartz
|
||||
pyobjc-framework-coretext==11.1
|
||||
pyobjc-framework-coretext==12.0
|
||||
# via pyobjc-framework-applicationservices
|
||||
pyobjc-framework-quartz==11.1
|
||||
pyobjc-framework-quartz==12.0
|
||||
# via
|
||||
# pynput
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-coretext
|
||||
pyopengl==3.1.9
|
||||
pyopengl==3.1.10
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.3
|
||||
pyparsing==3.2.5
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via reachy2-sdk
|
||||
pyrealsense2-macosx==2.54.2
|
||||
# via lerobot
|
||||
pyserial==3.5
|
||||
@@ -465,12 +583,14 @@ pyserial==3.5
|
||||
# dynamixel-sdk
|
||||
# feetech-servo-sdk
|
||||
# lerobot
|
||||
pytest==8.4.1
|
||||
pytest==8.4.2
|
||||
# via
|
||||
# bddl
|
||||
# lerobot
|
||||
# pytest-cov
|
||||
# pytest-timeout
|
||||
pytest-cov==6.2.1
|
||||
# teleop
|
||||
pytest-cov==7.0.0
|
||||
# via lerobot
|
||||
pytest-timeout==2.4.0
|
||||
# via lerobot
|
||||
@@ -478,46 +598,73 @@ python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# matplotlib
|
||||
# pandas
|
||||
python-dotenv==1.1.1
|
||||
# via uvicorn
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
pyyaml==6.0.2
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# draccus
|
||||
# hebi-py
|
||||
# huggingface-hub
|
||||
# jupytext
|
||||
# omegaconf
|
||||
# peft
|
||||
# pre-commit
|
||||
# pyngrok
|
||||
# pyyaml-include
|
||||
# timm
|
||||
# transformers
|
||||
# uvicorn
|
||||
# wandb
|
||||
pyyaml-include==1.4.1
|
||||
# via draccus
|
||||
pyzmq==27.0.0
|
||||
pyzmq==27.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# meshcat
|
||||
regex==2025.7.34
|
||||
reachy2-sdk==1.0.14
|
||||
# via lerobot
|
||||
reachy2-sdk-api==1.0.21
|
||||
# via reachy2-sdk
|
||||
referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2025.10.23
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
requests==2.32.4
|
||||
requests==2.32.5
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# teleop
|
||||
# transformers
|
||||
# wandb
|
||||
rerun-sdk==0.22.1
|
||||
rerun-sdk==0.26.1
|
||||
# via lerobot
|
||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||
# via placo
|
||||
safetensors==0.5.3
|
||||
robomimic==0.2.0
|
||||
# via libero
|
||||
robosuite==1.4.0
|
||||
# via libero
|
||||
rpds-py==0.28.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.6.2
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# transformers
|
||||
scikit-image==0.25.2
|
||||
# via
|
||||
@@ -526,10 +673,12 @@ scikit-image==0.25.2
|
||||
scipy==1.15.3
|
||||
# via
|
||||
# dm-control
|
||||
# metaworld
|
||||
# robosuite
|
||||
# scikit-image
|
||||
sentry-sdk==2.34.1
|
||||
sentry-sdk==2.42.1
|
||||
# via wandb
|
||||
shapely==2.1.1
|
||||
shapely==2.1.2
|
||||
# via gym-pusht
|
||||
six==1.17.0
|
||||
# via
|
||||
@@ -537,64 +686,106 @@ six==1.17.0
|
||||
# python-dateutil
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
starlette==0.48.0
|
||||
# via fastapi
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
termcolor==3.1.0
|
||||
teleop==0.1.2
|
||||
# via lerobot
|
||||
tensorboard==2.20.0
|
||||
# via robomimic
|
||||
tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
tensorboardx==2.6.4
|
||||
# via robomimic
|
||||
termcolor==3.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
thop==0.1.1.post2209072238
|
||||
# via libero
|
||||
tifffile==2025.5.10
|
||||
# via scikit-image
|
||||
tokenizers==0.21.4
|
||||
timm==1.0.20
|
||||
# via lerobot
|
||||
tokenizers==0.22.1
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via draccus
|
||||
tomli==2.2.1
|
||||
tomli==2.3.0
|
||||
# via
|
||||
# cmeel
|
||||
# coverage
|
||||
# jupytext
|
||||
# pytest
|
||||
torch==2.7.1
|
||||
# via
|
||||
# accelerate
|
||||
# lerobot
|
||||
# peft
|
||||
# robomimic
|
||||
# thop
|
||||
# timm
|
||||
# torchvision
|
||||
torchcodec==0.5
|
||||
# via lerobot
|
||||
torchvision==0.22.1
|
||||
# via lerobot
|
||||
tornado==6.5.1
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
# timm
|
||||
tornado==6.5.2
|
||||
# via meshcat
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# datasets
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# peft
|
||||
# robomimic
|
||||
# transformers
|
||||
traitlets==5.14.3
|
||||
# via
|
||||
# ipython
|
||||
# jupyter-core
|
||||
# matplotlib-inline
|
||||
transformers==4.51.3
|
||||
# via lerobot
|
||||
typing-extensions==4.14.1
|
||||
# nbformat
|
||||
transformers==4.57.1
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
# peft
|
||||
transforms3d==0.4.2
|
||||
# via teleop
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# etils
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
# gymnasium
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# multidict
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
# starlette
|
||||
# torch
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# uvicorn
|
||||
# virtualenv
|
||||
# wandb
|
||||
typing-inspect==0.9.0
|
||||
# via draccus
|
||||
typing-inspection==0.4.1
|
||||
typing-inspection==0.4.2
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via pandas
|
||||
@@ -604,22 +795,36 @@ urllib3==2.5.0
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
virtualenv==20.32.0
|
||||
uvicorn[standard]==0.38.0
|
||||
# via teleop
|
||||
uvloop==0.22.1
|
||||
# via uvicorn
|
||||
virtualenv==20.35.3
|
||||
# via pre-commit
|
||||
wandb==0.21.0
|
||||
# via lerobot
|
||||
wcwidth==0.2.13
|
||||
wandb==0.21.4
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
watchfiles==1.1.1
|
||||
# via uvicorn
|
||||
wcwidth==0.2.14
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.9.0
|
||||
# via teleop
|
||||
websockets==15.0.1
|
||||
# via uvicorn
|
||||
werkzeug==3.1.3
|
||||
# via flask
|
||||
wrapt==1.17.2
|
||||
# via tensorboard
|
||||
wrapt==2.0.0
|
||||
# via dm-tree
|
||||
xxhash==3.5.0
|
||||
xxhash==3.6.0
|
||||
# via datasets
|
||||
yarl==1.20.1
|
||||
yarl==1.22.0
|
||||
# via aiohttp
|
||||
zipp==3.23.0
|
||||
# via importlib-metadata
|
||||
# via
|
||||
# etils
|
||||
# importlib-metadata
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# setuptools
|
||||
|
||||
+325
-114
@@ -13,47 +13,62 @@ absl-py==2.3.1
|
||||
# dm-tree
|
||||
# labmaze
|
||||
# mujoco
|
||||
accelerate==1.9.0
|
||||
# via lerobot
|
||||
# tensorboard
|
||||
accelerate==1.11.0
|
||||
# via
|
||||
# lerobot
|
||||
# peft
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.12.15
|
||||
aiohttp==3.13.1
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via
|
||||
# hydra-core
|
||||
# omegaconf
|
||||
anyio==4.11.0
|
||||
# via
|
||||
# starlette
|
||||
# watchfiles
|
||||
asttokens==3.0.0
|
||||
# via stack-data
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.3.0
|
||||
attrs==25.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
# dm-tree
|
||||
# jsonlines
|
||||
# jsonschema
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
av==15.0.0
|
||||
av==15.1.0
|
||||
# via lerobot
|
||||
blinker==1.9.0
|
||||
# via flask
|
||||
certifi==2025.7.14
|
||||
bddl==1.0.1
|
||||
# via libero
|
||||
certifi==2025.10.5
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==1.17.1
|
||||
cffi==2.0.0
|
||||
# via pymunk
|
||||
cfgv==3.4.0
|
||||
# via pre-commit
|
||||
charset-normalizer==3.4.2
|
||||
charset-normalizer==3.4.4
|
||||
# via requests
|
||||
click==8.2.1
|
||||
click==8.3.0
|
||||
# via
|
||||
# flask
|
||||
# uvicorn
|
||||
# wandb
|
||||
cloudpickle==3.1.1
|
||||
# via gymnasium
|
||||
cmake==4.0.3
|
||||
# via
|
||||
# gymnasium
|
||||
# libero
|
||||
cmake==4.1.0
|
||||
# via lerobot
|
||||
cmeel==0.57.3
|
||||
# via
|
||||
@@ -95,27 +110,29 @@ coal-library==3.0.1
|
||||
# via pin
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
coverage[toml]==7.10.1
|
||||
coverage[toml]==7.11.0
|
||||
# via pytest-cov
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
datasets==3.6.0
|
||||
datasets==4.1.1
|
||||
# via lerobot
|
||||
debugpy==1.8.15
|
||||
debugpy==1.8.17
|
||||
# via lerobot
|
||||
decorator==5.2.1
|
||||
# via ipython
|
||||
deepdiff==8.5.0
|
||||
decord==0.6.0
|
||||
# via lerobot
|
||||
diffusers==0.34.0
|
||||
deepdiff==8.6.1
|
||||
# via lerobot
|
||||
dill==0.3.8
|
||||
diffusers==0.35.2
|
||||
# via lerobot
|
||||
dill==0.4.0
|
||||
# via
|
||||
# datasets
|
||||
# multiprocess
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
dm-control==1.0.14
|
||||
dm-control==1.0.34
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
@@ -123,31 +140,48 @@ dm-tree==0.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# lerobot
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
draccus==0.10.0
|
||||
# via lerobot
|
||||
dynamixel-sdk==3.7.31
|
||||
dynamixel-sdk==3.8.4
|
||||
# via lerobot
|
||||
easydict==1.13
|
||||
# via libero
|
||||
egl-probe @ git+https://github.com/huggingface/egl_probe.git
|
||||
# via
|
||||
# libero
|
||||
# robomimic
|
||||
eigenpy==3.10.3
|
||||
# via coal-library
|
||||
einops==0.8.1
|
||||
# via lerobot
|
||||
# via
|
||||
# flash-attn
|
||||
# lerobot
|
||||
# libero
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
etils[epath,epy]==1.13.0
|
||||
# via mujoco
|
||||
evdev==1.9.2
|
||||
# via pynput
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# ipython
|
||||
# pytest
|
||||
executing==2.2.0
|
||||
executing==2.2.1
|
||||
# via stack-data
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fastapi==0.119.1
|
||||
# via teleop
|
||||
fastjsonschema==2.21.2
|
||||
# via nbformat
|
||||
feetech-servo-sdk==1.0.0
|
||||
# via lerobot
|
||||
filelock==3.18.0
|
||||
filelock==3.20.0
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
@@ -155,24 +189,27 @@ filelock==3.18.0
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
flask==3.1.1
|
||||
flash-attn==2.8.3
|
||||
# via lerobot
|
||||
fonttools==4.59.0
|
||||
fonttools==4.60.1
|
||||
# via matplotlib
|
||||
frozenlist==1.7.0
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2025.3.0
|
||||
fsspec[http]==2025.9.0
|
||||
# via
|
||||
# datasets
|
||||
# etils
|
||||
# huggingface-hub
|
||||
# torch
|
||||
future==1.0.0
|
||||
# via libero
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.45
|
||||
# via wandb
|
||||
glfw==2.9.0
|
||||
glfw==2.10.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
@@ -180,61 +217,79 @@ grpcio==1.73.1
|
||||
# via
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
grpcio-tools==1.73.1
|
||||
# via
|
||||
# lerobot
|
||||
# reachy2-sdk-api
|
||||
gym-aloha==0.1.3
|
||||
# via lerobot
|
||||
gym-aloha==0.1.1
|
||||
gym-hil==0.1.13
|
||||
# via lerobot
|
||||
gym-hil==0.1.10
|
||||
gym-pusht==0.1.6
|
||||
# via lerobot
|
||||
gym-pusht==0.1.5
|
||||
# via lerobot
|
||||
gym-xarm==0.1.1
|
||||
# via lerobot
|
||||
gymnasium==0.29.1
|
||||
gymnasium==1.2.1
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# pettingzoo
|
||||
gymnasium-robotics==1.2.4
|
||||
# via gym-xarm
|
||||
# libero
|
||||
# metaworld
|
||||
h11==0.16.0
|
||||
# via uvicorn
|
||||
h5py==3.15.1
|
||||
# via robomimic
|
||||
hebi-py==2.11.0
|
||||
# via lerobot
|
||||
hf-transfer==0.1.9
|
||||
# via huggingface-hub
|
||||
hf-xet==1.1.5
|
||||
hf-xet==1.1.10
|
||||
# via huggingface-hub
|
||||
hidapi==0.14.0.post4
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
huggingface-hub[cli,hf-transfer]==0.34.3
|
||||
httptools==0.7.1
|
||||
# via uvicorn
|
||||
huggingface-hub[cli,hf-transfer]==0.35.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
identify==2.6.12
|
||||
hydra-core==1.3.2
|
||||
# via libero
|
||||
identify==2.6.15
|
||||
# via pre-commit
|
||||
idna==3.10
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# requests
|
||||
# yarl
|
||||
imageio[ffmpeg]==2.37.0
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robomimic
|
||||
# scikit-image
|
||||
imageio-ffmpeg==0.6.0
|
||||
# via imageio
|
||||
# via
|
||||
# imageio
|
||||
# robomimic
|
||||
importlib-metadata==8.7.0
|
||||
# via diffusers
|
||||
iniconfig==2.1.0
|
||||
importlib-resources==6.5.2
|
||||
# via etils
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
inquirerpy==0.3.4
|
||||
# via huggingface-hub
|
||||
@@ -242,50 +297,71 @@ ipython==8.37.0
|
||||
# via meshcat
|
||||
ischedule==1.2.7
|
||||
# via placo
|
||||
itsdangerous==2.2.0
|
||||
# via flask
|
||||
jedi==0.19.2
|
||||
# via ipython
|
||||
jinja2==3.1.6
|
||||
# via
|
||||
# flask
|
||||
# gymnasium-robotics
|
||||
# torch
|
||||
# via torch
|
||||
jsonlines==4.0.0
|
||||
# via lerobot
|
||||
kiwisolver==1.4.8
|
||||
jsonschema==4.25.1
|
||||
# via nbformat
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
jupyter-core==5.9.1
|
||||
# via nbformat
|
||||
jupytext==1.18.1
|
||||
# via bddl
|
||||
kiwisolver==1.4.9
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lazy-loader==0.4
|
||||
# via scikit-image
|
||||
lxml==6.0.0
|
||||
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
|
||||
# via lerobot
|
||||
llvmlite==0.45.1
|
||||
# via numba
|
||||
lxml==6.0.2
|
||||
# via dm-control
|
||||
markupsafe==3.0.2
|
||||
markdown==3.9
|
||||
# via tensorboard
|
||||
markdown-it-py==4.0.0
|
||||
# via
|
||||
# jupytext
|
||||
# mdit-py-plugins
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# flask
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.10.5
|
||||
# via lerobot
|
||||
matplotlib-inline==0.1.7
|
||||
matplotlib==3.10.7
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
matplotlib-inline==0.2.1
|
||||
# via ipython
|
||||
mdit-py-plugins==0.5.0
|
||||
# via jupytext
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mergedeep==1.3.4
|
||||
# via draccus
|
||||
meshcat==0.3.2
|
||||
# via placo
|
||||
metaworld==3.0.0
|
||||
# via lerobot
|
||||
mock-serial==0.0.1
|
||||
# via lerobot
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mujoco==2.3.7
|
||||
mujoco==3.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
multidict==6.6.3
|
||||
# libero
|
||||
# metaworld
|
||||
# robosuite
|
||||
multidict==6.7.0
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@@ -293,42 +369,63 @@ multiprocess==0.70.16
|
||||
# via datasets
|
||||
mypy-extensions==1.1.0
|
||||
# via typing-inspect
|
||||
nbformat==5.10.4
|
||||
# via jupytext
|
||||
networkx==3.4.2
|
||||
# via
|
||||
# bddl
|
||||
# scikit-image
|
||||
# torch
|
||||
ninja==1.13.0
|
||||
# via lerobot
|
||||
nodeenv==1.9.1
|
||||
# via pre-commit
|
||||
num2words==0.5.14
|
||||
# via lerobot
|
||||
numba==0.62.1
|
||||
# via robosuite
|
||||
numpy==2.2.6
|
||||
# via
|
||||
# accelerate
|
||||
# bddl
|
||||
# cmeel-boost
|
||||
# contourpy
|
||||
# datasets
|
||||
# decord
|
||||
# diffusers
|
||||
# dm-control
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# gymnasium
|
||||
# gymnasium-robotics
|
||||
# h5py
|
||||
# hebi-py
|
||||
# imageio
|
||||
# labmaze
|
||||
# libero
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# metaworld
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# opencv-python-headless
|
||||
# pandas
|
||||
# pettingzoo
|
||||
# peft
|
||||
# pyquaternion
|
||||
# reachy2-sdk
|
||||
# rerun-sdk
|
||||
# robomimic
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# scipy
|
||||
# shapely
|
||||
# teleop
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# tifffile
|
||||
# torchvision
|
||||
# transformers
|
||||
# transforms3d
|
||||
nvidia-cublas-cu12==12.6.4.1
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
@@ -366,8 +463,14 @@ nvidia-nvjitlink-cu12==12.6.85
|
||||
# torch
|
||||
nvidia-nvtx-cu12==12.6.77
|
||||
# via torch
|
||||
omegaconf==2.3.0
|
||||
# via hydra-core
|
||||
opencv-python==4.12.0.88
|
||||
# via gym-pusht
|
||||
# via
|
||||
# gym-pusht
|
||||
# libero
|
||||
# reachy2-sdk
|
||||
# robosuite
|
||||
opencv-python-headless==4.12.0.88
|
||||
# via lerobot
|
||||
orderly-set==5.5.0
|
||||
@@ -377,53 +480,63 @@ packaging==25.0
|
||||
# accelerate
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# hydra-core
|
||||
# jupytext
|
||||
# lazy-loader
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# peft
|
||||
# pytest
|
||||
# reachy2-sdk
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# transformers
|
||||
# wandb
|
||||
pandas==2.3.1
|
||||
pandas==2.3.3
|
||||
# via
|
||||
# datasets
|
||||
# lerobot
|
||||
parso==0.8.4
|
||||
parso==0.8.5
|
||||
# via jedi
|
||||
pettingzoo==1.24.3
|
||||
# via gymnasium-robotics
|
||||
peft==0.17.1
|
||||
# via lerobot
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pfzy==0.3.4
|
||||
# via inquirerpy
|
||||
pillow==11.3.0
|
||||
pillow==12.0.0
|
||||
# via
|
||||
# diffusers
|
||||
# imageio
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# rerun-sdk
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# torchvision
|
||||
pin==3.4.0
|
||||
# via placo
|
||||
placo==0.9.14
|
||||
# via lerobot
|
||||
platformdirs==4.3.8
|
||||
platformdirs==4.5.0
|
||||
# via
|
||||
# jupyter-core
|
||||
# virtualenv
|
||||
# wandb
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
pre-commit==4.2.0
|
||||
pre-commit==4.3.0
|
||||
# via lerobot
|
||||
prompt-toolkit==3.0.51
|
||||
prompt-toolkit==3.0.52
|
||||
# via
|
||||
# inquirerpy
|
||||
# ipython
|
||||
propcache==0.3.2
|
||||
propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@@ -432,11 +545,17 @@ protobuf==6.31.0
|
||||
# dm-control
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# wandb
|
||||
psutil==7.0.0
|
||||
psutil==7.1.1
|
||||
# via
|
||||
# accelerate
|
||||
# imageio
|
||||
# peft
|
||||
# robomimic
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
@@ -445,11 +564,13 @@ pyarrow==21.0.0
|
||||
# via
|
||||
# datasets
|
||||
# rerun-sdk
|
||||
pycparser==2.22
|
||||
pycparser==2.23
|
||||
# via cffi
|
||||
pydantic==2.11.7
|
||||
# via wandb
|
||||
pydantic-core==2.33.2
|
||||
pydantic==2.12.3
|
||||
# via
|
||||
# fastapi
|
||||
# wandb
|
||||
pydantic-core==2.41.4
|
||||
# via pydantic
|
||||
pygame==2.6.1
|
||||
# via
|
||||
@@ -464,20 +585,22 @@ pymunk==6.11.1
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pyngrok==7.2.12
|
||||
pyngrok==7.4.1
|
||||
# via meshcat
|
||||
pynput==1.8.1
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
pyopengl==3.1.9
|
||||
pyopengl==3.1.10
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.3
|
||||
pyparsing==3.2.5
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via reachy2-sdk
|
||||
pyrealsense2==2.56.5.9235
|
||||
# via lerobot
|
||||
pyserial==3.5
|
||||
@@ -485,12 +608,14 @@ pyserial==3.5
|
||||
# dynamixel-sdk
|
||||
# feetech-servo-sdk
|
||||
# lerobot
|
||||
pytest==8.4.1
|
||||
pytest==8.4.2
|
||||
# via
|
||||
# bddl
|
||||
# lerobot
|
||||
# pytest-cov
|
||||
# pytest-timeout
|
||||
pytest-cov==6.2.1
|
||||
# teleop
|
||||
pytest-cov==7.0.0
|
||||
# via lerobot
|
||||
pytest-timeout==2.4.0
|
||||
# via lerobot
|
||||
@@ -498,48 +623,75 @@ python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# matplotlib
|
||||
# pandas
|
||||
python-dotenv==1.1.1
|
||||
# via uvicorn
|
||||
python-xlib==0.33
|
||||
# via pynput
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
pyyaml==6.0.2
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# draccus
|
||||
# hebi-py
|
||||
# huggingface-hub
|
||||
# jupytext
|
||||
# omegaconf
|
||||
# peft
|
||||
# pre-commit
|
||||
# pyngrok
|
||||
# pyyaml-include
|
||||
# timm
|
||||
# transformers
|
||||
# uvicorn
|
||||
# wandb
|
||||
pyyaml-include==1.4.1
|
||||
# via draccus
|
||||
pyzmq==27.0.0
|
||||
pyzmq==27.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# meshcat
|
||||
regex==2025.7.34
|
||||
reachy2-sdk==1.0.14
|
||||
# via lerobot
|
||||
reachy2-sdk-api==1.0.21
|
||||
# via reachy2-sdk
|
||||
referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2025.10.23
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
requests==2.32.4
|
||||
requests==2.32.5
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# teleop
|
||||
# transformers
|
||||
# wandb
|
||||
rerun-sdk==0.22.1
|
||||
rerun-sdk==0.26.1
|
||||
# via lerobot
|
||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||
# via placo
|
||||
safetensors==0.5.3
|
||||
robomimic==0.2.0
|
||||
# via libero
|
||||
robosuite==1.4.0
|
||||
# via libero
|
||||
rpds-py==0.28.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.6.2
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# transformers
|
||||
scikit-image==0.25.2
|
||||
# via
|
||||
@@ -548,10 +700,12 @@ scikit-image==0.25.2
|
||||
scipy==1.15.3
|
||||
# via
|
||||
# dm-control
|
||||
# metaworld
|
||||
# robosuite
|
||||
# scikit-image
|
||||
sentry-sdk==2.34.1
|
||||
sentry-sdk==2.42.1
|
||||
# via wandb
|
||||
shapely==2.1.1
|
||||
shapely==2.1.2
|
||||
# via gym-pusht
|
||||
six==1.17.0
|
||||
# via
|
||||
@@ -560,66 +714,109 @@ six==1.17.0
|
||||
# python-xlib
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
starlette==0.48.0
|
||||
# via fastapi
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
termcolor==3.1.0
|
||||
teleop==0.1.2
|
||||
# via lerobot
|
||||
tensorboard==2.20.0
|
||||
# via robomimic
|
||||
tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
tensorboardx==2.6.4
|
||||
# via robomimic
|
||||
termcolor==3.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
thop==0.1.1.post2209072238
|
||||
# via libero
|
||||
tifffile==2025.5.10
|
||||
# via scikit-image
|
||||
tokenizers==0.21.4
|
||||
timm==1.0.20
|
||||
# via lerobot
|
||||
tokenizers==0.22.1
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via draccus
|
||||
tomli==2.2.1
|
||||
tomli==2.3.0
|
||||
# via
|
||||
# cmeel
|
||||
# coverage
|
||||
# jupytext
|
||||
# pytest
|
||||
torch==2.7.1
|
||||
# via
|
||||
# accelerate
|
||||
# flash-attn
|
||||
# lerobot
|
||||
# peft
|
||||
# robomimic
|
||||
# thop
|
||||
# timm
|
||||
# torchvision
|
||||
torchcodec==0.5
|
||||
# via lerobot
|
||||
torchvision==0.22.1
|
||||
# via lerobot
|
||||
tornado==6.5.1
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
# timm
|
||||
tornado==6.5.2
|
||||
# via meshcat
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# datasets
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# peft
|
||||
# robomimic
|
||||
# transformers
|
||||
traitlets==5.14.3
|
||||
# via
|
||||
# ipython
|
||||
# jupyter-core
|
||||
# matplotlib-inline
|
||||
transformers==4.51.3
|
||||
# via lerobot
|
||||
# nbformat
|
||||
transformers==4.57.1
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
# peft
|
||||
transforms3d==0.4.2
|
||||
# via teleop
|
||||
triton==3.3.1
|
||||
# via torch
|
||||
typing-extensions==4.14.1
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# etils
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
# gymnasium
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# multidict
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
# starlette
|
||||
# torch
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# uvicorn
|
||||
# virtualenv
|
||||
# wandb
|
||||
typing-inspect==0.9.0
|
||||
# via draccus
|
||||
typing-inspection==0.4.1
|
||||
typing-inspection==0.4.2
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via pandas
|
||||
@@ -629,22 +826,36 @@ urllib3==2.5.0
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
virtualenv==20.32.0
|
||||
uvicorn[standard]==0.38.0
|
||||
# via teleop
|
||||
uvloop==0.22.1
|
||||
# via uvicorn
|
||||
virtualenv==20.35.3
|
||||
# via pre-commit
|
||||
wandb==0.21.0
|
||||
# via lerobot
|
||||
wcwidth==0.2.13
|
||||
wandb==0.21.4
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
watchfiles==1.1.1
|
||||
# via uvicorn
|
||||
wcwidth==0.2.14
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.9.0
|
||||
# via teleop
|
||||
websockets==15.0.1
|
||||
# via uvicorn
|
||||
werkzeug==3.1.3
|
||||
# via flask
|
||||
wrapt==1.17.2
|
||||
# via tensorboard
|
||||
wrapt==2.0.0
|
||||
# via dm-tree
|
||||
xxhash==3.5.0
|
||||
xxhash==3.6.0
|
||||
# via datasets
|
||||
yarl==1.20.1
|
||||
yarl==1.22.0
|
||||
# via aiohttp
|
||||
zipp==3.23.0
|
||||
# via importlib-metadata
|
||||
# via
|
||||
# etils
|
||||
# importlib-metadata
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# setuptools
|
||||
|
||||
+4
-4
@@ -1,9 +1,9 @@
|
||||
# requirements.in
|
||||
|
||||
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 15.5 24F74 arm64).
|
||||
# Darwin MacBook-Pro.local 24.5.0 Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:43 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8132 arm64
|
||||
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.0.1 25A362 arm64).
|
||||
# Darwin MacBook-Pro.local 25.0.0 Darwin Kernel Version 25.0.0: Wed Sep 17 21:42:08 PDT 2025; root:xnu-12377.1.9~141/RELEASE_ARM64_T8132 arm64
|
||||
|
||||
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.2 LTS x86_64).
|
||||
# Linux mlerobot-linux 6.14.0-27-generic #27~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 22 17:38:49 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
||||
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.3 LTS x86_64).
|
||||
# Linux mlerobot-linux 6.14.0-33-generic #33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
||||
|
||||
-e .[all]
|
||||
|
||||
@@ -57,7 +57,6 @@ available_tasks_per_env = {
|
||||
"AlohaTransferCube-v0",
|
||||
],
|
||||
"pusht": ["PushT-v0"],
|
||||
"xarm": ["XarmLift-v0"],
|
||||
}
|
||||
available_envs = list(available_tasks_per_env.keys())
|
||||
|
||||
@@ -75,16 +74,6 @@ available_datasets_per_env = {
|
||||
# TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly
|
||||
# coupled with tests.
|
||||
"pusht": ["lerobot/pusht", "lerobot/pusht_image"],
|
||||
"xarm": [
|
||||
"lerobot/xarm_lift_medium",
|
||||
"lerobot/xarm_lift_medium_replay",
|
||||
"lerobot/xarm_push_medium",
|
||||
"lerobot/xarm_push_medium_replay",
|
||||
"lerobot/xarm_lift_medium_image",
|
||||
"lerobot/xarm_lift_medium_replay_image",
|
||||
"lerobot/xarm_push_medium_image",
|
||||
"lerobot/xarm_push_medium_replay_image",
|
||||
],
|
||||
}
|
||||
|
||||
available_real_world_datasets = [
|
||||
@@ -195,7 +184,6 @@ available_motors = [
|
||||
available_policies_per_env = {
|
||||
"aloha": ["act"],
|
||||
"pusht": ["diffusion", "vqbet"],
|
||||
"xarm": ["tdmpc"],
|
||||
"koch_real": ["act_koch_real"],
|
||||
"aloha_real": ["act_aloha_real"],
|
||||
}
|
||||
|
||||
@@ -142,11 +142,6 @@ class RobotClientConfig:
|
||||
default=False, metadata={"help": "Visualize the action queue size"}
|
||||
)
|
||||
|
||||
# Verification configuration
|
||||
verify_robot_cameras: bool = field(
|
||||
default=True, metadata={"help": "Verify that the robot cameras match the policy cameras"}
|
||||
)
|
||||
|
||||
@property
|
||||
def environment_dt(self) -> float:
|
||||
"""Environment time step, in seconds"""
|
||||
|
||||
@@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
||||
|
||||
# All action chunking policies
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"]
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"]
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower"]
|
||||
|
||||
@@ -16,7 +16,7 @@ import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@@ -25,7 +25,14 @@ from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
|
||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
|
||||
from lerobot.policies import ( # noqa: F401
|
||||
ACTConfig,
|
||||
DiffusionConfig,
|
||||
PI0Config,
|
||||
PI05Config,
|
||||
SmolVLAConfig,
|
||||
VQBeTConfig,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.utils import init_logging
|
||||
@@ -55,15 +62,6 @@ def visualize_action_queue_size(action_queue_size: list[int]) -> None:
|
||||
plt.show()
|
||||
|
||||
|
||||
def validate_robot_cameras_for_policy(
|
||||
lerobot_observation_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature]
|
||||
) -> None:
|
||||
image_keys = list(filter(is_image_key, lerobot_observation_features))
|
||||
assert set(image_keys) == set(policy_image_features.keys()), (
|
||||
f"Policy image features must match robot cameras! Received {list(policy_image_features.keys())} != {image_keys}"
|
||||
)
|
||||
|
||||
|
||||
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
|
||||
return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False)
|
||||
|
||||
@@ -85,11 +83,11 @@ def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int,
|
||||
return resized.squeeze(0)
|
||||
|
||||
|
||||
# TODO(Steven): Consider implementing a pipeline step for this
|
||||
def raw_observation_to_observation(
|
||||
raw_observation: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
policy_image_features: dict[str, PolicyFeature],
|
||||
device: str,
|
||||
) -> Observation:
|
||||
observation = {}
|
||||
|
||||
@@ -98,9 +96,7 @@ def raw_observation_to_observation(
|
||||
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations
|
||||
if "image" in k:
|
||||
# Policy expects images in shape (B, C, H, W)
|
||||
observation[k] = prepare_image(v).unsqueeze(0).to(device)
|
||||
else:
|
||||
observation[k] = v.to(device)
|
||||
observation[k] = prepare_image(v).unsqueeze(0)
|
||||
else:
|
||||
observation[k] = v
|
||||
|
||||
@@ -272,6 +268,7 @@ class RemotePolicyConfig:
|
||||
lerobot_features: dict[str, PolicyFeature]
|
||||
actions_per_chunk: int
|
||||
device: str = "cpu"
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"""
|
||||
Example:
|
||||
```shell
|
||||
python src/lerobot/async_inference/policy_server.py \
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080 \
|
||||
--fps=30 \
|
||||
@@ -32,12 +32,17 @@ from concurrent import futures
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
@@ -82,6 +87,8 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
self.lerobot_features = None
|
||||
self.actions_per_chunk = None
|
||||
self.policy = None
|
||||
self.preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None
|
||||
self.postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
@@ -146,6 +153,19 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
start = time.perf_counter()
|
||||
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
|
||||
self.policy.to(self.device)
|
||||
|
||||
# Load preprocessor and postprocessor, overriding device to match requested device
|
||||
device_override = {"device": self.device}
|
||||
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
||||
self.policy.config,
|
||||
pretrained_path=policy_specs.pretrained_name_or_path,
|
||||
preprocessor_overrides={
|
||||
"device_processor": device_override,
|
||||
"rename_observations_processor": {"rename_map": policy_specs.rename_map},
|
||||
},
|
||||
postprocessor_overrides={"device_processor": device_override},
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
|
||||
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
|
||||
@@ -173,7 +193,7 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
# Calculate FPS metrics
|
||||
fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp)
|
||||
|
||||
self.logger.info(
|
||||
self.logger.debug(
|
||||
f"Received observation #{obs_timestep} | "
|
||||
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client
|
||||
f"Target: {fps_metrics['target_fps']:.2f} | "
|
||||
@@ -189,7 +209,7 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
if not self._enqueue_observation(
|
||||
timed_observation # wrapping a RawObservation
|
||||
):
|
||||
self.logger.info(f"Observation #{obs_timestep} has been filtered out")
|
||||
self.logger.debug(f"Observation #{obs_timestep} has been filtered out")
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
@@ -301,23 +321,6 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
for i, action in enumerate(action_chunk)
|
||||
]
|
||||
|
||||
def _prepare_observation(self, observation_t: TimedObservation) -> Observation:
|
||||
"""
|
||||
Prepare observation, ready for policy inference.
|
||||
E.g.: To keep observation sampling rate high (and network packet tiny) we send int8 [0,255] images from the
|
||||
client and then convert them to float32 [0,1] images here, before running inference.
|
||||
"""
|
||||
# RawObservation from robot.get_observation() - wrong keys, wrong dtype, wrong image shape
|
||||
observation: Observation = raw_observation_to_observation(
|
||||
observation_t.get_observation(),
|
||||
self.lerobot_features,
|
||||
self.policy_image_features,
|
||||
self.device,
|
||||
)
|
||||
# processed Observation - right keys, right dtype, right image shape
|
||||
|
||||
return observation
|
||||
|
||||
def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Get an action chunk from the policy. The chunk contains only"""
|
||||
chunk = self.policy.predict_action_chunk(observation)
|
||||
@@ -327,44 +330,76 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
return chunk[:, : self.actions_per_chunk, :]
|
||||
|
||||
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
|
||||
"""Predict an action chunk based on an observation"""
|
||||
inference_starts = time.perf_counter()
|
||||
"""Predict an action chunk based on an observation.
|
||||
|
||||
Pipeline:
|
||||
1. Convert raw observation to LeRobot format
|
||||
2. Apply preprocessor (tokenization, normalization, batching, device placement)
|
||||
3. Run policy inference to get action chunk
|
||||
4. Apply postprocessor (unnormalization, device movement)
|
||||
5. Convert to TimedAction list
|
||||
"""
|
||||
"""1. Prepare observation"""
|
||||
start_time = time.perf_counter()
|
||||
observation = self._prepare_observation(observation_t)
|
||||
preprocessing_time = time.perf_counter() - start_time
|
||||
start_prepare = time.perf_counter()
|
||||
observation: Observation = raw_observation_to_observation(
|
||||
observation_t.get_observation(),
|
||||
self.lerobot_features,
|
||||
self.policy_image_features,
|
||||
)
|
||||
prepare_time = time.perf_counter() - start_prepare
|
||||
|
||||
"""2. Apply preprocessor"""
|
||||
start_preprocess = time.perf_counter()
|
||||
observation = self.preprocessor(observation)
|
||||
self.last_processed_obs: TimedObservation = observation_t
|
||||
preprocessing_time = time.perf_counter() - start_preprocess
|
||||
|
||||
"""2. Get action chunk"""
|
||||
start_time = time.perf_counter()
|
||||
"""3. Get action chunk"""
|
||||
start_inference = time.perf_counter()
|
||||
action_tensor = self._get_action_chunk(observation)
|
||||
inference_time = time.perf_counter() - start_time
|
||||
inference_time = time.perf_counter() - start_inference
|
||||
self.logger.info(
|
||||
f"Preprocessing and inference took {inference_time:.4f}s, action shape: {action_tensor.shape}"
|
||||
)
|
||||
|
||||
"""3. Post-inference processing"""
|
||||
start_time = time.perf_counter()
|
||||
# Move to CPU before serializing
|
||||
action_tensor = action_tensor.cpu().squeeze(0)
|
||||
"""4. Apply postprocessor"""
|
||||
# Apply postprocessor (handles unnormalization and device movement)
|
||||
# Postprocessor expects (B, action_dim) per action, but we have (B, chunk_size, action_dim)
|
||||
# So we process each action in the chunk individually
|
||||
start_postprocess = time.perf_counter()
|
||||
_, chunk_size, _ = action_tensor.shape
|
||||
|
||||
# Process each action in the chunk
|
||||
processed_actions = []
|
||||
for i in range(chunk_size):
|
||||
# Extract action at timestep i: (B, action_dim)
|
||||
single_action = action_tensor[:, i, :]
|
||||
processed_action = self.postprocessor(single_action)
|
||||
processed_actions.append(processed_action)
|
||||
|
||||
# Stack back to (B, chunk_size, action_dim), then remove batch dim
|
||||
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
|
||||
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
|
||||
|
||||
"""5. Convert to TimedAction list"""
|
||||
action_chunk = self._time_action_chunk(
|
||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
||||
)
|
||||
postprocessing_time = time.perf_counter() - start_time
|
||||
inference_stops = time.perf_counter()
|
||||
postprocess_stops = time.perf_counter()
|
||||
postprocessing_time = postprocess_stops - start_postprocess
|
||||
|
||||
self.logger.info(
|
||||
f"Observation {observation_t.get_timestep()} |"
|
||||
f"Inference time: {1000 * (inference_stops - inference_starts):.2f}ms"
|
||||
f"Observation {observation_t.get_timestep()} | "
|
||||
f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
|
||||
)
|
||||
|
||||
# full-process latency breakdown for debugging purposes
|
||||
self.logger.debug(
|
||||
f"Observation {observation_t.get_timestep()} | "
|
||||
f"Preprocessing time: {1000 * (preprocessing_time - inference_starts):.2f}ms | "
|
||||
f"Inference time: {1000 * (inference_time - preprocessing_time):.2f}ms | "
|
||||
f"Postprocessing time: {1000 * (postprocessing_time - inference_time):.2f}ms | "
|
||||
f"Total time: {1000 * (postprocessing_time - inference_starts):.2f}ms"
|
||||
f"Prepare time: {1000 * prepare_time:.2f}ms | "
|
||||
f"Preprocessing time: {1000 * preprocessing_time:.2f}ms | "
|
||||
f"Inference time: {1000 * inference_time:.2f}ms | "
|
||||
f"Postprocessing time: {1000 * postprocessing_time:.2f}ms | "
|
||||
f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
|
||||
)
|
||||
|
||||
return action_chunk
|
||||
|
||||
@@ -48,10 +48,10 @@ import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so100_follower,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
@@ -75,7 +75,6 @@ from .helpers import (
|
||||
TimedObservation,
|
||||
get_logger,
|
||||
map_robot_keys_to_lerobot_features,
|
||||
validate_robot_cameras_for_policy,
|
||||
visualize_action_queue_size,
|
||||
)
|
||||
|
||||
@@ -97,14 +96,6 @@ class RobotClient:
|
||||
|
||||
lerobot_features = map_robot_keys_to_lerobot_features(self.robot)
|
||||
|
||||
if config.verify_robot_cameras:
|
||||
# Load policy config for validation
|
||||
policy_config = PreTrainedConfig.from_pretrained(config.pretrained_name_or_path)
|
||||
policy_image_features = policy_config.image_features
|
||||
|
||||
# The cameras specified for inference must match the one supported by the policy chosen
|
||||
validate_robot_cameras_for_policy(lerobot_features, policy_image_features)
|
||||
|
||||
# Use environment variable if server_address is not provided in config
|
||||
self.server_address = config.server_address
|
||||
|
||||
@@ -214,7 +205,7 @@ class RobotClient:
|
||||
)
|
||||
_ = self.stub.SendObservations(observation_iterator)
|
||||
obs_timestep = obs.get_timestep()
|
||||
self.logger.info(f"Sent observation #{obs_timestep} | ")
|
||||
self.logger.debug(f"Sent observation #{obs_timestep} | ")
|
||||
|
||||
return True
|
||||
|
||||
@@ -467,7 +458,7 @@ class RobotClient:
|
||||
if self._ready_to_send_observation():
|
||||
_captured_observation = self.control_loop_observation(task, verbose)
|
||||
|
||||
self.logger.info(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
|
||||
self.logger.debug(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
|
||||
# Dynamically adjust sleep time to maintain the desired control frequency
|
||||
time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start)))
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
from .configs import CameraConfig, ColorMode
|
||||
|
||||
@@ -89,7 +89,7 @@ class Camera(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""Capture and return a single frame from the camera.
|
||||
|
||||
Args:
|
||||
@@ -102,7 +102,7 @@ class Camera(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def async_read(self, timeout_ms: float = ...) -> np.ndarray:
|
||||
def async_read(self, timeout_ms: float = ...) -> NDArray[Any]:
|
||||
"""Asynchronously capture and return a single frame from the camera.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -18,7 +18,7 @@ import abc
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import draccus
|
||||
import draccus # type: ignore # TODO: add type stubs for draccus
|
||||
|
||||
|
||||
class ColorMode(str, Enum):
|
||||
@@ -34,11 +34,11 @@ class Cv2Rotation(int, Enum):
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
return str(self.get_choice_name(self.__class__))
|
||||
|
||||
@@ -14,3 +14,5 @@
|
||||
|
||||
from .camera_opencv import OpenCVCamera
|
||||
from .configuration_opencv import OpenCVCameraConfig
|
||||
|
||||
__all__ = ["OpenCVCamera", "OpenCVCameraConfig"]
|
||||
|
||||
@@ -25,11 +25,12 @@ from pathlib import Path
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
# Fix MSMF hardware transform compatibility for Windows before importing cv2
|
||||
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2
|
||||
import numpy as np
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
@@ -121,7 +122,7 @@ class OpenCVCamera(Camera):
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
@@ -140,7 +141,7 @@ class OpenCVCamera(Camera):
|
||||
"""Checks if the camera is currently connected and opened."""
|
||||
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
|
||||
|
||||
def connect(self, warmup: bool = True):
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the OpenCV camera specified in the configuration.
|
||||
|
||||
@@ -180,12 +181,14 @@ class OpenCVCamera(Camera):
|
||||
|
||||
def _configure_capture_settings(self) -> None:
|
||||
"""
|
||||
Applies the specified FPS, width, and height settings to the connected camera.
|
||||
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
|
||||
|
||||
This method attempts to set the camera properties via OpenCV. It checks if
|
||||
the camera successfully applied the settings and raises an error if not.
|
||||
FOURCC is set first (if specified) as it can affect the available FPS and resolution options.
|
||||
|
||||
Args:
|
||||
fourcc: The desired FOURCC code (e.g., "MJPG", "YUYV"). If None, auto-detect.
|
||||
fps: The desired frames per second. If None, the setting is skipped.
|
||||
width: The desired capture width. If None, the setting is skipped.
|
||||
height: The desired capture height. If None, the setting is skipped.
|
||||
@@ -199,10 +202,11 @@ class OpenCVCamera(Camera):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
|
||||
|
||||
if self.fps is None:
|
||||
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
|
||||
else:
|
||||
self._validate_fps()
|
||||
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
|
||||
if self.config.fourcc is not None:
|
||||
self._validate_fourcc()
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
|
||||
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
||||
@@ -216,18 +220,56 @@ class OpenCVCamera(Camera):
|
||||
else:
|
||||
self._validate_width_and_height()
|
||||
|
||||
if self.fps is None:
|
||||
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
|
||||
else:
|
||||
self._validate_fps()
|
||||
|
||||
def _validate_fps(self) -> None:
|
||||
"""Validates and sets the camera's frames per second (FPS)."""
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
if self.fps is None:
|
||||
raise ValueError(f"{self} FPS is not set")
|
||||
|
||||
success = self.videocapture.set(cv2.CAP_PROP_FPS, float(self.fps))
|
||||
actual_fps = self.videocapture.get(cv2.CAP_PROP_FPS)
|
||||
# Use math.isclose for robust float comparison
|
||||
if not success or not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
raise RuntimeError(f"{self} failed to set fps={self.fps} ({actual_fps=}).")
|
||||
|
||||
def _validate_fourcc(self) -> None:
|
||||
"""Validates and sets the camera's FOURCC code."""
|
||||
|
||||
fourcc_code = cv2.VideoWriter_fourcc(*self.config.fourcc)
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
success = self.videocapture.set(cv2.CAP_PROP_FOURCC, fourcc_code)
|
||||
actual_fourcc_code = self.videocapture.get(cv2.CAP_PROP_FOURCC)
|
||||
|
||||
# Convert actual FOURCC code back to string for comparison
|
||||
actual_fourcc_code_int = int(actual_fourcc_code)
|
||||
actual_fourcc = "".join([chr((actual_fourcc_code_int >> 8 * i) & 0xFF) for i in range(4)])
|
||||
|
||||
if not success or actual_fourcc != self.config.fourcc:
|
||||
logger.warning(
|
||||
f"{self} failed to set fourcc={self.config.fourcc} (actual={actual_fourcc}, success={success}). "
|
||||
f"Continuing with default format."
|
||||
)
|
||||
|
||||
def _validate_width_and_height(self) -> None:
|
||||
"""Validates and sets the camera's frame capture width and height."""
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
if self.capture_width is None or self.capture_height is None:
|
||||
raise ValueError(f"{self} capture_width or capture_height is not set")
|
||||
|
||||
width_success = self.videocapture.set(cv2.CAP_PROP_FRAME_WIDTH, float(self.capture_width))
|
||||
height_success = self.videocapture.set(cv2.CAP_PROP_FRAME_HEIGHT, float(self.capture_height))
|
||||
|
||||
@@ -258,11 +300,12 @@ class OpenCVCamera(Camera):
|
||||
"""
|
||||
found_cameras_info = []
|
||||
|
||||
targets_to_scan: list[str | int]
|
||||
if platform.system() == "Linux":
|
||||
possible_paths = sorted(Path("/dev").glob("video*"), key=lambda p: p.name)
|
||||
targets_to_scan = [str(p) for p in possible_paths]
|
||||
else:
|
||||
targets_to_scan = list(range(MAX_OPENCV_INDEX))
|
||||
targets_to_scan = [int(i) for i in range(MAX_OPENCV_INDEX)]
|
||||
|
||||
for target in targets_to_scan:
|
||||
camera = cv2.VideoCapture(target)
|
||||
@@ -271,6 +314,12 @@ class OpenCVCamera(Camera):
|
||||
default_height = int(camera.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
default_fps = camera.get(cv2.CAP_PROP_FPS)
|
||||
default_format = camera.get(cv2.CAP_PROP_FORMAT)
|
||||
|
||||
# Get FOURCC code and convert to string
|
||||
default_fourcc_code = camera.get(cv2.CAP_PROP_FOURCC)
|
||||
default_fourcc_code_int = int(default_fourcc_code)
|
||||
default_fourcc = "".join([chr((default_fourcc_code_int >> 8 * i) & 0xFF) for i in range(4)])
|
||||
|
||||
camera_info = {
|
||||
"name": f"OpenCV Camera @ {target}",
|
||||
"type": "OpenCV",
|
||||
@@ -278,6 +327,7 @@ class OpenCVCamera(Camera):
|
||||
"backend_api": camera.getBackendName(),
|
||||
"default_stream_profile": {
|
||||
"format": default_format,
|
||||
"fourcc": default_fourcc,
|
||||
"width": default_width,
|
||||
"height": default_height,
|
||||
"fps": default_fps,
|
||||
@@ -289,7 +339,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return found_cameras_info
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
@@ -317,6 +367,9 @@ class OpenCVCamera(Camera):
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
ret, frame = self.videocapture.read()
|
||||
|
||||
if not ret or frame is None:
|
||||
@@ -329,7 +382,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return processed_frame
|
||||
|
||||
def _postprocess_image(self, image: np.ndarray, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Applies color conversion, dimension validation, and rotation to a raw frame.
|
||||
|
||||
@@ -372,7 +425,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return processed_image
|
||||
|
||||
def _read_loop(self):
|
||||
def _read_loop(self) -> None:
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
@@ -383,6 +436,9 @@ class OpenCVCamera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
@@ -419,7 +475,7 @@ class OpenCVCamera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
@@ -462,7 +518,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera and cleans up resources.
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ from pathlib import Path
|
||||
|
||||
from ..configs import CameraConfig, ColorMode, Cv2Rotation
|
||||
|
||||
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("opencv")
|
||||
@dataclass
|
||||
@@ -33,8 +35,9 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
OpenCVCameraConfig(0, 30, 1280, 720) # 1280x720 @ 30FPS
|
||||
OpenCVCameraConfig(/dev/video4, 60, 640, 480) # 640x480 @ 60FPS
|
||||
|
||||
# Advanced configurations
|
||||
OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation
|
||||
# Advanced configurations with FOURCC format
|
||||
OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90, fourcc="MJPG") # With 90° rotation and MJPG format
|
||||
OpenCVCameraConfig(0, 30, 1280, 720, fourcc="YUYV") # With YUYV format
|
||||
```
|
||||
|
||||
Attributes:
|
||||
@@ -46,17 +49,21 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
|
||||
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
|
||||
warmup_s: Time reading frames before returning from connect (in seconds)
|
||||
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
|
||||
|
||||
Note:
|
||||
- Only 3-channel color output (RGB/BGR) is currently supported.
|
||||
- FOURCC codes must be 4-character strings (e.g., "MJPG", "YUYV"). Some common FOUCC codes: https://learn.microsoft.com/en-us/windows/win32/medfound/video-fourccs#fourcc-constants
|
||||
- Setting FOURCC can help achieve higher frame rates on some cameras.
|
||||
"""
|
||||
|
||||
index_or_path: int | Path
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
|
||||
warmup_s: int = 1
|
||||
fourcc: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
@@ -71,3 +78,8 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
raise ValueError(
|
||||
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
|
||||
)
|
||||
|
||||
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
|
||||
raise ValueError(
|
||||
f"`fourcc` must be a 4-character string (e.g., 'MJPG', 'YUYV'), but '{self.fourcc}' is provided."
|
||||
)
|
||||
|
||||
@@ -16,6 +16,8 @@ from dataclasses import dataclass
|
||||
|
||||
from ..configs import CameraConfig, ColorMode
|
||||
|
||||
__all__ = ["CameraConfig", "ColorMode", "Reachy2CameraConfig"]
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("reachy2_camera")
|
||||
@dataclass
|
||||
@@ -62,7 +64,7 @@ class Reachy2CameraConfig(CameraConfig):
|
||||
port: int = 50065
|
||||
# use_depth: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.name not in ["teleop", "depth"]:
|
||||
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
|
||||
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (
|
||||
|
||||
@@ -23,13 +23,17 @@ import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
# Fix MSMF hardware transform compatibility for Windows before importing cv2
|
||||
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2
|
||||
import numpy as np
|
||||
from reachy2_sdk.media.camera import CameraView
|
||||
from reachy2_sdk.media.camera_manager import CameraManager
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import numpy as np # type: ignore # TODO: add type stubs for numpy
|
||||
from reachy2_sdk.media.camera import CameraView # type: ignore # TODO: add type stubs for reachy2_sdk
|
||||
from reachy2_sdk.media.camera_manager import ( # type: ignore # TODO: add type stubs for reachy2_sdk
|
||||
CameraManager,
|
||||
)
|
||||
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
@@ -73,7 +77,7 @@ class Reachy2Camera(Camera):
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -83,13 +87,17 @@ class Reachy2Camera(Camera):
|
||||
def is_connected(self) -> bool:
|
||||
"""Checks if the camera is currently connected and opened."""
|
||||
if self.config.name == "teleop":
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
|
||||
return bool(
|
||||
self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
|
||||
)
|
||||
elif self.config.name == "depth":
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
|
||||
return bool(
|
||||
self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
|
||||
|
||||
def connect(self, warmup: bool = True):
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the Reachy2 CameraManager as specified in the configuration.
|
||||
"""
|
||||
@@ -131,7 +139,7 @@ class Reachy2Camera(Camera):
|
||||
camera_manager.disconnect()
|
||||
return initialized_cameras
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
@@ -152,7 +160,7 @@ class Reachy2Camera(Camera):
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
frame = None
|
||||
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
|
||||
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
@@ -179,7 +187,7 @@ class Reachy2Camera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self):
|
||||
def _read_loop(self) -> None:
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
@@ -190,6 +198,9 @@ class Reachy2Camera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
@@ -226,7 +237,7 @@ class Reachy2Camera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
@@ -269,7 +280,7 @@ class Reachy2Camera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Stops the background read thread (if running).
|
||||
|
||||
|
||||
@@ -21,11 +21,12 @@ import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import numpy as np # type: ignore # TODO: add type stubs for numpy
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
try:
|
||||
import pyrealsense2 as rs
|
||||
import pyrealsense2 as rs # type: ignore # TODO: add type stubs for pyrealsense2
|
||||
except Exception as e:
|
||||
logging.info(f"Could not import realsense: {e}")
|
||||
|
||||
@@ -132,7 +133,7 @@ class RealSenseCamera(Camera):
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
@@ -150,7 +151,7 @@ class RealSenseCamera(Camera):
|
||||
"""Checks if the camera pipeline is started and streams are active."""
|
||||
return self.rs_pipeline is not None and self.rs_profile is not None
|
||||
|
||||
def connect(self, warmup: bool = True):
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the RealSense camera specified in the configuration.
|
||||
|
||||
@@ -264,7 +265,7 @@ class RealSenseCamera(Camera):
|
||||
serial_number = str(found_devices[0]["serial_number"])
|
||||
return serial_number
|
||||
|
||||
def _configure_rs_pipeline_config(self, rs_config):
|
||||
def _configure_rs_pipeline_config(self, rs_config: Any) -> None:
|
||||
"""Creates and configures the RealSense pipeline configuration object."""
|
||||
rs.config.enable_device(rs_config, self.serial_number)
|
||||
|
||||
@@ -293,6 +294,9 @@ class RealSenseCamera(Camera):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
|
||||
|
||||
if self.rs_profile is None:
|
||||
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
|
||||
|
||||
stream = self.rs_profile.get_stream(rs.stream.color).as_video_stream_profile()
|
||||
|
||||
if self.fps is None:
|
||||
@@ -308,7 +312,7 @@ class RealSenseCamera(Camera):
|
||||
self.width, self.height = actual_width, actual_height
|
||||
self.capture_width, self.capture_height = actual_width, actual_height
|
||||
|
||||
def read_depth(self, timeout_ms: int = 200) -> np.ndarray:
|
||||
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame (depth) synchronously from the camera.
|
||||
|
||||
@@ -336,6 +340,9 @@ class RealSenseCamera(Camera):
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.rs_pipeline is None:
|
||||
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
|
||||
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
|
||||
|
||||
if not ret or frame is None:
|
||||
@@ -351,7 +358,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return depth_map_processed
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> np.ndarray:
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame (color) synchronously from the camera.
|
||||
|
||||
@@ -376,6 +383,9 @@ class RealSenseCamera(Camera):
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.rs_pipeline is None:
|
||||
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
|
||||
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
|
||||
|
||||
if not ret or frame is None:
|
||||
@@ -392,8 +402,8 @@ class RealSenseCamera(Camera):
|
||||
return color_image_processed
|
||||
|
||||
def _postprocess_image(
|
||||
self, image: np.ndarray, color_mode: ColorMode | None = None, depth_frame: bool = False
|
||||
) -> np.ndarray:
|
||||
self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False
|
||||
) -> NDArray[Any]:
|
||||
"""
|
||||
Applies color conversion, dimension validation, and rotation to a raw color frame.
|
||||
|
||||
@@ -438,7 +448,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return processed_image
|
||||
|
||||
def _read_loop(self):
|
||||
def _read_loop(self) -> None:
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
@@ -449,6 +459,9 @@ class RealSenseCamera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read(timeout_ms=500)
|
||||
@@ -474,7 +487,7 @@ class RealSenseCamera(Camera):
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def _stop_read_thread(self):
|
||||
def _stop_read_thread(self) -> None:
|
||||
"""Signals the background read thread to stop and waits for it to join."""
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
@@ -486,7 +499,7 @@ class RealSenseCamera(Camera):
|
||||
self.stop_event = None
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame data (color) asynchronously.
|
||||
|
||||
@@ -529,7 +542,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera, stops the pipeline, and cleans up resources.
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ class RealSenseCameraConfig(CameraConfig):
|
||||
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
|
||||
warmup_s: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
|
||||
@@ -15,15 +15,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
import platform
|
||||
from typing import cast
|
||||
|
||||
from lerobot.utils.import_utils import make_device_from_device_class
|
||||
|
||||
from .camera import Camera
|
||||
from .configs import CameraConfig, Cv2Rotation
|
||||
|
||||
|
||||
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[str, Camera]:
|
||||
cameras = {}
|
||||
cameras: dict[str, Camera] = {}
|
||||
|
||||
for key, cfg in camera_configs.items():
|
||||
# TODO(Steven): Consider just using the make_device_from_device_class for all types
|
||||
if cfg.type == "opencv":
|
||||
from .opencv import OpenCVCamera
|
||||
|
||||
@@ -40,20 +44,23 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
|
||||
cameras[key] = Reachy2Camera(cfg)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
|
||||
try:
|
||||
cameras[key] = cast(Camera, make_device_from_device_class(cfg))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error creating camera {key} with config {cfg}: {e}") from e
|
||||
|
||||
return cameras
|
||||
|
||||
|
||||
def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
|
||||
import cv2
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
|
||||
if rotation == Cv2Rotation.ROTATE_90:
|
||||
return cv2.ROTATE_90_CLOCKWISE
|
||||
return int(cv2.ROTATE_90_CLOCKWISE)
|
||||
elif rotation == Cv2Rotation.ROTATE_180:
|
||||
return cv2.ROTATE_180
|
||||
return int(cv2.ROTATE_180)
|
||||
elif rotation == Cv2Rotation.ROTATE_270:
|
||||
return cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -62,8 +69,8 @@ def get_cv2_backend() -> int:
|
||||
import cv2
|
||||
|
||||
if platform.system() == "Windows":
|
||||
return cv2.CAP_MSMF # Use MSMF for Windows instead of AVFOUNDATION
|
||||
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
|
||||
# elif platform.system() == "Darwin": # macOS
|
||||
# return cv2.CAP_AVFOUNDATION
|
||||
else: # Linux and others
|
||||
return cv2.CAP_ANY
|
||||
return int(cv2.CAP_ANY)
|
||||
|
||||
@@ -57,7 +57,7 @@ class EvalConfig:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
use_async_envs: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.batch_size > self.n_episodes:
|
||||
raise ValueError(
|
||||
"The eval batch size is greater than the number of eval episodes "
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import datetime as dt
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot import envs, policies # noqa: F401
|
||||
@@ -22,6 +22,8 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.default import EvalConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalPipelineConfig:
|
||||
@@ -34,25 +36,31 @@ class EvalPipelineConfig:
|
||||
output_dir: Path | None = None
|
||||
job_name: str | None = None
|
||||
seed: int | None = 1000
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
|
||||
else:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
||||
)
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{self.policy.type}"
|
||||
self.job_name = f"{self.policy.type if self.policy is not None else 'scratch'}"
|
||||
else:
|
||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
||||
self.job_name = (
|
||||
f"{self.env.type}_{self.policy.type if self.policy is not None else 'scratch'}"
|
||||
)
|
||||
|
||||
logger.warning(f"No job name provided, using '{self.job_name}' as job name.")
|
||||
|
||||
if not self.output_dir:
|
||||
now = dt.datetime.now()
|
||||
|
||||
@@ -16,14 +16,19 @@ import inspect
|
||||
import pkgutil
|
||||
import sys
|
||||
from argparse import ArgumentError
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from pkgutil import ModuleInfo
|
||||
from types import ModuleType
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.utils.utils import has_method
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
PATH_KEY = "path"
|
||||
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||
|
||||
@@ -60,7 +65,7 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
|
||||
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict[str, str]:
|
||||
"""Parse plugin-related arguments from command-line arguments.
|
||||
|
||||
This function extracts arguments from command-line arguments that match a specified suffix pattern.
|
||||
@@ -127,7 +132,7 @@ def load_plugin(plugin_path: str) -> None:
|
||||
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
||||
) from e
|
||||
|
||||
def iter_namespace(ns_pkg):
|
||||
def iter_namespace(ns_pkg: ModuleType) -> Iterable[ModuleInfo]:
|
||||
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
|
||||
|
||||
try:
|
||||
@@ -148,6 +153,8 @@ def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | No
|
||||
|
||||
|
||||
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
|
||||
if args is None:
|
||||
return []
|
||||
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
|
||||
|
||||
|
||||
@@ -171,7 +178,8 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
if isinstance(fields_to_filter, str):
|
||||
fields_to_filter = [fields_to_filter]
|
||||
|
||||
filtered_args = args
|
||||
filtered_args = [] if args is None else list(args)
|
||||
|
||||
for field in fields_to_filter:
|
||||
if get_path_arg(field, args):
|
||||
if get_type_arg(field, args):
|
||||
@@ -184,7 +192,7 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
return filtered_args
|
||||
|
||||
|
||||
def wrap(config_path: Path | None = None):
|
||||
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
"""
|
||||
HACK: Similar to draccus.wrap but does three additional things:
|
||||
- Will remove '.path' arguments from CLI in order to process them later on.
|
||||
@@ -195,9 +203,9 @@ def wrap(config_path: Path | None = None):
|
||||
from the CLI '.type' arguments
|
||||
"""
|
||||
|
||||
def wrapper_outer(fn):
|
||||
def wrapper_outer(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def wrapper_inner(*args, **kwargs):
|
||||
def wrapper_inner(*args: Any, **kwargs: Any) -> Any:
|
||||
argspec = inspect.getfullargspec(fn)
|
||||
argtype = argspec.annotations[argspec.args[0]]
|
||||
if len(args) > 0 and type(args[0]) is argtype:
|
||||
@@ -225,6 +233,6 @@ def wrap(config_path: Path | None = None):
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
return wrapper_inner
|
||||
return cast(F, wrapper_inner)
|
||||
|
||||
return wrapper_outer
|
||||
return cast(Callable[[F], F], wrapper_outer)
|
||||
|
||||
@@ -14,12 +14,12 @@
|
||||
import abc
|
||||
import builtins
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -34,10 +34,11 @@ from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedConfig")
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: ignore[misc,name-defined] #TODO: draccus issue
|
||||
"""
|
||||
Base configuration class for policy models.
|
||||
|
||||
@@ -57,12 +58,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
device: str | None = None # cuda | cpu | mp
|
||||
device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps"
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool = False
|
||||
|
||||
push_to_hub: bool = True
|
||||
push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override
|
||||
repo_id: str | None = None
|
||||
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
@@ -71,38 +72,43 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
tags: list[str] | None = None
|
||||
# Add tags to your policy on the hub.
|
||||
license: str | None = None
|
||||
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
||||
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
|
||||
pretrained_path: Path | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.pretrained_path = None
|
||||
def __post_init__(self) -> None:
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
self.device = auto_device.type
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
if self.use_amp and not is_amp_available(self.device):
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||
)
|
||||
self.use_amp = False
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
choice_name = self.get_choice_name(self.__class__)
|
||||
if not isinstance(choice_name, str):
|
||||
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||
return choice_name
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_delta_indices(self) -> list | None:
|
||||
def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def reward_delta_indices(self) -> list | None:
|
||||
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -152,13 +158,13 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = None,
|
||||
proxies: dict | None = None,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[Any, Any] | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**policy_kwargs,
|
||||
**policy_kwargs: Any,
|
||||
) -> T:
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
@@ -166,7 +172,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
if CONFIG_NAME in os.listdir(model_id):
|
||||
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||
else:
|
||||
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
@@ -192,6 +198,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
with draccus.config_type("json"):
|
||||
orig_config = draccus.parse(cls, config_file, args=[])
|
||||
|
||||
if config_file is None:
|
||||
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
|
||||
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import datetime as dt
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -63,18 +64,18 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
self.checkpoint_path = None
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
# Only load the policy config
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
elif self.resume:
|
||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||
config_path = parser.parse_arg("config_path")
|
||||
@@ -82,14 +83,22 @@ class TrainPipelineConfig(HubMixin):
|
||||
raise ValueError(
|
||||
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
||||
)
|
||||
|
||||
if not Path(config_path).resolve().exists():
|
||||
raise NotADirectoryError(
|
||||
f"{config_path=} is expected to be a local path. "
|
||||
"Resuming from the hub is not supported for now."
|
||||
)
|
||||
policy_path = Path(config_path).parent
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.checkpoint_path = policy_path.parent
|
||||
|
||||
policy_dir = Path(config_path).parent
|
||||
if self.policy is not None:
|
||||
self.policy.pretrained_path = policy_dir
|
||||
self.checkpoint_path = policy_dir.parent
|
||||
|
||||
if self.policy is None:
|
||||
raise ValueError(
|
||||
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
|
||||
)
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
@@ -126,8 +135,8 @@ class TrainPipelineConfig(HubMixin):
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return draccus.encode(self)
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
@@ -139,13 +148,13 @@ class TrainPipelineConfig(HubMixin):
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = None,
|
||||
proxies: dict | None = None,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[Any, Any] | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> "TrainPipelineConfig":
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
@@ -181,4 +190,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
@@ -35,9 +35,11 @@ class NormalizationMode(str, Enum):
|
||||
MIN_MAX = "MIN_MAX"
|
||||
MEAN_STD = "MEAN_STD"
|
||||
IDENTITY = "IDENTITY"
|
||||
QUANTILES = "QUANTILES"
|
||||
QUANTILE10 = "QUANTILE10"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyFeature:
|
||||
type: FeatureType
|
||||
shape: tuple
|
||||
shape: tuple[int, ...]
|
||||
|
||||
@@ -31,15 +31,15 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
get_file_size_in_mb,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_video_size_in_mb,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
|
||||
|
||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
@@ -130,10 +130,34 @@ def update_meta_data(
|
||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||
for key, video_idx in videos_idx.items():
|
||||
df[f"videos/{key}/chunk_index"] = df[f"videos/{key}/chunk_index"] + video_idx["chunk"]
|
||||
df[f"videos/{key}/file_index"] = df[f"videos/{key}/file_index"] + video_idx["file"]
|
||||
df[f"videos/{key}/from_timestamp"] = df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||
df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
|
||||
# Store original video file indices before updating
|
||||
orig_chunk_col = f"videos/{key}/chunk_index"
|
||||
orig_file_col = f"videos/{key}/file_index"
|
||||
df["_orig_chunk"] = df[orig_chunk_col].copy()
|
||||
df["_orig_file"] = df[orig_file_col].copy()
|
||||
|
||||
# Update chunk and file indices to point to destination
|
||||
df[orig_chunk_col] = video_idx["chunk"]
|
||||
df[orig_file_col] = video_idx["file"]
|
||||
|
||||
# Apply per-source-file timestamp offsets
|
||||
src_to_offset = video_idx.get("src_to_offset", {})
|
||||
if src_to_offset:
|
||||
# Apply offset based on original source file
|
||||
for idx in df.index:
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||
else:
|
||||
# Fallback to simple offset (for backward compatibility)
|
||||
df[f"videos/{key}/from_timestamp"] = (
|
||||
df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||
)
|
||||
df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
|
||||
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
@@ -193,6 +217,10 @@ def aggregate_datasets(
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
root=aggr_root,
|
||||
use_videos=len(video_keys) > 0,
|
||||
chunks_size=chunk_size,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
video_files_size_in_mb=video_files_size_in_mb,
|
||||
)
|
||||
|
||||
logging.info("Find all tasks")
|
||||
@@ -236,6 +264,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
Returns:
|
||||
dict: Updated videos_idx with current chunk and file indices.
|
||||
"""
|
||||
for key in videos_idx:
|
||||
videos_idx[key]["episode_duration"] = 0
|
||||
# Track offset for each source (chunk, file) pair
|
||||
videos_idx[key]["src_to_offset"] = {}
|
||||
|
||||
for key, video_idx in videos_idx.items():
|
||||
unique_chunk_file_pairs = {
|
||||
(chunk, file)
|
||||
@@ -249,6 +282,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
|
||||
chunk_idx = video_idx["chunk"]
|
||||
file_idx = video_idx["file"]
|
||||
current_offset = video_idx["latest_duration"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
@@ -263,21 +297,25 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
# If a new file is created, we don't want to increment the latest_duration
|
||||
update_latest_duration = False
|
||||
src_duration = get_video_duration_in_s(src_path)
|
||||
|
||||
if not dst_path.exists():
|
||||
# First write to this destination file
|
||||
# Store offset before incrementing
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
continue # not accumulating further, already copied the file in place
|
||||
videos_idx[key]["episode_duration"] += src_duration
|
||||
current_offset += src_duration
|
||||
continue
|
||||
|
||||
# Check file sizes before appending
|
||||
src_size = get_video_size_in_mb(src_path)
|
||||
dst_size = get_video_size_in_mb(dst_path)
|
||||
src_size = get_file_size_in_mb(src_path)
|
||||
dst_size = get_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= video_files_size_in_mb:
|
||||
# Rotate to a new chunk/file
|
||||
# Rotate to a new file, this source becomes start of new destination
|
||||
# So its offset should be 0
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
@@ -286,25 +324,22 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
# Reset offset for next file
|
||||
current_offset = src_duration
|
||||
else:
|
||||
# Get the timestamps shift for this video
|
||||
timestamps_shift_s = dst_meta.info["total_frames"] / dst_meta.info["fps"]
|
||||
|
||||
# Append to existing video file
|
||||
# Append to existing video file - use current accumulated offset
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
concatenate_video_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
)
|
||||
# Update the latest_duration when appending (shifts timestamps!)
|
||||
update_latest_duration = not update_latest_duration
|
||||
current_offset += src_duration
|
||||
|
||||
videos_idx[key]["episode_duration"] += src_duration
|
||||
|
||||
# Update the videos_idx with the final chunk and file indices for this key
|
||||
videos_idx[key]["chunk"] = chunk_idx
|
||||
videos_idx[key]["file"] = file_idx
|
||||
|
||||
if update_latest_duration:
|
||||
videos_idx[key]["latest_duration"] += timestamps_shift_s
|
||||
|
||||
return videos_idx
|
||||
|
||||
|
||||
@@ -389,9 +424,6 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
videos_idx,
|
||||
)
|
||||
|
||||
for k in videos_idx:
|
||||
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||
|
||||
meta_idx = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
@@ -403,6 +435,10 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
aggr_root=dst_meta.root,
|
||||
)
|
||||
|
||||
# Increment latest_duration by the total duration added from this source dataset
|
||||
for k in videos_idx:
|
||||
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||
|
||||
return meta_idx
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,9 @@ Please, update your dataset to the new format using this command:
|
||||
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id={repo_id}
|
||||
```
|
||||
|
||||
If you already have a converted version uploaded to the hub, then this error might be because of
|
||||
an older version in your local cache. Consider deleting the cached version and retrying.
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
@@ -17,6 +17,179 @@ import numpy as np
|
||||
|
||||
from lerobot.datasets.utils import load_image_as_numpy
|
||||
|
||||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||||
|
||||
|
||||
class RunningQuantileStats:
|
||||
"""
|
||||
Maintains running statistics for batches of vectors, including mean,
|
||||
standard deviation, min, max, and approximate quantiles.
|
||||
|
||||
Statistics are computed per feature dimension and updated incrementally
|
||||
as new batches are observed. Quantiles are estimated using histograms,
|
||||
which adapt dynamically if the observed data range expands.
|
||||
"""
|
||||
|
||||
def __init__(self, quantile_list: list[float] | None = None, num_quantile_bins: int = 5000):
|
||||
self._count = 0
|
||||
self._mean = None
|
||||
self._mean_of_squares = None
|
||||
self._min = None
|
||||
self._max = None
|
||||
self._histograms = None
|
||||
self._bin_edges = None
|
||||
self._num_quantile_bins = num_quantile_bins
|
||||
|
||||
self._quantile_list = quantile_list
|
||||
if self._quantile_list is None:
|
||||
self._quantile_list = DEFAULT_QUANTILES
|
||||
self._quantile_keys = [f"q{int(q * 100):02d}" for q in self._quantile_list]
|
||||
|
||||
def update(self, batch: np.ndarray) -> None:
|
||||
"""Update the running statistics with a batch of vectors.
|
||||
|
||||
Args:
|
||||
batch: An array where all dimensions except the last are batch dimensions.
|
||||
"""
|
||||
batch = batch.reshape(-1, batch.shape[-1])
|
||||
num_elements, vector_length = batch.shape
|
||||
|
||||
if self._count == 0:
|
||||
self._mean = np.mean(batch, axis=0)
|
||||
self._mean_of_squares = np.mean(batch**2, axis=0)
|
||||
self._min = np.min(batch, axis=0)
|
||||
self._max = np.max(batch, axis=0)
|
||||
self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)]
|
||||
self._bin_edges = [
|
||||
np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1)
|
||||
for i in range(vector_length)
|
||||
]
|
||||
else:
|
||||
if vector_length != self._mean.size:
|
||||
raise ValueError("The length of new vectors does not match the initialized vector length.")
|
||||
|
||||
new_max = np.max(batch, axis=0)
|
||||
new_min = np.min(batch, axis=0)
|
||||
max_changed = np.any(new_max > self._max)
|
||||
min_changed = np.any(new_min < self._min)
|
||||
self._max = np.maximum(self._max, new_max)
|
||||
self._min = np.minimum(self._min, new_min)
|
||||
|
||||
if max_changed or min_changed:
|
||||
self._adjust_histograms()
|
||||
|
||||
self._count += num_elements
|
||||
|
||||
batch_mean = np.mean(batch, axis=0)
|
||||
batch_mean_of_squares = np.mean(batch**2, axis=0)
|
||||
|
||||
# Update running mean and mean of squares
|
||||
self._mean += (batch_mean - self._mean) * (num_elements / self._count)
|
||||
self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (
|
||||
num_elements / self._count
|
||||
)
|
||||
|
||||
self._update_histograms(batch)
|
||||
|
||||
def get_statistics(self) -> dict[str, np.ndarray]:
|
||||
"""Compute and return the statistics of the vectors processed so far.
|
||||
|
||||
Args:
|
||||
quantiles: List of quantiles to compute (e.g., [0.01, 0.10, 0.50, 0.90, 0.99]). If None, no quantiles computed.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the computed statistics.
|
||||
"""
|
||||
if self._count < 2:
|
||||
raise ValueError("Cannot compute statistics for less than 2 vectors.")
|
||||
|
||||
variance = self._mean_of_squares - self._mean**2
|
||||
|
||||
stddev = np.sqrt(np.maximum(0, variance))
|
||||
|
||||
stats = {
|
||||
"min": self._min.copy(),
|
||||
"max": self._max.copy(),
|
||||
"mean": self._mean.copy(),
|
||||
"std": stddev,
|
||||
"count": np.array([self._count]),
|
||||
}
|
||||
|
||||
quantile_results = self._compute_quantiles()
|
||||
for i, q in enumerate(self._quantile_keys):
|
||||
stats[q] = quantile_results[i]
|
||||
|
||||
return stats
|
||||
|
||||
def _adjust_histograms(self):
|
||||
"""Adjust histograms when min or max changes."""
|
||||
for i in range(len(self._histograms)):
|
||||
old_edges = self._bin_edges[i]
|
||||
old_hist = self._histograms[i]
|
||||
|
||||
# Create new edges with small padding to ensure range coverage
|
||||
padding = (self._max[i] - self._min[i]) * 1e-10
|
||||
new_edges = np.linspace(
|
||||
self._min[i] - padding, self._max[i] + padding, self._num_quantile_bins + 1
|
||||
)
|
||||
|
||||
# Redistribute existing histogram counts to new bins
|
||||
# We need to map each old bin center to the new bins
|
||||
old_centers = (old_edges[:-1] + old_edges[1:]) / 2
|
||||
new_hist = np.zeros(self._num_quantile_bins)
|
||||
|
||||
for old_center, count in zip(old_centers, old_hist, strict=False):
|
||||
if count > 0:
|
||||
# Find which new bin this old center belongs to
|
||||
bin_idx = np.searchsorted(new_edges, old_center) - 1
|
||||
bin_idx = max(0, min(bin_idx, self._num_quantile_bins - 1))
|
||||
new_hist[bin_idx] += count
|
||||
|
||||
self._histograms[i] = new_hist
|
||||
self._bin_edges[i] = new_edges
|
||||
|
||||
def _update_histograms(self, batch: np.ndarray) -> None:
|
||||
"""Update histograms with new vectors."""
|
||||
for i in range(batch.shape[1]):
|
||||
hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i])
|
||||
self._histograms[i] += hist
|
||||
|
||||
def _compute_quantiles(self) -> list[np.ndarray]:
|
||||
"""Compute quantiles based on histograms."""
|
||||
results = []
|
||||
for q in self._quantile_list:
|
||||
target_count = q * self._count
|
||||
q_values = []
|
||||
|
||||
for hist, edges in zip(self._histograms, self._bin_edges, strict=True):
|
||||
q_value = self._compute_single_quantile(hist, edges, target_count)
|
||||
q_values.append(q_value)
|
||||
|
||||
results.append(np.array(q_values))
|
||||
return results
|
||||
|
||||
def _compute_single_quantile(self, hist: np.ndarray, edges: np.ndarray, target_count: float) -> float:
|
||||
"""Compute a single quantile value from histogram and bin edges."""
|
||||
cumsum = np.cumsum(hist)
|
||||
idx = np.searchsorted(cumsum, target_count)
|
||||
|
||||
if idx == 0:
|
||||
return edges[0]
|
||||
if idx >= len(cumsum):
|
||||
return edges[-1]
|
||||
|
||||
# If not edge case, interpolate within the bin
|
||||
count_before = cumsum[idx - 1]
|
||||
count_in_bin = cumsum[idx] - count_before
|
||||
|
||||
# If no samples in this bin, use the bin edge
|
||||
if count_in_bin == 0:
|
||||
return edges[idx]
|
||||
|
||||
# Linear interpolation within the bin
|
||||
fraction = (target_count - count_before) / count_in_bin
|
||||
return edges[idx] + fraction * (edges[idx + 1] - edges[idx])
|
||||
|
||||
|
||||
def estimate_num_samples(
|
||||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||
@@ -72,33 +245,282 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
return images
|
||||
|
||||
|
||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||
return {
|
||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||
"max": np.max(array, axis=axis, keepdims=keepdims),
|
||||
"mean": np.mean(array, axis=axis, keepdims=keepdims),
|
||||
"std": np.std(array, axis=axis, keepdims=keepdims),
|
||||
"count": np.array([len(array)]),
|
||||
def _reshape_stats_by_axis(
|
||||
stats: dict[str, np.ndarray],
|
||||
axis: int | tuple[int, ...] | None,
|
||||
keepdims: bool,
|
||||
original_shape: tuple[int, ...],
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Reshape all statistics to match NumPy's output conventions.
|
||||
|
||||
Applies consistent reshaping to all statistics (except 'count') based on the
|
||||
axis and keepdims parameters. This ensures statistics have the correct shape
|
||||
for broadcasting with the original data.
|
||||
|
||||
Args:
|
||||
stats: Dictionary of computed statistics
|
||||
axis: Axis or axes along which statistics were computed
|
||||
keepdims: Whether to keep reduced dimensions as size-1 dimensions
|
||||
original_shape: Shape of the original array
|
||||
|
||||
Returns:
|
||||
Dictionary with reshaped statistics
|
||||
|
||||
Note:
|
||||
The 'count' statistic is never reshaped as it represents metadata
|
||||
rather than per-feature statistics.
|
||||
"""
|
||||
if axis == (1,) and not keepdims:
|
||||
return stats
|
||||
|
||||
result = {}
|
||||
for key, value in stats.items():
|
||||
if key == "count":
|
||||
result[key] = value
|
||||
else:
|
||||
result[key] = _reshape_single_stat(value, axis, keepdims, original_shape)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _reshape_for_image_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
|
||||
"""Reshape statistics for image data (axis=(0,2,3))."""
|
||||
if keepdims and value.ndim == 1:
|
||||
return value.reshape(1, -1, 1, 1)
|
||||
return value
|
||||
|
||||
|
||||
def _reshape_for_vector_stats(
|
||||
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
|
||||
) -> np.ndarray:
|
||||
"""Reshape statistics for vector data (axis=0 or axis=(0,))."""
|
||||
if not keepdims:
|
||||
return value
|
||||
|
||||
if len(original_shape) == 1 and value.ndim > 0:
|
||||
return value.reshape(1)
|
||||
elif len(original_shape) >= 2 and value.ndim == 1:
|
||||
return value.reshape(1, -1)
|
||||
return value
|
||||
|
||||
|
||||
def _reshape_for_feature_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
|
||||
"""Reshape statistics for feature-wise computation (axis=(1,))."""
|
||||
if not keepdims:
|
||||
return value
|
||||
|
||||
if value.ndim == 0:
|
||||
return value.reshape(1, 1)
|
||||
elif value.ndim == 1:
|
||||
return value.reshape(-1, 1)
|
||||
return value
|
||||
|
||||
|
||||
def _reshape_for_global_stats(
|
||||
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
|
||||
) -> np.ndarray | float:
|
||||
"""Reshape statistics for global reduction (axis=None)."""
|
||||
if keepdims:
|
||||
target_shape = tuple(1 for _ in original_shape)
|
||||
return value.reshape(target_shape)
|
||||
# Keep at least 1-D arrays to satisfy validator
|
||||
return np.atleast_1d(value)
|
||||
|
||||
|
||||
def _reshape_single_stat(
|
||||
value: np.ndarray, axis: int | tuple[int, ...] | None, keepdims: bool, original_shape: tuple[int, ...]
|
||||
) -> np.ndarray | float:
|
||||
"""Apply appropriate reshaping to a single statistic array.
|
||||
|
||||
This function transforms statistic arrays to match expected output shapes
|
||||
based on the axis configuration and keepdims parameter.
|
||||
|
||||
Args:
|
||||
value: The statistic array to reshape
|
||||
axis: Axis or axes that were reduced during computation
|
||||
keepdims: Whether to maintain reduced dimensions as size-1 dimensions
|
||||
original_shape: Shape of the original data before reduction
|
||||
|
||||
Returns:
|
||||
Reshaped array following NumPy broadcasting conventions
|
||||
|
||||
"""
|
||||
if axis == (0, 2, 3):
|
||||
return _reshape_for_image_stats(value, keepdims)
|
||||
|
||||
if axis in [0, (0,)]:
|
||||
return _reshape_for_vector_stats(value, keepdims, original_shape)
|
||||
|
||||
if axis == (1,):
|
||||
return _reshape_for_feature_stats(value, keepdims)
|
||||
|
||||
if axis is None:
|
||||
return _reshape_for_global_stats(value, keepdims, original_shape)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _prepare_array_for_stats(array: np.ndarray, axis: int | tuple[int, ...] | None) -> tuple[np.ndarray, int]:
|
||||
"""Prepare array for statistics computation by reshaping according to axis.
|
||||
|
||||
Args:
|
||||
array: Input data array
|
||||
axis: Axis or axes along which to compute statistics
|
||||
|
||||
Returns:
|
||||
Tuple of (reshaped_array, sample_count)
|
||||
"""
|
||||
if axis == (0, 2, 3): # Image data
|
||||
batch_size, channels, height, width = array.shape
|
||||
reshaped = array.transpose(0, 2, 3, 1).reshape(-1, channels)
|
||||
return reshaped, batch_size
|
||||
|
||||
if axis == 0 or axis == (0,): # Vector data
|
||||
reshaped = array
|
||||
if array.ndim == 1:
|
||||
reshaped = array.reshape(-1, 1)
|
||||
return reshaped, array.shape[0]
|
||||
|
||||
if axis == (1,): # Feature-wise statistics
|
||||
return array.T, array.shape[1]
|
||||
|
||||
if axis is None: # Global statistics
|
||||
reshaped = array.reshape(-1, 1)
|
||||
# For backward compatibility, count represents the first dimension size
|
||||
return reshaped, array.shape[0] if array.ndim > 0 else 1
|
||||
|
||||
raise ValueError(f"Unsupported axis configuration: {axis}")
|
||||
|
||||
|
||||
def _compute_basic_stats(
|
||||
array: np.ndarray, sample_count: int, quantile_list: list[float] | None = None
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Compute basic statistics for arrays with insufficient samples for quantiles.
|
||||
|
||||
Args:
|
||||
array: Reshaped array ready for statistics computation
|
||||
sample_count: Number of samples represented in the data
|
||||
|
||||
Returns:
|
||||
Dictionary with basic statistics and quantiles set to mean values
|
||||
"""
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
quantile_list_keys = [f"q{int(q * 100):02d}" for q in quantile_list]
|
||||
|
||||
stats = {
|
||||
"min": np.min(array, axis=0),
|
||||
"max": np.max(array, axis=0),
|
||||
"mean": np.mean(array, axis=0),
|
||||
"std": np.std(array, axis=0),
|
||||
"count": np.array([sample_count]),
|
||||
}
|
||||
|
||||
for q in quantile_list_keys:
|
||||
stats[q] = stats["mean"].copy()
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def get_feature_stats(
|
||||
array: np.ndarray,
|
||||
axis: int | tuple[int, ...] | None,
|
||||
keepdims: bool,
|
||||
quantile_list: list[float] | None = None,
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Compute comprehensive statistics for array features along specified axes.
|
||||
|
||||
This function calculates min, max, mean, std, and quantiles (1%, 10%, 50%, 90%, 99%)
|
||||
for the input array along the specified axes. It handles different data layouts:
|
||||
- Image data: axis=(0,2,3) computes per-channel statistics
|
||||
- Vector data: axis=0 computes per-feature statistics
|
||||
- Feature-wise: axis=1 computes statistics across features
|
||||
- Global: axis=None computes statistics over entire array
|
||||
|
||||
Args:
|
||||
array: Input data array with shape appropriate for the specified axis
|
||||
axis: Axis or axes along which to compute statistics
|
||||
- (0, 2, 3): For image data (batch, channels, height, width)
|
||||
- 0 or (0,): For vector/tabular data (samples, features)
|
||||
- (1,): For computing across features
|
||||
- None: For global statistics over entire array
|
||||
keepdims: If True, reduced axes are kept as dimensions with size 1
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- 'min': Minimum values
|
||||
- 'max': Maximum values
|
||||
- 'mean': Mean values
|
||||
- 'std': Standard deviation
|
||||
- 'count': Number of samples (always shape (1,))
|
||||
- 'q01', 'q10', 'q50', 'q90', 'q99': Quantile values
|
||||
|
||||
"""
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
|
||||
original_shape = array.shape
|
||||
reshaped, sample_count = _prepare_array_for_stats(array, axis)
|
||||
|
||||
if reshaped.shape[0] < 2:
|
||||
stats = _compute_basic_stats(reshaped, sample_count, quantile_list)
|
||||
else:
|
||||
running_stats = RunningQuantileStats()
|
||||
running_stats.update(reshaped)
|
||||
stats = running_stats.get_statistics()
|
||||
stats["count"] = np.array([sample_count])
|
||||
|
||||
stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape)
|
||||
return stats
|
||||
|
||||
|
||||
def compute_episode_stats(
|
||||
episode_data: dict[str, list[str] | np.ndarray],
|
||||
features: dict,
|
||||
quantile_list: list[float] | None = None,
|
||||
) -> dict:
|
||||
"""Compute comprehensive statistics for all features in an episode.
|
||||
|
||||
Processes different data types appropriately:
|
||||
- Images/videos: Samples from paths, computes per-channel stats, normalizes to [0,1]
|
||||
- Numerical arrays: Computes per-feature statistics
|
||||
- Strings: Skipped (no statistics computed)
|
||||
|
||||
Args:
|
||||
episode_data: Dictionary mapping feature names to data
|
||||
- For images/videos: list of file paths
|
||||
- For numerical data: numpy arrays
|
||||
features: Dictionary describing each feature's dtype and shape
|
||||
|
||||
Returns:
|
||||
Dictionary mapping feature names to their statistics dictionaries.
|
||||
Each statistics dictionary contains min, max, mean, std, count, and quantiles.
|
||||
|
||||
Note:
|
||||
Image statistics are normalized to [0,1] range and have shape (3,1,1) for
|
||||
per-channel values when dtype is 'image' or 'video'.
|
||||
"""
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
|
||||
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
continue # HACK: we should receive np.arrays of strings
|
||||
elif features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data) # data is a list of image paths
|
||||
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||
continue
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data)
|
||||
axes_to_reduce = (0, 2, 3)
|
||||
keepdims = True
|
||||
else:
|
||||
ep_ft_array = data # data is already a np.ndarray
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
keepdims = data.ndim == 1 # keep as np.array
|
||||
ep_ft_array = data
|
||||
axes_to_reduce = 0
|
||||
keepdims = data.ndim == 1
|
||||
|
||||
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||
ep_stats[key] = get_feature_stats(
|
||||
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list
|
||||
)
|
||||
|
||||
# finally, we normalize and remove batch dim for images
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
@@ -107,20 +529,37 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
|
||||
return ep_stats
|
||||
|
||||
|
||||
def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
|
||||
"""Validate a single statistic value."""
|
||||
if not isinstance(value, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Stats must be composed of numpy array, but key '{key}' of feature '{feature_key}' "
|
||||
f"is of type '{type(value)}' instead."
|
||||
)
|
||||
|
||||
if value.ndim == 0:
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
|
||||
if key == "count" and value.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
|
||||
|
||||
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
|
||||
|
||||
|
||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
for i in range(len(stats_list)):
|
||||
for fkey in stats_list[i]:
|
||||
for k, v in stats_list[i][fkey].items():
|
||||
if not isinstance(v, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
|
||||
)
|
||||
if v.ndim == 0:
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
if k == "count" and v.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
|
||||
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
||||
"""Validate that all statistics have correct types and shapes.
|
||||
|
||||
Args:
|
||||
stats_list: List of statistics dictionaries to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If any statistic has incorrect type or shape
|
||||
"""
|
||||
for stats in stats_list:
|
||||
for feature_key, feature_stats in stats.items():
|
||||
for stat_key, stat_value in feature_stats.items():
|
||||
_validate_stat_value(stat_value, stat_key, feature_key)
|
||||
|
||||
|
||||
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
@@ -143,7 +582,7 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
|
||||
weighted_variances = (variances + delta_means**2) * counts
|
||||
total_variance = weighted_variances.sum(axis=0) / total_count
|
||||
|
||||
return {
|
||||
aggregated = {
|
||||
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
|
||||
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
|
||||
"mean": total_mean,
|
||||
@@ -151,6 +590,17 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
|
||||
"count": total_count,
|
||||
}
|
||||
|
||||
if stats_ft_list:
|
||||
quantile_keys = [k for k in stats_ft_list[0] if k.startswith("q") and k[1:].isdigit()]
|
||||
|
||||
for q_key in quantile_keys:
|
||||
if all(q_key in s for s in stats_ft_list):
|
||||
quantile_values = np.stack([s[q_key] for s in stats_ft_list])
|
||||
weighted_quantiles = quantile_values * counts
|
||||
aggregated[q_key] = weighted_quantiles.sum(axis=0) / total_count
|
||||
|
||||
return aggregated
|
||||
|
||||
|
||||
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -68,7 +68,30 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
|
||||
"""
|
||||
Saves a NumPy array or PIL Image to a file.
|
||||
|
||||
This function handles both NumPy arrays and PIL Image objects, converting
|
||||
the former to a PIL Image before saving. It includes error handling for
|
||||
the save operation.
|
||||
|
||||
Args:
|
||||
image (np.ndarray | PIL.Image.Image): The image data to save.
|
||||
fpath (Path): The destination file path for the image.
|
||||
compress_level (int, optional): The compression level for the saved
|
||||
image, as used by PIL.Image.save(). Defaults to 1.
|
||||
Refer to: https://github.com/huggingface/lerobot/pull/2135
|
||||
for more details on the default value rationale.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input 'image' is not a NumPy array or a
|
||||
PIL.Image.Image object.
|
||||
|
||||
Side Effects:
|
||||
Prints an error message to the console if the image writing process
|
||||
fails for any reason.
|
||||
"""
|
||||
try:
|
||||
if isinstance(image, np.ndarray):
|
||||
img = image_array_to_pil_image(image)
|
||||
@@ -76,7 +99,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
img = image
|
||||
else:
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath)
|
||||
img.save(fpath, compress_level=compress_level)
|
||||
except Exception as e:
|
||||
print(f"Error writing image {fpath}: {e}")
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import gc
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
@@ -26,6 +25,8 @@ import numpy as np
|
||||
import packaging.version
|
||||
import pandas as pd
|
||||
import PIL.Image
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
@@ -46,13 +47,9 @@ from lerobot.datasets.utils import (
|
||||
embed_images,
|
||||
flatten_dict,
|
||||
get_delta_indices,
|
||||
get_hf_dataset_cache_dir,
|
||||
get_hf_dataset_size_in_mb,
|
||||
get_file_size_in_mb,
|
||||
get_hf_features_from_features,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_parquet_num_frames,
|
||||
get_safe_version,
|
||||
get_video_size_in_mb,
|
||||
hf_transform_to_torch,
|
||||
is_valid_version,
|
||||
load_episodes,
|
||||
@@ -60,7 +57,6 @@ from lerobot.datasets.utils import (
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
@@ -90,10 +86,15 @@ class LeRobotDatasetMetadata:
|
||||
root: str | Path | None = None,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
metadata_buffer_size: int = 10,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
self.writer = None
|
||||
self.latest_episode = None
|
||||
self.metadata_buffer: list[dict] = []
|
||||
self.metadata_buffer_size = metadata_buffer_size
|
||||
|
||||
try:
|
||||
if force_cache_sync:
|
||||
@@ -107,6 +108,54 @@ class LeRobotDatasetMetadata:
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.load_metadata()
|
||||
|
||||
def _flush_metadata_buffer(self) -> None:
|
||||
"""Write all buffered episode metadata to parquet file."""
|
||||
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
|
||||
return
|
||||
|
||||
combined_dict = {}
|
||||
for episode_dict in self.metadata_buffer:
|
||||
for key, value in episode_dict.items():
|
||||
if key not in combined_dict:
|
||||
combined_dict[key] = []
|
||||
# Extract value and serialize numpy arrays
|
||||
# because PyArrow's from_pydict function doesn't support numpy arrays
|
||||
val = value[0] if isinstance(value, list) else value
|
||||
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
|
||||
|
||||
first_ep = self.metadata_buffer[0]
|
||||
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
|
||||
file_idx = first_ep["meta/episodes/file_index"][0]
|
||||
|
||||
table = pa.Table.from_pydict(combined_dict)
|
||||
|
||||
if not self.writer:
|
||||
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True
|
||||
)
|
||||
|
||||
self.writer.write_table(table)
|
||||
|
||||
self.latest_episode = self.metadata_buffer[-1]
|
||||
self.metadata_buffer.clear()
|
||||
|
||||
def _close_writer(self) -> None:
|
||||
"""Close and cleanup the parquet writer if it exists."""
|
||||
self._flush_metadata_buffer()
|
||||
|
||||
writer = getattr(self, "writer", None)
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
self.writer = None
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
|
||||
"""
|
||||
self._close_writer()
|
||||
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
@@ -138,6 +187,12 @@ class LeRobotDatasetMetadata:
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
raise IndexError(
|
||||
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||
)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep["data/chunk_index"]
|
||||
file_idx = ep["data/file_index"]
|
||||
@@ -145,6 +200,12 @@ class LeRobotDatasetMetadata:
|
||||
return Path(fpath)
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
raise IndexError(
|
||||
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||
)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep[f"videos/{vid_key}/chunk_index"]
|
||||
file_idx = ep[f"videos/{vid_key}/file_index"]
|
||||
@@ -260,72 +321,75 @@ class LeRobotDatasetMetadata:
|
||||
write_tasks(self.tasks, self.root)
|
||||
|
||||
def _save_episode_metadata(self, episode_dict: dict) -> None:
|
||||
"""Save episode metadata to a parquet file and update the Hugging Face dataset of episodes metadata.
|
||||
"""Buffer episode metadata and write to parquet in batches for efficiency.
|
||||
|
||||
This function processes episodes metadata from a dictionary, converts it into a Hugging Face dataset,
|
||||
and saves it as a parquet file. It handles both the creation of new parquet files and the
|
||||
updating of existing ones based on size constraints. After saving the metadata, it reloads
|
||||
the Hugging Face dataset to ensure it is up-to-date.
|
||||
This function accumulates episode metadata in a buffer and flushes it when the buffer
|
||||
reaches the configured size. This reduces I/O overhead by writing multiple episodes
|
||||
at once instead of one row at a time.
|
||||
|
||||
Notes: We both need to update parquet files and HF dataset:
|
||||
- `pandas` loads parquet file in RAM
|
||||
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
|
||||
or loads directly from pyarrow cache.
|
||||
"""
|
||||
# Convert buffer into HF Dataset
|
||||
# Convert to list format for each value
|
||||
episode_dict = {key: [value] for key, value in episode_dict.items()}
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict)
|
||||
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||
df = pd.DataFrame(ep_dataset)
|
||||
num_frames = episode_dict["length"][0]
|
||||
|
||||
if self.episodes is None:
|
||||
if self.latest_episode is None:
|
||||
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
df["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
df["meta/episodes/file_index"] = [file_idx]
|
||||
df["dataset_from_index"] = [0]
|
||||
df["dataset_to_index"] = [num_frames]
|
||||
else:
|
||||
# Retrieve information from the latest parquet file
|
||||
latest_ep = self.episodes[-1]
|
||||
chunk_idx = latest_ep["meta/episodes/chunk_index"]
|
||||
file_idx = latest_ep["meta/episodes/file_index"]
|
||||
if self.episodes is not None and len(self.episodes) > 0:
|
||||
# It means we are resuming recording, so we need to load the latest episode
|
||||
# Update the indices to avoid overwriting the latest episode
|
||||
chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"]
|
||||
file_idx = self.episodes[-1]["meta/episodes/file_index"]
|
||||
latest_num_frames = self.episodes[-1]["dataset_to_index"]
|
||||
episode_dict["dataset_from_index"] = [latest_num_frames]
|
||||
episode_dict["dataset_to_index"] = [latest_num_frames + num_frames]
|
||||
|
||||
latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.data_files_size_in_mb:
|
||||
# Size limit is reached, prepare new parquet file
|
||||
# When resuming, move to the next file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
||||
else:
|
||||
episode_dict["dataset_from_index"] = [0]
|
||||
episode_dict["dataset_to_index"] = [num_frames]
|
||||
|
||||
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
episode_dict["meta/episodes/file_index"] = [file_idx]
|
||||
else:
|
||||
chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0]
|
||||
file_idx = self.latest_episode["meta/episodes/file_index"][0]
|
||||
|
||||
latest_path = (
|
||||
self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
if self.writer is None
|
||||
else self.writer.where
|
||||
)
|
||||
|
||||
if Path(latest_path).exists():
|
||||
latest_size_in_mb = get_file_size_in_mb(Path(latest_path))
|
||||
latest_num_frames = self.latest_episode["episode_index"][0]
|
||||
|
||||
av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0
|
||||
|
||||
if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb:
|
||||
# Size limit is reached, flush buffer and prepare new parquet file
|
||||
self._flush_metadata_buffer()
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
||||
self._close_writer()
|
||||
|
||||
# Update the existing pandas dataframe with new row
|
||||
df["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
df["meta/episodes/file_index"] = [file_idx]
|
||||
df["dataset_from_index"] = [latest_ep["dataset_to_index"]]
|
||||
df["dataset_to_index"] = [latest_ep["dataset_to_index"] + num_frames]
|
||||
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
episode_dict["meta/episodes/file_index"] = [file_idx]
|
||||
episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]]
|
||||
episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames]
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb < self.data_files_size_in_mb:
|
||||
# Size limit wasnt reached, concatenate latest dataframe with new one
|
||||
latest_df = pd.read_parquet(latest_path)
|
||||
df = pd.concat([latest_df, df], ignore_index=True)
|
||||
# Add to buffer
|
||||
self.metadata_buffer.append(episode_dict)
|
||||
self.latest_episode = episode_dict
|
||||
|
||||
# Memort optimization
|
||||
del latest_df
|
||||
gc.collect()
|
||||
|
||||
# Write the resulting dataframe from RAM to disk
|
||||
path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(path, index=False)
|
||||
|
||||
if self.episodes is not None:
|
||||
# Remove the episodes cache directory, necessary to avoid cache bloat
|
||||
cached_dir = get_hf_dataset_cache_dir(self.episodes)
|
||||
if cached_dir is not None:
|
||||
shutil.rmtree(cached_dir)
|
||||
|
||||
self.episodes = load_episodes(self.root)
|
||||
if len(self.metadata_buffer) >= self.metadata_buffer_size:
|
||||
self._flush_metadata_buffer()
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
@@ -438,6 +502,10 @@ class LeRobotDatasetMetadata:
|
||||
robot_type: str | None = None,
|
||||
root: str | Path | None = None,
|
||||
use_videos: bool = True,
|
||||
metadata_buffer_size: int = 10,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> "LeRobotDatasetMetadata":
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
@@ -452,11 +520,24 @@ class LeRobotDatasetMetadata:
|
||||
obj.tasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type)
|
||||
obj.info = create_empty_dataset_info(
|
||||
CODEBASE_VERSION,
|
||||
fps,
|
||||
features,
|
||||
use_videos,
|
||||
robot_type,
|
||||
chunks_size,
|
||||
data_files_size_in_mb,
|
||||
video_files_size_in_mb,
|
||||
)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
obj.revision = None
|
||||
obj.writer = None
|
||||
obj.latest_episode = None
|
||||
obj.metadata_buffer = []
|
||||
obj.metadata_buffer_size = metadata_buffer_size
|
||||
return obj
|
||||
|
||||
|
||||
@@ -603,6 +684,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
self.episode_buffer = None
|
||||
self.writer = None
|
||||
self.latest_episode = None
|
||||
self._current_file_start_frame = None # Track the starting frame index of the current parquet file
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@@ -611,6 +695,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
|
||||
# Track dataset state for efficient incremental writing
|
||||
self._lazy_loading = False
|
||||
self._recorded_frames = self.meta.total_frames
|
||||
self._writer_closed_for_reading = False
|
||||
|
||||
# Load actual data
|
||||
try:
|
||||
if force_cache_sync:
|
||||
@@ -620,7 +709,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if not self._check_cached_episodes_sufficient():
|
||||
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
@@ -629,6 +719,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
def _close_writer(self) -> None:
|
||||
"""Close and cleanup the parquet writer if it exists."""
|
||||
writer = getattr(self, "writer", None)
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
self.writer = None
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
|
||||
"""
|
||||
self._close_writer()
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
branch: str | None = None,
|
||||
@@ -734,14 +837,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return hf_dataset
|
||||
|
||||
def _check_cached_episodes_sufficient(self) -> bool:
|
||||
"""Check if the cached dataset contains all requested episodes."""
|
||||
"""Check if the cached dataset contains all requested episodes and their video files."""
|
||||
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
||||
return False
|
||||
|
||||
# Get available episode indices from cached dataset
|
||||
available_episodes = {
|
||||
ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx
|
||||
for ep_idx in self.hf_dataset["episode_index"]
|
||||
for ep_idx in self.hf_dataset.unique("episode_index")
|
||||
}
|
||||
|
||||
# Determine requested episodes
|
||||
@@ -753,7 +856,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
requested_episodes = set(self.episodes)
|
||||
|
||||
# Check if all requested episodes are available in cached data
|
||||
return requested_episodes.issubset(available_episodes)
|
||||
if not requested_episodes.issubset(available_episodes):
|
||||
return False
|
||||
|
||||
# Check if all required video files exist
|
||||
if len(self.meta.video_keys) > 0:
|
||||
for ep_idx in requested_episodes:
|
||||
for vid_key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
if not video_path.exists():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def create_hf_dataset(self) -> datasets.Dataset:
|
||||
features = get_hf_features_from_features(self.features)
|
||||
@@ -769,8 +883,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Number of frames in selected episodes."""
|
||||
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
|
||||
"""Number of frames in selected episodes.
|
||||
|
||||
Note: When episodes a subset of the full dataset is requested, we must return the
|
||||
actual loaded data length (len(self.hf_dataset)) rather than metadata total_frames.
|
||||
self.meta.total_frames is the total number of frames in the full dataset.
|
||||
"""
|
||||
if self.episodes is not None and self.hf_dataset is not None:
|
||||
return len(self.hf_dataset)
|
||||
return self.meta.total_frames
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
@@ -848,10 +969,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return item
|
||||
|
||||
def _ensure_hf_dataset_loaded(self):
|
||||
"""Lazy load the HF dataset only when needed for reading."""
|
||||
if self._lazy_loading or self.hf_dataset is None:
|
||||
# Close the writer before loading to ensure parquet file is properly finalized
|
||||
if self.writer is not None:
|
||||
self._close_writer()
|
||||
self._writer_closed_for_reading = True
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self._lazy_loading = False
|
||||
|
||||
def __len__(self):
|
||||
return self.num_frames
|
||||
|
||||
def __getitem__(self, idx) -> dict:
|
||||
# Ensure dataset is loaded when we actually need to read from it
|
||||
self._ensure_hf_dataset_loaded()
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
@@ -890,6 +1023,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"})',\n"
|
||||
)
|
||||
|
||||
def finalize(self):
|
||||
"""
|
||||
Close the parquet writers. This function needs to be called after data collection/conversion, else footer metadata won't be written to the parquet files.
|
||||
The dataset won't be valid and can't be loaded as ds = LeRobotDataset(repo_id=repo, root=HF_LEROBOT_HOME.joinpath(repo))
|
||||
"""
|
||||
self._close_writer()
|
||||
self.meta._close_writer()
|
||||
|
||||
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
ep_buffer = {}
|
||||
@@ -1097,74 +1238,104 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ep_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train")
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||
ep_num_frames = len(ep_dataset)
|
||||
df = pd.DataFrame(ep_dataset)
|
||||
|
||||
if self.meta.episodes is None:
|
||||
if self.latest_episode is None:
|
||||
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
latest_num_frames = 0
|
||||
global_frame_index = 0
|
||||
self._current_file_start_frame = 0
|
||||
# However, if the episodes already exists
|
||||
# It means we are resuming recording, so we need to load the latest episode
|
||||
# Update the indices to avoid overwriting the latest episode
|
||||
if self.meta.episodes is not None and len(self.meta.episodes) > 0:
|
||||
latest_ep = self.meta.episodes[-1]
|
||||
global_frame_index = latest_ep["dataset_to_index"]
|
||||
chunk_idx = latest_ep["data/chunk_index"]
|
||||
file_idx = latest_ep["data/file_index"]
|
||||
|
||||
# When resuming, move to the next file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
self._current_file_start_frame = global_frame_index
|
||||
else:
|
||||
# Retrieve information from the latest parquet file
|
||||
latest_ep = self.meta.episodes[-1]
|
||||
latest_ep = self.latest_episode
|
||||
chunk_idx = latest_ep["data/chunk_index"]
|
||||
file_idx = latest_ep["data/file_index"]
|
||||
global_frame_index = latest_ep["index"][-1] + 1
|
||||
|
||||
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||
latest_num_frames = get_parquet_num_frames(latest_path)
|
||||
latest_size_in_mb = get_file_size_in_mb(latest_path)
|
||||
|
||||
frames_in_current_file = global_frame_index - self._current_file_start_frame
|
||||
av_size_per_frame = (
|
||||
latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0
|
||||
)
|
||||
|
||||
# Determine if a new parquet file is needed
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.data_files_size_in_mb:
|
||||
# Size limit is reached, prepare new parquet file
|
||||
if (
|
||||
latest_size_in_mb + av_size_per_frame * ep_num_frames >= self.meta.data_files_size_in_mb
|
||||
or self._writer_closed_for_reading
|
||||
):
|
||||
# Size limit is reached or writer was closed for reading, prepare new parquet file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
latest_num_frames = 0
|
||||
else:
|
||||
# Update the existing parquet file with new rows
|
||||
latest_df = pd.read_parquet(latest_path)
|
||||
df = pd.concat([latest_df, df], ignore_index=True)
|
||||
self._close_writer()
|
||||
self._writer_closed_for_reading = False
|
||||
self._current_file_start_frame = global_frame_index
|
||||
|
||||
# Memort optimization
|
||||
del latest_df
|
||||
gc.collect()
|
||||
ep_dict["data/chunk_index"] = chunk_idx
|
||||
ep_dict["data/file_index"] = file_idx
|
||||
|
||||
# Write the resulting dataframe from RAM to disk
|
||||
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if len(self.meta.image_keys) > 0:
|
||||
to_parquet_with_hf_images(df, path)
|
||||
else:
|
||||
df.to_parquet(path)
|
||||
|
||||
if self.hf_dataset is not None:
|
||||
# Remove hf dataset cache directory, necessary to avoid cache bloat
|
||||
cached_dir = get_hf_dataset_cache_dir(self.hf_dataset)
|
||||
if cached_dir is not None:
|
||||
shutil.rmtree(cached_dir)
|
||||
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
table = ep_dataset.with_format("arrow")[:]
|
||||
if not self.writer:
|
||||
self.writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True
|
||||
)
|
||||
self.writer.write_table(table)
|
||||
|
||||
metadata = {
|
||||
"data/chunk_index": chunk_idx,
|
||||
"data/file_index": file_idx,
|
||||
"dataset_from_index": latest_num_frames,
|
||||
"dataset_to_index": latest_num_frames + ep_num_frames,
|
||||
"dataset_from_index": global_frame_index,
|
||||
"dataset_to_index": global_frame_index + ep_num_frames,
|
||||
}
|
||||
|
||||
# Store metadata with episode data for next episode
|
||||
self.latest_episode = {**ep_dict, **metadata}
|
||||
|
||||
# Mark that the HF dataset needs reloading (lazy loading approach)
|
||||
# This avoids expensive reloading during sequential recording
|
||||
self._lazy_loading = True
|
||||
# Update recorded frames count for efficient length tracking
|
||||
self._recorded_frames += ep_num_frames
|
||||
|
||||
return metadata
|
||||
|
||||
def _save_episode_video(self, video_key: str, episode_index: int) -> dict:
|
||||
# Encode episode frames into a temporary video
|
||||
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
|
||||
if self.meta.episodes is None or (
|
||||
f"videos/{video_key}/chunk_index" not in self.meta.episodes.column_names
|
||||
or f"videos/{video_key}/file_index" not in self.meta.episodes.column_names
|
||||
if (
|
||||
episode_index == 0
|
||||
or self.meta.latest_episode is None
|
||||
or f"videos/{video_key}/chunk_index" not in self.meta.latest_episode
|
||||
):
|
||||
# Initialize indices for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
if self.meta.episodes is not None and len(self.meta.episodes) > 0:
|
||||
# It means we are resuming recording, so we need to load the latest episode
|
||||
# Update the indices to avoid overwriting the latest episode
|
||||
old_chunk_idx = self.meta.episodes[-1][f"videos/{video_key}/chunk_index"]
|
||||
old_file_idx = self.meta.episodes[-1][f"videos/{video_key}/file_index"]
|
||||
chunk_idx, file_idx = update_chunk_file_indices(
|
||||
old_chunk_idx, old_file_idx, self.meta.chunks_size
|
||||
)
|
||||
latest_duration_in_s = 0.0
|
||||
new_path = self.root / self.meta.video_path.format(
|
||||
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
@@ -1172,16 +1343,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(ep_path), str(new_path))
|
||||
else:
|
||||
# Retrieve information from the latest updated video file (possibly several episodes ago)
|
||||
latest_ep = self.meta.episodes[episode_index - 1]
|
||||
chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"]
|
||||
file_idx = latest_ep[f"videos/{video_key}/file_index"]
|
||||
# Retrieve information from the latest updated video file using latest_episode
|
||||
latest_ep = self.meta.latest_episode
|
||||
chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"][0]
|
||||
file_idx = latest_ep[f"videos/{video_key}/file_index"][0]
|
||||
|
||||
latest_path = self.root / self.meta.video_path.format(
|
||||
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
latest_size_in_mb = get_video_size_in_mb(latest_path)
|
||||
latest_duration_in_s = get_video_duration_in_s(latest_path)
|
||||
latest_size_in_mb = get_file_size_in_mb(latest_path)
|
||||
latest_duration_in_s = latest_ep[f"videos/{video_key}/to_timestamp"][0]
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.video_files_size_in_mb:
|
||||
# Move temporary episode video to a new video file in the dataset
|
||||
@@ -1315,6 +1486,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj.writer = None
|
||||
obj.latest_episode = None
|
||||
obj._current_file_start_frame = None
|
||||
# Initialize tracking for incremental recording
|
||||
obj._lazy_loading = False
|
||||
obj._recorded_frames = 0
|
||||
obj._writer_closed_for_reading = False
|
||||
return obj
|
||||
|
||||
|
||||
|
||||
@@ -206,6 +206,11 @@ class ImageTransformsConfig:
|
||||
type="SharpnessJitter",
|
||||
kwargs={"sharpness": (0.5, 1.5)},
|
||||
),
|
||||
"affine": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="RandomAffine",
|
||||
kwargs={"degrees": (-5.0, 5.0), "translate": (0.05, 0.05)},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -217,6 +222,8 @@ def make_transform_from_config(cfg: ImageTransformConfig):
|
||||
return v2.ColorJitter(**cfg.kwargs)
|
||||
elif cfg.type == "SharpnessJitter":
|
||||
return SharpnessJitter(**cfg.kwargs)
|
||||
elif cfg.type == "RandomAffine":
|
||||
return v2.RandomAffine(**cfg.kwargs)
|
||||
else:
|
||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ import pandas
|
||||
import pandas as pd
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from datasets import Dataset, concatenate_datasets
|
||||
from datasets import Dataset
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
@@ -44,7 +44,7 @@ from lerobot.datasets.backward_compatibility import (
|
||||
ForwardCompatibilityError,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_string
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
@@ -94,12 +94,6 @@ def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
|
||||
return hf_ds.data.nbytes // (1024**2)
|
||||
|
||||
|
||||
def get_hf_dataset_cache_dir(hf_ds: Dataset) -> Path | None:
|
||||
if hf_ds.cache_files is None or len(hf_ds.cache_files) == 0:
|
||||
return None
|
||||
return Path(hf_ds.cache_files[0]["filename"]).parents[2]
|
||||
|
||||
|
||||
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]:
|
||||
if file_idx == chunks_size - 1:
|
||||
file_idx = 0
|
||||
@@ -123,8 +117,9 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None)
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
||||
datasets = [Dataset.from_parquet(str(path), features=features) for path in paths]
|
||||
return concatenate_datasets(datasets)
|
||||
with SuppressProgressBars():
|
||||
datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
|
||||
return datasets
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
@@ -132,10 +127,14 @@ def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
return metadata.num_rows
|
||||
|
||||
|
||||
def get_video_size_in_mb(mp4_path: Path) -> float:
|
||||
file_size_bytes = mp4_path.stat().st_size
|
||||
file_size_mb = file_size_bytes / (1024**2)
|
||||
return file_size_mb
|
||||
def get_file_size_in_mb(file_path: Path) -> float:
|
||||
"""Get file size on disk in megabytes.
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to the file.
|
||||
"""
|
||||
file_size_bytes = file_path.stat().st_size
|
||||
return file_size_bytes / (1024**2)
|
||||
|
||||
|
||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
|
||||
@@ -0,0 +1,260 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script augments existing LeRobot datasets with quantile statistics.
|
||||
|
||||
Most datasets created before the quantile feature was added do not contain
|
||||
quantile statistics (q01, q10, q50, q90, q99) in their metadata. This script:
|
||||
|
||||
1. Loads an existing LeRobot dataset in v3.0 format
|
||||
2. Checks if it already contains quantile statistics
|
||||
3. If missing, computes quantile statistics for all features
|
||||
4. Updates the dataset metadata with the new quantile statistics
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
--repo-id=lerobot/pusht \
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.datasets.utils import write_stats
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
def has_quantile_stats(stats: dict[str, dict] | None, quantile_list_keys: list[str] | None = None) -> bool:
|
||||
"""Check if dataset statistics already contain quantile information.
|
||||
|
||||
Args:
|
||||
stats: Dataset statistics dictionary
|
||||
|
||||
Returns:
|
||||
True if quantile statistics are present, False otherwise
|
||||
"""
|
||||
if quantile_list_keys is None:
|
||||
quantile_list_keys = [f"q{int(q * 100):02d}" for q in DEFAULT_QUANTILES]
|
||||
|
||||
if stats is None:
|
||||
return False
|
||||
|
||||
for feature_stats in stats.values():
|
||||
if any(q_key in feature_stats for q_key in quantile_list_keys):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||
"""Process a single episode and return its statistics.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset
|
||||
episode_idx: Index of the episode to process
|
||||
|
||||
Returns:
|
||||
Dictionary containing episode statistics
|
||||
"""
|
||||
logging.info(f"Computing stats for episode {episode_idx}")
|
||||
|
||||
start_idx = dataset.meta.episodes[episode_idx]["dataset_from_index"]
|
||||
end_idx = dataset.meta.episodes[episode_idx]["dataset_to_index"]
|
||||
|
||||
collected_data: dict[str, list] = {}
|
||||
for idx in range(start_idx, end_idx):
|
||||
item = dataset[idx]
|
||||
for key, value in item.items():
|
||||
if key not in dataset.features:
|
||||
continue
|
||||
|
||||
if key not in collected_data:
|
||||
collected_data[key] = []
|
||||
collected_data[key].append(value)
|
||||
|
||||
ep_stats = {}
|
||||
for key, data_list in collected_data.items():
|
||||
if dataset.features[key]["dtype"] == "string":
|
||||
continue
|
||||
|
||||
data = torch.stack(data_list).cpu().numpy()
|
||||
if dataset.features[key]["dtype"] in ["image", "video"]:
|
||||
if data.dtype == np.uint8:
|
||||
data = data.astype(np.float32) / 255.0
|
||||
|
||||
axes_to_reduce = (0, 2, 3)
|
||||
keepdims = True
|
||||
else:
|
||||
axes_to_reduce = 0
|
||||
keepdims = data.ndim == 1
|
||||
|
||||
ep_stats[key] = get_feature_stats(
|
||||
data, axis=axes_to_reduce, keepdims=keepdims, quantile_list=DEFAULT_QUANTILES
|
||||
)
|
||||
|
||||
if dataset.features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
|
||||
|
||||
def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dict]:
|
||||
"""Compute quantile statistics for all episodes in the dataset.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset to compute statistics for
|
||||
|
||||
Returns:
|
||||
Dictionary containing aggregated statistics with quantiles
|
||||
|
||||
Note:
|
||||
Video decoding operations are not thread-safe, so we process episodes sequentially
|
||||
when video keys are present. For datasets without videos, we use parallel processing
|
||||
with ThreadPoolExecutor for better performance.
|
||||
"""
|
||||
logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes")
|
||||
|
||||
episode_stats_list = []
|
||||
has_videos = len(dataset.meta.video_keys) > 0
|
||||
|
||||
if has_videos:
|
||||
logging.info("Dataset contains video keys - using sequential processing for thread safety")
|
||||
for episode_idx in tqdm(range(dataset.num_episodes), desc="Processing episodes"):
|
||||
ep_stats = process_single_episode(dataset, episode_idx)
|
||||
episode_stats_list.append(ep_stats)
|
||||
else:
|
||||
logging.info("Dataset has no video keys - using parallel processing for better performance")
|
||||
max_workers = min(dataset.num_episodes, 16)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_episode = {
|
||||
executor.submit(process_single_episode, dataset, episode_idx): episode_idx
|
||||
for episode_idx in range(dataset.num_episodes)
|
||||
}
|
||||
|
||||
episode_results = {}
|
||||
with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar:
|
||||
for future in concurrent.futures.as_completed(future_to_episode):
|
||||
episode_idx = future_to_episode[future]
|
||||
ep_stats = future.result()
|
||||
episode_results[episode_idx] = ep_stats
|
||||
pbar.update(1)
|
||||
|
||||
for episode_idx in range(dataset.num_episodes):
|
||||
if episode_idx in episode_results:
|
||||
episode_stats_list.append(episode_results[episode_idx])
|
||||
|
||||
if not episode_stats_list:
|
||||
raise ValueError("No episode data found for computing statistics")
|
||||
|
||||
logging.info(f"Aggregating statistics from {len(episode_stats_list)} episodes")
|
||||
return aggregate_stats(episode_stats_list)
|
||||
|
||||
|
||||
def augment_dataset_with_quantile_stats(
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""Augment a dataset with quantile statistics if they are missing.
|
||||
|
||||
Args:
|
||||
repo_id: Repository ID of the dataset
|
||||
root: Local root directory for the dataset
|
||||
overwrite: Overwrite existing quantile statistics if they already exist
|
||||
"""
|
||||
logging.info(f"Loading dataset: {repo_id}")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
)
|
||||
|
||||
if not overwrite and has_quantile_stats(dataset.meta.stats):
|
||||
logging.info("Dataset already contains quantile statistics. No action needed.")
|
||||
return
|
||||
|
||||
logging.info("Dataset does not contain quantile statistics. Computing them now...")
|
||||
|
||||
new_stats = compute_quantile_stats_for_dataset(dataset)
|
||||
|
||||
logging.info("Updating dataset metadata with new quantile statistics")
|
||||
dataset.meta.stats = new_stats
|
||||
|
||||
write_stats(new_stats, dataset.meta.root)
|
||||
|
||||
logging.info("Successfully updated dataset with quantile statistics")
|
||||
dataset.push_to_hub()
|
||||
|
||||
hub_api = HfApi()
|
||||
try:
|
||||
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
except HTTPError as e:
|
||||
logging.info(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
|
||||
pass
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=None, repo_type="dataset")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run the augmentation script."""
|
||||
parser = argparse.ArgumentParser(description="Augment LeRobot dataset with quantile statistics")
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository ID of the dataset (e.g., 'lerobot/pusht')",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
help="Local root directory for the dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
action="store_true",
|
||||
help="Overwrite existing quantile statistics if they already exist",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
root = Path(args.root) if args.root else None
|
||||
|
||||
init_logging()
|
||||
|
||||
augment_dataset_with_quantile_stats(
|
||||
repo_id=args.repo_id,
|
||||
root=root,
|
||||
overwrite=args.overwrite,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -26,11 +26,20 @@ This script will help you convert any LeRobot dataset already pushed to the hub
|
||||
|
||||
Usage:
|
||||
|
||||
Convert a dataset from the hub:
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
--repo-id=lerobot/pusht
|
||||
```
|
||||
|
||||
Convert a local dataset (works in place):
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
--repo-id=lerobot/pusht \
|
||||
--root=/path/to/local/dataset/directory
|
||||
--push-to-hub=false
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -60,9 +69,9 @@ from lerobot.datasets.utils import (
|
||||
LEGACY_TASKS_PATH,
|
||||
cast_stats_to_numpy,
|
||||
flatten_dict,
|
||||
get_file_size_in_mb,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_parquet_num_frames,
|
||||
get_video_size_in_mb,
|
||||
load_info,
|
||||
update_chunk_file_indices,
|
||||
write_episodes,
|
||||
@@ -75,7 +84,7 @@ from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
V21 = "v2.1"
|
||||
|
||||
V30 = "v3.0"
|
||||
|
||||
"""
|
||||
-------------------------
|
||||
@@ -89,7 +98,7 @@ OLD
|
||||
videos/chunk-000/CAMERA/episode_000000.mp4
|
||||
|
||||
NEW
|
||||
videos/chunk-000/file_000.mp4
|
||||
videos/CAMERA/chunk-000/file_000.mp4
|
||||
-------------------------
|
||||
OLD
|
||||
episodes.jsonl
|
||||
@@ -145,6 +154,17 @@ def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
return tasks, task_to_task_index
|
||||
|
||||
|
||||
def validate_local_dataset_version(local_path: Path) -> None:
|
||||
"""Validate that the local dataset has the expected v2.1 version."""
|
||||
info = load_info(local_path)
|
||||
dataset_version = info.get("codebase_version", "unknown")
|
||||
if dataset_version != V21:
|
||||
raise ValueError(
|
||||
f"Local dataset has codebase version '{dataset_version}', expected '{V21}'. "
|
||||
f"This script is specifically for converting v2.1 datasets to v3.0."
|
||||
)
|
||||
|
||||
|
||||
def convert_tasks(root, new_root):
|
||||
logging.info(f"Converting tasks from {root} to {new_root}")
|
||||
tasks, _ = legacy_load_tasks(root)
|
||||
@@ -290,7 +310,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
||||
episodes_metadata = []
|
||||
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
|
||||
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
|
||||
# Check if adding this episode would exceed the limit
|
||||
@@ -407,13 +427,13 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
|
||||
|
||||
def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb):
|
||||
info = load_info(root)
|
||||
info["codebase_version"] = "v3.0"
|
||||
info["codebase_version"] = V30
|
||||
del info["total_chunks"]
|
||||
del info["total_videos"]
|
||||
info["data_files_size_in_mb"] = data_file_size_in_mb
|
||||
info["video_files_size_in_mb"] = video_file_size_in_mb
|
||||
info["data_path"] = DEFAULT_DATA_PATH
|
||||
info["video_path"] = DEFAULT_VIDEO_PATH
|
||||
info["video_path"] = DEFAULT_VIDEO_PATH if info["video_path"] is not None else None
|
||||
info["fps"] = int(info["fps"])
|
||||
logging.info(f"Converting info from {root} to {new_root}")
|
||||
for key in info["features"]:
|
||||
@@ -429,16 +449,36 @@ def convert_dataset(
|
||||
branch: str | None = None,
|
||||
data_file_size_in_mb: int | None = None,
|
||||
video_file_size_in_mb: int | None = None,
|
||||
root: str | Path | None = None,
|
||||
push_to_hub: bool = True,
|
||||
force_conversion: bool = False,
|
||||
):
|
||||
root = HF_LEROBOT_HOME / repo_id
|
||||
old_root = HF_LEROBOT_HOME / f"{repo_id}_old"
|
||||
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
|
||||
|
||||
if data_file_size_in_mb is None:
|
||||
data_file_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
if video_file_size_in_mb is None:
|
||||
video_file_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
|
||||
# First check if the dataset already has a v3.0 version
|
||||
if root is None and not force_conversion:
|
||||
try:
|
||||
print("Trying to download v3.0 version of the dataset from the hub...")
|
||||
snapshot_download(repo_id, repo_type="dataset", revision=V30, local_dir=HF_LEROBOT_HOME / repo_id)
|
||||
return
|
||||
except Exception:
|
||||
print("Dataset does not have an uploaded v3.0 version. Continuing with conversion.")
|
||||
|
||||
# Set root based on whether local dataset path is provided
|
||||
use_local_dataset = False
|
||||
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root) / repo_id
|
||||
if root.exists():
|
||||
validate_local_dataset_version(root)
|
||||
use_local_dataset = True
|
||||
print(f"Using local dataset at {root}")
|
||||
|
||||
old_root = root.parent / f"{root.name}_old"
|
||||
new_root = root.parent / f"{root.name}_v30"
|
||||
|
||||
# Handle old_root cleanup if both old_root and root exist
|
||||
if old_root.is_dir() and root.is_dir():
|
||||
shutil.rmtree(str(root))
|
||||
shutil.move(str(old_root), str(root))
|
||||
@@ -446,12 +486,13 @@ def convert_dataset(
|
||||
if new_root.is_dir():
|
||||
shutil.rmtree(new_root)
|
||||
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
repo_type="dataset",
|
||||
revision=V21,
|
||||
local_dir=root,
|
||||
)
|
||||
if not use_local_dataset:
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
repo_type="dataset",
|
||||
revision=V21,
|
||||
local_dir=root,
|
||||
)
|
||||
|
||||
convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb)
|
||||
convert_tasks(root, new_root)
|
||||
@@ -462,21 +503,22 @@ def convert_dataset(
|
||||
shutil.move(str(root), str(old_root))
|
||||
shutil.move(str(new_root), str(root))
|
||||
|
||||
hub_api = HfApi()
|
||||
try:
|
||||
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
except HTTPError as e:
|
||||
print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
|
||||
pass
|
||||
hub_api.delete_files(
|
||||
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"],
|
||||
repo_id=repo_id,
|
||||
revision=branch,
|
||||
repo_type="dataset",
|
||||
)
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
if push_to_hub:
|
||||
hub_api = HfApi()
|
||||
try:
|
||||
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
except HTTPError as e:
|
||||
print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
|
||||
pass
|
||||
hub_api.delete_files(
|
||||
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"],
|
||||
repo_id=repo_id,
|
||||
revision=branch,
|
||||
repo_type="dataset",
|
||||
)
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
LeRobotDataset(repo_id).push_to_hub()
|
||||
LeRobotDataset(repo_id).push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -507,6 +549,23 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="File size in MB. Defaults to 100 for data and 500 for videos.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Local directory to use for downloading/writing the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
type=lambda input: input.lower() == "true",
|
||||
default=True,
|
||||
help="Push the converted dataset to the hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-conversion",
|
||||
action="store_true",
|
||||
help="Force conversion even if the dataset already has a v3.0 version.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_dataset(**vars(args))
|
||||
|
||||
@@ -342,8 +342,8 @@ def encode_video_frames(
|
||||
# Define video output frame size (assuming all input frames are the same size)
|
||||
if len(input_list) == 0:
|
||||
raise FileNotFoundError(f"No images found in {imgs_dir}.")
|
||||
dummy_image = Image.open(input_list[0])
|
||||
width, height = dummy_image.size
|
||||
with Image.open(input_list[0]) as dummy_image:
|
||||
width, height = dummy_image.size
|
||||
|
||||
# Define video codec options
|
||||
video_options = {}
|
||||
@@ -373,11 +373,12 @@ def encode_video_frames(
|
||||
|
||||
# Loop through input frames and encode them
|
||||
for input_data in input_list:
|
||||
input_image = Image.open(input_data).convert("RGB")
|
||||
input_frame = av.VideoFrame.from_image(input_image)
|
||||
packet = output_stream.encode(input_frame)
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
with Image.open(input_data) as input_image:
|
||||
input_image = input_image.convert("RGB")
|
||||
input_frame = av.VideoFrame.from_image(input_image)
|
||||
packet = output_stream.encode(input_frame)
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
|
||||
# Flush the encoder
|
||||
packet = output_stream.encode()
|
||||
@@ -451,11 +452,9 @@ def concatenate_video_files(
|
||||
stream_map[input_stream.index] = output_container.add_stream_from_template(
|
||||
template=input_stream, opaque=True
|
||||
)
|
||||
stream_map[
|
||||
input_stream.index
|
||||
].time_base = (
|
||||
input_stream.time_base
|
||||
) # set the time base to the input stream time base (missing in the codec context)
|
||||
|
||||
# set the time base to the input stream time base (missing in the codec context)
|
||||
stream_map[input_stream.index].time_base = input_stream.time_base
|
||||
|
||||
# Demux + remux packets (no re-encode)
|
||||
for packet in input_container.demux():
|
||||
@@ -644,6 +643,9 @@ class VideoEncodingManager:
|
||||
)
|
||||
self.dataset._batch_save_episode_video(start_ep, end_ep)
|
||||
|
||||
# Finalize the dataset to properly close all writers
|
||||
self.dataset.finalize()
|
||||
|
||||
# Clean up episode images if recording was interrupted
|
||||
if exc_type is not None:
|
||||
interrupted_episode_index = self.dataset.num_episodes
|
||||
|
||||
@@ -12,4 +12,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
|
||||
from .configs import AlohaEnv, EnvConfig, PushtEnv # noqa: F401
|
||||
|
||||
+72
-47
@@ -37,6 +37,16 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@property
|
||||
def package_name(self) -> str:
|
||||
"""Package name to import if environment not found in gym registry"""
|
||||
return f"gym_{self.type}"
|
||||
|
||||
@property
|
||||
def gym_id(self) -> str:
|
||||
"""ID string used in gym.make() to instantiate the environment"""
|
||||
return f"{self.package_name}/{self.task}"
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def gym_kwargs(self) -> dict:
|
||||
@@ -50,6 +60,8 @@ class AlohaEnv(EnvConfig):
|
||||
fps: int = 50
|
||||
episode_length: int = 400
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
observation_height: int = 480
|
||||
observation_width: int = 640
|
||||
render_mode: str = "rgb_array"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
@@ -67,10 +79,14 @@ class AlohaEnv(EnvConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
self.features["top"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
|
||||
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
self.features["pixels/top"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
@@ -91,6 +107,8 @@ class PushtEnv(EnvConfig):
|
||||
render_mode: str = "rgb_array"
|
||||
visualization_width: int = 384
|
||||
visualization_height: int = 384
|
||||
observation_height: int = 384
|
||||
observation_width: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
@@ -108,7 +126,9 @@ class PushtEnv(EnvConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
|
||||
self.features["pixels"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
elif self.obs_type == "environment_state_agent_pos":
|
||||
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
|
||||
|
||||
@@ -123,45 +143,6 @@ class PushtEnv(EnvConfig):
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("xarm")
|
||||
@dataclass
|
||||
class XarmEnv(EnvConfig):
|
||||
task: str | None = "XarmLift-v0"
|
||||
fps: int = 15
|
||||
episode_length: int = 200
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
visualization_width: int = 384
|
||||
visualization_height: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"visualization_width": self.visualization_width,
|
||||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImagePreprocessingConfig:
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
@@ -254,7 +235,9 @@ class LiberoEnv(EnvConfig):
|
||||
render_mode: str = "rgb_array"
|
||||
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
|
||||
init_states: bool = True
|
||||
camera_name_mapping: dict[str, str] | None = (None,)
|
||||
camera_name_mapping: dict[str, str] | None = None
|
||||
observation_height: int = 360
|
||||
observation_width: int = 360
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
@@ -272,18 +255,18 @@ class LiberoEnv(EnvConfig):
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["pixels/agentview_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
|
||||
self.features["pixels/agentview_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
||||
@@ -294,3 +277,45 @@ class LiberoEnv(EnvConfig):
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("metaworld")
|
||||
@dataclass
|
||||
class MetaworldEnv(EnvConfig):
|
||||
task: str = "metaworld-push-v2" # add all tasks
|
||||
fps: int = 80
|
||||
episode_length: int = 400
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
multitask_eval: bool = True
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"top": f"{OBS_IMAGE}",
|
||||
"pixels/top": f"{OBS_IMAGE}",
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3))
|
||||
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
|
||||
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
}
|
||||
|
||||
+32
-12
@@ -16,8 +16,9 @@
|
||||
import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv, XarmEnv
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
@@ -25,8 +26,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return AlohaEnv(**kwargs)
|
||||
elif env_type == "pusht":
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "xarm":
|
||||
return XarmEnv(**kwargs)
|
||||
elif env_type == "libero":
|
||||
return LiberoEnv(**kwargs)
|
||||
else:
|
||||
@@ -63,6 +62,9 @@ def make_env(
|
||||
if "libero" in cfg.type:
|
||||
from lerobot.envs.libero import create_libero_envs
|
||||
|
||||
if cfg.task is None:
|
||||
raise ValueError("LiberoEnv requires a task to be specified")
|
||||
|
||||
return create_libero_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
@@ -71,20 +73,38 @@ def make_env(
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
elif "metaworld" in cfg.type:
|
||||
from lerobot.envs.metaworld import create_metaworld_envs
|
||||
|
||||
package_name = f"gym_{cfg.type}"
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
|
||||
raise e
|
||||
if cfg.task is None:
|
||||
raise ValueError("MetaWorld requires a task to be specified")
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.task}"
|
||||
return create_metaworld_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
if cfg.gym_id not in gym_registry:
|
||||
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
|
||||
try:
|
||||
importlib.import_module(cfg.package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
|
||||
f"Please install it or check PYTHONPATH."
|
||||
) from e
|
||||
|
||||
if cfg.gym_id not in gym_registry:
|
||||
raise gym.error.NameNotFound(
|
||||
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
|
||||
)
|
||||
|
||||
def _make_one():
|
||||
return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
|
||||
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
|
||||
|
||||
vec = env_cls([_make_one for _ in range(n_envs)])
|
||||
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
|
||||
|
||||
# normalize to {suite: {task_id: vec_env}} for consistency
|
||||
suite_name = cfg.type # e.g., "pusht", "aloha"
|
||||
|
||||
+15
-11
@@ -260,19 +260,23 @@ class LiberoEnv(gym.Env):
|
||||
|
||||
is_success = self._env.check_success()
|
||||
terminated = done or is_success
|
||||
info["is_success"] = is_success
|
||||
|
||||
info.update(
|
||||
{
|
||||
"task": self.task,
|
||||
"task_id": self.task_id,
|
||||
"done": done,
|
||||
"is_success": is_success,
|
||||
}
|
||||
)
|
||||
observation = self._format_raw_obs(raw_obs)
|
||||
if done:
|
||||
if terminated:
|
||||
info["final_info"] = {
|
||||
"task": self.task,
|
||||
"task_id": self.task_id,
|
||||
"done": bool(done),
|
||||
"is_success": bool(is_success),
|
||||
}
|
||||
self.reset()
|
||||
info.update(
|
||||
{
|
||||
"task": self.task,
|
||||
"task_id": self.task_id,
|
||||
"done": done,
|
||||
"is_success": is_success,
|
||||
}
|
||||
)
|
||||
truncated = False
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
|
||||
@@ -0,0 +1,313 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import metaworld
|
||||
import metaworld.policies as policies
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
# ---- Load configuration data from the external JSON file ----
|
||||
CONFIG_PATH = Path(__file__).parent / "metaworld_config.json"
|
||||
try:
|
||||
with open(CONFIG_PATH) as f:
|
||||
data = json.load(f)
|
||||
except FileNotFoundError as err:
|
||||
raise FileNotFoundError(
|
||||
"Could not find 'metaworld_config.json'. "
|
||||
"Please ensure the configuration file is in the same directory as the script."
|
||||
) from err
|
||||
except json.JSONDecodeError as err:
|
||||
raise ValueError(
|
||||
"Failed to decode 'metaworld_config.json'. Please ensure it is a valid JSON file."
|
||||
) from err
|
||||
|
||||
# ---- Process the loaded data ----
|
||||
|
||||
# extract and type-check top-level dicts
|
||||
task_descriptions_obj = data.get("TASK_DESCRIPTIONS")
|
||||
if not isinstance(task_descriptions_obj, dict):
|
||||
raise TypeError("Expected TASK_DESCRIPTIONS to be a dict[str, str]")
|
||||
TASK_DESCRIPTIONS: dict[str, str] = task_descriptions_obj
|
||||
|
||||
task_name_to_id_obj = data.get("TASK_NAME_TO_ID")
|
||||
if not isinstance(task_name_to_id_obj, dict):
|
||||
raise TypeError("Expected TASK_NAME_TO_ID to be a dict[str, int]")
|
||||
TASK_NAME_TO_ID: dict[str, int] = task_name_to_id_obj
|
||||
|
||||
# difficulty -> tasks mapping
|
||||
difficulty_to_tasks = data.get("DIFFICULTY_TO_TASKS")
|
||||
if not isinstance(difficulty_to_tasks, dict):
|
||||
raise TypeError("Expected 'DIFFICULTY_TO_TASKS' to be a dict[str, list[str]]")
|
||||
DIFFICULTY_TO_TASKS: dict[str, list[str]] = difficulty_to_tasks
|
||||
|
||||
# convert policy strings -> actual policy classes
|
||||
task_policy_mapping = data.get("TASK_POLICY_MAPPING")
|
||||
if not isinstance(task_policy_mapping, dict):
|
||||
raise TypeError("Expected 'TASK_POLICY_MAPPING' to be a dict[str, str]")
|
||||
TASK_POLICY_MAPPING: dict[str, Any] = {
|
||||
task_name: getattr(policies, policy_class_name)
|
||||
for task_name, policy_class_name in task_policy_mapping.items()
|
||||
}
|
||||
ACTION_DIM = 4
|
||||
OBS_DIM = 4
|
||||
|
||||
|
||||
class MetaworldEnv(gym.Env):
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
camera_name="corner2",
|
||||
obs_type="pixels",
|
||||
render_mode="rgb_array",
|
||||
observation_width=480,
|
||||
observation_height=480,
|
||||
visualization_width=640,
|
||||
visualization_height=480,
|
||||
):
|
||||
super().__init__()
|
||||
self.task = task.replace("metaworld-", "")
|
||||
self.obs_type = obs_type
|
||||
self.render_mode = render_mode
|
||||
self.observation_width = observation_width
|
||||
self.observation_height = observation_height
|
||||
self.visualization_width = visualization_width
|
||||
self.visualization_height = visualization_height
|
||||
self.camera_name = camera_name
|
||||
|
||||
self._env = self._make_envs_task(self.task)
|
||||
self._max_episode_steps = self._env.max_path_length
|
||||
self.task_description = TASK_DESCRIPTIONS[self.task]
|
||||
|
||||
self.expert_policy = TASK_POLICY_MAPPING[self.task]()
|
||||
|
||||
if self.obs_type == "state":
|
||||
raise NotImplementedError()
|
||||
elif self.obs_type == "pixels":
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
}
|
||||
)
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
dtype=np.uint8,
|
||||
),
|
||||
"agent_pos": spaces.Box(
|
||||
low=-1000.0,
|
||||
high=1000.0,
|
||||
shape=(OBS_DIM,),
|
||||
dtype=np.float64,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(ACTION_DIM,), dtype=np.float32)
|
||||
|
||||
def render(self) -> np.ndarray:
|
||||
"""
|
||||
Render the current environment frame.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The rendered RGB image from the environment.
|
||||
"""
|
||||
image = self._env.render()
|
||||
if self.camera_name == "corner2":
|
||||
# Images from this camera are flipped — correct them
|
||||
image = np.flip(image, (0, 1))
|
||||
return image
|
||||
|
||||
def _make_envs_task(self, env_name: str):
|
||||
mt1 = metaworld.MT1(env_name, seed=42)
|
||||
env = mt1.train_classes[env_name](render_mode="rgb_array", camera_name=self.camera_name)
|
||||
env.set_task(mt1.train_tasks[0])
|
||||
if self.camera_name == "corner2":
|
||||
env.model.cam_pos[2] = [
|
||||
0.75,
|
||||
0.075,
|
||||
0.7,
|
||||
] # corner2 position, similar to https://arxiv.org/pdf/2206.14244
|
||||
env.reset()
|
||||
env._freeze_rand_vec = False # otherwise no randomization
|
||||
return env
|
||||
|
||||
def _format_raw_obs(self, raw_obs: np.ndarray) -> dict[str, Any]:
|
||||
image = None
|
||||
if self._env is not None:
|
||||
image = self._env.render()
|
||||
if self.camera_name == "corner2":
|
||||
# NOTE: The "corner2" camera in MetaWorld environments outputs images with both axes inverted.
|
||||
image = np.flip(image, (0, 1))
|
||||
agent_pos = raw_obs[:4]
|
||||
if self.obs_type == "state":
|
||||
raise NotImplementedError(
|
||||
"'state' obs_type not implemented for MetaWorld. Use pixel modes instead."
|
||||
)
|
||||
|
||||
elif self.obs_type in ("pixels", "pixels_agent_pos"):
|
||||
assert image is not None, (
|
||||
"Expected `image` to be rendered before constructing pixel-based observations. "
|
||||
"This likely means `env.render()` returned None or the environment was not provided."
|
||||
)
|
||||
|
||||
if self.obs_type == "pixels":
|
||||
obs = {"pixels": image.copy()}
|
||||
|
||||
else: # pixels_agent_pos
|
||||
obs = {
|
||||
"pixels": image.copy(),
|
||||
"agent_pos": agent_pos,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown obs_type: {self.obs_type}")
|
||||
return obs
|
||||
|
||||
def reset(
|
||||
self,
|
||||
seed: int | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""
|
||||
Reset the environment to its initial state.
|
||||
|
||||
Args:
|
||||
seed (Optional[int]): Random seed for environment initialization.
|
||||
|
||||
Returns:
|
||||
observation (Dict[str, Any]): The initial formatted observation.
|
||||
info (Dict[str, Any]): Additional info about the reset state.
|
||||
"""
|
||||
super().reset(seed=seed)
|
||||
|
||||
raw_obs, info = self._env.reset(seed=seed)
|
||||
|
||||
observation = self._format_raw_obs(raw_obs)
|
||||
|
||||
info = {"is_success": False}
|
||||
return observation, info
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||
"""
|
||||
Perform one environment step.
|
||||
|
||||
Args:
|
||||
action (np.ndarray): The action to execute, must be 1-D with shape (action_dim,).
|
||||
|
||||
Returns:
|
||||
observation (Dict[str, Any]): The formatted observation after the step.
|
||||
reward (float): The scalar reward for this step.
|
||||
terminated (bool): Whether the episode terminated successfully.
|
||||
truncated (bool): Whether the episode was truncated due to a time limit.
|
||||
info (Dict[str, Any]): Additional environment info.
|
||||
"""
|
||||
if action.ndim != 1:
|
||||
raise ValueError(
|
||||
f"Expected action to be 1-D (shape (action_dim,)), "
|
||||
f"but got shape {action.shape} with ndim={action.ndim}"
|
||||
)
|
||||
raw_obs, reward, done, truncated, info = self._env.step(action)
|
||||
|
||||
# Determine whether the task was successful
|
||||
is_success = bool(info.get("success", 0))
|
||||
terminated = done or is_success
|
||||
info.update(
|
||||
{
|
||||
"task": self.task,
|
||||
"done": done,
|
||||
"is_success": is_success,
|
||||
}
|
||||
)
|
||||
|
||||
# Format the raw observation into the expected structure
|
||||
observation = self._format_raw_obs(raw_obs)
|
||||
if terminated:
|
||||
info["final_info"] = {
|
||||
"task": self.task,
|
||||
"done": bool(done),
|
||||
"is_success": bool(is_success),
|
||||
}
|
||||
self.reset()
|
||||
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def close(self):
|
||||
self._env.close()
|
||||
|
||||
|
||||
# ---- Main API ----------------------------------------------------------------
|
||||
|
||||
|
||||
def create_metaworld_envs(
|
||||
task: str,
|
||||
n_envs: int,
|
||||
gym_kwargs: dict[str, Any] | None = None,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""
|
||||
Create vectorized Meta-World environments with a consistent return shape.
|
||||
|
||||
Returns:
|
||||
dict[task_group][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories)
|
||||
Notes:
|
||||
- n_envs is the number of rollouts *per task* (episode_index = 0..n_envs-1).
|
||||
- `task` can be a single difficulty group (e.g., "easy", "medium", "hard") or a comma-separated list.
|
||||
- If a task name is not in DIFFICULTY_TO_TASKS, we treat it as a single custom task.
|
||||
"""
|
||||
if env_cls is None or not callable(env_cls):
|
||||
raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.")
|
||||
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||
|
||||
gym_kwargs = dict(gym_kwargs or {})
|
||||
task_groups = [t.strip() for t in task.split(",") if t.strip()]
|
||||
if not task_groups:
|
||||
raise ValueError("`task` must contain at least one Meta-World task or difficulty group.")
|
||||
|
||||
print(f"Creating Meta-World envs | task_groups={task_groups} | n_envs(per task)={n_envs}")
|
||||
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
|
||||
for group in task_groups:
|
||||
# if not in difficulty presets, treat it as a single custom task
|
||||
tasks = DIFFICULTY_TO_TASKS.get(group, [group])
|
||||
|
||||
for tid, task_name in enumerate(tasks):
|
||||
print(f"Building vec env | group={group} | task_id={tid} | task={task_name}")
|
||||
|
||||
# build n_envs factories
|
||||
fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)]
|
||||
|
||||
out[group][tid] = env_cls(fns)
|
||||
|
||||
# return a plain dict for consistency
|
||||
return {group: dict(task_map) for group, task_map in out.items()}
|
||||
@@ -0,0 +1,121 @@
|
||||
{
|
||||
"TASK_DESCRIPTIONS": {
|
||||
"assembly-v3": "Pick up a nut and place it onto a peg",
|
||||
"basketball-v3": "Dunk the basketball into the basket",
|
||||
"bin-picking-v3": "Grasp the puck from one bin and place it into another bin",
|
||||
"box-close-v3": "Grasp the cover and close the box with it",
|
||||
"button-press-topdown-v3": "Press a button from the top",
|
||||
"button-press-topdown-wall-v3": "Bypass a wall and press a button from the top",
|
||||
"button-press-v3": "Press a button",
|
||||
"button-press-wall-v3": "Bypass a wall and press a button",
|
||||
"coffee-button-v3": "Push a button on the coffee machine",
|
||||
"coffee-pull-v3": "Pull a mug from a coffee machine",
|
||||
"coffee-push-v3": "Push a mug under a coffee machine",
|
||||
"dial-turn-v3": "Rotate a dial 180 degrees",
|
||||
"disassemble-v3": "Pick a nut out of a peg",
|
||||
"door-close-v3": "Close a door with a revolving joint",
|
||||
"door-lock-v3": "Lock the door by rotating the lock clockwise",
|
||||
"door-open-v3": "Open a door with a revolving joint",
|
||||
"door-unlock-v3": "Unlock the door by rotating the lock counter-clockwise",
|
||||
"hand-insert-v3": "Insert the gripper into a hole",
|
||||
"drawer-close-v3": "Push and close a drawer",
|
||||
"drawer-open-v3": "Open a drawer",
|
||||
"faucet-open-v3": "Rotate the faucet counter-clockwise",
|
||||
"faucet-close-v3": "Rotate the faucet clockwise",
|
||||
"hammer-v3": "Hammer a screw on the wall",
|
||||
"handle-press-side-v3": "Press a handle down sideways",
|
||||
"handle-press-v3": "Press a handle down",
|
||||
"handle-pull-side-v3": "Pull a handle up sideways",
|
||||
"handle-pull-v3": "Pull a handle up",
|
||||
"lever-pull-v3": "Pull a lever down 90 degrees",
|
||||
"peg-insert-side-v3": "Insert a peg sideways",
|
||||
"pick-place-wall-v3": "Pick a puck, bypass a wall and place the puck",
|
||||
"pick-out-of-hole-v3": "Pick up a puck from a hole",
|
||||
"reach-v3": "Reach a goal position",
|
||||
"push-back-v3": "Push the puck to a goal",
|
||||
"push-v3": "Push the puck to a goal",
|
||||
"pick-place-v3": "Pick and place a puck to a goal",
|
||||
"plate-slide-v3": "Slide a plate into a cabinet",
|
||||
"plate-slide-side-v3": "Slide a plate into a cabinet sideways",
|
||||
"plate-slide-back-v3": "Get a plate from the cabinet",
|
||||
"plate-slide-back-side-v3": "Get a plate from the cabinet sideways",
|
||||
"peg-unplug-side-v3": "Unplug a peg sideways",
|
||||
"soccer-v3": "Kick a soccer into the goal",
|
||||
"stick-push-v3": "Grasp a stick and push a box using the stick",
|
||||
"stick-pull-v3": "Grasp a stick and pull a box with the stick",
|
||||
"push-wall-v3": "Bypass a wall and push a puck to a goal",
|
||||
"reach-wall-v3": "Bypass a wall and reach a goal",
|
||||
"shelf-place-v3": "Pick and place a puck onto a shelf",
|
||||
"sweep-into-v3": "Sweep a puck into a hole",
|
||||
"sweep-v3": "Sweep a puck off the table",
|
||||
"window-open-v3": "Push and open a window",
|
||||
"window-close-v3": "Push and close a window"
|
||||
},
|
||||
"TASK_NAME_TO_ID": {
|
||||
"assembly-v3": 0, "basketball-v3": 1, "bin-picking-v3": 2, "box-close-v3": 3,
|
||||
"button-press-topdown-v3": 4, "button-press-topdown-wall-v3": 5, "button-press-v3": 6,
|
||||
"button-press-wall-v3": 7, "coffee-button-v3": 8, "coffee-pull-v3": 9, "coffee-push-v3": 10,
|
||||
"dial-turn-v3": 11, "disassemble-v3": 12, "door-close-v3": 13, "door-lock-v3": 14,
|
||||
"door-open-v3": 15, "door-unlock-v3": 16, "drawer-close-v3": 17, "drawer-open-v3": 18,
|
||||
"faucet-close-v3": 19, "faucet-open-v3": 20, "hammer-v3": 21, "hand-insert-v3": 22,
|
||||
"handle-press-side-v3": 23, "handle-press-v3": 24, "handle-pull-side-v3": 25,
|
||||
"handle-pull-v3": 26, "lever-pull-v3": 27, "peg-insert-side-v3": 28, "peg-unplug-side-v3": 29,
|
||||
"pick-out-of-hole-v3": 30, "pick-place-v3": 31, "pick-place-wall-v3": 32,
|
||||
"plate-slide-back-side-v3": 33, "plate-slide-back-v3": 34, "plate-slide-side-v3": 35,
|
||||
"plate-slide-v3": 36, "push-back-v3": 37, "push-v3": 38, "push-wall-v3": 39, "reach-v3": 40,
|
||||
"reach-wall-v3": 41, "shelf-place-v3": 42, "soccer-v3": 43, "stick-pull-v3": 44,
|
||||
"stick-push-v3": 45, "sweep-into-v3": 46, "sweep-v3": 47, "window-open-v3": 48,
|
||||
"window-close-v3": 49
|
||||
},
|
||||
"DIFFICULTY_TO_TASKS": {
|
||||
"easy": [
|
||||
"button-press-v3", "button-press-topdown-v3", "button-press-topdown-wall-v3",
|
||||
"button-press-wall-v3", "coffee-button-v3", "dial-turn-v3", "door-close-v3",
|
||||
"door-lock-v3", "door-open-v3", "door-unlock-v3", "drawer-close-v3", "drawer-open-v3",
|
||||
"faucet-close-v3", "faucet-open-v3", "handle-press-v3", "handle-press-side-v3",
|
||||
"handle-pull-v3", "handle-pull-side-v3", "lever-pull-v3", "plate-slide-v3",
|
||||
"plate-slide-back-v3", "plate-slide-back-side-v3", "plate-slide-side-v3", "reach-v3",
|
||||
"reach-wall-v3", "window-close-v3", "window-open-v3", "peg-unplug-side-v3"
|
||||
],
|
||||
"medium": [
|
||||
"basketball-v3", "bin-picking-v3", "box-close-v3", "coffee-pull-v3", "coffee-push-v3",
|
||||
"hammer-v3", "peg-insert-side-v3", "push-wall-v3", "soccer-v3", "sweep-v3", "sweep-into-v3"
|
||||
],
|
||||
"hard": [
|
||||
"assembly-v3", "hand-insert-v3", "pick-out-of-hole-v3", "pick-place-v3", "push-v3", "push-back-v3"
|
||||
],
|
||||
"very_hard": [
|
||||
"shelf-place-v3", "disassemble-v3", "stick-pull-v3", "stick-push-v3", "pick-place-wall-v3"
|
||||
]
|
||||
},
|
||||
"TASK_POLICY_MAPPING": {
|
||||
"assembly-v3": "SawyerAssemblyV3Policy", "basketball-v3": "SawyerBasketballV3Policy",
|
||||
"bin-picking-v3": "SawyerBinPickingV3Policy", "box-close-v3": "SawyerBoxCloseV3Policy",
|
||||
"button-press-topdown-v3": "SawyerButtonPressTopdownV3Policy",
|
||||
"button-press-topdown-wall-v3": "SawyerButtonPressTopdownWallV3Policy",
|
||||
"button-press-v3": "SawyerButtonPressV3Policy", "button-press-wall-v3": "SawyerButtonPressWallV3Policy",
|
||||
"coffee-button-v3": "SawyerCoffeeButtonV3Policy", "coffee-pull-v3": "SawyerCoffeePullV3Policy",
|
||||
"coffee-push-v3": "SawyerCoffeePushV3Policy", "dial-turn-v3": "SawyerDialTurnV3Policy",
|
||||
"disassemble-v3": "SawyerDisassembleV3Policy", "door-close-v3": "SawyerDoorCloseV3Policy",
|
||||
"door-lock-v3": "SawyerDoorLockV3Policy", "door-open-v3": "SawyerDoorOpenV3Policy",
|
||||
"door-unlock-v3": "SawyerDoorUnlockV3Policy", "drawer-close-v3": "SawyerDrawerCloseV3Policy",
|
||||
"drawer-open-v3": "SawyerDrawerOpenV3Policy", "faucet-close-v3": "SawyerFaucetCloseV3Policy",
|
||||
"faucet-open-v3": "SawyerFaucetOpenV3Policy", "hammer-v3": "SawyerHammerV3Policy",
|
||||
"hand-insert-v3": "SawyerHandInsertV3Policy", "handle-press-side-v3": "SawyerHandlePressSideV3Policy",
|
||||
"handle-press-v3": "SawyerHandlePressV3Policy", "handle-pull-side-v3": "SawyerHandlePullSideV3Policy",
|
||||
"handle-pull-v3": "SawyerHandlePullV3Policy", "lever-pull-v3": "SawyerLeverPullV3Policy",
|
||||
"peg-insert-side-v3": "SawyerPegInsertionSideV3Policy", "peg-unplug-side-v3": "SawyerPegUnplugSideV3Policy",
|
||||
"pick-out-of-hole-v3": "SawyerPickOutOfHoleV3Policy", "pick-place-v3": "SawyerPickPlaceV3Policy",
|
||||
"pick-place-wall-v3": "SawyerPickPlaceWallV3Policy",
|
||||
"plate-slide-back-side-v3": "SawyerPlateSlideBackSideV3Policy",
|
||||
"plate-slide-back-v3": "SawyerPlateSlideBackV3Policy",
|
||||
"plate-slide-side-v3": "SawyerPlateSlideSideV3Policy", "plate-slide-v3": "SawyerPlateSlideV3Policy",
|
||||
"push-back-v3": "SawyerPushBackV3Policy", "push-v3": "SawyerPushV3Policy",
|
||||
"push-wall-v3": "SawyerPushWallV3Policy", "reach-v3": "SawyerReachV3Policy",
|
||||
"reach-wall-v3": "SawyerReachWallV3Policy", "shelf-place-v3": "SawyerShelfPlaceV3Policy",
|
||||
"soccer-v3": "SawyerSoccerV3Policy", "stick-pull-v3": "SawyerStickPullV3Policy",
|
||||
"stick-push-v3": "SawyerStickPushV3Policy", "sweep-into-v3": "SawyerSweepIntoV3Policy",
|
||||
"sweep-v3": "SawyerSweepV3Policy", "window-open-v3": "SawyerWindowOpenV3Policy",
|
||||
"window-close-v3": "SawyerWindowCloseV3Policy"
|
||||
}
|
||||
}
|
||||
+10
-10
@@ -48,25 +48,25 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
img = torch.from_numpy(img)
|
||||
img_tensor = torch.from_numpy(img)
|
||||
|
||||
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
|
||||
# This is the case for human-in-the-loop RL where there is only one environment.
|
||||
if img.ndim == 3:
|
||||
img = img.unsqueeze(0)
|
||||
if img_tensor.ndim == 3:
|
||||
img_tensor = img_tensor.unsqueeze(0)
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
_, h, w, c = img_tensor.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img_tensor.shape=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
assert img_tensor.dtype == torch.uint8, f"expect torch.uint8, but instead {img_tensor.dtype=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous()
|
||||
img_tensor = img_tensor.type(torch.float32)
|
||||
img_tensor /= 255
|
||||
|
||||
return_observations[imgkey] = img
|
||||
return_observations[imgkey] = img_tensor
|
||||
|
||||
if "environment_state" in observations:
|
||||
env_state = torch.from_numpy(observations["environment_state"]).float()
|
||||
|
||||
@@ -22,18 +22,18 @@ class RobotKinematics:
|
||||
self,
|
||||
urdf_path: str,
|
||||
target_frame_name: str = "gripper_frame_link",
|
||||
joint_names: list[str] = None,
|
||||
joint_names: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize placo-based kinematics solver.
|
||||
|
||||
Args:
|
||||
urdf_path: Path to the robot URDF file
|
||||
target_frame_name: Name of the end-effector frame in the URDF
|
||||
joint_names: List of joint names to use for the kinematics solver
|
||||
urdf_path (str): Path to the robot URDF file
|
||||
target_frame_name (str): Name of the end-effector frame in the URDF
|
||||
joint_names (list[str] | None): List of joint names to use for the kinematics solver
|
||||
"""
|
||||
try:
|
||||
import placo
|
||||
import placo # type: ignore[import-not-found] # C++ library with Python bindings, no type stubs available. TODO: Create stub file or request upstream typing support.
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"placo is required for RobotKinematics. "
|
||||
@@ -52,7 +52,7 @@ class RobotKinematics:
|
||||
# Initialize frame task for IK
|
||||
self.tip_frame = self.solver.add_frame_task(self.target_frame_name, np.eye(4))
|
||||
|
||||
def forward_kinematics(self, joint_pos_deg):
|
||||
def forward_kinematics(self, joint_pos_deg: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Compute forward kinematics for given joint configuration given the target frame name in the constructor.
|
||||
|
||||
@@ -77,8 +77,12 @@ class RobotKinematics:
|
||||
return self.robot.get_T_world_frame(self.target_frame_name)
|
||||
|
||||
def inverse_kinematics(
|
||||
self, current_joint_pos, desired_ee_pose, position_weight=1.0, orientation_weight=0.01
|
||||
):
|
||||
self,
|
||||
current_joint_pos: np.ndarray,
|
||||
desired_ee_pose: np.ndarray,
|
||||
position_weight: float = 1.0,
|
||||
orientation_weight: float = 0.01,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute inverse kinematics using placo solver.
|
||||
|
||||
|
||||
@@ -1 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus
|
||||
|
||||
@@ -60,7 +60,7 @@ class OperatingMode(Enum):
|
||||
|
||||
# This mode controls position. This mode is identical to the Multi-turn Position Control from existing
|
||||
# DYNAMIXEL. 512 turns are supported(-256[rev] ~ 256[rev]). This mode is ideal for multi-turn wrists or
|
||||
# conveyer systems or a system that requires an additional reduction gear. Note that Max Position
|
||||
# conveyor systems or a system that requires an additional reduction gear. Note that Max Position
|
||||
# Limit(48), Min Position Limit(52) are not used on Extended Position Control Mode.
|
||||
EXTENDED_POSITION = 4
|
||||
|
||||
|
||||
@@ -206,8 +206,12 @@ MODEL_BAUDRATE_TABLE = {
|
||||
# Sign-Magnitude encoding bits
|
||||
STS_SMS_SERIES_ENCODINGS_TABLE = {
|
||||
"Homing_Offset": 11,
|
||||
"Goal_Position": 15,
|
||||
"Goal_Velocity": 15,
|
||||
"Goal_Speed": 15,
|
||||
"Present_Position": 15,
|
||||
"Present_Velocity": 15,
|
||||
"Present_Speed": 15,
|
||||
}
|
||||
|
||||
MODEL_ENCODING_TABLE = {
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
@@ -79,7 +80,11 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Used by Physical Intelligence to train Pi0"""
|
||||
"""Used by Physical Intelligence to train Pi0.
|
||||
|
||||
Automatically scales warmup and decay steps if num_training_steps < num_decay_steps.
|
||||
This ensures the learning rate schedule completes properly even with shorter training runs.
|
||||
"""
|
||||
|
||||
num_warmup_steps: int
|
||||
num_decay_steps: int
|
||||
@@ -87,23 +92,39 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
decay_lr: float
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
del num_training_steps
|
||||
# Auto-scale scheduler parameters if training steps are shorter than configured decay steps
|
||||
actual_warmup_steps = self.num_warmup_steps
|
||||
actual_decay_steps = self.num_decay_steps
|
||||
|
||||
if num_training_steps < self.num_decay_steps:
|
||||
# Calculate scaling factor to fit the schedule into the available training steps
|
||||
scale_factor = num_training_steps / self.num_decay_steps
|
||||
actual_warmup_steps = int(self.num_warmup_steps * scale_factor)
|
||||
actual_decay_steps = num_training_steps
|
||||
|
||||
logging.info(
|
||||
f"Auto-scaling LR scheduler: "
|
||||
f"num_training_steps ({num_training_steps}) < num_decay_steps ({self.num_decay_steps}). "
|
||||
f"Scaling warmup: {self.num_warmup_steps} → {actual_warmup_steps}, "
|
||||
f"decay: {self.num_decay_steps} → {actual_decay_steps} "
|
||||
f"(scale factor: {scale_factor:.3f})"
|
||||
)
|
||||
|
||||
def lr_lambda(current_step):
|
||||
def linear_warmup_schedule(current_step):
|
||||
if current_step <= 0:
|
||||
return 1 / (self.num_warmup_steps + 1)
|
||||
frac = 1 - current_step / self.num_warmup_steps
|
||||
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
|
||||
return 1 / (actual_warmup_steps + 1)
|
||||
frac = 1 - current_step / actual_warmup_steps
|
||||
return (1 / (actual_warmup_steps + 1) - 1) * frac + 1
|
||||
|
||||
def cosine_decay_schedule(current_step):
|
||||
step = min(current_step, self.num_decay_steps)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
||||
step = min(current_step, actual_decay_steps)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / actual_decay_steps))
|
||||
alpha = self.decay_lr / self.peak_lr
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
return decayed
|
||||
|
||||
if current_step < self.num_warmup_steps:
|
||||
if current_step < actual_warmup_steps:
|
||||
return linear_warmup_schedule(current_step)
|
||||
|
||||
return cosine_decay_schedule(current_step)
|
||||
|
||||
@@ -14,8 +14,9 @@
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0.processor_pi0 import Pi0NewLineProcessor
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
@@ -25,7 +26,9 @@ __all__ = [
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"PI0Config",
|
||||
"PI05Config",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
"GrootConfig",
|
||||
]
|
||||
|
||||
@@ -626,8 +626,8 @@ class ACTDecoderLayer(nn.Module):
|
||||
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
|
||||
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
|
||||
cross-attending with.
|
||||
decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
|
||||
encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
|
||||
encoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
|
||||
decoder_pos_embed: (DS, 1, C) positional embedding for the queries (from the decoder).
|
||||
Returns:
|
||||
(DS, B, C) tensor of decoder output features.
|
||||
"""
|
||||
|
||||
@@ -90,16 +90,16 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
# stack n latest observations from the queue
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
actions = self.diffusion.generate_actions(batch, noise=noise)
|
||||
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method handles caching a history of observations and an action trajectory generated by the
|
||||
@@ -131,7 +131,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
actions = self.predict_action_chunk(batch, noise=noise)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues[ACTION].popleft()
|
||||
@@ -199,17 +199,25 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
# ========= inference ============
|
||||
def conditional_sample(
|
||||
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
|
||||
self,
|
||||
batch_size: int,
|
||||
global_cond: Tensor | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
noise: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
dtype = get_dtype_from_parameters(self)
|
||||
|
||||
# Sample prior.
|
||||
sample = torch.randn(
|
||||
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
sample = (
|
||||
noise
|
||||
if noise is not None
|
||||
else torch.randn(
|
||||
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
)
|
||||
)
|
||||
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_steps)
|
||||
@@ -264,7 +272,7 @@ class DiffusionModel(nn.Module):
|
||||
# Concatenate features then flatten to (B, global_cond_dim).
|
||||
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||
|
||||
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def generate_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""
|
||||
This function expects `batch` to have:
|
||||
{
|
||||
@@ -282,7 +290,7 @@ class DiffusionModel(nn.Module):
|
||||
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||
|
||||
# run sampling
|
||||
actions = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||
actions = self.conditional_sample(batch_size, global_cond=global_cond, noise=noise)
|
||||
|
||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||
start = n_obs_steps - 1
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user