mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c5925399a9 | |||
| f478ae5bfa | |||
| b4d40d0228 | |||
| db5c26f07d | |||
| 8904768db4 | |||
| b0efa73520 | |||
| 00b662de02 | |||
| 5c51a74484 | |||
| db8547e35d | |||
| c17d949531 | |||
| 1e131f93f8 | |||
| 2fb5c7add0 | |||
| 4f2ef024d8 | |||
| 6139b133ca | |||
| 85de893fa7 | |||
| a4c66e530b | |||
| a225127527 | |||
| e489ba24fc | |||
| d324ffe810 | |||
| 1a24f770d3 | |||
| 92fba37225 | |||
| 3e45120272 | |||
| f0d2b37beb | |||
| cbc8bfb2e6 | |||
| 0d1be72dc8 | |||
| 96b7c212c4 | |||
| 4303b3c930 | |||
| 63dca86df8 | |||
| 8a0cc3d664 | |||
| 8bb8ed4803 |
@@ -44,7 +44,7 @@ permissions:
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
PYTHON_VERSION: "3.12"
|
||||
|
||||
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
|
||||
concurrency:
|
||||
@@ -61,6 +61,7 @@ jobs:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
@@ -89,5 +90,11 @@ jobs:
|
||||
- name: Install lerobot with test extras
|
||||
run: uv sync --extra "test"
|
||||
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
|
||||
- name: Run pytest
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
|
||||
@@ -37,7 +37,7 @@ permissions:
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
PYTHON_VERSION: "3.12"
|
||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
|
||||
|
||||
# Ensures that only the latest action is built, canceling older runs.
|
||||
@@ -60,6 +60,7 @@ jobs:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
@@ -87,6 +88,12 @@ jobs:
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
|
||||
@@ -162,6 +169,7 @@ jobs:
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
@@ -173,8 +181,13 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Fix ptxas permissions
|
||||
run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
run: chmod +x /lerobot/.venv/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
|
||||
@@ -28,7 +28,7 @@ on:
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
PYTHON_VERSION: "3.12"
|
||||
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest
|
||||
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest
|
||||
|
||||
@@ -119,6 +119,7 @@ jobs:
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --shm-size "16gb"
|
||||
@@ -130,6 +131,11 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Run pytest on CPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
@@ -146,6 +152,7 @@ jobs:
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
@@ -157,6 +164,11 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
@@ -174,6 +186,7 @@ jobs:
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
CUDA_VISIBLE_DEVICES: "0,1,2,3"
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
@@ -185,12 +198,15 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Verify GPU availability
|
||||
run: |
|
||||
nvidia-smi
|
||||
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
||||
|
||||
- name: Run multi-GPU training tests
|
||||
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
|
||||
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
|
||||
timeout-minutes: 10
|
||||
run: pytest -vv tests/training/
|
||||
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Run pre-commit hooks
|
||||
uses: pre-commit/action@v3.0.1 # zizmor: ignore[unpinned-uses]
|
||||
|
||||
@@ -22,7 +22,7 @@ on:
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
PYTHON_VERSION: "3.12"
|
||||
|
||||
jobs:
|
||||
# This job builds the Python package and publishes it to PyPI
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Extract Version
|
||||
id: extract_info
|
||||
@@ -83,14 +83,6 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Remove Tags with Git dependencies
|
||||
# TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies.
|
||||
run: |
|
||||
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
|
||||
grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
|
||||
sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml
|
||||
echo "::info:: Git dependencies removed. Proceeding with build."
|
||||
|
||||
- name: Install build dependencies
|
||||
run: python -m pip install build
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ permissions:
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
PYTHON_VERSION: "3.12"
|
||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound
|
||||
|
||||
# Ensures that only the latest action is built, canceling older runs.
|
||||
@@ -48,6 +48,7 @@ jobs:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
@@ -79,7 +80,11 @@ jobs:
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv
|
||||
|
||||
@@ -137,6 +142,7 @@ jobs:
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
@@ -148,6 +154,11 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv
|
||||
- name: Run end-to-end tests
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
default_language_version:
|
||||
python: python3.10
|
||||
python: python3.12
|
||||
|
||||
exclude: "tests/artifacts/.*\\.safetensors$"
|
||||
|
||||
@@ -55,7 +55,7 @@ repos:
|
||||
rev: v3.21.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py310-plus]
|
||||
args: [--py312-plus]
|
||||
|
||||
##### Markdown Quality #####
|
||||
- repo: https://github.com/rbubley/mirrors-prettier
|
||||
|
||||
@@ -135,7 +135,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
||||
|
||||
## Citation
|
||||
|
||||
If you use LeRobot in your research, please cite:
|
||||
If you use LeRobot in your project, please cite the GitHub repository to acknowledge the ongoing development and contributors:
|
||||
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
@@ -146,6 +146,23 @@ If you use LeRobot in your research, please cite:
|
||||
}
|
||||
```
|
||||
|
||||
If you are referencing our research or the academic paper, please also cite our ICLR publication:
|
||||
|
||||
<details>
|
||||
<summary><b>ICLR 2026 Paper</b></summary>
|
||||
|
||||
```bibtex
|
||||
@inproceedings{cadenelerobot,
|
||||
title={LeRobot: An Open-Source Library for End-to-End Robot Learning},
|
||||
author={Cadene, Remi and Alibert, Simon and Capuano, Francesco and Aractingi, Michel and Zouitine, Adil and Kooijmans, Pepijn and Choghari, Jade and Russi, Martino and Pascal, Caroline and Palma, Steven and Shukor, Mustafa and Moss, Jess and Soare, Alexander and Aubakirova, Dana and Lhoest, Quentin and Gallou\'edec, Quentin and Wolf, Thomas},
|
||||
booktitle={The Fourteenth International Conference on Learning Representations},
|
||||
year={2026},
|
||||
url={https://arxiv.org/abs/2602.22818}
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Contribute
|
||||
|
||||
We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](./CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!
|
||||
|
||||
@@ -24,7 +24,7 @@ ARG OS_VERSION=22.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
# Define Python version argument
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG PYTHON_VERSION=3.12
|
||||
|
||||
# Configure environment variables
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
# docker run -it --rm lerobot-user
|
||||
|
||||
# Configure the base image
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG PYTHON_VERSION=3.12
|
||||
FROM python:${PYTHON_VERSION}-slim
|
||||
|
||||
# Configure environment variables
|
||||
|
||||
@@ -31,6 +31,8 @@
|
||||
title: Using Subtasks in the Dataset
|
||||
- local: streaming_video_encoding
|
||||
title: Streaming Video Encoding
|
||||
- local: multi_dataset_training
|
||||
title: Multi-Dataset Training
|
||||
title: "Datasets"
|
||||
- sections:
|
||||
- local: act
|
||||
|
||||
@@ -48,7 +48,7 @@ python -m lerobot.async_inference.robot_client \
|
||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server (cuda, mps, xpu, cpu)
|
||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||
|
||||
@@ -32,7 +32,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
# your policy-specific dependencies
|
||||
]
|
||||
requires-python = ">= 3.11"
|
||||
requires-python = ">= 3.12"
|
||||
|
||||
[build-system]
|
||||
build-backend = # your-build-backend
|
||||
@@ -82,7 +82,7 @@ Create your policy implementation by inheriting from LeRobot's base `PreTrainedP
|
||||
# modeling_my_custom_policy.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
@@ -91,7 +91,7 @@ class MyCustomPolicy(PreTrainedPolicy):
|
||||
config_class = MyCustomPolicyConfig
|
||||
name = "my_custom_policy"
|
||||
|
||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None):
|
||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||
super().__init__(config, dataset_stats)
|
||||
...
|
||||
```
|
||||
@@ -102,7 +102,7 @@ Create processor functions:
|
||||
|
||||
```python
|
||||
# processor_my_custom_policy.py
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ The EarthRover Mini Plus is a fully open source mobile robot that connects throu
|
||||
### Hardware
|
||||
|
||||
- EarthRover Mini robot
|
||||
- Computer with Python 3.10 or newer
|
||||
- Computer with Python 3.12 or newer
|
||||
- Internet connection
|
||||
|
||||
### Setting Up the Frodobots SDK
|
||||
@@ -170,13 +170,13 @@ Once you can drive the robot well, you can start recording data to train AI mode
|
||||
We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face username:
|
||||
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
|
||||
@@ -155,10 +155,10 @@ Upload your repository to Hugging Face:
|
||||
pip install huggingface_hub
|
||||
|
||||
# Login to Hugging Face
|
||||
huggingface-cli login
|
||||
hf auth login
|
||||
|
||||
# Create a new repository
|
||||
huggingface-cli repo create my-custom-env --type space --org my-org
|
||||
hf repo create my-org/my-custom-env
|
||||
|
||||
# Initialize git and push
|
||||
git init
|
||||
|
||||
@@ -159,7 +159,7 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
|
||||
Add your token to the CLI by running this command:
|
||||
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
@@ -327,7 +327,7 @@ You can look for other LeRobot datasets on the hub by searching for `LeRobot` [t
|
||||
You can also push your local dataset to the Hub manually, running:
|
||||
|
||||
```bash
|
||||
huggingface-cli upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
|
||||
hf upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
|
||||
```
|
||||
|
||||
#### Record function
|
||||
@@ -491,7 +491,7 @@ If your local computer doesn't have a powerful GPU you could utilize Google Cola
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
|
||||
```bash
|
||||
huggingface-cli upload ${HF_USER}/act_so101_test \
|
||||
hf upload ${HF_USER}/act_so101_test \
|
||||
outputs/train/act_so101_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
@@ -499,7 +499,7 @@ You can also upload intermediate checkpoints with:
|
||||
|
||||
```bash
|
||||
CKPT=010000
|
||||
huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
hf upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
outputs/train/act_so101_test/checkpoints/${CKPT}/pretrained_model
|
||||
```
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Installation
|
||||
|
||||
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
|
||||
This guide uses `conda` (via miniforge) to manage environments (recommended). If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and `ffmpeg` installed with the `libsvtav1` encoder, then skip ahead to [Environment Setup](#step-2-environment-setup).
|
||||
|
||||
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
|
||||
## Step 1 (`conda` only): Install [`miniforge`](https://conda-forge.org/download/)
|
||||
|
||||
```bash
|
||||
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
|
||||
@@ -11,22 +11,47 @@ bash Miniforge3-$(uname)-$(uname -m).sh
|
||||
|
||||
## Step 2: Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10, using conda:
|
||||
Create a virtual environment with Python 3.12:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
<hfoptions id="create_venv">
|
||||
<hfoption id="conda">
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda create -y -n lerobot python=3.12
|
||||
```
|
||||
|
||||
Then activate your conda environment, you have to do this each time you open a shell to use lerobot:
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="uv">
|
||||
```bash
|
||||
uv python install 3.12
|
||||
uv venv --python 3.12
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
Then activate your virtual environment, you have to do this each time you open a shell to use lerobot:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
<hfoptions id="activate_venv">
|
||||
<hfoption id="conda">```bash
|
||||
conda activate lerobot
|
||||
```</hfoption>
|
||||
<hfoption id="uv">
|
||||
```bash
|
||||
# Linux/macOSsource
|
||||
source .venv/bin/activate
|
||||
# Windows PowerShell
|
||||
source .venv\Scripts\Activate.ps1
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
When using `conda`, install `ffmpeg` in your environment:
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
ffmpeg -version # ffmpeg 8.X is not yet supported !
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
@@ -47,6 +72,9 @@ conda install ffmpeg -c conda-forge
|
||||
> conda install evdev -c conda-forge
|
||||
> ```
|
||||
|
||||
> [!IMPORTANT]
|
||||
> If you are using `uv` you will have to install `ffmpeg` system-wide (outside of the virtual environment). You rely on `uv` and `torchcodec` ability to dynamically link to the system `ffmpeg`.
|
||||
|
||||
## Step 3: Install LeRobot 🤗
|
||||
|
||||
### From Source
|
||||
@@ -60,23 +88,45 @@ cd lerobot
|
||||
|
||||
Then, install the library in editable mode. This is useful if you plan to contribute to the code.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
<hfoptions id="install_lerobot_src">
|
||||
<hfoption id="conda">
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="uv">
|
||||
```bash
|
||||
uv pip install -e .
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Installation from PyPI
|
||||
|
||||
**Core Library:**
|
||||
Install the base package with:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
<hfoptions id="install_lerobot_pypi">
|
||||
<hfoption id="conda">
|
||||
```bash
|
||||
pip install lerobot
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="uv">
|
||||
```bash
|
||||
uv pip install lerobot
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
_This installs only the default dependencies._
|
||||
|
||||
**Extra Features:**
|
||||
To install additional functionality, use one of the following:
|
||||
To install additional functionality, use one of the following (If you are using `uv`, replace `pip install` with `uv pip install` in the commands below.):
|
||||
|
||||
```bash
|
||||
pip install 'lerobot[all]' # All available features
|
||||
@@ -90,13 +140,10 @@ _Replace `[...]` with your desired features._
|
||||
For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||
To install these for linux run:
|
||||
To install these for Linux run:
|
||||
|
||||
```bash
|
||||
sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev
|
||||
@@ -106,7 +153,7 @@ For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/
|
||||
|
||||
## Optional dependencies
|
||||
|
||||
LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`.
|
||||
LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`. If you are using `uv`, replace `pip install` with `uv pip install` in the commands below.
|
||||
|
||||
### Simulations
|
||||
|
||||
|
||||
@@ -279,13 +279,13 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
|
||||
Add your token to the CLI by running this command:
|
||||
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
# Multi-Dataset Training
|
||||
|
||||
This guide covers how to train a single policy on multiple heterogeneous datasets using `MultiLeRobotDataset`.
|
||||
|
||||
## Overview
|
||||
|
||||
Real-world robot learning datasets come from different environments, robots, and camera setups. A RoboCasa dataset might have three cameras named `robot0_agentview_left`, `robot0_agentview_right`, and `robot0_eye_in_hand`, while a LIBERO dataset uses `observation.images.front` and `observation.images.wrist`, and a RoboMME dataset uses bare `image` and `wrist_image` keys. State and action dimensions also differ.
|
||||
|
||||
`MultiLeRobotDataset` lets you train on all of them jointly by:
|
||||
|
||||
- **Mapping** each dataset's feature keys into a shared namespace
|
||||
- **Padding** features that a dataset doesn't have with zeros
|
||||
- **Weighting** how often each dataset is sampled
|
||||
- **Transforming** samples per-dataset (e.g. padding actions to a common dimension)
|
||||
- **Aggregating** statistics across all sub-datasets for normalization
|
||||
|
||||
## Configuration
|
||||
|
||||
Multi-dataset training is configured via `MultiDatasetConfig` in a YAML config file. Instead of a single `dataset.repo_id`, you provide a `datasets` list where each entry is a `SubDatasetConfig`.
|
||||
|
||||
### SubDatasetConfig fields
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `repo_id` | `str` | required | HuggingFace repo ID or local dataset name |
|
||||
| `root` | `str \| None` | `None` | Local root directory for the dataset |
|
||||
| `episodes` | `list[int] \| None` | `None` | Subset of episode indices to use |
|
||||
| `revision` | `str \| None` | `None` | Dataset version / revision |
|
||||
| `video_backend` | `str` | auto | Video decoding backend (`pyav`, `torchcodec`, etc.) |
|
||||
| `weight` | `float` | `1.0` | Relative sampling weight for this dataset |
|
||||
| `feature_map` | `dict[str, str]` | `{}` | Maps dataset keys to unified policy keys |
|
||||
| `transforms` | `list` | `None` | Per-dataset transform steps (applied per sample) |
|
||||
|
||||
### Example: Three-dataset config
|
||||
|
||||
```yaml
|
||||
dataset:
|
||||
type: multi
|
||||
use_imagenet_stats: true
|
||||
datasets:
|
||||
# RoboCasa: 3 cameras, state(16), action(12)
|
||||
- repo_id: pepijn223/robocasa_PrepareCoffee
|
||||
root: /data/robocasa_PrepareCoffee
|
||||
weight: 1.0
|
||||
feature_map:
|
||||
observation.images.robot0_agentview_left: observation.images.front_left
|
||||
observation.images.robot0_agentview_right: observation.images.front_right
|
||||
observation.images.robot0_eye_in_hand: observation.images.wrist
|
||||
|
||||
# LIBERO-plus: 2 cameras, state(8), action(7)
|
||||
- repo_id: pepijn223/libero_plus_lerobot
|
||||
root: /data/libero_plus_lerobot
|
||||
weight: 0.5
|
||||
feature_map:
|
||||
observation.images.front: observation.images.front_left
|
||||
observation.images.wrist: observation.images.wrist
|
||||
transforms:
|
||||
- type: pad_action
|
||||
kwargs: {target_dim: 12}
|
||||
- type: pad_state
|
||||
kwargs: {target_dim: 16}
|
||||
|
||||
# RoboMME: 2 cameras (non-standard keys), state(8), action(8)
|
||||
- repo_id: pepijn223/robomme_data_lerobot
|
||||
root: /data/robomme_data_lerobot
|
||||
weight: 0.3
|
||||
feature_map:
|
||||
image: observation.images.front_left
|
||||
wrist_image: observation.images.wrist
|
||||
state: observation.state
|
||||
actions: action
|
||||
transforms:
|
||||
- type: pad_action
|
||||
kwargs: {target_dim: 12}
|
||||
- type: pad_state
|
||||
kwargs: {target_dim: 16}
|
||||
```
|
||||
|
||||
## Feature Mapping
|
||||
|
||||
The `feature_map` dictionary renames dataset-local keys into a shared namespace. Keys not listed pass through unchanged. In the example above, all three datasets end up with the same camera key names (`observation.images.front_left`, `observation.images.wrist`) even though they use different conventions internally.
|
||||
|
||||
After mapping, the **union** of all features across datasets defines the unified schema. If a feature exists in some datasets but not others, it is automatically zero-padded for datasets that lack it, and a boolean `{key}_is_pad` flag is added to the sample so the policy can optionally mask padded features.
|
||||
|
||||
## Automatic Padding
|
||||
|
||||
When a sub-dataset doesn't have a feature that exists in the unified schema:
|
||||
|
||||
- **Images/videos**: padded with a black frame (zeros) matching the expected resolution
|
||||
- **Float tensors** (state, action): padded with zeros
|
||||
- **Integer/bool tensors**: padded with zeros / False
|
||||
|
||||
A companion `{key}_is_pad = True` tensor is added so the model can distinguish real data from padding.
|
||||
|
||||
## Per-Dataset Transforms
|
||||
|
||||
Each sub-dataset can have its own `transforms` pipeline that runs after feature renaming but before cross-dataset padding. This is useful for making shapes compatible before PyTorch's collate function stacks the batch.
|
||||
|
||||
### Built-in transforms
|
||||
|
||||
| Name | Description | Parameters |
|
||||
|------|-------------|------------|
|
||||
| `pad_action` | Zero-pad `action` to a target dimension | `target_dim: int` |
|
||||
| `pad_state` | Zero-pad `observation.state` to a target dimension | `target_dim: int` |
|
||||
| `resize_images` | Resize all `observation.images.*` tensors | `height: int`, `width: int` |
|
||||
|
||||
### Custom transforms
|
||||
|
||||
You can register your own transforms in `lerobot/datasets/transforms.py`:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.transforms import DatasetTransformStep, register_dataset_transform
|
||||
|
||||
@register_dataset_transform("my_transform")
|
||||
class MyTransform(DatasetTransformStep):
|
||||
def __init__(self, some_param: int):
|
||||
self.some_param = some_param
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
# Modify sample in-place or return a new dict
|
||||
sample["action"] = sample["action"] * self.some_param
|
||||
return sample
|
||||
```
|
||||
|
||||
Then reference it in the config:
|
||||
|
||||
```yaml
|
||||
transforms:
|
||||
- type: my_transform
|
||||
kwargs: {some_param: 2}
|
||||
```
|
||||
|
||||
## Weighted Sampling
|
||||
|
||||
The `weight` field on each sub-dataset controls how often it is sampled during training. Weights are relative and automatically normalized to probabilities. For example, with weights `[1.0, 0.5, 0.3]`, the first dataset is sampled roughly 56% of the time, the second 28%, and the third 16%.
|
||||
|
||||
This uses `WeightedEpisodeAwareSampler`, which respects episode boundaries (so `drop_n_last_frames` and similar policy settings work correctly) while sampling across datasets proportionally.
|
||||
|
||||
## Stats Aggregation
|
||||
|
||||
Normalization statistics (mean, std, min, max, quantiles) are automatically aggregated across all sub-datasets using the mapped feature keys. The aggregation uses a weighted parallel variance algorithm so that datasets with more frames contribute proportionally to the global statistics.
|
||||
|
||||
The aggregated stats are used by the standard LeRobot preprocessor for normalization during training.
|
||||
|
||||
## Training
|
||||
|
||||
Launch training the same way as single-dataset training. The factory and training script automatically detect `MultiDatasetConfig` and set up the weighted sampler:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--config_path path/to/multi_dataset_config.yaml
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
The data flow during training with `MultiLeRobotDataset`:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ MultiLeRobotDataset.__getitem__(global_idx) │
|
||||
│ │
|
||||
│ 1. Map global_idx → (dataset_idx, local_idx) │
|
||||
│ 2. Fetch sample from sub-dataset │
|
||||
│ 3. Rename keys via feature_map │
|
||||
│ 4. Apply per-dataset transforms (pad_action, etc.) │
|
||||
│ 5. Zero-pad missing features + add _is_pad flags │
|
||||
│ 6. Add dataset_index tag │
|
||||
└─────────────────────┬───────────────────────────────────┘
|
||||
│
|
||||
┌────────────▼────────────┐
|
||||
│ PyTorch DataLoader │
|
||||
│ (collates into batch) │
|
||||
└────────────┬────────────┘
|
||||
│
|
||||
┌────────────▼────────────┐
|
||||
│ LeRobot Preprocessor │
|
||||
│ (normalize, tokenize) │
|
||||
└────────────┬────────────┘
|
||||
│
|
||||
┌────────────▼────────────┐
|
||||
│ Policy forward + loss │
|
||||
└─────────────────────────┘
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### `NewMultiLeRobotDataset`
|
||||
|
||||
```python
|
||||
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
|
||||
|
||||
dataset = NewMultiLeRobotDataset(
|
||||
configs=[...], # list[SubDatasetConfig]
|
||||
image_transforms=None, # optional image augmentation
|
||||
delta_timestamps=None, # optional temporal neighbors
|
||||
tolerance_s=1e-4, # timestamp tolerance
|
||||
)
|
||||
|
||||
dataset.num_frames # total frames across all sub-datasets
|
||||
dataset.num_episodes # total episodes
|
||||
dataset.meta # MultiDatasetMeta (stats, features, episodes)
|
||||
dataset.dataset_weights # list of per-dataset weights
|
||||
dataset.features # unified feature dict (union of all mapped features)
|
||||
dataset.camera_keys # unified camera key list
|
||||
```
|
||||
|
||||
### `WeightedEpisodeAwareSampler`
|
||||
|
||||
```python
|
||||
from lerobot.datasets.sampler import WeightedEpisodeAwareSampler
|
||||
|
||||
sampler = WeightedEpisodeAwareSampler(
|
||||
dataset_from_indices=dataset.meta.episodes["dataset_from_index"],
|
||||
dataset_to_indices=dataset.meta.episodes["dataset_to_index"],
|
||||
dataset_membership=dataset.meta.episodes["dataset_source"],
|
||||
dataset_weights=dataset.dataset_weights,
|
||||
shuffle=True,
|
||||
)
|
||||
```
|
||||
|
||||
### `DatasetTransformPipeline`
|
||||
|
||||
```python
|
||||
from lerobot.datasets.transforms import DatasetTransformPipeline, DatasetTransformStepConfig
|
||||
|
||||
pipeline = DatasetTransformPipeline([
|
||||
DatasetTransformStepConfig(type="pad_action", kwargs={"target_dim": 12}),
|
||||
DatasetTransformStepConfig(type="pad_state", kwargs={"target_dim": 16}),
|
||||
])
|
||||
|
||||
sample = pipeline(sample) # modifies the sample dict
|
||||
```
|
||||
@@ -34,11 +34,6 @@ As described by Physical Intelligence, while AI has achieved remarkable success
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Training Data and Capabilities
|
||||
|
||||
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
|
||||
|
||||
@@ -36,11 +36,6 @@ This diverse training mixture creates a "curriculum" that enables generalization
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Usage
|
||||
|
||||
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
+10
-15
@@ -43,16 +43,11 @@ This approach can transform **any existing VLM** into a VLA by training it to pr
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install the pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Training a Custom FAST Tokenizer
|
||||
|
||||
You have two options for the FAST tokenizer:
|
||||
|
||||
1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
|
||||
1. **Use the pre-trained tokenizer**: The `lerobot/fast-action-tokenizer` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
|
||||
|
||||
2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data.
|
||||
|
||||
@@ -114,15 +109,15 @@ lerobot-train \
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
| -------------------------------------- | -------------------------------------------------- | ---------------------------- |
|
||||
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
|
||||
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
|
||||
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
|
||||
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
|
||||
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
|
||||
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` |
|
||||
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
|
||||
| Parameter | Description | Default |
|
||||
| -------------------------------------- | -------------------------------------------------- | ------------------------------- |
|
||||
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
|
||||
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
|
||||
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
|
||||
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
|
||||
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
|
||||
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `lerobot/fast-action-tokenizer` |
|
||||
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
|
||||
|
||||
## Inference
|
||||
|
||||
|
||||
+140
-186
@@ -1,23 +1,49 @@
|
||||
# Unitree G1
|
||||
|
||||
This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion.
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/unitree_thumbnail.jpg"
|
||||
alt="Unitree G1 locomanipulation demo"
|
||||
style={{ width: "100%" }}
|
||||
/>
|
||||
|
||||
## About
|
||||
|
||||
We support both 29 and 23 DOF G1 EDU version. We introduce:
|
||||
|
||||
- **`unitree g1` robot class, handling low level read/write from/to the humanoid**
|
||||
- **ZMQ socket bridge** for remote communication and camera streaming, allowing for remote policy deployment over wlan, eth or directly on the robot
|
||||
- **Locomotion policies** from NVIDIA gr00t and Amazon FAR Holosoma
|
||||
- **Simulation mode** for testing policies without the physical robot in mujoco
|
||||
The Unitree G1 humanoid is now supported in LeRobot! You can teleoperate, train locomanipulation policies, test in sim, and more. Both 29 and 23 DoF variants are supported.
|
||||
|
||||
---
|
||||
|
||||
## Connection guide
|
||||
## Part 1: Getting Started
|
||||
|
||||
### Step 1: Configure Ethernet Interface
|
||||
### Install LeRobot on Your Machine
|
||||
|
||||
Set a static IP on the same subnet as the robot:
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.12
|
||||
conda activate lerobot
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
```
|
||||
|
||||
### Test the Installation (Simulation)
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.id=wbc_unitree \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
This will launch a [MuJoCo sim instance](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main) for the G1.
|
||||
|
||||
- Press `9` to release the robot
|
||||
- Press `7` / `8` to increase / decrease waist height
|
||||
|
||||
### Connect to the Robot
|
||||
|
||||
The G1's Ethernet IP is fixed at `192.168.123.164`. Your machine must have a static IP on the same subnet: `192.168.123.x` where `x ≠ 164`.
|
||||
|
||||
```bash
|
||||
# Replace 'enp131s0' with your ethernet interface name (check with `ip a`)
|
||||
@@ -26,272 +52,200 @@ sudo ip addr add 192.168.123.200/24 dev enp131s0
|
||||
sudo ip link set enp131s0 up
|
||||
```
|
||||
|
||||
**Note**: The G1's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` with x ≠ 164.
|
||||
|
||||
### Step 2: SSH into the Robot
|
||||
### SSH into the Robot
|
||||
|
||||
```bash
|
||||
ssh unitree@192.168.123.164
|
||||
# Password: 123
|
||||
```
|
||||
|
||||
You should now be connected to the G1's Orin.
|
||||
### Install LeRobot on the G1
|
||||
|
||||
From the robot:
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.12
|
||||
conda activate lerobot
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
```
|
||||
|
||||
> **Note:** The Unitree SDK requires CycloneDDS v0.10.2. See the [Unitree SDK docs](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
|
||||
|
||||
---
|
||||
|
||||
## Part 2: Enable WiFi on the Robot
|
||||
|
||||
Wlan0 is disabled by default on the G1. To enable it:
|
||||
|
||||
### Step 1: Enable WiFi Hardware
|
||||
Wi-Fi connectivity is blocked by default on the G1. To activate:
|
||||
|
||||
```bash
|
||||
sudo rfkill unblock wifi
|
||||
sudo rfkill unblock all
|
||||
|
||||
# Bring up wlan0
|
||||
sudo ip link set wlan0 up
|
||||
|
||||
# Enable NetworkManager control of wlan0
|
||||
sudo nmcli radio wifi on
|
||||
sudo nmcli device set wlan0 managed yes
|
||||
sudo systemctl restart NetworkManager
|
||||
```
|
||||
|
||||
### Step 2: Enable Internet Forwarding
|
||||
|
||||
**On your laptop:**
|
||||
**On your laptop** (share internet via Ethernet):
|
||||
|
||||
```bash
|
||||
# Enable IP forwarding
|
||||
sudo sysctl -w net.ipv4.ip_forward=1
|
||||
|
||||
# Set up NAT (replace wlp132s0f0 with your WiFi interface)
|
||||
# Replace wlp132s0f0 with your WiFi interface name
|
||||
sudo iptables -t nat -A POSTROUTING -o wlp132s0f0 -s 192.168.123.0/24 -j MASQUERADE
|
||||
sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTABLISHED -j ACCEPT
|
||||
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
|
||||
```
|
||||
|
||||
**On the G1:**
|
||||
**On the G1** (set default route through your laptop):
|
||||
|
||||
```bash
|
||||
# Add laptop as default gateway
|
||||
sudo ip route del default 2>/dev/null || true
|
||||
sudo ip route add default via 192.168.123.200 dev eth0
|
||||
echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
|
||||
|
||||
# Test connection
|
||||
# Verify
|
||||
ping -c 3 8.8.8.8
|
||||
```
|
||||
|
||||
### Step 3: Connect to WiFi Network
|
||||
**Connect to a WiFi network:**
|
||||
|
||||
```bash
|
||||
# List available networks
|
||||
nmcli device wifi list
|
||||
|
||||
# Connect to your WiFi (example)
|
||||
sudo nmcli connection add type wifi ifname wlan0 con-name "YourNetwork" ssid "YourNetwork"
|
||||
sudo nmcli connection modify "YourNetwork" wifi-sec.key-mgmt wpa-psk
|
||||
sudo nmcli connection modify "YourNetwork" wifi-sec.psk "YourPassword"
|
||||
sudo nmcli connection modify "YourNetwork" connection.autoconnect yes
|
||||
sudo nmcli connection up "YourNetwork"
|
||||
|
||||
# Check WiFi IP address
|
||||
ip a show wlan0
|
||||
```
|
||||
|
||||
### Step 4: SSH Over WiFi
|
||||
|
||||
Once connected to WiFi, note the robot's IP address and disconnect the Ethernet cable. You can now SSH over WiFi:
|
||||
You can now SSH over WiFi:
|
||||
|
||||
```bash
|
||||
ssh unitree@<YOUR_ROBOT_IP>
|
||||
ssh unitree@<ROBOT_WIFI_IP>
|
||||
# Password: 123
|
||||
```
|
||||
|
||||
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address.
|
||||
|
||||
---
|
||||
|
||||
## Part 3: Robot Server Setup
|
||||
## Part 3: Teleoperation & Locomotion
|
||||
|
||||
### Step 1: Install LeRobot on the Orin
|
||||
|
||||
SSH into the robot and install LeRobot:
|
||||
|
||||
```bash
|
||||
ssh unitree@<YOUR_ROBOT_IP>
|
||||
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
```
|
||||
|
||||
**Note**: The Unitree SDK requires CycloneDDS v0.10.2 to be installed. See the [Unitree SDK documentation](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
|
||||
|
||||
### Step 2: Run the Robot Server
|
||||
### Run the Robot Server
|
||||
|
||||
On the robot:
|
||||
|
||||
```bash
|
||||
python src/lerobot/robots/unitree_g1/run_g1_server.py
|
||||
python src/lerobot/robots/unitree_g1/run_g1_server.py --camera
|
||||
```
|
||||
|
||||
**Important**: Keep this terminal running. The server must be active for remote control.
|
||||
### Run the Locomotion Policy
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.robot_ip=<ROBOT_IP> \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.id=wbc_unitree \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "<ROBOT_IP>", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--display_data=true \
|
||||
--robot.controller=HolosomaLocomotionController
|
||||
```
|
||||
|
||||
We support both [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) and [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl).
|
||||
|
||||
---
|
||||
|
||||
## Part 4: Controlling the robot
|
||||
## Part 4: Loco-Manipulation with the Homunculus Exoskeleton
|
||||
|
||||
With the robot server running, you can now control the robot remotely. Let's launch a locomotion policy
|
||||
We provide a loco-manipulation solution via the Homunculus Exoskeleton — an open-source 7 DoF exoskeleton for whole-body control. Assembly instructions [here](https://github.com/nepyope/hmc_exo).
|
||||
|
||||
### Step 1: Install LeRobot on your machine
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
```
|
||||
|
||||
### Step 2: Update Robot IP in Config
|
||||
|
||||
Edit the config file to match your robot's WiFi IP:
|
||||
|
||||
```python
|
||||
# In src/lerobot/robots/unitree_g1/config_unitree_g1.py
|
||||
robot_ip: str = "<YOUR_ROBOT_IP>" # Replace with your robot's WiFi IP.
|
||||
```
|
||||
|
||||
### Step 3: Run the Locomotion Policy
|
||||
|
||||
```bash
|
||||
# Run GR00T locomotion controller
|
||||
python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1"
|
||||
|
||||
# Run Holosoma locomotion controller
|
||||
python examples/unitree_g1/holosoma_locomotion.py
|
||||
|
||||
```
|
||||
|
||||
Press `Ctrl+C` to stop the policy.
|
||||
|
||||
---
|
||||
|
||||
## Running in Simulation Mode (MuJoCo)
|
||||
|
||||
You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI.
|
||||
|
||||
### Calibrate Exoskeleton Teleoperator
|
||||
### Calibrate
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo
|
||||
```
|
||||
|
||||
### Teleoperate in Simulation
|
||||
During calibration move each joint through its entire range. After fitting, move the joint in a neutral position and press `n` to advance.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--fps=100
|
||||
```
|
||||
|
||||
### Record Dataset in Simulation
|
||||
### Record a Dataset
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
|
||||
> **Note:** Omit `--teleop.left_arm_config.port` and `--teleop.right_arm_config.port` if you're only using the joystick.
|
||||
|
||||
Example dataset: [nepyope/unitree_box_move_blue_full](https://huggingface.co/datasets/nepyope/unitree_box_move_blue_full)
|
||||
|
||||
---
|
||||
|
||||
## Running on Real Robot
|
||||
## Part 5: Training & Inference
|
||||
|
||||
Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot.
|
||||
|
||||
### Start the Camera Server
|
||||
|
||||
On the robot, start the ZMQ image server:
|
||||
### Train
|
||||
|
||||
```bash
|
||||
python src/lerobot/cameras/zmq/image_server.py
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--policy.type=pi05 \
|
||||
--output_dir=./outputs/pi05_training \
|
||||
--job_name=pi05_training \
|
||||
--policy.repo_id=your-username/your-repo-id \
|
||||
--policy.pretrained_path=lerobot/pi05_base \
|
||||
--policy.compile_model=true \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--wandb.enable=true \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.train_expert_only=false \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--batch_size=32
|
||||
```
|
||||
|
||||
Keep this running in a separate terminal for camera streaming during recording.
|
||||
### Inference with RTC
|
||||
|
||||
### Teleoperate Real Robot
|
||||
Once trained, we recommend deploying policies using inference-time RTC:
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--fps=100
|
||||
python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=your-username/your-repo-id \
|
||||
--policy.device=cuda \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.controller=HolosomaLocomotionController \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "<ROBOT_IP>", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--task="task_description" \
|
||||
--duration=1000 \
|
||||
--fps=30 \
|
||||
--rtc.enabled=true
|
||||
```
|
||||
|
||||
### Record Dataset on Real Robot
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
**Note**: Update `server_address` to match your robot's camera server IP.
|
||||
|
||||
Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real)
|
||||
|
||||
---
|
||||
|
||||
## Additional Resources
|
||||
@@ -300,8 +254,8 @@ Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/da
|
||||
- [GR00T-WholeBodyControl](https://github.com/NVlabs/GR00T-WholeBodyControl)
|
||||
- [Holosoma](https://github.com/amazon-far/holosoma)
|
||||
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||
- [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
|
||||
- [Unitree IL LeRobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
|
||||
|
||||
---
|
||||
|
||||
_Last updated: December 2025_
|
||||
_Last updated: March 2026_
|
||||
|
||||
@@ -0,0 +1,490 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SLURM-distributed SARM RA-BC annotation pipeline.
|
||||
|
||||
Computes SARM progress values for all frames in a dataset, distributed across
|
||||
SLURM workers, then merges the shards into a single sarm_progress.parquet.
|
||||
|
||||
Two subcommands, each a separate SLURM submission:
|
||||
|
||||
compute – N workers, each computes progress for a subset of episodes
|
||||
aggregate – 1 worker, merges N shards into sarm_progress.parquet, pushes to hub
|
||||
|
||||
Usage:
|
||||
python slurm_compute_rabc.py compute \\
|
||||
--repo-id user/dataset --reward-model-path user/sarm_model \\
|
||||
--stride 10 --device cpu --workers 50 --partition cpu
|
||||
|
||||
python slurm_compute_rabc.py aggregate \\
|
||||
--repo-id user/dataset --reward-model-path user/sarm_model \\
|
||||
--partition cpu --push-to-hub
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
|
||||
class ComputeProgressShards(PipelineStep):
|
||||
"""Each worker computes SARM progress for its assigned episodes."""
|
||||
|
||||
def __init__(
|
||||
self, repo_id, reward_model_path, stride=1, head_mode="sparse", device="cpu", shard_dir="rabc_shards"
|
||||
):
|
||||
super().__init__()
|
||||
if stride < 1:
|
||||
raise ValueError(f"stride must be >= 1, got {stride}")
|
||||
self.repo_id = repo_id
|
||||
self.reward_model_path = reward_model_path
|
||||
self.stride = stride
|
||||
self.head_mode = head_mode
|
||||
self.device = device
|
||||
self.shard_dir = shard_dir
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.policies.sarm.compute_rabc_weights import (
|
||||
generate_all_frame_indices,
|
||||
interpolate_progress,
|
||||
load_sarm_resources,
|
||||
)
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
|
||||
dataset, reward_model, preprocess = load_sarm_resources(
|
||||
self.repo_id,
|
||||
self.reward_model_path,
|
||||
self.device,
|
||||
)
|
||||
|
||||
if hasattr(preprocess, "eval"):
|
||||
preprocess.eval()
|
||||
for step in preprocess.steps:
|
||||
if hasattr(step, "eval"):
|
||||
step.eval()
|
||||
|
||||
image_key = reward_model.config.image_key
|
||||
state_key = reward_model.config.state_key
|
||||
frame_gap = reward_model.config.frame_gap
|
||||
center_idx = reward_model.config.n_obs_steps // 2
|
||||
|
||||
dual_mode = reward_model.config.uses_dual_heads
|
||||
compute_sparse = self.head_mode in ("sparse", "both") or not dual_mode
|
||||
compute_dense = self.head_mode in ("dense", "both") and dual_mode
|
||||
|
||||
my_episodes = list(range(dataset.num_episodes))[rank::world_size]
|
||||
if not my_episodes:
|
||||
logging.info(f"Rank {rank}: no episodes assigned")
|
||||
return
|
||||
logging.info(f"Rank {rank}: {len(my_episodes)} / {dataset.num_episodes} episodes")
|
||||
|
||||
all_rows = []
|
||||
|
||||
for ep_idx in tqdm(my_episodes, desc=f"Rank {rank}"):
|
||||
ep = dataset.meta.episodes[ep_idx]
|
||||
ep_start, ep_end = ep["dataset_from_index"], ep["dataset_to_index"]
|
||||
task = dataset[ep_start].get("task", "perform the task")
|
||||
|
||||
all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap)
|
||||
if self.stride > 1:
|
||||
compute_indices = [i for i in all_ep_indices if (i - ep_start) % self.stride == 0]
|
||||
if (ep_end - 1) not in compute_indices:
|
||||
compute_indices.append(ep_end - 1)
|
||||
compute_indices = sorted(set(compute_indices))
|
||||
else:
|
||||
compute_indices = all_ep_indices
|
||||
|
||||
frame_results = {}
|
||||
for qi in tqdm(compute_indices, desc=f" Ep {ep_idx}", leave=False):
|
||||
try:
|
||||
sample = dataset[qi]
|
||||
batch = {
|
||||
image_key: sample[image_key],
|
||||
"task": task,
|
||||
"index": qi,
|
||||
"episode_index": ep_idx,
|
||||
}
|
||||
if state_key in sample:
|
||||
batch[state_key] = sample[state_key]
|
||||
|
||||
with torch.no_grad():
|
||||
processed = preprocess(batch)
|
||||
vf = processed["video_features"].to(self.device)
|
||||
tf = processed["text_features"].to(self.device)
|
||||
sf = processed.get("state_features")
|
||||
if sf is not None:
|
||||
sf = sf.to(self.device)
|
||||
lengths = processed.get("lengths")
|
||||
|
||||
sparse_val = dense_val = np.nan
|
||||
if compute_sparse:
|
||||
r = reward_model.calculate_rewards(
|
||||
text_embeddings=tf,
|
||||
video_embeddings=vf,
|
||||
state_features=sf,
|
||||
lengths=lengths,
|
||||
return_all_frames=True,
|
||||
head_mode="sparse",
|
||||
)
|
||||
sparse_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx])
|
||||
if compute_dense:
|
||||
r = reward_model.calculate_rewards(
|
||||
text_embeddings=tf,
|
||||
video_embeddings=vf,
|
||||
state_features=sf,
|
||||
lengths=lengths,
|
||||
return_all_frames=True,
|
||||
head_mode="dense",
|
||||
)
|
||||
dense_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx])
|
||||
|
||||
frame_results[qi] = (sparse_val, dense_val)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed frame {qi}: {e}")
|
||||
|
||||
if not frame_results:
|
||||
logging.warning(f"Episode {ep_idx}: all frames failed, skipping")
|
||||
continue
|
||||
|
||||
# Interpolate to all frames in this episode
|
||||
computed_idx = np.array(sorted(frame_results.keys()))
|
||||
all_frame_arr = np.arange(ep_start, ep_end)
|
||||
|
||||
sparse_vals = np.array([frame_results[i][0] for i in computed_idx]) if compute_sparse else None
|
||||
dense_vals = np.array([frame_results[i][1] for i in computed_idx]) if compute_dense else None
|
||||
|
||||
if self.stride > 1 and len(computed_idx) > 1:
|
||||
if compute_sparse:
|
||||
sparse_vals = interpolate_progress(computed_idx, sparse_vals, all_frame_arr)
|
||||
if compute_dense:
|
||||
dense_vals = interpolate_progress(computed_idx, dense_vals, all_frame_arr)
|
||||
output_frames = all_frame_arr
|
||||
else:
|
||||
# Use only successfully computed frames to avoid indexing mismatch on failures
|
||||
output_frames = computed_idx
|
||||
|
||||
for i, fi in enumerate(output_frames):
|
||||
row = {"index": int(fi), "episode_index": ep_idx, "frame_index": int(fi - ep_start)}
|
||||
if compute_sparse:
|
||||
row["progress_sparse"] = float(sparse_vals[i])
|
||||
if compute_dense:
|
||||
row["progress_dense"] = float(dense_vals[i])
|
||||
all_rows.append(row)
|
||||
|
||||
if all_rows:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame(all_rows).sort_values("index").reset_index(drop=True)
|
||||
table = pa.Table.from_pandas(df, preserve_index=False)
|
||||
table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()})
|
||||
shard_dir = Path(self.shard_dir)
|
||||
shard_dir.mkdir(parents=True, exist_ok=True)
|
||||
out = shard_dir / f"shard_{rank:05d}.parquet"
|
||||
pq.write_table(table, out)
|
||||
logging.info(f"Rank {rank}: saved {len(df)} rows to {out}")
|
||||
|
||||
|
||||
class AggregateProgress(PipelineStep):
|
||||
"""Merge all shard parquets into final sarm_progress.parquet."""
|
||||
|
||||
def __init__(self, repo_id, reward_model_path, shard_dir="rabc_shards", push_to_hub=False):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.reward_model_path = reward_model_path
|
||||
self.shard_dir = shard_dir
|
||||
self.push_to_hub = push_to_hub
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
if rank != 0:
|
||||
return
|
||||
|
||||
shard_dir = Path(self.shard_dir)
|
||||
shards = sorted(shard_dir.glob("shard_*.parquet"))
|
||||
if not shards:
|
||||
raise FileNotFoundError(f"No shards found in {shard_dir}")
|
||||
|
||||
# Log shard modification time range to help detect stale files
|
||||
mtimes = [os.path.getmtime(s) for s in shards]
|
||||
oldest = datetime.datetime.fromtimestamp(min(mtimes)).isoformat(timespec="seconds")
|
||||
newest = datetime.datetime.fromtimestamp(max(mtimes)).isoformat(timespec="seconds")
|
||||
logging.info(f"Aggregating {len(shards)} shards (oldest: {oldest}, newest: {newest})")
|
||||
|
||||
df = pd.concat([pd.read_parquet(s) for s in shards], ignore_index=True)
|
||||
df = df.sort_values("index").reset_index(drop=True)
|
||||
|
||||
table = pa.Table.from_pandas(df, preserve_index=False)
|
||||
table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()})
|
||||
|
||||
temp_ds = LeRobotDataset(self.repo_id, download_videos=False)
|
||||
out_path = Path(temp_ds.root) / "sarm_progress.parquet"
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(table, out_path)
|
||||
logging.info(f"Saved {len(df)} rows to {out_path}")
|
||||
|
||||
for col in ["progress_sparse", "progress_dense"]:
|
||||
if col in df.columns:
|
||||
v = df[col].dropna()
|
||||
logging.info(
|
||||
f"{col}: mean={v.mean():.4f} std={v.std():.4f} min={v.min():.4f} max={v.max():.4f}"
|
||||
)
|
||||
|
||||
if self.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
hub_path = "sarm_progress.parquet"
|
||||
logging.info(f"Uploading to {self.repo_id}/{hub_path}")
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(out_path),
|
||||
path_in_repo=hub_path,
|
||||
repo_id=self.repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
logging.info(f"Uploaded: https://huggingface.co/datasets/{self.repo_id}/blob/main/{hub_path}")
|
||||
|
||||
|
||||
def make_compute_executor(
|
||||
repo_id,
|
||||
reward_model_path,
|
||||
stride,
|
||||
head_mode,
|
||||
device,
|
||||
shard_dir,
|
||||
logs_dir,
|
||||
job_name,
|
||||
slurm,
|
||||
workers,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
ComputeProgressShards(repo_id, reward_model_path, stride, head_mode, device, str(shard_dir)),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": workers,
|
||||
"workers": workers,
|
||||
"time": "24:00:00",
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
return SlurmPipelineExecutor(**kwargs)
|
||||
|
||||
kwargs.update({"tasks": workers, "workers": 1})
|
||||
return LocalPipelineExecutor(**kwargs)
|
||||
|
||||
|
||||
def make_aggregate_executor(
|
||||
repo_id,
|
||||
reward_model_path,
|
||||
shard_dir,
|
||||
logs_dir,
|
||||
job_name,
|
||||
slurm,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
push_to_hub,
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
AggregateProgress(repo_id, reward_model_path, str(shard_dir), push_to_hub),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
"time": "02:00:00",
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
return SlurmPipelineExecutor(**kwargs)
|
||||
|
||||
kwargs.update({"tasks": 1, "workers": 1})
|
||||
return LocalPipelineExecutor(**kwargs)
|
||||
|
||||
|
||||
def _add_shared_args(p):
|
||||
p.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Hugging Face repository identifier, e.g. 'user/dataset'.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--shard-dir",
|
||||
type=Path,
|
||||
default=Path("rabc_shards"),
|
||||
help="Directory to read/write per-rank parquet shards.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
default=Path("logs"),
|
||||
help="Directory for datatrove logs.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="SLURM job name (defaults to rabc_<subcommand>).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="1 = submit via SLURM; 0 = run locally (useful for debugging).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
default=None,
|
||||
help="SLURM partition to submit to.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of CPUs per SLURM task.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="4G",
|
||||
help="Memory per CPU, e.g. '4G' or '1950M'.",
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="SLURM-distributed SARM RA-BC annotation pipeline",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# compute subcommand
|
||||
cp = sub.add_parser(
|
||||
"compute",
|
||||
help="Distribute progress computation across SLURM workers.",
|
||||
)
|
||||
_add_shared_args(cp)
|
||||
cp.add_argument(
|
||||
"--reward-model-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path or HF repo id of the SARM reward model.",
|
||||
)
|
||||
cp.add_argument(
|
||||
"--stride",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Compute every Nth frame; intermediate frames are interpolated (must be >= 1).",
|
||||
)
|
||||
cp.add_argument(
|
||||
"--head-mode",
|
||||
type=str,
|
||||
default="sparse",
|
||||
choices=["sparse", "dense", "both"],
|
||||
help="Which reward head(s) to compute.",
|
||||
)
|
||||
cp.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Device for reward model inference, e.g. 'cpu' or 'cuda'.",
|
||||
)
|
||||
cp.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of parallel SLURM tasks (one shard per worker).",
|
||||
)
|
||||
|
||||
# aggregate subcommand
|
||||
ap = sub.add_parser(
|
||||
"aggregate",
|
||||
help="Merge per-rank shards into a single sarm_progress.parquet.",
|
||||
)
|
||||
_add_shared_args(ap)
|
||||
ap.add_argument(
|
||||
"--reward-model-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path or HF repo id of the SARM reward model (stored in parquet metadata).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Upload sarm_progress.parquet to the Hugging Face Hub after aggregation.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
job_name = args.job_name or f"rabc_{args.command}"
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
kwargs["job_name"] = job_name
|
||||
command = kwargs.pop("command")
|
||||
|
||||
executor = make_compute_executor(**kwargs) if command == "compute" else make_aggregate_executor(**kwargs)
|
||||
|
||||
executor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -18,6 +18,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
@@ -70,6 +71,9 @@ def main():
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_evaluate")
|
||||
@@ -95,6 +99,9 @@ def main():
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -109,6 +116,9 @@ def main():
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
@@ -45,6 +46,9 @@ def main():
|
||||
leader_arm = SO100Leader(leader_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
@@ -89,6 +93,9 @@ def main():
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -104,6 +111,9 @@ def main():
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -17,16 +17,30 @@
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
make_default_teleop_action_processor,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.pipelines import (
|
||||
make_so10x_fk_observation_pipeline,
|
||||
make_so10x_ik_action_pipeline,
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.pipeline_utils import build_dataset_features
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
@@ -37,10 +51,6 @@ TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
URDF_PATH = "./SO101/so101_new_calib.urdf"
|
||||
|
||||
|
||||
def main():
|
||||
# Create the robot configuration & robot
|
||||
@@ -54,31 +64,68 @@ def main():
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Attach FK/IK pipelines so the robot works in EE space
|
||||
motor_names = list(robot.bus.motors.keys())
|
||||
robot.set_output_pipeline(make_so10x_fk_observation_pipeline(URDF_PATH, motor_names))
|
||||
robot.set_input_pipeline(make_so10x_ik_action_pipeline(URDF_PATH, motor_names))
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Create the dataset — obs auto-derived from FK pipeline, EE action spec explicit
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joints observation to EE observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
|
||||
)
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=build_dataset_features(
|
||||
robot,
|
||||
use_videos=True,
|
||||
action_features={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
},
|
||||
features=combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# User for now should be explicit on the feature keys that were used for record
|
||||
# Alternatively, the user can pass the processor step that has the right features
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
@@ -104,18 +151,21 @@ def main():
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop — pipelines applied internally by robot
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -130,6 +180,9 @@ def main():
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -16,17 +16,21 @@
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
robot_action_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.pipelines import make_so10x_fk_observation_pipeline
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
ForwardKinematicsJointsToEE,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
@@ -35,7 +39,6 @@ from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.pipeline_utils import build_dataset_features
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
@@ -46,10 +49,6 @@ RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
URDF_PATH = "./SO101/so101_new_calib.urdf"
|
||||
|
||||
|
||||
def main():
|
||||
# Create the robot and teleoperator configurations
|
||||
@@ -66,59 +65,77 @@ def main():
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
|
||||
motor_names = list(robot.bus.motors.keys())
|
||||
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path=URDF_PATH,
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=motor_names,
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Phone output pipeline: map raw phone gesture to EE delta (no robot obs needed)
|
||||
phone.set_output_pipeline(
|
||||
RobotProcessorPipeline[RobotAction, RobotAction](
|
||||
steps=[MapPhoneActionToRobotAction(platform=teleop_config.phone_os)],
|
||||
to_transition=robot_action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# Build pipeline to convert phone action to EE action
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
](
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
use_latched_reference=True,
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
),
|
||||
GripperVelocityToJoint(speed_factor=20.0),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Robot FK observation pipeline: joints → EE pose
|
||||
robot.set_output_pipeline(make_so10x_fk_observation_pipeline(URDF_PATH, motor_names))
|
||||
|
||||
# Robot input pipeline: EE delta + current robot obs → joint commands
|
||||
robot.set_input_pipeline(
|
||||
RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=motor_names,
|
||||
use_latched_reference=True,
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
),
|
||||
GripperVelocityToJoint(speed_factor=20.0),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=motor_names,
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Dataset features auto-derived from robot's FK obs pipeline and phone's mapped action pipeline
|
||||
# Build pipeline to convert joint observation to EE observation
|
||||
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
|
||||
)
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=build_dataset_features(robot, phone, use_videos=True),
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose_processor,
|
||||
initial_features=create_initial_features(action=phone.action_features),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
@@ -141,7 +158,7 @@ def main():
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop — pipelines applied internally by robot and phone
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
@@ -151,6 +168,9 @@ def main():
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -166,6 +186,9 @@ def main():
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -78,6 +78,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
@@ -87,8 +88,8 @@ from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor.factory import (
|
||||
_make_identity_observation_pipeline as make_default_robot_observation_processor,
|
||||
_make_identity_robot_action_pipeline as make_default_robot_action_processor,
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
)
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
@@ -97,6 +98,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
so_follower,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
@@ -17,16 +17,30 @@
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
make_default_teleop_action_processor,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.pipelines import (
|
||||
make_so10x_fk_observation_pipeline,
|
||||
make_so10x_ik_action_pipeline,
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.pipeline_utils import build_dataset_features
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
@@ -37,10 +51,6 @@ TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
URDF_PATH = "./SO101/so101_new_calib.urdf"
|
||||
|
||||
|
||||
def main():
|
||||
# Create the robot configuration & robot
|
||||
@@ -54,31 +64,68 @@ def main():
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Attach FK/IK pipelines so the robot works in EE space
|
||||
motor_names = list(robot.bus.motors.keys())
|
||||
robot.set_output_pipeline(make_so10x_fk_observation_pipeline(URDF_PATH, motor_names))
|
||||
robot.set_input_pipeline(make_so10x_ik_action_pipeline(URDF_PATH, motor_names))
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Create the dataset — obs auto-derived from FK pipeline, EE action spec explicit
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joints observation to EE observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
|
||||
)
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=build_dataset_features(
|
||||
robot,
|
||||
use_videos=True,
|
||||
action_features={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
},
|
||||
features=combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# User for now should be explicit on the feature keys that were used for record
|
||||
# Alternatively, the user can pass the processor step that has the right features
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
@@ -88,7 +135,7 @@ def main():
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
@@ -104,18 +151,21 @@ def main():
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop — pipelines applied internally by robot
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -130,6 +180,9 @@ def main():
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -17,20 +17,25 @@
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.pipelines import (
|
||||
make_so10x_fk_observation_pipeline,
|
||||
make_so10x_ik_action_pipeline,
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.teleoperators.so_leader.pipelines import make_so10x_leader_fk_pipeline
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.pipeline_utils import (
|
||||
build_dataset_features,
|
||||
check_action_space_compatibility,
|
||||
check_observation_space_compatibility,
|
||||
)
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
@@ -41,10 +46,6 @@ RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# NOTE: Use the URDF from the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
URDF_PATH = "./SO101/so101_new_calib.urdf"
|
||||
|
||||
|
||||
def main():
|
||||
# Create the robot and teleoperator configurations
|
||||
@@ -61,17 +62,77 @@ def main():
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
|
||||
# Attach EE-space pipelines to the objects
|
||||
motor_names = list(follower.bus.motors.keys())
|
||||
follower.set_output_pipeline(make_so10x_fk_observation_pipeline(URDF_PATH, motor_names))
|
||||
follower.set_input_pipeline(make_so10x_ik_action_pipeline(URDF_PATH, motor_names))
|
||||
leader.set_output_pipeline(make_so10x_leader_fk_pipeline(URDF_PATH, list(leader.bus.motors.keys())))
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Dataset features are derived automatically from robot/teleop pipelines
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert follower joints to EE observation
|
||||
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Build pipeline to convert leader joints to EE action
|
||||
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to follower joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=build_dataset_features(follower, leader, use_videos=True),
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=leader_joints_to_ee,
|
||||
initial_features=create_initial_features(action=leader.action_features),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=follower_joints_to_ee,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
@@ -81,13 +142,9 @@ def main():
|
||||
leader.connect()
|
||||
follower.connect()
|
||||
|
||||
# Verify action/observation space alignment (warns on mismatch)
|
||||
check_action_space_compatibility(leader, follower)
|
||||
check_observation_space_compatibility(follower, leader)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording_ee")
|
||||
init_rerun(session_name="recording_phone")
|
||||
|
||||
try:
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
@@ -98,8 +155,7 @@ def main():
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Pipelines applied automatically inside robot.get_observation(),
|
||||
# teleop.get_action(), and robot.send_action()
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
@@ -109,6 +165,9 @@ def main():
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -124,6 +183,9 @@ def main():
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -14,23 +14,27 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.pipelines import (
|
||||
make_so10x_fk_observation_pipeline,
|
||||
make_so10x_ik_action_pipeline,
|
||||
import time
|
||||
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
robot_action_observation_to_transition,
|
||||
robot_action_to_transition,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_teleoperate import teleop_loop
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.teleoperators.so_leader.pipelines import make_so10x_leader_fk_pipeline
|
||||
from lerobot.utils.pipeline_utils import check_action_space_compatibility
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
# NOTE: Use the URDF from the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
URDF_PATH = "./SO101/so101_new_calib.urdf"
|
||||
|
||||
|
||||
def main():
|
||||
# Initialize the robot and teleoperator config
|
||||
@@ -43,14 +47,47 @@ def main():
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
|
||||
# Attach EE-space pipelines to the objects
|
||||
motor_names = list(follower.bus.motors.keys())
|
||||
follower.set_output_pipeline(make_so10x_fk_observation_pipeline(URDF_PATH, motor_names))
|
||||
follower.set_input_pipeline(make_so10x_ik_action_pipeline(URDF_PATH, motor_names))
|
||||
leader.set_output_pipeline(make_so10x_leader_fk_pipeline(URDF_PATH, list(leader.bus.motors.keys())))
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Verify action space alignment (warns if leader EE ≠ follower action_features)
|
||||
check_action_space_compatibility(leader, follower)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert teleop joints to EE action
|
||||
leader_to_ee = RobotProcessorPipeline[RobotAction, RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# build pipeline to convert EE action to robot joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=False,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Connect to the robot and teleoperator
|
||||
follower.connect()
|
||||
@@ -60,12 +97,28 @@ def main():
|
||||
init_rerun(session_name="so100_so100_EE_teleop")
|
||||
|
||||
print("Starting teleop loop...")
|
||||
try:
|
||||
# Pipelines applied automatically inside teleop.get_action() and robot.send_action()
|
||||
teleop_loop(teleop=leader, robot=follower, fps=FPS, display_data=True)
|
||||
finally:
|
||||
follower.disconnect()
|
||||
leader.disconnect()
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = follower.get_observation()
|
||||
|
||||
# Get teleop observation
|
||||
leader_joints_obs = leader.get_action()
|
||||
|
||||
# teleop joints -> teleop EE action
|
||||
leader_ee_act = leader_to_ee(leader_joints_obs)
|
||||
|
||||
# teleop EE -> robot joints
|
||||
follower_joints_act = ee_to_follower_joints((leader_ee_act, robot_obs))
|
||||
|
||||
# Send action to robot
|
||||
_ = follower.send_action(follower_joints_act)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+59
-117
@@ -25,11 +25,11 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.4.5"
|
||||
version = "0.5.1"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
dynamic = ["readme"]
|
||||
license = { text = "Apache-2.0" }
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.12"
|
||||
authors = [
|
||||
{ name = "Rémi Cadène", email = "re.cadene@gmail.com" },
|
||||
{ name = "Simon Alibert", email = "alibert.sim@gmail.com" },
|
||||
@@ -50,7 +50,8 @@ classifiers = [
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Software Development :: Build Tools",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
]
|
||||
@@ -61,26 +62,28 @@ dependencies = [
|
||||
# Hugging Face dependencies
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"diffusers>=0.27.2,<0.36.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||
"huggingface-hub>=1.0.0,<2.0.0",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
|
||||
# Core dependencies
|
||||
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
|
||||
"setuptools>=71.0.0,<81.0.0",
|
||||
"cmake>=3.29.0.1,<4.2.0",
|
||||
"packaging>=24.2,<26.0",
|
||||
|
||||
"torch>=2.2.1,<2.11.0",
|
||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||
"torchvision>=0.21.0,<0.26.0",
|
||||
|
||||
"einops>=0.8.0,<0.9.0",
|
||||
"opencv-python-headless>=4.9.0,<4.13.0",
|
||||
"av>=15.0.0,<16.0.0",
|
||||
"jsonlines>=4.0.0,<5.0.0",
|
||||
"packaging>=24.2,<26.0",
|
||||
"pynput>=1.7.7,<1.9.0",
|
||||
"pynput>=1.7.8,<1.9.0",
|
||||
"pyserial>=3.5,<4.0",
|
||||
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
|
||||
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
|
||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
|
||||
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
|
||||
|
||||
"draccus==0.10.0", # TODO: Remove ==
|
||||
"draccus==0.10.0", # TODO: Relax version constraint
|
||||
"gymnasium>=1.1.1,<2.0.0",
|
||||
"rerun-sdk>=0.24.0,<0.27.0",
|
||||
|
||||
@@ -95,10 +98,14 @@ dependencies = [
|
||||
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||
transformers-dep = ["transformers>=5.3.0,<6.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
|
||||
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
@@ -112,34 +119,36 @@ gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
unitree_g1 = [
|
||||
"unitree-sdk2==1.0.1",
|
||||
"pyzmq>=26.2.1,<28.0.0",
|
||||
"onnxruntime>=1.16.0,<2.0.0",
|
||||
"pin>=3.0.0,<4.0.0",
|
||||
"meshcat>=0.3.0,<0.4.0",
|
||||
"matplotlib>=3.9.0,<4.0.0",
|
||||
"lerobot[matplotlib-dep]",
|
||||
"lerobot[pygame-dep]",
|
||||
"casadi>=3.6.0,<4.0.0",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54,<2.57.0 ; sys_platform == 'darwin'",
|
||||
]
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0", "lerobot[scipy-dep]"]
|
||||
|
||||
# Policies
|
||||
wallx = [
|
||||
"transformers==4.49.0",
|
||||
"peft==0.17.1",
|
||||
"scipy==1.15.3",
|
||||
"torchdiffeq==0.2.5",
|
||||
"qwen_vl_utils==0.0.11"
|
||||
"lerobot[transformers-dep]",
|
||||
"lerobot[peft]",
|
||||
"lerobot[scipy-dep]",
|
||||
"torchdiffeq>=0.2.4,<0.3.0",
|
||||
"lerobot[qwen-vl-utils-dep]",
|
||||
]
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
"peft>=0.13.0,<1.0.0",
|
||||
"lerobot[peft]",
|
||||
"dm-tree>=0.1.8,<1.0.0",
|
||||
"timm>=1.0.0,<1.1.0",
|
||||
"safetensors>=0.4.3,<1.0.0",
|
||||
@@ -148,13 +157,13 @@ groot = [
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14,<0.1.0"]
|
||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
peft = ["lerobot[transformers-dep]", "peft>=0.18.0,<1.0.0"]
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
|
||||
@@ -162,13 +171,27 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0"]
|
||||
metaworld = ["metaworld==3.0.0"]
|
||||
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero_plus = [
|
||||
"lerobot[transformers-dep]",
|
||||
"libero @ git+https://github.com/sylvestf/LIBERO-plus.git@main ; sys_platform == 'linux'",
|
||||
"lerobot[scipy-dep]",
|
||||
]
|
||||
robomme = [
|
||||
"robomme @ git+https://github.com/RoboMME/robomme_benchmark.git@main ; sys_platform == 'linux'",
|
||||
]
|
||||
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
|
||||
# All
|
||||
all = [
|
||||
# NOTE(resolver hint): scipy is pulled in transitively via lerobot[scipy-dep] through
|
||||
# multiple extras (aloha, metaworld, pi, wallx, phone). Listing it explicitly
|
||||
# helps pip's resolver converge by constraining scipy early, before it encounters
|
||||
# the loose scipy requirements from transitive deps like dm-control and metaworld.
|
||||
"scipy>=1.14.0,<2.0.0",
|
||||
"lerobot[dynamixel]",
|
||||
"lerobot[gamepad]",
|
||||
"lerobot[hopejr]",
|
||||
@@ -176,8 +199,8 @@ all = [
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
# "lerobot[wallx]",
|
||||
# "lerobot[pi]", TODO(Pepijn): Update pi to transformers v5
|
||||
"lerobot[wallx]",
|
||||
"lerobot[pi]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
@@ -189,10 +212,11 @@ all = [
|
||||
"lerobot[aloha]",
|
||||
"lerobot[pusht]",
|
||||
"lerobot[phone]",
|
||||
"lerobot[libero]",
|
||||
"lerobot[libero]; sys_platform == 'linux'",
|
||||
"lerobot[metaworld]",
|
||||
"lerobot[sarm]",
|
||||
"lerobot[peft]",
|
||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -221,7 +245,7 @@ lerobot = ["envs/*.json"]
|
||||
where = ["src"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
target-version = "py312"
|
||||
line-length = 110
|
||||
exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
|
||||
|
||||
@@ -313,7 +337,7 @@ default.extend-ignore-identifiers-re = [
|
||||
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
python_version = "3.12"
|
||||
ignore_missing_imports = true
|
||||
follow_imports = "skip"
|
||||
# warn_return_any = true
|
||||
@@ -397,85 +421,3 @@ ignore_errors = false
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.scripts.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[tool.uv]
|
||||
# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "transformers-dep" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "pi" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "smolvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "groot" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "xvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "sarm" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "hilserl" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "libero" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "peft" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "all" },
|
||||
],
|
||||
# pi uses custom branch which conflicts with transformers-dep
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "transformers-dep" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "smolvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "groot" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "xvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "sarm" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "hilserl" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "libero" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "peft" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "all" },
|
||||
],
|
||||
]
|
||||
|
||||
+170
-271
@@ -1,76 +1,73 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with Python 3.10
|
||||
# This file is autogenerated by pip-compile with Python 3.12
|
||||
# by the following command:
|
||||
#
|
||||
# pip-compile --output-file=requirements-macos.txt requirements.in
|
||||
#
|
||||
-e .[all]
|
||||
# via -[all]
|
||||
absl-py==2.3.1
|
||||
absl-py==2.4.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# labmaze
|
||||
# mujoco
|
||||
# tensorboard
|
||||
accelerate==1.11.0
|
||||
accelerate==1.13.0
|
||||
# via
|
||||
# lerobot
|
||||
# peft
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.1
|
||||
aiohttp==3.13.3
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-doc==0.0.4
|
||||
# via
|
||||
# fastapi
|
||||
# typer
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via
|
||||
# hydra-core
|
||||
# omegaconf
|
||||
anyio==4.11.0
|
||||
anyio==4.12.1
|
||||
# via
|
||||
# httpx
|
||||
# starlette
|
||||
# watchfiles
|
||||
asttokens==3.0.0
|
||||
asttokens==3.0.1
|
||||
# via stack-data
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
# dm-tree
|
||||
# jsonlines
|
||||
# jsonschema
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
av==15.1.0
|
||||
# via lerobot
|
||||
bddl==1.0.1
|
||||
# via libero
|
||||
certifi==2025.10.5
|
||||
# via
|
||||
# lerobot
|
||||
# qwen-vl-utils
|
||||
certifi==2026.2.25
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==2.0.0
|
||||
# via pymunk
|
||||
cfgv==3.4.0
|
||||
cfgv==3.5.0
|
||||
# via pre-commit
|
||||
charset-normalizer==3.4.4
|
||||
charset-normalizer==3.4.5
|
||||
# via requests
|
||||
click==8.3.0
|
||||
click==8.3.1
|
||||
# via
|
||||
# typer
|
||||
# uvicorn
|
||||
# wandb
|
||||
cloudpickle==3.1.1
|
||||
# via
|
||||
# gymnasium
|
||||
# libero
|
||||
cmake==4.1.0
|
||||
cloudpickle==3.1.2
|
||||
# via gymnasium
|
||||
cmake==4.1.3
|
||||
# via lerobot
|
||||
cmeel==0.57.3
|
||||
cmeel==0.59.0
|
||||
# via
|
||||
# cmeel-assimp
|
||||
# cmeel-boost
|
||||
@@ -108,15 +105,17 @@ cmeel-zlib==1.3.1
|
||||
# via cmeel-assimp
|
||||
coal-library==3.0.1
|
||||
# via pin
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
coverage[toml]==7.11.0
|
||||
contourpy==1.3.3
|
||||
# via
|
||||
# lerobot
|
||||
# matplotlib
|
||||
coverage[toml]==7.13.4
|
||||
# via pytest-cov
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
datasets==4.1.1
|
||||
datasets==4.6.1
|
||||
# via lerobot
|
||||
debugpy==1.8.17
|
||||
debugpy==1.8.20
|
||||
# via lerobot
|
||||
decorator==5.2.1
|
||||
# via ipython
|
||||
@@ -130,7 +129,7 @@ dill==0.4.0
|
||||
# multiprocess
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
dm-control==1.0.34
|
||||
dm-control==1.0.37
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
@@ -138,69 +137,55 @@ dm-tree==0.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# lerobot
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
draccus==0.10.0
|
||||
# via lerobot
|
||||
dynamixel-sdk==3.8.4
|
||||
# via lerobot
|
||||
easydict==1.13
|
||||
# via libero
|
||||
egl-probe @ git+https://github.com/huggingface/egl_probe.git
|
||||
# via
|
||||
# libero
|
||||
# robomimic
|
||||
eigenpy==3.10.3
|
||||
# via coal-library
|
||||
einops==0.8.1
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
einops==0.8.2
|
||||
# via lerobot
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
etils[epath,epy]==1.13.0
|
||||
etils[epath,epy]==1.14.0
|
||||
# via mujoco
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# ipython
|
||||
# pytest
|
||||
executing==2.2.1
|
||||
# via stack-data
|
||||
faker==34.0.2
|
||||
# via lerobot
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fastapi==0.119.1
|
||||
# via teleop
|
||||
fastjsonschema==2.21.2
|
||||
# via nbformat
|
||||
fastapi==0.135.1
|
||||
# via
|
||||
# lerobot
|
||||
# teleop
|
||||
feetech-servo-sdk==1.0.0
|
||||
# via lerobot
|
||||
filelock==3.20.0
|
||||
filelock==3.25.0
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# huggingface-hub
|
||||
# python-discovery
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
fonttools==4.60.1
|
||||
fonttools==4.61.1
|
||||
# via matplotlib
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2025.9.0
|
||||
fsspec[http]==2026.2.0
|
||||
# via
|
||||
# datasets
|
||||
# etils
|
||||
# huggingface-hub
|
||||
# torch
|
||||
future==1.0.0
|
||||
# via libero
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.45
|
||||
gitpython==3.1.46
|
||||
# via wandb
|
||||
glfw==2.10.0
|
||||
# via
|
||||
@@ -212,7 +197,6 @@ grpcio==1.73.1
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
grpcio-tools==1.73.1
|
||||
# via
|
||||
# lerobot
|
||||
@@ -223,71 +207,67 @@ gym-hil==0.1.13
|
||||
# via lerobot
|
||||
gym-pusht==0.1.6
|
||||
# via lerobot
|
||||
gymnasium==1.2.1
|
||||
gymnasium==1.2.3
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
# libero
|
||||
# metaworld
|
||||
h11==0.16.0
|
||||
# via uvicorn
|
||||
h5py==3.15.1
|
||||
# via robomimic
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
hebi-py==2.11.0
|
||||
# via lerobot
|
||||
hf-transfer==0.1.9
|
||||
# via huggingface-hub
|
||||
hf-xet==1.1.10
|
||||
hf-xet==1.3.2
|
||||
# via huggingface-hub
|
||||
hidapi==0.14.0.post4
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httptools==0.7.1
|
||||
# via uvicorn
|
||||
huggingface-hub[cli,hf-transfer]==0.35.3
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
huggingface-hub==1.6.0
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
hydra-core==1.3.2
|
||||
# via libero
|
||||
identify==2.6.15
|
||||
identify==2.6.17
|
||||
# via pre-commit
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
imageio[ffmpeg]==2.37.0
|
||||
imageio[ffmpeg]==2.37.2
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robomimic
|
||||
# scikit-image
|
||||
imageio-ffmpeg==0.6.0
|
||||
# via
|
||||
# imageio
|
||||
# robomimic
|
||||
importlib-metadata==8.7.0
|
||||
# via imageio
|
||||
importlib-metadata==8.7.1
|
||||
# via diffusers
|
||||
importlib-resources==6.5.2
|
||||
# via etils
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
inquirerpy==0.3.4
|
||||
# via huggingface-hub
|
||||
ipython==8.37.0
|
||||
ipython==9.11.0
|
||||
# via meshcat
|
||||
ipython-pygments-lexers==1.1.1
|
||||
# via ipython
|
||||
ischedule==1.2.7
|
||||
# via placo
|
||||
jedi==0.19.2
|
||||
@@ -296,44 +276,24 @@ jinja2==3.1.6
|
||||
# via torch
|
||||
jsonlines==4.0.0
|
||||
# via lerobot
|
||||
jsonschema==4.25.1
|
||||
# via nbformat
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
jupyter-core==5.9.1
|
||||
# via nbformat
|
||||
jupytext==1.18.1
|
||||
# via bddl
|
||||
kiwisolver==1.4.9
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lazy-loader==0.4
|
||||
lazy-loader==0.5
|
||||
# via scikit-image
|
||||
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
|
||||
# via lerobot
|
||||
llvmlite==0.45.1
|
||||
# via numba
|
||||
librt==0.8.1
|
||||
# via mypy
|
||||
lxml==6.0.2
|
||||
# via dm-control
|
||||
markdown==3.9
|
||||
# via tensorboard
|
||||
markdown-it-py==4.0.0
|
||||
# via
|
||||
# jupytext
|
||||
# mdit-py-plugins
|
||||
# via rich
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.10.7
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
# via jinja2
|
||||
matplotlib==3.10.8
|
||||
# via lerobot
|
||||
matplotlib-inline==0.2.1
|
||||
# via ipython
|
||||
mdit-py-plugins==0.5.0
|
||||
# via jupytext
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mergedeep==1.3.4
|
||||
@@ -346,41 +306,35 @@ mock-serial==0.0.1
|
||||
# via lerobot
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mujoco==3.3.7
|
||||
mujoco==3.5.0
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# libero
|
||||
# metaworld
|
||||
# robosuite
|
||||
multidict==6.7.0
|
||||
multidict==6.7.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
multiprocess==0.70.16
|
||||
multiprocess==0.70.18
|
||||
# via datasets
|
||||
mypy==1.19.1
|
||||
# via lerobot
|
||||
mypy-extensions==1.1.0
|
||||
# via typing-inspect
|
||||
nbformat==5.10.4
|
||||
# via jupytext
|
||||
networkx==3.4.2
|
||||
# via
|
||||
# bddl
|
||||
# mypy
|
||||
# typing-inspect
|
||||
networkx==3.6.1
|
||||
# via
|
||||
# scikit-image
|
||||
# torch
|
||||
ninja==1.13.0
|
||||
# via lerobot
|
||||
nodeenv==1.9.1
|
||||
nodeenv==1.10.0
|
||||
# via pre-commit
|
||||
num2words==0.5.14
|
||||
# via lerobot
|
||||
numba==0.62.1
|
||||
# via robosuite
|
||||
numpy==2.2.6
|
||||
# via
|
||||
# accelerate
|
||||
# bddl
|
||||
# cmeel-boost
|
||||
# contourpy
|
||||
# datasets
|
||||
@@ -389,16 +343,14 @@ numpy==2.2.6
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# gymnasium
|
||||
# h5py
|
||||
# hebi-py
|
||||
# imageio
|
||||
# labmaze
|
||||
# libero
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# metaworld
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# opencv-python-headless
|
||||
# pandas
|
||||
@@ -406,26 +358,18 @@ numpy==2.2.6
|
||||
# pyquaternion
|
||||
# reachy2-sdk
|
||||
# rerun-sdk
|
||||
# robomimic
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# scipy
|
||||
# shapely
|
||||
# teleop
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# tifffile
|
||||
# torchvision
|
||||
# transformers
|
||||
# transforms3d
|
||||
omegaconf==2.3.0
|
||||
# via hydra-core
|
||||
opencv-python==4.12.0.88
|
||||
opencv-python==4.13.0.92
|
||||
# via
|
||||
# gym-pusht
|
||||
# libero
|
||||
# reachy2-sdk
|
||||
# robosuite
|
||||
opencv-python-headless==4.12.0.88
|
||||
# via lerobot
|
||||
orderly-set==5.5.0
|
||||
@@ -435,97 +379,87 @@ packaging==25.0
|
||||
# accelerate
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# hydra-core
|
||||
# jupytext
|
||||
# lazy-loader
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# peft
|
||||
# pytest
|
||||
# qwen-vl-utils
|
||||
# reachy2-sdk
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# transformers
|
||||
# wandb
|
||||
pandas==2.3.3
|
||||
# via
|
||||
# datasets
|
||||
# lerobot
|
||||
parso==0.8.5
|
||||
parso==0.8.6
|
||||
# via jedi
|
||||
peft==0.17.1
|
||||
pathspec==1.0.4
|
||||
# via mypy
|
||||
peft==0.18.1
|
||||
# via lerobot
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pfzy==0.3.4
|
||||
# via inquirerpy
|
||||
pillow==12.0.0
|
||||
pillow==12.1.1
|
||||
# via
|
||||
# diffusers
|
||||
# imageio
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# qwen-vl-utils
|
||||
# rerun-sdk
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# torchvision
|
||||
pin==3.4.0
|
||||
# via placo
|
||||
placo==0.9.14
|
||||
placo==0.9.16
|
||||
# via lerobot
|
||||
platformdirs==4.5.0
|
||||
platformdirs==4.9.4
|
||||
# via
|
||||
# jupyter-core
|
||||
# python-discovery
|
||||
# virtualenv
|
||||
# wandb
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
pre-commit==4.3.0
|
||||
pre-commit==4.5.1
|
||||
# via lerobot
|
||||
prompt-toolkit==3.0.52
|
||||
# via
|
||||
# inquirerpy
|
||||
# ipython
|
||||
# via ipython
|
||||
propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
protobuf==6.31.0
|
||||
protobuf==6.31.1
|
||||
# via
|
||||
# dm-control
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# wandb
|
||||
psutil==7.1.1
|
||||
psutil==7.2.2
|
||||
# via
|
||||
# accelerate
|
||||
# imageio
|
||||
# peft
|
||||
# robomimic
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
# via stack-data
|
||||
pyarrow==21.0.0
|
||||
pyarrow==23.0.1
|
||||
# via
|
||||
# datasets
|
||||
# rerun-sdk
|
||||
pycparser==2.23
|
||||
pycparser==3.0
|
||||
# via cffi
|
||||
pydantic==2.12.3
|
||||
pydantic==2.12.5
|
||||
# via
|
||||
# fastapi
|
||||
# wandb
|
||||
pydantic-core==2.41.4
|
||||
pydantic-core==2.41.5
|
||||
# via pydantic
|
||||
pygame==2.6.1
|
||||
# via
|
||||
@@ -535,33 +469,35 @@ pygame==2.6.1
|
||||
pygments==2.19.2
|
||||
# via
|
||||
# ipython
|
||||
# ipython-pygments-lexers
|
||||
# pytest
|
||||
# rich
|
||||
pymunk==6.11.1
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pyngrok==7.4.1
|
||||
pyngrok==7.5.1
|
||||
# via meshcat
|
||||
pynput==1.8.1
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
pyobjc-core==12.0
|
||||
pyobjc-core==12.1
|
||||
# via
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-cocoa
|
||||
# pyobjc-framework-coretext
|
||||
# pyobjc-framework-quartz
|
||||
pyobjc-framework-applicationservices==12.0
|
||||
pyobjc-framework-applicationservices==12.1
|
||||
# via pynput
|
||||
pyobjc-framework-cocoa==12.0
|
||||
pyobjc-framework-cocoa==12.1
|
||||
# via
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-coretext
|
||||
# pyobjc-framework-quartz
|
||||
pyobjc-framework-coretext==12.0
|
||||
pyobjc-framework-coretext==12.1
|
||||
# via pyobjc-framework-applicationservices
|
||||
pyobjc-framework-quartz==12.0
|
||||
pyobjc-framework-quartz==12.1
|
||||
# via
|
||||
# pynput
|
||||
# pyobjc-framework-applicationservices
|
||||
@@ -570,13 +506,13 @@ pyopengl==3.1.10
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.5
|
||||
pyparsing==3.3.2
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via reachy2-sdk
|
||||
pyrealsense2-macosx==2.54.2
|
||||
pyrealsense2-macosx==2.56.5
|
||||
# via lerobot
|
||||
pyserial==3.5
|
||||
# via
|
||||
@@ -585,7 +521,6 @@ pyserial==3.5
|
||||
# lerobot
|
||||
pytest==8.4.2
|
||||
# via
|
||||
# bddl
|
||||
# lerobot
|
||||
# pytest-cov
|
||||
# pytest-timeout
|
||||
@@ -596,11 +531,14 @@ pytest-timeout==2.4.0
|
||||
# via lerobot
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# faker
|
||||
# matplotlib
|
||||
# pandas
|
||||
python-dotenv==1.1.1
|
||||
python-discovery==1.1.1
|
||||
# via virtualenv
|
||||
python-dotenv==1.2.2
|
||||
# via uvicorn
|
||||
pytz==2025.2
|
||||
pytz==2026.1.post1
|
||||
# via pandas
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
@@ -609,13 +547,10 @@ pyyaml==6.0.3
|
||||
# draccus
|
||||
# hebi-py
|
||||
# huggingface-hub
|
||||
# jupytext
|
||||
# omegaconf
|
||||
# peft
|
||||
# pre-commit
|
||||
# pyngrok
|
||||
# pyyaml-include
|
||||
# timm
|
||||
# transformers
|
||||
# uvicorn
|
||||
# wandb
|
||||
@@ -625,15 +560,13 @@ pyzmq==27.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# meshcat
|
||||
reachy2-sdk==1.0.14
|
||||
qwen-vl-utils==0.0.14
|
||||
# via lerobot
|
||||
reachy2-sdk==1.0.15
|
||||
# via lerobot
|
||||
reachy2-sdk-api==1.0.21
|
||||
# via reachy2-sdk
|
||||
referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2025.10.23
|
||||
regex==2026.2.28
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
@@ -642,184 +575,150 @@ requests==2.32.5
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# qwen-vl-utils
|
||||
# teleop
|
||||
# transformers
|
||||
# wandb
|
||||
rerun-sdk==0.26.1
|
||||
rerun-sdk==0.26.2
|
||||
# via lerobot
|
||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||
# via placo
|
||||
robomimic==0.2.0
|
||||
# via libero
|
||||
robosuite==1.4.0
|
||||
# via libero
|
||||
rpds-py==0.28.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.6.2
|
||||
rich==14.3.3
|
||||
# via typer
|
||||
safetensors==0.7.0
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# transformers
|
||||
scikit-image==0.25.2
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
scipy==1.15.3
|
||||
scipy==1.17.1
|
||||
# via
|
||||
# dm-control
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robosuite
|
||||
# scikit-image
|
||||
sentry-sdk==2.42.1
|
||||
# torchdiffeq
|
||||
sentry-sdk==2.54.0
|
||||
# via wandb
|
||||
shapely==2.1.2
|
||||
# via gym-pusht
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.17.0
|
||||
# via
|
||||
# pynput
|
||||
# python-dateutil
|
||||
smmap==5.0.2
|
||||
smmap==5.0.3
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
starlette==0.48.0
|
||||
starlette==0.52.1
|
||||
# via fastapi
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
teleop==0.1.2
|
||||
teleop==0.1.4
|
||||
# via lerobot
|
||||
tensorboard==2.20.0
|
||||
# via robomimic
|
||||
tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
tensorboardx==2.6.4
|
||||
# via robomimic
|
||||
termcolor==3.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
thop==0.1.1.post2209072238
|
||||
# via libero
|
||||
tifffile==2025.5.10
|
||||
termcolor==3.3.0
|
||||
# via lerobot
|
||||
tifffile==2026.3.3
|
||||
# via scikit-image
|
||||
timm==1.0.20
|
||||
# via lerobot
|
||||
tokenizers==0.22.1
|
||||
tokenizers==0.22.2
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via draccus
|
||||
tomli==2.3.0
|
||||
# via
|
||||
# cmeel
|
||||
# coverage
|
||||
# jupytext
|
||||
# pytest
|
||||
torch==2.7.1
|
||||
torch==2.10.0
|
||||
# via
|
||||
# accelerate
|
||||
# lerobot
|
||||
# peft
|
||||
# robomimic
|
||||
# thop
|
||||
# timm
|
||||
# torchdiffeq
|
||||
# torchvision
|
||||
torchcodec==0.5
|
||||
torchcodec==0.10.0
|
||||
# via lerobot
|
||||
torchvision==0.22.1
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
# timm
|
||||
tornado==6.5.2
|
||||
torchdiffeq==0.2.5
|
||||
# via lerobot
|
||||
torchvision==0.25.0
|
||||
# via lerobot
|
||||
tornado==6.5.4
|
||||
# via meshcat
|
||||
tqdm==4.67.1
|
||||
tqdm==4.67.3
|
||||
# via
|
||||
# datasets
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# peft
|
||||
# robomimic
|
||||
# transformers
|
||||
traitlets==5.14.3
|
||||
# via
|
||||
# ipython
|
||||
# jupyter-core
|
||||
# matplotlib-inline
|
||||
# nbformat
|
||||
transformers==4.57.1
|
||||
transformers==5.3.0
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
# peft
|
||||
transforms3d==0.4.2
|
||||
# via teleop
|
||||
typer==0.24.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# etils
|
||||
# exceptiongroup
|
||||
# faker
|
||||
# fastapi
|
||||
# gymnasium
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# multidict
|
||||
# mypy
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
# starlette
|
||||
# torch
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# uvicorn
|
||||
# virtualenv
|
||||
# wandb
|
||||
typing-inspect==0.9.0
|
||||
# via draccus
|
||||
typing-inspection==0.4.2
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via
|
||||
# fastapi
|
||||
# pydantic
|
||||
tzdata==2025.3
|
||||
# via pandas
|
||||
u-msgpack-python==2.8.0
|
||||
# via meshcat
|
||||
urllib3==2.5.0
|
||||
urllib3==2.6.3
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
uvicorn[standard]==0.38.0
|
||||
uvicorn[standard]==0.41.0
|
||||
# via teleop
|
||||
uvloop==0.22.1
|
||||
# via uvicorn
|
||||
virtualenv==20.35.3
|
||||
virtualenv==21.1.0
|
||||
# via pre-commit
|
||||
wandb==0.21.4
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
wandb==0.24.2
|
||||
# via lerobot
|
||||
watchfiles==1.1.1
|
||||
# via uvicorn
|
||||
wcwidth==0.2.14
|
||||
wcwidth==0.6.0
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.9.0
|
||||
# via teleop
|
||||
websockets==15.0.1
|
||||
websockets==16.0
|
||||
# via uvicorn
|
||||
werkzeug==3.1.3
|
||||
# via tensorboard
|
||||
wrapt==2.0.0
|
||||
wrapt==2.1.2
|
||||
# via dm-tree
|
||||
xxhash==3.6.0
|
||||
# via datasets
|
||||
yarl==1.22.0
|
||||
yarl==1.23.0
|
||||
# via aiohttp
|
||||
zipp==3.23.0
|
||||
# via
|
||||
|
||||
+209
-188
@@ -1,12 +1,12 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with Python 3.10
|
||||
# This file is autogenerated by pip-compile with Python 3.12
|
||||
# by the following command:
|
||||
#
|
||||
# pip-compile --output-file=requirements-ubuntu.txt requirements.in
|
||||
#
|
||||
-e .[all]
|
||||
# via -[all]
|
||||
absl-py==2.3.1
|
||||
absl-py==2.4.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
@@ -14,30 +14,33 @@ absl-py==2.3.1
|
||||
# labmaze
|
||||
# mujoco
|
||||
# tensorboard
|
||||
accelerate==1.11.0
|
||||
accelerate==1.13.0
|
||||
# via
|
||||
# lerobot
|
||||
# peft
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.1
|
||||
aiohttp==3.13.3
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-doc==0.0.4
|
||||
# via
|
||||
# fastapi
|
||||
# typer
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via
|
||||
# hydra-core
|
||||
# omegaconf
|
||||
anyio==4.11.0
|
||||
anyio==4.12.1
|
||||
# via
|
||||
# httpx
|
||||
# starlette
|
||||
# watchfiles
|
||||
asttokens==3.0.0
|
||||
asttokens==3.0.1
|
||||
# via stack-data
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
@@ -47,30 +50,35 @@ attrs==25.4.0
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
av==15.1.0
|
||||
# via lerobot
|
||||
bddl==1.0.1
|
||||
# via libero
|
||||
certifi==2025.10.5
|
||||
# via
|
||||
# lerobot
|
||||
# qwen-vl-utils
|
||||
bddl==1.0.1
|
||||
# via hf-libero
|
||||
certifi==2026.2.25
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==2.0.0
|
||||
# via pymunk
|
||||
cfgv==3.4.0
|
||||
cfgv==3.5.0
|
||||
# via pre-commit
|
||||
charset-normalizer==3.4.4
|
||||
charset-normalizer==3.4.5
|
||||
# via requests
|
||||
click==8.3.0
|
||||
click==8.3.1
|
||||
# via
|
||||
# typer
|
||||
# uvicorn
|
||||
# wandb
|
||||
cloudpickle==3.1.1
|
||||
cloudpickle==3.1.2
|
||||
# via
|
||||
# gymnasium
|
||||
# libero
|
||||
cmake==4.1.0
|
||||
# hf-libero
|
||||
cmake==4.1.3
|
||||
# via lerobot
|
||||
cmeel==0.57.3
|
||||
cmeel==0.59.0
|
||||
# via
|
||||
# cmeel-assimp
|
||||
# cmeel-boost
|
||||
@@ -108,20 +116,24 @@ cmeel-zlib==1.3.1
|
||||
# via cmeel-assimp
|
||||
coal-library==3.0.1
|
||||
# via pin
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
coverage[toml]==7.11.0
|
||||
contourpy==1.3.3
|
||||
# via
|
||||
# lerobot
|
||||
# matplotlib
|
||||
coverage[toml]==7.13.4
|
||||
# via pytest-cov
|
||||
cuda-bindings==12.9.4
|
||||
# via torch
|
||||
cuda-pathfinder==1.4.1
|
||||
# via cuda-bindings
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
datasets==4.1.1
|
||||
datasets==4.6.1
|
||||
# via lerobot
|
||||
debugpy==1.8.17
|
||||
debugpy==1.8.20
|
||||
# via lerobot
|
||||
decorator==5.2.1
|
||||
# via ipython
|
||||
decord==0.6.0
|
||||
# via lerobot
|
||||
deepdiff==8.6.1
|
||||
# via lerobot
|
||||
diffusers==0.35.2
|
||||
@@ -132,7 +144,7 @@ dill==0.4.0
|
||||
# multiprocess
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
dm-control==1.0.34
|
||||
dm-control==1.0.37
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
@@ -140,7 +152,6 @@ dm-tree==0.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# lerobot
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
draccus==0.10.0
|
||||
@@ -148,66 +159,60 @@ draccus==0.10.0
|
||||
dynamixel-sdk==3.8.4
|
||||
# via lerobot
|
||||
easydict==1.13
|
||||
# via libero
|
||||
egl-probe @ git+https://github.com/huggingface/egl_probe.git
|
||||
# via
|
||||
# libero
|
||||
# robomimic
|
||||
# via hf-libero
|
||||
egl-probe==1.0.2
|
||||
# via robomimic
|
||||
eigenpy==3.10.3
|
||||
# via coal-library
|
||||
einops==0.8.1
|
||||
einops==0.8.2
|
||||
# via
|
||||
# flash-attn
|
||||
# hf-libero
|
||||
# lerobot
|
||||
# libero
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
etils[epath,epy]==1.13.0
|
||||
etils[epath,epy]==1.14.0
|
||||
# via mujoco
|
||||
evdev==1.9.2
|
||||
evdev==1.9.3
|
||||
# via pynput
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# ipython
|
||||
# pytest
|
||||
executing==2.2.1
|
||||
# via stack-data
|
||||
faker==34.0.2
|
||||
# via lerobot
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fastapi==0.119.1
|
||||
# via teleop
|
||||
fastapi==0.135.1
|
||||
# via
|
||||
# lerobot
|
||||
# teleop
|
||||
fastjsonschema==2.21.2
|
||||
# via nbformat
|
||||
feetech-servo-sdk==1.0.0
|
||||
# via lerobot
|
||||
filelock==3.20.0
|
||||
filelock==3.25.0
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# huggingface-hub
|
||||
# python-discovery
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
flash-attn==2.8.3
|
||||
# via lerobot
|
||||
fonttools==4.60.1
|
||||
fonttools==4.61.1
|
||||
# via matplotlib
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2025.9.0
|
||||
fsspec[http]==2026.2.0
|
||||
# via
|
||||
# datasets
|
||||
# etils
|
||||
# huggingface-hub
|
||||
# torch
|
||||
future==1.0.0
|
||||
# via libero
|
||||
# via hf-libero
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.45
|
||||
gitpython==3.1.46
|
||||
# via wandb
|
||||
glfw==2.10.0
|
||||
# via
|
||||
@@ -230,50 +235,60 @@ gym-hil==0.1.13
|
||||
# via lerobot
|
||||
gym-pusht==0.1.6
|
||||
# via lerobot
|
||||
gymnasium==1.2.1
|
||||
gymnasium==1.2.3
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# hf-libero
|
||||
# lerobot
|
||||
# libero
|
||||
# metaworld
|
||||
h11==0.16.0
|
||||
# via uvicorn
|
||||
h5py==3.15.1
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
h5py==3.16.0
|
||||
# via robomimic
|
||||
hebi-py==2.11.0
|
||||
# via lerobot
|
||||
hf-transfer==0.1.9
|
||||
# via huggingface-hub
|
||||
hf-xet==1.1.10
|
||||
hf-egl-probe==1.0.2
|
||||
# via hf-libero
|
||||
hf-libero==0.1.3
|
||||
# via lerobot
|
||||
hf-xet==1.3.2
|
||||
# via huggingface-hub
|
||||
hidapi==0.14.0.post4
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httptools==0.7.1
|
||||
# via uvicorn
|
||||
huggingface-hub[cli,hf-transfer]==0.35.3
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
huggingface-hub==1.6.0
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
hydra-core==1.3.2
|
||||
# via libero
|
||||
identify==2.6.15
|
||||
# via hf-libero
|
||||
identify==2.6.17
|
||||
# via pre-commit
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
imageio[ffmpeg]==2.37.0
|
||||
imageio[ffmpeg]==2.37.2
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
@@ -285,16 +300,14 @@ imageio-ffmpeg==0.6.0
|
||||
# via
|
||||
# imageio
|
||||
# robomimic
|
||||
importlib-metadata==8.7.0
|
||||
importlib-metadata==8.7.1
|
||||
# via diffusers
|
||||
importlib-resources==6.5.2
|
||||
# via etils
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
inquirerpy==0.3.4
|
||||
# via huggingface-hub
|
||||
ipython==8.37.0
|
||||
ipython==9.11.0
|
||||
# via meshcat
|
||||
ipython-pygments-lexers==1.1.1
|
||||
# via ipython
|
||||
ischedule==1.2.7
|
||||
# via placo
|
||||
jedi==0.19.2
|
||||
@@ -303,40 +316,41 @@ jinja2==3.1.6
|
||||
# via torch
|
||||
jsonlines==4.0.0
|
||||
# via lerobot
|
||||
jsonschema==4.25.1
|
||||
jsonschema==4.26.0
|
||||
# via nbformat
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
jupyter-core==5.9.1
|
||||
# via nbformat
|
||||
jupytext==1.18.1
|
||||
jupytext==1.19.1
|
||||
# via bddl
|
||||
kiwisolver==1.4.9
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lazy-loader==0.4
|
||||
lazy-loader==0.5
|
||||
# via scikit-image
|
||||
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
|
||||
# via lerobot
|
||||
llvmlite==0.45.1
|
||||
librt==0.8.1
|
||||
# via mypy
|
||||
llvmlite==0.46.0
|
||||
# via numba
|
||||
lxml==6.0.2
|
||||
# via dm-control
|
||||
markdown==3.9
|
||||
markdown==3.10.2
|
||||
# via tensorboard
|
||||
markdown-it-py==4.0.0
|
||||
# via
|
||||
# jupytext
|
||||
# mdit-py-plugins
|
||||
# rich
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.10.7
|
||||
matplotlib==3.10.8
|
||||
# via
|
||||
# hf-libero
|
||||
# lerobot
|
||||
# libero
|
||||
matplotlib-inline==0.2.1
|
||||
# via ipython
|
||||
mdit-py-plugins==0.5.0
|
||||
@@ -353,36 +367,38 @@ mock-serial==0.0.1
|
||||
# via lerobot
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mujoco==3.3.7
|
||||
mujoco==3.5.0
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# libero
|
||||
# hf-libero
|
||||
# metaworld
|
||||
# robosuite
|
||||
multidict==6.7.0
|
||||
multidict==6.7.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
multiprocess==0.70.16
|
||||
multiprocess==0.70.18
|
||||
# via datasets
|
||||
mypy==1.19.1
|
||||
# via lerobot
|
||||
mypy-extensions==1.1.0
|
||||
# via typing-inspect
|
||||
# via
|
||||
# mypy
|
||||
# typing-inspect
|
||||
nbformat==5.10.4
|
||||
# via jupytext
|
||||
networkx==3.4.2
|
||||
networkx==3.6.1
|
||||
# via
|
||||
# bddl
|
||||
# scikit-image
|
||||
# torch
|
||||
ninja==1.13.0
|
||||
# via lerobot
|
||||
nodeenv==1.9.1
|
||||
nodeenv==1.10.0
|
||||
# via pre-commit
|
||||
num2words==0.5.14
|
||||
# via lerobot
|
||||
numba==0.62.1
|
||||
numba==0.64.0
|
||||
# via robosuite
|
||||
numpy==2.2.6
|
||||
# via
|
||||
@@ -391,7 +407,6 @@ numpy==2.2.6
|
||||
# cmeel-boost
|
||||
# contourpy
|
||||
# datasets
|
||||
# decord
|
||||
# diffusers
|
||||
# dm-control
|
||||
# dm-env
|
||||
@@ -399,9 +414,10 @@ numpy==2.2.6
|
||||
# gymnasium
|
||||
# h5py
|
||||
# hebi-py
|
||||
# hf-libero
|
||||
# imageio
|
||||
# labmaze
|
||||
# libero
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# metaworld
|
||||
@@ -426,49 +442,51 @@ numpy==2.2.6
|
||||
# torchvision
|
||||
# transformers
|
||||
# transforms3d
|
||||
nvidia-cublas-cu12==12.6.4.1
|
||||
nvidia-cublas-cu12==12.8.4.1
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# torch
|
||||
nvidia-cuda-cupti-cu12==12.6.80
|
||||
nvidia-cuda-cupti-cu12==12.8.90
|
||||
# via torch
|
||||
nvidia-cuda-nvrtc-cu12==12.6.77
|
||||
nvidia-cuda-nvrtc-cu12==12.8.93
|
||||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.6.77
|
||||
nvidia-cuda-runtime-cu12==12.8.90
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==9.5.1.17
|
||||
nvidia-cudnn-cu12==9.10.2.21
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.3.0.4
|
||||
nvidia-cufft-cu12==11.3.3.83
|
||||
# via torch
|
||||
nvidia-cufile-cu12==1.11.1.6
|
||||
nvidia-cufile-cu12==1.13.1.3
|
||||
# via torch
|
||||
nvidia-curand-cu12==10.3.7.77
|
||||
nvidia-curand-cu12==10.3.9.90
|
||||
# via torch
|
||||
nvidia-cusolver-cu12==11.7.1.2
|
||||
nvidia-cusolver-cu12==11.7.3.90
|
||||
# via torch
|
||||
nvidia-cusparse-cu12==12.5.4.2
|
||||
nvidia-cusparse-cu12==12.5.8.93
|
||||
# via
|
||||
# nvidia-cusolver-cu12
|
||||
# torch
|
||||
nvidia-cusparselt-cu12==0.6.3
|
||||
nvidia-cusparselt-cu12==0.7.1
|
||||
# via torch
|
||||
nvidia-nccl-cu12==2.26.2
|
||||
nvidia-nccl-cu12==2.27.5
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.6.85
|
||||
nvidia-nvjitlink-cu12==12.8.93
|
||||
# via
|
||||
# nvidia-cufft-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# nvidia-cusparse-cu12
|
||||
# torch
|
||||
nvidia-nvtx-cu12==12.6.77
|
||||
nvidia-nvshmem-cu12==3.4.5
|
||||
# via torch
|
||||
nvidia-nvtx-cu12==12.8.90
|
||||
# via torch
|
||||
omegaconf==2.3.0
|
||||
# via hydra-core
|
||||
opencv-python==4.12.0.88
|
||||
opencv-python==4.13.0.92
|
||||
# via
|
||||
# gym-pusht
|
||||
# libero
|
||||
# hf-libero
|
||||
# reachy2-sdk
|
||||
# robosuite
|
||||
opencv-python-headless==4.12.0.88
|
||||
@@ -487,6 +505,7 @@ packaging==25.0
|
||||
# matplotlib
|
||||
# peft
|
||||
# pytest
|
||||
# qwen-vl-utils
|
||||
# reachy2-sdk
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
@@ -497,21 +516,21 @@ pandas==2.3.3
|
||||
# via
|
||||
# datasets
|
||||
# lerobot
|
||||
parso==0.8.5
|
||||
parso==0.8.6
|
||||
# via jedi
|
||||
peft==0.17.1
|
||||
pathspec==1.0.4
|
||||
# via mypy
|
||||
peft==0.18.1
|
||||
# via lerobot
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pfzy==0.3.4
|
||||
# via inquirerpy
|
||||
pillow==12.0.0
|
||||
pillow==12.1.1
|
||||
# via
|
||||
# diffusers
|
||||
# imageio
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# qwen-vl-utils
|
||||
# rerun-sdk
|
||||
# robosuite
|
||||
# scikit-image
|
||||
@@ -519,28 +538,27 @@ pillow==12.0.0
|
||||
# torchvision
|
||||
pin==3.4.0
|
||||
# via placo
|
||||
placo==0.9.14
|
||||
placo==0.9.16
|
||||
# via lerobot
|
||||
platformdirs==4.5.0
|
||||
platformdirs==4.9.4
|
||||
# via
|
||||
# jupyter-core
|
||||
# python-discovery
|
||||
# virtualenv
|
||||
# wandb
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
pre-commit==4.3.0
|
||||
pre-commit==4.5.1
|
||||
# via lerobot
|
||||
prompt-toolkit==3.0.52
|
||||
# via
|
||||
# inquirerpy
|
||||
# ipython
|
||||
# via ipython
|
||||
propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
protobuf==6.31.0
|
||||
protobuf==6.31.1
|
||||
# via
|
||||
# dm-control
|
||||
# grpcio-tools
|
||||
@@ -550,7 +568,7 @@ protobuf==6.31.0
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# wandb
|
||||
psutil==7.1.1
|
||||
psutil==7.2.2
|
||||
# via
|
||||
# accelerate
|
||||
# imageio
|
||||
@@ -560,17 +578,17 @@ ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
# via stack-data
|
||||
pyarrow==21.0.0
|
||||
pyarrow==23.0.1
|
||||
# via
|
||||
# datasets
|
||||
# rerun-sdk
|
||||
pycparser==2.23
|
||||
pycparser==3.0
|
||||
# via cffi
|
||||
pydantic==2.12.3
|
||||
pydantic==2.12.5
|
||||
# via
|
||||
# fastapi
|
||||
# wandb
|
||||
pydantic-core==2.41.4
|
||||
pydantic-core==2.41.5
|
||||
# via pydantic
|
||||
pygame==2.6.1
|
||||
# via
|
||||
@@ -580,12 +598,14 @@ pygame==2.6.1
|
||||
pygments==2.19.2
|
||||
# via
|
||||
# ipython
|
||||
# ipython-pygments-lexers
|
||||
# pytest
|
||||
# rich
|
||||
pymunk==6.11.1
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pyngrok==7.4.1
|
||||
pyngrok==7.5.1
|
||||
# via meshcat
|
||||
pynput==1.8.1
|
||||
# via
|
||||
@@ -595,7 +615,7 @@ pyopengl==3.1.10
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.5
|
||||
pyparsing==3.3.2
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
@@ -621,13 +641,16 @@ pytest-timeout==2.4.0
|
||||
# via lerobot
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# faker
|
||||
# matplotlib
|
||||
# pandas
|
||||
python-dotenv==1.1.1
|
||||
python-discovery==1.1.1
|
||||
# via virtualenv
|
||||
python-dotenv==1.2.2
|
||||
# via uvicorn
|
||||
python-xlib==0.33
|
||||
# via pynput
|
||||
pytz==2025.2
|
||||
pytz==2026.1.post1
|
||||
# via pandas
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
@@ -642,7 +665,6 @@ pyyaml==6.0.3
|
||||
# pre-commit
|
||||
# pyngrok
|
||||
# pyyaml-include
|
||||
# timm
|
||||
# transformers
|
||||
# uvicorn
|
||||
# wandb
|
||||
@@ -652,7 +674,9 @@ pyzmq==27.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# meshcat
|
||||
reachy2-sdk==1.0.14
|
||||
qwen-vl-utils==0.0.14
|
||||
# via lerobot
|
||||
reachy2-sdk==1.0.15
|
||||
# via lerobot
|
||||
reachy2-sdk-api==1.0.21
|
||||
# via reachy2-sdk
|
||||
@@ -660,7 +684,7 @@ referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2025.10.23
|
||||
regex==2026.2.28
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
@@ -669,60 +693,62 @@ requests==2.32.5
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# qwen-vl-utils
|
||||
# teleop
|
||||
# transformers
|
||||
# wandb
|
||||
rerun-sdk==0.26.1
|
||||
rerun-sdk==0.26.2
|
||||
# via lerobot
|
||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||
# via placo
|
||||
rich==14.3.3
|
||||
# via typer
|
||||
robomimic==0.2.0
|
||||
# via libero
|
||||
# via hf-libero
|
||||
robosuite==1.4.0
|
||||
# via libero
|
||||
rpds-py==0.28.0
|
||||
# via hf-libero
|
||||
rpds-py==0.30.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.6.2
|
||||
safetensors==0.7.0
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# transformers
|
||||
scikit-image==0.25.2
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
scipy==1.15.3
|
||||
scipy==1.17.1
|
||||
# via
|
||||
# dm-control
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robosuite
|
||||
# scikit-image
|
||||
sentry-sdk==2.42.1
|
||||
# torchdiffeq
|
||||
sentry-sdk==2.54.0
|
||||
# via wandb
|
||||
shapely==2.1.2
|
||||
# via gym-pusht
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.17.0
|
||||
# via
|
||||
# pynput
|
||||
# python-dateutil
|
||||
# python-xlib
|
||||
smmap==5.0.2
|
||||
smmap==5.0.3
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
starlette==0.48.0
|
||||
starlette==0.52.1
|
||||
# via fastapi
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
teleop==0.1.2
|
||||
teleop==0.1.4
|
||||
# via lerobot
|
||||
tensorboard==2.20.0
|
||||
# via robomimic
|
||||
@@ -730,46 +756,38 @@ tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
tensorboardx==2.6.4
|
||||
# via robomimic
|
||||
termcolor==3.1.0
|
||||
termcolor==3.3.0
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
thop==0.1.1.post2209072238
|
||||
# via libero
|
||||
tifffile==2025.5.10
|
||||
# via hf-libero
|
||||
tifffile==2026.3.3
|
||||
# via scikit-image
|
||||
timm==1.0.20
|
||||
# via lerobot
|
||||
tokenizers==0.22.1
|
||||
tokenizers==0.22.2
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via draccus
|
||||
tomli==2.3.0
|
||||
# via
|
||||
# cmeel
|
||||
# coverage
|
||||
# jupytext
|
||||
# pytest
|
||||
torch==2.7.1
|
||||
torch==2.10.0
|
||||
# via
|
||||
# accelerate
|
||||
# flash-attn
|
||||
# lerobot
|
||||
# peft
|
||||
# robomimic
|
||||
# thop
|
||||
# timm
|
||||
# torchdiffeq
|
||||
# torchvision
|
||||
torchcodec==0.5
|
||||
torchcodec==0.10.0
|
||||
# via lerobot
|
||||
torchvision==0.22.1
|
||||
torchdiffeq==0.2.5
|
||||
# via lerobot
|
||||
torchvision==0.25.0
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
# timm
|
||||
tornado==6.5.2
|
||||
tornado==6.5.4
|
||||
# via meshcat
|
||||
tqdm==4.67.1
|
||||
tqdm==4.67.3
|
||||
# via
|
||||
# datasets
|
||||
# dm-control
|
||||
@@ -783,26 +801,29 @@ traitlets==5.14.3
|
||||
# jupyter-core
|
||||
# matplotlib-inline
|
||||
# nbformat
|
||||
transformers==4.57.1
|
||||
transformers==5.3.0
|
||||
# via
|
||||
# hf-libero
|
||||
# lerobot
|
||||
# libero
|
||||
# peft
|
||||
transforms3d==0.4.2
|
||||
# via teleop
|
||||
triton==3.3.1
|
||||
triton==3.6.0
|
||||
# via torch
|
||||
typer==0.24.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# etils
|
||||
# exceptiongroup
|
||||
# faker
|
||||
# fastapi
|
||||
# gymnasium
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# multidict
|
||||
# mypy
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
@@ -811,46 +832,46 @@ typing-extensions==4.15.0
|
||||
# torch
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# uvicorn
|
||||
# virtualenv
|
||||
# wandb
|
||||
typing-inspect==0.9.0
|
||||
# via draccus
|
||||
typing-inspection==0.4.2
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via
|
||||
# fastapi
|
||||
# pydantic
|
||||
tzdata==2025.3
|
||||
# via pandas
|
||||
u-msgpack-python==2.8.0
|
||||
# via meshcat
|
||||
urllib3==2.5.0
|
||||
urllib3==2.6.3
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
uvicorn[standard]==0.38.0
|
||||
uvicorn[standard]==0.41.0
|
||||
# via teleop
|
||||
uvloop==0.22.1
|
||||
# via uvicorn
|
||||
virtualenv==20.35.3
|
||||
virtualenv==21.1.0
|
||||
# via pre-commit
|
||||
wandb==0.21.4
|
||||
wandb==0.24.2
|
||||
# via
|
||||
# hf-libero
|
||||
# lerobot
|
||||
# libero
|
||||
watchfiles==1.1.1
|
||||
# via uvicorn
|
||||
wcwidth==0.2.14
|
||||
wcwidth==0.6.0
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.9.0
|
||||
# via teleop
|
||||
websockets==15.0.1
|
||||
websockets==16.0
|
||||
# via uvicorn
|
||||
werkzeug==3.1.3
|
||||
werkzeug==3.1.6
|
||||
# via tensorboard
|
||||
wrapt==2.0.0
|
||||
wrapt==2.1.2
|
||||
# via dm-tree
|
||||
xxhash==3.6.0
|
||||
# via datasets
|
||||
yarl==1.22.0
|
||||
yarl==1.23.0
|
||||
# via aiohttp
|
||||
zipp==3.23.0
|
||||
# via
|
||||
|
||||
+4
-4
@@ -1,9 +1,9 @@
|
||||
# requirements.in
|
||||
|
||||
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.0.1 25A362 arm64).
|
||||
# Darwin MacBook-Pro.local 25.0.0 Darwin Kernel Version 25.0.0: Wed Sep 17 21:42:08 PDT 2025; root:xnu-12377.1.9~141/RELEASE_ARM64_T8132 arm64
|
||||
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.3.1 25D2128 arm64).
|
||||
# Darwin MacBook-Pro.local 25.3.0 Darwin Kernel Version 25.3.0: Wed Jan 28 20:54:55 PST 2026; root:xnu-12377.91.3~2/RELEASE_ARM64_T8132 arm64
|
||||
|
||||
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.3 LTS x86_64).
|
||||
# Linux mlerobot-linux 6.14.0-33-generic #33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
||||
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.4 LTS x86_64).
|
||||
# Linux lerobot-linux 6.17.0-14-generic #14~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Jan 15 15:52:10 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
||||
|
||||
-e .[all]
|
||||
|
||||
@@ -49,9 +49,14 @@ import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.robots import (
|
||||
RobotConfig, # noqa: F401
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
|
||||
@@ -181,7 +181,7 @@ class ZMQCamera(Camera):
|
||||
try:
|
||||
message = self.socket.recv_string()
|
||||
except Exception as e:
|
||||
# Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import
|
||||
# zmq is lazy-imported in connect(), so check by name to avoid a top-level import
|
||||
if type(e).__name__ == "Again":
|
||||
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
|
||||
raise
|
||||
|
||||
@@ -23,6 +23,7 @@ import base64
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
@@ -42,10 +43,57 @@ def encode_image(image: np.ndarray, quality: int = 80) -> str:
|
||||
return base64.b64encode(buffer).decode("utf-8")
|
||||
|
||||
|
||||
class CameraCaptureThread:
|
||||
"""Background thread that continuously captures and encodes frames from a camera."""
|
||||
|
||||
def __init__(self, camera: OpenCVCamera, name: str):
|
||||
self.camera = camera
|
||||
self.name = name
|
||||
self.latest_encoded: str | None = None # Pre-encoded JPEG as base64
|
||||
self.latest_timestamp: float = 0.0
|
||||
self.frame_lock = threading.Lock()
|
||||
self.running = False
|
||||
self.thread: threading.Thread | None = None
|
||||
|
||||
def start(self):
|
||||
"""Start the capture thread."""
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._capture_loop, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop the capture thread."""
|
||||
self.running = False
|
||||
if self.thread:
|
||||
self.thread.join(timeout=1.0)
|
||||
|
||||
def _capture_loop(self):
|
||||
"""Continuously capture and encode frames at the camera's native rate."""
|
||||
while self.running:
|
||||
try:
|
||||
frame = self.camera.read() # Blocks at camera's native rate
|
||||
timestamp = time.time()
|
||||
# Encode immediately in capture thread (this is the slow part)
|
||||
encoded = encode_image(frame)
|
||||
with self.frame_lock:
|
||||
self.latest_encoded = encoded
|
||||
self.latest_timestamp = timestamp
|
||||
except Exception as e:
|
||||
logger.warning(f"Camera {self.name} capture error: {e}")
|
||||
time.sleep(0.01)
|
||||
|
||||
def get_latest(self) -> tuple[str | None, float]:
|
||||
"""Get the latest encoded frame and its timestamp."""
|
||||
with self.frame_lock:
|
||||
return self.latest_encoded, self.latest_timestamp
|
||||
|
||||
|
||||
class ImageServer:
|
||||
def __init__(self, config: dict, port: int = 5555):
|
||||
# fps controls the publish loop rate (how often frames are sent over ZMQ), not the camera capture rate
|
||||
self.fps = config.get("fps", 30)
|
||||
self.cameras: dict[str, OpenCVCamera] = {}
|
||||
self.capture_threads: dict[str, CameraCaptureThread] = {}
|
||||
|
||||
for name, cfg in config.get("cameras", {}).items():
|
||||
shape = cfg.get("shape", [480, 640])
|
||||
@@ -61,6 +109,10 @@ class ImageServer:
|
||||
self.cameras[name] = camera
|
||||
logger.info(f"Camera {name}: {shape[1]}x{shape[0]}")
|
||||
|
||||
# Create capture thread for this camera
|
||||
capture_thread = CameraCaptureThread(camera, name)
|
||||
self.capture_threads[name] = capture_thread
|
||||
|
||||
# ZMQ PUB socket
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.PUB)
|
||||
@@ -73,6 +125,18 @@ class ImageServer:
|
||||
def run(self):
|
||||
frame_count = 0
|
||||
frame_times = deque(maxlen=60)
|
||||
last_published_ts: dict[str, float] = {}
|
||||
|
||||
# Start all capture threads
|
||||
for capture_thread in self.capture_threads.values():
|
||||
capture_thread.start()
|
||||
|
||||
# Wait for first frames to be captured and encoded
|
||||
logger.info("Waiting for cameras to start capturing...")
|
||||
for name, capture_thread in self.capture_threads.items():
|
||||
while capture_thread.get_latest()[0] is None:
|
||||
time.sleep(0.01)
|
||||
logger.info(f"Camera {name} ready (capture + encode in background)")
|
||||
|
||||
try:
|
||||
while True:
|
||||
@@ -80,10 +144,12 @@ class ImageServer:
|
||||
|
||||
# Build message
|
||||
message = {"timestamps": {}, "images": {}}
|
||||
for name, cam in self.cameras.items():
|
||||
frame = cam.read() # Returns RGB
|
||||
message["timestamps"][name] = time.time()
|
||||
message["images"][name] = encode_image(frame)
|
||||
for name, capture_thread in self.capture_threads.items():
|
||||
encoded, timestamp = capture_thread.get_latest()
|
||||
if encoded is not None and timestamp > last_published_ts.get(name, 0.0):
|
||||
message["timestamps"][name] = timestamp
|
||||
message["images"][name] = encoded
|
||||
last_published_ts[name] = timestamp
|
||||
|
||||
# Send as JSON string (suppress if buffer full)
|
||||
with contextlib.suppress(zmq.Again):
|
||||
@@ -102,6 +168,8 @@ class ImageServer:
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
for capture_thread in self.capture_threads.values():
|
||||
capture_thread.stop()
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
self.socket.close()
|
||||
|
||||
@@ -16,18 +16,13 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.datasets.transforms import ImageTransformsConfig
|
||||
from lerobot.datasets.transforms import DatasetTransformStepConfig, ImageTransformsConfig
|
||||
from lerobot.datasets.video_utils import get_safe_default_codec
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
|
||||
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
|
||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||
# datasets are provided.
|
||||
repo_id: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
@@ -37,6 +32,32 @@ class DatasetConfig:
|
||||
streaming: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubDatasetConfig:
|
||||
"""Configuration for a single dataset within a MultiDatasetConfig."""
|
||||
|
||||
repo_id: str
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
revision: str | None = None
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
weight: float = 1.0
|
||||
# Maps dataset-local feature keys to unified policy keys.
|
||||
# Keys not listed pass through unchanged.
|
||||
feature_map: dict[str, str] = field(default_factory=dict)
|
||||
# Per-dataset transforms applied after feature renaming, before cross-dataset padding.
|
||||
transforms: list[DatasetTransformStepConfig] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiDatasetConfig:
|
||||
"""Configuration for training on multiple datasets jointly."""
|
||||
|
||||
datasets: list[SubDatasetConfig] = field(default_factory=list)
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
use_imagenet_stats: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class WandBConfig:
|
||||
enable: bool = False
|
||||
|
||||
@@ -24,7 +24,7 @@ from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot import envs
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from lerobot.configs.default import DatasetConfig, EvalConfig, MultiDatasetConfig, PeftConfig, WandBConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.optim import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
@@ -35,7 +35,7 @@ TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig
|
||||
dataset: DatasetConfig | MultiDatasetConfig
|
||||
env: envs.EnvConfig | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
@@ -50,6 +50,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||
# AND for the evaluation environments.
|
||||
seed: int | None = 1000
|
||||
# Set to True to use deterministic cuDNN algorithms for reproducibility.
|
||||
# This disables cudnn.benchmark and may reduce training speed by ~10-20%.
|
||||
cudnn_deterministic: bool = False
|
||||
# Number of workers for the dataloader.
|
||||
num_workers: int = 4
|
||||
batch_size: int = 8
|
||||
@@ -126,8 +129,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
||||
self.output_dir = Path("outputs/train") / train_dir
|
||||
|
||||
if isinstance(self.dataset.repo_id, list):
|
||||
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
||||
if isinstance(self.dataset, MultiDatasetConfig):
|
||||
if len(self.dataset.datasets) < 1:
|
||||
raise ValueError("MultiDatasetConfig.datasets must contain at least one sub-dataset.")
|
||||
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||
@@ -140,8 +144,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
||||
)
|
||||
|
||||
if self.use_rabc and not self.rabc_progress_path:
|
||||
# Auto-detect from dataset path
|
||||
if self.use_rabc and not self.rabc_progress_path and isinstance(self.dataset, DatasetConfig):
|
||||
repo_id = self.dataset.repo_id
|
||||
if self.dataset.root:
|
||||
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
||||
|
||||
@@ -289,7 +289,9 @@ def aggregate_datasets(
|
||||
|
||||
logging.info("Find all tasks")
|
||||
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
|
||||
dst_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
|
||||
dst_meta.tasks = pd.DataFrame(
|
||||
{"task_index": range(len(unique_tasks))}, index=pd.Index(unique_tasks, name="task")
|
||||
)
|
||||
|
||||
meta_idx = {"chunk": 0, "file": 0}
|
||||
data_idx = {"chunk": 0, "file": 0}
|
||||
|
||||
@@ -89,8 +89,8 @@ def delete_episodes(
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
episode_indices: List of episode indices to delete.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
"""
|
||||
if not episode_indices:
|
||||
raise ValueError("No episodes to delete")
|
||||
@@ -152,7 +152,7 @@ def split_dataset(
|
||||
dataset: The source LeRobotDataset to split.
|
||||
splits: Either a dict mapping split names to episode indices, or a dict mapping
|
||||
split names to fractions (must sum to <= 1.0).
|
||||
output_dir: Base directory for output datasets. If None, uses default location.
|
||||
output_dir: Root directory where the split datasets will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
|
||||
Examples:
|
||||
Split by specific episodes
|
||||
@@ -243,8 +243,8 @@ def merge_datasets(
|
||||
|
||||
Args:
|
||||
datasets: List of LeRobotDatasets to merge.
|
||||
output_repo_id: Repository ID for the merged dataset.
|
||||
output_dir: Directory to save the merged dataset. If None, uses default location.
|
||||
output_repo_id: Merged dataset identifier.
|
||||
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
|
||||
"""
|
||||
if not datasets:
|
||||
raise ValueError("No datasets to merge")
|
||||
@@ -288,8 +288,8 @@ def modify_features(
|
||||
dataset: The source LeRobotDataset.
|
||||
add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples.
|
||||
remove_features: Optional feature name(s) to remove. Can be a single string or list.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
|
||||
Returns:
|
||||
New dataset with features modified.
|
||||
@@ -390,8 +390,8 @@ def add_features(
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
features: Dictionary mapping feature names to (feature_values, feature_info) tuples.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
|
||||
Returns:
|
||||
New dataset with all features added.
|
||||
@@ -427,8 +427,8 @@ def remove_feature(
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
feature_names: Name(s) of features to remove. Can be a single string or list.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
|
||||
Returns:
|
||||
New dataset with features removed.
|
||||
@@ -1475,7 +1475,9 @@ def modify_tasks(
|
||||
|
||||
# Collect all unique tasks and create new task mapping
|
||||
unique_tasks = sorted(set(episode_to_task.values()))
|
||||
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
|
||||
new_task_df = pd.DataFrame(
|
||||
{"task_index": list(range(len(unique_tasks)))}, index=pd.Index(unique_tasks, name="task")
|
||||
)
|
||||
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
|
||||
|
||||
logging.info(f"Modifying tasks in {dataset.repo_id}")
|
||||
@@ -1529,7 +1531,7 @@ def modify_tasks(
|
||||
|
||||
def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
output_dir: Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
@@ -1548,8 +1550,8 @@ def convert_image_to_video_dataset(
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Directory to save the new video dataset
|
||||
repo_id: Repository ID for the new dataset (default: original_id + "_video")
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
@@ -1600,6 +1602,7 @@ def convert_image_to_video_dataset(
|
||||
# Video info will be updated after episodes are encoded
|
||||
|
||||
# Create new metadata for video dataset
|
||||
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
|
||||
@@ -18,13 +18,14 @@ from pprint import pformat
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.default import DatasetConfig, MultiDatasetConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
|
||||
@@ -68,66 +69,81 @@ def resolve_delta_timestamps(
|
||||
return delta_timestamps
|
||||
|
||||
|
||||
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
|
||||
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: The MultiLeRobotDataset is currently deactivated.
|
||||
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | NewMultiLeRobotDataset:
|
||||
"""Create a single or multi-dataset depending on the config type.
|
||||
|
||||
Returns:
|
||||
LeRobotDataset | MultiLeRobotDataset
|
||||
LeRobotDataset | NewMultiLeRobotDataset
|
||||
"""
|
||||
if isinstance(cfg.dataset, MultiDatasetConfig):
|
||||
return _make_multi_dataset(cfg)
|
||||
|
||||
return _make_single_dataset(cfg)
|
||||
|
||||
|
||||
def _make_single_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset:
|
||||
ds_cfg: DatasetConfig = cfg.dataset # type: ignore[assignment]
|
||||
image_transforms = (
|
||||
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
||||
ImageTransforms(ds_cfg.image_transforms) if ds_cfg.image_transforms.enable else None
|
||||
)
|
||||
ds_meta = LeRobotDatasetMetadata(ds_cfg.repo_id, root=ds_cfg.root, revision=ds_cfg.revision)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
|
||||
if isinstance(cfg.dataset.repo_id, str):
|
||||
ds_meta = LeRobotDatasetMetadata(
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
if not cfg.dataset.streaming:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
dataset = StreamingLeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
# TODO(aliberts): add proper support for multi dataset
|
||||
# delta_timestamps=delta_timestamps,
|
||||
if not ds_cfg.streaming:
|
||||
dataset = LeRobotDataset(
|
||||
ds_cfg.repo_id,
|
||||
root=ds_cfg.root,
|
||||
episodes=ds_cfg.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
revision=ds_cfg.revision,
|
||||
video_backend=ds_cfg.video_backend,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
||||
else:
|
||||
dataset = StreamingLeRobotDataset(
|
||||
ds_cfg.repo_id,
|
||||
root=ds_cfg.root,
|
||||
episodes=ds_cfg.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=ds_cfg.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
if ds_cfg.use_imagenet_stats:
|
||||
for key in dataset.meta.camera_keys:
|
||||
for stats_type, stats in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
for stats_type, stats_val in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats_val, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def _make_multi_dataset(cfg: TrainPipelineConfig) -> NewMultiLeRobotDataset:
|
||||
multi_cfg: MultiDatasetConfig = cfg.dataset # type: ignore[assignment]
|
||||
image_transforms = (
|
||||
ImageTransforms(multi_cfg.image_transforms) if multi_cfg.image_transforms.enable else None
|
||||
)
|
||||
|
||||
dataset = NewMultiLeRobotDataset(
|
||||
configs=multi_cfg.datasets,
|
||||
image_transforms=image_transforms,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
|
||||
logging.info(
|
||||
"MultiLeRobotDataset created with %d sub-datasets:\n%s",
|
||||
len(multi_cfg.datasets),
|
||||
pformat(
|
||||
{i: c.repo_id for i, c in enumerate(multi_cfg.datasets)},
|
||||
indent=2,
|
||||
),
|
||||
)
|
||||
|
||||
if multi_cfg.use_imagenet_stats:
|
||||
for key in dataset.meta.camera_keys:
|
||||
for stats_type, stats_val in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats_val, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -314,7 +314,7 @@ class LeRobotDatasetMetadata:
|
||||
if self.tasks is None:
|
||||
new_tasks = tasks
|
||||
task_indices = range(len(tasks))
|
||||
self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks)
|
||||
self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task"))
|
||||
else:
|
||||
new_tasks = [task for task in tasks if task not in self.tasks.index]
|
||||
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
|
||||
|
||||
@@ -0,0 +1,364 @@
|
||||
"""MultiLeRobotDataset: joint training over heterogeneous LeRobot datasets.
|
||||
|
||||
Supports:
|
||||
- Per-dataset feature mapping (rename keys to a unified namespace)
|
||||
- Automatic zero-padding for features missing in some datasets
|
||||
- Per-dataset transform pipelines
|
||||
- Weighted sampling via dataset weights
|
||||
- Aggregated stats across all sub-datasets
|
||||
- A ``meta`` shim compatible with EpisodeAwareSampler and make_policy
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
|
||||
from lerobot.configs.default import SubDatasetConfig
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.transforms import DatasetTransformPipeline
|
||||
|
||||
|
||||
class MultiDatasetMeta:
|
||||
"""Lightweight metadata shim that exposes the same interface as ``LeRobotDatasetMetadata``.
|
||||
|
||||
Built by aggregating the metadata of multiple sub-datasets after their
|
||||
feature keys have been mapped to a unified namespace.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasets: list[LeRobotDataset],
|
||||
feature_maps: list[dict[str, str]],
|
||||
):
|
||||
self._datasets = datasets
|
||||
self._feature_maps = feature_maps
|
||||
|
||||
self._unified_features = self._build_unified_features()
|
||||
self._episodes = self._build_episodes()
|
||||
self._stats = self._build_stats()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Feature union
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_unified_features(self) -> dict[str, dict]:
|
||||
"""Build feature dict as the *union* of all mapped feature keys."""
|
||||
unified: dict[str, dict] = {}
|
||||
for ds, fmap in zip(self._datasets, self._feature_maps):
|
||||
for original_key, feat_info in ds.meta.features.items():
|
||||
mapped_key = fmap.get(original_key, original_key)
|
||||
if mapped_key not in unified:
|
||||
unified[mapped_key] = dict(feat_info)
|
||||
else:
|
||||
existing_shape = tuple(unified[mapped_key]["shape"])
|
||||
new_shape = tuple(feat_info["shape"])
|
||||
if existing_shape != new_shape and unified[mapped_key]["dtype"] == feat_info["dtype"]:
|
||||
logging.warning(
|
||||
"Feature '%s' has shape %s in one dataset but %s in another. "
|
||||
"The larger shape will be used (padding applied automatically).",
|
||||
mapped_key,
|
||||
existing_shape,
|
||||
new_shape,
|
||||
)
|
||||
if np.prod(new_shape) > np.prod(existing_shape):
|
||||
unified[mapped_key] = dict(feat_info)
|
||||
return unified
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Episode metadata (global flat indexing)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_episodes(self) -> dict[str, list]:
|
||||
"""Concatenate episode boundaries across sub-datasets with frame offsets.
|
||||
|
||||
Produces the same column structure as ``load_episodes()`` so that
|
||||
``EpisodeAwareSampler`` and ``WeightedEpisodeAwareSampler`` can consume it.
|
||||
"""
|
||||
from_indices: list[int] = []
|
||||
to_indices: list[int] = []
|
||||
dataset_source: list[int] = []
|
||||
|
||||
frame_offset = 0
|
||||
for ds_idx, ds in enumerate(self._datasets):
|
||||
eps = ds.meta.episodes
|
||||
for ep in eps:
|
||||
from_indices.append(ep["dataset_from_index"] + frame_offset)
|
||||
to_indices.append(ep["dataset_to_index"] + frame_offset)
|
||||
dataset_source.append(ds_idx)
|
||||
frame_offset += ds.num_frames
|
||||
|
||||
return {
|
||||
"dataset_from_index": from_indices,
|
||||
"dataset_to_index": to_indices,
|
||||
"dataset_source": dataset_source,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Stats aggregation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_stats(self) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregate stats across sub-datasets using mapped feature keys."""
|
||||
mapped_stats_list: list[dict[str, dict]] = []
|
||||
for ds, fmap in zip(self._datasets, self._feature_maps):
|
||||
reverse_map = {v: k for k, v in fmap.items()}
|
||||
mapped: dict[str, dict] = {}
|
||||
for unified_key in self._unified_features:
|
||||
original_key = reverse_map.get(unified_key, unified_key)
|
||||
if original_key in ds.meta.stats:
|
||||
mapped[unified_key] = ds.meta.stats[original_key]
|
||||
mapped_stats_list.append(mapped)
|
||||
|
||||
return aggregate_stats(mapped_stats_list)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Properties matching LeRobotDatasetMetadata API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
return self._unified_features
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
return [k for k, f in self._unified_features.items() if f["dtype"] == "image"]
|
||||
|
||||
@property
|
||||
def video_keys(self) -> list[str]:
|
||||
return [k for k, f in self._unified_features.items() if f["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return [k for k, f in self._unified_features.items() if f["dtype"] in ("video", "image")]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
return {k: f["names"] for k, f in self._unified_features.items()}
|
||||
|
||||
@property
|
||||
def shapes(self) -> dict[str, tuple]:
|
||||
return {k: tuple(f["shape"]) for k, f in self._unified_features.items()}
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
fps_values = {ds.meta.fps for ds in self._datasets}
|
||||
if len(fps_values) > 1:
|
||||
logging.warning("Sub-datasets have different FPS values: %s. Using the first.", fps_values)
|
||||
return self._datasets[0].meta.fps
|
||||
|
||||
@property
|
||||
def stats(self) -> dict[str, dict[str, np.ndarray]]:
|
||||
return self._stats
|
||||
|
||||
@stats.setter
|
||||
def stats(self, value: dict):
|
||||
self._stats = value
|
||||
|
||||
@property
|
||||
def episodes(self) -> dict[str, list]:
|
||||
return self._episodes
|
||||
|
||||
@property
|
||||
def total_episodes(self) -> int:
|
||||
return sum(ds.meta.total_episodes for ds in self._datasets)
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
return sum(ds.meta.total_frames for ds in self._datasets)
|
||||
|
||||
@property
|
||||
def total_tasks(self) -> int:
|
||||
return sum(ds.meta.total_tasks for ds in self._datasets)
|
||||
|
||||
@property
|
||||
def info(self) -> dict:
|
||||
return {
|
||||
"fps": self.fps,
|
||||
"features": self._unified_features,
|
||||
"total_episodes": self.total_episodes,
|
||||
"total_frames": self.total_frames,
|
||||
"total_tasks": self.total_tasks,
|
||||
"codebase_version": "v3.0",
|
||||
}
|
||||
|
||||
|
||||
class NewMultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""Dataset that wraps multiple ``LeRobotDataset`` instances with feature mapping and padding.
|
||||
|
||||
Each sub-dataset can have different feature names and shapes. A per-dataset
|
||||
``feature_map`` renames keys into a shared namespace. Features that a given
|
||||
sub-dataset does not provide are zero-padded so every ``__getitem__`` returns
|
||||
the full unified feature set.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
configs: list[SubDatasetConfig],
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
):
|
||||
super().__init__()
|
||||
self._configs = configs
|
||||
self.image_transforms = image_transforms
|
||||
|
||||
self._datasets: list[LeRobotDataset] = []
|
||||
self._feature_maps: list[dict[str, str]] = []
|
||||
self._transform_pipelines: list[DatasetTransformPipeline | None] = []
|
||||
self._weights: list[float] = []
|
||||
|
||||
for cfg in configs:
|
||||
ds = LeRobotDataset(
|
||||
repo_id=cfg.repo_id,
|
||||
root=cfg.root,
|
||||
episodes=cfg.episodes,
|
||||
image_transforms=image_transforms,
|
||||
delta_timestamps=delta_timestamps,
|
||||
tolerance_s=tolerance_s,
|
||||
revision=cfg.revision,
|
||||
video_backend=cfg.video_backend,
|
||||
)
|
||||
self._datasets.append(ds)
|
||||
self._feature_maps.append(cfg.feature_map or {})
|
||||
self._transform_pipelines.append(
|
||||
DatasetTransformPipeline(cfg.transforms) if cfg.transforms else None
|
||||
)
|
||||
self._weights.append(cfg.weight)
|
||||
|
||||
self._meta = MultiDatasetMeta(self._datasets, self._feature_maps)
|
||||
|
||||
# Pre-compute cumulative frame counts for fast index mapping.
|
||||
self._cumulative_frames: list[int] = []
|
||||
total = 0
|
||||
for ds in self._datasets:
|
||||
total += ds.num_frames
|
||||
self._cumulative_frames.append(total)
|
||||
|
||||
# Build reverse maps (unified_key -> original_key) per dataset for padding.
|
||||
self._reverse_maps: list[dict[str, str]] = []
|
||||
for fmap in self._feature_maps:
|
||||
self._reverse_maps.append({v: k for k, v in fmap.items()})
|
||||
|
||||
logging.info(
|
||||
"MultiLeRobotDataset: %d sub-datasets, %d total frames, %d total episodes, "
|
||||
"%d unified features",
|
||||
len(self._datasets),
|
||||
self.num_frames,
|
||||
self.num_episodes,
|
||||
len(self._meta.features),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def meta(self) -> MultiDatasetMeta:
|
||||
return self._meta
|
||||
|
||||
@property
|
||||
def dataset_weights(self) -> list[float]:
|
||||
return self._weights
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
return self._cumulative_frames[-1] if self._cumulative_frames else 0
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return sum(ds.num_episodes for ds in self._datasets)
|
||||
|
||||
@property
|
||||
def episodes(self) -> list[int] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
return self._meta.fps
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
return self._meta.features
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return self._meta.camera_keys
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Indexing
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _locate(self, idx: int) -> tuple[int, int]:
|
||||
"""Map a global frame index to (dataset_index, local_index)."""
|
||||
for ds_idx, cum in enumerate(self._cumulative_frames):
|
||||
if idx < cum:
|
||||
local = idx - (self._cumulative_frames[ds_idx - 1] if ds_idx > 0 else 0)
|
||||
return ds_idx, local
|
||||
raise IndexError(f"Index {idx} out of range (total {self.num_frames})")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_frames
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
ds_idx, local_idx = self._locate(idx)
|
||||
item = self._datasets[ds_idx][local_idx]
|
||||
|
||||
# 1. Rename keys according to feature_map.
|
||||
fmap = self._feature_maps[ds_idx]
|
||||
if fmap:
|
||||
renamed: dict[str, torch.Tensor] = {}
|
||||
for key, value in item.items():
|
||||
renamed[fmap.get(key, key)] = value
|
||||
item = renamed
|
||||
|
||||
# 2. Apply per-dataset transform pipeline.
|
||||
pipeline = self._transform_pipelines[ds_idx]
|
||||
if pipeline is not None:
|
||||
item = pipeline(item)
|
||||
|
||||
# 3. Pad missing features with zeros.
|
||||
reverse_map = self._reverse_maps[ds_idx]
|
||||
ds_features = self._datasets[ds_idx].meta.features
|
||||
for unified_key, feat_info in self._meta.features.items():
|
||||
if unified_key in item:
|
||||
continue
|
||||
original_key = reverse_map.get(unified_key, unified_key)
|
||||
if original_key in ds_features:
|
||||
continue
|
||||
shape = tuple(feat_info["shape"])
|
||||
dtype = feat_info["dtype"]
|
||||
if dtype in ("video", "image"):
|
||||
# Camera tensors are (C, H, W) after transforms.
|
||||
c, h, w = (shape[2], shape[0], shape[1]) if len(shape) == 3 else (3, shape[0], shape[1])
|
||||
item[unified_key] = torch.zeros(c, h, w, dtype=torch.float32)
|
||||
elif dtype in ("float32", "float64"):
|
||||
item[unified_key] = torch.zeros(shape, dtype=torch.float32)
|
||||
elif dtype in ("int32", "int64"):
|
||||
item[unified_key] = torch.zeros(shape, dtype=torch.int64)
|
||||
elif dtype == "bool":
|
||||
item[unified_key] = torch.zeros(shape, dtype=torch.bool)
|
||||
else:
|
||||
item[unified_key] = torch.zeros(shape, dtype=torch.float32)
|
||||
item[f"{unified_key}_is_pad"] = torch.tensor(True)
|
||||
|
||||
# 4. Tag which dataset this sample came from.
|
||||
item["dataset_index"] = torch.tensor(ds_idx)
|
||||
return item
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repo_ids = [c.repo_id for c in self._configs]
|
||||
return (
|
||||
f"NewMultiLeRobotDataset(\n"
|
||||
f" repo_ids={repo_ids},\n"
|
||||
f" num_frames={self.num_frames},\n"
|
||||
f" num_episodes={self.num_episodes},\n"
|
||||
f" unified_features={list(self._meta.features.keys())},\n"
|
||||
f" weights={self._weights},\n"
|
||||
f")"
|
||||
)
|
||||
@@ -12,14 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import DataProcessorPipeline, RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
|
||||
|
||||
def create_initial_features(
|
||||
@@ -41,3 +41,99 @@ def create_initial_features(
|
||||
if observation:
|
||||
features[PipelineFeatureType.OBSERVATION] = observation
|
||||
return features
|
||||
|
||||
|
||||
# Helper to filter state/action keys based on regex patterns.
|
||||
def should_keep(key: str, patterns: tuple[str]) -> bool:
|
||||
if patterns is None:
|
||||
return True
|
||||
return any(re.search(pat, key) for pat in patterns)
|
||||
|
||||
|
||||
def strip_prefix(key: str, prefixes_to_strip: tuple[str]) -> str:
|
||||
for prefix in prefixes_to_strip:
|
||||
if key.startswith(prefix):
|
||||
return key[len(prefix) :]
|
||||
return key
|
||||
|
||||
|
||||
# Define prefixes to strip from feature keys for clean names.
|
||||
# Handles both fully qualified (e.g., "action.state") and short (e.g., "state") forms.
|
||||
PREFIXES_TO_STRIP = tuple(
|
||||
f"{token}." for const in (ACTION, OBS_STATE, OBS_IMAGES) for token in (const, const.split(".")[-1])
|
||||
)
|
||||
|
||||
|
||||
def aggregate_pipeline_dataset_features(
|
||||
pipeline: DataProcessorPipeline,
|
||||
initial_features: dict[PipelineFeatureType, dict[str, Any]],
|
||||
*,
|
||||
use_videos: bool = True,
|
||||
patterns: Sequence[str] | None = None,
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
|
||||
|
||||
This function transforms initial features using the pipeline, categorizes them as action or observations
|
||||
(image or state), filters them based on `use_videos` and `patterns`, and finally
|
||||
formats them for use with a Hugging Face LeRobot Dataset.
|
||||
|
||||
Args:
|
||||
pipeline: The DataProcessorPipeline to apply.
|
||||
initial_features: A dictionary of raw feature specs for actions and observations.
|
||||
use_videos: If False, image features are excluded.
|
||||
patterns: A sequence of regex patterns to filter action and state features.
|
||||
Image features are not affected by this filter.
|
||||
|
||||
Returns:
|
||||
A dictionary of features formatted for a Hugging Face LeRobot Dataset.
|
||||
"""
|
||||
all_features = pipeline.transform_features(initial_features)
|
||||
|
||||
# Intermediate storage for categorized and filtered features.
|
||||
processed_features: dict[str, dict[str, Any]] = {
|
||||
ACTION: {},
|
||||
OBS_STR: {},
|
||||
}
|
||||
images_token = OBS_IMAGES.split(".")[-1]
|
||||
|
||||
# Iterate through all features transformed by the pipeline.
|
||||
for ptype, feats in all_features.items():
|
||||
if ptype not in [PipelineFeatureType.ACTION, PipelineFeatureType.OBSERVATION]:
|
||||
continue
|
||||
|
||||
for key, value in feats.items():
|
||||
# 1. Categorize the feature.
|
||||
is_action = ptype == PipelineFeatureType.ACTION
|
||||
# Observations are classified as images if their key matches image-related tokens or if the shape of the feature is 3.
|
||||
# All other observations are treated as state.
|
||||
is_image = not is_action and (
|
||||
(isinstance(value, tuple) and len(value) == 3)
|
||||
or (
|
||||
key.startswith(f"{OBS_IMAGES}.")
|
||||
or key.startswith(f"{images_token}.")
|
||||
or f".{images_token}." in key
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Apply filtering rules.
|
||||
if is_image and not use_videos:
|
||||
continue
|
||||
if not is_image and not should_keep(key, patterns):
|
||||
continue
|
||||
|
||||
# 3. Add the feature to the appropriate group with a clean name.
|
||||
name = strip_prefix(key, PREFIXES_TO_STRIP)
|
||||
if is_action:
|
||||
processed_features[ACTION][name] = value
|
||||
else:
|
||||
processed_features[OBS_STR][name] = value
|
||||
|
||||
# Convert the processed features into the final dataset format.
|
||||
dataset_features = {}
|
||||
if processed_features[ACTION]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features[ACTION], ACTION, use_videos))
|
||||
if processed_features[OBS_STR]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features[OBS_STR], OBS_STR, use_videos))
|
||||
|
||||
return dataset_features
|
||||
|
||||
@@ -59,3 +59,80 @@ class EpisodeAwareSampler:
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.indices)
|
||||
|
||||
|
||||
class WeightedEpisodeAwareSampler:
|
||||
"""Sampler that draws frames from multiple datasets according to per-dataset weights.
|
||||
|
||||
Each iteration first selects a sub-dataset proportionally to its weight, then
|
||||
uniformly samples a frame from that sub-dataset's valid index set. Episode
|
||||
boundary information is respected so that dropped frames are excluded.
|
||||
|
||||
Args:
|
||||
dataset_from_indices: Start index for each episode (global, flat).
|
||||
dataset_to_indices: End index (exclusive) for each episode (global, flat).
|
||||
dataset_membership: Which sub-dataset each episode belongs to (integer id).
|
||||
dataset_weights: Relative sampling weight per sub-dataset.
|
||||
episode_indices_to_use: If given, only episodes in this set are used.
|
||||
drop_n_first_frames: Frames to skip at the start of each episode.
|
||||
drop_n_last_frames: Frames to skip at the end of each episode.
|
||||
shuffle: Whether to shuffle within each epoch.
|
||||
num_samples: How many samples per epoch. Defaults to total valid frames.
|
||||
generator: Optional torch.Generator for reproducibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_from_indices: list[int],
|
||||
dataset_to_indices: list[int],
|
||||
dataset_membership: list[int],
|
||||
dataset_weights: list[float],
|
||||
episode_indices_to_use: list | None = None,
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
num_samples: int | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
):
|
||||
n_datasets = max(dataset_membership) + 1 if dataset_membership else 0
|
||||
self._per_dataset_indices: list[list[int]] = [[] for _ in range(n_datasets)]
|
||||
|
||||
episodes_to_use = set(episode_indices_to_use) if episode_indices_to_use is not None else None
|
||||
|
||||
for ep_idx, (start, end, ds_id) in enumerate(
|
||||
zip(dataset_from_indices, dataset_to_indices, dataset_membership, strict=True)
|
||||
):
|
||||
if episodes_to_use is not None and ep_idx not in episodes_to_use:
|
||||
continue
|
||||
frame_range = range(start + drop_n_first_frames, end - drop_n_last_frames)
|
||||
self._per_dataset_indices[ds_id].extend(frame_range)
|
||||
|
||||
# Normalise weights (only over datasets that actually have frames).
|
||||
raw_weights = list(dataset_weights[:n_datasets])
|
||||
self._weights = torch.zeros(n_datasets)
|
||||
for i, w in enumerate(raw_weights):
|
||||
if len(self._per_dataset_indices[i]) > 0:
|
||||
self._weights[i] = w
|
||||
total_w = self._weights.sum()
|
||||
if total_w > 0:
|
||||
self._weights /= total_w
|
||||
|
||||
self._total_frames = sum(len(idx) for idx in self._per_dataset_indices)
|
||||
self._num_samples = num_samples if num_samples is not None else self._total_frames
|
||||
self.shuffle = shuffle
|
||||
self._generator = generator
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
if not self.shuffle:
|
||||
for ds_indices in self._per_dataset_indices:
|
||||
yield from ds_indices
|
||||
return
|
||||
|
||||
for _ in range(self._num_samples):
|
||||
ds_id = int(torch.multinomial(self._weights, 1, generator=self._generator).item())
|
||||
indices = self._per_dataset_indices[ds_id]
|
||||
local_idx = int(torch.randint(len(indices), (1,), generator=self._generator).item())
|
||||
yield indices[local_idx]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._num_samples
|
||||
|
||||
@@ -14,11 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F_nn
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import (
|
||||
Transform,
|
||||
@@ -258,3 +260,114 @@ class ImageTransforms(Transform):
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
return self.tf(*inputs)
|
||||
|
||||
|
||||
|
||||
# Per-dataset transform pipeline (used by MultiLeRobotDataset)
|
||||
|
||||
@dataclass
|
||||
class DatasetTransformStepConfig:
|
||||
"""Config for a single per-dataset transform step."""
|
||||
|
||||
type: str
|
||||
kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
_DATASET_TRANSFORM_REGISTRY: dict[str, type["DatasetTransformStep"]] = {}
|
||||
|
||||
|
||||
def register_dataset_transform(name: str):
|
||||
"""Decorator to register a DatasetTransformStep by name."""
|
||||
|
||||
def decorator(cls: type["DatasetTransformStep"]) -> type["DatasetTransformStep"]:
|
||||
_DATASET_TRANSFORM_REGISTRY[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class DatasetTransformStep:
|
||||
"""Base class for a single per-dataset transform applied to a sample dict."""
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@register_dataset_transform("pad_action")
|
||||
class PadAction(DatasetTransformStep):
|
||||
"""Zero-pad the ``action`` tensor to *target_dim* along the last axis."""
|
||||
|
||||
def __init__(self, target_dim: int):
|
||||
self.target_dim = target_dim
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
action = sample.get("action")
|
||||
if action is None:
|
||||
return sample
|
||||
current = action.shape[-1]
|
||||
if current < self.target_dim:
|
||||
sample["action"] = F_nn.pad(action, (0, self.target_dim - current))
|
||||
return sample
|
||||
|
||||
|
||||
@register_dataset_transform("pad_state")
|
||||
class PadState(DatasetTransformStep):
|
||||
"""Zero-pad ``observation.state`` to *target_dim* along the last axis."""
|
||||
|
||||
def __init__(self, target_dim: int):
|
||||
self.target_dim = target_dim
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
state = sample.get("observation.state")
|
||||
if state is None:
|
||||
return sample
|
||||
current = state.shape[-1]
|
||||
if current < self.target_dim:
|
||||
sample["observation.state"] = F_nn.pad(state, (0, self.target_dim - current))
|
||||
return sample
|
||||
|
||||
|
||||
@register_dataset_transform("resize_images")
|
||||
class ResizeImages(DatasetTransformStep):
|
||||
"""Resize all image/video camera tensors to (height, width)."""
|
||||
|
||||
def __init__(self, height: int, width: int):
|
||||
self.size = (height, width)
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
for key in list(sample.keys()):
|
||||
if not key.startswith("observation.images."):
|
||||
continue
|
||||
img = sample[key]
|
||||
if not isinstance(img, torch.Tensor) or img.ndim < 3:
|
||||
continue
|
||||
sample[key] = F.resize(img, self.size, antialias=True)
|
||||
return sample
|
||||
|
||||
|
||||
class DatasetTransformPipeline:
|
||||
"""Sequential pipeline of DatasetTransformStep instances."""
|
||||
|
||||
def __init__(self, configs: list[DatasetTransformStepConfig] | None = None):
|
||||
self.steps: list[DatasetTransformStep] = []
|
||||
if configs:
|
||||
for cfg in configs:
|
||||
self.steps.append(self._build(cfg))
|
||||
|
||||
@staticmethod
|
||||
def _build(cfg: DatasetTransformStepConfig) -> DatasetTransformStep:
|
||||
cls = _DATASET_TRANSFORM_REGISTRY.get(cfg.type)
|
||||
if cls is None:
|
||||
raise ValueError(
|
||||
f"Unknown dataset transform '{cfg.type}'. "
|
||||
f"Available: {list(_DATASET_TRANSFORM_REGISTRY)}"
|
||||
)
|
||||
return cls(**cfg.kwargs)
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
for step in self.steps:
|
||||
sample = step(sample)
|
||||
return sample
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DatasetTransformPipeline(steps={self.steps})"
|
||||
|
||||
@@ -21,7 +21,7 @@ from collections import deque
|
||||
from collections.abc import Iterable, Iterator
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@@ -78,8 +78,6 @@ DEFAULT_FEATURES = {
|
||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
@@ -341,6 +339,7 @@ def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None:
|
||||
|
||||
def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||
tasks.index.name = "task"
|
||||
return tasks
|
||||
|
||||
|
||||
@@ -1233,7 +1232,7 @@ class LookAheadError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Backtrackable(Generic[T]):
|
||||
class Backtrackable[T]:
|
||||
"""
|
||||
Wrap any iterator/iterable so you can step back up to `history` items
|
||||
and look ahead up to `lookahead` items.
|
||||
|
||||
@@ -36,8 +36,11 @@ Convert a local dataset (works in place):
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
--repo-id=lerobot/pusht \
|
||||
--root=/path/to/local/dataset/directory
|
||||
--root=/path/to/local/dataset/directory \
|
||||
--push-to-hub=false
|
||||
|
||||
N.B. Path semantics (v2): --root is the exact dataset folder containing
|
||||
meta/, data/, videos/. When omitted, defaults to $HF_LEROBOT_HOME/{repo_id}.
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -105,7 +108,7 @@ episodes.jsonl
|
||||
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
|
||||
|
||||
NEW
|
||||
meta/episodes/chunk-000/episodes_000.parquet
|
||||
meta/episodes/chunk-000/file_000.parquet
|
||||
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
|
||||
-------------------------
|
||||
OLD
|
||||
@@ -113,15 +116,16 @@ tasks.jsonl
|
||||
{"task_index": 1, "task": "Put the blue block in the green bowl"}
|
||||
|
||||
NEW
|
||||
meta/tasks/chunk-000/file_000.parquet
|
||||
meta/tasks.parquet
|
||||
task_index | task
|
||||
-------------------------
|
||||
OLD
|
||||
episodes_stats.jsonl
|
||||
{"episode_index": 1, "stats": {"feature_name": {"min": ..., "max": ..., "mean": ..., "std": ..., "count": ...}}}
|
||||
|
||||
NEW
|
||||
meta/episodes_stats/chunk-000/file_000.parquet
|
||||
episode_index | mean | std | min | max
|
||||
meta/episodes/chunk-000/file_000.parquet
|
||||
episode_index | feature_name/min | feature_name/max | feature_name/mean | feature_name/std | feature_name/count
|
||||
-------------------------
|
||||
UPDATE
|
||||
meta/info.json
|
||||
@@ -170,7 +174,7 @@ def convert_tasks(root, new_root):
|
||||
tasks, _ = legacy_load_tasks(root)
|
||||
task_indices = tasks.keys()
|
||||
task_strings = tasks.values()
|
||||
df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings)
|
||||
df_tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(task_strings, name="task"))
|
||||
write_tasks(df_tasks, new_root)
|
||||
|
||||
|
||||
@@ -201,7 +205,6 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
|
||||
image_keys = get_image_keys(root)
|
||||
|
||||
ep_idx = 0
|
||||
chunk_idx = 0
|
||||
file_idx = 0
|
||||
size_in_mb = 0
|
||||
@@ -211,9 +214,23 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
|
||||
logging.info(f"Converting data files from {len(ep_paths)} episodes")
|
||||
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc="convert data files"):
|
||||
for ep_idx, ep_path in enumerate(tqdm.tqdm(ep_paths, desc="convert data files")):
|
||||
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
|
||||
ep_num_frames = get_parquet_num_frames(ep_path)
|
||||
|
||||
# Check if we need to start a new file BEFORE creating metadata
|
||||
if size_in_mb + ep_size_in_mb >= data_file_size_in_mb and len(paths_to_cat) > 0:
|
||||
# Write the accumulated data files
|
||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||
|
||||
# Move to next file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
|
||||
# Reset for the next file
|
||||
size_in_mb = 0
|
||||
paths_to_cat = []
|
||||
|
||||
# Now create metadata with correct chunk/file indices
|
||||
ep_metadata = {
|
||||
"episode_index": ep_idx,
|
||||
"data/chunk_index": chunk_idx,
|
||||
@@ -224,20 +241,7 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
size_in_mb += ep_size_in_mb
|
||||
num_frames += ep_num_frames
|
||||
episodes_metadata.append(ep_metadata)
|
||||
ep_idx += 1
|
||||
|
||||
if size_in_mb < data_file_size_in_mb:
|
||||
paths_to_cat.append(ep_path)
|
||||
continue
|
||||
|
||||
if paths_to_cat:
|
||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||
|
||||
# Reset for the next file
|
||||
size_in_mb = ep_size_in_mb
|
||||
paths_to_cat = [ep_path]
|
||||
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
paths_to_cat.append(ep_path)
|
||||
|
||||
# Write remaining data if any
|
||||
if paths_to_cat:
|
||||
@@ -469,7 +473,7 @@ def convert_dataset(
|
||||
|
||||
# Set root based on whether local dataset path is provided
|
||||
use_local_dataset = False
|
||||
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root) / repo_id
|
||||
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root)
|
||||
if root.exists():
|
||||
validate_local_dataset_version(root)
|
||||
use_local_dataset = True
|
||||
@@ -553,7 +557,7 @@ if __name__ == "__main__":
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Local directory to use for downloading/writing the dataset.",
|
||||
help="Local directory to use for downloading/writing the dataset. Defaults to $HF_LEROBOT_HOME/repo_id.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
|
||||
@@ -346,6 +346,105 @@ class LiberoEnv(EnvConfig):
|
||||
return kwargs
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("libero_plus")
|
||||
@dataclass
|
||||
class LiberoPlusEnv(LiberoEnv):
|
||||
"""Alias config for LIBERO-plus benchmarks.
|
||||
|
||||
LIBERO-plus keeps the same Python package/module names as LIBERO, so this
|
||||
config reuses the existing LIBERO env implementation while making intent explicit
|
||||
in experiment configs (`env.type=libero_plus`).
|
||||
"""
|
||||
|
||||
task: str = "libero_spatial"
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("robocasa")
|
||||
@dataclass
|
||||
class RoboCasaEnv(EnvConfig):
|
||||
"""RoboCasa kitchen composite-task environments.
|
||||
|
||||
Wraps ``robocasa.wrappers.gym_wrapper.RoboCasaGymEnv`` with a flat 12-D Box
|
||||
action space and a structured pixel + state observation dict.
|
||||
|
||||
Selected benchmark tasks (3 short + 2 long):
|
||||
Short: PickPlaceCounterToCabinet, PrepareToast, CoffeeSetupMug
|
||||
Long: PrepareCoffee, RestockPantry
|
||||
"""
|
||||
|
||||
task: str = "PickPlaceCounterToCabinet"
|
||||
tasks: list[str] | None = None # multi-task: list of task names (without robocasa/ prefix)
|
||||
fps: int = 20
|
||||
episode_length: int = 500
|
||||
image_size: int = 128
|
||||
split: str = "target" # "pretrain" or "target"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(12,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"agentview_left": f"{OBS_IMAGES}.agentview_left",
|
||||
"agentview_right": f"{OBS_IMAGES}.agentview_right",
|
||||
"eye_in_hand": f"{OBS_IMAGES}.eye_in_hand",
|
||||
"robot_state": OBS_STATE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
|
||||
self.features[cam] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.image_size, self.image_size, 3)
|
||||
)
|
||||
self.features["robot_state"] = PolicyFeature(type=FeatureType.STATE, shape=(16,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {"split": self.split}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("robomme")
|
||||
@dataclass
|
||||
class RoboMMEEnv(EnvConfig):
|
||||
"""RoboMME memory-augmented manipulation benchmark (ManiSkill/SAPIEN).
|
||||
|
||||
16 tasks across 4 suites: Counting, Permanence, Reference, Imitation.
|
||||
Uses BenchmarkEnvBuilder from the robomme package.
|
||||
"""
|
||||
|
||||
task: str = "PickXtimes"
|
||||
fps: int = 10
|
||||
episode_length: int = 300
|
||||
action_space: str = "joint_angle"
|
||||
dataset_split: str = "test"
|
||||
task_ids: list[int] | None = None
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(8,)),
|
||||
"front_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
||||
"wrist_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"front_rgb": f"{OBS_IMAGES}.front",
|
||||
"wrist_rgb": f"{OBS_IMAGES}.wrist",
|
||||
OBS_STATE: OBS_STATE,
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"action_space": self.action_space,
|
||||
"dataset": self.dataset_split,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("metaworld")
|
||||
@dataclass
|
||||
class MetaworldEnv(EnvConfig):
|
||||
|
||||
@@ -20,11 +20,21 @@ import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
|
||||
from lerobot.envs.configs import (
|
||||
AlohaEnv,
|
||||
EnvConfig,
|
||||
HubEnvConfig,
|
||||
IsaaclabArenaEnv,
|
||||
LiberoEnv,
|
||||
LiberoPlusEnv,
|
||||
PushtEnv,
|
||||
RoboCasaEnv,
|
||||
RoboMMEEnv,
|
||||
)
|
||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import ProcessorStep
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep, RoboCasaProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
|
||||
@@ -35,6 +45,12 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "libero":
|
||||
return LiberoEnv(**kwargs)
|
||||
elif env_type == "libero_plus":
|
||||
return LiberoPlusEnv(**kwargs)
|
||||
elif env_type == "robocasa":
|
||||
return RoboCasaEnv(**kwargs)
|
||||
elif env_type == "robomme":
|
||||
return RoboMMEEnv(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
@@ -70,9 +86,13 @@ def make_env_pre_post_processors(
|
||||
return make_xvla_libero_pre_post_processors()
|
||||
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
if isinstance(env_cfg, (LiberoEnv, LiberoPlusEnv)) or "libero" in env_cfg.type:
|
||||
preprocessor_steps.append(LiberoProcessorStep())
|
||||
|
||||
# For RoboCasa environments, add the RoboCasaProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, RoboCasaEnv) or "robocasa" in env_cfg.type:
|
||||
preprocessor_steps.append(RoboCasaProcessorStep())
|
||||
|
||||
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
|
||||
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
|
||||
# Parse comma-separated keys (handle None for state-based policies)
|
||||
@@ -181,6 +201,33 @@ def make_env(
|
||||
control_mode=cfg.control_mode,
|
||||
episode_length=cfg.episode_length,
|
||||
)
|
||||
elif "robocasa" in cfg.type:
|
||||
from lerobot.envs.robocasa import create_robocasa_envs
|
||||
|
||||
tasks = cfg.tasks if cfg.tasks else [cfg.task]
|
||||
return create_robocasa_envs(
|
||||
tasks=tasks,
|
||||
n_envs=n_envs,
|
||||
image_size=cfg.image_size,
|
||||
split=cfg.split,
|
||||
episode_length=cfg.episode_length,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
elif "robomme" in cfg.type:
|
||||
from lerobot.envs.robomme import create_robomme_envs
|
||||
|
||||
return create_robomme_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
action_space_type=cfg.action_space,
|
||||
dataset=cfg.dataset_split,
|
||||
episode_length=cfg.episode_length,
|
||||
task_ids=cfg.task_ids,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
elif "metaworld" in cfg.type:
|
||||
from lerobot.envs.metaworld import create_metaworld_envs
|
||||
|
||||
|
||||
@@ -26,8 +26,14 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
try:
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
except ImportError:
|
||||
# LIBERO-plus may be installed from source with an extra nested package level.
|
||||
from libero.libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
from lerobot.processor import RobotObservation
|
||||
|
||||
|
||||
@@ -0,0 +1,273 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
|
||||
# Action layout (flat 12D, normalized to [-1, 1]):
|
||||
# [0:3] end_effector_position (delta x, y, z)
|
||||
# [3:6] end_effector_rotation (delta roll, pitch, yaw)
|
||||
# [6:7] gripper_close (open=-1, close=+1)
|
||||
# [7:11] base_motion (x, y, theta, torso_height)
|
||||
# [11:12] control_mode (arm=-1, base=+1)
|
||||
ACTION_DIM = 12
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
|
||||
# Proprioceptive state layout (flat 16D):
|
||||
# [0:2] gripper_qpos
|
||||
# [2:5] base_position
|
||||
# [5:9] base_rotation (quaternion)
|
||||
# [9:12] end_effector_position_relative
|
||||
# [12:16] end_effector_rotation_relative (quaternion)
|
||||
STATE_DIM = 16
|
||||
|
||||
# Obs dict keys from RoboCasaGymEnv.get_observation()
|
||||
_CAM_KEYS = (
|
||||
"video.robot0_agentview_left",
|
||||
"video.robot0_agentview_right",
|
||||
"video.robot0_eye_in_hand",
|
||||
)
|
||||
_STATE_KEYS_ORDERED = (
|
||||
"state.gripper_qpos", # (2,)
|
||||
"state.base_position", # (3,)
|
||||
"state.base_rotation", # (4,)
|
||||
"state.end_effector_position_relative", # (3,)
|
||||
"state.end_effector_rotation_relative", # (4,)
|
||||
)
|
||||
|
||||
# Mapping from video.* key → short image name used in features_map
|
||||
CAM_KEY_TO_NAME = {
|
||||
"video.robot0_agentview_left": "agentview_left",
|
||||
"video.robot0_agentview_right": "agentview_right",
|
||||
"video.robot0_eye_in_hand": "eye_in_hand",
|
||||
}
|
||||
|
||||
|
||||
def _flat_to_action_dict(flat: np.ndarray) -> dict[str, np.ndarray]:
|
||||
"""Convert a 12D flat action array to the Dict format expected by RoboCasaGymEnv."""
|
||||
return {
|
||||
"action.end_effector_position": flat[0:3],
|
||||
"action.end_effector_rotation": flat[3:6],
|
||||
"action.gripper_close": flat[6:7],
|
||||
"action.base_motion": flat[7:11],
|
||||
"action.control_mode": flat[11:12],
|
||||
}
|
||||
|
||||
|
||||
class RoboCasaEnv(gym.Env):
|
||||
"""Thin wrapper around RoboCasaGymEnv that provides a flat Box action space
|
||||
and a structured observation dict compatible with LeRobot policies.
|
||||
|
||||
Observations returned by step/reset:
|
||||
{
|
||||
"pixels": {
|
||||
"agentview_left": (H, W, 3) uint8,
|
||||
"agentview_right": (H, W, 3) uint8,
|
||||
"eye_in_hand": (H, W, 3) uint8,
|
||||
},
|
||||
"robot_state": (16,) float32,
|
||||
}
|
||||
|
||||
Actions: flat float32 ndarray of shape (12,), normalized to [-1, 1].
|
||||
"""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
split: str = "target",
|
||||
image_size: int = 128,
|
||||
render_mode: str = "rgb_array",
|
||||
episode_length: int = 500,
|
||||
**gym_kwargs: Any,
|
||||
):
|
||||
super().__init__()
|
||||
# Lazy import — robocasa is optional
|
||||
import robocasa.environments # noqa: F401 — registers all gym envs
|
||||
|
||||
self.task = task
|
||||
self.render_mode = render_mode
|
||||
self.image_size = image_size
|
||||
self._max_episode_steps = episode_length
|
||||
self._step_count = 0
|
||||
|
||||
self._env = gym.make(
|
||||
f"robocasa/{task}",
|
||||
split=split,
|
||||
camera_widths=image_size,
|
||||
camera_heights=image_size,
|
||||
**gym_kwargs,
|
||||
)
|
||||
|
||||
# Flat 12D Box action space
|
||||
self.action_space = spaces.Box(
|
||||
low=ACTION_LOW,
|
||||
high=ACTION_HIGH,
|
||||
shape=(ACTION_DIM,),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
images = {
|
||||
name: spaces.Box(low=0, high=255, shape=(image_size, image_size, 3), dtype=np.uint8)
|
||||
for name in CAM_KEY_TO_NAME.values()
|
||||
}
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(images),
|
||||
"robot_state": spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(STATE_DIM,), dtype=np.float32
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _format_obs(self, raw_obs: dict) -> dict:
|
||||
pixels = {
|
||||
CAM_KEY_TO_NAME[k]: raw_obs[k]
|
||||
for k in _CAM_KEYS
|
||||
if k in raw_obs
|
||||
}
|
||||
state_parts = [
|
||||
np.asarray(raw_obs[k], dtype=np.float32)
|
||||
for k in _STATE_KEYS_ORDERED
|
||||
if k in raw_obs
|
||||
]
|
||||
robot_state = np.concatenate(state_parts) if state_parts else np.zeros(STATE_DIM, dtype=np.float32)
|
||||
return {"pixels": pixels, "robot_state": robot_state}
|
||||
|
||||
def reset(self, seed: int | None = None, **kwargs) -> tuple[dict, dict]:
|
||||
super().reset(seed=seed)
|
||||
self._step_count = 0
|
||||
raw_obs, info = self._env.reset(seed=seed)
|
||||
info.setdefault("is_success", False)
|
||||
info["task"] = self.task
|
||||
return self._format_obs(raw_obs), info
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[dict, float, bool, bool, dict]:
|
||||
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
|
||||
raise ValueError(
|
||||
f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}"
|
||||
)
|
||||
action_dict = _flat_to_action_dict(action)
|
||||
raw_obs, reward, terminated, truncated, info = self._env.step(action_dict)
|
||||
self._step_count += 1
|
||||
|
||||
is_success = bool(info.get("success", False))
|
||||
terminated = terminated or is_success
|
||||
if self._step_count >= self._max_episode_steps:
|
||||
truncated = True
|
||||
|
||||
info.update({"task": self.task, "is_success": is_success})
|
||||
obs = self._format_obs(raw_obs)
|
||||
|
||||
if terminated or truncated:
|
||||
info["final_info"] = {"task": self.task, "is_success": is_success}
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def render(self) -> np.ndarray | None:
|
||||
if self.render_mode == "rgb_array":
|
||||
return self._env.render()
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
self._env.close()
|
||||
|
||||
|
||||
def _make_env_fns(
|
||||
*,
|
||||
task: str,
|
||||
n_envs: int,
|
||||
image_size: int,
|
||||
split: str,
|
||||
episode_length: int,
|
||||
gym_kwargs: dict[str, Any],
|
||||
) -> list[Callable[[], RoboCasaEnv]]:
|
||||
"""Build n_envs factory callables for a single task."""
|
||||
def _make(episode_index: int) -> RoboCasaEnv: # noqa: ARG001
|
||||
return RoboCasaEnv(
|
||||
task=task,
|
||||
split=split,
|
||||
image_size=image_size,
|
||||
episode_length=episode_length,
|
||||
**gym_kwargs,
|
||||
)
|
||||
|
||||
return [partial(_make, i) for i in range(n_envs)]
|
||||
|
||||
|
||||
def create_robocasa_envs(
|
||||
tasks: str | Sequence[str],
|
||||
n_envs: int,
|
||||
image_size: int = 128,
|
||||
split: str = "target",
|
||||
episode_length: int = 500,
|
||||
gym_kwargs: dict[str, Any] | None = None,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""Create vectorized RoboCasa environments.
|
||||
|
||||
Args:
|
||||
tasks: A single task name or list of task names (without "robocasa/" prefix).
|
||||
E.g. "PickPlaceCounterToCabinet" or ["BoilPot", "PrepareCoffee"].
|
||||
n_envs: Number of parallel envs per task.
|
||||
image_size: Square image resolution for all cameras.
|
||||
split: RoboCasa dataset split — "pretrain" or "target".
|
||||
episode_length: Max steps per episode before truncation.
|
||||
gym_kwargs: Extra kwargs forwarded to each RoboCasaEnv.
|
||||
env_cls: Callable to wrap list of factory fns (SyncVectorEnv or AsyncVectorEnv).
|
||||
|
||||
Returns:
|
||||
dict[task_name][task_id=0] -> vec_env
|
||||
"""
|
||||
if env_cls is None or not callable(env_cls):
|
||||
raise ValueError("env_cls must be a callable wrapping a list of env factory callables.")
|
||||
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||
|
||||
if isinstance(tasks, str):
|
||||
task_list = [t.strip() for t in tasks.split(",") if t.strip()]
|
||||
else:
|
||||
task_list = [str(t).strip() for t in tasks if str(t).strip()]
|
||||
if not task_list:
|
||||
raise ValueError("`tasks` must contain at least one task name.")
|
||||
|
||||
gym_kwargs = dict(gym_kwargs or {})
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
|
||||
print(f"Creating RoboCasa envs | tasks={task_list} | n_envs(per task)={n_envs} | split={split}")
|
||||
for task in task_list:
|
||||
fns = _make_env_fns(
|
||||
task=task,
|
||||
n_envs=n_envs,
|
||||
image_size=image_size,
|
||||
split=split,
|
||||
episode_length=episode_length,
|
||||
gym_kwargs=gym_kwargs,
|
||||
)
|
||||
out["robocasa"][len(out["robocasa"])] = env_cls(fns)
|
||||
print(f" Built vec env | task={task} | n_envs={n_envs}")
|
||||
|
||||
return {suite: dict(task_map) for suite, task_map in out.items()}
|
||||
@@ -0,0 +1,154 @@
|
||||
"""RoboMME environment wrapper for LeRobot evaluation.
|
||||
|
||||
Wraps the RoboMME ``BenchmarkEnvBuilder`` into a Gymnasium-compatible
|
||||
``VectorEnv`` suitable for ``lerobot_eval``.
|
||||
|
||||
RoboMME tasks:
|
||||
Counting: BinFill, PickXtimes, SwingXtimes, StopCube
|
||||
Permanence: VideoUnmask, VideoUnmaskSwap, ButtonUnmask, ButtonUnmaskSwap
|
||||
Reference: PickHighlight, VideoRepick, VideoPlaceButton, VideoPlaceOrder
|
||||
Imitation: MoveCube, InsertPeg, PatternLock, RouteStick
|
||||
|
||||
Install: pip install robomme (or from source: https://github.com/RoboMME/robomme_benchmark)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
ROBOMME_TASKS = [
|
||||
"BinFill", "PickXtimes", "SwingXtimes", "StopCube",
|
||||
"VideoUnmask", "VideoUnmaskSwap", "ButtonUnmask", "ButtonUnmaskSwap",
|
||||
"PickHighlight", "VideoRepick", "VideoPlaceButton", "VideoPlaceOrder",
|
||||
"MoveCube", "InsertPeg", "PatternLock", "RouteStick",
|
||||
]
|
||||
|
||||
|
||||
class RoboMMEGymEnv(gym.Env):
|
||||
"""Thin Gymnasium wrapper around a single RoboMME episode env."""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"]}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: str = "PickXtimes",
|
||||
action_space_type: str = "joint_angle",
|
||||
dataset: str = "test",
|
||||
episode_idx: int = 0,
|
||||
max_steps: int = 300,
|
||||
):
|
||||
super().__init__()
|
||||
from robomme.env_record_wrapper import BenchmarkEnvBuilder
|
||||
|
||||
self._task = task
|
||||
self._action_space_type = action_space_type
|
||||
self._dataset = dataset
|
||||
self._episode_idx = episode_idx
|
||||
self._max_steps = max_steps
|
||||
|
||||
self._builder = BenchmarkEnvBuilder(
|
||||
env_id=task,
|
||||
dataset=dataset,
|
||||
action_space=action_space_type,
|
||||
gui_render=False,
|
||||
max_steps=max_steps,
|
||||
)
|
||||
self._env = None
|
||||
|
||||
action_dim = 8 if action_space_type == "joint_angle" else 7
|
||||
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(action_dim,), dtype=np.float32)
|
||||
self.observation_space = spaces.Dict({
|
||||
"front_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
|
||||
"wrist_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
|
||||
"state": spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32),
|
||||
})
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
super().reset(seed=seed)
|
||||
self._env = self._builder.make_env_for_episode(
|
||||
episode_idx=self._episode_idx, max_steps=self._max_steps,
|
||||
)
|
||||
obs, info = self._env.reset()
|
||||
return self._convert_obs(obs), self._convert_info(info)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self._env.step(action)
|
||||
|
||||
terminated_bool = bool(terminated.item()) if hasattr(terminated, "item") else bool(terminated)
|
||||
truncated_bool = bool(truncated.item()) if hasattr(truncated, "item") else bool(truncated)
|
||||
|
||||
status = info.get("status", "ongoing")
|
||||
is_success = status == "success"
|
||||
conv_info = self._convert_info(info)
|
||||
conv_info["is_success"] = is_success
|
||||
|
||||
return self._convert_obs(obs), float(reward), terminated_bool, truncated_bool, conv_info
|
||||
|
||||
def _convert_obs(self, obs: dict) -> dict:
|
||||
front_rgb = obs["front_rgb_list"][-1] if isinstance(obs["front_rgb_list"], list) else obs["front_rgb_list"]
|
||||
wrist_rgb = obs["wrist_rgb_list"][-1] if isinstance(obs["wrist_rgb_list"], list) else obs["wrist_rgb_list"]
|
||||
joint_state = obs["joint_state_list"][-1] if isinstance(obs["joint_state_list"], list) else obs["joint_state_list"]
|
||||
gripper_state = obs["gripper_state_list"][-1] if isinstance(obs["gripper_state_list"], list) else obs["gripper_state_list"]
|
||||
|
||||
front_rgb = np.asarray(front_rgb, dtype=np.uint8)
|
||||
wrist_rgb = np.asarray(wrist_rgb, dtype=np.uint8)
|
||||
joint = np.asarray(joint_state, dtype=np.float32).flatten()[:7]
|
||||
gripper = np.asarray(gripper_state, dtype=np.float32).flatten()[:1]
|
||||
state = np.concatenate([joint, gripper])
|
||||
|
||||
return {
|
||||
"front_rgb": front_rgb,
|
||||
"wrist_rgb": wrist_rgb,
|
||||
"state": state,
|
||||
}
|
||||
|
||||
def _convert_info(self, info: dict) -> dict:
|
||||
return {
|
||||
"status": info.get("status", "ongoing"),
|
||||
"task_goal": info.get("task_goal", ""),
|
||||
}
|
||||
|
||||
|
||||
def create_robomme_envs(
|
||||
task: str,
|
||||
n_envs: int = 1,
|
||||
action_space_type: str = "joint_angle",
|
||||
dataset: str = "test",
|
||||
episode_length: int = 300,
|
||||
task_ids: list[int] | None = None,
|
||||
env_cls=None,
|
||||
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
"""Create vectorized RoboMME environments for evaluation.
|
||||
|
||||
Returns {suite_name: {task_id: VectorEnv}} matching lerobot's expected format.
|
||||
"""
|
||||
if env_cls is None:
|
||||
env_cls = gym.vector.SyncVectorEnv
|
||||
|
||||
if task_ids is None:
|
||||
task_ids = [0]
|
||||
|
||||
suite_name = "robomme"
|
||||
envs_by_task = {}
|
||||
|
||||
for task_id in task_ids:
|
||||
def _make_one(ep_idx=task_id):
|
||||
return RoboMMEGymEnv(
|
||||
task=task,
|
||||
action_space_type=action_space_type,
|
||||
dataset=dataset,
|
||||
episode_idx=ep_idx,
|
||||
max_steps=episode_length,
|
||||
)
|
||||
|
||||
vec = env_cls(
|
||||
[_make_one for _ in range(n_envs)],
|
||||
autoreset_mode=gym.vector.AutoresetMode.SAME_STEP,
|
||||
)
|
||||
envs_by_task[task_id] = vec
|
||||
|
||||
return {suite_name: envs_by_task}
|
||||
@@ -29,7 +29,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from pprint import pformat
|
||||
from typing import Protocol, TypeAlias
|
||||
from typing import Protocol
|
||||
|
||||
import serial
|
||||
from deepdiff import DeepDiff
|
||||
@@ -38,8 +38,8 @@ from tqdm import tqdm
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
NameOrID: TypeAlias = str | int
|
||||
Value: TypeAlias = int | float
|
||||
type NameOrID = str | int
|
||||
type Value = int | float
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1277,4 +1277,4 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
MotorsBus: TypeAlias = SerialMotorsBus
|
||||
MotorsBus = SerialMotorsBus
|
||||
|
||||
@@ -18,10 +18,9 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, Unpack
|
||||
|
||||
import torch
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
@@ -4,17 +4,16 @@
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
|
||||
from typing import Optional
|
||||
|
||||
from transformers.image_processing_utils import (
|
||||
BatchFeature,
|
||||
get_patch_output_size,
|
||||
)
|
||||
from transformers.image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
ImagesKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
@@ -77,7 +76,7 @@ def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> tor
|
||||
return img[:, top:bottom, left:right]
|
||||
|
||||
|
||||
class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
class Eagle25VLFastImageProcessorKwargs(ImagesKwargs):
|
||||
max_dynamic_tiles: int | None
|
||||
min_dynamic_tiles: int | None
|
||||
use_thumbnail: bool | None
|
||||
@@ -165,11 +164,11 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
def _resize_for_patching(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
image: torch.Tensor,
|
||||
target_resolution: tuple,
|
||||
interpolation: "F.InterpolationMode",
|
||||
interpolation: F.InterpolationMode,
|
||||
input_data_format: ChannelDimension,
|
||||
) -> "torch.Tensor":
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Resizes an image to a target resolution while maintaining aspect ratio.
|
||||
|
||||
@@ -219,8 +218,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
return best_ratio
|
||||
|
||||
def _pad_for_patching(
|
||||
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> "torch.Tensor":
|
||||
self, image: torch.Tensor, target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
@@ -236,15 +235,15 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
def _get_image_patches(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
image: torch.Tensor,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
size: tuple,
|
||||
tile_size: int,
|
||||
use_thumbnail: bool,
|
||||
interpolation: "F.InterpolationMode",
|
||||
interpolation: F.InterpolationMode,
|
||||
pad_during_tiling: bool,
|
||||
) -> list["torch.Tensor"]:
|
||||
) -> list[torch.Tensor]:
|
||||
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
||||
orig_height, orig_width = image_size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
@@ -305,8 +304,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
def _pad_for_batching(
|
||||
self,
|
||||
pixel_values: list["torch.Tensor"],
|
||||
) -> list["torch.Tensor"]:
|
||||
pixel_values: list[torch.Tensor],
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
||||
|
||||
@@ -327,14 +326,14 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
images: list[torch.Tensor],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
max_dynamic_tiles: int,
|
||||
min_dynamic_tiles: int,
|
||||
use_thumbnail: bool,
|
||||
pad_during_tiling: bool,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
interpolation: F.InterpolationMode | None,
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
|
||||
@@ -15,16 +15,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -32,13 +32,21 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaForCausalLM,
|
||||
_gated_residual,
|
||||
layernorm_forward,
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
modeling_gemma = None
|
||||
GemmaForCausalLM = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
layernorm_forward = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||
@@ -191,7 +199,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
resized_images = resized_images.clamp(0.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
@@ -202,7 +210,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
@@ -221,14 +229,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.language_model, gemma_expert.model]
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
@@ -254,10 +262,10 @@ def compute_layer_complete(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.language_model.layers[layer_idx].self_attn,
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -265,7 +273,7 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
@@ -277,15 +285,15 @@ def compute_layer_complete(
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||
# first residual
|
||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
||||
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
||||
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
# second residual
|
||||
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
||||
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||
outputs_embeds.append(out_emb)
|
||||
start_pos = end_pos
|
||||
return outputs_embeds
|
||||
@@ -358,7 +366,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
@@ -366,7 +374,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
|
||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||
head_dim=action_expert_config.head_dim,
|
||||
@@ -377,13 +385,13 @@ class PaliGemmaWithExpertModel(
|
||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||
vocab_size=257152,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
torch_dtype="float32",
|
||||
dtype="float32",
|
||||
use_adarms=use_adarms[1],
|
||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||
)
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
||||
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
@@ -398,10 +406,11 @@ class PaliGemmaWithExpertModel(
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||
# "same dtype" error). Align with PI05.
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
@@ -413,8 +422,8 @@ class PaliGemmaWithExpertModel(
|
||||
|
||||
def _set_requires_grad(self):
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
for param in self.paligemma.vision_tower.parameters():
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
for param in self.paligemma.model.vision_tower.parameters():
|
||||
param.requires_grad = False
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
@@ -424,15 +433,23 @@ class PaliGemmaWithExpertModel(
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.paligemma.model.get_image_features(image)
|
||||
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -446,7 +463,7 @@ class PaliGemmaWithExpertModel(
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.language_model.forward(
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@@ -470,7 +487,7 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.language_model, self.gemma_expert.model]
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
@@ -510,7 +527,7 @@ class PaliGemmaWithExpertModel(
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -576,29 +593,19 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Also compile the main forward pass used during training
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
@@ -760,7 +767,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
||||
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -834,7 +841,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks_4d,
|
||||
@@ -908,6 +915,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
@@ -997,14 +1005,12 @@ class PI0Policy(PreTrainedPolicy):
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# Now manually load and remap the state dict
|
||||
# Load state dict (expects keys with "model." prefix)
|
||||
try:
|
||||
# Try to load the pytorch_model.bin or model.safetensors file
|
||||
print(f"Loading model from: {pretrained_name_or_path}")
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_name_or_path,
|
||||
"model.safetensors",
|
||||
@@ -1012,7 +1018,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
token=kwargs.get("token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
@@ -1025,7 +1031,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
@@ -1070,7 +1076,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
print("All keys loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not remap state dict keys: {e}")
|
||||
print(f"Warning: Could not load state dict: {e}")
|
||||
|
||||
return model
|
||||
|
||||
@@ -1120,6 +1126,14 @@ class PI0Policy(PreTrainedPolicy):
|
||||
# Some checkpoints might have this, but current model expects different structure
|
||||
logging.warning(f"Vision embedding key might need handling: {key}")
|
||||
|
||||
if (
|
||||
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
|
||||
or key == "paligemma_with_expert.paligemma.lm_head.weight"
|
||||
):
|
||||
fixed_state_dict[
|
||||
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
] = value.clone()
|
||||
|
||||
fixed_state_dict[new_key] = value
|
||||
|
||||
return fixed_state_dict
|
||||
|
||||
@@ -15,16 +15,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -32,14 +32,20 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaForCausalLM,
|
||||
_gated_residual,
|
||||
layernorm_forward,
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
modeling_gemma = None
|
||||
GemmaForCausalLM = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
layernorm_forward = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
@@ -92,10 +98,11 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
||||
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
||||
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
|
||||
alpha_t = torch.tensor(alpha, dtype=torch.float32)
|
||||
beta_t = torch.tensor(beta, dtype=torch.float32)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,))
|
||||
return dist.sample((bsize,)).to(device)
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
|
||||
@@ -189,7 +196,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
resized_images = resized_images.clamp(0.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
@@ -200,7 +207,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
@@ -219,14 +226,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.language_model, gemma_expert.model]
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
@@ -252,10 +259,10 @@ def compute_layer_complete(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.language_model.layers[layer_idx].self_attn,
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -263,7 +270,7 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
@@ -275,15 +282,15 @@ def compute_layer_complete(
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||
# first residual
|
||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
||||
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
||||
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
# second residual
|
||||
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
||||
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||
outputs_embeds.append(out_emb)
|
||||
start_pos = end_pos
|
||||
return outputs_embeds
|
||||
@@ -356,7 +363,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
@@ -364,7 +371,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
|
||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||
head_dim=action_expert_config.head_dim,
|
||||
@@ -375,13 +382,13 @@ class PaliGemmaWithExpertModel(
|
||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||
vocab_size=257152,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
torch_dtype="float32",
|
||||
dtype="float32",
|
||||
use_adarms=use_adarms[1],
|
||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||
)
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
||||
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
@@ -396,10 +403,11 @@ class PaliGemmaWithExpertModel(
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||
# "same dtype" error). Saves memory vs full float32; more memory than only 3 params.
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
@@ -411,8 +419,8 @@ class PaliGemmaWithExpertModel(
|
||||
|
||||
def _set_requires_grad(self):
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
for param in self.paligemma.vision_tower.parameters():
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
for param in self.paligemma.model.vision_tower.parameters():
|
||||
param.requires_grad = False
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
@@ -422,15 +430,23 @@ class PaliGemmaWithExpertModel(
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.paligemma.model.get_image_features(image)
|
||||
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32).
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -444,7 +460,7 @@ class PaliGemmaWithExpertModel(
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.language_model.forward(
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@@ -468,7 +484,7 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.language_model, self.gemma_expert.model]
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
@@ -508,7 +524,7 @@ class PaliGemmaWithExpertModel(
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -573,29 +589,19 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Also compile the main forward pass used during training
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
||||
|
||||
@@ -737,7 +743,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -808,7 +814,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks_4d,
|
||||
@@ -880,6 +886,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
@@ -969,14 +976,12 @@ class PI05Policy(PreTrainedPolicy):
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# Now manually load and remap the state dict
|
||||
# Load state dict (expects keys with "model." prefix)
|
||||
try:
|
||||
# Try to load the pytorch_model.bin or model.safetensors file
|
||||
print(f"Loading model from: {pretrained_name_or_path}")
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_name_or_path,
|
||||
"model.safetensors",
|
||||
@@ -984,7 +989,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
token=kwargs.get("token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
@@ -997,7 +1002,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
@@ -1009,8 +1014,6 @@ class PI05Policy(PreTrainedPolicy):
|
||||
new_key = f"model.{key}"
|
||||
remapped_state_dict[new_key] = value
|
||||
remap_count += 1
|
||||
if remap_count <= 10: # Only print first 10 to avoid spam
|
||||
print(f"Remapped: {key} -> {new_key}")
|
||||
else:
|
||||
remapped_state_dict[key] = value
|
||||
|
||||
@@ -1044,7 +1047,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
print("All keys loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not remap state dict keys: {e}")
|
||||
print(f"Warning: Could not load state dict: {e}")
|
||||
|
||||
return model
|
||||
|
||||
@@ -1098,6 +1101,14 @@ class PI05Policy(PreTrainedPolicy):
|
||||
# Some checkpoints might have this, but current model expects different structure
|
||||
logging.warning(f"Vision embedding key might need handling: {key}")
|
||||
|
||||
if (
|
||||
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
|
||||
or key == "paligemma_with_expert.paligemma.lm_head.weight"
|
||||
):
|
||||
fixed_state_dict[
|
||||
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
] = value.clone()
|
||||
|
||||
fixed_state_dict[new_key] = value
|
||||
|
||||
return fixed_state_dict
|
||||
|
||||
@@ -23,7 +23,6 @@ import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
@@ -68,9 +67,6 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.max_state_dim)
|
||||
|
||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
|
||||
@@ -54,7 +54,7 @@ class PI0FastConfig(PreTrainedConfig):
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
|
||||
action_tokenizer_name: str = "physical-intelligence/fast"
|
||||
action_tokenizer_name: str = "lerobot/fast-action-tokenizer"
|
||||
temperature: float = 0.0
|
||||
max_decoding_steps: int = 256
|
||||
fast_skip_tokens: int = 128
|
||||
|
||||
@@ -19,13 +19,12 @@ import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _scipy_available, _transformers_available
|
||||
|
||||
@@ -38,11 +37,16 @@ else:
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaModel,
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
AutoTokenizer = None
|
||||
PiGemmaModel = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||
@@ -121,7 +125,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
resized_images = resized_images.clamp(0.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
@@ -132,7 +136,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
@@ -206,16 +210,22 @@ class PI0FastPaliGemma(nn.Module):
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
|
||||
# Use PI Gemma (AdaRMS) as language model when use_adarms[0] is True so that
|
||||
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
|
||||
if use_adarms[0]:
|
||||
text_config = self.paligemma.config.text_config
|
||||
self.paligemma.model.language_model = PiGemmaModel(text_config)
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
|
||||
@@ -228,10 +238,11 @@ class PI0FastPaliGemma(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||
# "same dtype" error). Align with PI05.
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
@@ -242,10 +253,18 @@ class PI0FastPaliGemma(nn.Module):
|
||||
param.data = param.data.to(dtype=torch.float32)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.paligemma.model.get_image_features(image)
|
||||
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -259,7 +278,7 @@ class PI0FastPaliGemma(nn.Module):
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.language_model.forward(
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@@ -306,24 +325,14 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode)
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
# Call the proper gradient_checkpointing_enable() method with use_reentrant=False for better memory efficiency
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_enable(
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_enable(
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
logging.info("Enabled gradient checkpointing for PI0FastPytorch model")
|
||||
@@ -332,8 +341,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
# Call the proper gradient_checkpointing_disable() method
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_disable()
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_disable()
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_disable()
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_disable()
|
||||
logging.info("Disabled gradient checkpointing for PI0FastPytorch model")
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
@@ -523,7 +532,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
# Convert embeddings to bfloat16 if needed
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -616,7 +625,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -714,7 +723,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
# Ensure correct precision (bfloat16/float32)
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -897,14 +906,12 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# Now manually load and remap the state dict
|
||||
# Load state dict (expects keys with "model." prefix)
|
||||
try:
|
||||
# Try to load the pytorch_model.bin or model.safetensors file
|
||||
print(f"Loading model from: {pretrained_name_or_path}")
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_name_or_path,
|
||||
"model.safetensors",
|
||||
@@ -912,7 +919,7 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
token=kwargs.get("token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
@@ -925,8 +932,9 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
remapped_state_dict = {}
|
||||
remap_count = 0
|
||||
@@ -936,8 +944,6 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
new_key = f"model.{key}"
|
||||
remapped_state_dict[new_key] = value
|
||||
remap_count += 1
|
||||
if remap_count <= 10: # Only print first 10 to avoid spam
|
||||
print(f"Remapped: {key} -> {new_key}")
|
||||
else:
|
||||
remapped_state_dict[key] = value
|
||||
|
||||
@@ -971,7 +977,7 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
print("All keys loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not remap state dict keys: {e}")
|
||||
print(f"Warning: Could not load state dict: {e}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||
from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
|
||||
from lerobot.processor import (
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -69,9 +68,6 @@ class Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.max_state_dim)
|
||||
|
||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
|
||||
@@ -0,0 +1,363 @@
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.masking_utils import create_causal_mask
|
||||
from transformers.modeling_layers import GradientCheckpointingLayer
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.gemma.modeling_gemma import (
|
||||
GemmaAttention,
|
||||
GemmaConfig,
|
||||
GemmaForCausalLM,
|
||||
GemmaMLP,
|
||||
GemmaModel,
|
||||
)
|
||||
from transformers.models.paligemma.modeling_paligemma import (
|
||||
PaliGemmaForConditionalGeneration,
|
||||
PaliGemmaModel,
|
||||
)
|
||||
else:
|
||||
GemmaAttention = None
|
||||
GemmaConfig = None
|
||||
GemmaForCausalLM = None
|
||||
GemmaMLP = None
|
||||
GemmaModel = None
|
||||
PaliGemmaModel = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
DynamicCache = None
|
||||
GradientCheckpointingLayer = None
|
||||
BaseModelOutputWithPast = None
|
||||
create_causal_mask = None
|
||||
|
||||
|
||||
def _gated_residual(
|
||||
x: torch.Tensor | None,
|
||||
y: torch.Tensor | None,
|
||||
gate: torch.Tensor | None,
|
||||
) -> torch.Tensor | None:
|
||||
"""Gated residual: x + y when gate is None, else x + y * gate."""
|
||||
if x is None and y is None:
|
||||
return None
|
||||
if x is None or y is None:
|
||||
return x if x is not None else y
|
||||
if gate is None:
|
||||
return x + y
|
||||
return x + y * gate
|
||||
|
||||
|
||||
def layernorm_forward(
|
||||
layernorm: nn.Module,
|
||||
x: torch.Tensor,
|
||||
cond: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
call layernorm and return hidden states and gate
|
||||
if cond is not None, use conditional norm
|
||||
otherwise, use normal gemma norm
|
||||
"""
|
||||
if cond is not None:
|
||||
return layernorm(x, cond=cond)
|
||||
else:
|
||||
return layernorm(x)
|
||||
|
||||
|
||||
class PiGemmaRMSNorm(nn.Module):
|
||||
"""
|
||||
Adaptive RMSNorm for PI Gemma (AdaRMS).
|
||||
When cond_dim is set, uses cond to modulate scale/shift/gate; otherwise behaves like standard GemmaRMSNorm.
|
||||
forward(x, cond=None) returns (output, gate) for use with _gated_residual.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.dim = dim
|
||||
self.cond_dim = cond_dim
|
||||
if cond_dim is not None:
|
||||
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
|
||||
nn.init.zeros_(self.dense.weight)
|
||||
else:
|
||||
self.weight = nn.Parameter(torch.zeros(dim))
|
||||
self.dense = None
|
||||
|
||||
def _norm(self, x):
|
||||
# Compute variance in float32 (like the source implementation)
|
||||
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
|
||||
# Compute normalization in float32
|
||||
normed_inputs = x * torch.rsqrt(var + self.eps)
|
||||
return normed_inputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cond: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
dtype = x.dtype
|
||||
normed = self._norm(x)
|
||||
if cond is None or self.dense is None:
|
||||
normed = normed * (1.0 + self.weight.float())
|
||||
return normed.type_as(x), None
|
||||
if cond.shape[-1] != self.cond_dim:
|
||||
raise ValueError(f"Expected cond dim {self.cond_dim}, got {cond.shape[-1]}")
|
||||
modulation = self.dense(cond)
|
||||
if len(x.shape) == 3:
|
||||
modulation = modulation.unsqueeze(1)
|
||||
scale, shift, gate = modulation.chunk(3, dim=-1)
|
||||
normed = normed * (1 + scale.float()) + shift.float()
|
||||
return normed.to(dtype), gate.to(dtype)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
if self.dense is not None:
|
||||
return f"dim={self.dim}, eps={self.eps}, adaptive=True, cond_dim={self.cond_dim}"
|
||||
return f"dim={self.dim}, eps={self.eps}"
|
||||
|
||||
|
||||
def _get_pi_gemma_decoder_layer_base():
|
||||
"""base for PiGemmaDecoderLayer"""
|
||||
|
||||
class _PiGemmaDecoderLayerBase(GradientCheckpointingLayer):
|
||||
"""Decoder layer that uses PiGemmaRMSNorm and _gated_residual, compatible with v5 Gemma."""
|
||||
|
||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = GemmaMLP(config)
|
||||
cond_dim = (
|
||||
getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
|
||||
)
|
||||
self.input_layernorm = PiGemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
||||
)
|
||||
self.post_attention_layernorm = PiGemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values=None,
|
||||
use_cache: bool = False,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states, gate = self.input_layernorm(hidden_states, cond=adarms_cond)
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states, gate = self.post_attention_layernorm(hidden_states, cond=adarms_cond)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||
return hidden_states
|
||||
|
||||
return _PiGemmaDecoderLayerBase
|
||||
|
||||
|
||||
class PiGemmaModel(GemmaModel): # type: ignore[misc]
|
||||
"""
|
||||
GemmaModel extended with AdaRMS (adaptive RMSNorm) and gated residuals when config.use_adarms is True.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GemmaConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
# if not getattr(config, "use_adarms", False):
|
||||
# return
|
||||
cond_dim = getattr(config, "adarms_cond_dim", None)
|
||||
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
|
||||
self.layers = nn.ModuleList(
|
||||
[pi_gemma_decoder_layer_base(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = PiGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: DynamicCache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> BaseModelOutputWithPast:
|
||||
"""
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
import logging
|
||||
|
||||
logging.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
# Convert to bfloat16 if the first layer uses bfloat16
|
||||
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.bfloat16)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# normalized
|
||||
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
adarms_cond=adarms_cond,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, adarms_cond)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
|
||||
"""
|
||||
Causal LM wrapper using PiGemmaModel as the backbone, for consistency with GemmaForCausalLM
|
||||
and the language model used in pi0_fast. Use this for the action expert in pi0/pi05.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GemmaConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.model = PiGemmaModel(config)
|
||||
|
||||
|
||||
class PaliGemmaModelWithPiGemma(PaliGemmaModel):
|
||||
"""PaliGemmaModel whose language_model is PiGemmaModel (custom decoder with PiGemmaRMSNorm and gated residuals)."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.language_model = PiGemmaModel(config.text_config)
|
||||
|
||||
|
||||
class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGeneration):
|
||||
"""PaliGemmaForConditionalGeneration using PiGemma decoder for the language model."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = PaliGemmaModelWithPiGemma(config)
|
||||
|
||||
# Make modules available through conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
return self.model.language_model
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PiGemmaModel",
|
||||
"PiGemmaForCausalLM",
|
||||
"PiGemmaRMSNorm",
|
||||
"_gated_residual",
|
||||
"layernorm_forward",
|
||||
"PaliGemmaModelWithPiGemma",
|
||||
"PaliGemmaForConditionalGenerationWithPiGemma",
|
||||
]
|
||||
@@ -19,7 +19,7 @@ import os
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import TypedDict, TypeVar
|
||||
from typing import TypedDict, TypeVar, Unpack
|
||||
|
||||
import packaging
|
||||
import safetensors
|
||||
@@ -28,7 +28,6 @@ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
@@ -33,7 +33,7 @@ class RewardClassifierConfig(PreTrainedConfig):
|
||||
latent_dim: int = 256
|
||||
image_embedding_pooling_dim: int = 8
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "helper2424/resnet10"
|
||||
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
|
||||
@@ -54,12 +54,11 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import TypedDict
|
||||
from typing import TypedDict, Unpack
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
|
||||
@@ -55,7 +55,7 @@ class WallXConfig(PreTrainedConfig):
|
||||
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
|
||||
|
||||
# Tokenizer settings
|
||||
action_tokenizer_path: str | None = "physical-intelligence/fast"
|
||||
action_tokenizer_path: str | None = "lerobot/fast-action-tokenizer"
|
||||
|
||||
# Action prediction mode: "diffusion" or "fast"
|
||||
prediction_mode: str = "diffusion"
|
||||
|
||||
@@ -261,10 +261,15 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
and optional LoRA fine-tuning support.
|
||||
"""
|
||||
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
config_class = Qwen2_5_VLConfig
|
||||
_no_split_modules = ["Qwen2_5_VLDecoderLayer_with_MoE", "Qwen2_5_VLVisionBlock"]
|
||||
|
||||
def init_weights(self):
|
||||
if getattr(self.model, "language_model", None) is not None:
|
||||
return
|
||||
super().init_weights()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -312,6 +317,11 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
processor.action_processor = action_tokenizer
|
||||
else:
|
||||
action_tokenizer = None
|
||||
|
||||
# add pad_token_id to config
|
||||
config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
config.text_config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
|
||||
# Initialize model with configuration and processor
|
||||
model = cls(config, processor=processor, action_tokenizer=action_tokenizer, **kwargs)
|
||||
|
||||
@@ -331,7 +341,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
token=kwargs.get("token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
window_size=112,
|
||||
out_hidden_size=3584,
|
||||
fullatt_block_indexes=[7, 15, 23, 31],
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@@ -38,6 +39,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
self.window_size = window_size
|
||||
self.fullatt_block_indexes = fullatt_block_indexes
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
|
||||
@@ -11,7 +11,6 @@ from transformers.activations import ACT2FN
|
||||
from transformers.cache_utils import (
|
||||
Cache,
|
||||
DynamicCache,
|
||||
SlidingWindowCache,
|
||||
StaticCache,
|
||||
)
|
||||
from transformers.generation import GenerationMixin
|
||||
@@ -31,6 +30,15 @@ from transformers.utils import (
|
||||
|
||||
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
|
||||
|
||||
|
||||
# TODO(Steven): SlidingWindowCache was removed in transformers v5. Define a placeholder so isinstance checks
|
||||
# always return False (which is the correct behavior when no sliding window cache is in use).
|
||||
class _SlidingWindowCachePlaceholder:
|
||||
pass
|
||||
|
||||
|
||||
SlidingWindowCache = _SlidingWindowCachePlaceholder
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
@@ -594,19 +602,40 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _compute_default_rope_parameters_qwen2_5_vl(config, device=None):
|
||||
"""
|
||||
compute default rope parameters for Qwen2_5_VL
|
||||
"""
|
||||
base = config.text_config.rope_parameters["rope_theta"]
|
||||
dim = config.hidden_size // config.num_attention_heads
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
||||
)
|
||||
return inv_freq, 1.0
|
||||
|
||||
|
||||
class Qwen2_5_VLRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Qwen2_5_VLConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
elif hasattr(config, "rope_parameters") and config.rope_parameters is not None:
|
||||
self.rope_type = config.rope_parameters.get("rope_type", "default")
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
if self.rope_type == "default":
|
||||
self.rope_init_fn = _compute_default_rope_parameters_qwen2_5_vl
|
||||
self.rope_kwargs = {}
|
||||
else:
|
||||
rope_type_key = "linear" if self.rope_type == "linear" else self.rope_type
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type_key]
|
||||
self.rope_kwargs = {}
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
@@ -1567,7 +1596,7 @@ QWEN2_5_VL_INPUTS_DOCSTRING = r"""
|
||||
|
||||
|
||||
class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
config_class = Qwen2_5_VLConfig
|
||||
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ def preprocesser_call(
|
||||
"""
|
||||
# Process image inputs
|
||||
if images is not None and len(images) > 0:
|
||||
image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
|
||||
image_inputs = processor.image_processor(images=images, return_tensors=return_tensors)
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
@@ -152,7 +152,7 @@ def preprocesser_call(
|
||||
|
||||
# Process video inputs
|
||||
if videos is not None:
|
||||
videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
|
||||
videos_inputs = processor.image_processor(videos=videos, return_tensors=return_tensors)
|
||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||
else:
|
||||
videos_inputs = {}
|
||||
|
||||
@@ -276,6 +276,8 @@ class Florence2LanguageConfig(PretrainedConfig):
|
||||
)
|
||||
|
||||
# ensure backward compatibility for BART CNN models
|
||||
if not hasattr(self, "forced_bos_token_id"):
|
||||
self.forced_bos_token_id = None
|
||||
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
||||
self.forced_bos_token_id = self.bos_token_id
|
||||
warnings.warn(
|
||||
|
||||
@@ -1951,7 +1951,10 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
||||
|
||||
|
||||
class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
_tied_weights_keys = {
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: Florence2LanguageConfig):
|
||||
super().__init__(config)
|
||||
@@ -2076,7 +2079,10 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
||||
|
||||
class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"model.encoder.embed_tokens.weight": "model.shared.weight",
|
||||
"model.decoder.embed_tokens.weight": "model.shared.weight",
|
||||
}
|
||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||
|
||||
def __init__(self, config: Florence2LanguageConfig):
|
||||
@@ -2436,11 +2442,10 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
|
||||
FLORENCE2_START_DOCSTRING,
|
||||
)
|
||||
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
||||
_tied_weights_keys = [
|
||||
"language_model.encoder.embed_tokens.weight",
|
||||
"language_model.decoder.embed_tokens.weight",
|
||||
"language_model.lm_head.weight",
|
||||
]
|
||||
_tied_weights_keys = {
|
||||
"language_model.model.encoder.embed_tokens.weight": "language_model.model.shared.weight",
|
||||
"language_model.model.decoder.embed_tokens.weight": "language_model.model.shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: Florence2Config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -30,6 +30,12 @@ from .core import (
|
||||
)
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .factory import (
|
||||
make_default_processors,
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
make_default_teleop_action_processor,
|
||||
)
|
||||
from .gym_action_processor import (
|
||||
Numpy2TorchActionProcessorStep,
|
||||
Torch2NumpyActionProcessorStep,
|
||||
@@ -89,7 +95,11 @@ __all__ = [
|
||||
"ImageCropResizeProcessorStep",
|
||||
"InfoProcessorStep",
|
||||
"InterventionActionProcessorStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"make_default_processors",
|
||||
"make_default_teleop_action_processor",
|
||||
"make_default_robot_action_processor",
|
||||
"make_default_robot_observation_processor",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"NormalizerProcessorStep",
|
||||
"Numpy2TorchActionProcessorStep",
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, TypeAlias, TypedDict
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -36,10 +36,10 @@ class TransitionKey(str, Enum):
|
||||
COMPLEMENTARY_DATA = "complementary_data"
|
||||
|
||||
|
||||
PolicyAction: TypeAlias = torch.Tensor
|
||||
RobotAction: TypeAlias = dict[str, Any]
|
||||
EnvAction: TypeAlias = np.ndarray
|
||||
RobotObservation: TypeAlias = dict[str, Any]
|
||||
PolicyAction = torch.Tensor
|
||||
RobotAction = dict[str, Any]
|
||||
EnvAction = np.ndarray
|
||||
RobotObservation = dict[str, Any]
|
||||
|
||||
|
||||
EnvTransition = TypedDict(
|
||||
|
||||
@@ -153,6 +153,44 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="robocasa_processor")
|
||||
class RoboCasaProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes RoboCasa observations into LeRobot format.
|
||||
|
||||
The RoboCasaEnv wrapper returns:
|
||||
- ``pixels.<cam_name>``: (B, C, H, W) float32 images (already converted by vectorenv)
|
||||
- ``observation.robot_state``: (B, 16) float32 proprioception
|
||||
|
||||
This step remaps them to:
|
||||
- ``observation.images.<cam_name>`` (unchanged tensor)
|
||||
- ``observation.state`` (robot_state renamed)
|
||||
"""
|
||||
|
||||
def _process_observation(self, observation: dict) -> dict:
|
||||
processed = {}
|
||||
obs_prefix = OBS_PREFIX # "observation."
|
||||
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}."):
|
||||
# Already in the right place; pass through
|
||||
processed[key] = value
|
||||
elif key == OBS_STATE or key == f"{obs_prefix}robot_state":
|
||||
# Rename robot_state → observation.state
|
||||
processed[OBS_STATE] = value.float() if hasattr(value, "float") else value
|
||||
|
||||
return processed
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
return self._process_observation(observation)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="isaaclab_arena_processor")
|
||||
class IsaaclabArenaProcessorStep(ObservationProcessorStep):
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
from .converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
robot_action_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
@@ -25,44 +24,39 @@ from .core import RobotAction, RobotObservation
|
||||
from .pipeline import IdentityProcessorStep, RobotProcessorPipeline
|
||||
|
||||
|
||||
# ── Internal identity pipeline helpers (used by Robot/Teleoperator base classes) ──────────────────
|
||||
|
||||
|
||||
def _make_identity_observation_pipeline() -> RobotProcessorPipeline[RobotObservation, RobotObservation]:
|
||||
"""Identity pipeline for robot observations (get_observation output pipeline)."""
|
||||
return RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
|
||||
def _make_identity_robot_action_pipeline() -> RobotProcessorPipeline[
|
||||
def make_default_teleop_action_processor() -> RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
]:
|
||||
"""Identity pipeline for robot action input (send_action input pipeline, takes (action, obs) tuple)."""
|
||||
return RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
teleop_action_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
return teleop_action_processor
|
||||
|
||||
|
||||
def _make_identity_teleop_action_pipeline() -> RobotProcessorPipeline[RobotAction, RobotAction]:
|
||||
"""Identity pipeline for teleop action output (get_action output pipeline, takes just action)."""
|
||||
return RobotProcessorPipeline[RobotAction, RobotAction](
|
||||
def make_default_robot_action_processor() -> RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
]:
|
||||
robot_action_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_to_transition,
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
return robot_action_processor
|
||||
|
||||
|
||||
def _make_identity_feedback_pipeline() -> RobotProcessorPipeline[dict, dict]:
|
||||
"""Identity pipeline for teleop feedback input (send_feedback input pipeline)."""
|
||||
return RobotProcessorPipeline[dict, dict](
|
||||
def make_default_robot_observation_processor() -> RobotProcessorPipeline[RobotObservation, RobotObservation]:
|
||||
robot_observation_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
return robot_observation_processor
|
||||
|
||||
|
||||
def make_default_processors():
|
||||
teleop_action_processor = make_default_teleop_action_processor()
|
||||
robot_action_processor = make_default_robot_action_processor()
|
||||
robot_observation_processor = make_default_robot_observation_processor()
|
||||
return (teleop_action_processor, robot_action_processor, robot_observation_processor)
|
||||
|
||||
@@ -19,17 +19,15 @@ from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
from .converters import from_tensor_to_numpy, to_tensor
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry, RobotObservation
|
||||
|
||||
@@ -39,7 +39,7 @@ from collections.abc import Callable, Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast
|
||||
from typing import Any, TypedDict, TypeVar, cast
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -251,7 +251,7 @@ class ProcessorMigrationError(Exception):
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
||||
class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
"""A sequential pipeline for processing data, integrated with the Hugging Face Hub.
|
||||
|
||||
This class chains together multiple `ProcessorStep` instances to form a complete
|
||||
@@ -1432,8 +1432,8 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
||||
|
||||
|
||||
# Type aliases for semantic clarity.
|
||||
RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
|
||||
PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
|
||||
RobotProcessorPipeline = DataProcessorPipeline[TInput, TOutput]
|
||||
PolicyProcessorPipeline = DataProcessorPipeline[TInput, TOutput]
|
||||
|
||||
|
||||
class ObservationProcessorStep(ProcessorStep, ABC):
|
||||
|
||||
@@ -43,9 +43,12 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
from .core import EnvTransition, RobotObservation, TransitionKey
|
||||
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
# Type-checking only import — do NOT import transformers at module level (it loads TF which blocks)
|
||||
if TYPE_CHECKING:
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
else:
|
||||
AutoProcessor = None
|
||||
AutoTokenizer = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -103,7 +106,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
# Use provided tokenizer object directly
|
||||
self.input_tokenizer = self.tokenizer
|
||||
elif self.tokenizer_name is not None:
|
||||
from transformers import AutoTokenizer # lazy import to avoid TF deadlock at module load
|
||||
if AutoTokenizer is None:
|
||||
raise ImportError("AutoTokenizer is not available")
|
||||
self.input_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -332,7 +336,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Attributes:
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "lerobot/fast-action-tokenizer").
|
||||
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
||||
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
||||
@@ -366,12 +370,12 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionTokenizerProcessorStep."
|
||||
)
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer # lazy import to avoid TF deadlock at module load
|
||||
|
||||
if self.action_tokenizer_input_object is not None:
|
||||
self.action_tokenizer = self.action_tokenizer_input_object
|
||||
|
||||
elif self.action_tokenizer_name is not None:
|
||||
if AutoProcessor is None:
|
||||
raise ImportError("AutoProcessor is not available")
|
||||
self.action_tokenizer = AutoProcessor.from_pretrained(
|
||||
self.action_tokenizer_name, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
|
||||
@@ -102,11 +102,11 @@ class BiOpenArmFollower(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
@@ -136,7 +136,7 @@ class BiOpenArmFollower(Robot):
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
@@ -150,7 +150,7 @@ class BiOpenArmFollower(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(
|
||||
def send_action(
|
||||
self,
|
||||
action: RobotAction,
|
||||
custom_kp: dict[str, float] | None = None,
|
||||
|
||||
@@ -86,11 +86,11 @@ class BiSOFollower(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
@@ -119,7 +119,7 @@ class BiSOFollower(Robot):
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
@@ -133,7 +133,7 @@ class BiSOFollower(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
# Remove "left_" prefix
|
||||
left_action = {
|
||||
key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_")
|
||||
|
||||
@@ -147,7 +147,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
pass
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
"""Define the observation space for dataset recording.
|
||||
|
||||
Returns:
|
||||
@@ -184,7 +184,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
"""Define the action space.
|
||||
|
||||
Returns:
|
||||
@@ -198,7 +198,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""Get current robot observation from SDK.
|
||||
|
||||
Returns:
|
||||
@@ -255,7 +255,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
return observation
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Send action to robot via SDK.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -71,11 +71,11 @@ class HopeJrArm(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
@@ -128,7 +128,7 @@ class HopeJrArm(Robot):
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position", self.other_motors)
|
||||
@@ -147,7 +147,7 @@ class HopeJrArm(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
|
||||
@@ -107,11 +107,11 @@ class HopeJrHand(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
@@ -158,7 +158,7 @@ class HopeJrHand(Robot):
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
# Read hand position
|
||||
@@ -178,7 +178,7 @@ class HopeJrHand(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return action
|
||||
|
||||
@@ -73,11 +73,11 @@ class KochFollower(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
@@ -182,7 +182,7 @@ class KochFollower(Robot):
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
@@ -200,7 +200,7 @@ class KochFollower(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
|
||||
@@ -98,11 +98,11 @@ class LeKiwi(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._state_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._state_ft
|
||||
|
||||
@property
|
||||
@@ -338,7 +338,7 @@ class LeKiwi(Robot):
|
||||
} # m/s and deg/s
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
# Read actuators position for arm and vel for base
|
||||
start = time.perf_counter()
|
||||
arm_pos = self.bus.sync_read("Present_Position", self.arm_motors)
|
||||
@@ -367,7 +367,7 @@ class LeKiwi(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command lekiwi to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
|
||||
@@ -98,11 +98,11 @@ class LeKiwiClient(Robot):
|
||||
return {name: (cfg.height, cfg.width, 3) for name, cfg in self.config.cameras.items()}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._state_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._state_ft
|
||||
|
||||
@property
|
||||
@@ -250,7 +250,7 @@ class LeKiwiClient(Robot):
|
||||
return new_frames, new_state
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Capture observations from the remote robot: current follower arm positions,
|
||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||
@@ -304,7 +304,7 @@ class LeKiwiClient(Robot):
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ
|
||||
|
||||
Args:
|
||||
|
||||
@@ -73,11 +73,11 @@ class OmxFollower(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
@@ -165,7 +165,7 @@ class OmxFollower(Robot):
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
@@ -183,7 +183,7 @@ class OmxFollower(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
|
||||
@@ -105,12 +105,12 @@ class OpenArmFollower(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
"""Combined observation features from motors and cameras."""
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
"""Action features."""
|
||||
return self._motors_ft
|
||||
|
||||
@@ -219,7 +219,7 @@ class OpenArmFollower(Robot):
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Get current observation from robot including position, velocity, and torque.
|
||||
|
||||
@@ -251,7 +251,7 @@ class OpenArmFollower(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(
|
||||
def send_action(
|
||||
self,
|
||||
action: RobotAction,
|
||||
custom_kp: dict[str, float] | None = None,
|
||||
|
||||
@@ -95,11 +95,11 @@ class Reachy2Robot(Robot):
|
||||
self.joints_dict: dict[str, str] = self._generate_joints_dict()
|
||||
|
||||
@property
|
||||
def raw_observation_features(self) -> dict[str, Any]:
|
||||
def observation_features(self) -> dict[str, Any]:
|
||||
return {**self.motors_features, **self.camera_features}
|
||||
|
||||
@property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self.motors_features
|
||||
|
||||
@property
|
||||
@@ -170,7 +170,7 @@ class Reachy2Robot(Robot):
|
||||
else:
|
||||
return {}
|
||||
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict: RobotObservation = {}
|
||||
|
||||
# Read Reachy 2 state
|
||||
@@ -184,7 +184,7 @@ class Reachy2Robot(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
if self.reachy is not None:
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
+30
-154
@@ -18,11 +18,8 @@ from pathlib import Path
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType
|
||||
from lerobot.motors import MotorCalibration
|
||||
from lerobot.processor.core import RobotAction, RobotObservation
|
||||
from lerobot.processor.factory import _make_identity_observation_pipeline, _make_identity_robot_action_pipeline
|
||||
from lerobot.processor.pipeline import RobotProcessorPipeline
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, ROBOTS
|
||||
|
||||
from .config import RobotConfig
|
||||
@@ -37,10 +34,6 @@ class Robot(abc.ABC):
|
||||
This class provides a standardized interface for interacting with physical robots.
|
||||
Subclasses must implement all abstract methods and properties to be usable.
|
||||
|
||||
Pipelines are first-class citizens: every robot carries an optional output pipeline
|
||||
(applied in get_observation()) and an optional input pipeline (applied in send_action()).
|
||||
Both default to identity (no-op), so existing robots work without any changes.
|
||||
|
||||
Attributes:
|
||||
config_class (RobotConfig): The expected configuration class for this robot.
|
||||
name (str): The unique robot name used to identify this robot type.
|
||||
@@ -62,12 +55,6 @@ class Robot(abc.ABC):
|
||||
if self.calibration_fpath.is_file():
|
||||
self._load_calibration()
|
||||
|
||||
# Pipeline interface — default to identity (no-op), swap via set_output/input_pipeline()
|
||||
self._output_pipeline: RobotProcessorPipeline = _make_identity_observation_pipeline()
|
||||
self._input_pipeline: RobotProcessorPipeline = _make_identity_robot_action_pipeline()
|
||||
# Cache of most recent raw observation; used by input_pipeline for IK initial guess
|
||||
self._last_raw_obs: RobotObservation = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.id} {self.__class__.__name__}"
|
||||
|
||||
@@ -97,117 +84,40 @@ class Robot(abc.ABC):
|
||||
except Exception: # nosec B110
|
||||
pass
|
||||
|
||||
# ── Pipeline interface ────────────────────────────────────────────────────
|
||||
|
||||
def output_pipeline(self) -> RobotProcessorPipeline:
|
||||
"""
|
||||
Pipeline applied inside get_observation() to transform raw hardware observations.
|
||||
Default: identity (no-op). Override via set_output_pipeline() or subclassing.
|
||||
|
||||
Example: set a forward-kinematics pipeline to convert joint positions to EE pose.
|
||||
"""
|
||||
return self._output_pipeline
|
||||
|
||||
def input_pipeline(self) -> RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]:
|
||||
"""
|
||||
Pipeline applied inside send_action() to transform incoming actions before hardware write.
|
||||
Default: identity (no-op). Override via set_input_pipeline() or subclassing.
|
||||
|
||||
The pipeline receives a (action, last_raw_obs) tuple so IK solvers can use the
|
||||
current joint configuration as an initial guess.
|
||||
|
||||
Example: set an inverse-kinematics pipeline to convert EE commands to joint positions.
|
||||
"""
|
||||
return self._input_pipeline
|
||||
|
||||
def set_output_pipeline(self, pipeline: RobotProcessorPipeline) -> None:
|
||||
"""Set the observation output pipeline (applied in get_observation())."""
|
||||
self._output_pipeline = pipeline
|
||||
|
||||
def set_input_pipeline(self, pipeline: RobotProcessorPipeline) -> None:
|
||||
"""Set the action input pipeline (applied in send_action())."""
|
||||
self._input_pipeline = pipeline
|
||||
|
||||
# ── Feature properties ────────────────────────────────────────────────────
|
||||
|
||||
# TODO(aliberts): create a proper Feature class for this that links with datasets
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def observation_features(self) -> dict:
|
||||
"""
|
||||
Pipeline-transformed observation features.
|
||||
A dictionary describing the structure and types of the observations produced by the robot.
|
||||
Its structure (keys) should match the structure of what is returned by :pymeth:`get_observation`.
|
||||
Values for the dict should either be:
|
||||
- The type of the value if it's a simple value, e.g. `float` for single proprioceptive value (a joint's position/velocity)
|
||||
- A tuple representing the shape if it's an array-type value, e.g. `(height, width, channel)` for images
|
||||
|
||||
Applies output_pipeline().transform_features() to raw_observation_features so the
|
||||
returned dict matches what get_observation() actually returns to callers.
|
||||
|
||||
Use raw_observation_features to inspect hardware-level feature shapes.
|
||||
|
||||
Note: this property should be able to be called regardless of whether the robot
|
||||
is connected or not.
|
||||
"""
|
||||
from lerobot.datasets.pipeline_features import create_initial_features # lazy import
|
||||
|
||||
initial = create_initial_features(observation=self.raw_observation_features)
|
||||
transformed = self.output_pipeline().transform_features(initial)
|
||||
return transformed.get(PipelineFeatureType.OBSERVATION, {})
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def raw_observation_features(self) -> dict:
|
||||
"""
|
||||
Hardware-level observation features (before any pipeline transformation).
|
||||
|
||||
A dictionary describing the structure and types of the observations produced
|
||||
directly by the robot hardware. Its structure (keys) should match the structure
|
||||
of what is returned by :pymeth:`_get_observation`. Values should be:
|
||||
- The type if it's a simple value, e.g. ``float`` for joint position
|
||||
- A tuple representing the shape for array values, e.g. ``(H, W, C)`` for images
|
||||
|
||||
Note: this property should be able to be called regardless of whether the robot
|
||||
is connected or not.
|
||||
Note: this property should be able to be called regardless of whether the robot is connected or not.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def raw_action_features(self) -> dict:
|
||||
"""
|
||||
Hardware-level action features (before any pipeline transformation).
|
||||
|
||||
A dictionary describing the structure and types of the actions accepted directly
|
||||
by the robot hardware (i.e. what :pymeth:`_send_action` receives). Its structure
|
||||
(keys) should match the structure of what is expected by :pymeth:`_send_action`.
|
||||
Values should be the type of the value if it's a simple value, e.g. ``float`` for
|
||||
single proprioceptive value (a joint's goal position/velocity).
|
||||
|
||||
Note: this property should be able to be called regardless of whether the robot
|
||||
is connected or not.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
"""
|
||||
Pipeline-transformed action features.
|
||||
A dictionary describing the structure and types of the actions expected by the robot. Its structure
|
||||
(keys) should match the structure of what is passed to :pymeth:`send_action`. Values for the dict
|
||||
should be the type of the value if it's a simple value, e.g. `float` for single proprioceptive value
|
||||
(a joint's goal position/velocity)
|
||||
|
||||
Applies input_pipeline().transform_features() to raw_action_features so the
|
||||
returned dict reflects what the input pipeline outputs to hardware.
|
||||
|
||||
Use raw_action_features to inspect hardware-level action feature shapes.
|
||||
|
||||
Note: this property should be able to be called regardless of whether the robot
|
||||
is connected or not.
|
||||
Note: this property should be able to be called regardless of whether the robot is connected or not.
|
||||
"""
|
||||
from lerobot.datasets.pipeline_features import create_initial_features # lazy import
|
||||
|
||||
initial = create_initial_features(action=self.raw_action_features)
|
||||
transformed = self.input_pipeline().transform_features(initial)
|
||||
return transformed.get(PipelineFeatureType.ACTION, {})
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
"""
|
||||
Whether the robot is currently connected or not. If ``False``, calling
|
||||
:pymeth:`get_observation` or :pymeth:`send_action` should raise an error.
|
||||
Whether the robot is currently connected or not. If `False`, calling :pymeth:`get_observation` or
|
||||
:pymeth:`send_action` should raise an error.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -225,7 +135,7 @@ class Robot(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_calibrated(self) -> bool:
|
||||
"""Whether the robot is currently calibrated or not. Should be always ``True`` if not applicable"""
|
||||
"""Whether the robot is currently calibrated or not. Should be always `True` if not applicable"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -243,7 +153,7 @@ class Robot(abc.ABC):
|
||||
Helper to load calibration data from the specified file.
|
||||
|
||||
Args:
|
||||
fpath (Path | None): Optional path to the calibration file. Defaults to ``self.calibration_fpath``.
|
||||
fpath (Path | None): Optional path to the calibration file. Defaults to `self.calibration_fpath`.
|
||||
"""
|
||||
fpath = self.calibration_fpath if fpath is None else fpath
|
||||
with open(fpath) as f, draccus.config_type("json"):
|
||||
@@ -254,7 +164,7 @@ class Robot(abc.ABC):
|
||||
Helper to save calibration data to the specified file.
|
||||
|
||||
Args:
|
||||
fpath (Path | None): Optional path to save the calibration file. Defaults to ``self.calibration_fpath``.
|
||||
fpath (Path | None): Optional path to save the calibration file. Defaults to `self.calibration_fpath`.
|
||||
"""
|
||||
fpath = self.calibration_fpath if fpath is None else fpath
|
||||
with open(fpath, "w") as f, draccus.config_type("json"):
|
||||
@@ -268,64 +178,30 @@ class Robot(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
# ── Template methods (concrete, call pipeline internally) ─────────────────
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Retrieve the current observation from the robot and apply the output pipeline.
|
||||
|
||||
Calls :pymeth:`_get_observation` to get raw hardware data, caches it for use as
|
||||
IK initial guess in :pymeth:`send_action`, then applies :pymeth:`output_pipeline`.
|
||||
Retrieve the current observation from the robot.
|
||||
|
||||
Returns:
|
||||
RobotObservation: Pipeline-transformed observation. With the default identity
|
||||
pipeline this equals the raw observation from :pymeth:`_get_observation`.
|
||||
RobotObservation: A flat dictionary representing the robot's current sensory state. Its structure
|
||||
should match :pymeth:`observation_features`.
|
||||
"""
|
||||
raw = self._get_observation()
|
||||
self._last_raw_obs = raw
|
||||
return self.output_pipeline()(raw)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Retrieve the raw observation directly from robot hardware.
|
||||
|
||||
Returns:
|
||||
RobotObservation: A flat dictionary representing the robot's current sensory
|
||||
state. Its structure should match :pymeth:`raw_observation_features`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""
|
||||
Apply the input pipeline and send the resulting action to robot hardware.
|
||||
|
||||
The input pipeline receives ``(action, last_raw_obs)`` so IK solvers can use the
|
||||
cached joint configuration as an initial guess. With the default identity pipeline,
|
||||
the action is forwarded unchanged.
|
||||
Send an action command to the robot.
|
||||
|
||||
Args:
|
||||
action (RobotAction): Dictionary representing the desired action. Its structure
|
||||
should match :pymeth:`action_features`.
|
||||
action (RobotAction): Dictionary representing the desired action. Its structure should match
|
||||
:pymeth:`action_features`.
|
||||
|
||||
Returns:
|
||||
RobotAction: The action actually sent to the motors, potentially clipped or
|
||||
modified by the pipeline or hardware safety limits.
|
||||
"""
|
||||
transformed = self.input_pipeline()((action, self._last_raw_obs))
|
||||
return self._send_action(transformed)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""
|
||||
Send an action command directly to robot hardware.
|
||||
|
||||
Args:
|
||||
action (RobotAction): Dictionary of motor-level commands. Its structure should
|
||||
match what the hardware expects (typically motor positions/velocities).
|
||||
|
||||
Returns:
|
||||
RobotAction: The action actually sent, potentially clipped by safety limits.
|
||||
RobotAction: The action actually sent to the motors potentially clipped or modified, e.g. by
|
||||
safety limits on velocity.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypeAlias
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
@@ -50,5 +49,5 @@ class SOFollowerRobotConfig(RobotConfig, SOFollowerConfig):
|
||||
pass
|
||||
|
||||
|
||||
SO100FollowerConfig: TypeAlias = SOFollowerRobotConfig
|
||||
SO101FollowerConfig: TypeAlias = SOFollowerRobotConfig
|
||||
SO100FollowerConfig = SOFollowerRobotConfig
|
||||
SO101FollowerConfig = SOFollowerRobotConfig
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .ee_space import make_so10x_fk_observation_pipeline, make_so10x_ik_action_pipeline
|
||||
|
||||
__all__ = ["make_so10x_fk_observation_pipeline", "make_so10x_ik_action_pipeline"]
|
||||
@@ -1,147 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
End-effector space pipelines for SO-100/101 follower robots.
|
||||
|
||||
These factory functions return ready-to-use pipelines that convert between joint space
|
||||
and Cartesian end-effector space. Attach them to a robot with ``set_output_pipeline`` /
|
||||
``set_input_pipeline`` to enable EE-space recording and teleoperation.
|
||||
|
||||
Example::
|
||||
|
||||
from lerobot.robots.so_follower.pipelines import (
|
||||
make_so10x_fk_observation_pipeline,
|
||||
make_so10x_ik_action_pipeline,
|
||||
)
|
||||
|
||||
motor_names = list(follower.bus.motors.keys())
|
||||
follower.set_output_pipeline(make_so10x_fk_observation_pipeline(URDF_PATH, motor_names))
|
||||
follower.set_input_pipeline(make_so10x_ik_action_pipeline(URDF_PATH, motor_names))
|
||||
"""
|
||||
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
|
||||
_DEFAULT_EE_BOUNDS = {"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}
|
||||
_DEFAULT_GRIPPER_FRAME = "gripper_frame_link"
|
||||
|
||||
|
||||
def make_so10x_fk_observation_pipeline(
|
||||
urdf_path: str,
|
||||
motor_names: list[str],
|
||||
*,
|
||||
target_frame_name: str = _DEFAULT_GRIPPER_FRAME,
|
||||
) -> RobotProcessorPipeline[RobotObservation, RobotObservation]:
|
||||
"""
|
||||
Create a forward-kinematics observation pipeline for SO-100/101 follower robots.
|
||||
|
||||
Converts raw joint positions (observation) into end-effector pose (position + orientation).
|
||||
Attach this to a follower robot via ``set_output_pipeline`` so that ``get_observation()``
|
||||
returns EE coordinates instead of raw joint angles.
|
||||
|
||||
Args:
|
||||
urdf_path: Path to the SO-100/101 URDF file used for kinematics.
|
||||
motor_names: Ordered list of motor names matching the URDF joint names.
|
||||
target_frame_name: Name of the end-effector frame in the URDF.
|
||||
|
||||
Returns:
|
||||
A RobotProcessorPipeline that maps joint observations to EE observations.
|
||||
|
||||
Example::
|
||||
|
||||
follower.set_output_pipeline(
|
||||
make_so10x_fk_observation_pipeline("./so101.urdf", motor_names)
|
||||
)
|
||||
obs = follower.get_observation() # now contains ee.x, ee.y, ee.z, ...
|
||||
"""
|
||||
kinematics = RobotKinematics(
|
||||
urdf_path=urdf_path,
|
||||
target_frame_name=target_frame_name,
|
||||
joint_names=motor_names,
|
||||
)
|
||||
return RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics, motor_names=motor_names)],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
|
||||
def make_so10x_ik_action_pipeline(
|
||||
urdf_path: str,
|
||||
motor_names: list[str],
|
||||
*,
|
||||
target_frame_name: str = _DEFAULT_GRIPPER_FRAME,
|
||||
end_effector_bounds: dict | None = None,
|
||||
max_ee_step_m: float = 0.10,
|
||||
) -> RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]:
|
||||
"""
|
||||
Create an inverse-kinematics action pipeline for SO-100/101 follower robots.
|
||||
|
||||
Converts incoming end-effector pose commands into joint positions, applying safety
|
||||
bounds and step-size limits before solving IK. The current joint positions are used
|
||||
as the IK initial guess (taken from the cached ``_last_raw_obs``).
|
||||
|
||||
Attach this to a follower robot via ``set_input_pipeline`` so that ``send_action()``
|
||||
receives EE commands and translates them to motor positions before the hardware write.
|
||||
|
||||
Args:
|
||||
urdf_path: Path to the SO-100/101 URDF file used for kinematics.
|
||||
motor_names: Ordered list of motor names matching the URDF joint names.
|
||||
target_frame_name: Name of the end-effector frame in the URDF.
|
||||
end_effector_bounds: Dict with ``"min"`` and ``"max"`` lists (3D position bounds in metres).
|
||||
Defaults to ``{"min": [-1, -1, -1], "max": [1, 1, 1]}``.
|
||||
max_ee_step_m: Maximum allowed EE position change per step in metres.
|
||||
|
||||
Returns:
|
||||
A RobotProcessorPipeline that maps (EE action, raw obs) to joint action.
|
||||
|
||||
Example::
|
||||
|
||||
follower.set_input_pipeline(
|
||||
make_so10x_ik_action_pipeline("./so101.urdf", motor_names)
|
||||
)
|
||||
# send_action() now accepts ee.x, ee.y, ee.z, ee.wx, ee.wy, ee.wz, ee.gripper_vel
|
||||
"""
|
||||
kinematics = RobotKinematics(
|
||||
urdf_path=urdf_path,
|
||||
target_frame_name=target_frame_name,
|
||||
joint_names=motor_names,
|
||||
)
|
||||
bounds = end_effector_bounds or _DEFAULT_EE_BOUNDS
|
||||
return RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
EEBoundsAndSafety(end_effector_bounds=bounds, max_ee_step_m=max_ee_step_m),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics,
|
||||
motor_names=motor_names,
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
@@ -17,7 +17,6 @@
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import TypeAlias
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
@@ -74,11 +73,11 @@ class SOFollower(Robot):
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def raw_observation_features(self) -> dict[str, type | tuple]:
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def raw_action_features(self) -> dict[str, type]:
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
@@ -176,7 +175,7 @@ class SOFollower(Robot):
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def _get_observation(self) -> RobotObservation:
|
||||
def get_observation(self) -> RobotObservation:
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
@@ -194,7 +193,7 @@ class SOFollower(Robot):
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def _send_action(self, action: RobotAction) -> RobotAction:
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
@@ -230,5 +229,5 @@ class SOFollower(Robot):
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
|
||||
SO100Follower: TypeAlias = SOFollower
|
||||
SO101Follower: TypeAlias = SOFollower
|
||||
SO100Follower = SOFollower
|
||||
SO101Follower = SOFollower
|
||||
|
||||
@@ -16,3 +16,5 @@
|
||||
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
from .unitree_g1 import UnitreeG1
|
||||
|
||||
__all__ = ["UnitreeG1", "UnitreeG1Config"]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user