mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
Compare commits
48 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| db5c26f07d | |||
| 8904768db4 | |||
| b0efa73520 | |||
| 00b662de02 | |||
| 5c51a74484 | |||
| db8547e35d | |||
| c17d949531 | |||
| 1e131f93f8 | |||
| 2fb5c7add0 | |||
| 4f2ef024d8 | |||
| 6139b133ca | |||
| 85de893fa7 | |||
| a4c66e530b | |||
| a225127527 | |||
| e489ba24fc | |||
| d324ffe810 | |||
| 1a24f770d3 | |||
| 92fba37225 | |||
| 3e45120272 | |||
| f0d2b37beb | |||
| cbc8bfb2e6 | |||
| 0d1be72dc8 | |||
| 96b7c212c4 | |||
| 4303b3c930 | |||
| 63dca86df8 | |||
| 8a0cc3d664 | |||
| 8bb8ed4803 | |||
| 095856b06a | |||
| 563f42bdb1 | |||
| 8fff0fde7c | |||
| 04de496547 | |||
| baf9b50365 | |||
| a0fdbf037a | |||
| c085531b17 | |||
| c7c6205332 | |||
| 4e54be1334 | |||
| fde9d08281 | |||
| 46044fed75 | |||
| 975dcad918 | |||
| d0b58190da | |||
| 9a5ab8ffab | |||
| 7541d72130 | |||
| 0317a15bf1 | |||
| f138e5948a | |||
| 8fef4ddab8 | |||
| 18d9cb5ac4 | |||
| 5095ab0845 | |||
| dac1efd13d |
@@ -44,7 +44,7 @@ permissions:
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
PYTHON_VERSION: "3.12"
|
||||
|
||||
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
|
||||
concurrency:
|
||||
@@ -61,6 +61,7 @@ jobs:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
@@ -89,5 +90,11 @@ jobs:
|
||||
- name: Install lerobot with test extras
|
||||
run: uv sync --extra "test"
|
||||
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
|
||||
- name: Run pytest
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
|
||||
@@ -37,7 +37,7 @@ permissions:
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
PYTHON_VERSION: "3.12"
|
||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
|
||||
|
||||
# Ensures that only the latest action is built, canceling older runs.
|
||||
@@ -60,6 +60,7 @@ jobs:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
@@ -87,6 +88,12 @@ jobs:
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
|
||||
@@ -162,6 +169,7 @@ jobs:
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
@@ -173,6 +181,13 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Fix ptxas permissions
|
||||
run: chmod +x /lerobot/.venv/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
|
||||
@@ -28,7 +28,7 @@ on:
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
PYTHON_VERSION: "3.12"
|
||||
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest
|
||||
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest
|
||||
|
||||
@@ -119,6 +119,7 @@ jobs:
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --shm-size "16gb"
|
||||
@@ -130,6 +131,11 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Run pytest on CPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
@@ -146,6 +152,7 @@ jobs:
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
@@ -157,6 +164,11 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
@@ -174,6 +186,7 @@ jobs:
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
CUDA_VISIBLE_DEVICES: "0,1,2,3"
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
@@ -185,12 +198,15 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Verify GPU availability
|
||||
run: |
|
||||
nvidia-smi
|
||||
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
||||
|
||||
- name: Run multi-GPU training tests
|
||||
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
|
||||
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
|
||||
timeout-minutes: 10
|
||||
run: pytest -vv tests/training/
|
||||
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Run pre-commit hooks
|
||||
uses: pre-commit/action@v3.0.1 # zizmor: ignore[unpinned-uses]
|
||||
|
||||
@@ -22,7 +22,7 @@ on:
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
PYTHON_VERSION: "3.12"
|
||||
|
||||
jobs:
|
||||
# This job builds the Python package and publishes it to PyPI
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Extract Version
|
||||
id: extract_info
|
||||
@@ -83,14 +83,6 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Remove Tags with Git dependencies
|
||||
# TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies.
|
||||
run: |
|
||||
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
|
||||
grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
|
||||
sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml
|
||||
echo "::info:: Git dependencies removed. Proceeding with build."
|
||||
|
||||
- name: Install build dependencies
|
||||
run: python -m pip install build
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ permissions:
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
PYTHON_VERSION: "3.12"
|
||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound
|
||||
|
||||
# Ensures that only the latest action is built, canceling older runs.
|
||||
@@ -48,6 +48,7 @@ jobs:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
@@ -79,7 +80,11 @@ jobs:
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv
|
||||
|
||||
@@ -137,6 +142,7 @@ jobs:
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
@@ -148,6 +154,11 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv
|
||||
- name: Run end-to-end tests
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
default_language_version:
|
||||
python: python3.10
|
||||
python: python3.12
|
||||
|
||||
exclude: "tests/artifacts/.*\\.safetensors$"
|
||||
|
||||
@@ -55,7 +55,7 @@ repos:
|
||||
rev: v3.21.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py310-plus]
|
||||
args: [--py312-plus]
|
||||
|
||||
##### Markdown Quality #####
|
||||
- repo: https://github.com/rbubley/mirrors-prettier
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
# AI Usage Policy
|
||||
|
||||
The LeRobot project welcomes contributions from everyone, and we have a few guidelines regarding AI usage to ensure high code quality, clear communication, and a healthy open-source ecosystem:
|
||||
|
||||
- **Please disclose significant AI assistance.** If you used AI tools (e.g., Copilot, Claude, Cursor, ChatGPT) to generate a substantial portion of your code or text, let us know in your PR description. Transparency helps us review your changes more effectively.
|
||||
- **Own your code (The Human-in-the-Loop).** You must fully understand all the changes you are proposing. If you cannot explain what your AI-assisted code does or how it interacts with LeRobot's broader architecture, please take the time to learn and test it before submitting.
|
||||
- **Keep issues and discussions focused.** You are welcome to use AI to help draft issues or PR descriptions, but please review and edit them carefully before posting. AI can often be overly verbose; trimming the noise and getting straight to the point helps our maintainers address your needs faster.
|
||||
|
||||
Our core maintainers also use AI tools to aid their workflows, but they do so while bringing deep contextual knowledge of the LeRobot codebase to validate the output. We ask all contributors to apply that same level of rigor.
|
||||
|
||||
## Remember the Human Maintainers
|
||||
|
||||
Please remember that LeRobot is maintained by a dedicated team of humans.
|
||||
|
||||
Every discussion, issue, and pull request is read and reviewed by real people. While AI tools can generate thousands of lines of code in seconds, reviewing that code still takes human time and energy. Submitting unverified or low-effort AI output puts an unfair burden on our maintainers.
|
||||
|
||||
Today, the quality of the AI output still heavily depends on the developer driving the tool. We ask that you respect our maintainers' time by thoroughly vetting, testing, and refining your submissions.
|
||||
|
||||
## AI is Welcome Here
|
||||
|
||||
LeRobot operates at the cutting edge of AI and robotics, and many of our maintainers actively embrace AI coding assistants as valuable productivity tools. We are a pro-AI project!
|
||||
|
||||
Our reason for having an AI policy is not an anti-AI stance. Rather, it exists to ensure that AI is used to enhance human contributions, not replace them with unverified noise. It's about how the tools are used, not the tools themselves.
|
||||
|
||||
We value the unique human insight you bring to the LeRobot community. Let AI empower your workflow, but always let your own judgment take the wheel.
|
||||
+1
-1
@@ -2,7 +2,7 @@
|
||||
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable.
|
||||
|
||||
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md).
|
||||
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md) and our [AI policy](./AI_POLICY.md).
|
||||
|
||||
## Ways to Contribute
|
||||
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||
include src/lerobot/datasets/card_template.md
|
||||
include src/lerobot/envs/metaworld_config.json
|
||||
|
||||
@@ -135,7 +135,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
||||
|
||||
## Citation
|
||||
|
||||
If you use LeRobot in your research, please cite:
|
||||
If you use LeRobot in your project, please cite the GitHub repository to acknowledge the ongoing development and contributors:
|
||||
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
@@ -146,6 +146,23 @@ If you use LeRobot in your research, please cite:
|
||||
}
|
||||
```
|
||||
|
||||
If you are referencing our research or the academic paper, please also cite our ICLR publication:
|
||||
|
||||
<details>
|
||||
<summary><b>ICLR 2026 Paper</b></summary>
|
||||
|
||||
```bibtex
|
||||
@inproceedings{cadenelerobot,
|
||||
title={LeRobot: An Open-Source Library for End-to-End Robot Learning},
|
||||
author={Cadene, Remi and Alibert, Simon and Capuano, Francesco and Aractingi, Michel and Zouitine, Adil and Kooijmans, Pepijn and Choghari, Jade and Russi, Martino and Pascal, Caroline and Palma, Steven and Shukor, Mustafa and Moss, Jess and Soare, Alexander and Aubakirova, Dana and Lhoest, Quentin and Gallou\'edec, Quentin and Wolf, Thomas},
|
||||
booktitle={The Fourteenth International Conference on Learning Representations},
|
||||
year={2026},
|
||||
url={https://arxiv.org/abs/2602.22818}
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Contribute
|
||||
|
||||
We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](./CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!
|
||||
|
||||
@@ -24,7 +24,7 @@ ARG OS_VERSION=22.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
# Define Python version argument
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG PYTHON_VERSION=3.12
|
||||
|
||||
# Configure environment variables
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
@@ -85,6 +85,8 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
|
||||
RUN uv pip install --no-cache ".[all]"
|
||||
|
||||
RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
|
||||
# Copy the rest of the application source code
|
||||
# Make sure to have the git-LFS files for testing
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
# docker run -it --rm lerobot-user
|
||||
|
||||
# Configure the base image
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG PYTHON_VERSION=3.12
|
||||
FROM python:${PYTHON_VERSION}-slim
|
||||
|
||||
# Configure environment variables
|
||||
|
||||
@@ -17,8 +17,6 @@
|
||||
title: Train RL in Simulation
|
||||
- local: multi_gpu_training
|
||||
title: Multi GPU training
|
||||
- local: hil_collection
|
||||
title: Human In the Loop Data Collection
|
||||
- local: peft_training
|
||||
title: Training with PEFT (e.g., LoRA)
|
||||
title: "Tutorials"
|
||||
|
||||
@@ -48,7 +48,7 @@ python -m lerobot.async_inference.robot_client \
|
||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server (cuda, mps, xpu, cpu)
|
||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||
|
||||
@@ -32,7 +32,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
# your policy-specific dependencies
|
||||
]
|
||||
requires-python = ">= 3.11"
|
||||
requires-python = ">= 3.12"
|
||||
|
||||
[build-system]
|
||||
build-backend = # your-build-backend
|
||||
@@ -82,7 +82,7 @@ Create your policy implementation by inheriting from LeRobot's base `PreTrainedP
|
||||
# modeling_my_custom_policy.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
@@ -91,7 +91,7 @@ class MyCustomPolicy(PreTrainedPolicy):
|
||||
config_class = MyCustomPolicyConfig
|
||||
name = "my_custom_policy"
|
||||
|
||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None):
|
||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||
super().__init__(config, dataset_stats)
|
||||
...
|
||||
```
|
||||
@@ -102,7 +102,7 @@ Create processor functions:
|
||||
|
||||
```python
|
||||
# processor_my_custom_policy.py
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ The EarthRover Mini Plus is a fully open source mobile robot that connects throu
|
||||
### Hardware
|
||||
|
||||
- EarthRover Mini robot
|
||||
- Computer with Python 3.10 or newer
|
||||
- Computer with Python 3.12 or newer
|
||||
- Internet connection
|
||||
|
||||
### Setting Up the Frodobots SDK
|
||||
@@ -170,13 +170,13 @@ Once you can drive the robot well, you can start recording data to train AI mode
|
||||
We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face username:
|
||||
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
|
||||
@@ -155,10 +155,10 @@ Upload your repository to Hugging Face:
|
||||
pip install huggingface_hub
|
||||
|
||||
# Login to Hugging Face
|
||||
huggingface-cli login
|
||||
hf auth login
|
||||
|
||||
# Create a new repository
|
||||
huggingface-cli repo create my-custom-env --type space --org my-org
|
||||
hf repo create my-org/my-custom-env
|
||||
|
||||
# Initialize git and push
|
||||
git init
|
||||
|
||||
@@ -1,237 +0,0 @@
|
||||
# Human-In-the-Loop Data Collection
|
||||
|
||||
Human-In-the-Loop (HIL) data collection lets you improve a trained policy by deploying it on a real robot while a human operator monitors and intervenes when needed. The intervention data — recovery movements and corrections — is recorded alongside autonomous segments, producing a richer training dataset that teaches the policy how to handle failures.
|
||||
|
||||
---
|
||||
|
||||
## Why Human-In-the-Loop?
|
||||
|
||||
Standard behavioral cloning trains policies on successful demonstrations only. During deployment, small errors can compound and push the robot into states never seen during training (distribution shift). HIL data collection addresses this by:
|
||||
|
||||
- Running the trained policy on the real robot
|
||||
- Having a human intervene when the robot is about to fail
|
||||
- Recording the human's recovery and correction as training data
|
||||
- Fine-tuning the policy on the combined dataset
|
||||
|
||||
This produces a policy that not only knows how to perform the task, but also how to recover when things go wrong.
|
||||
|
||||
---
|
||||
|
||||
## How It Works
|
||||
|
||||
During a HIL session, the human operator follows this loop within each episode:
|
||||
|
||||
1. **Watch** the policy run autonomously
|
||||
2. **Pause** when failure is imminent — the robot holds its position
|
||||
3. **Take control** — teleoperate the robot back to a good state (recovery), then correct the behavior
|
||||
4. **Return control to the policy** — the policy resumes autonomous execution
|
||||
5. Repeat steps 2–4 as many times as needed during the episode
|
||||
6. **End the episode** when the task is complete, save and move on to the next rollout
|
||||
|
||||
Both autonomous and human-controlled segments are recorded. The policy and human can alternate control multiple times within a single episode, and the episode continues from the current state after each handoff (no reset required just because intervention happened). This captures autonomous execution, recovery, and correction in one continuous trajectory. After collection, the combined dataset (original demonstrations + HIL data) is used to fine-tune the policy.
|
||||
|
||||
This process can be repeated iteratively: deploy, collect, fine-tune, repeat — each round targeting the current policy's failure modes.
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ Policy v0 (trained on demos) │
|
||||
│ ↓ │
|
||||
│ HIL Collection (target current failure modes) → Fine-tune → Policy v1 │
|
||||
│ ↓ │
|
||||
│ HIL Collection (target new failure modes) → Fine-tune → Policy v2 │
|
||||
│ ↓ │
|
||||
│ ... (repeat until satisfactory performance) │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
### Teleoperator Requirements
|
||||
|
||||
The HIL data collection scripts require **teleoperators with active motors** that can:
|
||||
|
||||
- Enable/disable torque programmatically
|
||||
- Move to target positions (to mirror the robot state when pausing)
|
||||
|
||||
**Compatible teleoperators:**
|
||||
|
||||
- `so101_leader` - SO-101 Leader Arm
|
||||
- `openarms_mini` - OpenArms Mini (via third-party plugin)
|
||||
|
||||
---
|
||||
|
||||
## Scripts
|
||||
|
||||
Two scripts are provided depending on your policy's inference speed:
|
||||
|
||||
| Script | Use Case | Models |
|
||||
| ---------------------------- | ------------------------------------------ | --------------------- |
|
||||
| `hil_data_collection.py` | Standard synchronous inference | ACT, Diffusion Policy |
|
||||
| `hil_data_collection_rtc.py` | Real-Time Chunking for high-latency models | Pi0, Pi0.5, SmolVLA |
|
||||
|
||||
---
|
||||
|
||||
## Step-by-Step Guide
|
||||
|
||||
### Step 1: Pre-train a Base Policy
|
||||
|
||||
First, train a policy on your demonstration dataset:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/demo-dataset \
|
||||
--policy.type=pi0 \
|
||||
--output_dir=outputs/pretrain \
|
||||
--batch_size=32 \
|
||||
--steps=50000
|
||||
```
|
||||
|
||||
### Step 2: Collect HIL Data
|
||||
|
||||
**Standard inference (ACT, Diffusion Policy):**
|
||||
|
||||
```bash
|
||||
python examples/rac/hil_data_collection.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/hil-dataset \
|
||||
--dataset.single_task="Pick up the cube and place it in the bowl" \
|
||||
--dataset.num_episodes=50
|
||||
```
|
||||
|
||||
**With RTC for large models (Pi0, Pi0.5, SmolVLA):**
|
||||
|
||||
For models with high inference latency, use the RTC script for smooth execution:
|
||||
|
||||
```bash
|
||||
python examples/rac/hil_data_collection_rtc.py \
|
||||
--robot.type=so100_follower \
|
||||
--teleop.type=so100_leader \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/hil-rtc-dataset \
|
||||
--dataset.single_task="Pick up the cube" \
|
||||
--rtc.execution_horizon=20 \
|
||||
--interpolation=true
|
||||
```
|
||||
|
||||
**Controls (Conceptual):**
|
||||
|
||||
The interaction model is:
|
||||
|
||||
- **Pause input**: pause autonomous policy execution
|
||||
- **Takeover input**: transfer control to the human operator and record intervention data
|
||||
- **Return-to-policy input**: hand control back to the policy and continue the same episode
|
||||
- **Episode control inputs**: save/re-record/stop/reset as needed
|
||||
|
||||
Exact key/pedal bindings can differ across scripts and hardware integrations. Use each script's printed controls as the source of truth for the concrete mapping on your setup.
|
||||
|
||||
**The HIL Protocol:**
|
||||
|
||||
1. Watch the policy run autonomously (teleop is idle/free)
|
||||
2. When you see imminent failure, trigger the **pause input**
|
||||
- Policy stops
|
||||
- Teleoperator moves to match robot position (torque enabled)
|
||||
- No frames recorded during pause
|
||||
3. Trigger the **takeover input** to take control
|
||||
- Teleoperator torque disabled, free to move
|
||||
- **Recovery**: Teleoperate the robot back to a good state
|
||||
- **Correction**: Correct the behavior
|
||||
- All movements are recorded
|
||||
4. Trigger the **return-to-policy input**
|
||||
- Policy resumes autonomous execution from the current state
|
||||
- You can intervene again at any time (repeat steps 2–4)
|
||||
5. End and save the episode when the task is complete (or episode time limit is reached)
|
||||
6. **Reset**: Teleop moves to robot position, you can move the robot to the starting position
|
||||
7. Start the next episode
|
||||
|
||||
**Foot Pedal Setup (Linux):**
|
||||
|
||||
If using a USB foot pedal (PCsensor FootSwitch), ensure access:
|
||||
|
||||
```bash
|
||||
sudo setfacl -m u:$USER:rw /dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd
|
||||
```
|
||||
|
||||
### Step 3: Fine-tune the Policy
|
||||
|
||||
Fine-tune on the combined demonstration + HIL data:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/hil-dataset \
|
||||
--policy.type=pi0 \
|
||||
--policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--output_dir=outputs/hil_finetune \
|
||||
--steps=20000
|
||||
```
|
||||
|
||||
Then deploy the fine-tuned policy and repeat from Step 2 to target its remaining failure modes.
|
||||
|
||||
---
|
||||
|
||||
## Tips for Effective HIL Collection
|
||||
|
||||
### When to Intervene
|
||||
|
||||
Intervene when you see:
|
||||
|
||||
- Robot about to make an irreversible mistake
|
||||
- Robot hesitating or showing uncertain behavior
|
||||
- Robot deviating from the expected trajectory
|
||||
|
||||
### Recovery: Teleoperating Back to a Good State
|
||||
|
||||
During recovery, teleoperate the robot back to a state where:
|
||||
|
||||
- The robot is in a familiar, in-distribution configuration
|
||||
- The current subtask can still be completed
|
||||
- The recovery trajectory itself is informative training data
|
||||
|
||||
### Quality of Corrections
|
||||
|
||||
During correction:
|
||||
|
||||
- Provide **confident, clean** trajectories
|
||||
- Complete the current subtask fully
|
||||
- Don't overcorrect or add unnecessary movements
|
||||
|
||||
---
|
||||
|
||||
## Related Work
|
||||
|
||||
This HIL data collection approach builds on ideas from interactive imitation learning, including DAgger (Ross et al., 2011), HG-DAgger (Kelly et al., 2019), RaC (Hu et al., 2025), and RECAP (Physical Intelligence, 2025). See those works for a deeper treatment of the theory behind human-in-the-loop policy improvement.
|
||||
|
||||
```bibtex
|
||||
@article{ross2011dagger,
|
||||
title={A Reduction of Imitation Learning and Structured Prediction to No-Regret Online Learning},
|
||||
author={Ross, Stéphane and Gordon, Geoffrey and Bagnell, Drew},
|
||||
journal={Proceedings of the Fourteenth International Conference on Artificial Intelligence and Statistics},
|
||||
year={2011}
|
||||
}
|
||||
|
||||
@article{kelly2019hgdagger,
|
||||
title={HG-DAgger: Interactive Imitation Learning with Human Experts},
|
||||
author={Kelly, Michael and Sidrane, Chelsea and Driggs-Campbell, Katherine and Kochenderfer, Mykel J},
|
||||
journal={arXiv preprint arXiv:1810.02890},
|
||||
year={2019}
|
||||
}
|
||||
|
||||
@article{hu2025rac,
|
||||
title={RaC: Robot Learning for Long-Horizon Tasks by Scaling Recovery and Correction},
|
||||
author={Hu, Zheyuan and Wu, Robyn and Enock, Naveen and Li, Jasmine and Kadakia, Riya and Erickson, Zackory and Kumar, Aviral},
|
||||
journal={arXiv preprint arXiv:2509.07953},
|
||||
year={2025}
|
||||
}
|
||||
|
||||
@article{pi2025recap,
|
||||
title={π0.6: a VLA That Learns From Experience},
|
||||
author={Physical Intelligence},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
@@ -159,7 +159,7 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
|
||||
Add your token to the CLI by running this command:
|
||||
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
@@ -327,7 +327,7 @@ You can look for other LeRobot datasets on the hub by searching for `LeRobot` [t
|
||||
You can also push your local dataset to the Hub manually, running:
|
||||
|
||||
```bash
|
||||
huggingface-cli upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
|
||||
hf upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
|
||||
```
|
||||
|
||||
#### Record function
|
||||
@@ -491,7 +491,7 @@ If your local computer doesn't have a powerful GPU you could utilize Google Cola
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
|
||||
```bash
|
||||
huggingface-cli upload ${HF_USER}/act_so101_test \
|
||||
hf upload ${HF_USER}/act_so101_test \
|
||||
outputs/train/act_so101_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
@@ -499,7 +499,7 @@ You can also upload intermediate checkpoints with:
|
||||
|
||||
```bash
|
||||
CKPT=010000
|
||||
huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
hf upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
outputs/train/act_so101_test/checkpoints/${CKPT}/pretrained_model
|
||||
```
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Installation
|
||||
|
||||
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
|
||||
This guide uses `conda` (via miniforge) to manage environments (recommended). If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and `ffmpeg` installed with the `libsvtav1` encoder, then skip ahead to [Environment Setup](#step-2-environment-setup).
|
||||
|
||||
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
|
||||
## Step 1 (`conda` only): Install [`miniforge`](https://conda-forge.org/download/)
|
||||
|
||||
```bash
|
||||
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
|
||||
@@ -11,22 +11,47 @@ bash Miniforge3-$(uname)-$(uname -m).sh
|
||||
|
||||
## Step 2: Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10, using conda:
|
||||
Create a virtual environment with Python 3.12:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
<hfoptions id="create_venv">
|
||||
<hfoption id="conda">
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda create -y -n lerobot python=3.12
|
||||
```
|
||||
|
||||
Then activate your conda environment, you have to do this each time you open a shell to use lerobot:
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="uv">
|
||||
```bash
|
||||
uv python install 3.12
|
||||
uv venv --python 3.12
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
Then activate your virtual environment, you have to do this each time you open a shell to use lerobot:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
<hfoptions id="activate_venv">
|
||||
<hfoption id="conda">```bash
|
||||
conda activate lerobot
|
||||
```</hfoption>
|
||||
<hfoption id="uv">
|
||||
```bash
|
||||
# Linux/macOSsource
|
||||
source .venv/bin/activate
|
||||
# Windows PowerShell
|
||||
source .venv\Scripts\Activate.ps1
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
When using `conda`, install `ffmpeg` in your environment:
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
ffmpeg -version # ffmpeg 8.X is not yet supported !
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
@@ -47,6 +72,9 @@ conda install ffmpeg -c conda-forge
|
||||
> conda install evdev -c conda-forge
|
||||
> ```
|
||||
|
||||
> [!IMPORTANT]
|
||||
> If you are using `uv` you will have to install `ffmpeg` system-wide (outside of the virtual environment). You rely on `uv` and `torchcodec` ability to dynamically link to the system `ffmpeg`.
|
||||
|
||||
## Step 3: Install LeRobot 🤗
|
||||
|
||||
### From Source
|
||||
@@ -60,23 +88,45 @@ cd lerobot
|
||||
|
||||
Then, install the library in editable mode. This is useful if you plan to contribute to the code.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
<hfoptions id="install_lerobot_src">
|
||||
<hfoption id="conda">
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="uv">
|
||||
```bash
|
||||
uv pip install -e .
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Installation from PyPI
|
||||
|
||||
**Core Library:**
|
||||
Install the base package with:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
<hfoptions id="install_lerobot_pypi">
|
||||
<hfoption id="conda">
|
||||
```bash
|
||||
pip install lerobot
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="uv">
|
||||
```bash
|
||||
uv pip install lerobot
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
_This installs only the default dependencies._
|
||||
|
||||
**Extra Features:**
|
||||
To install additional functionality, use one of the following:
|
||||
To install additional functionality, use one of the following (If you are using `uv`, replace `pip install` with `uv pip install` in the commands below.):
|
||||
|
||||
```bash
|
||||
pip install 'lerobot[all]' # All available features
|
||||
@@ -90,13 +140,10 @@ _Replace `[...]` with your desired features._
|
||||
For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||
To install these for linux run:
|
||||
To install these for Linux run:
|
||||
|
||||
```bash
|
||||
sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev
|
||||
@@ -106,7 +153,7 @@ For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/
|
||||
|
||||
## Optional dependencies
|
||||
|
||||
LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`.
|
||||
LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`. If you are using `uv`, replace `pip install` with `uv pip install` in the commands below.
|
||||
|
||||
### Simulations
|
||||
|
||||
|
||||
@@ -279,13 +279,13 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
|
||||
Add your token to the CLI by running this command:
|
||||
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
|
||||
@@ -34,11 +34,6 @@ As described by Physical Intelligence, while AI has achieved remarkable success
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Training Data and Capabilities
|
||||
|
||||
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
|
||||
|
||||
@@ -36,11 +36,6 @@ This diverse training mixture creates a "curriculum" that enables generalization
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Usage
|
||||
|
||||
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
+10
-15
@@ -43,16 +43,11 @@ This approach can transform **any existing VLM** into a VLA by training it to pr
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install the pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Training a Custom FAST Tokenizer
|
||||
|
||||
You have two options for the FAST tokenizer:
|
||||
|
||||
1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
|
||||
1. **Use the pre-trained tokenizer**: The `lerobot/fast-action-tokenizer` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
|
||||
|
||||
2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data.
|
||||
|
||||
@@ -114,15 +109,15 @@ lerobot-train \
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
| -------------------------------------- | -------------------------------------------------- | ---------------------------- |
|
||||
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
|
||||
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
|
||||
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
|
||||
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
|
||||
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
|
||||
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` |
|
||||
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
|
||||
| Parameter | Description | Default |
|
||||
| -------------------------------------- | -------------------------------------------------- | ------------------------------- |
|
||||
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
|
||||
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
|
||||
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
|
||||
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
|
||||
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
|
||||
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `lerobot/fast-action-tokenizer` |
|
||||
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
|
||||
|
||||
## Inference
|
||||
|
||||
|
||||
+140
-186
@@ -1,23 +1,49 @@
|
||||
# Unitree G1
|
||||
|
||||
This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion.
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/unitree_thumbnail.jpg"
|
||||
alt="Unitree G1 locomanipulation demo"
|
||||
style={{ width: "100%" }}
|
||||
/>
|
||||
|
||||
## About
|
||||
|
||||
We support both 29 and 23 DOF G1 EDU version. We introduce:
|
||||
|
||||
- **`unitree g1` robot class, handling low level read/write from/to the humanoid**
|
||||
- **ZMQ socket bridge** for remote communication and camera streaming, allowing for remote policy deployment over wlan, eth or directly on the robot
|
||||
- **Locomotion policies** from NVIDIA gr00t and Amazon FAR Holosoma
|
||||
- **Simulation mode** for testing policies without the physical robot in mujoco
|
||||
The Unitree G1 humanoid is now supported in LeRobot! You can teleoperate, train locomanipulation policies, test in sim, and more. Both 29 and 23 DoF variants are supported.
|
||||
|
||||
---
|
||||
|
||||
## Connection guide
|
||||
## Part 1: Getting Started
|
||||
|
||||
### Step 1: Configure Ethernet Interface
|
||||
### Install LeRobot on Your Machine
|
||||
|
||||
Set a static IP on the same subnet as the robot:
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.12
|
||||
conda activate lerobot
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
```
|
||||
|
||||
### Test the Installation (Simulation)
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.id=wbc_unitree \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
This will launch a [MuJoCo sim instance](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main) for the G1.
|
||||
|
||||
- Press `9` to release the robot
|
||||
- Press `7` / `8` to increase / decrease waist height
|
||||
|
||||
### Connect to the Robot
|
||||
|
||||
The G1's Ethernet IP is fixed at `192.168.123.164`. Your machine must have a static IP on the same subnet: `192.168.123.x` where `x ≠ 164`.
|
||||
|
||||
```bash
|
||||
# Replace 'enp131s0' with your ethernet interface name (check with `ip a`)
|
||||
@@ -26,272 +52,200 @@ sudo ip addr add 192.168.123.200/24 dev enp131s0
|
||||
sudo ip link set enp131s0 up
|
||||
```
|
||||
|
||||
**Note**: The G1's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` with x ≠ 164.
|
||||
|
||||
### Step 2: SSH into the Robot
|
||||
### SSH into the Robot
|
||||
|
||||
```bash
|
||||
ssh unitree@192.168.123.164
|
||||
# Password: 123
|
||||
```
|
||||
|
||||
You should now be connected to the G1's Orin.
|
||||
### Install LeRobot on the G1
|
||||
|
||||
From the robot:
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.12
|
||||
conda activate lerobot
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
```
|
||||
|
||||
> **Note:** The Unitree SDK requires CycloneDDS v0.10.2. See the [Unitree SDK docs](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
|
||||
|
||||
---
|
||||
|
||||
## Part 2: Enable WiFi on the Robot
|
||||
|
||||
Wlan0 is disabled by default on the G1. To enable it:
|
||||
|
||||
### Step 1: Enable WiFi Hardware
|
||||
Wi-Fi connectivity is blocked by default on the G1. To activate:
|
||||
|
||||
```bash
|
||||
sudo rfkill unblock wifi
|
||||
sudo rfkill unblock all
|
||||
|
||||
# Bring up wlan0
|
||||
sudo ip link set wlan0 up
|
||||
|
||||
# Enable NetworkManager control of wlan0
|
||||
sudo nmcli radio wifi on
|
||||
sudo nmcli device set wlan0 managed yes
|
||||
sudo systemctl restart NetworkManager
|
||||
```
|
||||
|
||||
### Step 2: Enable Internet Forwarding
|
||||
|
||||
**On your laptop:**
|
||||
**On your laptop** (share internet via Ethernet):
|
||||
|
||||
```bash
|
||||
# Enable IP forwarding
|
||||
sudo sysctl -w net.ipv4.ip_forward=1
|
||||
|
||||
# Set up NAT (replace wlp132s0f0 with your WiFi interface)
|
||||
# Replace wlp132s0f0 with your WiFi interface name
|
||||
sudo iptables -t nat -A POSTROUTING -o wlp132s0f0 -s 192.168.123.0/24 -j MASQUERADE
|
||||
sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTABLISHED -j ACCEPT
|
||||
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
|
||||
```
|
||||
|
||||
**On the G1:**
|
||||
**On the G1** (set default route through your laptop):
|
||||
|
||||
```bash
|
||||
# Add laptop as default gateway
|
||||
sudo ip route del default 2>/dev/null || true
|
||||
sudo ip route add default via 192.168.123.200 dev eth0
|
||||
echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
|
||||
|
||||
# Test connection
|
||||
# Verify
|
||||
ping -c 3 8.8.8.8
|
||||
```
|
||||
|
||||
### Step 3: Connect to WiFi Network
|
||||
**Connect to a WiFi network:**
|
||||
|
||||
```bash
|
||||
# List available networks
|
||||
nmcli device wifi list
|
||||
|
||||
# Connect to your WiFi (example)
|
||||
sudo nmcli connection add type wifi ifname wlan0 con-name "YourNetwork" ssid "YourNetwork"
|
||||
sudo nmcli connection modify "YourNetwork" wifi-sec.key-mgmt wpa-psk
|
||||
sudo nmcli connection modify "YourNetwork" wifi-sec.psk "YourPassword"
|
||||
sudo nmcli connection modify "YourNetwork" connection.autoconnect yes
|
||||
sudo nmcli connection up "YourNetwork"
|
||||
|
||||
# Check WiFi IP address
|
||||
ip a show wlan0
|
||||
```
|
||||
|
||||
### Step 4: SSH Over WiFi
|
||||
|
||||
Once connected to WiFi, note the robot's IP address and disconnect the Ethernet cable. You can now SSH over WiFi:
|
||||
You can now SSH over WiFi:
|
||||
|
||||
```bash
|
||||
ssh unitree@<YOUR_ROBOT_IP>
|
||||
ssh unitree@<ROBOT_WIFI_IP>
|
||||
# Password: 123
|
||||
```
|
||||
|
||||
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address.
|
||||
|
||||
---
|
||||
|
||||
## Part 3: Robot Server Setup
|
||||
## Part 3: Teleoperation & Locomotion
|
||||
|
||||
### Step 1: Install LeRobot on the Orin
|
||||
|
||||
SSH into the robot and install LeRobot:
|
||||
|
||||
```bash
|
||||
ssh unitree@<YOUR_ROBOT_IP>
|
||||
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
```
|
||||
|
||||
**Note**: The Unitree SDK requires CycloneDDS v0.10.2 to be installed. See the [Unitree SDK documentation](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
|
||||
|
||||
### Step 2: Run the Robot Server
|
||||
### Run the Robot Server
|
||||
|
||||
On the robot:
|
||||
|
||||
```bash
|
||||
python src/lerobot/robots/unitree_g1/run_g1_server.py
|
||||
python src/lerobot/robots/unitree_g1/run_g1_server.py --camera
|
||||
```
|
||||
|
||||
**Important**: Keep this terminal running. The server must be active for remote control.
|
||||
### Run the Locomotion Policy
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.robot_ip=<ROBOT_IP> \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.id=wbc_unitree \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "<ROBOT_IP>", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--display_data=true \
|
||||
--robot.controller=HolosomaLocomotionController
|
||||
```
|
||||
|
||||
We support both [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) and [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl).
|
||||
|
||||
---
|
||||
|
||||
## Part 4: Controlling the robot
|
||||
## Part 4: Loco-Manipulation with the Homunculus Exoskeleton
|
||||
|
||||
With the robot server running, you can now control the robot remotely. Let's launch a locomotion policy
|
||||
We provide a loco-manipulation solution via the Homunculus Exoskeleton — an open-source 7 DoF exoskeleton for whole-body control. Assembly instructions [here](https://github.com/nepyope/hmc_exo).
|
||||
|
||||
### Step 1: Install LeRobot on your machine
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
```
|
||||
|
||||
### Step 2: Update Robot IP in Config
|
||||
|
||||
Edit the config file to match your robot's WiFi IP:
|
||||
|
||||
```python
|
||||
# In src/lerobot/robots/unitree_g1/config_unitree_g1.py
|
||||
robot_ip: str = "<YOUR_ROBOT_IP>" # Replace with your robot's WiFi IP.
|
||||
```
|
||||
|
||||
### Step 3: Run the Locomotion Policy
|
||||
|
||||
```bash
|
||||
# Run GR00T locomotion controller
|
||||
python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1"
|
||||
|
||||
# Run Holosoma locomotion controller
|
||||
python examples/unitree_g1/holosoma_locomotion.py
|
||||
|
||||
```
|
||||
|
||||
Press `Ctrl+C` to stop the policy.
|
||||
|
||||
---
|
||||
|
||||
## Running in Simulation Mode (MuJoCo)
|
||||
|
||||
You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI.
|
||||
|
||||
### Calibrate Exoskeleton Teleoperator
|
||||
### Calibrate
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo
|
||||
```
|
||||
|
||||
### Teleoperate in Simulation
|
||||
During calibration move each joint through its entire range. After fitting, move the joint in a neutral position and press `n` to advance.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--fps=100
|
||||
```
|
||||
|
||||
### Record Dataset in Simulation
|
||||
### Record a Dataset
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
|
||||
> **Note:** Omit `--teleop.left_arm_config.port` and `--teleop.right_arm_config.port` if you're only using the joystick.
|
||||
|
||||
Example dataset: [nepyope/unitree_box_move_blue_full](https://huggingface.co/datasets/nepyope/unitree_box_move_blue_full)
|
||||
|
||||
---
|
||||
|
||||
## Running on Real Robot
|
||||
## Part 5: Training & Inference
|
||||
|
||||
Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot.
|
||||
|
||||
### Start the Camera Server
|
||||
|
||||
On the robot, start the ZMQ image server:
|
||||
### Train
|
||||
|
||||
```bash
|
||||
python src/lerobot/cameras/zmq/image_server.py
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--policy.type=pi05 \
|
||||
--output_dir=./outputs/pi05_training \
|
||||
--job_name=pi05_training \
|
||||
--policy.repo_id=your-username/your-repo-id \
|
||||
--policy.pretrained_path=lerobot/pi05_base \
|
||||
--policy.compile_model=true \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--wandb.enable=true \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.train_expert_only=false \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--batch_size=32
|
||||
```
|
||||
|
||||
Keep this running in a separate terminal for camera streaming during recording.
|
||||
### Inference with RTC
|
||||
|
||||
### Teleoperate Real Robot
|
||||
Once trained, we recommend deploying policies using inference-time RTC:
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--fps=100
|
||||
python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=your-username/your-repo-id \
|
||||
--policy.device=cuda \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.controller=HolosomaLocomotionController \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "<ROBOT_IP>", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--task="task_description" \
|
||||
--duration=1000 \
|
||||
--fps=30 \
|
||||
--rtc.enabled=true
|
||||
```
|
||||
|
||||
### Record Dataset on Real Robot
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
**Note**: Update `server_address` to match your robot's camera server IP.
|
||||
|
||||
Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real)
|
||||
|
||||
---
|
||||
|
||||
## Additional Resources
|
||||
@@ -300,8 +254,8 @@ Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/da
|
||||
- [GR00T-WholeBodyControl](https://github.com/NVlabs/GR00T-WholeBodyControl)
|
||||
- [Holosoma](https://github.com/amazon-far/holosoma)
|
||||
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||
- [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
|
||||
- [Unitree IL LeRobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
|
||||
|
||||
---
|
||||
|
||||
_Last updated: December 2025_
|
||||
_Last updated: March 2026_
|
||||
|
||||
@@ -57,7 +57,7 @@ class DatasetReplayConfig:
|
||||
repo_id: str
|
||||
# Episode to replay.
|
||||
episode: int
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second. By default, uses the policy fps.
|
||||
fps: int = 30
|
||||
|
||||
@@ -0,0 +1,490 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SLURM-distributed SARM RA-BC annotation pipeline.
|
||||
|
||||
Computes SARM progress values for all frames in a dataset, distributed across
|
||||
SLURM workers, then merges the shards into a single sarm_progress.parquet.
|
||||
|
||||
Two subcommands, each a separate SLURM submission:
|
||||
|
||||
compute – N workers, each computes progress for a subset of episodes
|
||||
aggregate – 1 worker, merges N shards into sarm_progress.parquet, pushes to hub
|
||||
|
||||
Usage:
|
||||
python slurm_compute_rabc.py compute \\
|
||||
--repo-id user/dataset --reward-model-path user/sarm_model \\
|
||||
--stride 10 --device cpu --workers 50 --partition cpu
|
||||
|
||||
python slurm_compute_rabc.py aggregate \\
|
||||
--repo-id user/dataset --reward-model-path user/sarm_model \\
|
||||
--partition cpu --push-to-hub
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
|
||||
class ComputeProgressShards(PipelineStep):
|
||||
"""Each worker computes SARM progress for its assigned episodes."""
|
||||
|
||||
def __init__(
|
||||
self, repo_id, reward_model_path, stride=1, head_mode="sparse", device="cpu", shard_dir="rabc_shards"
|
||||
):
|
||||
super().__init__()
|
||||
if stride < 1:
|
||||
raise ValueError(f"stride must be >= 1, got {stride}")
|
||||
self.repo_id = repo_id
|
||||
self.reward_model_path = reward_model_path
|
||||
self.stride = stride
|
||||
self.head_mode = head_mode
|
||||
self.device = device
|
||||
self.shard_dir = shard_dir
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.policies.sarm.compute_rabc_weights import (
|
||||
generate_all_frame_indices,
|
||||
interpolate_progress,
|
||||
load_sarm_resources,
|
||||
)
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
|
||||
dataset, reward_model, preprocess = load_sarm_resources(
|
||||
self.repo_id,
|
||||
self.reward_model_path,
|
||||
self.device,
|
||||
)
|
||||
|
||||
if hasattr(preprocess, "eval"):
|
||||
preprocess.eval()
|
||||
for step in preprocess.steps:
|
||||
if hasattr(step, "eval"):
|
||||
step.eval()
|
||||
|
||||
image_key = reward_model.config.image_key
|
||||
state_key = reward_model.config.state_key
|
||||
frame_gap = reward_model.config.frame_gap
|
||||
center_idx = reward_model.config.n_obs_steps // 2
|
||||
|
||||
dual_mode = reward_model.config.uses_dual_heads
|
||||
compute_sparse = self.head_mode in ("sparse", "both") or not dual_mode
|
||||
compute_dense = self.head_mode in ("dense", "both") and dual_mode
|
||||
|
||||
my_episodes = list(range(dataset.num_episodes))[rank::world_size]
|
||||
if not my_episodes:
|
||||
logging.info(f"Rank {rank}: no episodes assigned")
|
||||
return
|
||||
logging.info(f"Rank {rank}: {len(my_episodes)} / {dataset.num_episodes} episodes")
|
||||
|
||||
all_rows = []
|
||||
|
||||
for ep_idx in tqdm(my_episodes, desc=f"Rank {rank}"):
|
||||
ep = dataset.meta.episodes[ep_idx]
|
||||
ep_start, ep_end = ep["dataset_from_index"], ep["dataset_to_index"]
|
||||
task = dataset[ep_start].get("task", "perform the task")
|
||||
|
||||
all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap)
|
||||
if self.stride > 1:
|
||||
compute_indices = [i for i in all_ep_indices if (i - ep_start) % self.stride == 0]
|
||||
if (ep_end - 1) not in compute_indices:
|
||||
compute_indices.append(ep_end - 1)
|
||||
compute_indices = sorted(set(compute_indices))
|
||||
else:
|
||||
compute_indices = all_ep_indices
|
||||
|
||||
frame_results = {}
|
||||
for qi in tqdm(compute_indices, desc=f" Ep {ep_idx}", leave=False):
|
||||
try:
|
||||
sample = dataset[qi]
|
||||
batch = {
|
||||
image_key: sample[image_key],
|
||||
"task": task,
|
||||
"index": qi,
|
||||
"episode_index": ep_idx,
|
||||
}
|
||||
if state_key in sample:
|
||||
batch[state_key] = sample[state_key]
|
||||
|
||||
with torch.no_grad():
|
||||
processed = preprocess(batch)
|
||||
vf = processed["video_features"].to(self.device)
|
||||
tf = processed["text_features"].to(self.device)
|
||||
sf = processed.get("state_features")
|
||||
if sf is not None:
|
||||
sf = sf.to(self.device)
|
||||
lengths = processed.get("lengths")
|
||||
|
||||
sparse_val = dense_val = np.nan
|
||||
if compute_sparse:
|
||||
r = reward_model.calculate_rewards(
|
||||
text_embeddings=tf,
|
||||
video_embeddings=vf,
|
||||
state_features=sf,
|
||||
lengths=lengths,
|
||||
return_all_frames=True,
|
||||
head_mode="sparse",
|
||||
)
|
||||
sparse_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx])
|
||||
if compute_dense:
|
||||
r = reward_model.calculate_rewards(
|
||||
text_embeddings=tf,
|
||||
video_embeddings=vf,
|
||||
state_features=sf,
|
||||
lengths=lengths,
|
||||
return_all_frames=True,
|
||||
head_mode="dense",
|
||||
)
|
||||
dense_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx])
|
||||
|
||||
frame_results[qi] = (sparse_val, dense_val)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed frame {qi}: {e}")
|
||||
|
||||
if not frame_results:
|
||||
logging.warning(f"Episode {ep_idx}: all frames failed, skipping")
|
||||
continue
|
||||
|
||||
# Interpolate to all frames in this episode
|
||||
computed_idx = np.array(sorted(frame_results.keys()))
|
||||
all_frame_arr = np.arange(ep_start, ep_end)
|
||||
|
||||
sparse_vals = np.array([frame_results[i][0] for i in computed_idx]) if compute_sparse else None
|
||||
dense_vals = np.array([frame_results[i][1] for i in computed_idx]) if compute_dense else None
|
||||
|
||||
if self.stride > 1 and len(computed_idx) > 1:
|
||||
if compute_sparse:
|
||||
sparse_vals = interpolate_progress(computed_idx, sparse_vals, all_frame_arr)
|
||||
if compute_dense:
|
||||
dense_vals = interpolate_progress(computed_idx, dense_vals, all_frame_arr)
|
||||
output_frames = all_frame_arr
|
||||
else:
|
||||
# Use only successfully computed frames to avoid indexing mismatch on failures
|
||||
output_frames = computed_idx
|
||||
|
||||
for i, fi in enumerate(output_frames):
|
||||
row = {"index": int(fi), "episode_index": ep_idx, "frame_index": int(fi - ep_start)}
|
||||
if compute_sparse:
|
||||
row["progress_sparse"] = float(sparse_vals[i])
|
||||
if compute_dense:
|
||||
row["progress_dense"] = float(dense_vals[i])
|
||||
all_rows.append(row)
|
||||
|
||||
if all_rows:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame(all_rows).sort_values("index").reset_index(drop=True)
|
||||
table = pa.Table.from_pandas(df, preserve_index=False)
|
||||
table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()})
|
||||
shard_dir = Path(self.shard_dir)
|
||||
shard_dir.mkdir(parents=True, exist_ok=True)
|
||||
out = shard_dir / f"shard_{rank:05d}.parquet"
|
||||
pq.write_table(table, out)
|
||||
logging.info(f"Rank {rank}: saved {len(df)} rows to {out}")
|
||||
|
||||
|
||||
class AggregateProgress(PipelineStep):
|
||||
"""Merge all shard parquets into final sarm_progress.parquet."""
|
||||
|
||||
def __init__(self, repo_id, reward_model_path, shard_dir="rabc_shards", push_to_hub=False):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.reward_model_path = reward_model_path
|
||||
self.shard_dir = shard_dir
|
||||
self.push_to_hub = push_to_hub
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
if rank != 0:
|
||||
return
|
||||
|
||||
shard_dir = Path(self.shard_dir)
|
||||
shards = sorted(shard_dir.glob("shard_*.parquet"))
|
||||
if not shards:
|
||||
raise FileNotFoundError(f"No shards found in {shard_dir}")
|
||||
|
||||
# Log shard modification time range to help detect stale files
|
||||
mtimes = [os.path.getmtime(s) for s in shards]
|
||||
oldest = datetime.datetime.fromtimestamp(min(mtimes)).isoformat(timespec="seconds")
|
||||
newest = datetime.datetime.fromtimestamp(max(mtimes)).isoformat(timespec="seconds")
|
||||
logging.info(f"Aggregating {len(shards)} shards (oldest: {oldest}, newest: {newest})")
|
||||
|
||||
df = pd.concat([pd.read_parquet(s) for s in shards], ignore_index=True)
|
||||
df = df.sort_values("index").reset_index(drop=True)
|
||||
|
||||
table = pa.Table.from_pandas(df, preserve_index=False)
|
||||
table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()})
|
||||
|
||||
temp_ds = LeRobotDataset(self.repo_id, download_videos=False)
|
||||
out_path = Path(temp_ds.root) / "sarm_progress.parquet"
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(table, out_path)
|
||||
logging.info(f"Saved {len(df)} rows to {out_path}")
|
||||
|
||||
for col in ["progress_sparse", "progress_dense"]:
|
||||
if col in df.columns:
|
||||
v = df[col].dropna()
|
||||
logging.info(
|
||||
f"{col}: mean={v.mean():.4f} std={v.std():.4f} min={v.min():.4f} max={v.max():.4f}"
|
||||
)
|
||||
|
||||
if self.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
hub_path = "sarm_progress.parquet"
|
||||
logging.info(f"Uploading to {self.repo_id}/{hub_path}")
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(out_path),
|
||||
path_in_repo=hub_path,
|
||||
repo_id=self.repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
logging.info(f"Uploaded: https://huggingface.co/datasets/{self.repo_id}/blob/main/{hub_path}")
|
||||
|
||||
|
||||
def make_compute_executor(
|
||||
repo_id,
|
||||
reward_model_path,
|
||||
stride,
|
||||
head_mode,
|
||||
device,
|
||||
shard_dir,
|
||||
logs_dir,
|
||||
job_name,
|
||||
slurm,
|
||||
workers,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
ComputeProgressShards(repo_id, reward_model_path, stride, head_mode, device, str(shard_dir)),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": workers,
|
||||
"workers": workers,
|
||||
"time": "24:00:00",
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
return SlurmPipelineExecutor(**kwargs)
|
||||
|
||||
kwargs.update({"tasks": workers, "workers": 1})
|
||||
return LocalPipelineExecutor(**kwargs)
|
||||
|
||||
|
||||
def make_aggregate_executor(
|
||||
repo_id,
|
||||
reward_model_path,
|
||||
shard_dir,
|
||||
logs_dir,
|
||||
job_name,
|
||||
slurm,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
push_to_hub,
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
AggregateProgress(repo_id, reward_model_path, str(shard_dir), push_to_hub),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
"time": "02:00:00",
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
return SlurmPipelineExecutor(**kwargs)
|
||||
|
||||
kwargs.update({"tasks": 1, "workers": 1})
|
||||
return LocalPipelineExecutor(**kwargs)
|
||||
|
||||
|
||||
def _add_shared_args(p):
|
||||
p.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Hugging Face repository identifier, e.g. 'user/dataset'.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--shard-dir",
|
||||
type=Path,
|
||||
default=Path("rabc_shards"),
|
||||
help="Directory to read/write per-rank parquet shards.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
default=Path("logs"),
|
||||
help="Directory for datatrove logs.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="SLURM job name (defaults to rabc_<subcommand>).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="1 = submit via SLURM; 0 = run locally (useful for debugging).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
default=None,
|
||||
help="SLURM partition to submit to.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of CPUs per SLURM task.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="4G",
|
||||
help="Memory per CPU, e.g. '4G' or '1950M'.",
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="SLURM-distributed SARM RA-BC annotation pipeline",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# compute subcommand
|
||||
cp = sub.add_parser(
|
||||
"compute",
|
||||
help="Distribute progress computation across SLURM workers.",
|
||||
)
|
||||
_add_shared_args(cp)
|
||||
cp.add_argument(
|
||||
"--reward-model-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path or HF repo id of the SARM reward model.",
|
||||
)
|
||||
cp.add_argument(
|
||||
"--stride",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Compute every Nth frame; intermediate frames are interpolated (must be >= 1).",
|
||||
)
|
||||
cp.add_argument(
|
||||
"--head-mode",
|
||||
type=str,
|
||||
default="sparse",
|
||||
choices=["sparse", "dense", "both"],
|
||||
help="Which reward head(s) to compute.",
|
||||
)
|
||||
cp.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Device for reward model inference, e.g. 'cpu' or 'cuda'.",
|
||||
)
|
||||
cp.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of parallel SLURM tasks (one shard per worker).",
|
||||
)
|
||||
|
||||
# aggregate subcommand
|
||||
ap = sub.add_parser(
|
||||
"aggregate",
|
||||
help="Merge per-rank shards into a single sarm_progress.parquet.",
|
||||
)
|
||||
_add_shared_args(ap)
|
||||
ap.add_argument(
|
||||
"--reward-model-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path or HF repo id of the SARM reward model (stored in parquet metadata).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Upload sarm_progress.parquet to the Hugging Face Hub after aggregation.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
job_name = args.job_name or f"rabc_{args.command}"
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
kwargs["job_name"] = job_name
|
||||
command = kwargs.pop("command")
|
||||
|
||||
executor = make_compute_executor(**kwargs) if command == "compute" else make_aggregate_executor(**kwargs)
|
||||
|
||||
executor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,351 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Human-in-the-Loop (HIL) Data Collection with Policy Rollout.
|
||||
|
||||
Implements the RaC paradigm (Hu et al., 2025) for LeRobot with standard synchronous
|
||||
inference. For large models with high inference latency, use hil_data_collection_rtc.py.
|
||||
|
||||
The workflow:
|
||||
1. Policy runs autonomously
|
||||
2. Press SPACE to pause - robot holds position
|
||||
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
|
||||
4. Press → to end episode (save and continue to next)
|
||||
5. Reset, then do next rollout
|
||||
|
||||
Keyboard Controls:
|
||||
SPACE - Pause policy (robot holds position, no recording)
|
||||
c - Take control (start correction, recording resumes)
|
||||
→ - End episode (save and continue to next)
|
||||
← - Re-record episode
|
||||
ESC - Stop recording and push dataset to hub
|
||||
|
||||
Usage:
|
||||
python examples/rac/hil_data_collection.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=my_user/hil_dataset \
|
||||
--dataset.single_task="Pick up the cube"
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from hil_utils import (
|
||||
HILDatasetConfig,
|
||||
init_keyboard_listener,
|
||||
make_identity_processors,
|
||||
print_controls,
|
||||
reset_loop,
|
||||
teleop_disable_torque,
|
||||
teleop_smooth_move_to,
|
||||
)
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import is_headless, predict_action
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILConfig:
|
||||
robot: RobotConfig
|
||||
teleop: TeleoperatorConfig
|
||||
dataset: HILDatasetConfig
|
||||
policy: PreTrainedConfig | None = None
|
||||
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
||||
display_data: bool = True
|
||||
play_sounds: bool = True
|
||||
resume: bool = False
|
||||
device: str = "cuda"
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
if self.policy is None:
|
||||
raise ValueError("policy.path is required")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def rollout_loop(
|
||||
robot: Robot,
|
||||
teleop: Teleoperator,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
dataset: LeRobotDataset,
|
||||
events: dict,
|
||||
cfg: HILConfig,
|
||||
):
|
||||
"""Rollout loop with standard synchronous inference."""
|
||||
fps = cfg.dataset.fps
|
||||
device = get_safe_torch_device(cfg.device)
|
||||
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
frame_buffer = []
|
||||
teleop_disable_torque(teleop)
|
||||
|
||||
was_paused = False
|
||||
waiting_for_takeover = False
|
||||
last_action: dict[str, Any] | None = None
|
||||
robot_action: dict[str, Any] = {}
|
||||
action_keys = sorted(robot.action_features.keys())
|
||||
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
control_interval = interpolator.get_control_interval(fps)
|
||||
|
||||
timestamp = 0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < cfg.dataset.episode_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
break
|
||||
|
||||
# Transition to paused state
|
||||
if events["policy_paused"] and not was_paused:
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {
|
||||
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||
}
|
||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
events["start_next_episode"] = False
|
||||
waiting_for_takeover = True
|
||||
was_paused = True
|
||||
interpolator.reset()
|
||||
|
||||
# Takeover
|
||||
if waiting_for_takeover and events["start_next_episode"]:
|
||||
teleop_disable_torque(teleop)
|
||||
events["start_next_episode"] = False
|
||||
events["correction_active"] = True
|
||||
waiting_for_takeover = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
|
||||
|
||||
if events["correction_active"]:
|
||||
robot_action = teleop.get_action()
|
||||
robot.send_action(robot_action)
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
|
||||
|
||||
elif waiting_for_takeover or events["policy_paused"]:
|
||||
if last_action:
|
||||
robot.send_action(last_action)
|
||||
|
||||
else:
|
||||
# Policy execution with optional interpolation
|
||||
if interpolator.needs_new_action():
|
||||
action_values = predict_action(
|
||||
observation=obs_frame,
|
||||
policy=policy,
|
||||
device=device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=cfg.dataset.single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
robot_action = make_robot_action(action_values, dataset.features)
|
||||
action_tensor = torch.tensor([robot_action[k] for k in action_keys])
|
||||
interpolator.add(action_tensor)
|
||||
|
||||
interp_action = interpolator.get()
|
||||
if interp_action is not None:
|
||||
robot_action = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
|
||||
robot.send_action(robot_action)
|
||||
last_action = robot_action
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
|
||||
|
||||
if cfg.display_data and robot_action:
|
||||
log_rerun_data(observation=obs_filtered, action=robot_action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_time := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_time)
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
teleop_disable_torque(teleop)
|
||||
|
||||
for frame in frame_buffer:
|
||||
dataset.add_frame(frame)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def hil_collect(cfg: HILConfig) -> LeRobotDataset:
|
||||
"""Main HIL data collection function."""
|
||||
init_logging()
|
||||
logger.info(pformat(cfg.__dict__))
|
||||
|
||||
if cfg.display_data:
|
||||
init_rerun(session_name="hil_collection")
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
|
||||
teleop_proc, obs_proc = make_identity_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_proc,
|
||||
initial_features=create_initial_features(action=robot.action_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=obs_proc,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
)
|
||||
|
||||
dataset = None
|
||||
listener = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
if hasattr(robot, "cameras") and robot.cameras:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.dataset.num_image_writer_processes,
|
||||
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
else:
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot.name,
|
||||
features=dataset_features,
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||
* len(robot.cameras if hasattr(robot, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
print_controls(rtc=False)
|
||||
print(f" Policy: {cfg.policy.pretrained_path}")
|
||||
print(f" Task: {cfg.dataset.single_task}")
|
||||
print(f" Interpolation: {cfg.interpolation_multiplier}x\n")
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded = 0
|
||||
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
|
||||
rollout_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
cfg=cfg,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
|
||||
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
reset_loop(robot, teleop, events, cfg.dataset.fps)
|
||||
|
||||
finally:
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
if teleop.is_connected:
|
||||
teleop.disconnect()
|
||||
|
||||
if not is_headless() and listener:
|
||||
listener.stop()
|
||||
|
||||
if cfg.dataset.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
register_third_party_plugins()
|
||||
hil_collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,513 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Human-in-the-Loop (HIL) Data Collection with Real-Time Chunking (RTC).
|
||||
|
||||
Implements the RaC paradigm (Hu et al., 2025) with RTC for large flow-matching models
|
||||
(Pi0, Pi0.5, SmolVLA) that have high inference latency. RTC generates action chunks
|
||||
asynchronously in a background thread for smooth robot control.
|
||||
|
||||
For fast models (ACT, Diffusion), use hil_data_collection.py instead.
|
||||
|
||||
The workflow:
|
||||
1. Policy runs autonomously with RTC
|
||||
2. Press SPACE to pause - robot holds position
|
||||
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
|
||||
4. Press → to end episode (save and continue to next)
|
||||
5. Reset, then do next rollout
|
||||
|
||||
Keyboard Controls:
|
||||
SPACE - Pause policy (robot holds position, no recording)
|
||||
c - Take control (start correction, recording resumes)
|
||||
→ - End episode (save and continue to next)
|
||||
← - Re-record episode
|
||||
ESC - Stop recording and push dataset to hub
|
||||
|
||||
Usage:
|
||||
python examples/rac/hil_data_collection_rtc.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--policy.path=outputs/train/pi0_policy/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=my_user/hil_rtc_dataset \
|
||||
--dataset.single_task="Pick up the cube" \
|
||||
--rtc.execution_horizon=20
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pprint import pformat
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from hil_utils import (
|
||||
HILDatasetConfig,
|
||||
init_keyboard_listener,
|
||||
make_identity_processors,
|
||||
print_controls,
|
||||
reset_loop,
|
||||
teleop_disable_torque,
|
||||
teleop_smooth_move_to,
|
||||
)
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import is_headless
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import init_logging, log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILRTCConfig:
|
||||
robot: RobotConfig
|
||||
teleop: TeleoperatorConfig
|
||||
dataset: HILDatasetConfig
|
||||
policy: PreTrainedConfig | None = None
|
||||
rtc: RTCConfig = field(default_factory=lambda: RTCConfig(enabled=True, execution_horizon=20))
|
||||
interpolation_multiplier: int = 2 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
||||
display_data: bool = True
|
||||
play_sounds: bool = True
|
||||
resume: bool = False
|
||||
device: str = "cuda"
|
||||
use_torch_compile: bool = False # First compile takes minutes, disable for real-time
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
if self.policy is None:
|
||||
raise ValueError("policy.path is required")
|
||||
self.rtc.enabled = True
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
class ThreadSafeRobot:
|
||||
"""Thread-safe wrapper for robot operations."""
|
||||
|
||||
def __init__(self, robot: Robot):
|
||||
self._robot = robot
|
||||
self._lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
return self._robot.get_observation()
|
||||
|
||||
def send_action(self, action: dict) -> None:
|
||||
with self._lock:
|
||||
self._robot.send_action(action)
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict:
|
||||
return self._robot.observation_features
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
return self._robot.action_features
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._robot.name
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str:
|
||||
return self._robot.robot_type
|
||||
|
||||
@property
|
||||
def cameras(self):
|
||||
return getattr(self._robot, "cameras", {})
|
||||
|
||||
|
||||
def rtc_inference_thread(
|
||||
policy: PreTrainedPolicy,
|
||||
obs_holder: dict,
|
||||
hw_features: dict,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
queue_holder: dict,
|
||||
shutdown_event: Event,
|
||||
policy_active: Event,
|
||||
cfg: HILRTCConfig,
|
||||
):
|
||||
"""Background thread for RTC action chunk generation."""
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / cfg.dataset.fps
|
||||
threshold = 30
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not policy_active.is_set():
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
queue = queue_holder.get("queue")
|
||||
obs = obs_holder.get("obs")
|
||||
if queue is None or obs is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
if queue.qsize() <= threshold:
|
||||
try:
|
||||
current_time = time.perf_counter()
|
||||
idx_before = queue.get_action_index()
|
||||
prev_actions = queue.get_left_over()
|
||||
|
||||
latency = latency_tracker.max()
|
||||
delay = math.ceil(latency / time_per_chunk) if latency else 0
|
||||
|
||||
obs_batch = build_dataset_frame(hw_features, obs, prefix="observation")
|
||||
for name in obs_batch:
|
||||
obs_batch[name] = torch.from_numpy(obs_batch[name])
|
||||
if "image" in name:
|
||||
obs_batch[name] = obs_batch[name].float() / 255
|
||||
obs_batch[name] = obs_batch[name].permute(2, 0, 1).contiguous()
|
||||
obs_batch[name] = obs_batch[name].unsqueeze(0).to(cfg.device)
|
||||
|
||||
obs_batch["task"] = [cfg.dataset.single_task]
|
||||
obs_batch["robot_type"] = obs_holder.get("robot_type", "unknown")
|
||||
|
||||
preprocessed = preprocessor(obs_batch)
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
|
||||
original = actions.squeeze(0).clone()
|
||||
processed = postprocessor(actions).squeeze(0)
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
queue.merge(original, processed, new_delay, idx_before)
|
||||
logger.debug(f"[RTC] Inference latency={new_latency:.2f}s, queue={queue.qsize()}")
|
||||
except Exception as e:
|
||||
logger.error(f"[RTC] Error: {e}")
|
||||
time.sleep(0.5)
|
||||
else:
|
||||
time.sleep(0.01)
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def rollout_loop(
|
||||
robot: ThreadSafeRobot,
|
||||
teleop: Teleoperator,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
dataset: LeRobotDataset,
|
||||
events: dict,
|
||||
cfg: HILRTCConfig,
|
||||
queue_holder: dict,
|
||||
obs_holder: dict,
|
||||
policy_active: Event,
|
||||
hw_features: dict,
|
||||
):
|
||||
"""Rollout loop with RTC for asynchronous inference."""
|
||||
fps = cfg.dataset.fps
|
||||
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
frame_buffer = []
|
||||
teleop_disable_torque(teleop)
|
||||
|
||||
was_paused = False
|
||||
waiting_for_takeover = False
|
||||
last_action: dict[str, Any] | None = None
|
||||
action_keys = [k for k in robot.action_features if k.endswith(".pos")]
|
||||
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
control_interval = interpolator.get_control_interval(fps)
|
||||
|
||||
robot_action: dict[str, Any] = {}
|
||||
timestamp = 0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < cfg.dataset.episode_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
break
|
||||
|
||||
# Transition to paused state
|
||||
if events["policy_paused"] and not was_paused:
|
||||
policy_active.clear()
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {
|
||||
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||
}
|
||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
events["start_next_episode"] = False
|
||||
waiting_for_takeover = True
|
||||
was_paused = True
|
||||
interpolator.reset()
|
||||
|
||||
# Takeover
|
||||
if waiting_for_takeover and events["start_next_episode"]:
|
||||
teleop_disable_torque(teleop)
|
||||
events["start_next_episode"] = False
|
||||
events["correction_active"] = True
|
||||
waiting_for_takeover = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
|
||||
|
||||
obs_holder["obs"] = obs_filtered
|
||||
|
||||
if events["correction_active"]:
|
||||
robot_action = teleop.get_action()
|
||||
robot.send_action(robot_action)
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
|
||||
|
||||
elif waiting_for_takeover or events["policy_paused"]:
|
||||
if last_action:
|
||||
robot.send_action(last_action)
|
||||
|
||||
else:
|
||||
# Policy execution with RTC
|
||||
if not policy_active.is_set():
|
||||
policy_active.set()
|
||||
|
||||
queue = queue_holder["queue"]
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
new_action = queue.get() if queue else None
|
||||
if new_action is not None:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action_tensor = interpolator.get()
|
||||
if action_tensor is not None:
|
||||
robot_action = {
|
||||
k: action_tensor[i].item() for i, k in enumerate(action_keys) if i < len(action_tensor)
|
||||
}
|
||||
robot.send_action(robot_action)
|
||||
last_action = robot_action
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
|
||||
|
||||
if cfg.display_data and robot_action:
|
||||
log_rerun_data(observation=obs_filtered, action=robot_action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_time := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_time)
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
policy_active.clear()
|
||||
teleop_disable_torque(teleop)
|
||||
|
||||
for frame in frame_buffer:
|
||||
dataset.add_frame(frame)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def hil_rtc_collect(cfg: HILRTCConfig) -> LeRobotDataset:
|
||||
"""Main HIL data collection function with RTC."""
|
||||
init_logging()
|
||||
logger.info(pformat(cfg.__dict__))
|
||||
|
||||
if cfg.display_data:
|
||||
init_rerun(session_name="hil_rtc_collection")
|
||||
|
||||
robot_raw = make_robot_from_config(cfg.robot)
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
|
||||
teleop_proc, obs_proc = make_identity_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_proc,
|
||||
initial_features=create_initial_features(action=robot_raw.action_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=obs_proc,
|
||||
initial_features=create_initial_features(observation=robot_raw.observation_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
)
|
||||
|
||||
dataset = None
|
||||
listener = None
|
||||
shutdown_event = Event()
|
||||
policy_active = Event()
|
||||
rtc_thread = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
if hasattr(robot_raw, "cameras") and robot_raw.cameras:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.dataset.num_image_writer_processes,
|
||||
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras),
|
||||
)
|
||||
else:
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot_raw.name,
|
||||
features=dataset_features,
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||
* len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
# Load policy with RTC
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
if hasattr(policy_config, "compile_model"):
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config)
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
if hasattr(policy, "init_rtc_processor"):
|
||||
policy.init_rtc_processor()
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot_raw.connect()
|
||||
robot = ThreadSafeRobot(robot_raw)
|
||||
teleop.connect()
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
queue_holder = {"queue": ActionQueue(cfg.rtc)}
|
||||
obs_holder = {"obs": None, "robot_type": robot.robot_type}
|
||||
hw_features = hw_to_dataset_features(robot_raw.observation_features, "observation")
|
||||
|
||||
rtc_thread = Thread(
|
||||
target=rtc_inference_thread,
|
||||
args=(
|
||||
policy,
|
||||
obs_holder,
|
||||
hw_features,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
queue_holder,
|
||||
shutdown_event,
|
||||
policy_active,
|
||||
cfg,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
rtc_thread.start()
|
||||
|
||||
print_controls(rtc=True)
|
||||
print(f" Policy: {cfg.policy.pretrained_path}")
|
||||
print(f" Task: {cfg.dataset.single_task}")
|
||||
print(f" Interpolation: {cfg.interpolation_multiplier}x\n")
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded = 0
|
||||
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
|
||||
queue_holder["queue"] = ActionQueue(cfg.rtc)
|
||||
|
||||
rollout_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
cfg=cfg,
|
||||
queue_holder=queue_holder,
|
||||
obs_holder=obs_holder,
|
||||
policy_active=policy_active,
|
||||
hw_features=hw_features,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
|
||||
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
reset_loop(robot, teleop, events, cfg.dataset.fps)
|
||||
|
||||
finally:
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
|
||||
shutdown_event.set()
|
||||
policy_active.clear()
|
||||
|
||||
if rtc_thread and rtc_thread.is_alive():
|
||||
rtc_thread.join(timeout=2.0)
|
||||
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
|
||||
if robot_raw.is_connected:
|
||||
robot_raw.disconnect()
|
||||
if teleop.is_connected:
|
||||
teleop.disconnect()
|
||||
|
||||
if not is_headless() and listener:
|
||||
listener.stop()
|
||||
|
||||
if cfg.dataset.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
register_third_party_plugins()
|
||||
hil_rtc_collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,206 +0,0 @@
|
||||
"""Shared utilities for Human-in-the-Loop data collection scripts."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.control_utils import is_headless
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILDatasetConfig:
|
||||
repo_id: str
|
||||
single_task: str
|
||||
root: str | Path | None = None
|
||||
fps: int = 30
|
||||
episode_time_s: float = 120
|
||||
num_episodes: int = 50
|
||||
video: bool = True
|
||||
push_to_hub: bool = True
|
||||
private: bool = False
|
||||
tags: list[str] | None = None
|
||||
num_image_writer_processes: int = 0
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
video_encoding_batch_size: int = 1
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
|
||||
"""Check if teleoperator has motor control capabilities."""
|
||||
return hasattr(teleop, "bus") and hasattr(teleop.bus, "disable_torque")
|
||||
|
||||
|
||||
def teleop_disable_torque(teleop: Teleoperator) -> None:
|
||||
"""Disable teleop torque if supported."""
|
||||
if teleop_has_motor_control(teleop):
|
||||
teleop.bus.disable_torque()
|
||||
|
||||
|
||||
def teleop_enable_torque(teleop: Teleoperator) -> None:
|
||||
"""Enable teleop torque if supported."""
|
||||
if teleop_has_motor_control(teleop):
|
||||
teleop.bus.enable_torque()
|
||||
|
||||
|
||||
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
|
||||
"""Smoothly move teleop to target position if motor control is available."""
|
||||
if not teleop_has_motor_control(teleop):
|
||||
logger.warning("Teleop does not support motor control - cannot mirror robot position")
|
||||
return
|
||||
|
||||
teleop_enable_torque(teleop)
|
||||
current = teleop.get_action()
|
||||
steps = int(duration_s * fps)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {}
|
||||
for k in current:
|
||||
if k in target_pos:
|
||||
interp[k] = current[k] * (1 - t) + target_pos[k] * t
|
||||
else:
|
||||
interp[k] = current[k]
|
||||
teleop.bus.sync_write("Goal_Position", {k.replace(".pos", ""): v for k, v in interp.items()})
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
"""Initialize keyboard listener with HIL controls."""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
"policy_paused": False,
|
||||
"correction_active": False,
|
||||
"in_reset": False,
|
||||
"start_next_episode": False,
|
||||
}
|
||||
|
||||
if is_headless():
|
||||
logger.warning("Headless environment - keyboard controls unavailable")
|
||||
return None, events
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events["in_reset"]:
|
||||
if key in [keyboard.Key.space, keyboard.Key.right]:
|
||||
print("\n[HIL] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[HIL] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
if key == keyboard.Key.space:
|
||||
if not events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[HIL] ⏸ PAUSED - Press 'c' to take control")
|
||||
events["policy_paused"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[HIL] ▶ Taking control...")
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
print("[HIL] → End episode")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("[HIL] ← Re-record episode")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[HIL] ESC - Stop recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
print(f"Key error: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
return listener, events
|
||||
|
||||
|
||||
def make_identity_processors():
|
||||
"""Create identity processors for recording."""
|
||||
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
return teleop_proc, obs_proc
|
||||
|
||||
|
||||
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
|
||||
"""Reset period where human repositions environment."""
|
||||
print("\n" + "=" * 60)
|
||||
print(" [HIL] RESET")
|
||||
print("=" * 60)
|
||||
|
||||
events["in_reset"] = True
|
||||
events["start_next_episode"] = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
print(" Press any key to enable teleoperation")
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events["stop_recording"]:
|
||||
return
|
||||
|
||||
events["start_next_episode"] = False
|
||||
teleop_disable_torque(teleop)
|
||||
print(" Teleop enabled - press any key to start episode")
|
||||
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
loop_start = time.perf_counter()
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
|
||||
|
||||
events["in_reset"] = False
|
||||
events["start_next_episode"] = False
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
|
||||
|
||||
def print_controls(rtc: bool = False):
|
||||
"""Print control instructions."""
|
||||
print("\n" + "=" * 60)
|
||||
print(" Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else ""))
|
||||
print("=" * 60)
|
||||
print()
|
||||
print(" Controls:")
|
||||
print(" SPACE - Pause policy")
|
||||
print(" c - Take control")
|
||||
print(" → - End episode")
|
||||
print(" ESC - Stop and push to hub")
|
||||
print("=" * 60 + "\n")
|
||||
@@ -78,12 +78,15 @@ from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor.factory import (
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
@@ -95,6 +98,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
so_follower,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
@@ -149,7 +153,6 @@ class RTCDemoConfig(HubMixin):
|
||||
# Demo parameters
|
||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
||||
fps: float = 10.0 # Action execution frequency (Hz)
|
||||
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
||||
|
||||
# Compute device
|
||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||
@@ -350,22 +353,20 @@ def actor_control(
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
|
||||
action_count = 0
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
action_interval = interpolator.get_control_interval(cfg.fps)
|
||||
action_interval = 1.0 / cfg.fps
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
new_action = action_queue.get()
|
||||
if new_action is not None:
|
||||
interpolator.add(new_action.cpu())
|
||||
# Try to get an action from the queue with timeout
|
||||
action = action_queue.get()
|
||||
|
||||
action = interpolator.get()
|
||||
if action is not None:
|
||||
action = action.cpu()
|
||||
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
|
||||
action_count += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
|
||||
+59
-117
@@ -25,11 +25,11 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.4.4"
|
||||
version = "0.5.1"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
dynamic = ["readme"]
|
||||
license = { text = "Apache-2.0" }
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.12"
|
||||
authors = [
|
||||
{ name = "Rémi Cadène", email = "re.cadene@gmail.com" },
|
||||
{ name = "Simon Alibert", email = "alibert.sim@gmail.com" },
|
||||
@@ -50,7 +50,8 @@ classifiers = [
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Software Development :: Build Tools",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
]
|
||||
@@ -61,26 +62,28 @@ dependencies = [
|
||||
# Hugging Face dependencies
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"diffusers>=0.27.2,<0.36.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||
"huggingface-hub>=1.0.0,<2.0.0",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
|
||||
# Core dependencies
|
||||
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
|
||||
"setuptools>=71.0.0,<81.0.0",
|
||||
"cmake>=3.29.0.1,<4.2.0",
|
||||
"packaging>=24.2,<26.0",
|
||||
|
||||
"torch>=2.2.1,<2.11.0",
|
||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||
"torchvision>=0.21.0,<0.26.0",
|
||||
|
||||
"einops>=0.8.0,<0.9.0",
|
||||
"opencv-python-headless>=4.9.0,<4.13.0",
|
||||
"av>=15.0.0,<16.0.0",
|
||||
"jsonlines>=4.0.0,<5.0.0",
|
||||
"packaging>=24.2,<26.0",
|
||||
"pynput>=1.7.7,<1.9.0",
|
||||
"pynput>=1.7.8,<1.9.0",
|
||||
"pyserial>=3.5,<4.0",
|
||||
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
|
||||
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
|
||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
|
||||
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
|
||||
|
||||
"draccus==0.10.0", # TODO: Remove ==
|
||||
"draccus==0.10.0", # TODO: Relax version constraint
|
||||
"gymnasium>=1.1.1,<2.0.0",
|
||||
"rerun-sdk>=0.24.0,<0.27.0",
|
||||
|
||||
@@ -95,10 +98,14 @@ dependencies = [
|
||||
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||
transformers-dep = ["transformers>=5.3.0,<6.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
|
||||
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
@@ -112,34 +119,36 @@ gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
unitree_g1 = [
|
||||
"unitree-sdk2==1.0.1",
|
||||
"pyzmq>=26.2.1,<28.0.0",
|
||||
"onnxruntime>=1.16.0,<2.0.0",
|
||||
"pin>=3.0.0,<4.0.0",
|
||||
"meshcat>=0.3.0,<0.4.0",
|
||||
"matplotlib>=3.9.0,<4.0.0",
|
||||
"lerobot[matplotlib-dep]",
|
||||
"lerobot[pygame-dep]",
|
||||
"casadi>=3.6.0,<4.0.0",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54,<2.57.0 ; sys_platform == 'darwin'",
|
||||
]
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0", "lerobot[scipy-dep]"]
|
||||
|
||||
# Policies
|
||||
wallx = [
|
||||
"transformers==4.49.0",
|
||||
"peft==0.17.1",
|
||||
"scipy==1.15.3",
|
||||
"torchdiffeq==0.2.5",
|
||||
"qwen_vl_utils==0.0.11"
|
||||
"lerobot[transformers-dep]",
|
||||
"lerobot[peft]",
|
||||
"lerobot[scipy-dep]",
|
||||
"torchdiffeq>=0.2.4,<0.3.0",
|
||||
"lerobot[qwen-vl-utils-dep]",
|
||||
]
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
"peft>=0.13.0,<1.0.0",
|
||||
"lerobot[peft]",
|
||||
"dm-tree>=0.1.8,<1.0.0",
|
||||
"timm>=1.0.0,<1.1.0",
|
||||
"safetensors>=0.4.3,<1.0.0",
|
||||
@@ -148,13 +157,13 @@ groot = [
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14,<0.1.0"]
|
||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
peft = ["lerobot[transformers-dep]", "peft>=0.18.0,<1.0.0"]
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
|
||||
@@ -162,13 +171,24 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0"]
|
||||
metaworld = ["metaworld==3.0.0"]
|
||||
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero_plus = [
|
||||
"lerobot[transformers-dep]",
|
||||
"libero @ git+https://github.com/sylvestf/LIBERO-plus.git@main ; sys_platform == 'linux'",
|
||||
"lerobot[scipy-dep]",
|
||||
]
|
||||
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
|
||||
# All
|
||||
all = [
|
||||
# NOTE(resolver hint): scipy is pulled in transitively via lerobot[scipy-dep] through
|
||||
# multiple extras (aloha, metaworld, pi, wallx, phone). Listing it explicitly
|
||||
# helps pip's resolver converge by constraining scipy early, before it encounters
|
||||
# the loose scipy requirements from transitive deps like dm-control and metaworld.
|
||||
"scipy>=1.14.0,<2.0.0",
|
||||
"lerobot[dynamixel]",
|
||||
"lerobot[gamepad]",
|
||||
"lerobot[hopejr]",
|
||||
@@ -176,8 +196,8 @@ all = [
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
# "lerobot[wallx]",
|
||||
# "lerobot[pi]", TODO(Pepijn): Update pi to transformers v5
|
||||
"lerobot[wallx]",
|
||||
"lerobot[pi]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
@@ -189,10 +209,11 @@ all = [
|
||||
"lerobot[aloha]",
|
||||
"lerobot[pusht]",
|
||||
"lerobot[phone]",
|
||||
"lerobot[libero]",
|
||||
"lerobot[libero]; sys_platform == 'linux'",
|
||||
"lerobot[metaworld]",
|
||||
"lerobot[sarm]",
|
||||
"lerobot[peft]",
|
||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -214,11 +235,14 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
target-version = "py312"
|
||||
line-length = 110
|
||||
exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
|
||||
|
||||
@@ -310,7 +334,7 @@ default.extend-ignore-identifiers-re = [
|
||||
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
python_version = "3.12"
|
||||
ignore_missing_imports = true
|
||||
follow_imports = "skip"
|
||||
# warn_return_any = true
|
||||
@@ -394,85 +418,3 @@ ignore_errors = false
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.scripts.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[tool.uv]
|
||||
# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "transformers-dep" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "pi" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "smolvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "groot" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "xvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "sarm" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "hilserl" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "libero" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "peft" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "all" },
|
||||
],
|
||||
# pi uses custom branch which conflicts with transformers-dep
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "transformers-dep" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "smolvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "groot" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "xvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "sarm" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "hilserl" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "libero" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "peft" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "all" },
|
||||
],
|
||||
]
|
||||
|
||||
+170
-271
@@ -1,76 +1,73 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with Python 3.10
|
||||
# This file is autogenerated by pip-compile with Python 3.12
|
||||
# by the following command:
|
||||
#
|
||||
# pip-compile --output-file=requirements-macos.txt requirements.in
|
||||
#
|
||||
-e .[all]
|
||||
# via -[all]
|
||||
absl-py==2.3.1
|
||||
absl-py==2.4.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# labmaze
|
||||
# mujoco
|
||||
# tensorboard
|
||||
accelerate==1.11.0
|
||||
accelerate==1.13.0
|
||||
# via
|
||||
# lerobot
|
||||
# peft
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.1
|
||||
aiohttp==3.13.3
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-doc==0.0.4
|
||||
# via
|
||||
# fastapi
|
||||
# typer
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via
|
||||
# hydra-core
|
||||
# omegaconf
|
||||
anyio==4.11.0
|
||||
anyio==4.12.1
|
||||
# via
|
||||
# httpx
|
||||
# starlette
|
||||
# watchfiles
|
||||
asttokens==3.0.0
|
||||
asttokens==3.0.1
|
||||
# via stack-data
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
# dm-tree
|
||||
# jsonlines
|
||||
# jsonschema
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
av==15.1.0
|
||||
# via lerobot
|
||||
bddl==1.0.1
|
||||
# via libero
|
||||
certifi==2025.10.5
|
||||
# via
|
||||
# lerobot
|
||||
# qwen-vl-utils
|
||||
certifi==2026.2.25
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==2.0.0
|
||||
# via pymunk
|
||||
cfgv==3.4.0
|
||||
cfgv==3.5.0
|
||||
# via pre-commit
|
||||
charset-normalizer==3.4.4
|
||||
charset-normalizer==3.4.5
|
||||
# via requests
|
||||
click==8.3.0
|
||||
click==8.3.1
|
||||
# via
|
||||
# typer
|
||||
# uvicorn
|
||||
# wandb
|
||||
cloudpickle==3.1.1
|
||||
# via
|
||||
# gymnasium
|
||||
# libero
|
||||
cmake==4.1.0
|
||||
cloudpickle==3.1.2
|
||||
# via gymnasium
|
||||
cmake==4.1.3
|
||||
# via lerobot
|
||||
cmeel==0.57.3
|
||||
cmeel==0.59.0
|
||||
# via
|
||||
# cmeel-assimp
|
||||
# cmeel-boost
|
||||
@@ -108,15 +105,17 @@ cmeel-zlib==1.3.1
|
||||
# via cmeel-assimp
|
||||
coal-library==3.0.1
|
||||
# via pin
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
coverage[toml]==7.11.0
|
||||
contourpy==1.3.3
|
||||
# via
|
||||
# lerobot
|
||||
# matplotlib
|
||||
coverage[toml]==7.13.4
|
||||
# via pytest-cov
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
datasets==4.1.1
|
||||
datasets==4.6.1
|
||||
# via lerobot
|
||||
debugpy==1.8.17
|
||||
debugpy==1.8.20
|
||||
# via lerobot
|
||||
decorator==5.2.1
|
||||
# via ipython
|
||||
@@ -130,7 +129,7 @@ dill==0.4.0
|
||||
# multiprocess
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
dm-control==1.0.34
|
||||
dm-control==1.0.37
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
@@ -138,69 +137,55 @@ dm-tree==0.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# lerobot
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
draccus==0.10.0
|
||||
# via lerobot
|
||||
dynamixel-sdk==3.8.4
|
||||
# via lerobot
|
||||
easydict==1.13
|
||||
# via libero
|
||||
egl-probe @ git+https://github.com/huggingface/egl_probe.git
|
||||
# via
|
||||
# libero
|
||||
# robomimic
|
||||
eigenpy==3.10.3
|
||||
# via coal-library
|
||||
einops==0.8.1
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
einops==0.8.2
|
||||
# via lerobot
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
etils[epath,epy]==1.13.0
|
||||
etils[epath,epy]==1.14.0
|
||||
# via mujoco
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# ipython
|
||||
# pytest
|
||||
executing==2.2.1
|
||||
# via stack-data
|
||||
faker==34.0.2
|
||||
# via lerobot
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fastapi==0.119.1
|
||||
# via teleop
|
||||
fastjsonschema==2.21.2
|
||||
# via nbformat
|
||||
fastapi==0.135.1
|
||||
# via
|
||||
# lerobot
|
||||
# teleop
|
||||
feetech-servo-sdk==1.0.0
|
||||
# via lerobot
|
||||
filelock==3.20.0
|
||||
filelock==3.25.0
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# huggingface-hub
|
||||
# python-discovery
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
fonttools==4.60.1
|
||||
fonttools==4.61.1
|
||||
# via matplotlib
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2025.9.0
|
||||
fsspec[http]==2026.2.0
|
||||
# via
|
||||
# datasets
|
||||
# etils
|
||||
# huggingface-hub
|
||||
# torch
|
||||
future==1.0.0
|
||||
# via libero
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.45
|
||||
gitpython==3.1.46
|
||||
# via wandb
|
||||
glfw==2.10.0
|
||||
# via
|
||||
@@ -212,7 +197,6 @@ grpcio==1.73.1
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
grpcio-tools==1.73.1
|
||||
# via
|
||||
# lerobot
|
||||
@@ -223,71 +207,67 @@ gym-hil==0.1.13
|
||||
# via lerobot
|
||||
gym-pusht==0.1.6
|
||||
# via lerobot
|
||||
gymnasium==1.2.1
|
||||
gymnasium==1.2.3
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
# libero
|
||||
# metaworld
|
||||
h11==0.16.0
|
||||
# via uvicorn
|
||||
h5py==3.15.1
|
||||
# via robomimic
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
hebi-py==2.11.0
|
||||
# via lerobot
|
||||
hf-transfer==0.1.9
|
||||
# via huggingface-hub
|
||||
hf-xet==1.1.10
|
||||
hf-xet==1.3.2
|
||||
# via huggingface-hub
|
||||
hidapi==0.14.0.post4
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httptools==0.7.1
|
||||
# via uvicorn
|
||||
huggingface-hub[cli,hf-transfer]==0.35.3
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
huggingface-hub==1.6.0
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
hydra-core==1.3.2
|
||||
# via libero
|
||||
identify==2.6.15
|
||||
identify==2.6.17
|
||||
# via pre-commit
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
imageio[ffmpeg]==2.37.0
|
||||
imageio[ffmpeg]==2.37.2
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robomimic
|
||||
# scikit-image
|
||||
imageio-ffmpeg==0.6.0
|
||||
# via
|
||||
# imageio
|
||||
# robomimic
|
||||
importlib-metadata==8.7.0
|
||||
# via imageio
|
||||
importlib-metadata==8.7.1
|
||||
# via diffusers
|
||||
importlib-resources==6.5.2
|
||||
# via etils
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
inquirerpy==0.3.4
|
||||
# via huggingface-hub
|
||||
ipython==8.37.0
|
||||
ipython==9.11.0
|
||||
# via meshcat
|
||||
ipython-pygments-lexers==1.1.1
|
||||
# via ipython
|
||||
ischedule==1.2.7
|
||||
# via placo
|
||||
jedi==0.19.2
|
||||
@@ -296,44 +276,24 @@ jinja2==3.1.6
|
||||
# via torch
|
||||
jsonlines==4.0.0
|
||||
# via lerobot
|
||||
jsonschema==4.25.1
|
||||
# via nbformat
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
jupyter-core==5.9.1
|
||||
# via nbformat
|
||||
jupytext==1.18.1
|
||||
# via bddl
|
||||
kiwisolver==1.4.9
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lazy-loader==0.4
|
||||
lazy-loader==0.5
|
||||
# via scikit-image
|
||||
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
|
||||
# via lerobot
|
||||
llvmlite==0.45.1
|
||||
# via numba
|
||||
librt==0.8.1
|
||||
# via mypy
|
||||
lxml==6.0.2
|
||||
# via dm-control
|
||||
markdown==3.9
|
||||
# via tensorboard
|
||||
markdown-it-py==4.0.0
|
||||
# via
|
||||
# jupytext
|
||||
# mdit-py-plugins
|
||||
# via rich
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.10.7
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
# via jinja2
|
||||
matplotlib==3.10.8
|
||||
# via lerobot
|
||||
matplotlib-inline==0.2.1
|
||||
# via ipython
|
||||
mdit-py-plugins==0.5.0
|
||||
# via jupytext
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mergedeep==1.3.4
|
||||
@@ -346,41 +306,35 @@ mock-serial==0.0.1
|
||||
# via lerobot
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mujoco==3.3.7
|
||||
mujoco==3.5.0
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# libero
|
||||
# metaworld
|
||||
# robosuite
|
||||
multidict==6.7.0
|
||||
multidict==6.7.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
multiprocess==0.70.16
|
||||
multiprocess==0.70.18
|
||||
# via datasets
|
||||
mypy==1.19.1
|
||||
# via lerobot
|
||||
mypy-extensions==1.1.0
|
||||
# via typing-inspect
|
||||
nbformat==5.10.4
|
||||
# via jupytext
|
||||
networkx==3.4.2
|
||||
# via
|
||||
# bddl
|
||||
# mypy
|
||||
# typing-inspect
|
||||
networkx==3.6.1
|
||||
# via
|
||||
# scikit-image
|
||||
# torch
|
||||
ninja==1.13.0
|
||||
# via lerobot
|
||||
nodeenv==1.9.1
|
||||
nodeenv==1.10.0
|
||||
# via pre-commit
|
||||
num2words==0.5.14
|
||||
# via lerobot
|
||||
numba==0.62.1
|
||||
# via robosuite
|
||||
numpy==2.2.6
|
||||
# via
|
||||
# accelerate
|
||||
# bddl
|
||||
# cmeel-boost
|
||||
# contourpy
|
||||
# datasets
|
||||
@@ -389,16 +343,14 @@ numpy==2.2.6
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# gymnasium
|
||||
# h5py
|
||||
# hebi-py
|
||||
# imageio
|
||||
# labmaze
|
||||
# libero
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# metaworld
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# opencv-python-headless
|
||||
# pandas
|
||||
@@ -406,26 +358,18 @@ numpy==2.2.6
|
||||
# pyquaternion
|
||||
# reachy2-sdk
|
||||
# rerun-sdk
|
||||
# robomimic
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# scipy
|
||||
# shapely
|
||||
# teleop
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# tifffile
|
||||
# torchvision
|
||||
# transformers
|
||||
# transforms3d
|
||||
omegaconf==2.3.0
|
||||
# via hydra-core
|
||||
opencv-python==4.12.0.88
|
||||
opencv-python==4.13.0.92
|
||||
# via
|
||||
# gym-pusht
|
||||
# libero
|
||||
# reachy2-sdk
|
||||
# robosuite
|
||||
opencv-python-headless==4.12.0.88
|
||||
# via lerobot
|
||||
orderly-set==5.5.0
|
||||
@@ -435,97 +379,87 @@ packaging==25.0
|
||||
# accelerate
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# hydra-core
|
||||
# jupytext
|
||||
# lazy-loader
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# peft
|
||||
# pytest
|
||||
# qwen-vl-utils
|
||||
# reachy2-sdk
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# transformers
|
||||
# wandb
|
||||
pandas==2.3.3
|
||||
# via
|
||||
# datasets
|
||||
# lerobot
|
||||
parso==0.8.5
|
||||
parso==0.8.6
|
||||
# via jedi
|
||||
peft==0.17.1
|
||||
pathspec==1.0.4
|
||||
# via mypy
|
||||
peft==0.18.1
|
||||
# via lerobot
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pfzy==0.3.4
|
||||
# via inquirerpy
|
||||
pillow==12.0.0
|
||||
pillow==12.1.1
|
||||
# via
|
||||
# diffusers
|
||||
# imageio
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# qwen-vl-utils
|
||||
# rerun-sdk
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# torchvision
|
||||
pin==3.4.0
|
||||
# via placo
|
||||
placo==0.9.14
|
||||
placo==0.9.16
|
||||
# via lerobot
|
||||
platformdirs==4.5.0
|
||||
platformdirs==4.9.4
|
||||
# via
|
||||
# jupyter-core
|
||||
# python-discovery
|
||||
# virtualenv
|
||||
# wandb
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
pre-commit==4.3.0
|
||||
pre-commit==4.5.1
|
||||
# via lerobot
|
||||
prompt-toolkit==3.0.52
|
||||
# via
|
||||
# inquirerpy
|
||||
# ipython
|
||||
# via ipython
|
||||
propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
protobuf==6.31.0
|
||||
protobuf==6.31.1
|
||||
# via
|
||||
# dm-control
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# wandb
|
||||
psutil==7.1.1
|
||||
psutil==7.2.2
|
||||
# via
|
||||
# accelerate
|
||||
# imageio
|
||||
# peft
|
||||
# robomimic
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
# via stack-data
|
||||
pyarrow==21.0.0
|
||||
pyarrow==23.0.1
|
||||
# via
|
||||
# datasets
|
||||
# rerun-sdk
|
||||
pycparser==2.23
|
||||
pycparser==3.0
|
||||
# via cffi
|
||||
pydantic==2.12.3
|
||||
pydantic==2.12.5
|
||||
# via
|
||||
# fastapi
|
||||
# wandb
|
||||
pydantic-core==2.41.4
|
||||
pydantic-core==2.41.5
|
||||
# via pydantic
|
||||
pygame==2.6.1
|
||||
# via
|
||||
@@ -535,33 +469,35 @@ pygame==2.6.1
|
||||
pygments==2.19.2
|
||||
# via
|
||||
# ipython
|
||||
# ipython-pygments-lexers
|
||||
# pytest
|
||||
# rich
|
||||
pymunk==6.11.1
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pyngrok==7.4.1
|
||||
pyngrok==7.5.1
|
||||
# via meshcat
|
||||
pynput==1.8.1
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
pyobjc-core==12.0
|
||||
pyobjc-core==12.1
|
||||
# via
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-cocoa
|
||||
# pyobjc-framework-coretext
|
||||
# pyobjc-framework-quartz
|
||||
pyobjc-framework-applicationservices==12.0
|
||||
pyobjc-framework-applicationservices==12.1
|
||||
# via pynput
|
||||
pyobjc-framework-cocoa==12.0
|
||||
pyobjc-framework-cocoa==12.1
|
||||
# via
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-coretext
|
||||
# pyobjc-framework-quartz
|
||||
pyobjc-framework-coretext==12.0
|
||||
pyobjc-framework-coretext==12.1
|
||||
# via pyobjc-framework-applicationservices
|
||||
pyobjc-framework-quartz==12.0
|
||||
pyobjc-framework-quartz==12.1
|
||||
# via
|
||||
# pynput
|
||||
# pyobjc-framework-applicationservices
|
||||
@@ -570,13 +506,13 @@ pyopengl==3.1.10
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.5
|
||||
pyparsing==3.3.2
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via reachy2-sdk
|
||||
pyrealsense2-macosx==2.54.2
|
||||
pyrealsense2-macosx==2.56.5
|
||||
# via lerobot
|
||||
pyserial==3.5
|
||||
# via
|
||||
@@ -585,7 +521,6 @@ pyserial==3.5
|
||||
# lerobot
|
||||
pytest==8.4.2
|
||||
# via
|
||||
# bddl
|
||||
# lerobot
|
||||
# pytest-cov
|
||||
# pytest-timeout
|
||||
@@ -596,11 +531,14 @@ pytest-timeout==2.4.0
|
||||
# via lerobot
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# faker
|
||||
# matplotlib
|
||||
# pandas
|
||||
python-dotenv==1.1.1
|
||||
python-discovery==1.1.1
|
||||
# via virtualenv
|
||||
python-dotenv==1.2.2
|
||||
# via uvicorn
|
||||
pytz==2025.2
|
||||
pytz==2026.1.post1
|
||||
# via pandas
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
@@ -609,13 +547,10 @@ pyyaml==6.0.3
|
||||
# draccus
|
||||
# hebi-py
|
||||
# huggingface-hub
|
||||
# jupytext
|
||||
# omegaconf
|
||||
# peft
|
||||
# pre-commit
|
||||
# pyngrok
|
||||
# pyyaml-include
|
||||
# timm
|
||||
# transformers
|
||||
# uvicorn
|
||||
# wandb
|
||||
@@ -625,15 +560,13 @@ pyzmq==27.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# meshcat
|
||||
reachy2-sdk==1.0.14
|
||||
qwen-vl-utils==0.0.14
|
||||
# via lerobot
|
||||
reachy2-sdk==1.0.15
|
||||
# via lerobot
|
||||
reachy2-sdk-api==1.0.21
|
||||
# via reachy2-sdk
|
||||
referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2025.10.23
|
||||
regex==2026.2.28
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
@@ -642,184 +575,150 @@ requests==2.32.5
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# qwen-vl-utils
|
||||
# teleop
|
||||
# transformers
|
||||
# wandb
|
||||
rerun-sdk==0.26.1
|
||||
rerun-sdk==0.26.2
|
||||
# via lerobot
|
||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||
# via placo
|
||||
robomimic==0.2.0
|
||||
# via libero
|
||||
robosuite==1.4.0
|
||||
# via libero
|
||||
rpds-py==0.28.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.6.2
|
||||
rich==14.3.3
|
||||
# via typer
|
||||
safetensors==0.7.0
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# transformers
|
||||
scikit-image==0.25.2
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
scipy==1.15.3
|
||||
scipy==1.17.1
|
||||
# via
|
||||
# dm-control
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robosuite
|
||||
# scikit-image
|
||||
sentry-sdk==2.42.1
|
||||
# torchdiffeq
|
||||
sentry-sdk==2.54.0
|
||||
# via wandb
|
||||
shapely==2.1.2
|
||||
# via gym-pusht
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.17.0
|
||||
# via
|
||||
# pynput
|
||||
# python-dateutil
|
||||
smmap==5.0.2
|
||||
smmap==5.0.3
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
starlette==0.48.0
|
||||
starlette==0.52.1
|
||||
# via fastapi
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
teleop==0.1.2
|
||||
teleop==0.1.4
|
||||
# via lerobot
|
||||
tensorboard==2.20.0
|
||||
# via robomimic
|
||||
tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
tensorboardx==2.6.4
|
||||
# via robomimic
|
||||
termcolor==3.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
thop==0.1.1.post2209072238
|
||||
# via libero
|
||||
tifffile==2025.5.10
|
||||
termcolor==3.3.0
|
||||
# via lerobot
|
||||
tifffile==2026.3.3
|
||||
# via scikit-image
|
||||
timm==1.0.20
|
||||
# via lerobot
|
||||
tokenizers==0.22.1
|
||||
tokenizers==0.22.2
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via draccus
|
||||
tomli==2.3.0
|
||||
# via
|
||||
# cmeel
|
||||
# coverage
|
||||
# jupytext
|
||||
# pytest
|
||||
torch==2.7.1
|
||||
torch==2.10.0
|
||||
# via
|
||||
# accelerate
|
||||
# lerobot
|
||||
# peft
|
||||
# robomimic
|
||||
# thop
|
||||
# timm
|
||||
# torchdiffeq
|
||||
# torchvision
|
||||
torchcodec==0.5
|
||||
torchcodec==0.10.0
|
||||
# via lerobot
|
||||
torchvision==0.22.1
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
# timm
|
||||
tornado==6.5.2
|
||||
torchdiffeq==0.2.5
|
||||
# via lerobot
|
||||
torchvision==0.25.0
|
||||
# via lerobot
|
||||
tornado==6.5.4
|
||||
# via meshcat
|
||||
tqdm==4.67.1
|
||||
tqdm==4.67.3
|
||||
# via
|
||||
# datasets
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# peft
|
||||
# robomimic
|
||||
# transformers
|
||||
traitlets==5.14.3
|
||||
# via
|
||||
# ipython
|
||||
# jupyter-core
|
||||
# matplotlib-inline
|
||||
# nbformat
|
||||
transformers==4.57.1
|
||||
transformers==5.3.0
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
# peft
|
||||
transforms3d==0.4.2
|
||||
# via teleop
|
||||
typer==0.24.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# etils
|
||||
# exceptiongroup
|
||||
# faker
|
||||
# fastapi
|
||||
# gymnasium
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# multidict
|
||||
# mypy
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
# starlette
|
||||
# torch
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# uvicorn
|
||||
# virtualenv
|
||||
# wandb
|
||||
typing-inspect==0.9.0
|
||||
# via draccus
|
||||
typing-inspection==0.4.2
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via
|
||||
# fastapi
|
||||
# pydantic
|
||||
tzdata==2025.3
|
||||
# via pandas
|
||||
u-msgpack-python==2.8.0
|
||||
# via meshcat
|
||||
urllib3==2.5.0
|
||||
urllib3==2.6.3
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
uvicorn[standard]==0.38.0
|
||||
uvicorn[standard]==0.41.0
|
||||
# via teleop
|
||||
uvloop==0.22.1
|
||||
# via uvicorn
|
||||
virtualenv==20.35.3
|
||||
virtualenv==21.1.0
|
||||
# via pre-commit
|
||||
wandb==0.21.4
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
wandb==0.24.2
|
||||
# via lerobot
|
||||
watchfiles==1.1.1
|
||||
# via uvicorn
|
||||
wcwidth==0.2.14
|
||||
wcwidth==0.6.0
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.9.0
|
||||
# via teleop
|
||||
websockets==15.0.1
|
||||
websockets==16.0
|
||||
# via uvicorn
|
||||
werkzeug==3.1.3
|
||||
# via tensorboard
|
||||
wrapt==2.0.0
|
||||
wrapt==2.1.2
|
||||
# via dm-tree
|
||||
xxhash==3.6.0
|
||||
# via datasets
|
||||
yarl==1.22.0
|
||||
yarl==1.23.0
|
||||
# via aiohttp
|
||||
zipp==3.23.0
|
||||
# via
|
||||
|
||||
+209
-188
@@ -1,12 +1,12 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with Python 3.10
|
||||
# This file is autogenerated by pip-compile with Python 3.12
|
||||
# by the following command:
|
||||
#
|
||||
# pip-compile --output-file=requirements-ubuntu.txt requirements.in
|
||||
#
|
||||
-e .[all]
|
||||
# via -[all]
|
||||
absl-py==2.3.1
|
||||
absl-py==2.4.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
@@ -14,30 +14,33 @@ absl-py==2.3.1
|
||||
# labmaze
|
||||
# mujoco
|
||||
# tensorboard
|
||||
accelerate==1.11.0
|
||||
accelerate==1.13.0
|
||||
# via
|
||||
# lerobot
|
||||
# peft
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.1
|
||||
aiohttp==3.13.3
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-doc==0.0.4
|
||||
# via
|
||||
# fastapi
|
||||
# typer
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via
|
||||
# hydra-core
|
||||
# omegaconf
|
||||
anyio==4.11.0
|
||||
anyio==4.12.1
|
||||
# via
|
||||
# httpx
|
||||
# starlette
|
||||
# watchfiles
|
||||
asttokens==3.0.0
|
||||
asttokens==3.0.1
|
||||
# via stack-data
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
@@ -47,30 +50,35 @@ attrs==25.4.0
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
av==15.1.0
|
||||
# via lerobot
|
||||
bddl==1.0.1
|
||||
# via libero
|
||||
certifi==2025.10.5
|
||||
# via
|
||||
# lerobot
|
||||
# qwen-vl-utils
|
||||
bddl==1.0.1
|
||||
# via hf-libero
|
||||
certifi==2026.2.25
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==2.0.0
|
||||
# via pymunk
|
||||
cfgv==3.4.0
|
||||
cfgv==3.5.0
|
||||
# via pre-commit
|
||||
charset-normalizer==3.4.4
|
||||
charset-normalizer==3.4.5
|
||||
# via requests
|
||||
click==8.3.0
|
||||
click==8.3.1
|
||||
# via
|
||||
# typer
|
||||
# uvicorn
|
||||
# wandb
|
||||
cloudpickle==3.1.1
|
||||
cloudpickle==3.1.2
|
||||
# via
|
||||
# gymnasium
|
||||
# libero
|
||||
cmake==4.1.0
|
||||
# hf-libero
|
||||
cmake==4.1.3
|
||||
# via lerobot
|
||||
cmeel==0.57.3
|
||||
cmeel==0.59.0
|
||||
# via
|
||||
# cmeel-assimp
|
||||
# cmeel-boost
|
||||
@@ -108,20 +116,24 @@ cmeel-zlib==1.3.1
|
||||
# via cmeel-assimp
|
||||
coal-library==3.0.1
|
||||
# via pin
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
coverage[toml]==7.11.0
|
||||
contourpy==1.3.3
|
||||
# via
|
||||
# lerobot
|
||||
# matplotlib
|
||||
coverage[toml]==7.13.4
|
||||
# via pytest-cov
|
||||
cuda-bindings==12.9.4
|
||||
# via torch
|
||||
cuda-pathfinder==1.4.1
|
||||
# via cuda-bindings
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
datasets==4.1.1
|
||||
datasets==4.6.1
|
||||
# via lerobot
|
||||
debugpy==1.8.17
|
||||
debugpy==1.8.20
|
||||
# via lerobot
|
||||
decorator==5.2.1
|
||||
# via ipython
|
||||
decord==0.6.0
|
||||
# via lerobot
|
||||
deepdiff==8.6.1
|
||||
# via lerobot
|
||||
diffusers==0.35.2
|
||||
@@ -132,7 +144,7 @@ dill==0.4.0
|
||||
# multiprocess
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
dm-control==1.0.34
|
||||
dm-control==1.0.37
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
@@ -140,7 +152,6 @@ dm-tree==0.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# lerobot
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
draccus==0.10.0
|
||||
@@ -148,66 +159,60 @@ draccus==0.10.0
|
||||
dynamixel-sdk==3.8.4
|
||||
# via lerobot
|
||||
easydict==1.13
|
||||
# via libero
|
||||
egl-probe @ git+https://github.com/huggingface/egl_probe.git
|
||||
# via
|
||||
# libero
|
||||
# robomimic
|
||||
# via hf-libero
|
||||
egl-probe==1.0.2
|
||||
# via robomimic
|
||||
eigenpy==3.10.3
|
||||
# via coal-library
|
||||
einops==0.8.1
|
||||
einops==0.8.2
|
||||
# via
|
||||
# flash-attn
|
||||
# hf-libero
|
||||
# lerobot
|
||||
# libero
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
etils[epath,epy]==1.13.0
|
||||
etils[epath,epy]==1.14.0
|
||||
# via mujoco
|
||||
evdev==1.9.2
|
||||
evdev==1.9.3
|
||||
# via pynput
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# ipython
|
||||
# pytest
|
||||
executing==2.2.1
|
||||
# via stack-data
|
||||
faker==34.0.2
|
||||
# via lerobot
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fastapi==0.119.1
|
||||
# via teleop
|
||||
fastapi==0.135.1
|
||||
# via
|
||||
# lerobot
|
||||
# teleop
|
||||
fastjsonschema==2.21.2
|
||||
# via nbformat
|
||||
feetech-servo-sdk==1.0.0
|
||||
# via lerobot
|
||||
filelock==3.20.0
|
||||
filelock==3.25.0
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# huggingface-hub
|
||||
# python-discovery
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
flash-attn==2.8.3
|
||||
# via lerobot
|
||||
fonttools==4.60.1
|
||||
fonttools==4.61.1
|
||||
# via matplotlib
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2025.9.0
|
||||
fsspec[http]==2026.2.0
|
||||
# via
|
||||
# datasets
|
||||
# etils
|
||||
# huggingface-hub
|
||||
# torch
|
||||
future==1.0.0
|
||||
# via libero
|
||||
# via hf-libero
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.45
|
||||
gitpython==3.1.46
|
||||
# via wandb
|
||||
glfw==2.10.0
|
||||
# via
|
||||
@@ -230,50 +235,60 @@ gym-hil==0.1.13
|
||||
# via lerobot
|
||||
gym-pusht==0.1.6
|
||||
# via lerobot
|
||||
gymnasium==1.2.1
|
||||
gymnasium==1.2.3
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# hf-libero
|
||||
# lerobot
|
||||
# libero
|
||||
# metaworld
|
||||
h11==0.16.0
|
||||
# via uvicorn
|
||||
h5py==3.15.1
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
h5py==3.16.0
|
||||
# via robomimic
|
||||
hebi-py==2.11.0
|
||||
# via lerobot
|
||||
hf-transfer==0.1.9
|
||||
# via huggingface-hub
|
||||
hf-xet==1.1.10
|
||||
hf-egl-probe==1.0.2
|
||||
# via hf-libero
|
||||
hf-libero==0.1.3
|
||||
# via lerobot
|
||||
hf-xet==1.3.2
|
||||
# via huggingface-hub
|
||||
hidapi==0.14.0.post4
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httptools==0.7.1
|
||||
# via uvicorn
|
||||
huggingface-hub[cli,hf-transfer]==0.35.3
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
huggingface-hub==1.6.0
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
hydra-core==1.3.2
|
||||
# via libero
|
||||
identify==2.6.15
|
||||
# via hf-libero
|
||||
identify==2.6.17
|
||||
# via pre-commit
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
imageio[ffmpeg]==2.37.0
|
||||
imageio[ffmpeg]==2.37.2
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
@@ -285,16 +300,14 @@ imageio-ffmpeg==0.6.0
|
||||
# via
|
||||
# imageio
|
||||
# robomimic
|
||||
importlib-metadata==8.7.0
|
||||
importlib-metadata==8.7.1
|
||||
# via diffusers
|
||||
importlib-resources==6.5.2
|
||||
# via etils
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
inquirerpy==0.3.4
|
||||
# via huggingface-hub
|
||||
ipython==8.37.0
|
||||
ipython==9.11.0
|
||||
# via meshcat
|
||||
ipython-pygments-lexers==1.1.1
|
||||
# via ipython
|
||||
ischedule==1.2.7
|
||||
# via placo
|
||||
jedi==0.19.2
|
||||
@@ -303,40 +316,41 @@ jinja2==3.1.6
|
||||
# via torch
|
||||
jsonlines==4.0.0
|
||||
# via lerobot
|
||||
jsonschema==4.25.1
|
||||
jsonschema==4.26.0
|
||||
# via nbformat
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
jupyter-core==5.9.1
|
||||
# via nbformat
|
||||
jupytext==1.18.1
|
||||
jupytext==1.19.1
|
||||
# via bddl
|
||||
kiwisolver==1.4.9
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lazy-loader==0.4
|
||||
lazy-loader==0.5
|
||||
# via scikit-image
|
||||
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
|
||||
# via lerobot
|
||||
llvmlite==0.45.1
|
||||
librt==0.8.1
|
||||
# via mypy
|
||||
llvmlite==0.46.0
|
||||
# via numba
|
||||
lxml==6.0.2
|
||||
# via dm-control
|
||||
markdown==3.9
|
||||
markdown==3.10.2
|
||||
# via tensorboard
|
||||
markdown-it-py==4.0.0
|
||||
# via
|
||||
# jupytext
|
||||
# mdit-py-plugins
|
||||
# rich
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.10.7
|
||||
matplotlib==3.10.8
|
||||
# via
|
||||
# hf-libero
|
||||
# lerobot
|
||||
# libero
|
||||
matplotlib-inline==0.2.1
|
||||
# via ipython
|
||||
mdit-py-plugins==0.5.0
|
||||
@@ -353,36 +367,38 @@ mock-serial==0.0.1
|
||||
# via lerobot
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mujoco==3.3.7
|
||||
mujoco==3.5.0
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# libero
|
||||
# hf-libero
|
||||
# metaworld
|
||||
# robosuite
|
||||
multidict==6.7.0
|
||||
multidict==6.7.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
multiprocess==0.70.16
|
||||
multiprocess==0.70.18
|
||||
# via datasets
|
||||
mypy==1.19.1
|
||||
# via lerobot
|
||||
mypy-extensions==1.1.0
|
||||
# via typing-inspect
|
||||
# via
|
||||
# mypy
|
||||
# typing-inspect
|
||||
nbformat==5.10.4
|
||||
# via jupytext
|
||||
networkx==3.4.2
|
||||
networkx==3.6.1
|
||||
# via
|
||||
# bddl
|
||||
# scikit-image
|
||||
# torch
|
||||
ninja==1.13.0
|
||||
# via lerobot
|
||||
nodeenv==1.9.1
|
||||
nodeenv==1.10.0
|
||||
# via pre-commit
|
||||
num2words==0.5.14
|
||||
# via lerobot
|
||||
numba==0.62.1
|
||||
numba==0.64.0
|
||||
# via robosuite
|
||||
numpy==2.2.6
|
||||
# via
|
||||
@@ -391,7 +407,6 @@ numpy==2.2.6
|
||||
# cmeel-boost
|
||||
# contourpy
|
||||
# datasets
|
||||
# decord
|
||||
# diffusers
|
||||
# dm-control
|
||||
# dm-env
|
||||
@@ -399,9 +414,10 @@ numpy==2.2.6
|
||||
# gymnasium
|
||||
# h5py
|
||||
# hebi-py
|
||||
# hf-libero
|
||||
# imageio
|
||||
# labmaze
|
||||
# libero
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# metaworld
|
||||
@@ -426,49 +442,51 @@ numpy==2.2.6
|
||||
# torchvision
|
||||
# transformers
|
||||
# transforms3d
|
||||
nvidia-cublas-cu12==12.6.4.1
|
||||
nvidia-cublas-cu12==12.8.4.1
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# torch
|
||||
nvidia-cuda-cupti-cu12==12.6.80
|
||||
nvidia-cuda-cupti-cu12==12.8.90
|
||||
# via torch
|
||||
nvidia-cuda-nvrtc-cu12==12.6.77
|
||||
nvidia-cuda-nvrtc-cu12==12.8.93
|
||||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.6.77
|
||||
nvidia-cuda-runtime-cu12==12.8.90
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==9.5.1.17
|
||||
nvidia-cudnn-cu12==9.10.2.21
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.3.0.4
|
||||
nvidia-cufft-cu12==11.3.3.83
|
||||
# via torch
|
||||
nvidia-cufile-cu12==1.11.1.6
|
||||
nvidia-cufile-cu12==1.13.1.3
|
||||
# via torch
|
||||
nvidia-curand-cu12==10.3.7.77
|
||||
nvidia-curand-cu12==10.3.9.90
|
||||
# via torch
|
||||
nvidia-cusolver-cu12==11.7.1.2
|
||||
nvidia-cusolver-cu12==11.7.3.90
|
||||
# via torch
|
||||
nvidia-cusparse-cu12==12.5.4.2
|
||||
nvidia-cusparse-cu12==12.5.8.93
|
||||
# via
|
||||
# nvidia-cusolver-cu12
|
||||
# torch
|
||||
nvidia-cusparselt-cu12==0.6.3
|
||||
nvidia-cusparselt-cu12==0.7.1
|
||||
# via torch
|
||||
nvidia-nccl-cu12==2.26.2
|
||||
nvidia-nccl-cu12==2.27.5
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.6.85
|
||||
nvidia-nvjitlink-cu12==12.8.93
|
||||
# via
|
||||
# nvidia-cufft-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# nvidia-cusparse-cu12
|
||||
# torch
|
||||
nvidia-nvtx-cu12==12.6.77
|
||||
nvidia-nvshmem-cu12==3.4.5
|
||||
# via torch
|
||||
nvidia-nvtx-cu12==12.8.90
|
||||
# via torch
|
||||
omegaconf==2.3.0
|
||||
# via hydra-core
|
||||
opencv-python==4.12.0.88
|
||||
opencv-python==4.13.0.92
|
||||
# via
|
||||
# gym-pusht
|
||||
# libero
|
||||
# hf-libero
|
||||
# reachy2-sdk
|
||||
# robosuite
|
||||
opencv-python-headless==4.12.0.88
|
||||
@@ -487,6 +505,7 @@ packaging==25.0
|
||||
# matplotlib
|
||||
# peft
|
||||
# pytest
|
||||
# qwen-vl-utils
|
||||
# reachy2-sdk
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
@@ -497,21 +516,21 @@ pandas==2.3.3
|
||||
# via
|
||||
# datasets
|
||||
# lerobot
|
||||
parso==0.8.5
|
||||
parso==0.8.6
|
||||
# via jedi
|
||||
peft==0.17.1
|
||||
pathspec==1.0.4
|
||||
# via mypy
|
||||
peft==0.18.1
|
||||
# via lerobot
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pfzy==0.3.4
|
||||
# via inquirerpy
|
||||
pillow==12.0.0
|
||||
pillow==12.1.1
|
||||
# via
|
||||
# diffusers
|
||||
# imageio
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# qwen-vl-utils
|
||||
# rerun-sdk
|
||||
# robosuite
|
||||
# scikit-image
|
||||
@@ -519,28 +538,27 @@ pillow==12.0.0
|
||||
# torchvision
|
||||
pin==3.4.0
|
||||
# via placo
|
||||
placo==0.9.14
|
||||
placo==0.9.16
|
||||
# via lerobot
|
||||
platformdirs==4.5.0
|
||||
platformdirs==4.9.4
|
||||
# via
|
||||
# jupyter-core
|
||||
# python-discovery
|
||||
# virtualenv
|
||||
# wandb
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
pre-commit==4.3.0
|
||||
pre-commit==4.5.1
|
||||
# via lerobot
|
||||
prompt-toolkit==3.0.52
|
||||
# via
|
||||
# inquirerpy
|
||||
# ipython
|
||||
# via ipython
|
||||
propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
protobuf==6.31.0
|
||||
protobuf==6.31.1
|
||||
# via
|
||||
# dm-control
|
||||
# grpcio-tools
|
||||
@@ -550,7 +568,7 @@ protobuf==6.31.0
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# wandb
|
||||
psutil==7.1.1
|
||||
psutil==7.2.2
|
||||
# via
|
||||
# accelerate
|
||||
# imageio
|
||||
@@ -560,17 +578,17 @@ ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
# via stack-data
|
||||
pyarrow==21.0.0
|
||||
pyarrow==23.0.1
|
||||
# via
|
||||
# datasets
|
||||
# rerun-sdk
|
||||
pycparser==2.23
|
||||
pycparser==3.0
|
||||
# via cffi
|
||||
pydantic==2.12.3
|
||||
pydantic==2.12.5
|
||||
# via
|
||||
# fastapi
|
||||
# wandb
|
||||
pydantic-core==2.41.4
|
||||
pydantic-core==2.41.5
|
||||
# via pydantic
|
||||
pygame==2.6.1
|
||||
# via
|
||||
@@ -580,12 +598,14 @@ pygame==2.6.1
|
||||
pygments==2.19.2
|
||||
# via
|
||||
# ipython
|
||||
# ipython-pygments-lexers
|
||||
# pytest
|
||||
# rich
|
||||
pymunk==6.11.1
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pyngrok==7.4.1
|
||||
pyngrok==7.5.1
|
||||
# via meshcat
|
||||
pynput==1.8.1
|
||||
# via
|
||||
@@ -595,7 +615,7 @@ pyopengl==3.1.10
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.5
|
||||
pyparsing==3.3.2
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
@@ -621,13 +641,16 @@ pytest-timeout==2.4.0
|
||||
# via lerobot
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# faker
|
||||
# matplotlib
|
||||
# pandas
|
||||
python-dotenv==1.1.1
|
||||
python-discovery==1.1.1
|
||||
# via virtualenv
|
||||
python-dotenv==1.2.2
|
||||
# via uvicorn
|
||||
python-xlib==0.33
|
||||
# via pynput
|
||||
pytz==2025.2
|
||||
pytz==2026.1.post1
|
||||
# via pandas
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
@@ -642,7 +665,6 @@ pyyaml==6.0.3
|
||||
# pre-commit
|
||||
# pyngrok
|
||||
# pyyaml-include
|
||||
# timm
|
||||
# transformers
|
||||
# uvicorn
|
||||
# wandb
|
||||
@@ -652,7 +674,9 @@ pyzmq==27.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# meshcat
|
||||
reachy2-sdk==1.0.14
|
||||
qwen-vl-utils==0.0.14
|
||||
# via lerobot
|
||||
reachy2-sdk==1.0.15
|
||||
# via lerobot
|
||||
reachy2-sdk-api==1.0.21
|
||||
# via reachy2-sdk
|
||||
@@ -660,7 +684,7 @@ referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2025.10.23
|
||||
regex==2026.2.28
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
@@ -669,60 +693,62 @@ requests==2.32.5
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# qwen-vl-utils
|
||||
# teleop
|
||||
# transformers
|
||||
# wandb
|
||||
rerun-sdk==0.26.1
|
||||
rerun-sdk==0.26.2
|
||||
# via lerobot
|
||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||
# via placo
|
||||
rich==14.3.3
|
||||
# via typer
|
||||
robomimic==0.2.0
|
||||
# via libero
|
||||
# via hf-libero
|
||||
robosuite==1.4.0
|
||||
# via libero
|
||||
rpds-py==0.28.0
|
||||
# via hf-libero
|
||||
rpds-py==0.30.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.6.2
|
||||
safetensors==0.7.0
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# transformers
|
||||
scikit-image==0.25.2
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
scipy==1.15.3
|
||||
scipy==1.17.1
|
||||
# via
|
||||
# dm-control
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robosuite
|
||||
# scikit-image
|
||||
sentry-sdk==2.42.1
|
||||
# torchdiffeq
|
||||
sentry-sdk==2.54.0
|
||||
# via wandb
|
||||
shapely==2.1.2
|
||||
# via gym-pusht
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.17.0
|
||||
# via
|
||||
# pynput
|
||||
# python-dateutil
|
||||
# python-xlib
|
||||
smmap==5.0.2
|
||||
smmap==5.0.3
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
starlette==0.48.0
|
||||
starlette==0.52.1
|
||||
# via fastapi
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
teleop==0.1.2
|
||||
teleop==0.1.4
|
||||
# via lerobot
|
||||
tensorboard==2.20.0
|
||||
# via robomimic
|
||||
@@ -730,46 +756,38 @@ tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
tensorboardx==2.6.4
|
||||
# via robomimic
|
||||
termcolor==3.1.0
|
||||
termcolor==3.3.0
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
thop==0.1.1.post2209072238
|
||||
# via libero
|
||||
tifffile==2025.5.10
|
||||
# via hf-libero
|
||||
tifffile==2026.3.3
|
||||
# via scikit-image
|
||||
timm==1.0.20
|
||||
# via lerobot
|
||||
tokenizers==0.22.1
|
||||
tokenizers==0.22.2
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via draccus
|
||||
tomli==2.3.0
|
||||
# via
|
||||
# cmeel
|
||||
# coverage
|
||||
# jupytext
|
||||
# pytest
|
||||
torch==2.7.1
|
||||
torch==2.10.0
|
||||
# via
|
||||
# accelerate
|
||||
# flash-attn
|
||||
# lerobot
|
||||
# peft
|
||||
# robomimic
|
||||
# thop
|
||||
# timm
|
||||
# torchdiffeq
|
||||
# torchvision
|
||||
torchcodec==0.5
|
||||
torchcodec==0.10.0
|
||||
# via lerobot
|
||||
torchvision==0.22.1
|
||||
torchdiffeq==0.2.5
|
||||
# via lerobot
|
||||
torchvision==0.25.0
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
# timm
|
||||
tornado==6.5.2
|
||||
tornado==6.5.4
|
||||
# via meshcat
|
||||
tqdm==4.67.1
|
||||
tqdm==4.67.3
|
||||
# via
|
||||
# datasets
|
||||
# dm-control
|
||||
@@ -783,26 +801,29 @@ traitlets==5.14.3
|
||||
# jupyter-core
|
||||
# matplotlib-inline
|
||||
# nbformat
|
||||
transformers==4.57.1
|
||||
transformers==5.3.0
|
||||
# via
|
||||
# hf-libero
|
||||
# lerobot
|
||||
# libero
|
||||
# peft
|
||||
transforms3d==0.4.2
|
||||
# via teleop
|
||||
triton==3.3.1
|
||||
triton==3.6.0
|
||||
# via torch
|
||||
typer==0.24.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# etils
|
||||
# exceptiongroup
|
||||
# faker
|
||||
# fastapi
|
||||
# gymnasium
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# multidict
|
||||
# mypy
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
@@ -811,46 +832,46 @@ typing-extensions==4.15.0
|
||||
# torch
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# uvicorn
|
||||
# virtualenv
|
||||
# wandb
|
||||
typing-inspect==0.9.0
|
||||
# via draccus
|
||||
typing-inspection==0.4.2
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via
|
||||
# fastapi
|
||||
# pydantic
|
||||
tzdata==2025.3
|
||||
# via pandas
|
||||
u-msgpack-python==2.8.0
|
||||
# via meshcat
|
||||
urllib3==2.5.0
|
||||
urllib3==2.6.3
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
uvicorn[standard]==0.38.0
|
||||
uvicorn[standard]==0.41.0
|
||||
# via teleop
|
||||
uvloop==0.22.1
|
||||
# via uvicorn
|
||||
virtualenv==20.35.3
|
||||
virtualenv==21.1.0
|
||||
# via pre-commit
|
||||
wandb==0.21.4
|
||||
wandb==0.24.2
|
||||
# via
|
||||
# hf-libero
|
||||
# lerobot
|
||||
# libero
|
||||
watchfiles==1.1.1
|
||||
# via uvicorn
|
||||
wcwidth==0.2.14
|
||||
wcwidth==0.6.0
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.9.0
|
||||
# via teleop
|
||||
websockets==15.0.1
|
||||
websockets==16.0
|
||||
# via uvicorn
|
||||
werkzeug==3.1.3
|
||||
werkzeug==3.1.6
|
||||
# via tensorboard
|
||||
wrapt==2.0.0
|
||||
wrapt==2.1.2
|
||||
# via dm-tree
|
||||
xxhash==3.6.0
|
||||
# via datasets
|
||||
yarl==1.22.0
|
||||
yarl==1.23.0
|
||||
# via aiohttp
|
||||
zipp==3.23.0
|
||||
# via
|
||||
|
||||
+4
-4
@@ -1,9 +1,9 @@
|
||||
# requirements.in
|
||||
|
||||
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.0.1 25A362 arm64).
|
||||
# Darwin MacBook-Pro.local 25.0.0 Darwin Kernel Version 25.0.0: Wed Sep 17 21:42:08 PDT 2025; root:xnu-12377.1.9~141/RELEASE_ARM64_T8132 arm64
|
||||
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.3.1 25D2128 arm64).
|
||||
# Darwin MacBook-Pro.local 25.3.0 Darwin Kernel Version 25.3.0: Wed Jan 28 20:54:55 PST 2026; root:xnu-12377.91.3~2/RELEASE_ARM64_T8132 arm64
|
||||
|
||||
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.3 LTS x86_64).
|
||||
# Linux mlerobot-linux 6.14.0-33-generic #33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
||||
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.4 LTS x86_64).
|
||||
# Linux lerobot-linux 6.17.0-14-generic #14~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Jan 15 15:52:10 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
||||
|
||||
-e .[all]
|
||||
|
||||
@@ -63,9 +63,9 @@ from lerobot.transport import (
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
from .configs import RobotClientConfig
|
||||
from .constants import SUPPORTED_ROBOTS
|
||||
from .helpers import (
|
||||
Action,
|
||||
FPSTracker,
|
||||
@@ -485,8 +485,9 @@ class RobotClient:
|
||||
def async_client(cfg: RobotClientConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||
raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||
# TODO: Assert if checking robot support is still needed with the plugin system
|
||||
# if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||
# raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||
|
||||
client = RobotClient(cfg)
|
||||
|
||||
@@ -512,4 +513,5 @@ def async_client(cfg: RobotClientConfig):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
register_third_party_plugins()
|
||||
async_client() # run the client
|
||||
|
||||
@@ -181,7 +181,7 @@ class ZMQCamera(Camera):
|
||||
try:
|
||||
message = self.socket.recv_string()
|
||||
except Exception as e:
|
||||
# Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import
|
||||
# zmq is lazy-imported in connect(), so check by name to avoid a top-level import
|
||||
if type(e).__name__ == "Again":
|
||||
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
|
||||
raise
|
||||
|
||||
@@ -23,6 +23,7 @@ import base64
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
@@ -42,10 +43,57 @@ def encode_image(image: np.ndarray, quality: int = 80) -> str:
|
||||
return base64.b64encode(buffer).decode("utf-8")
|
||||
|
||||
|
||||
class CameraCaptureThread:
|
||||
"""Background thread that continuously captures and encodes frames from a camera."""
|
||||
|
||||
def __init__(self, camera: OpenCVCamera, name: str):
|
||||
self.camera = camera
|
||||
self.name = name
|
||||
self.latest_encoded: str | None = None # Pre-encoded JPEG as base64
|
||||
self.latest_timestamp: float = 0.0
|
||||
self.frame_lock = threading.Lock()
|
||||
self.running = False
|
||||
self.thread: threading.Thread | None = None
|
||||
|
||||
def start(self):
|
||||
"""Start the capture thread."""
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._capture_loop, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop the capture thread."""
|
||||
self.running = False
|
||||
if self.thread:
|
||||
self.thread.join(timeout=1.0)
|
||||
|
||||
def _capture_loop(self):
|
||||
"""Continuously capture and encode frames at the camera's native rate."""
|
||||
while self.running:
|
||||
try:
|
||||
frame = self.camera.read() # Blocks at camera's native rate
|
||||
timestamp = time.time()
|
||||
# Encode immediately in capture thread (this is the slow part)
|
||||
encoded = encode_image(frame)
|
||||
with self.frame_lock:
|
||||
self.latest_encoded = encoded
|
||||
self.latest_timestamp = timestamp
|
||||
except Exception as e:
|
||||
logger.warning(f"Camera {self.name} capture error: {e}")
|
||||
time.sleep(0.01)
|
||||
|
||||
def get_latest(self) -> tuple[str | None, float]:
|
||||
"""Get the latest encoded frame and its timestamp."""
|
||||
with self.frame_lock:
|
||||
return self.latest_encoded, self.latest_timestamp
|
||||
|
||||
|
||||
class ImageServer:
|
||||
def __init__(self, config: dict, port: int = 5555):
|
||||
# fps controls the publish loop rate (how often frames are sent over ZMQ), not the camera capture rate
|
||||
self.fps = config.get("fps", 30)
|
||||
self.cameras: dict[str, OpenCVCamera] = {}
|
||||
self.capture_threads: dict[str, CameraCaptureThread] = {}
|
||||
|
||||
for name, cfg in config.get("cameras", {}).items():
|
||||
shape = cfg.get("shape", [480, 640])
|
||||
@@ -61,6 +109,10 @@ class ImageServer:
|
||||
self.cameras[name] = camera
|
||||
logger.info(f"Camera {name}: {shape[1]}x{shape[0]}")
|
||||
|
||||
# Create capture thread for this camera
|
||||
capture_thread = CameraCaptureThread(camera, name)
|
||||
self.capture_threads[name] = capture_thread
|
||||
|
||||
# ZMQ PUB socket
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.PUB)
|
||||
@@ -73,6 +125,18 @@ class ImageServer:
|
||||
def run(self):
|
||||
frame_count = 0
|
||||
frame_times = deque(maxlen=60)
|
||||
last_published_ts: dict[str, float] = {}
|
||||
|
||||
# Start all capture threads
|
||||
for capture_thread in self.capture_threads.values():
|
||||
capture_thread.start()
|
||||
|
||||
# Wait for first frames to be captured and encoded
|
||||
logger.info("Waiting for cameras to start capturing...")
|
||||
for name, capture_thread in self.capture_threads.items():
|
||||
while capture_thread.get_latest()[0] is None:
|
||||
time.sleep(0.01)
|
||||
logger.info(f"Camera {name} ready (capture + encode in background)")
|
||||
|
||||
try:
|
||||
while True:
|
||||
@@ -80,10 +144,12 @@ class ImageServer:
|
||||
|
||||
# Build message
|
||||
message = {"timestamps": {}, "images": {}}
|
||||
for name, cam in self.cameras.items():
|
||||
frame = cam.read() # Returns RGB
|
||||
message["timestamps"][name] = time.time()
|
||||
message["images"][name] = encode_image(frame)
|
||||
for name, capture_thread in self.capture_threads.items():
|
||||
encoded, timestamp = capture_thread.get_latest()
|
||||
if encoded is not None and timestamp > last_published_ts.get(name, 0.0):
|
||||
message["timestamps"][name] = timestamp
|
||||
message["images"][name] = encoded
|
||||
last_published_ts[name] = timestamp
|
||||
|
||||
# Send as JSON string (suppress if buffer full)
|
||||
with contextlib.suppress(zmq.Again):
|
||||
@@ -102,6 +168,8 @@ class ImageServer:
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
for capture_thread in self.capture_threads.values():
|
||||
capture_thread.stop()
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
self.socket.close()
|
||||
|
||||
@@ -27,7 +27,7 @@ class DatasetConfig:
|
||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||
# datasets are provided.
|
||||
repo_id: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
|
||||
@@ -50,6 +50,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||
# AND for the evaluation environments.
|
||||
seed: int | None = 1000
|
||||
# Set to True to use deterministic cuDNN algorithms for reproducibility.
|
||||
# This disables cudnn.benchmark and may reduce training speed by ~10-20%.
|
||||
cudnn_deterministic: bool = False
|
||||
# Number of workers for the dataloader.
|
||||
num_workers: int = 4
|
||||
batch_size: int = 8
|
||||
|
||||
@@ -289,7 +289,9 @@ def aggregate_datasets(
|
||||
|
||||
logging.info("Find all tasks")
|
||||
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
|
||||
dst_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
|
||||
dst_meta.tasks = pd.DataFrame(
|
||||
{"task_index": range(len(unique_tasks))}, index=pd.Index(unique_tasks, name="task")
|
||||
)
|
||||
|
||||
meta_idx = {"chunk": 0, "file": 0}
|
||||
data_idx = {"chunk": 0, "file": 0}
|
||||
|
||||
@@ -7,6 +7,13 @@
|
||||
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
{% if repo_id is defined and repo_id %}
|
||||
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ repo_id }}">
|
||||
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
|
||||
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
|
||||
</a>
|
||||
{% endif %}
|
||||
|
||||
## Dataset Description
|
||||
|
||||
{{ dataset_description | default("", true) }}
|
||||
|
||||
@@ -89,8 +89,8 @@ def delete_episodes(
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
episode_indices: List of episode indices to delete.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
"""
|
||||
if not episode_indices:
|
||||
raise ValueError("No episodes to delete")
|
||||
@@ -152,7 +152,7 @@ def split_dataset(
|
||||
dataset: The source LeRobotDataset to split.
|
||||
splits: Either a dict mapping split names to episode indices, or a dict mapping
|
||||
split names to fractions (must sum to <= 1.0).
|
||||
output_dir: Base directory for output datasets. If None, uses default location.
|
||||
output_dir: Root directory where the split datasets will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
|
||||
Examples:
|
||||
Split by specific episodes
|
||||
@@ -243,8 +243,8 @@ def merge_datasets(
|
||||
|
||||
Args:
|
||||
datasets: List of LeRobotDatasets to merge.
|
||||
output_repo_id: Repository ID for the merged dataset.
|
||||
output_dir: Directory to save the merged dataset. If None, uses default location.
|
||||
output_repo_id: Merged dataset identifier.
|
||||
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
|
||||
"""
|
||||
if not datasets:
|
||||
raise ValueError("No datasets to merge")
|
||||
@@ -288,8 +288,8 @@ def modify_features(
|
||||
dataset: The source LeRobotDataset.
|
||||
add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples.
|
||||
remove_features: Optional feature name(s) to remove. Can be a single string or list.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
|
||||
Returns:
|
||||
New dataset with features modified.
|
||||
@@ -390,8 +390,8 @@ def add_features(
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
features: Dictionary mapping feature names to (feature_values, feature_info) tuples.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
|
||||
Returns:
|
||||
New dataset with all features added.
|
||||
@@ -427,8 +427,8 @@ def remove_feature(
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
feature_names: Name(s) of features to remove. Can be a single string or list.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
|
||||
Returns:
|
||||
New dataset with features removed.
|
||||
@@ -567,20 +567,22 @@ def _copy_and_reindex_data(
|
||||
def _keep_episodes_from_video_with_av(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[float, float]],
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
fps: float,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
This function decodes frames from specified time ranges and re-encodes them with
|
||||
This function decodes frames from specified frame ranges and re-encodes them with
|
||||
properly reset timestamps to ensure monotonic progression.
|
||||
|
||||
Args:
|
||||
input_path: Source video file path.
|
||||
output_path: Destination video file path.
|
||||
episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep.
|
||||
episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep.
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
fps: Frame rate of the video.
|
||||
vcodec: Video codec to use for encoding.
|
||||
pix_fmt: Pixel format for output video.
|
||||
@@ -622,9 +624,10 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Create set of (start, end) ranges for fast lookup.
|
||||
# Convert to a sorted list for efficient checking.
|
||||
time_ranges = sorted(episodes_to_keep)
|
||||
frame_ranges = sorted(episodes_to_keep)
|
||||
|
||||
# Track frame index for setting PTS and current range being processed.
|
||||
src_frame_count = 0
|
||||
frame_count = 0
|
||||
range_idx = 0
|
||||
|
||||
@@ -634,21 +637,20 @@ def _keep_episodes_from_video_with_av(
|
||||
if frame is None:
|
||||
continue
|
||||
|
||||
# Get frame timestamp.
|
||||
frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0
|
||||
|
||||
# Check if frame is in any of our desired time ranges.
|
||||
# Check if frame is in any of our desired frame ranges.
|
||||
# Skip ranges that have already passed.
|
||||
while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]:
|
||||
while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]:
|
||||
range_idx += 1
|
||||
|
||||
# If we've passed all ranges, stop processing.
|
||||
if range_idx >= len(time_ranges):
|
||||
if range_idx >= len(frame_ranges):
|
||||
break
|
||||
|
||||
# Check if frame is in current range.
|
||||
start_ts, end_ts = time_ranges[range_idx]
|
||||
if frame_time < start_ts:
|
||||
start_frame = frame_ranges[range_idx][0]
|
||||
|
||||
if src_frame_count < start_frame:
|
||||
src_frame_count += 1
|
||||
continue
|
||||
|
||||
# Frame is in range - create a new frame with reset timestamps.
|
||||
@@ -661,6 +663,7 @@ def _keep_episodes_from_video_with_av(
|
||||
for pkt in v_out.encode(new_frame):
|
||||
out.mux(pkt)
|
||||
|
||||
src_frame_count += 1
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder.
|
||||
@@ -749,15 +752,17 @@ def _copy_and_reindex_videos(
|
||||
f"videos/{video_key}/to_timestamp"
|
||||
]
|
||||
else:
|
||||
# Build list of time ranges to keep, in sorted order.
|
||||
# Build list of frame ranges to keep, in sorted order.
|
||||
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
|
||||
episodes_to_keep_ranges: list[tuple[float, float]] = []
|
||||
|
||||
episodes_to_keep_ranges: list[tuple[int, int]] = []
|
||||
for old_idx in sorted_keep_episodes:
|
||||
src_ep = src_dataset.meta.episodes[old_idx]
|
||||
from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
|
||||
to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
|
||||
episodes_to_keep_ranges.append((from_ts, to_ts))
|
||||
from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps)
|
||||
to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps)
|
||||
assert src_ep["length"] == to_frame - from_frame, (
|
||||
f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}"
|
||||
)
|
||||
episodes_to_keep_ranges.append((from_frame, to_frame))
|
||||
|
||||
# Use PyAV filters to efficiently re-encode only the desired segments.
|
||||
assert src_dataset.meta.video_path is not None
|
||||
@@ -1470,7 +1475,9 @@ def modify_tasks(
|
||||
|
||||
# Collect all unique tasks and create new task mapping
|
||||
unique_tasks = sorted(set(episode_to_task.values()))
|
||||
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
|
||||
new_task_df = pd.DataFrame(
|
||||
{"task_index": list(range(len(unique_tasks)))}, index=pd.Index(unique_tasks, name="task")
|
||||
)
|
||||
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
|
||||
|
||||
logging.info(f"Modifying tasks in {dataset.repo_id}")
|
||||
@@ -1524,7 +1531,7 @@ def modify_tasks(
|
||||
|
||||
def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
output_dir: Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
@@ -1543,8 +1550,8 @@ def convert_image_to_video_dataset(
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Directory to save the new video dataset
|
||||
repo_id: Repository ID for the new dataset (default: original_id + "_video")
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
@@ -1595,6 +1602,7 @@ def convert_image_to_video_dataset(
|
||||
# Video info will be updated after episodes are encoded
|
||||
|
||||
# Create new metadata for video dataset
|
||||
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
|
||||
@@ -314,7 +314,7 @@ class LeRobotDatasetMetadata:
|
||||
if self.tasks is None:
|
||||
new_tasks = tasks
|
||||
task_indices = range(len(tasks))
|
||||
self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks)
|
||||
self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task"))
|
||||
else:
|
||||
new_tasks = [task for task in tasks if task not in self.tasks.index]
|
||||
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
|
||||
@@ -664,11 +664,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for the README).
|
||||
|
||||
Args:
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
||||
will be stored under root/repo_id.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
||||
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
'~/.cache/huggingface/lerobot'.
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset.
|
||||
root (Path | None, optional): Local directory where the dataset will be downloaded and
|
||||
stored. If set, all dataset files will be stored directly under this path. If not set, the
|
||||
dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the
|
||||
HF_LEROBOT_HOME environment variable).
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
||||
@@ -747,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Check if cached dataset contains all requested episodes
|
||||
if not self._check_cached_episodes_sufficient():
|
||||
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download(download_videos)
|
||||
@@ -839,7 +839,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
hub_api.upload_folder(**upload_kwargs)
|
||||
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
@@ -1771,11 +1771,12 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
if extra_keys:
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
@@ -21,7 +21,7 @@ from collections import deque
|
||||
from collections.abc import Iterable, Iterator
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@@ -78,8 +78,6 @@ DEFAULT_FEATURES = {
|
||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
@@ -341,6 +339,7 @@ def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None:
|
||||
|
||||
def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||
tasks.index.name = "task"
|
||||
return tasks
|
||||
|
||||
|
||||
@@ -1233,7 +1232,7 @@ class LookAheadError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Backtrackable(Generic[T]):
|
||||
class Backtrackable[T]:
|
||||
"""
|
||||
Wrap any iterator/iterable so you can step back up to `history` items
|
||||
and look ahead up to `lookahead` items.
|
||||
|
||||
@@ -36,8 +36,11 @@ Convert a local dataset (works in place):
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
--repo-id=lerobot/pusht \
|
||||
--root=/path/to/local/dataset/directory
|
||||
--root=/path/to/local/dataset/directory \
|
||||
--push-to-hub=false
|
||||
|
||||
N.B. Path semantics (v2): --root is the exact dataset folder containing
|
||||
meta/, data/, videos/. When omitted, defaults to $HF_LEROBOT_HOME/{repo_id}.
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -105,7 +108,7 @@ episodes.jsonl
|
||||
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
|
||||
|
||||
NEW
|
||||
meta/episodes/chunk-000/episodes_000.parquet
|
||||
meta/episodes/chunk-000/file_000.parquet
|
||||
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
|
||||
-------------------------
|
||||
OLD
|
||||
@@ -113,15 +116,16 @@ tasks.jsonl
|
||||
{"task_index": 1, "task": "Put the blue block in the green bowl"}
|
||||
|
||||
NEW
|
||||
meta/tasks/chunk-000/file_000.parquet
|
||||
meta/tasks.parquet
|
||||
task_index | task
|
||||
-------------------------
|
||||
OLD
|
||||
episodes_stats.jsonl
|
||||
{"episode_index": 1, "stats": {"feature_name": {"min": ..., "max": ..., "mean": ..., "std": ..., "count": ...}}}
|
||||
|
||||
NEW
|
||||
meta/episodes_stats/chunk-000/file_000.parquet
|
||||
episode_index | mean | std | min | max
|
||||
meta/episodes/chunk-000/file_000.parquet
|
||||
episode_index | feature_name/min | feature_name/max | feature_name/mean | feature_name/std | feature_name/count
|
||||
-------------------------
|
||||
UPDATE
|
||||
meta/info.json
|
||||
@@ -170,7 +174,7 @@ def convert_tasks(root, new_root):
|
||||
tasks, _ = legacy_load_tasks(root)
|
||||
task_indices = tasks.keys()
|
||||
task_strings = tasks.values()
|
||||
df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings)
|
||||
df_tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(task_strings, name="task"))
|
||||
write_tasks(df_tasks, new_root)
|
||||
|
||||
|
||||
@@ -201,7 +205,6 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
|
||||
image_keys = get_image_keys(root)
|
||||
|
||||
ep_idx = 0
|
||||
chunk_idx = 0
|
||||
file_idx = 0
|
||||
size_in_mb = 0
|
||||
@@ -211,9 +214,23 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
|
||||
logging.info(f"Converting data files from {len(ep_paths)} episodes")
|
||||
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc="convert data files"):
|
||||
for ep_idx, ep_path in enumerate(tqdm.tqdm(ep_paths, desc="convert data files")):
|
||||
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
|
||||
ep_num_frames = get_parquet_num_frames(ep_path)
|
||||
|
||||
# Check if we need to start a new file BEFORE creating metadata
|
||||
if size_in_mb + ep_size_in_mb >= data_file_size_in_mb and len(paths_to_cat) > 0:
|
||||
# Write the accumulated data files
|
||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||
|
||||
# Move to next file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
|
||||
# Reset for the next file
|
||||
size_in_mb = 0
|
||||
paths_to_cat = []
|
||||
|
||||
# Now create metadata with correct chunk/file indices
|
||||
ep_metadata = {
|
||||
"episode_index": ep_idx,
|
||||
"data/chunk_index": chunk_idx,
|
||||
@@ -224,20 +241,7 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
size_in_mb += ep_size_in_mb
|
||||
num_frames += ep_num_frames
|
||||
episodes_metadata.append(ep_metadata)
|
||||
ep_idx += 1
|
||||
|
||||
if size_in_mb < data_file_size_in_mb:
|
||||
paths_to_cat.append(ep_path)
|
||||
continue
|
||||
|
||||
if paths_to_cat:
|
||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||
|
||||
# Reset for the next file
|
||||
size_in_mb = ep_size_in_mb
|
||||
paths_to_cat = [ep_path]
|
||||
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
paths_to_cat.append(ep_path)
|
||||
|
||||
# Write remaining data if any
|
||||
if paths_to_cat:
|
||||
@@ -469,7 +473,7 @@ def convert_dataset(
|
||||
|
||||
# Set root based on whether local dataset path is provided
|
||||
use_local_dataset = False
|
||||
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root) / repo_id
|
||||
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root)
|
||||
if root.exists():
|
||||
validate_local_dataset_version(root)
|
||||
use_local_dataset = True
|
||||
@@ -553,7 +557,7 @@ if __name__ == "__main__":
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Local directory to use for downloading/writing the dataset.",
|
||||
help="Local directory to use for downloading/writing the dataset. Defaults to $HF_LEROBOT_HOME/repo_id.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
|
||||
@@ -227,16 +227,17 @@ def decode_video_frames_torchvision(
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
@@ -248,7 +249,11 @@ def decode_video_frames_torchvision(
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
if len(timestamps) != len(closest_frames):
|
||||
raise FrameTimestampError(
|
||||
f"Number of retrieved frames ({len(closest_frames)}) does not match "
|
||||
f"number of queried timestamps ({len(timestamps)})"
|
||||
)
|
||||
return closest_frames
|
||||
|
||||
|
||||
@@ -353,15 +358,16 @@ def decode_video_frames_torchcodec(
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
|
||||
@@ -346,6 +346,65 @@ class LiberoEnv(EnvConfig):
|
||||
return kwargs
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("libero_plus")
|
||||
@dataclass
|
||||
class LiberoPlusEnv(LiberoEnv):
|
||||
"""Alias config for LIBERO-plus benchmarks.
|
||||
|
||||
LIBERO-plus keeps the same Python package/module names as LIBERO, so this
|
||||
config reuses the existing LIBERO env implementation while making intent explicit
|
||||
in experiment configs (`env.type=libero_plus`).
|
||||
"""
|
||||
|
||||
task: str = "libero_spatial"
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("robocasa")
|
||||
@dataclass
|
||||
class RoboCasaEnv(EnvConfig):
|
||||
"""RoboCasa kitchen composite-task environments.
|
||||
|
||||
Wraps ``robocasa.wrappers.gym_wrapper.RoboCasaGymEnv`` with a flat 12-D Box
|
||||
action space and a structured pixel + state observation dict.
|
||||
|
||||
Selected benchmark tasks (3 short + 2 long):
|
||||
Short: PickPlaceCounterToCabinet, PrepareToast, CoffeeSetupMug
|
||||
Long: PrepareCoffee, RestockPantry
|
||||
"""
|
||||
|
||||
task: str = "PickPlaceCounterToCabinet"
|
||||
tasks: list[str] | None = None # multi-task: list of task names (without robocasa/ prefix)
|
||||
fps: int = 20
|
||||
episode_length: int = 500
|
||||
image_size: int = 128
|
||||
split: str = "target" # "pretrain" or "target"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(12,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"agentview_left": f"{OBS_IMAGES}.agentview_left",
|
||||
"agentview_right": f"{OBS_IMAGES}.agentview_right",
|
||||
"eye_in_hand": f"{OBS_IMAGES}.eye_in_hand",
|
||||
"robot_state": OBS_STATE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
|
||||
self.features[cam] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.image_size, self.image_size, 3)
|
||||
)
|
||||
self.features["robot_state"] = PolicyFeature(type=FeatureType.STATE, shape=(16,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {"split": self.split}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("metaworld")
|
||||
@dataclass
|
||||
class MetaworldEnv(EnvConfig):
|
||||
|
||||
@@ -20,11 +20,20 @@ import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
|
||||
from lerobot.envs.configs import (
|
||||
AlohaEnv,
|
||||
EnvConfig,
|
||||
HubEnvConfig,
|
||||
IsaaclabArenaEnv,
|
||||
LiberoEnv,
|
||||
LiberoPlusEnv,
|
||||
PushtEnv,
|
||||
RoboCasaEnv,
|
||||
)
|
||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import ProcessorStep
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep, RoboCasaProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
|
||||
@@ -35,6 +44,10 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "libero":
|
||||
return LiberoEnv(**kwargs)
|
||||
elif env_type == "libero_plus":
|
||||
return LiberoPlusEnv(**kwargs)
|
||||
elif env_type == "robocasa":
|
||||
return RoboCasaEnv(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
@@ -70,9 +83,13 @@ def make_env_pre_post_processors(
|
||||
return make_xvla_libero_pre_post_processors()
|
||||
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
if isinstance(env_cfg, (LiberoEnv, LiberoPlusEnv)) or "libero" in env_cfg.type:
|
||||
preprocessor_steps.append(LiberoProcessorStep())
|
||||
|
||||
# For RoboCasa environments, add the RoboCasaProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, RoboCasaEnv) or "robocasa" in env_cfg.type:
|
||||
preprocessor_steps.append(RoboCasaProcessorStep())
|
||||
|
||||
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
|
||||
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
|
||||
# Parse comma-separated keys (handle None for state-based policies)
|
||||
@@ -181,6 +198,20 @@ def make_env(
|
||||
control_mode=cfg.control_mode,
|
||||
episode_length=cfg.episode_length,
|
||||
)
|
||||
elif "robocasa" in cfg.type:
|
||||
from lerobot.envs.robocasa import create_robocasa_envs
|
||||
|
||||
tasks = cfg.tasks if cfg.tasks else [cfg.task]
|
||||
return create_robocasa_envs(
|
||||
tasks=tasks,
|
||||
n_envs=n_envs,
|
||||
image_size=cfg.image_size,
|
||||
split=cfg.split,
|
||||
episode_length=cfg.episode_length,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
elif "metaworld" in cfg.type:
|
||||
from lerobot.envs.metaworld import create_metaworld_envs
|
||||
|
||||
|
||||
@@ -26,8 +26,14 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
try:
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
except ImportError:
|
||||
# LIBERO-plus may be installed from source with an extra nested package level.
|
||||
from libero.libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
from lerobot.processor import RobotObservation
|
||||
|
||||
|
||||
@@ -0,0 +1,273 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
|
||||
# Action layout (flat 12D, normalized to [-1, 1]):
|
||||
# [0:3] end_effector_position (delta x, y, z)
|
||||
# [3:6] end_effector_rotation (delta roll, pitch, yaw)
|
||||
# [6:7] gripper_close (open=-1, close=+1)
|
||||
# [7:11] base_motion (x, y, theta, torso_height)
|
||||
# [11:12] control_mode (arm=-1, base=+1)
|
||||
ACTION_DIM = 12
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
|
||||
# Proprioceptive state layout (flat 16D):
|
||||
# [0:2] gripper_qpos
|
||||
# [2:5] base_position
|
||||
# [5:9] base_rotation (quaternion)
|
||||
# [9:12] end_effector_position_relative
|
||||
# [12:16] end_effector_rotation_relative (quaternion)
|
||||
STATE_DIM = 16
|
||||
|
||||
# Obs dict keys from RoboCasaGymEnv.get_observation()
|
||||
_CAM_KEYS = (
|
||||
"video.robot0_agentview_left",
|
||||
"video.robot0_agentview_right",
|
||||
"video.robot0_eye_in_hand",
|
||||
)
|
||||
_STATE_KEYS_ORDERED = (
|
||||
"state.gripper_qpos", # (2,)
|
||||
"state.base_position", # (3,)
|
||||
"state.base_rotation", # (4,)
|
||||
"state.end_effector_position_relative", # (3,)
|
||||
"state.end_effector_rotation_relative", # (4,)
|
||||
)
|
||||
|
||||
# Mapping from video.* key → short image name used in features_map
|
||||
CAM_KEY_TO_NAME = {
|
||||
"video.robot0_agentview_left": "agentview_left",
|
||||
"video.robot0_agentview_right": "agentview_right",
|
||||
"video.robot0_eye_in_hand": "eye_in_hand",
|
||||
}
|
||||
|
||||
|
||||
def _flat_to_action_dict(flat: np.ndarray) -> dict[str, np.ndarray]:
|
||||
"""Convert a 12D flat action array to the Dict format expected by RoboCasaGymEnv."""
|
||||
return {
|
||||
"action.end_effector_position": flat[0:3],
|
||||
"action.end_effector_rotation": flat[3:6],
|
||||
"action.gripper_close": flat[6:7],
|
||||
"action.base_motion": flat[7:11],
|
||||
"action.control_mode": flat[11:12],
|
||||
}
|
||||
|
||||
|
||||
class RoboCasaEnv(gym.Env):
|
||||
"""Thin wrapper around RoboCasaGymEnv that provides a flat Box action space
|
||||
and a structured observation dict compatible with LeRobot policies.
|
||||
|
||||
Observations returned by step/reset:
|
||||
{
|
||||
"pixels": {
|
||||
"agentview_left": (H, W, 3) uint8,
|
||||
"agentview_right": (H, W, 3) uint8,
|
||||
"eye_in_hand": (H, W, 3) uint8,
|
||||
},
|
||||
"robot_state": (16,) float32,
|
||||
}
|
||||
|
||||
Actions: flat float32 ndarray of shape (12,), normalized to [-1, 1].
|
||||
"""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
split: str = "target",
|
||||
image_size: int = 128,
|
||||
render_mode: str = "rgb_array",
|
||||
episode_length: int = 500,
|
||||
**gym_kwargs: Any,
|
||||
):
|
||||
super().__init__()
|
||||
# Lazy import — robocasa is optional
|
||||
import robocasa.environments # noqa: F401 — registers all gym envs
|
||||
|
||||
self.task = task
|
||||
self.render_mode = render_mode
|
||||
self.image_size = image_size
|
||||
self._max_episode_steps = episode_length
|
||||
self._step_count = 0
|
||||
|
||||
self._env = gym.make(
|
||||
f"robocasa/{task}",
|
||||
split=split,
|
||||
camera_widths=image_size,
|
||||
camera_heights=image_size,
|
||||
**gym_kwargs,
|
||||
)
|
||||
|
||||
# Flat 12D Box action space
|
||||
self.action_space = spaces.Box(
|
||||
low=ACTION_LOW,
|
||||
high=ACTION_HIGH,
|
||||
shape=(ACTION_DIM,),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
images = {
|
||||
name: spaces.Box(low=0, high=255, shape=(image_size, image_size, 3), dtype=np.uint8)
|
||||
for name in CAM_KEY_TO_NAME.values()
|
||||
}
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(images),
|
||||
"robot_state": spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(STATE_DIM,), dtype=np.float32
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _format_obs(self, raw_obs: dict) -> dict:
|
||||
pixels = {
|
||||
CAM_KEY_TO_NAME[k]: raw_obs[k]
|
||||
for k in _CAM_KEYS
|
||||
if k in raw_obs
|
||||
}
|
||||
state_parts = [
|
||||
np.asarray(raw_obs[k], dtype=np.float32)
|
||||
for k in _STATE_KEYS_ORDERED
|
||||
if k in raw_obs
|
||||
]
|
||||
robot_state = np.concatenate(state_parts) if state_parts else np.zeros(STATE_DIM, dtype=np.float32)
|
||||
return {"pixels": pixels, "robot_state": robot_state}
|
||||
|
||||
def reset(self, seed: int | None = None, **kwargs) -> tuple[dict, dict]:
|
||||
super().reset(seed=seed)
|
||||
self._step_count = 0
|
||||
raw_obs, info = self._env.reset(seed=seed)
|
||||
info.setdefault("is_success", False)
|
||||
info["task"] = self.task
|
||||
return self._format_obs(raw_obs), info
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[dict, float, bool, bool, dict]:
|
||||
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
|
||||
raise ValueError(
|
||||
f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}"
|
||||
)
|
||||
action_dict = _flat_to_action_dict(action)
|
||||
raw_obs, reward, terminated, truncated, info = self._env.step(action_dict)
|
||||
self._step_count += 1
|
||||
|
||||
is_success = bool(info.get("success", False))
|
||||
terminated = terminated or is_success
|
||||
if self._step_count >= self._max_episode_steps:
|
||||
truncated = True
|
||||
|
||||
info.update({"task": self.task, "is_success": is_success})
|
||||
obs = self._format_obs(raw_obs)
|
||||
|
||||
if terminated or truncated:
|
||||
info["final_info"] = {"task": self.task, "is_success": is_success}
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def render(self) -> np.ndarray | None:
|
||||
if self.render_mode == "rgb_array":
|
||||
return self._env.render()
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
self._env.close()
|
||||
|
||||
|
||||
def _make_env_fns(
|
||||
*,
|
||||
task: str,
|
||||
n_envs: int,
|
||||
image_size: int,
|
||||
split: str,
|
||||
episode_length: int,
|
||||
gym_kwargs: dict[str, Any],
|
||||
) -> list[Callable[[], RoboCasaEnv]]:
|
||||
"""Build n_envs factory callables for a single task."""
|
||||
def _make(episode_index: int) -> RoboCasaEnv: # noqa: ARG001
|
||||
return RoboCasaEnv(
|
||||
task=task,
|
||||
split=split,
|
||||
image_size=image_size,
|
||||
episode_length=episode_length,
|
||||
**gym_kwargs,
|
||||
)
|
||||
|
||||
return [partial(_make, i) for i in range(n_envs)]
|
||||
|
||||
|
||||
def create_robocasa_envs(
|
||||
tasks: str | Sequence[str],
|
||||
n_envs: int,
|
||||
image_size: int = 128,
|
||||
split: str = "target",
|
||||
episode_length: int = 500,
|
||||
gym_kwargs: dict[str, Any] | None = None,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""Create vectorized RoboCasa environments.
|
||||
|
||||
Args:
|
||||
tasks: A single task name or list of task names (without "robocasa/" prefix).
|
||||
E.g. "PickPlaceCounterToCabinet" or ["BoilPot", "PrepareCoffee"].
|
||||
n_envs: Number of parallel envs per task.
|
||||
image_size: Square image resolution for all cameras.
|
||||
split: RoboCasa dataset split — "pretrain" or "target".
|
||||
episode_length: Max steps per episode before truncation.
|
||||
gym_kwargs: Extra kwargs forwarded to each RoboCasaEnv.
|
||||
env_cls: Callable to wrap list of factory fns (SyncVectorEnv or AsyncVectorEnv).
|
||||
|
||||
Returns:
|
||||
dict[task_name][task_id=0] -> vec_env
|
||||
"""
|
||||
if env_cls is None or not callable(env_cls):
|
||||
raise ValueError("env_cls must be a callable wrapping a list of env factory callables.")
|
||||
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||
|
||||
if isinstance(tasks, str):
|
||||
task_list = [t.strip() for t in tasks.split(",") if t.strip()]
|
||||
else:
|
||||
task_list = [str(t).strip() for t in tasks if str(t).strip()]
|
||||
if not task_list:
|
||||
raise ValueError("`tasks` must contain at least one task name.")
|
||||
|
||||
gym_kwargs = dict(gym_kwargs or {})
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
|
||||
print(f"Creating RoboCasa envs | tasks={task_list} | n_envs(per task)={n_envs} | split={split}")
|
||||
for task in task_list:
|
||||
fns = _make_env_fns(
|
||||
task=task,
|
||||
n_envs=n_envs,
|
||||
image_size=image_size,
|
||||
split=split,
|
||||
episode_length=episode_length,
|
||||
gym_kwargs=gym_kwargs,
|
||||
)
|
||||
out["robocasa"][len(out["robocasa"])] = env_cls(fns)
|
||||
print(f" Built vec env | task={task} | n_envs={n_envs}")
|
||||
|
||||
return {suite: dict(task_map) for suite, task_map in out.items()}
|
||||
@@ -29,7 +29,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from pprint import pformat
|
||||
from typing import Protocol, TypeAlias
|
||||
from typing import Protocol
|
||||
|
||||
import serial
|
||||
from deepdiff import DeepDiff
|
||||
@@ -38,8 +38,8 @@ from tqdm import tqdm
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
NameOrID: TypeAlias = str | int
|
||||
Value: TypeAlias = int | float
|
||||
type NameOrID = str | int
|
||||
type Value = int | float
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1277,4 +1277,4 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
MotorsBus: TypeAlias = SerialMotorsBus
|
||||
MotorsBus = SerialMotorsBus
|
||||
|
||||
@@ -55,10 +55,16 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision
|
||||
backbone. If None, no resizing is done and the original image resolution is used.
|
||||
crop_ratio: Ratio in (0, 1] used to derive the crop size from resize_shape
|
||||
(crop_h = int(resize_shape[0] * crop_ratio), likewise for width).
|
||||
Set to 1.0 to disable cropping. Only takes effect when resize_shape is not None.
|
||||
crop_shape: (H, W) shape to crop images to. When resize_shape is set and crop_ratio < 1.0,
|
||||
this is computed automatically. Can also be set directly for legacy configs that use
|
||||
crop-only (without resize). If None and no derivation applies, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center
|
||||
crop in eval mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
@@ -114,7 +120,9 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
crop_shape: tuple[int, int] | None = (84, 84)
|
||||
resize_shape: tuple[int, int] | None = None
|
||||
crop_ratio: float = 1.0
|
||||
crop_shape: tuple[int, int] | None = None
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
@@ -139,6 +147,10 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
# Inference
|
||||
num_inference_steps: int | None = None
|
||||
|
||||
# Optimization
|
||||
compile_model: bool = False
|
||||
compile_mode: str = "reduce-overhead"
|
||||
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
@@ -171,6 +183,25 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
f"Got {self.noise_scheduler_type}."
|
||||
)
|
||||
|
||||
if self.resize_shape is not None and (
|
||||
len(self.resize_shape) != 2 or any(d <= 0 for d in self.resize_shape)
|
||||
):
|
||||
raise ValueError(f"`resize_shape` must be a pair of positive integers. Got {self.resize_shape}.")
|
||||
if not (0 < self.crop_ratio <= 1.0):
|
||||
raise ValueError(f"`crop_ratio` must be in (0, 1]. Got {self.crop_ratio}.")
|
||||
|
||||
if self.resize_shape is not None:
|
||||
if self.crop_ratio < 1.0:
|
||||
self.crop_shape = (
|
||||
int(self.resize_shape[0] * self.crop_ratio),
|
||||
int(self.resize_shape[1] * self.crop_ratio),
|
||||
)
|
||||
else:
|
||||
# Explicitly disable cropping for resize+ratio path when crop_ratio == 1.0.
|
||||
self.crop_shape = None
|
||||
if self.crop_shape is not None and (self.crop_shape[0] <= 0 or self.crop_shape[1] <= 0):
|
||||
raise ValueError(f"`crop_shape` must have positive dimensions. Got {self.crop_shape}.")
|
||||
|
||||
# Check that the horizon size and U-Net downsampling is compatible.
|
||||
# U-Net downsamples by 2 with each stage.
|
||||
downsampling_factor = 2 ** len(self.down_dims)
|
||||
@@ -198,13 +229,12 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if self.crop_shape is not None:
|
||||
if self.resize_shape is None and self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
f"`crop_shape` should fit within the image shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for `{key}`."
|
||||
)
|
||||
|
||||
# Check that all input images have the same shape.
|
||||
|
||||
@@ -142,6 +142,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
for key in self.config.image_features:
|
||||
if self.config.n_obs_steps == 1 and batch[key].ndim == 4:
|
||||
batch[key] = batch[key].unsqueeze(1)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
# no output_dict so returning None
|
||||
@@ -182,6 +185,11 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
if config.compile_model:
|
||||
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
|
||||
# common in diffusion inference.
|
||||
self.unet = torch.compile(self.unet, mode=config.compile_mode)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
num_train_timesteps=config.num_train_timesteps,
|
||||
@@ -446,12 +454,18 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
# Set up optional preprocessing.
|
||||
if config.crop_shape is not None:
|
||||
if config.resize_shape is not None:
|
||||
self.resize = torchvision.transforms.Resize(config.resize_shape)
|
||||
else:
|
||||
self.resize = None
|
||||
|
||||
crop_shape = config.crop_shape
|
||||
if crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -477,13 +491,16 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.image_features` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.image_features`.
|
||||
# The dummy shape mirrors the runtime preprocessing order: resize -> crop.
|
||||
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
if config.crop_shape is not None:
|
||||
dummy_shape_h_w = config.crop_shape
|
||||
elif config.resize_shape is not None:
|
||||
dummy_shape_h_w = config.resize_shape
|
||||
else:
|
||||
dummy_shape_h_w = images_shape[1:]
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
@@ -499,7 +516,10 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
Returns:
|
||||
(B, D) image feature.
|
||||
"""
|
||||
# Preprocess: maybe crop (if it was set up in the __init__).
|
||||
# Preprocess: resize if configured, then crop if configured.
|
||||
|
||||
if self.resize is not None:
|
||||
x = self.resize(x)
|
||||
if self.do_crop:
|
||||
if self.training: # noqa: SIM108
|
||||
x = self.maybe_random_crop(x)
|
||||
|
||||
@@ -18,10 +18,9 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, Unpack
|
||||
|
||||
import torch
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
@@ -4,17 +4,16 @@
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
|
||||
from typing import Optional
|
||||
|
||||
from transformers.image_processing_utils import (
|
||||
BatchFeature,
|
||||
get_patch_output_size,
|
||||
)
|
||||
from transformers.image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
ImagesKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
@@ -77,7 +76,7 @@ def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> tor
|
||||
return img[:, top:bottom, left:right]
|
||||
|
||||
|
||||
class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
class Eagle25VLFastImageProcessorKwargs(ImagesKwargs):
|
||||
max_dynamic_tiles: int | None
|
||||
min_dynamic_tiles: int | None
|
||||
use_thumbnail: bool | None
|
||||
@@ -165,11 +164,11 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
def _resize_for_patching(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
image: torch.Tensor,
|
||||
target_resolution: tuple,
|
||||
interpolation: "F.InterpolationMode",
|
||||
interpolation: F.InterpolationMode,
|
||||
input_data_format: ChannelDimension,
|
||||
) -> "torch.Tensor":
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Resizes an image to a target resolution while maintaining aspect ratio.
|
||||
|
||||
@@ -219,8 +218,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
return best_ratio
|
||||
|
||||
def _pad_for_patching(
|
||||
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> "torch.Tensor":
|
||||
self, image: torch.Tensor, target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
@@ -236,15 +235,15 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
def _get_image_patches(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
image: torch.Tensor,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
size: tuple,
|
||||
tile_size: int,
|
||||
use_thumbnail: bool,
|
||||
interpolation: "F.InterpolationMode",
|
||||
interpolation: F.InterpolationMode,
|
||||
pad_during_tiling: bool,
|
||||
) -> list["torch.Tensor"]:
|
||||
) -> list[torch.Tensor]:
|
||||
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
||||
orig_height, orig_width = image_size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
@@ -305,8 +304,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
def _pad_for_batching(
|
||||
self,
|
||||
pixel_values: list["torch.Tensor"],
|
||||
) -> list["torch.Tensor"]:
|
||||
pixel_values: list[torch.Tensor],
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
||||
|
||||
@@ -327,14 +326,14 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
images: list[torch.Tensor],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
max_dynamic_tiles: int,
|
||||
min_dynamic_tiles: int,
|
||||
use_thumbnail: bool,
|
||||
pad_during_tiling: bool,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
interpolation: F.InterpolationMode | None,
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
|
||||
@@ -15,16 +15,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -32,13 +32,21 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaForCausalLM,
|
||||
_gated_residual,
|
||||
layernorm_forward,
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
modeling_gemma = None
|
||||
GemmaForCausalLM = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
layernorm_forward = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||
@@ -191,7 +199,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
resized_images = resized_images.clamp(0.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
@@ -202,7 +210,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
@@ -221,14 +229,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.language_model, gemma_expert.model]
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
@@ -254,10 +262,10 @@ def compute_layer_complete(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.language_model.layers[layer_idx].self_attn,
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -265,7 +273,7 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
@@ -277,15 +285,15 @@ def compute_layer_complete(
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||
# first residual
|
||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
||||
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
||||
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
# second residual
|
||||
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
||||
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||
outputs_embeds.append(out_emb)
|
||||
start_pos = end_pos
|
||||
return outputs_embeds
|
||||
@@ -358,7 +366,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
@@ -366,7 +374,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
|
||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||
head_dim=action_expert_config.head_dim,
|
||||
@@ -377,13 +385,13 @@ class PaliGemmaWithExpertModel(
|
||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||
vocab_size=257152,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
torch_dtype="float32",
|
||||
dtype="float32",
|
||||
use_adarms=use_adarms[1],
|
||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||
)
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
||||
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
@@ -398,10 +406,11 @@ class PaliGemmaWithExpertModel(
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||
# "same dtype" error). Align with PI05.
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
@@ -413,8 +422,8 @@ class PaliGemmaWithExpertModel(
|
||||
|
||||
def _set_requires_grad(self):
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
for param in self.paligemma.vision_tower.parameters():
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
for param in self.paligemma.model.vision_tower.parameters():
|
||||
param.requires_grad = False
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
@@ -424,15 +433,23 @@ class PaliGemmaWithExpertModel(
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.paligemma.model.get_image_features(image)
|
||||
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -446,7 +463,7 @@ class PaliGemmaWithExpertModel(
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.language_model.forward(
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@@ -470,7 +487,7 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.language_model, self.gemma_expert.model]
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
@@ -510,7 +527,7 @@ class PaliGemmaWithExpertModel(
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -576,29 +593,19 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Also compile the main forward pass used during training
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
@@ -760,7 +767,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
||||
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -834,7 +841,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks_4d,
|
||||
@@ -908,6 +915,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
@@ -997,14 +1005,12 @@ class PI0Policy(PreTrainedPolicy):
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# Now manually load and remap the state dict
|
||||
# Load state dict (expects keys with "model." prefix)
|
||||
try:
|
||||
# Try to load the pytorch_model.bin or model.safetensors file
|
||||
print(f"Loading model from: {pretrained_name_or_path}")
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_name_or_path,
|
||||
"model.safetensors",
|
||||
@@ -1012,7 +1018,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
token=kwargs.get("token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
@@ -1025,7 +1031,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
@@ -1070,7 +1076,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
print("All keys loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not remap state dict keys: {e}")
|
||||
print(f"Warning: Could not load state dict: {e}")
|
||||
|
||||
return model
|
||||
|
||||
@@ -1120,6 +1126,14 @@ class PI0Policy(PreTrainedPolicy):
|
||||
# Some checkpoints might have this, but current model expects different structure
|
||||
logging.warning(f"Vision embedding key might need handling: {key}")
|
||||
|
||||
if (
|
||||
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
|
||||
or key == "paligemma_with_expert.paligemma.lm_head.weight"
|
||||
):
|
||||
fixed_state_dict[
|
||||
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
] = value.clone()
|
||||
|
||||
fixed_state_dict[new_key] = value
|
||||
|
||||
return fixed_state_dict
|
||||
|
||||
@@ -15,16 +15,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -32,14 +32,20 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaForCausalLM,
|
||||
_gated_residual,
|
||||
layernorm_forward,
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
modeling_gemma = None
|
||||
GemmaForCausalLM = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
layernorm_forward = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
@@ -92,10 +98,11 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
||||
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
||||
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
|
||||
alpha_t = torch.tensor(alpha, dtype=torch.float32)
|
||||
beta_t = torch.tensor(beta, dtype=torch.float32)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,))
|
||||
return dist.sample((bsize,)).to(device)
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
|
||||
@@ -189,7 +196,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
resized_images = resized_images.clamp(0.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
@@ -200,7 +207,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
@@ -219,14 +226,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.language_model, gemma_expert.model]
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
@@ -252,10 +259,10 @@ def compute_layer_complete(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.language_model.layers[layer_idx].self_attn,
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -263,7 +270,7 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
@@ -275,15 +282,15 @@ def compute_layer_complete(
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||
# first residual
|
||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
||||
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
||||
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
# second residual
|
||||
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
||||
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||
outputs_embeds.append(out_emb)
|
||||
start_pos = end_pos
|
||||
return outputs_embeds
|
||||
@@ -356,7 +363,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
@@ -364,7 +371,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
|
||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||
head_dim=action_expert_config.head_dim,
|
||||
@@ -375,13 +382,13 @@ class PaliGemmaWithExpertModel(
|
||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||
vocab_size=257152,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
torch_dtype="float32",
|
||||
dtype="float32",
|
||||
use_adarms=use_adarms[1],
|
||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||
)
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
||||
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
@@ -396,10 +403,11 @@ class PaliGemmaWithExpertModel(
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||
# "same dtype" error). Saves memory vs full float32; more memory than only 3 params.
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
@@ -411,8 +419,8 @@ class PaliGemmaWithExpertModel(
|
||||
|
||||
def _set_requires_grad(self):
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
for param in self.paligemma.vision_tower.parameters():
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
for param in self.paligemma.model.vision_tower.parameters():
|
||||
param.requires_grad = False
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
@@ -422,15 +430,23 @@ class PaliGemmaWithExpertModel(
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.paligemma.model.get_image_features(image)
|
||||
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32).
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -444,7 +460,7 @@ class PaliGemmaWithExpertModel(
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.language_model.forward(
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@@ -468,7 +484,7 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.language_model, self.gemma_expert.model]
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
@@ -508,7 +524,7 @@ class PaliGemmaWithExpertModel(
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -573,29 +589,19 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Also compile the main forward pass used during training
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
||||
|
||||
@@ -737,7 +743,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -808,7 +814,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks_4d,
|
||||
@@ -880,6 +886,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
@@ -969,14 +976,12 @@ class PI05Policy(PreTrainedPolicy):
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# Now manually load and remap the state dict
|
||||
# Load state dict (expects keys with "model." prefix)
|
||||
try:
|
||||
# Try to load the pytorch_model.bin or model.safetensors file
|
||||
print(f"Loading model from: {pretrained_name_or_path}")
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_name_or_path,
|
||||
"model.safetensors",
|
||||
@@ -984,7 +989,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
token=kwargs.get("token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
@@ -997,7 +1002,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
@@ -1009,8 +1014,6 @@ class PI05Policy(PreTrainedPolicy):
|
||||
new_key = f"model.{key}"
|
||||
remapped_state_dict[new_key] = value
|
||||
remap_count += 1
|
||||
if remap_count <= 10: # Only print first 10 to avoid spam
|
||||
print(f"Remapped: {key} -> {new_key}")
|
||||
else:
|
||||
remapped_state_dict[key] = value
|
||||
|
||||
@@ -1044,7 +1047,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
print("All keys loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not remap state dict keys: {e}")
|
||||
print(f"Warning: Could not load state dict: {e}")
|
||||
|
||||
return model
|
||||
|
||||
@@ -1098,6 +1101,14 @@ class PI05Policy(PreTrainedPolicy):
|
||||
# Some checkpoints might have this, but current model expects different structure
|
||||
logging.warning(f"Vision embedding key might need handling: {key}")
|
||||
|
||||
if (
|
||||
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
|
||||
or key == "paligemma_with_expert.paligemma.lm_head.weight"
|
||||
):
|
||||
fixed_state_dict[
|
||||
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
] = value.clone()
|
||||
|
||||
fixed_state_dict[new_key] = value
|
||||
|
||||
return fixed_state_dict
|
||||
|
||||
@@ -23,7 +23,6 @@ import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
@@ -68,9 +67,6 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.max_state_dim)
|
||||
|
||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
|
||||
@@ -54,7 +54,7 @@ class PI0FastConfig(PreTrainedConfig):
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
|
||||
action_tokenizer_name: str = "physical-intelligence/fast"
|
||||
action_tokenizer_name: str = "lerobot/fast-action-tokenizer"
|
||||
temperature: float = 0.0
|
||||
max_decoding_steps: int = 256
|
||||
fast_skip_tokens: int = 128
|
||||
|
||||
@@ -19,13 +19,12 @@ import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _scipy_available, _transformers_available
|
||||
|
||||
@@ -38,11 +37,16 @@ else:
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaModel,
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
AutoTokenizer = None
|
||||
PiGemmaModel = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||
@@ -121,7 +125,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
resized_images = resized_images.clamp(0.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
@@ -132,7 +136,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
@@ -206,16 +210,22 @@ class PI0FastPaliGemma(nn.Module):
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
|
||||
# Use PI Gemma (AdaRMS) as language model when use_adarms[0] is True so that
|
||||
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
|
||||
if use_adarms[0]:
|
||||
text_config = self.paligemma.config.text_config
|
||||
self.paligemma.model.language_model = PiGemmaModel(text_config)
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
|
||||
@@ -228,10 +238,11 @@ class PI0FastPaliGemma(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||
# "same dtype" error). Align with PI05.
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
@@ -242,10 +253,18 @@ class PI0FastPaliGemma(nn.Module):
|
||||
param.data = param.data.to(dtype=torch.float32)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.paligemma.model.get_image_features(image)
|
||||
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -259,7 +278,7 @@ class PI0FastPaliGemma(nn.Module):
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.language_model.forward(
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@@ -306,24 +325,14 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode)
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
# Call the proper gradient_checkpointing_enable() method with use_reentrant=False for better memory efficiency
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_enable(
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_enable(
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
logging.info("Enabled gradient checkpointing for PI0FastPytorch model")
|
||||
@@ -332,8 +341,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
# Call the proper gradient_checkpointing_disable() method
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_disable()
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_disable()
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_disable()
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_disable()
|
||||
logging.info("Disabled gradient checkpointing for PI0FastPytorch model")
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
@@ -523,7 +532,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
# Convert embeddings to bfloat16 if needed
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -616,7 +625,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -714,7 +723,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
# Ensure correct precision (bfloat16/float32)
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -897,14 +906,12 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# Now manually load and remap the state dict
|
||||
# Load state dict (expects keys with "model." prefix)
|
||||
try:
|
||||
# Try to load the pytorch_model.bin or model.safetensors file
|
||||
print(f"Loading model from: {pretrained_name_or_path}")
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_name_or_path,
|
||||
"model.safetensors",
|
||||
@@ -912,7 +919,7 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
token=kwargs.get("token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
@@ -925,8 +932,9 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
remapped_state_dict = {}
|
||||
remap_count = 0
|
||||
@@ -936,8 +944,6 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
new_key = f"model.{key}"
|
||||
remapped_state_dict[new_key] = value
|
||||
remap_count += 1
|
||||
if remap_count <= 10: # Only print first 10 to avoid spam
|
||||
print(f"Remapped: {key} -> {new_key}")
|
||||
else:
|
||||
remapped_state_dict[key] = value
|
||||
|
||||
@@ -971,7 +977,7 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
print("All keys loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not remap state dict keys: {e}")
|
||||
print(f"Warning: Could not load state dict: {e}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||
from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
|
||||
from lerobot.processor import (
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -69,9 +68,6 @@ class Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.max_state_dim)
|
||||
|
||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
|
||||
@@ -0,0 +1,363 @@
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.masking_utils import create_causal_mask
|
||||
from transformers.modeling_layers import GradientCheckpointingLayer
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.gemma.modeling_gemma import (
|
||||
GemmaAttention,
|
||||
GemmaConfig,
|
||||
GemmaForCausalLM,
|
||||
GemmaMLP,
|
||||
GemmaModel,
|
||||
)
|
||||
from transformers.models.paligemma.modeling_paligemma import (
|
||||
PaliGemmaForConditionalGeneration,
|
||||
PaliGemmaModel,
|
||||
)
|
||||
else:
|
||||
GemmaAttention = None
|
||||
GemmaConfig = None
|
||||
GemmaForCausalLM = None
|
||||
GemmaMLP = None
|
||||
GemmaModel = None
|
||||
PaliGemmaModel = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
DynamicCache = None
|
||||
GradientCheckpointingLayer = None
|
||||
BaseModelOutputWithPast = None
|
||||
create_causal_mask = None
|
||||
|
||||
|
||||
def _gated_residual(
|
||||
x: torch.Tensor | None,
|
||||
y: torch.Tensor | None,
|
||||
gate: torch.Tensor | None,
|
||||
) -> torch.Tensor | None:
|
||||
"""Gated residual: x + y when gate is None, else x + y * gate."""
|
||||
if x is None and y is None:
|
||||
return None
|
||||
if x is None or y is None:
|
||||
return x if x is not None else y
|
||||
if gate is None:
|
||||
return x + y
|
||||
return x + y * gate
|
||||
|
||||
|
||||
def layernorm_forward(
|
||||
layernorm: nn.Module,
|
||||
x: torch.Tensor,
|
||||
cond: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
call layernorm and return hidden states and gate
|
||||
if cond is not None, use conditional norm
|
||||
otherwise, use normal gemma norm
|
||||
"""
|
||||
if cond is not None:
|
||||
return layernorm(x, cond=cond)
|
||||
else:
|
||||
return layernorm(x)
|
||||
|
||||
|
||||
class PiGemmaRMSNorm(nn.Module):
|
||||
"""
|
||||
Adaptive RMSNorm for PI Gemma (AdaRMS).
|
||||
When cond_dim is set, uses cond to modulate scale/shift/gate; otherwise behaves like standard GemmaRMSNorm.
|
||||
forward(x, cond=None) returns (output, gate) for use with _gated_residual.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.dim = dim
|
||||
self.cond_dim = cond_dim
|
||||
if cond_dim is not None:
|
||||
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
|
||||
nn.init.zeros_(self.dense.weight)
|
||||
else:
|
||||
self.weight = nn.Parameter(torch.zeros(dim))
|
||||
self.dense = None
|
||||
|
||||
def _norm(self, x):
|
||||
# Compute variance in float32 (like the source implementation)
|
||||
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
|
||||
# Compute normalization in float32
|
||||
normed_inputs = x * torch.rsqrt(var + self.eps)
|
||||
return normed_inputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cond: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
dtype = x.dtype
|
||||
normed = self._norm(x)
|
||||
if cond is None or self.dense is None:
|
||||
normed = normed * (1.0 + self.weight.float())
|
||||
return normed.type_as(x), None
|
||||
if cond.shape[-1] != self.cond_dim:
|
||||
raise ValueError(f"Expected cond dim {self.cond_dim}, got {cond.shape[-1]}")
|
||||
modulation = self.dense(cond)
|
||||
if len(x.shape) == 3:
|
||||
modulation = modulation.unsqueeze(1)
|
||||
scale, shift, gate = modulation.chunk(3, dim=-1)
|
||||
normed = normed * (1 + scale.float()) + shift.float()
|
||||
return normed.to(dtype), gate.to(dtype)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
if self.dense is not None:
|
||||
return f"dim={self.dim}, eps={self.eps}, adaptive=True, cond_dim={self.cond_dim}"
|
||||
return f"dim={self.dim}, eps={self.eps}"
|
||||
|
||||
|
||||
def _get_pi_gemma_decoder_layer_base():
|
||||
"""base for PiGemmaDecoderLayer"""
|
||||
|
||||
class _PiGemmaDecoderLayerBase(GradientCheckpointingLayer):
|
||||
"""Decoder layer that uses PiGemmaRMSNorm and _gated_residual, compatible with v5 Gemma."""
|
||||
|
||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = GemmaMLP(config)
|
||||
cond_dim = (
|
||||
getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
|
||||
)
|
||||
self.input_layernorm = PiGemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
||||
)
|
||||
self.post_attention_layernorm = PiGemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values=None,
|
||||
use_cache: bool = False,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states, gate = self.input_layernorm(hidden_states, cond=adarms_cond)
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states, gate = self.post_attention_layernorm(hidden_states, cond=adarms_cond)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||
return hidden_states
|
||||
|
||||
return _PiGemmaDecoderLayerBase
|
||||
|
||||
|
||||
class PiGemmaModel(GemmaModel): # type: ignore[misc]
|
||||
"""
|
||||
GemmaModel extended with AdaRMS (adaptive RMSNorm) and gated residuals when config.use_adarms is True.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GemmaConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
# if not getattr(config, "use_adarms", False):
|
||||
# return
|
||||
cond_dim = getattr(config, "adarms_cond_dim", None)
|
||||
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
|
||||
self.layers = nn.ModuleList(
|
||||
[pi_gemma_decoder_layer_base(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = PiGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: DynamicCache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> BaseModelOutputWithPast:
|
||||
"""
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
import logging
|
||||
|
||||
logging.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
# Convert to bfloat16 if the first layer uses bfloat16
|
||||
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.bfloat16)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# normalized
|
||||
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
adarms_cond=adarms_cond,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, adarms_cond)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
|
||||
"""
|
||||
Causal LM wrapper using PiGemmaModel as the backbone, for consistency with GemmaForCausalLM
|
||||
and the language model used in pi0_fast. Use this for the action expert in pi0/pi05.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GemmaConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.model = PiGemmaModel(config)
|
||||
|
||||
|
||||
class PaliGemmaModelWithPiGemma(PaliGemmaModel):
|
||||
"""PaliGemmaModel whose language_model is PiGemmaModel (custom decoder with PiGemmaRMSNorm and gated residuals)."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.language_model = PiGemmaModel(config.text_config)
|
||||
|
||||
|
||||
class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGeneration):
|
||||
"""PaliGemmaForConditionalGeneration using PiGemma decoder for the language model."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = PaliGemmaModelWithPiGemma(config)
|
||||
|
||||
# Make modules available through conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
return self.model.language_model
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PiGemmaModel",
|
||||
"PiGemmaForCausalLM",
|
||||
"PiGemmaRMSNorm",
|
||||
"_gated_residual",
|
||||
"layernorm_forward",
|
||||
"PaliGemmaModelWithPiGemma",
|
||||
"PaliGemmaForConditionalGenerationWithPiGemma",
|
||||
]
|
||||
@@ -19,7 +19,7 @@ import os
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import TypedDict, TypeVar
|
||||
from typing import TypedDict, TypeVar, Unpack
|
||||
|
||||
import packaging
|
||||
import safetensors
|
||||
@@ -28,7 +28,6 @@ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Action interpolation for smoother robot control.
|
||||
|
||||
Provides configurable Nx control rate by interpolating between consecutive actions.
|
||||
Useful with RTC and action-chunking policies to reduce jerkiness.
|
||||
"""
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ActionInterpolator:
|
||||
"""Interpolates between consecutive actions for smoother control.
|
||||
|
||||
When enabled with multiplier N, produces N actions per policy action
|
||||
by linearly interpolating between the previous and current action.
|
||||
|
||||
Example with multiplier=3:
|
||||
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
|
||||
|
||||
This effectively multiplies the control rate for smoother motion.
|
||||
|
||||
Usage:
|
||||
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
|
||||
|
||||
# In control loop:
|
||||
if interpolator.needs_new_action():
|
||||
new_action = queue.get()
|
||||
if new_action:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action:
|
||||
robot.send_action(action)
|
||||
"""
|
||||
|
||||
def __init__(self, multiplier: int = 1):
|
||||
"""Initialize the interpolator.
|
||||
|
||||
Args:
|
||||
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
|
||||
"""
|
||||
if multiplier < 1:
|
||||
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
|
||||
self.multiplier = multiplier
|
||||
self._prev: Tensor | None = None
|
||||
self._buffer: list[Tensor] = []
|
||||
self._idx = 0
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Whether interpolation is active (multiplier > 1)."""
|
||||
return self.multiplier > 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset interpolation state (call between episodes)."""
|
||||
self._prev = None
|
||||
self._buffer = []
|
||||
self._idx = 0
|
||||
|
||||
def needs_new_action(self) -> bool:
|
||||
"""Check if a new action is needed from the queue."""
|
||||
return self._idx >= len(self._buffer)
|
||||
|
||||
def add(self, action: Tensor) -> None:
|
||||
"""Add a new action and compute interpolated sequence.
|
||||
|
||||
Args:
|
||||
action: New action tensor from policy/queue (already on CPU).
|
||||
"""
|
||||
if self.multiplier > 1 and self._prev is not None:
|
||||
self._buffer = []
|
||||
for i in range(1, self.multiplier + 1):
|
||||
t = i / self.multiplier
|
||||
interp = self._prev + t * (action - self._prev)
|
||||
self._buffer.append(interp)
|
||||
else:
|
||||
self._buffer = [action]
|
||||
self._prev = action
|
||||
self._idx = 0
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
"""Get the next interpolated action.
|
||||
|
||||
Returns:
|
||||
Next action tensor, or None if buffer is exhausted.
|
||||
"""
|
||||
if self._idx >= len(self._buffer):
|
||||
return None
|
||||
action = self._buffer[self._idx]
|
||||
self._idx += 1
|
||||
return action
|
||||
|
||||
def get_control_interval(self, fps: float) -> float:
|
||||
"""Get the control interval based on interpolation multiplier.
|
||||
|
||||
Args:
|
||||
fps: Base frames per second.
|
||||
|
||||
Returns:
|
||||
Control interval in seconds (divided by multiplier).
|
||||
"""
|
||||
return 1.0 / (fps * self.multiplier)
|
||||
@@ -33,7 +33,7 @@ class RewardClassifierConfig(PreTrainedConfig):
|
||||
latent_dim: int = 256
|
||||
image_embedding_pooling_dim: int = 8
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "helper2424/resnet10"
|
||||
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
|
||||
@@ -277,9 +277,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
|
||||
if self.dataset_meta is not None:
|
||||
episodes_df = None
|
||||
if self.sparse_subtask_names != ["task"]:
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
|
||||
# Generate sparse targets
|
||||
if self.sparse_temporal_proportions is not None:
|
||||
|
||||
@@ -106,6 +106,9 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@@ -54,12 +54,11 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import TypedDict
|
||||
from typing import TypedDict, Unpack
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
@@ -593,6 +592,12 @@ class VLAFlowMatching(nn.Module):
|
||||
self.prefix_length = self.config.prefix_length
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
# Compile model if requested
|
||||
if config.compile_model:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
|
||||
@@ -77,7 +77,6 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
print(f"Loading {model_id} weights ...")
|
||||
self.vlm = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
device_map=device,
|
||||
torch_dtype="bfloat16",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
@@ -55,7 +55,7 @@ class WallXConfig(PreTrainedConfig):
|
||||
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
|
||||
|
||||
# Tokenizer settings
|
||||
action_tokenizer_path: str | None = "physical-intelligence/fast"
|
||||
action_tokenizer_path: str | None = "lerobot/fast-action-tokenizer"
|
||||
|
||||
# Action prediction mode: "diffusion" or "fast"
|
||||
prediction_mode: str = "diffusion"
|
||||
|
||||
@@ -261,10 +261,15 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
and optional LoRA fine-tuning support.
|
||||
"""
|
||||
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
config_class = Qwen2_5_VLConfig
|
||||
_no_split_modules = ["Qwen2_5_VLDecoderLayer_with_MoE", "Qwen2_5_VLVisionBlock"]
|
||||
|
||||
def init_weights(self):
|
||||
if getattr(self.model, "language_model", None) is not None:
|
||||
return
|
||||
super().init_weights()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -312,6 +317,11 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
processor.action_processor = action_tokenizer
|
||||
else:
|
||||
action_tokenizer = None
|
||||
|
||||
# add pad_token_id to config
|
||||
config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
config.text_config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
|
||||
# Initialize model with configuration and processor
|
||||
model = cls(config, processor=processor, action_tokenizer=action_tokenizer, **kwargs)
|
||||
|
||||
@@ -331,7 +341,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
token=kwargs.get("token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
window_size=112,
|
||||
out_hidden_size=3584,
|
||||
fullatt_block_indexes=[7, 15, 23, 31],
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@@ -38,6 +39,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
self.window_size = window_size
|
||||
self.fullatt_block_indexes = fullatt_block_indexes
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
|
||||
@@ -11,7 +11,6 @@ from transformers.activations import ACT2FN
|
||||
from transformers.cache_utils import (
|
||||
Cache,
|
||||
DynamicCache,
|
||||
SlidingWindowCache,
|
||||
StaticCache,
|
||||
)
|
||||
from transformers.generation import GenerationMixin
|
||||
@@ -31,6 +30,15 @@ from transformers.utils import (
|
||||
|
||||
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
|
||||
|
||||
|
||||
# TODO(Steven): SlidingWindowCache was removed in transformers v5. Define a placeholder so isinstance checks
|
||||
# always return False (which is the correct behavior when no sliding window cache is in use).
|
||||
class _SlidingWindowCachePlaceholder:
|
||||
pass
|
||||
|
||||
|
||||
SlidingWindowCache = _SlidingWindowCachePlaceholder
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
@@ -594,19 +602,40 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _compute_default_rope_parameters_qwen2_5_vl(config, device=None):
|
||||
"""
|
||||
compute default rope parameters for Qwen2_5_VL
|
||||
"""
|
||||
base = config.text_config.rope_parameters["rope_theta"]
|
||||
dim = config.hidden_size // config.num_attention_heads
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
||||
)
|
||||
return inv_freq, 1.0
|
||||
|
||||
|
||||
class Qwen2_5_VLRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Qwen2_5_VLConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
elif hasattr(config, "rope_parameters") and config.rope_parameters is not None:
|
||||
self.rope_type = config.rope_parameters.get("rope_type", "default")
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
if self.rope_type == "default":
|
||||
self.rope_init_fn = _compute_default_rope_parameters_qwen2_5_vl
|
||||
self.rope_kwargs = {}
|
||||
else:
|
||||
rope_type_key = "linear" if self.rope_type == "linear" else self.rope_type
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type_key]
|
||||
self.rope_kwargs = {}
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
@@ -1567,7 +1596,7 @@ QWEN2_5_VL_INPUTS_DOCSTRING = r"""
|
||||
|
||||
|
||||
class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
config_class = Qwen2_5_VLConfig
|
||||
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ def preprocesser_call(
|
||||
"""
|
||||
# Process image inputs
|
||||
if images is not None and len(images) > 0:
|
||||
image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
|
||||
image_inputs = processor.image_processor(images=images, return_tensors=return_tensors)
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
@@ -152,7 +152,7 @@ def preprocesser_call(
|
||||
|
||||
# Process video inputs
|
||||
if videos is not None:
|
||||
videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
|
||||
videos_inputs = processor.image_processor(videos=videos, return_tensors=return_tensors)
|
||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||
else:
|
||||
videos_inputs = {}
|
||||
|
||||
@@ -276,6 +276,8 @@ class Florence2LanguageConfig(PretrainedConfig):
|
||||
)
|
||||
|
||||
# ensure backward compatibility for BART CNN models
|
||||
if not hasattr(self, "forced_bos_token_id"):
|
||||
self.forced_bos_token_id = None
|
||||
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
||||
self.forced_bos_token_id = self.bos_token_id
|
||||
warnings.warn(
|
||||
|
||||
@@ -1951,7 +1951,10 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
||||
|
||||
|
||||
class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
_tied_weights_keys = {
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: Florence2LanguageConfig):
|
||||
super().__init__(config)
|
||||
@@ -2076,7 +2079,10 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
||||
|
||||
class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"model.encoder.embed_tokens.weight": "model.shared.weight",
|
||||
"model.decoder.embed_tokens.weight": "model.shared.weight",
|
||||
}
|
||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||
|
||||
def __init__(self, config: Florence2LanguageConfig):
|
||||
@@ -2436,11 +2442,10 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
|
||||
FLORENCE2_START_DOCSTRING,
|
||||
)
|
||||
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
||||
_tied_weights_keys = [
|
||||
"language_model.encoder.embed_tokens.weight",
|
||||
"language_model.decoder.embed_tokens.weight",
|
||||
"language_model.lm_head.weight",
|
||||
]
|
||||
_tied_weights_keys = {
|
||||
"language_model.model.encoder.embed_tokens.weight": "language_model.model.shared.weight",
|
||||
"language_model.model.decoder.embed_tokens.weight": "language_model.model.shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: Florence2Config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, TypeAlias, TypedDict
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -36,10 +36,10 @@ class TransitionKey(str, Enum):
|
||||
COMPLEMENTARY_DATA = "complementary_data"
|
||||
|
||||
|
||||
PolicyAction: TypeAlias = torch.Tensor
|
||||
RobotAction: TypeAlias = dict[str, Any]
|
||||
EnvAction: TypeAlias = np.ndarray
|
||||
RobotObservation: TypeAlias = dict[str, Any]
|
||||
PolicyAction = torch.Tensor
|
||||
RobotAction = dict[str, Any]
|
||||
EnvAction = np.ndarray
|
||||
RobotObservation = dict[str, Any]
|
||||
|
||||
|
||||
EnvTransition = TypedDict(
|
||||
|
||||
@@ -153,6 +153,44 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="robocasa_processor")
|
||||
class RoboCasaProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes RoboCasa observations into LeRobot format.
|
||||
|
||||
The RoboCasaEnv wrapper returns:
|
||||
- ``pixels.<cam_name>``: (B, C, H, W) float32 images (already converted by vectorenv)
|
||||
- ``observation.robot_state``: (B, 16) float32 proprioception
|
||||
|
||||
This step remaps them to:
|
||||
- ``observation.images.<cam_name>`` (unchanged tensor)
|
||||
- ``observation.state`` (robot_state renamed)
|
||||
"""
|
||||
|
||||
def _process_observation(self, observation: dict) -> dict:
|
||||
processed = {}
|
||||
obs_prefix = OBS_PREFIX # "observation."
|
||||
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}."):
|
||||
# Already in the right place; pass through
|
||||
processed[key] = value
|
||||
elif key == OBS_STATE or key == f"{obs_prefix}robot_state":
|
||||
# Rename robot_state → observation.state
|
||||
processed[OBS_STATE] = value.float() if hasattr(value, "float") else value
|
||||
|
||||
return processed
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
return self._process_observation(observation)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="isaaclab_arena_processor")
|
||||
class IsaaclabArenaProcessorStep(ObservationProcessorStep):
|
||||
|
||||
@@ -39,7 +39,7 @@ from collections.abc import Callable, Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast
|
||||
from typing import Any, TypedDict, TypeVar, cast
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -251,7 +251,7 @@ class ProcessorMigrationError(Exception):
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
||||
class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
"""A sequential pipeline for processing data, integrated with the Hugging Face Hub.
|
||||
|
||||
This class chains together multiple `ProcessorStep` instances to form a complete
|
||||
@@ -1432,8 +1432,8 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
||||
|
||||
|
||||
# Type aliases for semantic clarity.
|
||||
RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
|
||||
PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
|
||||
RobotProcessorPipeline = DataProcessorPipeline[TInput, TOutput]
|
||||
PolicyProcessorPipeline = DataProcessorPipeline[TInput, TOutput]
|
||||
|
||||
|
||||
class ObservationProcessorStep(ProcessorStep, ABC):
|
||||
|
||||
@@ -336,7 +336,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Attributes:
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "lerobot/fast-action-tokenizer").
|
||||
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
||||
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypeAlias
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
@@ -50,5 +49,5 @@ class SOFollowerRobotConfig(RobotConfig, SOFollowerConfig):
|
||||
pass
|
||||
|
||||
|
||||
SO100FollowerConfig: TypeAlias = SOFollowerRobotConfig
|
||||
SO101FollowerConfig: TypeAlias = SOFollowerRobotConfig
|
||||
SO100FollowerConfig = SOFollowerRobotConfig
|
||||
SO101FollowerConfig = SOFollowerRobotConfig
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import TypeAlias
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
@@ -230,5 +229,5 @@ class SOFollower(Robot):
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
|
||||
SO100Follower: TypeAlias = SOFollower
|
||||
SO101Follower: TypeAlias = SOFollower
|
||||
SO100Follower = SOFollower
|
||||
SO101Follower = SOFollower
|
||||
|
||||
@@ -16,3 +16,5 @@
|
||||
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
from .unitree_g1 import UnitreeG1
|
||||
|
||||
__all__ = ["UnitreeG1", "UnitreeG1Config"]
|
||||
|
||||
@@ -27,11 +27,10 @@ _GAINS: dict[str, dict[str, list[float]]] = {
|
||||
}, # pitch, roll, yaw, knee, ankle_pitch, ankle_roll
|
||||
"right_leg": {"kp": [150, 150, 150, 300, 40, 40], "kd": [2, 2, 2, 4, 2, 2]},
|
||||
"waist": {"kp": [250, 250, 250], "kd": [5, 5, 5]}, # yaw, roll, pitch
|
||||
"left_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]}, # shoulder_pitch/roll/yaw, elbow
|
||||
"left_arm": {"kp": [50, 50, 80, 80], "kd": [3, 3, 3, 3]}, # shoulder_pitch/roll/yaw, elbow
|
||||
"left_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]}, # roll, pitch, yaw
|
||||
"right_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]},
|
||||
"right_arm": {"kp": [50, 50, 80, 80], "kd": [3, 3, 3, 3]},
|
||||
"right_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]},
|
||||
"other": {"kp": [80, 80, 80, 80, 80, 80], "kd": [3, 3, 3, 3, 3, 3]},
|
||||
}
|
||||
|
||||
|
||||
@@ -68,3 +67,7 @@ class UnitreeG1Config(RobotConfig):
|
||||
|
||||
# Compensates for gravity on the unitree's arms using the arm ik solver
|
||||
gravity_compensation: bool = False
|
||||
|
||||
# Lower-body controller class name, e.g. "GrootLocomotionController" or
|
||||
# "HolosomaLocomotionController". None disables it.
|
||||
controller: str | None = None
|
||||
|
||||
+20
-46
@@ -16,13 +16,11 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
parent2_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(parent2_dir)
|
||||
|
||||
|
||||
class WeightedMovingFilter:
|
||||
@@ -31,18 +29,14 @@ class WeightedMovingFilter:
|
||||
self._weights = np.array(weights)
|
||||
self._data_size = data_size
|
||||
self._filtered_data = np.zeros(self._data_size)
|
||||
self._data_queue = []
|
||||
self._data_queue = deque(maxlen=self._window_size)
|
||||
|
||||
def _apply_filter(self):
|
||||
if len(self._data_queue) < self._window_size:
|
||||
return self._data_queue[-1]
|
||||
|
||||
data_array = np.array(self._data_queue)
|
||||
temp_filtered_data = np.zeros(self._data_size)
|
||||
for i in range(self._data_size):
|
||||
temp_filtered_data[i] = np.convolve(data_array[:, i], self._weights, mode="valid")[-1]
|
||||
|
||||
return temp_filtered_data
|
||||
return data_array.T @ self._weights
|
||||
|
||||
def add_data(self, new_data):
|
||||
assert len(new_data) == self._data_size
|
||||
@@ -52,9 +46,6 @@ class WeightedMovingFilter:
|
||||
): # skip duplicate data
|
||||
return
|
||||
|
||||
if len(self._data_queue) >= self._window_size:
|
||||
self._data_queue.pop(0)
|
||||
|
||||
self._data_queue.append(new_data)
|
||||
self._filtered_data = self._apply_filter()
|
||||
|
||||
@@ -71,8 +62,6 @@ class G1_29_ArmIK: # noqa: N801
|
||||
from pinocchio import casadi as cpin
|
||||
|
||||
self._pin = pin
|
||||
np.set_printoptions(precision=5, suppress=True, linewidth=200)
|
||||
|
||||
self.unit_test = unit_test
|
||||
|
||||
self.repo_path = snapshot_download("lerobot/unitree-g1-mujoco")
|
||||
@@ -249,50 +238,35 @@ class G1_29_ArmIK: # noqa: N801
|
||||
self.opti.set_value(self.param_tf_r, right_wrist)
|
||||
self.opti.set_value(self.var_q_last, self.init_data) # for smooth
|
||||
|
||||
converged = True
|
||||
try:
|
||||
self.opti.solve()
|
||||
|
||||
sol_q = self.opti.value(self.var_q)
|
||||
self.smooth_filter.add_data(sol_q)
|
||||
sol_q = self.smooth_filter.filtered_data
|
||||
|
||||
if current_lr_arm_motor_dq is not None:
|
||||
v = current_lr_arm_motor_dq * 0.0
|
||||
else:
|
||||
v = (sol_q - self.init_data) * 0.0
|
||||
|
||||
self.init_data = sol_q
|
||||
|
||||
sol_tauff = self._pin.rnea(
|
||||
self.reduced_robot.model,
|
||||
self.reduced_robot.data,
|
||||
sol_q,
|
||||
v,
|
||||
np.zeros(self.reduced_robot.model.nv),
|
||||
)
|
||||
|
||||
return sol_q, sol_tauff
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR in convergence, plotting debug info.{e}")
|
||||
|
||||
converged = False
|
||||
logger.error(f"IK convergence error: {e}")
|
||||
sol_q = self.opti.debug.value(self.var_q)
|
||||
self.smooth_filter.add_data(sol_q)
|
||||
sol_q = self.smooth_filter.filtered_data
|
||||
|
||||
if current_lr_arm_motor_dq is not None:
|
||||
v = current_lr_arm_motor_dq * 0.0
|
||||
else:
|
||||
v = (sol_q - self.init_data) * 0.0
|
||||
|
||||
self.init_data = sol_q
|
||||
self.smooth_filter.add_data(sol_q)
|
||||
sol_q = self.smooth_filter.filtered_data
|
||||
self.init_data = sol_q
|
||||
|
||||
if not converged:
|
||||
logger.error(
|
||||
f"sol_q:{sol_q} \nmotorstate: \n{current_lr_arm_motor_q} \nleft_pose: \n{left_wrist} \nright_pose: \n{right_wrist}"
|
||||
)
|
||||
|
||||
return current_lr_arm_motor_q, np.zeros(self.reduced_robot.model.nv)
|
||||
|
||||
sol_tauff = self._pin.rnea(
|
||||
self.reduced_robot.model,
|
||||
self.reduced_robot.data,
|
||||
sol_q,
|
||||
np.zeros(self.reduced_robot.model.nv),
|
||||
np.zeros(self.reduced_robot.model.nv),
|
||||
)
|
||||
|
||||
return sol_q, sol_tauff
|
||||
|
||||
def solve_tau(self, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None):
|
||||
try:
|
||||
q_g1 = np.array(current_lr_arm_motor_q, dtype=float)
|
||||
@@ -14,12 +14,34 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
from enum import IntEnum
|
||||
|
||||
import numpy as np
|
||||
|
||||
# ruff: noqa: N801, N815
|
||||
|
||||
NUM_MOTORS = 29
|
||||
|
||||
REMOTE_AXES = ("remote.lx", "remote.ly", "remote.rx", "remote.ry")
|
||||
REMOTE_BUTTONS = tuple(f"remote.button.{i}" for i in range(16))
|
||||
REMOTE_KEYS = REMOTE_AXES + REMOTE_BUTTONS
|
||||
|
||||
|
||||
def default_remote_input() -> dict[str, float]:
|
||||
"""Return a zeroed-out remote input dict (axes + buttons)."""
|
||||
return dict.fromkeys(REMOTE_KEYS, 0.0)
|
||||
|
||||
|
||||
def get_gravity_orientation(quaternion: list[float] | np.ndarray) -> np.ndarray:
|
||||
"""Get gravity orientation from quaternion [w, x, y, z]."""
|
||||
qw, qx, qy, qz = quaternion
|
||||
gravity_orientation = np.zeros(3, dtype=np.float32)
|
||||
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
|
||||
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
|
||||
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
|
||||
return gravity_orientation
|
||||
|
||||
|
||||
class G1_29_JointArmIndex(IntEnum):
|
||||
# Left arm
|
||||
@@ -29,7 +51,7 @@ class G1_29_JointArmIndex(IntEnum):
|
||||
kLeftElbow = 18
|
||||
kLeftWristRoll = 19
|
||||
kLeftWristPitch = 20
|
||||
kLeftWristyaw = 21
|
||||
kLeftWristYaw = 21
|
||||
|
||||
# Right arm
|
||||
kRightShoulderPitch = 22
|
||||
@@ -41,6 +63,21 @@ class G1_29_JointArmIndex(IntEnum):
|
||||
kRightWristYaw = 28
|
||||
|
||||
|
||||
def make_locomotion_controller(name: str | None):
|
||||
"""Instantiate a locomotion controller by class name. Returns None if name is None."""
|
||||
if name is None:
|
||||
return None
|
||||
controllers = {
|
||||
"GrootLocomotionController": "lerobot.robots.unitree_g1.gr00t_locomotion",
|
||||
"HolosomaLocomotionController": "lerobot.robots.unitree_g1.holosoma_locomotion",
|
||||
}
|
||||
module_path = controllers.get(name)
|
||||
if module_path is None:
|
||||
raise ValueError(f"Unknown controller: {name!r}. Available: {list(controllers)}")
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, name)()
|
||||
|
||||
|
||||
class G1_29_JointIndex(IntEnum):
|
||||
# Left leg
|
||||
kLeftHipPitch = 0
|
||||
@@ -69,7 +106,7 @@ class G1_29_JointIndex(IntEnum):
|
||||
kLeftElbow = 18
|
||||
kLeftWristRoll = 19
|
||||
kLeftWristPitch = 20
|
||||
kLeftWristyaw = 21
|
||||
kLeftWristYaw = 21
|
||||
|
||||
# Right arm
|
||||
kRightShoulderPitch = 22
|
||||
|
||||
+52
-105
@@ -14,20 +14,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
from lerobot.robots.unitree_g1.g1_utils import (
|
||||
REMOTE_AXES,
|
||||
REMOTE_BUTTONS,
|
||||
G1_29_JointIndex,
|
||||
get_gravity_orientation,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -36,18 +36,13 @@ GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # Hip pitch
|
||||
GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # Knee
|
||||
GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # Ankle pitch
|
||||
|
||||
MISSING_JOINTS = []
|
||||
G1_MODEL = "g1_23" # Or "g1_29"
|
||||
if G1_MODEL == "g1_23":
|
||||
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # Waist yaw/pitch, wrist pitch/yaw
|
||||
|
||||
# Control parameters
|
||||
ACTION_SCALE = 0.25
|
||||
CONTROL_DT = 0.02 # 50Hz
|
||||
ANG_VEL_SCALE: float = 0.25
|
||||
DOF_POS_SCALE: float = 1.0
|
||||
DOF_VEL_SCALE: float = 0.05
|
||||
CMD_SCALE: list = [2.0, 2.0, 0.25]
|
||||
CMD_SCALE: list[float] = [2.0, 2.0, 0.25]
|
||||
|
||||
|
||||
DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
|
||||
@@ -85,11 +80,11 @@ def load_groot_policies(
|
||||
class GrootLocomotionController:
|
||||
"""GR00T lower-body locomotion controller for the Unitree G1."""
|
||||
|
||||
def __init__(self, policy_balance, policy_walk, robot, config):
|
||||
self.policy_balance = policy_balance
|
||||
self.policy_walk = policy_walk
|
||||
self.robot = robot
|
||||
self.config = config
|
||||
control_dt = CONTROL_DT # Expose for unitree_g1.py
|
||||
|
||||
def __init__(self):
|
||||
# Load policies
|
||||
self.policy_balance, self.policy_walk = load_groot_policies()
|
||||
|
||||
self.cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot
|
||||
|
||||
@@ -109,45 +104,60 @@ class GrootLocomotionController:
|
||||
|
||||
logger.info("GrootLocomotionController initialized")
|
||||
|
||||
def run_step(self):
|
||||
# Get current observation
|
||||
obs = self.robot.get_observation()
|
||||
def reset(self) -> None:
|
||||
"""Reset internal state for a new episode."""
|
||||
self.cmd[:] = 0.0
|
||||
self.groot_qj_all[:] = 0.0
|
||||
self.groot_dqj_all[:] = 0.0
|
||||
self.groot_action[:] = 0.0
|
||||
self.groot_obs_single[:] = 0.0
|
||||
self.groot_obs_stacked[:] = 0.0
|
||||
self.groot_height_cmd = 0.74
|
||||
self.groot_orientation_cmd[:] = 0.0
|
||||
self.groot_obs_history.clear()
|
||||
for _ in range(6):
|
||||
self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
|
||||
|
||||
if not obs:
|
||||
return
|
||||
def run_step(self, action: dict, lowstate) -> dict:
|
||||
"""Run one step of the locomotion controller.
|
||||
|
||||
# Get command from remote controller
|
||||
if obs["remote.buttons"][0]: # R1 - raise waist
|
||||
Args:
|
||||
action: Action dict containing remote.lx/ly/rx/ry and buttons
|
||||
lowstate: Robot lowstate containing motor positions/velocities and IMU
|
||||
|
||||
Returns:
|
||||
Action dict for lower body joints (0-14)
|
||||
"""
|
||||
if lowstate is None:
|
||||
return {}
|
||||
|
||||
buttons = [int(action.get(k, 0)) for k in REMOTE_BUTTONS]
|
||||
if buttons[0]: # R1 - raise waist
|
||||
self.groot_height_cmd += 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
if obs["remote.buttons"][4]: # R2 - lower waist
|
||||
if buttons[4]: # R2 - lower waist
|
||||
self.groot_height_cmd -= 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
|
||||
self.cmd[0] = obs["remote.ly"] # Forward/backward
|
||||
self.cmd[1] = obs["remote.lx"] * -1 # Left/right
|
||||
self.cmd[2] = obs["remote.rx"] * -1 # Rotation rate
|
||||
lx, ly, rx, _ry = (action.get(k, 0.0) for k in REMOTE_AXES)
|
||||
self.cmd[0] = ly # Forward/backward
|
||||
self.cmd[1] = -lx # Left/right (negated)
|
||||
self.cmd[2] = -rx # Rotation rate (negated)
|
||||
|
||||
# Get joint positions and velocities from flat dict
|
||||
# Get joint positions and velocities from lowstate
|
||||
for motor in G1_29_JointIndex:
|
||||
name = motor.name
|
||||
idx = motor.value
|
||||
self.groot_qj_all[idx] = obs[f"{name}.q"]
|
||||
self.groot_dqj_all[idx] = obs[f"{name}.dq"]
|
||||
|
||||
# Adapt observation for g1_23dof
|
||||
for idx in MISSING_JOINTS:
|
||||
self.groot_qj_all[idx] = 0.0
|
||||
self.groot_dqj_all[idx] = 0.0
|
||||
self.groot_qj_all[idx] = lowstate.motor_state[idx].q
|
||||
self.groot_dqj_all[idx] = lowstate.motor_state[idx].dq
|
||||
|
||||
# Scale joint positions and velocities
|
||||
qj_obs = self.groot_qj_all.copy()
|
||||
dqj_obs = self.groot_dqj_all.copy()
|
||||
|
||||
# Express IMU data in gravity frame of reference
|
||||
quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
|
||||
ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
|
||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||
quat = lowstate.imu_state.quaternion
|
||||
ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32)
|
||||
gravity_orientation = get_gravity_orientation(quat)
|
||||
|
||||
# Scale joint positions and velocities before policy inference
|
||||
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
|
||||
@@ -186,73 +196,10 @@ class GrootLocomotionController:
|
||||
# Transform action back to target joint positions
|
||||
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * ACTION_SCALE
|
||||
|
||||
# Build action dict (only first 15 joints for GR00T)
|
||||
# Build action dict
|
||||
action_dict = {}
|
||||
for i in range(15):
|
||||
motor_name = G1_29_JointIndex(i).name
|
||||
action_dict[f"{motor_name}.q"] = float(target_dof_pos_15[i])
|
||||
|
||||
# Zero out missing joints for g1_23dof
|
||||
for joint_idx in MISSING_JOINTS:
|
||||
motor_name = G1_29_JointIndex(joint_idx).name
|
||||
action_dict[f"{motor_name}.q"] = 0.0
|
||||
|
||||
# Send action to robot
|
||||
self.robot.send_action(action_dict)
|
||||
|
||||
|
||||
def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None:
|
||||
"""Main function to run the GR00T locomotion controller.
|
||||
|
||||
Args:
|
||||
repo_id: Hugging Face Hub repository ID for GR00T policies.
|
||||
"""
|
||||
# Load policies
|
||||
policy_balance, policy_walk = load_groot_policies(repo_id=repo_id)
|
||||
|
||||
# Initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
robot.connect()
|
||||
|
||||
# Initialize gr00T locomotion controller
|
||||
groot_controller = GrootLocomotionController(
|
||||
policy_balance=policy_balance,
|
||||
policy_walk=policy_walk,
|
||||
robot=robot,
|
||||
config=config,
|
||||
)
|
||||
|
||||
try:
|
||||
robot.reset(CONTROL_DT, GROOT_DEFAULT_ANGLES)
|
||||
|
||||
logger.info("Use joystick: LY=fwd/back, LX=left/right, RX=rotate, R1=raise waist, R2=lower waist")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# Run step
|
||||
while not robot._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
groot_controller.run_step()
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Stopping locomotion...")
|
||||
finally:
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
logger.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=DEFAULT_GROOT_REPO_ID,
|
||||
help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run(repo_id=args.repo_id)
|
||||
return action_dict
|
||||
+62
-112
@@ -14,21 +14,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
from lerobot.robots.unitree_g1.g1_utils import (
|
||||
REMOTE_AXES,
|
||||
G1_29_JointArmIndex,
|
||||
G1_29_JointIndex,
|
||||
get_gravity_orientation,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_ANGLES = np.zeros(29, dtype=np.float32)
|
||||
@@ -40,18 +40,13 @@ DEFAULT_ANGLES[16] = 0.2 # Left shoulder roll
|
||||
DEFAULT_ANGLES[23] = -0.2 # Right shoulder roll
|
||||
DEFAULT_ANGLES[[18, 25]] = 0.6 # Elbow
|
||||
|
||||
MISSING_JOINTS = []
|
||||
G1_MODEL = "g1_23" # Or "g1_29"
|
||||
if G1_MODEL == "g1_23":
|
||||
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # Waist yaw/pitch, wrist pitch/yaw
|
||||
|
||||
# Control parameters
|
||||
ACTION_SCALE = 0.25
|
||||
CONTROL_DT = 0.02 # 50Hz
|
||||
CONTROL_DT = 0.005 # 200Hz
|
||||
ANG_VEL_SCALE = 0.25
|
||||
DOF_POS_SCALE = 1.0
|
||||
DOF_VEL_SCALE = 0.05
|
||||
GAIT_PERIOD = 1.0
|
||||
GAIT_PERIOD = 0.5
|
||||
|
||||
|
||||
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
|
||||
@@ -87,7 +82,7 @@ def load_policy(
|
||||
logger.info(f"Policy loaded: {policy.get_inputs()[0].shape} → {policy.get_outputs()[0].shape}")
|
||||
|
||||
# Extract KP/KD from ONNX metadata
|
||||
model = onnx.load(policy_path)
|
||||
model = onnx.load(policy_path, load_external_data=False)
|
||||
metadata = {prop.key: prop.value for prop in model.metadata_props}
|
||||
|
||||
if "kp" not in metadata or "kd" not in metadata:
|
||||
@@ -101,15 +96,13 @@ def load_policy(
|
||||
|
||||
|
||||
class HolosomaLocomotionController:
|
||||
"""Holosoma whole-body locomotion controller for Unitree G1."""
|
||||
"""Holosoma lower-body locomotion controller for Unitree G1."""
|
||||
|
||||
def __init__(self, policy, robot, kp: np.ndarray, kd: np.ndarray):
|
||||
self.policy = policy
|
||||
self.robot = robot
|
||||
control_dt = CONTROL_DT # Expose for unitree_g1.py
|
||||
|
||||
# Override robot's PD gains with policy gains
|
||||
self.robot.kp = kp
|
||||
self.robot.kd = kd
|
||||
def __init__(self):
|
||||
# Load policy and gains
|
||||
self.policy, self.kp, self.kd = load_policy()
|
||||
|
||||
self.cmd = np.zeros(3, dtype=np.float32)
|
||||
|
||||
@@ -124,35 +117,55 @@ class HolosomaLocomotionController:
|
||||
self.phase_dt = 2 * np.pi / ((1.0 / CONTROL_DT) * GAIT_PERIOD)
|
||||
self.is_standing = True
|
||||
|
||||
def run_step(self):
|
||||
# Get current observation
|
||||
obs = self.robot.get_observation()
|
||||
logger.info("HolosomaLocomotionController initialized")
|
||||
|
||||
if not obs:
|
||||
return
|
||||
def reset(self) -> None:
|
||||
"""Reset internal state for a new episode."""
|
||||
self.cmd[:] = 0.0
|
||||
self.qj[:] = 0.0
|
||||
self.dqj[:] = 0.0
|
||||
self.obs[:] = 0.0
|
||||
self.last_action[:] = 0.0
|
||||
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||
self.is_standing = True
|
||||
|
||||
# Get command from remote controller
|
||||
ly = obs["remote.ly"] if abs(obs["remote.ly"]) > 0.1 else 0.0
|
||||
lx = obs["remote.lx"] if abs(obs["remote.lx"]) > 0.1 else 0.0
|
||||
rx = obs["remote.rx"] if abs(obs["remote.rx"]) > 0.1 else 0.0
|
||||
def run_step(self, action: dict, lowstate) -> dict:
|
||||
"""Run one step of the locomotion controller.
|
||||
|
||||
Args:
|
||||
action: Action dict containing remote.lx/ly/rx/ry
|
||||
lowstate: Robot lowstate containing motor positions/velocities and IMU
|
||||
|
||||
Returns:
|
||||
Action dict for lower body joints (0-14)
|
||||
"""
|
||||
if lowstate is None:
|
||||
return {}
|
||||
|
||||
lx, ly, rx, _ry = (action.get(k, 0.0) for k in REMOTE_AXES)
|
||||
ly = ly if abs(ly) > 0.1 else 0.0
|
||||
lx = lx if abs(lx) > 0.1 else 0.0
|
||||
rx = rx if abs(rx) > 0.1 else 0.0
|
||||
ly = np.clip(ly, -0.3, 0.3)
|
||||
lx = np.clip(lx, -0.3, 0.3)
|
||||
self.cmd[:] = [ly, -lx, -rx]
|
||||
|
||||
# Get joint positions and velocities
|
||||
# Get joint positions and velocities from lowstate
|
||||
for motor in G1_29_JointIndex:
|
||||
name = motor.name
|
||||
idx = motor.value
|
||||
self.qj[idx] = obs[f"{name}.q"]
|
||||
self.dqj[idx] = obs[f"{name}.dq"]
|
||||
self.qj[idx] = lowstate.motor_state[idx].q
|
||||
self.dqj[idx] = lowstate.motor_state[idx].dq
|
||||
|
||||
# Adapt observation for g1_23dof
|
||||
for idx in MISSING_JOINTS:
|
||||
self.qj[idx] = 0.0
|
||||
self.dqj[idx] = 0.0
|
||||
# Hide arm positions from policy (show DEFAULT_ANGLES instead)
|
||||
# This prevents policy from reacting to teleop arm movements
|
||||
for arm_joint in G1_29_JointArmIndex:
|
||||
self.qj[arm_joint.value] = DEFAULT_ANGLES[arm_joint.value]
|
||||
self.dqj[arm_joint.value] = 0.0
|
||||
|
||||
# Express IMU data in gravity frame of reference
|
||||
quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
|
||||
ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
|
||||
gravity = self.robot.get_gravity_orientation(quat)
|
||||
quat = lowstate.imu_state.quaternion
|
||||
ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32)
|
||||
gravity = get_gravity_orientation(quat)
|
||||
|
||||
# Scale joint positions and velocities before policy inference
|
||||
qj_obs = (self.qj - DEFAULT_ANGLES) * DOF_POS_SCALE
|
||||
@@ -186,79 +199,16 @@ class HolosomaLocomotionController:
|
||||
# Run policy inference
|
||||
ort_in = {self.policy.get_inputs()[0].name: self.obs.reshape(1, -1).astype(np.float32)}
|
||||
raw_action = self.policy.run(None, ort_in)[0].squeeze()
|
||||
action = np.clip(raw_action, -100.0, 100.0)
|
||||
self.last_action = action.copy()
|
||||
policy_action = np.clip(raw_action, -100.0, 100.0)
|
||||
self.last_action = policy_action.copy()
|
||||
|
||||
# Transform action back to target joint positions
|
||||
target = DEFAULT_ANGLES + action * ACTION_SCALE
|
||||
target = DEFAULT_ANGLES + policy_action * ACTION_SCALE
|
||||
|
||||
# Build action dict
|
||||
# Build action dict (first 15 joints only)
|
||||
action_dict = {}
|
||||
for motor in G1_29_JointIndex:
|
||||
action_dict[f"{motor.name}.q"] = float(target[motor.value])
|
||||
for i in range(15):
|
||||
motor_name = G1_29_JointIndex(i).name
|
||||
action_dict[f"{motor_name}.q"] = float(target[i])
|
||||
|
||||
# Zero out missing joints for g1_23dof
|
||||
for joint_idx in MISSING_JOINTS:
|
||||
motor_name = G1_29_JointIndex(joint_idx).name
|
||||
action_dict[f"{motor_name}.q"] = 0.0
|
||||
|
||||
# Send action to robot
|
||||
self.robot.send_action(action_dict)
|
||||
|
||||
|
||||
def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") -> None:
|
||||
"""Main function to run the Holosoma locomotion controller.
|
||||
|
||||
Args:
|
||||
repo_id: Hugging Face Hub repository ID for Holosoma policies.
|
||||
policy_type: Policy type to use ('fastsac' or 'ppo').
|
||||
"""
|
||||
# Load policy and gains
|
||||
policy, kp, kd = load_policy(repo_id=repo_id, policy_type=policy_type)
|
||||
|
||||
# Initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
robot.connect()
|
||||
|
||||
holosoma_controller = HolosomaLocomotionController(policy, robot, kp, kd)
|
||||
|
||||
try:
|
||||
robot.reset(CONTROL_DT, DEFAULT_ANGLES)
|
||||
|
||||
logger.info("Use joystick: LY=fwd/back, LX=left/right, RX=rotate")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# Run step
|
||||
while not robot._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
holosoma_controller.run_step()
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Stopping locomotion...")
|
||||
finally:
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
logger.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Holosoma Locomotion Controller for Unitree G1")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=DEFAULT_HOLOSOMA_REPO_ID,
|
||||
help=f"Hugging Face Hub repo ID for Holosoma policies (default: {DEFAULT_HOLOSOMA_REPO_ID})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy",
|
||||
type=str,
|
||||
choices=["fastsac", "ppo"],
|
||||
default="fastsac",
|
||||
help="Policy type to use: 'fastsac' (default) or 'ppo'",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run(repo_id=args.repo_id, policy_type=args.policy)
|
||||
return action_dict
|
||||
@@ -24,6 +24,7 @@ This server runs on the robot and forwards:
|
||||
Uses JSON for secure serialization instead of pickle.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
@@ -38,6 +39,8 @@ from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
|
||||
from lerobot.cameras.zmq.image_server import ImageServer
|
||||
|
||||
# DDS topic names follow Unitree SDK naming conventions
|
||||
# ruff: noqa: N816
|
||||
kTopicLowCommand_Debug = "rt/lowcmd" # action to robot
|
||||
@@ -150,6 +153,32 @@ def cmd_forward_loop(
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point for the robot server bridge."""
|
||||
parser = argparse.ArgumentParser(description="DDS-to-ZMQ bridge server for Unitree G1")
|
||||
parser.add_argument("--camera", action="store_true", help="Also launch camera server")
|
||||
parser.add_argument("--camera-device", type=int, default=4, help="Camera device ID (default: 4)")
|
||||
parser.add_argument("--camera-fps", type=int, default=30, help="Camera FPS (default: 30)")
|
||||
parser.add_argument("--camera-width", type=int, default=640, help="Camera width (default: 640)")
|
||||
parser.add_argument("--camera-height", type=int, default=480, help="Camera height (default: 480)")
|
||||
parser.add_argument("--camera-port", type=int, default=5555, help="Camera ZMQ port (default: 5555)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Optionally start camera server in background thread
|
||||
camera_thread = None
|
||||
if args.camera:
|
||||
camera_config = {
|
||||
"fps": args.camera_fps,
|
||||
"cameras": {
|
||||
"head_camera": {
|
||||
"device_id": args.camera_device,
|
||||
"shape": [args.camera_height, args.camera_width],
|
||||
}
|
||||
},
|
||||
}
|
||||
camera_server = ImageServer(camera_config, port=args.camera_port)
|
||||
camera_thread = threading.Thread(target=camera_server.run, daemon=True)
|
||||
camera_thread.start()
|
||||
print(f"Camera server started on port {args.camera_port} (device {args.camera_device})")
|
||||
|
||||
# initialize DDS
|
||||
ChannelFactoryInitialize(0)
|
||||
|
||||
@@ -206,6 +235,8 @@ def main() -> None:
|
||||
shutdown_event.set()
|
||||
ctx.term() # terminates blocking zmq.recv() calls
|
||||
t_state.join(timeout=2.0)
|
||||
if camera_thread is not None:
|
||||
camera_thread.join(timeout=2.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -14,27 +14,67 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex, G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
|
||||
from lerobot.robots.unitree_g1.g1_kinematics import G1_29_ArmIK
|
||||
from lerobot.robots.unitree_g1.g1_utils import (
|
||||
REMOTE_AXES,
|
||||
REMOTE_KEYS,
|
||||
G1_29_JointArmIndex,
|
||||
G1_29_JointIndex,
|
||||
default_remote_input,
|
||||
make_locomotion_controller,
|
||||
)
|
||||
from lerobot.utils.import_utils import _unitree_sdk_available
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
|
||||
if TYPE_CHECKING or _unitree_sdk_available:
|
||||
from unitree_sdk2py.core.channel import (
|
||||
ChannelFactoryInitialize as _SDKChannelFactoryInitialize,
|
||||
ChannelPublisher as _SDKChannelPublisher,
|
||||
ChannelSubscriber as _SDKChannelSubscriber,
|
||||
)
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
|
||||
LowCmd_ as hg_LowCmd,
|
||||
LowState_ as hg_LowState,
|
||||
)
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
else:
|
||||
_SDKChannelFactoryInitialize = None
|
||||
_SDKChannelPublisher = None
|
||||
_SDKChannelSubscriber = None
|
||||
unitree_hg_msg_dds__LowCmd_ = None
|
||||
hg_LowCmd = None
|
||||
hg_LowState = None
|
||||
CRC = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LocomotionController(Protocol):
|
||||
control_dt: float
|
||||
|
||||
def run_step(self, action: dict, lowstate) -> dict: ...
|
||||
|
||||
def reset(self) -> None: ...
|
||||
|
||||
|
||||
# DDS topic names follow Unitree SDK naming conventions
|
||||
# ruff: noqa: N816
|
||||
kTopicLowCommand_Debug = "rt/lowcmd"
|
||||
@@ -63,7 +103,7 @@ class IMUState:
|
||||
class G1_29_LowState: # noqa: N801
|
||||
motor_state: list[MotorState] = field(default_factory=lambda: [MotorState() for _ in G1_29_JointIndex])
|
||||
imu_state: IMUState = field(default_factory=IMUState)
|
||||
wireless_remote: Any = None # Raw wireless remote data
|
||||
wireless_remote: bytes | None = None # Raw wireless remote data
|
||||
mode_machine: int = 0 # Robot mode
|
||||
|
||||
|
||||
@@ -71,25 +111,6 @@ class UnitreeG1(Robot):
|
||||
config_class = UnitreeG1Config
|
||||
name = "unitree_g1"
|
||||
|
||||
# unitree remote controller
|
||||
class RemoteController:
|
||||
def __init__(self):
|
||||
self.lx = 0
|
||||
self.ly = 0
|
||||
self.rx = 0
|
||||
self.ry = 0
|
||||
self.button = [0] * 16
|
||||
|
||||
def set(self, data):
|
||||
# wireless_remote
|
||||
keys = struct.unpack("H", data[2:4])[0]
|
||||
for i in range(16):
|
||||
self.button[i] = (keys & (1 << i)) >> i
|
||||
self.lx = struct.unpack("f", data[4:8])[0]
|
||||
self.rx = struct.unpack("f", data[8:12])[0]
|
||||
self.ry = struct.unpack("f", data[12:16])[0]
|
||||
self.ly = struct.unpack("f", data[20:24])[0]
|
||||
|
||||
def __init__(self, config: UnitreeG1Config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -103,11 +124,9 @@ class UnitreeG1(Robot):
|
||||
|
||||
# Import channel classes based on mode
|
||||
if config.is_simulation:
|
||||
from unitree_sdk2py.core.channel import (
|
||||
ChannelFactoryInitialize,
|
||||
ChannelPublisher,
|
||||
ChannelSubscriber,
|
||||
)
|
||||
self._ChannelFactoryInitialize = _SDKChannelFactoryInitialize
|
||||
self._ChannelPublisher = _SDKChannelPublisher
|
||||
self._ChannelSubscriber = _SDKChannelSubscriber
|
||||
else:
|
||||
from lerobot.robots.unitree_g1.unitree_sdk2_socket import (
|
||||
ChannelFactoryInitialize,
|
||||
@@ -115,22 +134,30 @@ class UnitreeG1(Robot):
|
||||
ChannelSubscriber,
|
||||
)
|
||||
|
||||
# Store for use in connect()
|
||||
self._ChannelFactoryInitialize = ChannelFactoryInitialize
|
||||
self._ChannelPublisher = ChannelPublisher
|
||||
self._ChannelSubscriber = ChannelSubscriber
|
||||
self._ChannelFactoryInitialize = ChannelFactoryInitialize
|
||||
self._ChannelPublisher = ChannelPublisher
|
||||
self._ChannelSubscriber = ChannelSubscriber
|
||||
|
||||
# Initialize state variables
|
||||
self.sim_env = None
|
||||
self._env_wrapper = None
|
||||
self._lowstate = None
|
||||
self._lowstate_lock = threading.Lock()
|
||||
self._shutdown_event = threading.Event()
|
||||
self.subscribe_thread = None
|
||||
self.remote_controller = self.RemoteController()
|
||||
|
||||
self.arm_ik = G1_29_ArmIK()
|
||||
self.arm_ik = G1_29_ArmIK() if config.gravity_compensation else None
|
||||
|
||||
def _subscribe_motor_state(self): # polls robot state @ 250Hz
|
||||
# Lower-body controller loaded dynamically
|
||||
self.controller: LocomotionController | None = make_locomotion_controller(config.controller)
|
||||
|
||||
# Controller thread state
|
||||
self._controller_thread = None
|
||||
self._controller_action_lock = threading.Lock()
|
||||
self.controller_input = default_remote_input()
|
||||
self.controller_output = {}
|
||||
|
||||
def _subscribe_lowstate(self): # polls robot state @ 250Hz
|
||||
while not self._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
|
||||
@@ -143,11 +170,11 @@ class UnitreeG1(Robot):
|
||||
lowstate = G1_29_LowState()
|
||||
|
||||
# Capture motor states using jointindex
|
||||
for id in G1_29_JointIndex:
|
||||
lowstate.motor_state[id].q = msg.motor_state[id].q
|
||||
lowstate.motor_state[id].dq = msg.motor_state[id].dq
|
||||
lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est
|
||||
lowstate.motor_state[id].temperature = msg.motor_state[id].temperature
|
||||
for joint in G1_29_JointIndex:
|
||||
lowstate.motor_state[joint].q = msg.motor_state[joint].q
|
||||
lowstate.motor_state[joint].dq = msg.motor_state[joint].dq
|
||||
lowstate.motor_state[joint].tau_est = msg.motor_state[joint].tau_est
|
||||
lowstate.motor_state[joint].temperature = msg.motor_state[joint].temperature
|
||||
|
||||
# Capture IMU state
|
||||
lowstate.imu_state.quaternion = list(msg.imu_state.quaternion)
|
||||
@@ -162,31 +189,106 @@ class UnitreeG1(Robot):
|
||||
# Capture mode_machine
|
||||
lowstate.mode_machine = msg.mode_machine
|
||||
|
||||
self._lowstate = lowstate
|
||||
with self._lowstate_lock:
|
||||
self._lowstate = lowstate
|
||||
|
||||
current_time = time.time()
|
||||
all_t_elapsed = current_time - start_time
|
||||
sleep_time = max(0, (self.control_dt - all_t_elapsed)) # maintain constant control dt
|
||||
time.sleep(sleep_time)
|
||||
|
||||
def publish_lowcmd(
|
||||
self,
|
||||
action: RobotAction,
|
||||
kp: np.ndarray | list[float] | None = None,
|
||||
kd: np.ndarray | list[float] | None = None,
|
||||
tau: np.ndarray | list[float] | None = None,
|
||||
) -> None: # writes robot command whenever requested
|
||||
for motor in G1_29_JointIndex:
|
||||
key = f"{motor.name}.q"
|
||||
if key in action:
|
||||
self.msg.motor_cmd[motor.value].q = action[key]
|
||||
self.msg.motor_cmd[motor.value].qd = 0
|
||||
self.msg.motor_cmd[motor.value].kp = (
|
||||
kp[motor.value] if kp is not None else self.kp[motor.value]
|
||||
)
|
||||
self.msg.motor_cmd[motor.value].kd = (
|
||||
kd[motor.value] if kd is not None else self.kd[motor.value]
|
||||
)
|
||||
self.msg.motor_cmd[motor.value].tau = tau[motor.value] if tau is not None else 0.0
|
||||
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
|
||||
if self.controller is None:
|
||||
return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
|
||||
|
||||
def calibrate(self) -> None: # robot is already calibrated
|
||||
arm_features = {f"{G1_29_JointArmIndex(motor).name}.q": float for motor in G1_29_JointArmIndex}
|
||||
remote_features = dict.fromkeys(REMOTE_AXES, float)
|
||||
return {**arm_features, **remote_features}
|
||||
|
||||
def _controller_loop(self):
|
||||
"""Background thread that runs controller at policy's control_dt."""
|
||||
control_dt = self.controller.control_dt
|
||||
logger.info(f"Controller loop starting with control_dt={control_dt} ({1.0 / control_dt:.1f}Hz)")
|
||||
|
||||
loop_count = 0
|
||||
last_log_time = time.time()
|
||||
|
||||
while not self._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
|
||||
with self._lowstate_lock:
|
||||
lowstate = self._lowstate
|
||||
|
||||
if lowstate is not None and self.controller is not None:
|
||||
loop_count += 1
|
||||
if time.time() - last_log_time >= 5.0: # Log every 5 seconds
|
||||
actual_hz = loop_count / (time.time() - last_log_time)
|
||||
logger.info(
|
||||
f"Controller actual rate: {actual_hz:.1f}Hz (target: {1.0 / control_dt:.1f}Hz)"
|
||||
)
|
||||
loop_count = 0
|
||||
last_log_time = time.time()
|
||||
# Read controller input snapshot
|
||||
with self._controller_action_lock:
|
||||
controller_input = dict(self.controller_input)
|
||||
|
||||
# Run controller step
|
||||
controller_action = self.controller.run_step(controller_input, lowstate)
|
||||
|
||||
# Write controller output snapshot
|
||||
with self._controller_action_lock:
|
||||
self.controller_output = dict(controller_action)
|
||||
|
||||
ctrl_kp = self.controller.kp if hasattr(self.controller, "kp") else None
|
||||
ctrl_kd = self.controller.kd if hasattr(self.controller, "kd") else None
|
||||
self.publish_lowcmd(controller_action, kp=ctrl_kp, kd=ctrl_kd)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, control_dt - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
def calibrate(self) -> None:
|
||||
# TODO: implement g1_29 calibration
|
||||
pass
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None: # connect to DDS
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
|
||||
LowCmd_ as hg_LowCmd,
|
||||
LowState_ as hg_LowState,
|
||||
)
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
|
||||
# Initialize DDS channel and simulation environment
|
||||
if self.config.is_simulation:
|
||||
self._ChannelFactoryInitialize(0, "lo")
|
||||
@@ -194,7 +296,7 @@ class UnitreeG1(Robot):
|
||||
# Extract the actual gym env from the dict structure
|
||||
self.sim_env = self._env_wrapper["hub_env"][0].envs[0]
|
||||
else:
|
||||
self._ChannelFactoryInitialize(0)
|
||||
self._ChannelFactoryInitialize(0, config=self.config)
|
||||
|
||||
# Initialize direct motor control interface
|
||||
self.lowcmd_publisher = self._ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||
@@ -203,7 +305,7 @@ class UnitreeG1(Robot):
|
||||
self.lowstate_subscriber.Init()
|
||||
|
||||
# Start subscribe thread to read robot state
|
||||
self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state)
|
||||
self.subscribe_thread = threading.Thread(target=self._subscribe_lowstate)
|
||||
self.subscribe_thread.start()
|
||||
|
||||
# Connect cameras
|
||||
@@ -220,25 +322,53 @@ class UnitreeG1(Robot):
|
||||
|
||||
# Wait for first state message to arrive
|
||||
lowstate = None
|
||||
deadline = time.time() + 10.0
|
||||
while lowstate is None:
|
||||
lowstate = self._lowstate
|
||||
with self._lowstate_lock:
|
||||
lowstate = self._lowstate
|
||||
if lowstate is None:
|
||||
if time.time() > deadline:
|
||||
raise TimeoutError("Timed out waiting for robot state (10s)")
|
||||
logger.warning("[UnitreeG1] Waiting for robot state...")
|
||||
time.sleep(0.01)
|
||||
logger.warning("[UnitreeG1] Waiting for robot state...")
|
||||
logger.warning("[UnitreeG1] Connected to robot.")
|
||||
logger.info("[UnitreeG1] Connected to robot.")
|
||||
self.msg.mode_machine = lowstate.mode_machine
|
||||
|
||||
# Initialize all motors with unified kp/kd from config
|
||||
self.kp = np.array(self.config.kp, dtype=np.float32)
|
||||
self.kd = np.array(self.config.kd, dtype=np.float32)
|
||||
|
||||
for id in G1_29_JointIndex:
|
||||
self.msg.motor_cmd[id].mode = 1
|
||||
self.msg.motor_cmd[id].kp = self.kp[id.value]
|
||||
self.msg.motor_cmd[id].kd = self.kd[id.value]
|
||||
self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q
|
||||
for joint in G1_29_JointIndex:
|
||||
self.msg.motor_cmd[joint].mode = 1
|
||||
self.msg.motor_cmd[joint].kp = self.kp[joint.value]
|
||||
self.msg.motor_cmd[joint].kd = self.kd[joint.value]
|
||||
self.msg.motor_cmd[joint].q = lowstate.motor_state[joint.value].q
|
||||
|
||||
# Start controller thread if enabled
|
||||
if self.controller is not None:
|
||||
self._controller_thread = threading.Thread(target=self._controller_loop, daemon=True)
|
||||
self._controller_thread.start()
|
||||
fps = int(1.0 / self.controller.control_dt)
|
||||
logger.info(f"Controller thread started ({fps}Hz)")
|
||||
|
||||
def _send_zero_torque(self) -> None:
|
||||
"""Send a zero-gain command to make joints passive before shutting down."""
|
||||
try:
|
||||
with self._lowstate_lock:
|
||||
lowstate = self._lowstate
|
||||
if lowstate is None:
|
||||
return
|
||||
action = {f"{motor.name}.q": lowstate.motor_state[motor.value].q for motor in G1_29_JointIndex}
|
||||
zero_gains = np.zeros(29, dtype=np.float32)
|
||||
self.publish_lowcmd(action, kp=zero_gains, kd=zero_gains, tau=zero_gains)
|
||||
logger.info("Sent zero-torque command for safe shutdown")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send zero-torque on disconnect: {e}")
|
||||
|
||||
def disconnect(self):
|
||||
# Put robot in passive mode before stopping threads
|
||||
if not self.config.is_simulation:
|
||||
self._send_zero_torque()
|
||||
|
||||
# Signal thread to stop and unblock any waits
|
||||
self._shutdown_event.set()
|
||||
|
||||
@@ -248,6 +378,12 @@ class UnitreeG1(Robot):
|
||||
if self.subscribe_thread.is_alive():
|
||||
logger.warning("Subscribe thread did not stop cleanly")
|
||||
|
||||
# Wait for controller thread to finish
|
||||
if self._controller_thread is not None:
|
||||
self._controller_thread.join(timeout=2.0)
|
||||
if self._controller_thread.is_alive():
|
||||
logger.warning("Controller thread did not stop cleanly")
|
||||
|
||||
# Close simulation environment
|
||||
if self.config.is_simulation and self.sim_env is not None:
|
||||
try:
|
||||
@@ -274,7 +410,8 @@ class UnitreeG1(Robot):
|
||||
cam.disconnect()
|
||||
|
||||
def get_observation(self) -> RobotObservation:
|
||||
lowstate = self._lowstate
|
||||
with self._lowstate_lock:
|
||||
lowstate = self._lowstate
|
||||
if lowstate is None:
|
||||
return {}
|
||||
|
||||
@@ -313,14 +450,9 @@ class UnitreeG1(Robot):
|
||||
obs["imu.rpy.pitch"] = lowstate.imu_state.rpy[1]
|
||||
obs["imu.rpy.yaw"] = lowstate.imu_state.rpy[2]
|
||||
|
||||
# Controller - parse wireless_remote and add to obs
|
||||
if lowstate.wireless_remote and len(lowstate.wireless_remote) >= 24:
|
||||
self.remote_controller.set(lowstate.wireless_remote)
|
||||
obs["remote.buttons"] = self.remote_controller.button.copy()
|
||||
obs["remote.lx"] = self.remote_controller.lx
|
||||
obs["remote.ly"] = self.remote_controller.ly
|
||||
obs["remote.rx"] = self.remote_controller.rx
|
||||
obs["remote.ry"] = self.remote_controller.ry
|
||||
# Wireless remote (raw bytes for teleoperator)
|
||||
if lowstate.wireless_remote:
|
||||
obs["wireless_remote"] = lowstate.wireless_remote
|
||||
|
||||
# Cameras - read images from ZMQ cameras
|
||||
for cam_name, cam in self._cameras.items():
|
||||
@@ -328,73 +460,63 @@ class UnitreeG1(Robot):
|
||||
|
||||
return obs
|
||||
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
action_to_publish = action
|
||||
if self.controller is not None:
|
||||
# Controller thread owns legs/waist. Here we only update joystick inputs
|
||||
# and publish arm targets from the teleoperator.
|
||||
self._update_controller_action(action)
|
||||
arm_prefixes = tuple(j.name for j in G1_29_JointArmIndex)
|
||||
action_to_publish = {
|
||||
key: value
|
||||
for key, value in action.items()
|
||||
if key.endswith(".q") and key.startswith(arm_prefixes)
|
||||
}
|
||||
|
||||
tau = None
|
||||
if self.config.gravity_compensation and self.arm_ik is not None:
|
||||
tau = np.zeros(29, dtype=np.float32)
|
||||
action_np = np.array(
|
||||
[
|
||||
action_to_publish.get(f"{joint.name}.q", self.msg.motor_cmd[joint.value].q)
|
||||
for joint in G1_29_JointArmIndex
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
arm_tau = self.arm_ik.solve_tau(action_np)
|
||||
arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value
|
||||
for joint in G1_29_JointArmIndex:
|
||||
local_idx = joint.value - arm_start_idx
|
||||
tau[joint.value] = arm_tau[local_idx]
|
||||
|
||||
self.publish_lowcmd(action_to_publish, tau=tau)
|
||||
return action
|
||||
|
||||
def _update_controller_action(self, action: RobotAction) -> None:
|
||||
"""Update controller input state from incoming teleop action."""
|
||||
with self._controller_action_lock:
|
||||
for key in REMOTE_KEYS:
|
||||
if key in action:
|
||||
self.controller_input[key] = action[key]
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._lowstate is not None
|
||||
with self._lowstate_lock:
|
||||
return self._lowstate is not None
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
"""Joint positions for all 29 joints."""
|
||||
return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
|
||||
|
||||
@property
|
||||
def cameras(self) -> dict:
|
||||
return self._cameras
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
for motor in G1_29_JointIndex:
|
||||
key = f"{motor.name}.q"
|
||||
if key in action:
|
||||
self.msg.motor_cmd[motor.value].q = action[key]
|
||||
self.msg.motor_cmd[motor.value].qd = 0
|
||||
self.msg.motor_cmd[motor.value].kp = self.kp[motor.value]
|
||||
self.msg.motor_cmd[motor.value].kd = self.kd[motor.value]
|
||||
self.msg.motor_cmd[motor.value].tau = 0
|
||||
|
||||
if self.config.gravity_compensation:
|
||||
# Build action_np from motor commands (arm joints are indices 15-28, local indices 0-13)
|
||||
action_np = np.zeros(14)
|
||||
arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value # 15
|
||||
for joint in G1_29_JointArmIndex:
|
||||
local_idx = joint.value - arm_start_idx
|
||||
action_np[local_idx] = self.msg.motor_cmd[joint.value].q
|
||||
tau = self.arm_ik.solve_tau(action_np)
|
||||
|
||||
# Apply tau back to motor commands
|
||||
for joint in G1_29_JointArmIndex:
|
||||
local_idx = joint.value - arm_start_idx
|
||||
self.msg.motor_cmd[joint.value].tau = tau[local_idx]
|
||||
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
return action
|
||||
|
||||
def get_gravity_orientation(self, quaternion): # get gravity orientation from quaternion
|
||||
"""Get gravity orientation from quaternion."""
|
||||
qw = quaternion[0]
|
||||
qx = quaternion[1]
|
||||
qy = quaternion[2]
|
||||
qz = quaternion[3]
|
||||
|
||||
gravity_orientation = np.zeros(3)
|
||||
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
|
||||
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
|
||||
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
|
||||
return gravity_orientation
|
||||
|
||||
def reset(
|
||||
self,
|
||||
control_dt: float | None = None,
|
||||
@@ -407,15 +529,9 @@ class UnitreeG1(Robot):
|
||||
|
||||
if self.config.is_simulation and self.sim_env is not None:
|
||||
self.sim_env.reset()
|
||||
|
||||
for motor in G1_29_JointIndex:
|
||||
self.msg.motor_cmd[motor.value].q = default_positions[motor.value]
|
||||
self.msg.motor_cmd[motor.value].qd = 0
|
||||
self.msg.motor_cmd[motor.value].kp = self.kp[motor.value]
|
||||
self.msg.motor_cmd[motor.value].kd = self.kd[motor.value]
|
||||
self.msg.motor_cmd[motor.value].tau = 0
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
self.publish_lowcmd(
|
||||
{f"{motor.name}.q": float(default_positions[motor.value]) for motor in G1_29_JointIndex}
|
||||
)
|
||||
else:
|
||||
total_time = 3.0
|
||||
num_steps = int(total_time / control_dt)
|
||||
@@ -446,4 +562,8 @@ class UnitreeG1(Robot):
|
||||
sleep_time = max(0, control_dt - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
# Reset controller internal state (gait phase, obs history, etc.)
|
||||
if self.controller is not None and hasattr(self.controller, "reset"):
|
||||
self.controller.reset()
|
||||
|
||||
logger.info("Reached default position")
|
||||
|
||||
@@ -22,6 +22,8 @@ import zmq
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
|
||||
# Module-level ZMQ state mirrors the Unitree SDK's global ChannelFactory Singleton.
|
||||
# Only one robot connection per process is supported.
|
||||
_ctx: zmq.Context | None = None
|
||||
_lowcmd_sock: zmq.Socket | None = None
|
||||
_lowstate_sock: zmq.Socket | None = None
|
||||
@@ -97,17 +99,22 @@ def lowcmd_to_dict(topic: str, msg: Any) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def ChannelFactoryInitialize(*args: Any, **kwargs: Any) -> None: # noqa: N802
|
||||
def ChannelFactoryInitialize(domain_id: int = 0, config: Any = None) -> None: # noqa: N802
|
||||
"""
|
||||
Initialize ZMQ sockets for robot communication.
|
||||
|
||||
This function mimics the Unitree SDK's ChannelFactoryInitialize but uses
|
||||
ZMQ sockets to connect to the robot server bridge instead of DDS.
|
||||
|
||||
Args:
|
||||
domain_id: Ignored (for API compatibility with Unitree SDK)
|
||||
config: UnitreeG1Config instance with robot_ip
|
||||
"""
|
||||
global _ctx, _lowcmd_sock, _lowstate_sock
|
||||
|
||||
# read socket config
|
||||
config = UnitreeG1Config()
|
||||
if config is None:
|
||||
config = UnitreeG1Config()
|
||||
robot_ip = config.robot_ip
|
||||
|
||||
ctx = zmq.Context.instance()
|
||||
|
||||
@@ -56,6 +56,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
|
||||
@@ -132,10 +132,13 @@ def visualize_dataset(
|
||||
|
||||
logging.info("Logging to Rerun")
|
||||
|
||||
first_index = None
|
||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||
if first_index is None:
|
||||
first_index = batch["index"][0].item()
|
||||
# iterate over the batch
|
||||
for i in range(len(batch["index"])):
|
||||
rr.set_time("frame_index", sequence=batch["frame_index"][i].item())
|
||||
rr.set_time("frame_index", sequence=batch["index"][i].item() - first_index)
|
||||
rr.set_time("timestamp", timestamp=batch["timestamp"][i].item())
|
||||
|
||||
# display each camera image
|
||||
|
||||
@@ -21,6 +21,9 @@ This script allows you to delete episodes, split datasets, merge datasets,
|
||||
remove features, modify tasks, and convert image datasets to video format.
|
||||
When new_repo_id is specified, creates a new dataset.
|
||||
|
||||
Path semantics (v2): --root and --new_root are exact dataset folders containing
|
||||
meta/, data/, videos/. When omitted, defaults to $HF_LEROBOT_HOME/{repo_id}.
|
||||
|
||||
Usage Examples:
|
||||
|
||||
Delete episodes 0, 2, and 5 from a dataset:
|
||||
@@ -29,16 +32,31 @@ Delete episodes 0, 2, and 5 from a dataset:
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Delete episodes and save to a new dataset:
|
||||
Delete episodes from a local dataset at a specific path:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_filtered \
|
||||
--root /path/to/pusht \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Split dataset by fractions:
|
||||
Delete episodes and save to a new dataset at a specific path and with a new repo_id:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_filtered \
|
||||
--new_root /path/to/pusht_filtered \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Split dataset by fractions (pusht_train, pusht_val):
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.8, "val": 0.2}'
|
||||
|
||||
Split dataset by fractions and save split datasets to a specific folder (base_folder/train, base_folder/val):
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_root /path/to/base_folder \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.8, "val": 0.2}'
|
||||
|
||||
@@ -56,15 +74,29 @@ Split into more than two splits:
|
||||
|
||||
Merge multiple datasets:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_merged \
|
||||
--new_repo_id lerobot/pusht_merged \
|
||||
--operation.type merge \
|
||||
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
||||
|
||||
Merge multiple datasets to a specific output path:
|
||||
lerobot-edit-dataset \
|
||||
--new_repo_id lerobot/pusht_merged \
|
||||
--new_root /path/to/pusht_merged \
|
||||
--operation.type merge \
|
||||
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
||||
|
||||
Merge multiple datasets from a list of local dataset paths:
|
||||
lerobot-edit-dataset \
|
||||
--new_repo_id lerobot/pusht_merged \
|
||||
--operation.type merge \
|
||||
--operation.repo_ids "['pusht_train', 'pusht_val']" \
|
||||
--operation.roots "['/path/to/pusht_train', '/path/to/pusht_val']"
|
||||
|
||||
Remove camera feature:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
--operation.feature_names "['observation.image']"
|
||||
|
||||
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
|
||||
lerobot-edit-dataset \
|
||||
@@ -88,8 +120,8 @@ Modify tasks - set default task with overrides for specific episodes (WARNING: m
|
||||
Convert image dataset to video format and save locally:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
--new_root /path/to/output/pusht_video \
|
||||
--operation.type convert_image_to_video
|
||||
|
||||
Convert image dataset to video format and save with new repo_id:
|
||||
lerobot-edit-dataset \
|
||||
@@ -167,6 +199,7 @@ class SplitConfig(OperationConfig):
|
||||
@dataclass
|
||||
class MergeConfig(OperationConfig):
|
||||
repo_ids: list[str] | None = None
|
||||
roots: list[str] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("remove_feature")
|
||||
@@ -200,36 +233,46 @@ class ConvertImageToVideoConfig(OperationConfig):
|
||||
@OperationConfig.register_subclass("info")
|
||||
@dataclass
|
||||
class InfoConfig(OperationConfig):
|
||||
type: str = "info"
|
||||
show_features: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
# Operation configuration.
|
||||
operation: OperationConfig
|
||||
# Input dataset identifier. Always required unless for Merge operation.
|
||||
repo_id: str | None = None
|
||||
# Root directory where the input dataset is stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | None = None
|
||||
# Edited dataset identifier. When both new_repo_id (resp. new_root) and repo_id (resp. root) are identical, modifications are applied in-place and a backup of the original dataset is created. Required for Merge operation.
|
||||
new_repo_id: str | None = None
|
||||
# Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/new_repo_id. For Split operation, this is the base directory for the split datasets.
|
||||
new_root: str | None = None
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = False
|
||||
|
||||
|
||||
def get_output_path(repo_id: str, new_repo_id: str | None, root: Path | None) -> tuple[str, Path]:
|
||||
if new_repo_id:
|
||||
output_repo_id = new_repo_id
|
||||
output_dir = root / new_repo_id if root else HF_LEROBOT_HOME / new_repo_id
|
||||
else:
|
||||
output_repo_id = repo_id
|
||||
dataset_path = root / repo_id if root else HF_LEROBOT_HOME / repo_id
|
||||
old_path = Path(str(dataset_path) + "_old")
|
||||
def get_output_path(
|
||||
repo_id: str,
|
||||
new_repo_id: str | None,
|
||||
root: Path | str | None,
|
||||
new_root: Path | str | None,
|
||||
) -> tuple[str, Path]:
|
||||
input_path = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
if dataset_path.exists():
|
||||
if old_path.exists():
|
||||
shutil.rmtree(old_path)
|
||||
shutil.move(str(dataset_path), str(old_path))
|
||||
output_repo_id = new_repo_id if new_repo_id else repo_id
|
||||
output_path = Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id
|
||||
|
||||
output_dir = dataset_path
|
||||
# In case of in-place modification, create a backup of the original dataset (if it exists)
|
||||
if output_path == input_path:
|
||||
backup_path = input_path.with_name(input_path.name + "_old")
|
||||
|
||||
return output_repo_id, output_dir
|
||||
if input_path.exists():
|
||||
if backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
shutil.move(input_path, backup_path)
|
||||
|
||||
return output_repo_id, output_path
|
||||
|
||||
|
||||
def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
|
||||
@@ -241,11 +284,15 @@ def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
output_repo_id, output_dir = get_output_path(
|
||||
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
||||
cfg.repo_id,
|
||||
new_repo_id=cfg.new_repo_id,
|
||||
root=cfg.root,
|
||||
new_root=cfg.new_root,
|
||||
)
|
||||
|
||||
if cfg.new_repo_id is None:
|
||||
dataset.root = Path(str(dataset.root) + "_old")
|
||||
# In case of in-place modification, make the dataset point to the backup directory
|
||||
if output_dir == dataset.root:
|
||||
dataset.root = dataset.root.with_name(dataset.root.name + "_old")
|
||||
|
||||
logging.info(f"Deleting episodes {cfg.operation.episode_indices} from {cfg.repo_id}")
|
||||
new_dataset = delete_episodes(
|
||||
@@ -272,19 +319,27 @@ def handle_split(cfg: EditDatasetConfig) -> None:
|
||||
"splits dict must be specified with split names as keys and fractions/episode lists as values"
|
||||
)
|
||||
|
||||
if cfg.new_repo_id is not None:
|
||||
logging.warning(
|
||||
"split uses the original dataset identifier --repo_id to generate split names. The --new_repo_id parameter is ignored."
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
|
||||
logging.info(f"Splitting dataset {cfg.repo_id} with splits: {cfg.operation.splits}")
|
||||
split_datasets = split_dataset(dataset, splits=cfg.operation.splits)
|
||||
split_datasets = split_dataset(
|
||||
dataset,
|
||||
splits=cfg.operation.splits,
|
||||
output_dir=cfg.new_root,
|
||||
)
|
||||
|
||||
for split_name, split_ds in split_datasets.items():
|
||||
split_repo_id = f"{cfg.repo_id}_{split_name}"
|
||||
logging.info(
|
||||
f"{split_name}: {split_ds.meta.total_episodes} episodes, {split_ds.meta.total_frames} frames"
|
||||
)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing {split_name} split to hub as {split_repo_id}")
|
||||
logging.info(f"Pushing {split_name} split to hub as {split_ds.repo_id}")
|
||||
LeRobotDataset(split_ds.repo_id, root=split_ds.root).push_to_hub()
|
||||
|
||||
|
||||
@@ -295,18 +350,29 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
|
||||
if not cfg.operation.repo_ids:
|
||||
raise ValueError("repo_ids must be specified for merge operation")
|
||||
|
||||
if not cfg.repo_id:
|
||||
raise ValueError("repo_id must be specified as the output repository for merged dataset")
|
||||
if cfg.repo_id is not None or cfg.root is not None:
|
||||
logging.warning(
|
||||
"merge uses --new_repo_id and --new_root for the merged dataset. The --repo_id and --root parameters are ignored."
|
||||
)
|
||||
|
||||
logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge")
|
||||
datasets = [LeRobotDataset(repo_id, root=cfg.root) for repo_id in cfg.operation.repo_ids]
|
||||
if cfg.operation.roots:
|
||||
if len(cfg.operation.roots) != len(cfg.operation.repo_ids):
|
||||
raise ValueError("repo_ids and roots must have the same length for merge operation")
|
||||
logging.info(f"Loading {len(cfg.operation.roots)} datasets to merge")
|
||||
datasets = [
|
||||
LeRobotDataset(repo_id=repo_id, root=root)
|
||||
for repo_id, root in zip(cfg.operation.repo_ids, cfg.operation.roots, strict=True)
|
||||
]
|
||||
else:
|
||||
logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge")
|
||||
datasets = [LeRobotDataset(repo_id) for repo_id in cfg.operation.repo_ids]
|
||||
|
||||
output_dir = Path(cfg.root) / cfg.repo_id if cfg.root else HF_LEROBOT_HOME / cfg.repo_id
|
||||
output_dir = Path(cfg.new_root) if cfg.new_root else HF_LEROBOT_HOME / cfg.new_repo_id
|
||||
|
||||
logging.info(f"Merging datasets into {cfg.repo_id}")
|
||||
logging.info(f"Merging datasets into {cfg.new_repo_id}")
|
||||
merged_dataset = merge_datasets(
|
||||
datasets,
|
||||
output_repo_id=cfg.repo_id,
|
||||
output_repo_id=cfg.new_repo_id,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
@@ -316,7 +382,7 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
|
||||
)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {cfg.repo_id}")
|
||||
logging.info(f"Pushing to hub as {cfg.new_repo_id}")
|
||||
LeRobotDataset(merged_dataset.repo_id, root=output_dir).push_to_hub()
|
||||
|
||||
|
||||
@@ -329,11 +395,15 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
output_repo_id, output_dir = get_output_path(
|
||||
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
||||
cfg.repo_id,
|
||||
new_repo_id=cfg.new_repo_id,
|
||||
root=cfg.root,
|
||||
new_root=cfg.new_root,
|
||||
)
|
||||
|
||||
if cfg.new_repo_id is None:
|
||||
dataset.root = Path(str(dataset.root) + "_old")
|
||||
# In case of in-place modification, make the dataset point to the backup directory
|
||||
if output_dir == dataset.root:
|
||||
dataset.root = dataset.root.with_name(dataset.root.name + "_old")
|
||||
|
||||
logging.info(f"Removing features {cfg.operation.feature_names} from {cfg.repo_id}")
|
||||
new_dataset = remove_feature(
|
||||
@@ -361,9 +431,10 @@ def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
||||
if new_task is None and episode_tasks_raw is None:
|
||||
raise ValueError("Must specify at least one of new_task or episode_tasks for modify_tasks operation")
|
||||
|
||||
# Warn about in-place modification behavior
|
||||
if cfg.new_repo_id is not None:
|
||||
logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.")
|
||||
if cfg.new_repo_id is not None or cfg.new_root is not None:
|
||||
logging.warning(
|
||||
"modify_tasks modifies datasets in-place. The --new_repo_id and --new_root parameters are ignored."
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.")
|
||||
@@ -399,32 +470,30 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
|
||||
# Determine output directory and repo_id
|
||||
# Priority: 1) new_repo_id, 2) operation.output_dir, 3) auto-generated name
|
||||
# Priority: 1) new_root, 2) new_repo_id, 3) operation.output_dir, 4) auto-generated name
|
||||
output_dir_config = getattr(cfg.operation, "output_dir", None)
|
||||
if output_dir_config:
|
||||
logging.warning(
|
||||
"--operation.output_dir is deprecated and will be removed in future versions. "
|
||||
"Please use --new_root instead."
|
||||
)
|
||||
|
||||
if cfg.new_repo_id:
|
||||
# Use new_repo_id for both local storage and hub push
|
||||
if cfg.new_root:
|
||||
output_dir = Path(cfg.new_root)
|
||||
output_repo_id = cfg.new_repo_id or f"{cfg.repo_id}_video"
|
||||
logging.info(f"Saving to new_root: {output_dir} as {output_repo_id}")
|
||||
elif cfg.new_repo_id:
|
||||
output_repo_id = cfg.new_repo_id
|
||||
# Place new dataset as a sibling to the original dataset
|
||||
# Get the parent of the actual dataset root (not cfg.root which might be the lerobot cache dir)
|
||||
# Extract just the dataset name (after last slash) for the local directory
|
||||
local_dir_name = cfg.new_repo_id.split("/")[-1]
|
||||
output_dir = dataset.root.parent / local_dir_name
|
||||
output_dir = HF_LEROBOT_HOME / cfg.new_repo_id
|
||||
logging.info(f"Saving to new dataset: {cfg.new_repo_id} at {output_dir}")
|
||||
elif output_dir_config:
|
||||
# Use custom output directory for local-only storage
|
||||
output_dir = Path(output_dir_config)
|
||||
# Extract repo name from output_dir for the dataset
|
||||
output_repo_id = output_dir.name
|
||||
logging.info(f"Saving to local directory: {output_dir}")
|
||||
logging.info(f"Saving to local directory: {output_dir} as {output_repo_id}")
|
||||
else:
|
||||
# Auto-generate name: append "_video" to original repo_id
|
||||
output_repo_id = f"{cfg.repo_id}_video"
|
||||
# Place new dataset as a sibling to the original dataset
|
||||
# Extract just the dataset name (after last slash) for the local directory
|
||||
local_dir_name = output_repo_id.split("/")[-1]
|
||||
output_dir = dataset.root.parent / local_dir_name
|
||||
logging.info(f"Saving to auto-generated location: {output_dir}")
|
||||
output_dir = HF_LEROBOT_HOME / output_repo_id
|
||||
logging.info(f"Saving to auto-generated location: {output_dir} as {output_repo_id}")
|
||||
|
||||
logging.info(f"Converting dataset {cfg.repo_id} to video format")
|
||||
|
||||
@@ -499,8 +568,20 @@ def handle_info(cfg: EditDatasetConfig):
|
||||
sys.stdout.write(f"{feature_dump_str}\n")
|
||||
|
||||
|
||||
def _validate_config(cfg: EditDatasetConfig) -> None:
|
||||
if isinstance(cfg.operation, MergeConfig):
|
||||
if not cfg.new_repo_id:
|
||||
raise ValueError("--new_repo_id is required for merge operation (the merged dataset identifier)")
|
||||
else:
|
||||
if not cfg.repo_id:
|
||||
raise ValueError(
|
||||
f"--repo_id is required for {cfg.operation.type} operation (the input dataset identifier)"
|
||||
)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
_validate_config(cfg)
|
||||
operation_type = cfg.operation.type
|
||||
|
||||
if operation_type == "delete_episodes":
|
||||
|
||||
@@ -61,6 +61,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
)
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
@@ -74,8 +74,6 @@ from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.cameras import ( # noqa: F401
|
||||
CameraConfig, # noqa: F401
|
||||
)
|
||||
@@ -92,7 +90,6 @@ from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
PolicyAction,
|
||||
@@ -128,6 +125,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
@@ -157,7 +155,7 @@ class DatasetRecordConfig:
|
||||
repo_id: str
|
||||
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||
single_task: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second.
|
||||
fps: int = 30
|
||||
@@ -228,9 +226,6 @@ class RecordConfig:
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# Action interpolation multiplier for smoother policy control (1=off, 2=2x, 3=3x)
|
||||
# Only applies when using a policy (not teleop)
|
||||
interpolation_multiplier: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
@@ -303,7 +298,6 @@ def record_loop(
|
||||
control_time_s: int | None = None,
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
interpolator: ActionInterpolator | None = None,
|
||||
display_compressed_images: bool = False,
|
||||
):
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
@@ -340,14 +334,7 @@ def record_loop(
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
# Reset interpolator if provided
|
||||
if interpolator is not None:
|
||||
interpolator.reset()
|
||||
|
||||
# Calculate control interval based on interpolation
|
||||
use_interpolation = interpolator is not None and interpolator.enabled and policy is not None
|
||||
control_interval = interpolator.get_control_interval(fps) if interpolator else 1 / fps
|
||||
|
||||
no_action_count = 0
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < control_time_s:
|
||||
@@ -368,58 +355,26 @@ def record_loop(
|
||||
|
||||
# Get action from either policy or teleop
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
# With interpolation: only call policy when interpolator needs new action
|
||||
if use_interpolation:
|
||||
# Get action keys from robot
|
||||
action_keys = sorted(robot.action_features)
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
act_processed_policy = make_robot_action(action_values, dataset.features)
|
||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||
|
||||
# Convert to tensor for interpolator
|
||||
action_tensor = torch.tensor([robot_action_to_send[k] for k in action_keys])
|
||||
interpolator.add(action_tensor)
|
||||
|
||||
# Get interpolated action
|
||||
interp_action = interpolator.get()
|
||||
if interp_action is not None:
|
||||
robot_action_to_send = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
|
||||
action_values = robot_action_to_send
|
||||
else:
|
||||
# No action available yet, skip this iteration
|
||||
continue
|
||||
else:
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
|
||||
elif policy is None and isinstance(teleop, Teleoperator):
|
||||
if robot.name == "unitree_g1":
|
||||
teleop.send_feedback(obs)
|
||||
act = teleop.get_action()
|
||||
|
||||
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
|
||||
act_processed_teleop = teleop_action_processor((act, obs))
|
||||
action_values = act_processed_teleop
|
||||
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
|
||||
|
||||
elif policy is None and isinstance(teleop, list):
|
||||
arm_action = teleop_arm.get_action()
|
||||
@@ -428,15 +383,23 @@ def record_loop(
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_action)
|
||||
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
act_processed_teleop = teleop_action_processor((act, obs))
|
||||
else:
|
||||
no_action_count += 1
|
||||
if no_action_count == 1 or no_action_count % 10 == 0:
|
||||
logging.warning(
|
||||
"No policy or teleoperator provided, skipping action generation. "
|
||||
"This is likely to happen when resetting the environment without a teleop device. "
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
)
|
||||
continue
|
||||
|
||||
# Applies a pipeline to the action, default is IdentityProcessor
|
||||
if policy is not None and act_processed_policy is not None:
|
||||
action_values = act_processed_policy
|
||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||
else:
|
||||
action_values = act_processed_teleop
|
||||
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
|
||||
else:
|
||||
logging.info(
|
||||
"No policy or teleoperator provided, skipping action generation."
|
||||
"This is likely to happen when resetting the environment without a teleop device."
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
)
|
||||
continue
|
||||
|
||||
# Send action to robot
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
@@ -457,7 +420,7 @@ def record_loop(
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
|
||||
sleep_time_s: float = control_interval - dt_s
|
||||
sleep_time_s: float = 1 / fps - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
@@ -544,7 +507,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
interpolator = None
|
||||
if cfg.policy is not None:
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
@@ -555,10 +517,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
# Create interpolator for smoother policy control
|
||||
if cfg.interpolation_multiplier > 1:
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
logging.info(f"Action interpolation enabled: {cfg.interpolation_multiplier}x control rate")
|
||||
|
||||
robot.connect()
|
||||
if teleop is not None:
|
||||
@@ -590,7 +548,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
interpolator=interpolator,
|
||||
display_compressed_images=display_compressed_images,
|
||||
)
|
||||
|
||||
@@ -601,10 +558,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
):
|
||||
log_say("Reset the environment", cfg.play_sounds)
|
||||
|
||||
# reset g1 robot
|
||||
if robot.name == "unitree_g1":
|
||||
robot.reset()
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
|
||||
@@ -80,7 +80,7 @@ class DatasetReplayConfig:
|
||||
repo_id: str
|
||||
# Episode to replay.
|
||||
episode: int
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second. By default, uses the policy fps.
|
||||
fps: int = 30
|
||||
|
||||
@@ -43,6 +43,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
)
|
||||
|
||||
@@ -51,6 +52,7 @@ COMPATIBLE_DEVICES = [
|
||||
"koch_leader",
|
||||
"omx_follower",
|
||||
"omx_leader",
|
||||
"openarm_mini",
|
||||
"so100_follower",
|
||||
"so100_leader",
|
||||
"so101_follower",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user