Compare commits

...

29 Commits

Author SHA1 Message Date
Martino Russi ee24f64ae5 add motion imitation 2025-12-17 16:00:43 +01:00
Martino Russi 123b9f7851 add motion imitation 2025-12-17 15:59:56 +01:00
Martino Russi a6c3a0fa09 Feat/add mj env (#2613)
* add sim support

* close fix threading issues
2025-12-15 16:22:27 +01:00
Woojin Wie c2fb644613 feat(robot): Add support for OMX robot (#2614)
* upload

* feat(omx): simplify motor initialization and remove default calibration files

* feat(omx): read motor positions without normalization for improved accuracy

* update calibration method for return factory value

Signed-off-by: Junha Cha <ckwnsgk1@gachon.ac.kr>

* change the drive mode

* refactor: clean up code by removing unnecessary blank lines in omx_follower and omx_leader modules

* feat(omx): update calibration method to set drive modes for motors

* feat(pyproject): add 'ROBOTIS' to extend-ignore-identifiers-re list

* feat(omx): enhance calibration method to write default drive modes to motors

* Update src/lerobot/robots/omx_follower/__init__.py

Add informations about the robot

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Signed-off-by: Woojin Wie <dnldnwls1123@gmail.com>

---------

Signed-off-by: Junha Cha <ckwnsgk1@gachon.ac.kr>
Signed-off-by: Woojin Wie <dnldnwls1123@gmail.com>
Co-authored-by: Junha02 <chajunha2023@naver.com>
Co-authored-by: Junha Cha <ckwnsgk1@gachon.ac.kr>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-12-15 15:50:29 +01:00
Jade Choghari 1d07a4aefd add auto in docs (#2645)
Signed-off-by: Jade Choghari <chogharijade@gmail.com>
2025-12-13 17:11:19 +01:00
Michel Aractingi ce348a3460 enable variable image sizes to pi0/pi0.5 (#2609)
* enable variable image sizes to pi0/pi0.5

* add square image assertion
2025-12-10 19:41:11 +01:00
Jade Choghari cb920235c4 docs: update X-VLA training strategies/commands (#2611) 2025-12-09 19:08:09 +01:00
Jade Choghari 7f40b3bf82 feat(dataset): add tool to convert images to video datasets (#2560)
* add video encoding tool

* style

* make it work

* more fixes
2025-12-08 18:50:21 +01:00
Michel Aractingi 2e9c9fd832 Replay while loop in sample actions with for loops (#2600) 2025-12-08 14:47:54 +01:00
Steven Palma f9cb5e659c chore(ci): skip workflows if not lerobot repository (#2601)
Co-authored-by: Alex Tyshka <atyshka15@gmail.com>
2025-12-08 12:44:36 +01:00
Michel Aractingi 0217e1e3ad Fix dataset aggreagation for multi video datasets' (#2550) 2025-12-05 16:09:25 +01:00
Vladislav Sovrasov d79dd6d31f Add a documentation page with a brief intro to hw backends (#2385) 2025-12-05 13:32:58 +01:00
Steven Palma 56b43cc888 fix(scripts): missing so101 import (#2577)
* fix(scripts): missing so101 import

Co-authored-by: Skyler <skylerwiernik@gmail.com>

* fix(scripts): move urdf to cli args

* refactor(scripts): improve find_joints_limits

---------

Co-authored-by: Skyler <skylerwiernik@gmail.com>
2025-12-03 18:20:26 +01:00
Kevin Thomas 77fe5a09ed fix(docs): argument typo (#2361)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-12-03 17:57:18 +01:00
Austin King 89ae7813a7 Reorganize assembly instructions setup before assembly (#2333)
Motors should be set up before the arm is assembled. 

Moving the entire motor setup section before the part cleaning and assembly section.

Signed-off-by: Austin King <shout@ozten.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-12-03 17:56:58 +01:00
./c² e003108cf8 Fix link to lerobot-train script in documentation (#2466)
* Fix link to lerobot-train script in documentation

Signed-off-by: ./c² <cagataycali@icloud.com>

* Update link to lerobot record script

Signed-off-by: ./c² <cagataycali@icloud.com>

---------

Signed-off-by: ./c² <cagataycali@icloud.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-12-03 15:46:26 +01:00
Steven Palma 5766eea377 fix(docs): remove duplicated package in install instructions (#2573) 2025-12-03 15:45:56 +01:00
Steven Palma f8a4cf225b feat(robots): add earth rover robot support (#2575)
Co-authored-by: somthecoder <sbaner64@gmail.com>
Co-authored-by: randomSmarts <Aarshsmittal@gmail.com>
Co-authored-by: Hassoonu <halsae2@illinois.edu>
Co-authored-by: Saketh06 <saketh.kantipudi@gmail.com>
Co-authored-by: sairajshetye <sairajshetye2@gmail.com>
Co-authored-by: Khalil Meftah <kmeftah.khalil@gmail.com>
2025-12-03 15:36:22 +01:00
Jade Choghari 43b0f17eb9 feat(policies): Add X-VLA (#2405)
* first commit

* more fixes

* add franka action

* update testing script

* add changes

* update files

* logits matching

* add imagenet as a norm type

* logits matching atol1e-2

* more eval fixes

* more changes

* xvla works on libero

* remove seed

* more refactoring

* more fixes

* more changes

* more changes

* more fixes

* migrate policy revert

* major pre-commit cleanup

* renaming

* revert to self.transformer

* refactor

* new changes

* clean

* update libero

* more changes

* make it work

* more changes:

* remove imagenet dependency

* style

* more

* more refactor

* remove proprio

* add loss

* more

* more

* add freeze/unfreeze options

* add testing

* upgrade transformers version

* update testing

* add installation

* remove .sh file

* fix testing

* silent linter in xvlatest

* fix failing test

* upgrade test, fix failing

* fix testing

* more fixes to testing

* require cuda in tests

* temp check

* add xvla docs

* fix styling

* update libero doc

* remove timm dep

* add different dtype support

* remove timm skip

* remove white lines

* Enhance X-VLA finetuning documentation with optimizer details (#2537)

Added detailed instructions for implementing a custom optimizer and modifying parameter retrieval for X-VLA finetuning.

Signed-off-by: Jinliang Zheng <54488861+2toinf@users.noreply.github.com>

* fix style

* iterate on review

* iterate on cpilot

* revert xvla dep

* free up ci

* test(xvla): remove main test (#2565)

* Add xvla custom optim and dtype (#2567)

* add custom optim

* add custom optim

* add auto mode

* more changes

* add identity to all

* add auto

* release

* add docs

* make image smaller docs

* smaller image in doc

* evan smaller image doc

* finalize doc

---------

Signed-off-by: Jinliang Zheng <54488861+2toinf@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Jinliang Zheng <54488861+2toinf@users.noreply.github.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-12-03 15:29:14 +01:00
Steven Palma b0b755471b Revert "Earth Rover Mini Plus integration (#2544)" (#2574)
This reverts commit 35c5a27352.
2025-12-03 14:43:07 +01:00
s1lent4gnt 35c5a27352 Earth Rover Mini Plus integration (#2544)
* feat: Add EarthRover Mini Plus robot integration with Frodobots SDK

* refactor: Clean up

* refactor: Remove VirtualCamera implementation for EarthRover Mini Plus integration

* fix: Reduce timeout for camera requests

* fix: Add empty cameras dict for compatibility with recording script

* refactor: Remove record.py script for EarthRover Mini Plus use lerobot_record instead

* refactor: Update documentation for EarthRover Mini Plus integration

* refactor keyboard teleoperation

* refactor: Remove angular velocity

* docs: Add documentation for EarthRover Mini Plus integration

* Add earthrover_mini_plus robot to replay and teleoperate scripts

* refactor: Update stop key from Space to X

* refactor: Implement caching for camera frames and robot telemetry data

* refactor

* refactor: Replace string literals with constants for action and observation keys

* Add Earth Rover Mini to robots section in documentation

Co-authored-by: somthecoder sbaner64@gmail.com
Co-authored-by: randomSmarts Aarshsmittal@gmail.com
Co-authored-by: Hassoonu halsae2@illinois.edu
Co-authored-by: Saketh06 saketh.kantipudi@gmail.com
Co-authored-by: sairajshetye sairajshetye2@gmail.com
2025-12-03 14:24:57 +01:00
vinoyang afb90e17e7 doc: fix wrong package name in installation doc (#2513) 2025-12-03 13:36:59 +01:00
Daniel San José Pro 9ec9ee781a feat(policies): Allow users to register 3rd party policies - pip install lerobot_policy_mypolicy (#2308)
* feat: Register external policies

* ruff fix

* move policy util functions to policy factory

* refactor register_third_party_devices -> register_third_party_plugins

* feat: Update docs with bring your own policies

* Improve docs for new policies

* fix: Inconsistent quotation marks

* fix: Remove print statement

* fix: wrong base class name in documentation

* fix: Handle better how the models are parsed

* fix: precommit passing

* Update docs/source/bring_your_own_policies.mdx

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Signed-off-by: Daniel San José Pro <42489409+danielsanjosepro@users.noreply.github.com>

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Signed-off-by: Daniel San José Pro <42489409+danielsanjosepro@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-12-03 12:09:24 +01:00
Md. Muhaimin Rahman 0b497fc37d Make transport module Mypy Compliant [issue#1731] (#2433)
* latest

* Delete =3.0.0

Signed-off-by: Md. Muhaimin Rahman <sezan92@gmail.com>

* Update src/lerobot/transport/utils.py

Signed-off-by: Md. Muhaimin Rahman <sezan92@gmail.com>

---------

Signed-off-by: Md. Muhaimin Rahman <sezan92@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-12-02 22:12:15 +01:00
Michel Aractingi 797cd2725a fix pi05 forward compile (#2551) 2025-12-02 11:01:43 +01:00
Steven Palma af4766b602 fix(ci): move hub artifacts to /mnt to avoid runners' No space left on device (#2564)
* fix(ci): move hub & lerobot artefacts to /mnt to avoid No space left on device in the future

* chore(ci): remove dh -h steps
2025-12-01 20:14:51 +01:00
Martino Russi 37f43df88a Feat/add unitree g1 robot (#2530)
* add unitree_g1_robot_class

* finish locomotion loading code

* precommit

* separate groot locomotion logic

* remove leftover locomotion variable, unify kp kd

* format config

* properly comment config, example locomotion and unitree_g1 class

* ready to review

* download policy from the hub in `examples/unitree_g1/gr00t_locomotion`

* fix linter

* make precommit happy, add ignore flags

* linter pt3

* linter pt4

* [done] make precommit happy

* fix linter 5

* add docs

* push utils

* feat(robots): add Unitree G1 humanoid support with ZMQ bridge (#2539)

* feat(robots): add Unitree G1 humanoid support with ZMQ bridge

- Use JSON + base64 serialization for secure communication instead of pickle
- Add documentation section
- Rename robot_server to run_g1_server
- Add dependecies to pyproject.toml

* nit in docs

* remove globals use

* cast robot data to int/float

* ensure robot is connected before changing mode

* temperature can be list, average in such case

---------

Co-authored-by: Martino Russi <nopyeps@gmail.com>

* style nit

* remove transform_imu_data

* remove scipy dependency

* modify toml, add external unitree_sdk2py dep

* return actions from send_action

* cleaning

* add instructions for local deployment

* Update src/lerobot/robots/unitree_g1/unitree_g1.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>

* update config and readme

* update docs

* update docs

* remove torch import

* fix docs

* remove ip from docs

* add licence header

---------

Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-01 16:10:13 +01:00
Sota Nakamura 5f7b5f2817 remove the sampler cause the relative index is added (#2521)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-11-30 22:28:32 +01:00
Steven Palma c55fbe1b3e chore(dependencies): Bump lerobot to 0.4.3 (#2540) 2025-11-28 10:39:02 +01:00
92 changed files with 13129 additions and 331 deletions
@@ -31,7 +31,8 @@ jobs:
name: Upload Preview and Comment
if: >
github.event.workflow_run.event == 'pull_request' &&
github.event.workflow_run.conclusion == 'success'
github.event.workflow_run.conclusion == 'success' &&
github.repository == 'huggingface/lerobot'
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
with:
package_name: lerobot
+4 -2
View File
@@ -42,7 +42,9 @@ jobs:
# This job builds and deploys the official documentation.
build_main_docs:
name: Build Main Docs
if: github.event_name == 'push' || github.event_name == 'workflow_dispatch'
if: >
(github.event_name == 'push' || github.event_name == 'workflow_dispatch') &&
github.repository == 'huggingface/lerobot'
permissions:
contents: read
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
@@ -58,7 +60,7 @@ jobs:
# The result of this job triggers the 'Upload PR Documentation' workflow.
build_pr_docs:
name: Build PR Docs
if: github.event_name == 'pull_request'
if: github.event_name == 'pull_request' && github.repository == 'huggingface/lerobot'
permissions:
contents: read
pull-requests: write
+7 -1
View File
@@ -45,7 +45,6 @@ permissions:
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
concurrency:
@@ -60,12 +59,19 @@ jobs:
runs-on: ubuntu-latest
env:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
lfs: true
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
- name: Setup /mnt storage
run: sudo chown -R $USER:$USER /mnt
# TODO(Steven): Evaluate the need of these dependencies
- name: Install apt dependencies
run: |
+7
View File
@@ -58,12 +58,19 @@ jobs:
github.event_name == 'workflow_dispatch'
env:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
steps:
- uses: actions/checkout@v4
with:
lfs: true
persist-credentials: false
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
- name: Setup /mnt storage
run: sudo chown -R $USER:$USER /mnt
- name: Install apt dependencies
run: |
sudo apt-get update && sudo apt-get install -y build-essential \
+2
View File
@@ -43,6 +43,7 @@ jobs:
name: Build CPU Docker for Nightly
runs-on:
group: aws-general-8-plus
if: github.repository == 'huggingface/lerobot'
outputs:
image_tag: ${{ env.DOCKER_IMAGE_NAME_CPU }}
steps:
@@ -77,6 +78,7 @@ jobs:
name: Build GPU Docker for Nightly
runs-on:
group: aws-general-8-plus
if: github.repository == 'huggingface/lerobot'
outputs:
image_tag: ${{ env.DOCKER_IMAGE_NAME_GPU }}
steps:
+1
View File
@@ -29,6 +29,7 @@ jobs:
build-and-publish:
name: Build and publish Python distributions
runs-on: ubuntu-latest
if: github.repository == 'huggingface/lerobot'
outputs:
version: ${{ steps.extract_info.outputs.tag_version }}
permissions:
+1
View File
@@ -45,6 +45,7 @@ jobs:
stale:
name: Close Stale Issues and PRs
runs-on: ubuntu-latest
if: github.repository == 'huggingface/lerobot'
permissions:
actions: write
contents: write # only for delete-branch option
+8
View File
@@ -43,14 +43,22 @@ jobs:
full-tests:
name: Full Unbound Tests
runs-on: ubuntu-latest
if: github.repository == 'huggingface/lerobot'
env:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
steps:
- uses: actions/checkout@v4
with:
lfs: true
persist-credentials: false
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
- name: Setup /mnt storage
run: sudo chown -R $USER:$USER /mnt
- name: Install apt dependencies
run: |
sudo apt-get update && sudo apt-get install -y build-essential \
+12
View File
@@ -9,6 +9,8 @@
title: Imitation Learning for Robots
- local: cameras
title: Cameras
- local: bring_your_own_policies
title: Bring Your Own Policies
- local: integrate_hardware
title: Bring Your Own Hardware
- local: hilserl
@@ -37,6 +39,8 @@
title: π₀.₅ (Pi05)
- local: groot
title: NVIDIA GR00T N1.5
- local: xvla
title: X-VLA
title: "Policies"
- sections:
- local: async
@@ -79,11 +83,19 @@
title: Hope Jr
- local: reachy2
title: Reachy 2
- local: unitree_g1
title: Unitree G1
- local: earthrover_mini_plus
title: Earth Rover Mini
title: "Robots"
- sections:
- local: phone_teleop
title: Phone
title: "Teleoperators"
- sections:
- local: torch_accelerators
title: PyTorch accelerators
title: "Supported Hardware"
- sections:
- local: notebooks
title: Notebooks
+2 -2
View File
@@ -278,7 +278,7 @@ We found the default values of `actions_per_chunk` and `chunk_size_threshold` to
2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue.
3. **Adjust `chunk_size_threshold`**.
- Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model).
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug-visualize-queue-size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug_visualize_queue_size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
<p align="center">
<img
@@ -289,7 +289,7 @@ We found the default values of `actions_per_chunk` and `chunk_size_threshold` to
<p align="center">
<i>
The action queue size is plotted at runtime when the
`--debug-visualize-queue-size` flag is passed, for various levels of
`--debug_visualize_queue_size` flag is passed, for various levels of
`chunk_size_threshold` (`g` in the SmolVLA paper).
</i>
</p>
+175
View File
@@ -0,0 +1,175 @@
# Bring Your Own Policies
This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms.
## Step 1: Create a Policy Package
Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions.
### Package Structure
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
```bash
lerobot_policy_my_custom_policy/
├── pyproject.toml
└── src/
└── lerobot_policy_my_custom_policy/
├── __init__.py
├── configuration_my_custom_policy.py
├── modeling_my_custom_policy.py
└── processor_my_custom_policy.py
```
### Package Configuration
Set up your `pyproject.toml`:
```toml
[project]
name = "lerobot_policy_my_custom_policy"
version = "0.1.0"
dependencies = [
# your policy-specific dependencies
]
requires-python = ">= 3.11"
[build-system]
build-backend = # your-build-backend
requires = # your-build-system
```
## Step 2: Define the Policy Configuration
Create a configuration class that inherits from `PreTrainedConfig` and registers your policy type:
```python
# configuration_my_custom_policy.py
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("my_custom_policy")
@dataclass
class MyCustomPolicyConfig(PreTrainedConfig):
"""Configuration class for MyCustomPolicy.
Args:
n_obs_steps: Number of observation steps to use as input
horizon: Action prediction horizon
n_action_steps: Number of action steps to execute
hidden_dim: Hidden dimension for the policy network
# Add your policy-specific parameters here
"""
# ...PreTrainedConfig fields...
pass
def __post_init__(self):
super().__post_init__()
# Add any validation logic here
def validate_features(self) -> None:
"""Validate input/output feature compatibility."""
# Implement validation logic for your policy's requirements
pass
```
## Step 3: Implement the Policy Class
Create your policy implementation by inheriting from LeRobot's base `PreTrainedPolicy` class:
```python
# modeling_my_custom_policy.py
import torch
import torch.nn as nn
from typing import Dict, Any
from lerobot.policies.pretrained import PreTrainedPolicy
from .configuration_my_custom_policy import MyCustomPolicyConfig
class MyCustomPolicy(PreTrainedPolicy):
config_class = MyCustomPolicyConfig
name = "my_custom_policy"
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None):
super().__init__(config, dataset_stats)
...
```
## Step 4: Add Data Processors
Create processor functions:
```python
# processor_my_custom_policy.py
from typing import Dict, Any
import torch
def make_my_custom_policy_pre_post_processors(
config,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Create preprocessing and postprocessing functions for your policy."""
pass # Define your preprocessing and postprocessing logic here
```
## Step 5: Package Initialization
Expose your classes in the package's `__init__.py`:
```python
# __init__.py
"""Custom policy package for LeRobot."""
try:
import lerobot # noqa: F401
except ImportError:
raise ImportError(
"lerobot is not installed. Please install lerobot to use this policy package."
)
from .configuration_my_custom_policy import MyCustomPolicyConfig
from .modeling_my_custom_policy import MyCustomPolicy
from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors
__all__ = [
"MyCustomPolicyConfig",
"MyCustomPolicy",
"make_my_custom_policy_pre_post_processors",
]
```
## Step 6: Installation and Usage
### Install Your Policy Package
```bash
cd lerobot_policy_my_custom_policy
pip install -e .
# Or install from PyPI if published
pip install lerobot_policy_my_custom_policy
```
### Use Your Policy
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
```bash
lerobot-train \
--policy.type my_custom_policy \
--env.type pusht \
--steps 200000
```
## Examples and Community Contributions
Check out these example policy implementations:
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
Share your policy implementations with the community! 🤗
+206
View File
@@ -0,0 +1,206 @@
# EarthRover Mini Plus
The EarthRover Mini Plus is a fully open source mobile robot that connects through the cloud using the Frodobots SDK. This lets you control the robot and record datasets for training AI models.
## What You Need
### Hardware
- EarthRover Mini robot
- Computer with Python 3.10 or newer
- Internet connection
### Setting Up the Frodobots SDK
The robot needs the [Frodobots SDK](https://github.com/Frodobots/earth-rovers-sdk) running on your computer. Here's how:
1. Download and install the SDK:
```bash
git clone https://github.com/Frodobots/earth-rovers-sdk.git
cd earth-rovers-sdk
pip install -r requirements.txt
```
2. Start the SDK:
```bash
hypercorn main:app --reload
```
3. Open your web browser and go to `http://localhost:8000`, then click "Join"
The SDK gives you:
- Live video from front and rear cameras
> [!IMPORTANT]
> The SDK must be running before you can use the robot.
## Install LeRobot
Follow our [Installation Guide](./installation) to install LeRobot.
In addition to the base installation, install the EarthRover Mini dependencies:
```bash
pip install -e .
```
## How It Works
The robot uses the internet to communicate:
- **Movement commands**: Sent through the SDK
- **Camera video**: Received from the SDK
- **Robot info**: Battery, location, speed from the SDK
You don't need to plug anything in - it all works through the SDK.
## Calibration
No calibration needed! The robot is ready to use as soon as the SDK is running.
## Controlling the Robot
You control the robot using your keyboard - just like playing a video game with WASD keys.
### Keyboard Controls
| Key | Action |
| --- | -------------------------------- |
| W | Move forward |
| S | Move backward |
| A | Turn left (with forward motion) |
| D | Turn right (with forward motion) |
| Q | Rotate left in place |
| E | Rotate right in place |
| X | Stop all movement |
| +/= | Increase speed |
| - | Decrease speed |
| ESC | Disconnect |
### Speed Settings
You can adjust how fast the robot moves:
- **Forward/backward speed**: Default is full speed (1.0)
- **Turning speed**: Default is full speed (1.0)
- **Speed changes**: Use +/- keys to adjust by 0.1 each time
### Try It Out
Test driving the robot before recording data:
```python
from lerobot.robots.earthrover_mini_plus import EarthRoverMiniPlus, EarthRoverMiniPlusConfig
from lerobot.teleoperators.keyboard import KeyboardRoverTeleop, KeyboardRoverTeleopConfig
# Initialize robot
robot_config = EarthRoverMiniPlusConfig()
robot = EarthRoverMiniPlus(robot_config)
# Initialize teleoperator
teleop_config = KeyboardRoverTeleopConfig(
linear_speed=1.0,
angular_speed=1.0,
speed_increment=0.1
)
teleop = KeyboardRoverTeleop(teleop_config)
# Connect
robot.connect()
teleop.connect()
# Teleoperate (use keyboard controls)
try:
while True:
action = teleop.get_action()
robot.send_action(action)
except KeyboardInterrupt:
pass
finally:
robot.disconnect()
teleop.disconnect()
```
> [!TIP]
> If you're using a Mac, you might need to give Terminal permission to access your keyboard for teleoperation. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal.
## Recording Data
Once you can drive the robot well, you can start recording data to train AI models. The system records:
- **What you do**: How you move the robot (forward, backward, turning)
- **What the robot sees**:
- Videos from both cameras
- Robot speed and direction
- Battery level and location
- GPS position and signal
- Other sensor data
- **When it happened**: Timestamps for everything
### Setting Up Hugging Face
We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face username:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
### Start Recording
Use the standard recording command:
```bash
python src/lerobot/scripts/lerobot_record.py \
--robot.type=earthrover_mini_plus \
--teleop.type=keyboard_rover \
--dataset.repo_id=your_username/dataset_name \
--dataset.num_episodes=2 \
--dataset.fps=10 \
--dataset.single_task="Navigate around obstacles" \
--display_data=true
```
Replace `your_username/dataset_name` with your Hugging Face username and a name for your dataset.
### What Gets Saved
Your dataset includes:
**Your Actions (2 things)**:
- How much you moved forward/backward
- How much you turned left/right
**Robot Observations (12 things)**:
- Front camera video
- Rear camera video
- Current speed
- Battery level
- Which way the robot is facing
- GPS location (latitude, longitude, signal strength)
- Network signal strength
- Vibration level
- Lamp status (on/off)
### Where Your Data Goes
On your computer: `~/.cache/huggingface/lerobot/{repo-id}`
After recording, your data automatically uploads to your Hugging Face page:
```bash
echo https://huggingface.co/datasets/${HF_USER}/earthrover-navigation
```
Your dataset will be tagged with `LeRobot` for community discovery.
+2 -2
View File
@@ -428,7 +428,7 @@ Your robot should replicate movements similar to those you recorded. For example
## Train a policy
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_train.py) script. A few arguments are required. Here is an example command:
```bash
lerobot-train \
@@ -485,7 +485,7 @@ huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
## Run inference and evaluate your policy
You can use the `record` script from [`lerobot/record.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
<hfoptions id="eval">
<hfoption id="Command">
+1 -1
View File
@@ -90,7 +90,7 @@ If you encounter build errors, you may need to install additional dependencies:
To install these for linux run:
```bash
sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config
sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev
```
For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
+5
View File
@@ -62,6 +62,11 @@ lerobot-eval \
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
### Control Mode
LIBERO now supports two control modes: relative and absolute. This matters because different VLA checkpoints are trained with different mode of action to output hence control parameterizations.
You can switch them with: `env.control_mode = "relative"` and `env.control_mode = "absolute"`
### Policy inputs and outputs
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:
+125 -125
View File
@@ -30,131 +30,6 @@ The follower arm uses 6x STS3215 motors with 1/345 gearing. The leader, however,
| Wrist Roll | 5 | 1 / 147 |
| Gripper | 6 | 1 / 147 |
### Clean Parts
Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material.
It is advisable to install one 3-pin cable in the motor after placing them before continuing assembly.
### Joint 1
- Place the first motor into the base.
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
- Install both motor horns, securing the top horn with a M3x6mm screw.
- Attach the shoulder part.
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
- Add the shoulder motor holder.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint1_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Joint 2
- Slide the second motor in from the top.
- Fasten the second motor with 4 M2x6mm screws.
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
- Attach the upper arm with 4 M3x6mm screws on each side.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint2_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Joint 3
- Insert motor 3 and fasten using 4 M2x6mm screws
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint3_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Joint 4
- Slide over motor holder 4.
- Slide in motor 4.
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint4_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Joint 5
- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws.
- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw.
- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint5_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Gripper / Handle
<hfoptions id="assembly">
<hfoption id="Follower">
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
- Attach the motor horns and again use a M3x6mm horn screw.
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Gripper_v2.mp4"
type="video/mp4"
/>
</video>
</div>
</hfoption>
<hfoption id="Leader">
- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws.
- Attach the handle to motor 5 using 1 M2x6mm screw.
- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw.
- Attach the follower trigger with 4 M3x6mm screws.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Leader_v2.mp4"
type="video/mp4"
/>
</video>
</div>
</hfoption>
</hfoptions>
## Configure the motors
### 1. Find the USB ports associated with each arm
@@ -340,6 +215,131 @@ leader.setup_motors()
</hfoption>
</hfoptions>
### Clean Parts
Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material.
It is advisable to install one 3-pin cable in the motor after placing them before continuing assembly.
### Joint 1
- Place the first motor into the base.
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
- Install both motor horns, securing the top horn with a M3x6mm screw.
- Attach the shoulder part.
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
- Add the shoulder motor holder.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint1_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Joint 2
- Slide the second motor in from the top.
- Fasten the second motor with 4 M2x6mm screws.
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
- Attach the upper arm with 4 M3x6mm screws on each side.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint2_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Joint 3
- Insert motor 3 and fasten using 4 M2x6mm screws
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint3_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Joint 4
- Slide over motor holder 4.
- Slide in motor 4.
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint4_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Joint 5
- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws.
- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw.
- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint5_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Gripper / Handle
<hfoptions id="assembly">
<hfoption id="Follower">
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
- Attach the motor horns and again use a M3x6mm horn screw.
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Gripper_v2.mp4"
type="video/mp4"
/>
</video>
</div>
</hfoption>
<hfoption id="Leader">
- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws.
- Attach the handle to motor 5 using 1 M2x6mm screw.
- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw.
- Attach the follower trigger with 4 M3x6mm screws.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Leader_v2.mp4"
type="video/mp4"
/>
</video>
</div>
</hfoption>
</hfoptions>
## Calibrate
Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.
+42
View File
@@ -0,0 +1,42 @@
# PyTorch accelerators
LeRobot supports multiple hardware acceleration options for both training and inference.
These options include:
- **CPU**: CPU executes all computations, no dedicated accelerator is used
- **CUDA**: acceleration with NVIDIA & AMD GPUs
- **MPS**: acceleration with Apple Silicon GPUs
- **XPU**: acceleration with Intel integrated and discrete GPUs
## Getting Started
To use particular accelerator, a suitable version of PyTorch should be installed.
For CPU, CUDA, and MPS backends follow instructions provided on [PyTorch installation page](https://pytorch.org/get-started/locally).
For XPU backend, follow instructions from [PyTorch documentation](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html).
### Verifying the installation
After installation, accelerator availability can be verified by running
```python
import torch
print(torch.<backend_name>.is_available()) # <backend_name> is cuda, mps, or xpu
```
## How to run training or evaluation
To select the desired accelerator, use the `--policy.device` flag when running `lerobot-train` or `lerobot-eval`. For example, to use MPS on Apple Silicon, run:
```bash
lerobot-train
--policy.device=mps ...
```
```bash
lerobot-eval \
--policy.device=mps ...
```
However, in most cases, presence of an accelerator is detected automatically and `policy.device` parameter can be omitted from CLI commands.
+208
View File
@@ -0,0 +1,208 @@
# Unitree G1 Robot Setup and Control
This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion.
## About the Unitree G1
We offer support for both 29 and 23 DOF G1. We introduce:
- **`unitree g1` robot class, handling low level communication with the humanoid**
- **ZMQ socket bridge** for remote communication over WiFi, allowing one to deploy policies remotely instead of over ethernet or directly on the Orin
- **GR00T locomotion policy** for bipedal walking and balance
- **MuJoCo simulation mode** for testing policies without the physical robot
---
## Part 1: Connect to Robot over Ethernet
### Step 1: Configure Your Computer's Ethernet Interface
Set a static IP on the same subnet as the robot:
```bash
# Replace 'enp131s0' with your ethernet interface name (check with `ip a`)
sudo ip addr flush dev enp131s0
sudo ip addr add 192.168.123.200/24 dev enp131s0
sudo ip link set enp131s0 up
```
**Note**: The robot's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` where x ≠ 164.
### Step 2: SSH into the Robot
```bash
ssh unitree@192.168.123.164
# Password: 123
```
You should now be connected to the robot's onboard computer.
---
## Part 2: Enable WiFi on the Robot
Once connected via Ethernet, follow these steps to enable WiFi:
### Step 1: Enable WiFi Hardware
```bash
# Unblock WiFi radio
sudo rfkill unblock wifi
sudo rfkill unblock all
# Bring up WiFi interface
sudo ip link set wlan0 up
# Enable NetworkManager control
sudo nmcli radio wifi on
sudo nmcli device set wlan0 managed yes
sudo systemctl restart NetworkManager
```
### Step 2: Enable Internet Forwarding
**On your laptop:**
```bash
# Enable IP forwarding
sudo sysctl -w net.ipv4.ip_forward=1
# Set up NAT (replace wlp132s0f0 with your WiFi interface)
sudo iptables -t nat -A POSTROUTING -o wlp132s0f0 -s 192.168.123.0/24 -j MASQUERADE
sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTABLISHED -j ACCEPT
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
```
**On the robot:**
```bash
# Add laptop as default gateway
sudo ip route del default 2>/dev/null || true
sudo ip route add default via 192.168.123.200 dev eth0
echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
# Test connection
ping -c 3 8.8.8.8
```
### Step 3: Connect to WiFi Network
```bash
# List available networks
nmcli device wifi list
# Connect to your WiFi (example)
sudo nmcli connection add type wifi ifname wlan0 con-name "YourNetwork" ssid "YourNetwork"
sudo nmcli connection modify "YourNetwork" wifi-sec.key-mgmt wpa-psk
sudo nmcli connection modify "YourNetwork" wifi-sec.psk "YourPassword"
sudo nmcli connection modify "YourNetwork" connection.autoconnect yes
sudo nmcli connection up "YourNetwork"
# Check WiFi IP address
ip a show wlan0
```
### Step 4: SSH Over WiFi
Once connected to WiFi, note the robot's IP address and disconnect the Ethernet cable. You can now SSH over WiFi:
```bash
ssh unitree@<YOUR_ROBOT_IP>
# Password: 123
```
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address (e.g., `172.18.129.215`).
---
## Part 3: Robot Server Setup
### Step 1: Install LeRobot on the Orin
SSH into the robot and install LeRobot:
```bash
ssh unitree@<YOUR_ROBOT_IP>
conda create -y -n lerobot python=3.10
conda activate lerobot
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e '.[unitree_g1]'
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python && pip install -e .
```
**Note**: The Unitree SDK requires CycloneDDS v0.10.2 to be installed. See the [Unitree SDK documentation](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
### Step 2: Run the Robot Server
On the robot:
```bash
python src/lerobot/robots/unitree_g1/run_g1_server.py
```
**Important**: Keep this terminal running. The server must be active for remote control.
---
## Part 4: Running GR00T Locomotion
With the robot server running, you can now control the robot from your laptop.
### Step 1: Install LeRobot on your machine
```bash
conda create -y -n lerobot python=3.10
conda activate lerobot
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e '.[unitree_g1]'
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python && pip install -e .
```
### Step 2: Update Robot IP in Config
Edit the config file to match your robot's WiFi IP:
```python
# In src/lerobot/robots/unitree_g1/config_unitree_g1.py
robot_ip: str = "<YOUR_ROBOT_IP>" # Replace with your robot's WiFi IP.
```
**Note**: When running directly on the G1 (not remotely), set `robot_ip: str = "127.0.0.1"` instead.
### Step 3: Run the Locomotion Policy
```bash
# Run GR00T locomotion controller
python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1"
```
### Step 4: Control with Remote
- **Left stick**: Forward/backward and left/right movement
- **Right stick**: Rotation
- **R1 button**: Raise waist height
- **R2 button**: Lower waist height
Press `Ctrl+C` to stop the policy.
---
## Extra: Running in Simulation Mode (MuJoCo)
You can now test and develop policies without a physical robot using MuJoCo. to do so set `is_simulation=True` in config.
## Additional Resources
- [Unitree SDK Documentation](https://github.com/unitreerobotics/unitree_sdk2_python)
- [GR00T Policy Repository](https://huggingface.co/nepyope/GR00T-WholeBodyControl_g1)
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
- [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
---
_Last updated: December 2025_
+66 -3
View File
@@ -11,13 +11,14 @@ LeRobot provides several utilities for manipulating datasets:
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
4. **Add Features** - Add new features to a dataset
5. **Remove Features** - Remove features from a dataset
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
The core implementation is in `lerobot.datasets.dataset_tools`.
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
## Command-Line Tool: lerobot-edit-dataset
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features.
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, remove features, and convert image datasets to video format.
Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
@@ -86,9 +87,71 @@ lerobot-edit-dataset \
--operation.feature_names "['observation.images.top']"
```
#### Convert to Video
Convert an image-based dataset to video format, creating a new LeRobotDataset where images are stored as videos. This is useful for reducing storage requirements and improving data loading performance. The new dataset will have the exact same structure as the original, but with images encoded as MP4 videos in the proper LeRobot format.
```bash
# Local-only: Save to a custom output directory (no hub push)
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir /path/to/output/pusht_video
# Save with new repo_id (local storage)
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_to_video
# Convert and push to Hugging Face Hub
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_to_video \
--push_to_hub true
# Convert with custom video codec and quality settings
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \
--operation.vcodec libsvtav1 \
--operation.pix_fmt yuv420p \
--operation.g 2 \
--operation.crf 30
# Convert only specific episodes
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \
--operation.episode_indices "[0, 1, 2, 5, 10]"
# Convert with multiple workers for parallel processing
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \
--operation.num_workers 8
```
**Parameters:**
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
- `fast_decode`: Fast decode tuning option (default: 0)
- `episode_indices`: List of specific episodes to convert (default: all episodes)
- `num_workers`: Number of parallel workers for processing (default: 4)
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
### Push to Hub
Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
```bash
lerobot-edit-dataset \
@@ -96,7 +159,7 @@ lerobot-edit-dataset \
--new_repo_id lerobot/pusht_after_deletion \
--operation.type delete_episodes \
--operation.episode_indices "[0, 2, 5]" \
--push_to_hub
--push_to_hub true
```
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.
+528
View File
@@ -0,0 +1,528 @@
# X-VLA: The First Soft-Prompted Robot Foundation Model for Any Robot, Any Task
## Overview
For years, robotics has aspired to build agents that can follow natural human instructions and operate dexterously across many environments and robot bodies. Recent breakthroughs in LLMs and VLMs suggest a path forward: extend these foundation-model architectures to embodied control by grounding them in actions. This has led to the rise of Vision-Language-Action (VLA) models, with the hope that a single generalist model could combine broad semantic understanding with robust manipulation skills.
But training such models is difficult. Robot data is fragmented across platforms, sensors, embodiments, and collection protocols. Heterogeneity appears everywhere: different arm configurations, different action spaces, different camera setups, different visual domains, and different task distributions. These inconsistencies create major distribution shifts that make pretraining unstable and adaptation unreliable.
Inspired by meta-learning and prompt learning, we ask: **"What if a VLA model could learn the structure of each robot and dataset the same way LLMs learn tasks, through prompts?"**
**X-VLA** is a soft-prompted, flow-matching VLA framework that treats each hardware setup as a "task" and encodes it using a small set of learnable embeddings. These **Soft Prompts** capture embodiment and domain-specific variations, guiding the Transformer from the earliest stages of multimodal fusion. With this mechanism, X-VLA can reconcile diverse robot morphologies, data types, and sensor setups within a single unified architecture.
<p align="center">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture.png"
alt="XVLA Architecture"
style="max-width: 100%; height: auto; width: 800px;"
/>
</p>
Built from pure Transformer encoders, X-VLA scales naturally with model size and dataset diversity. Across 6 simulation benchmarks and 3 real robots, Soft Prompts consistently outperform existing methods in handling hardware and domain differences. X-VLA-0.9B, trained on 290K episodes spanning seven robotic platforms, learns an embodiment-agnostic generalist policy in Phase I, and adapts efficiently to new robots in Phase II simply by learning a new set of prompts, while keeping the backbone frozen.
<p align="center">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture2.png"
alt="XVLA Architecture 2"
style="width: 60%; height: auto;"
/>
</p>
With only 1% of parameters tuned (9M), X-VLA-0.9B achieves near-π₀ performance on LIBERO and Simpler-WidowX, despite using **300× fewer trainable parameters**. It also demonstrates strong real-world dexterity with minimal demonstrations, including folding cloths in under two minutes.
<p align="center">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-fold.png"
alt="XVLA fold visualization"
style="width: 95%; max-width: 1100px; height: auto;"
/>
</p>
X-VLA shows that generalist robot intelligence does not require increasingly complex architectures, only the right way to absorb heterogeneity. Soft Prompts offer a simple, scalable mechanism for unifying diverse robotic data, paving the way toward adaptable, cross-embodiment robot foundation models.
## Installation
After installing LeRobot, install the X-VLA dependencies:
```bash
pip install -e .[xvla]
```
After the new release, you'll be able to do:
```bash
pip install lerobot[xvla]
```
## Quick Start
### Basic Usage
To use X-VLA in your LeRobot configuration, specify the policy type as:
```bash
policy.type=xvla
```
### Evaluating Pre-trained Checkpoints
Example evaluation with LIBERO:
```bash
lerobot-eval \
--policy.path="lerobot/xvla-libero" \
--env.type=libero \
--env.task=libero_spatial,libero_goal,libero_10 \
--env.control_mode=absolute \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--env.episode_length=800 \
--seed=142
```
## Available Checkpoints
### 🎯 Base Model
**[lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base)**
A 0.9B parameter instantiation of X-VLA, trained with a carefully designed data processing and learning recipe. The training pipeline consists of two phases:
- **Phase I: Pretraining** - Pretrained on 290K episodes from Droid, Robomind, and Agibot, spanning seven platforms across five types of robotic arms (single-arm to bi-manual setups). By leveraging soft prompts to absorb embodiment-specific variations, the model learns an embodiment-agnostic generalist policy.
- **Phase II: Domain Adaptation** - Adapted to deployable policies for target domains. A new set of soft prompts is introduced and optimized to encode the hardware configuration of the novel domain, while the pretrained backbone remains frozen.
### Simulation Checkpoints
**[lerobot/xvla-libero](https://huggingface.co/lerobot/xvla-libero)**
Achieves 93% success rate on LIBERO benchmarks. Fine-tuned from the base model for simulation tasks.
**[lerobot/xvla-widowx](https://huggingface.co/lerobot/xvla-widowx)**
Fine-tuned on BridgeData for pick-and-place experiments on compact WidowX platforms. Demonstrates robust manipulation capabilities.
### 🤖 Real-World Checkpoints
**[lerobot/xvla-folding](https://huggingface.co/lerobot/xvla-folding)**
A fine-tuned dexterous manipulation model trained on the high-quality Soft-FOLD cloth folding dataset. Achieves 100% success rate over 2 hours of continuous cloth folding.
**[lerobot/xvla-agibot-world](https://huggingface.co/lerobot/xvla-agibot-world)**
Optimized for AgileX robot dexterous manipulation tasks.
**[lerobot/xvla-google-robot](https://huggingface.co/lerobot/xvla-google-robot)**
Adapted for Google Robot platforms.
## Training X-VLA
### Recommended Training Configuration
When fine-tuning X-VLA for a new embodiment or task, we recommend not freezing the VLM, and also setting the `policy.dtype=bfloat16` to not hit OOM errors.
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/xvla_training \
--job_name=xvla_training \
--policy.path="lerobot/xvla-base" \
--policy.repo_id="HF_USER/xvla-your-robot" \
--policy.dtype=bfloat16 \
--policy.action_mode=auto \
--steps=20000 \
--policy.device=cuda \
--policy.freeze_vision_encoder=false \
--policy.freeze_language_encoder=false \
--policy.train_policy_transformer=true \
--policy.train_soft_prompts=true \
```
### Training Parameters Explained
| Parameter | Default | Description |
| -------------------------- | ------- | ---------------------------------------------- |
| `freeze_vision_encoder` | `false` | Do not freeze the VLM vision encoder weights |
| `freeze_language_encoder` | `false` | Do not freeze the VLM language encoder weights |
| `train_policy_transformer` | `true` | Allow policy transformer layers to train |
| `train_soft_prompts` | `true` | Allow soft prompts to train |
**💡 Best Practice**: For Phase II adaptation to new embodiments, do not freeze the VLM encoders and also train the policy transformer and soft prompts.
### Example: Training on Bimanual Robot
```bash
lerobot-train \
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
--output_dir=./outputs/xvla_bimanual \
--job_name=xvla_so101_training \
--policy.path="lerobot/xvla-base" \
--policy.dtype=bfloat16 \
--policy.repo_id="YOUR_USERNAME/xvla-biso101" \
--steps=3000 \
--policy.device=cuda \
--policy.action_mode=so101_bimanual \
--policy.freeze_vision_encoder=false \
--policy.freeze_language_encoder=false \
--policy.train_policy_transformer=true \
--policy.train_soft_prompts=true
```
💡 **Best Performance:** If you have sufficient computational resources and want to achieve best X-VLA finetuning performance, you should follow the official finetuning strategy:
**🔥 Full-finetune all components with a custom learning-rate scheme**
To ensure stable optimization, the Vision-Language Model (VLM) must be trained with only 1/10 of the base learning rate, while all other components use the full LR.
This LR ratio is crucial for achieving strong and stable finetuning performance. This is already done for you by default.
❕Note
Completely matching the official reported performance may require an additional warm-up LR schedule for soft-prompts, which can bring minor improvements.
We encourage implementing this in your customized training pipeline for optimal results.
## Core Concepts
### 1. Action Modes
X-VLA uses an **Action Registry** system to handle different action spaces and embodiments. The `action_mode` parameter defines how actions are processed, what loss functions are used, and how predictions are post-processed.
#### Available Action Modes
| Action Mode | Action Dim | Description | Use Case |
| ---------------- | ----------------------- | ------------------------------------------- | ------------------------------------ |
| `ee6d` | 20 | End-effector with xyz, 6D rotation, gripper | Dual-arm setups with spatial control |
| `joint` | 14 | Joint-space with gripper | Direct joint control robots |
| `agibot_ee6d` | 20 | AGI-bot variant with MSE loss | AGI-bot platforms |
| `so101_bimanual` | 20 (model), 12 (real) | SO101 bimanual robot | Bimanual manipulation tasks |
| `auto` | 20 (model), auto (real) | Auto-detects action dim from dataset | **Recommended** for new robots |
#### Why Action Modes Matter
When you have a pretrained checkpoint like `lerobot/xvla-base` trained with `action_dim=20`, and you want to train on a dataset with a different action dimension (e.g., 14 for bimanual arms), you can't simply trim the action dimension. The action mode orchestrates:
1. **Loss Computation**: Different loss functions for different action components (MSE for joints, BCE for grippers, etc.)
2. **Preprocessing**: Zeroing out gripper channels, padding dimensions
3. **Postprocessing**: Applying sigmoid to gripper logits, trimming padding
#### Example: BimanualSO101 Action Space
The `so101_bimanual` action mode handles the mismatch between model output (20D) and real robot control (12D):
```python
# Model outputs 20 dimensions for compatibility
dim_action = 20
# Real robot only needs 12 dimensions
# [left_arm (6), right_arm (6)] = [joints (5) + gripper (1)] × 2
REAL_DIM = 12
# Preprocessing: Pad 12D actions to 20D for training
# Postprocessing: Trim 20D predictions to 12D for deployment
```
See the [action_hub.py](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py) implementation for details.
#### Auto Action Mode (Recommended)
The `auto` action mode is the easiest way to use X-VLA with any robot. It automatically detects your dataset's action dimension and handles padding/trimming:
```bash
lerobot-train \
--policy.path="lerobot/xvla-base" \
--policy.action_mode=auto \
--policy.max_action_dim=20 \
...
```
**How it works:**
- Reads `action_feature.shape[-1]` from your dataset (e.g., 7 for Franka)
- Model outputs `max_action_dim` (default 20) for pretrained compatibility
- Loss is computed **only on the real dimensions**: `MSE(pred[:,:,:real_dim], target[:,:,:real_dim])`
- Postprocess trims output back to `real_dim` for robot control
This eliminates the need to create custom action modes for most robots.
### 2. Domain IDs
Domain IDs are learnable identifiers for different robot configurations and camera setups. They allow X-VLA to distinguish between:
- Different robots (Robot 1 vs Robot 2)
- Different camera configurations (cam1 vs cam2)
- Different combinations (Robot1-cam1-cam2 vs Robot1-cam1 vs Robot2-cam1)
#### Setting Domain IDs
**During Training**: By default, domain_id is set to 0 for general training.
**During Evaluation**: Specify the domain_id that matches your checkpoint's training configuration.
```python
# Example: LIBERO checkpoint uses domain_id=3
domain_id = 3
```
The domain_id is automatically added to observations by the `XVLAAddDomainIdProcessorStep` in the preprocessing pipeline.
The `lerobot/xvla-base` model has been trained on the following domain IDs. It is recommended to choose one that most resembles your robot/configuration:
#### Fine-tuning Datasets
| Dataset Name | Domain ID |
| ---------------- | --------- |
| Bridge | 0 |
| RT1 | 1 |
| Calvin | 2 |
| libero | 3 |
| widowx-air | 4 |
| AIR-AGILEX-HQ | 5 |
| robotwin2_abs_ee | 6 |
| robotwin2_clean | 6 |
| robocasa-human | 7 |
| VLABench | 8 |
| AGIBOT-challenge | 9 |
| AIR-AGILEX | 10 |
| AIRBOT | 18 |
### 3. Processor Steps
X-VLA requires specific preprocessing and postprocessing steps for proper operation.
#### Required Preprocessing Steps
1. **XVLAImageToFloatProcessorStep**: Converts images from [0, 255] to [0, 1] range
2. **XVLAImageNetNormalizeProcessorStep**: Applies ImageNet normalization (required for VLM backbone)
3. **XVLAAddDomainIdProcessorStep**: Adds domain_id to observations
#### Example Custom Processor
For LIBERO environments, a custom processor handles the specific observation format:
```python
from lerobot.policies.xvla.processor_xvla import LiberoProcessorStep
processor = LiberoProcessorStep()
# Handles robot_state dictionary, converts rotation matrices to 6D representation
# Applies 180° image rotation for camera convention
```
### 4. Configuration Parameters
Key configuration parameters for X-VLA:
```python
# Observation and action
n_obs_steps: int = 1 # Number of observation timesteps
chunk_size: int = 32 # Action sequence length
n_action_steps: int = 32 # Number of action steps to execute
# Model architecture
hidden_size: int = 1024 # Transformer hidden dimension
depth: int = 24 # Number of transformer layers
num_heads: int = 16 # Number of attention heads
num_domains: int = 30 # Maximum number of domain IDs
len_soft_prompts: int = 32 # Length of soft prompt embeddings
# Action space
action_mode: str = "ee6d" # Action space type (use "auto" for auto-detection)
use_proprio: bool = True # Use proprioceptive state
max_state_dim: int = 32 # Maximum state dimension
max_action_dim: int = 20 # Max action dim for padding (used by "auto" mode)
# Vision
num_image_views: int | None # Number of camera views
resize_imgs_with_padding: tuple[int, int] | None # Target image size with padding
# Training
num_denoising_steps: int = 10 # Flow matching denoising steps
```
## Creating Custom Action Modes
If your robot has a unique action space, you can create a custom action mode:
### Step 1: Define Your Action Space
```python
from lerobot.policies.xvla.action_hub import BaseActionSpace, register_action
import torch.nn as nn
@register_action("my_custom_robot")
class MyCustomActionSpace(BaseActionSpace):
"""Custom action space for my robot."""
dim_action = 15 # Your robot's action dimension
gripper_idx = (7, 14) # Gripper channel indices
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
def compute_loss(self, pred, target):
"""Define your loss computation."""
# Example: MSE for joints, BCE for grippers
joints_loss = self.mse(pred[:, :, :7], target[:, :, :7])
gripper_loss = self.bce(pred[:, :, self.gripper_idx],
target[:, :, self.gripper_idx])
return {
"joints_loss": joints_loss,
"gripper_loss": gripper_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""Preprocess actions before training."""
# Example: Zero out grippers in proprioception
proprio_m = proprio.clone()
action_m = action.clone() if action is not None else None
proprio_m[..., self.gripper_idx] = 0.0
if action_m is not None:
action_m[..., self.gripper_idx] = 0.0
return proprio_m, action_m
def postprocess(self, action):
"""Post-process predictions for deployment."""
# Example: Apply sigmoid to gripper logits
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
return action
```
### Step 2: Use Your Custom Action Mode
```bash
lerobot-train \
--policy.action_mode=my_custom_robot \
--dataset.repo_id=YOUR_DATASET \
--policy.path="lerobot/xvla-base" \
...
```
## Advanced Topics
### Multi-Camera Support
X-VLA supports multiple camera views through the `num_image_views` parameter:
```python
# Configure for 3 camera views
policy.num_image_views=3
# Add empty cameras if you have fewer physical cameras
policy.empty_cameras=1 # Adds 1 zero-padded camera view
```
### Custom Preprocessing Pipeline
Create a custom preprocessing pipeline for your environment:
```python
from lerobot.processor import PolicyProcessorPipeline
from lerobot.policies.xvla.processor_xvla import (
XVLAImageToFloatProcessorStep,
XVLAImageNetNormalizeProcessorStep,
XVLAAddDomainIdProcessorStep,
)
# Build custom pipeline
preprocessor = PolicyProcessorPipeline(
steps=[
YourCustomProcessorStep(), # Your custom processing
XVLAImageToFloatProcessorStep(), # Required: convert to float
XVLAImageNetNormalizeProcessorStep(), # Required: ImageNet norm
XVLAAddDomainIdProcessorStep(domain_id=5), # Your domain ID
]
)
```
### Handling Different Action Dimensions
When your dataset has fewer action dimensions than the pretrained model:
**Option 1 (Recommended)**: Use `auto` action mode
```bash
# Automatically detects your dataset's action dimension
# Works with any robot without custom code
policy.action_mode=auto
policy.max_action_dim=20 # Match pretrained model
```
**Option 2**: Use a predefined action mode with built-in padding
```python
# Model expects 20D, dataset has 12D
# Action mode handles padding internally
action_mode = "so101_bimanual" # Pads 12 → 20
```
**Option 2**: Create a custom action mode that maps dimensions explicitly
```python
@register_action("my_mapped_action")
class MappedActionSpace(BaseActionSpace):
dim_action = 20
REAL_DIM = 12
def _pad_to_model_dim(self, x):
# Custom padding logic
...
```
## Troubleshooting
### Common Issues
**Issue**: "Action dimension mismatch"
- **Solution**: Check that your `action_mode` matches your robot's action space. Create a custom action mode if needed.
**Issue**: "Image values outside [0, 1] range"
- **Solution**: Ensure images are preprocessed with `XVLAImageToFloatProcessorStep` before normalization.
**Issue**: "Domain ID not found"
- **Solution**: Make sure `XVLAAddDomainIdProcessorStep` is in your preprocessing pipeline with the correct domain_id.
**Issue**: "Low success rate on new embodiment"
- **Solution**:
1. Verify your action_mode is correct
2. Check that soft prompts are being trained (`train_soft_prompts=True`)
3. Ensure proper preprocessing (ImageNet normalization, domain_id)
4. Consider increasing training steps
**Issue**: "Out of memory during training"
- **Solution**:
1. Reduce `chunk_size` (e.g., from 32 to 16)
2. Enable gradient checkpointing
3. Reduce batch size
4. Freeze more components
## Citation
If you use X-VLA in your research, please cite:
```bibtex
@article{zheng2025x,
title = {X-VLA: Soft-Prompted Transformer as Scalable Cross-Embodiment Vision-Language-Action Model},
author = {Zheng, Jinliang and Li, Jianxiong and Wang, Zhihao and Liu, Dongxiu and Kang, Xirui
and Feng, Yuchun and Zheng, Yinan and Zou, Jiayin and Chen, Yilun and Zeng, Jia and others},
journal = {arXiv preprint arXiv:2510.10274},
year = {2025}
}
```
## Additional Resources
- [X-VLA Paper](https://arxiv.org/pdf/2510.10274)
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
- [Action Registry Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/action_hub.py)
- [Processor Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/processor_xvla.py)
- [Model Configuration](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/configuration_xvla.py)
## Contributing
We welcome contributions! If you've implemented a new action mode or processor for your robot, please consider submitting a PR to help the community.
+454
View File
@@ -0,0 +1,454 @@
#!/usr/bin/env python3
"""
WBT (Whole Body Tracking) Dance Policy for Unitree G1
Uses ONNX model with motion data baked in.
Pattern matches gr00t_locomotion.py - uses UnitreeG1 robot class.
Usage:
python examples/unitree_g1/dance.py
"""
import argparse
import json
import logging
import threading
import time
from xml.etree import ElementTree
import numpy as np
import onnx
import onnxruntime as ort
import pinocchio as pin
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# =============================================================================
# CONFIGURATION
# =============================================================================
DANCE_ONNX_PATH = "examples/unitree_g1/fastsac_g1_29dof_dancing.onnx"
CONTROL_DT = 0.02 # 50 Hz
NUM_DOFS = 29
# Default joint positions (holosoma training defaults)
DEFAULT_DOF_POS = np.array([
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # Left leg (6)
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # Right leg (6)
0.0, 0.0, 0.0, # Waist (3)
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # Left arm (7)
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # Right arm (7)
], dtype=np.float32)
# Stiff hold KP/KD (for initialization)
STIFF_KP = np.array([
150, 150, 200, 200, 40, 40,
150, 150, 200, 200, 40, 40,
200, 200, 100,
100, 100, 100, 100, 50, 50, 50,
100, 100, 100, 100, 50, 50, 50,
], dtype=np.float32)
STIFF_KD = np.array([
2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
5.0, 5.0, 5.0,
2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
], dtype=np.float32)
# Joints to freeze at 0 with high KP
FROZEN_JOINTS = [13, 14, 20, 21, 27, 28]
FROZEN_KP = 500.0
FROZEN_KD = 5.0
# =============================================================================
# QUATERNION UTILITIES
# =============================================================================
def quat_inverse(q):
return np.concatenate((q[:, 0:1], -q[:, 1:]), axis=1)
def quat_mul(a, b):
a, b = a.reshape(-1, 4), b.reshape(-1, 4)
w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
xx = ww + yy + zz
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
w = qq - ww + (z1 - y1) * (y2 - z2)
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
return np.stack([w, x, y, z]).T.reshape(a.shape)
def subtract_frame_transforms(q01, q02):
return quat_mul(quat_inverse(q01), q02)
def matrix_from_quat(q):
r, i, j, k = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
two_s = 2.0 / (q * q).sum(-1)
o = np.stack((
1 - two_s * (j*j + k*k), two_s * (i*j - k*r), two_s * (i*k + j*r),
two_s * (i*j + k*r), 1 - two_s * (i*i + k*k), two_s * (j*k - i*r),
two_s * (i*k - j*r), two_s * (j*k + i*r), 1 - two_s * (i*i + j*j),
), -1)
return o.reshape(q.shape[:-1] + (3, 3))
def xyzw_to_wxyz(xyzw):
return np.concatenate([xyzw[:, -1:], xyzw[:, :3]], axis=1)
def quat_to_rpy(q):
w, x, y, z = q
roll = np.arctan2(2*(w*x + y*z), 1 - 2*(x**2 + y**2))
pitch = np.arcsin(np.clip(2*(w*y - z*x), -1, 1))
yaw = np.arctan2(2*(w*z + x*y), 1 - 2*(y**2 + z**2))
return roll, pitch, yaw
def rpy_to_quat(rpy):
roll, pitch, yaw = rpy
cy, sy = np.cos(yaw*0.5), np.sin(yaw*0.5)
cp, sp = np.cos(pitch*0.5), np.sin(pitch*0.5)
cr, sr = np.cos(roll*0.5), np.sin(roll*0.5)
return np.array([cr*cp*cy + sr*sp*sy, sr*cp*cy - cr*sp*sy,
cr*sp*cy + sr*cp*sy, cr*cp*sy - sr*sp*cy])
# =============================================================================
# PINOCCHIO FK
# =============================================================================
DOF_NAMES = (
"left_hip_pitch_joint", "left_hip_roll_joint", "left_hip_yaw_joint",
"left_knee_joint", "left_ankle_pitch_joint", "left_ankle_roll_joint",
"right_hip_pitch_joint", "right_hip_roll_joint", "right_hip_yaw_joint",
"right_knee_joint", "right_ankle_pitch_joint", "right_ankle_roll_joint",
"waist_yaw_joint", "waist_roll_joint", "waist_pitch_joint",
"left_shoulder_pitch_joint", "left_shoulder_roll_joint", "left_shoulder_yaw_joint", "left_elbow_joint",
"left_wrist_roll_joint", "left_wrist_pitch_joint", "left_wrist_yaw_joint",
"right_shoulder_pitch_joint", "right_shoulder_roll_joint", "right_shoulder_yaw_joint", "right_elbow_joint",
"right_wrist_roll_joint", "right_wrist_pitch_joint", "right_wrist_yaw_joint",
)
class PinocchioFK:
"""Pinocchio forward kinematics for torso_link orientation."""
def __init__(self, urdf_text: str):
root = ElementTree.fromstring(urdf_text)
for parent in root.iter():
for child in list(parent):
if child.tag.split("}")[-1] in {"visual", "collision"}:
parent.remove(child)
xml_text = '<?xml version="1.0"?>\n' + ElementTree.tostring(root, encoding="unicode")
self.model = pin.buildModelFromXML(xml_text, pin.JointModelFreeFlyer())
self.data = self.model.createData()
pin_names = [n for n in self.model.names if n not in ["universe", "root_joint"]]
self.idx_map = np.array([DOF_NAMES.index(n) for n in pin_names])
self.ref_frame_id = self.model.getFrameId("torso_link")
logger.info(f"Pinocchio FK: {len(pin_names)} joints, torso_link frame={self.ref_frame_id}")
def get_torso_quat(self, pos, quat_wxyz, dof_pos):
"""Get torso_link orientation in world frame."""
quat_xyzw = np.array([quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]])
config = np.concatenate([pos, quat_xyzw, dof_pos[self.idx_map]])
pin.framesForwardKinematics(self.model, self.data, config)
coeffs = pin.Quaternion(self.data.oMf[self.ref_frame_id].rotation).coeffs()
return np.array([coeffs[3], coeffs[0], coeffs[1], coeffs[2]]).reshape(1, 4)
# =============================================================================
# DANCE CONTROLLER
# =============================================================================
class DanceController:
"""
Handles WBT dance policy for the Unitree G1 robot.
This controller manages:
- 29-joint observation processing
- Pinocchio FK for torso orientation
- Policy inference with motion data from ONNX
"""
def __init__(self, policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale):
self.policy = policy
self.robot = robot
self.pinocchio_fk = pinocchio_fk
self.motor_kp = motor_kp
self.motor_kd = motor_kd
self.action_scale = action_scale
self.obs_dim = policy.get_inputs()[0].shape[1]
self.last_action = np.zeros((1, NUM_DOFS), dtype=np.float32)
self.motion_command = None
self.ref_quat_xyzw = None
self.timestep = 0
self.yaw_offset = 0.0
# Get initial motion data from ONNX
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
self.motion_command = np.concatenate(outs[0:2], axis=1)
self.ref_quat_xyzw = outs[2]
self.motion_start_pose = outs[0].flatten()
# Thread management
self.dance_running = False
self.dance_thread = None
logger.info(f"DanceController: obs_dim={self.obs_dim}, action_scale={action_scale}")
def capture_yaw_offset(self):
"""Capture robot's current yaw for relative tracking."""
robot_state = self.robot.lowstate_buffer.get_data()
if robot_state and self.pinocchio_fk:
quat = np.array(robot_state.imu_state.quaternion, dtype=np.float32)
dof = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof)
_, _, self.yaw_offset = quat_to_rpy(torso_q.flatten())
logger.info(f"Captured yaw offset: {np.degrees(self.yaw_offset):.1f}°")
def _remove_yaw_offset(self, quat_wxyz):
"""Remove stored yaw offset from orientation."""
if abs(self.yaw_offset) < 1e-6:
return quat_wxyz
yaw_q = rpy_to_quat((0, 0, -self.yaw_offset)).reshape(1, 4)
return quat_mul(yaw_q, quat_wxyz)
def run_step(self):
"""Single dance step - reads state, runs policy, sends commands."""
robot_state = self.robot.lowstate_buffer.get_data()
if robot_state is None:
return
# Read robot state
quat = np.array(robot_state.imu_state.quaternion, dtype=np.float32)
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
dof_pos = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
dof_vel = np.array([robot_state.motor_state[i].dq for i in range(NUM_DOFS)], dtype=np.float32)
# Compute motion_ref_ori_b using FK
if self.pinocchio_fk:
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof_pos)
torso_q = self._remove_yaw_offset(torso_q)
motion_ori = xyzw_to_wxyz(self.ref_quat_xyzw)
rel_quat = subtract_frame_transforms(torso_q, motion_ori)
ori_b = matrix_from_quat(rel_quat)[..., :2].reshape(1, -1)
else:
ori_b = np.zeros((1, 6), dtype=np.float32)
dof_rel = (dof_pos - DEFAULT_DOF_POS).reshape(1, -1)
# Build observation (alphabetical order)
obs_dict = {
"actions": self.last_action,
"base_ang_vel": ang_vel.reshape(1, 3),
"dof_pos": dof_rel,
"dof_vel": dof_vel.reshape(1, -1),
"motion_command": self.motion_command,
"motion_ref_ori_b": ori_b,
}
obs = np.concatenate([obs_dict[k].astype(np.float32) for k in sorted(obs_dict.keys())], axis=1)
obs = np.clip(obs, -100, 100)
# Run policy
outs = self.policy.run(["actions", "joint_pos", "joint_vel", "ref_quat_xyzw"],
{"obs": obs, "time_step": np.array([[self.timestep]], dtype=np.float32)})
action = np.clip(outs[0], -100, 100)
self.motion_command = np.concatenate(outs[1:3], axis=1)
self.ref_quat_xyzw = outs[3]
self.last_action = action.copy()
# Compute target positions
target_pos = DEFAULT_DOF_POS + action.flatten() * self.action_scale
# Send commands
for i in range(NUM_DOFS):
if i in FROZEN_JOINTS:
self.robot.msg.motor_cmd[i].q = 0.0
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
else:
self.robot.msg.motor_cmd[i].q = float(target_pos[i])
self.robot.msg.motor_cmd[i].kp = self.motor_kp[i]
self.robot.msg.motor_cmd[i].kd = self.motor_kd[i]
self.robot.msg.motor_cmd[i].qd = 0
self.robot.msg.motor_cmd[i].tau = 0
self.robot.send_action(self.robot.msg)
self.timestep += 1
def _dance_thread_loop(self):
"""Background thread that runs the dance policy."""
logger.info("Dance thread started")
while self.dance_running:
start_time = time.time()
try:
self.run_step()
except Exception as e:
logger.error(f"Error in dance loop: {e}")
import traceback
traceback.print_exc()
elapsed = time.time() - start_time
sleep_time = max(0, CONTROL_DT - elapsed)
time.sleep(sleep_time)
logger.info("Dance thread stopped")
def start_dance_thread(self):
"""Start the dance control thread."""
if self.dance_running:
logger.warning("Dance thread already running")
return
# Reset state for fresh start
self.timestep = 0
self.last_action.fill(0)
# Re-get initial motion data
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
self.motion_command = np.concatenate(outs[0:2], axis=1)
self.ref_quat_xyzw = outs[2]
self.capture_yaw_offset()
logger.info("Starting dance control thread...")
self.dance_running = True
self.dance_thread = threading.Thread(target=self._dance_thread_loop, daemon=True)
self.dance_thread.start()
def stop_dance_thread(self):
"""Stop the dance control thread."""
if not self.dance_running:
return
logger.info("Stopping dance control thread...")
self.dance_running = False
if self.dance_thread:
self.dance_thread.join(timeout=2.0)
logger.info("Dance control thread stopped")
def reset_to_motion_pose(self, duration: float = 3.0):
"""Move robot to initial motion pose over given duration."""
logger.info(f"Moving to dance start pose ({duration}s)...")
robot_state = self.robot.lowstate_buffer.get_data()
init_pos = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
target_pos = self.motion_start_pose
num_steps = int(duration / CONTROL_DT)
for step in range(num_steps):
alpha = step / num_steps
interp = init_pos * (1 - alpha) + target_pos * alpha
for i in range(NUM_DOFS):
if i in FROZEN_JOINTS:
self.robot.msg.motor_cmd[i].q = 0.0
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
else:
self.robot.msg.motor_cmd[i].q = float(interp[i])
self.robot.msg.motor_cmd[i].kp = STIFF_KP[i]
self.robot.msg.motor_cmd[i].kd = STIFF_KD[i]
self.robot.msg.motor_cmd[i].qd = 0
self.robot.msg.motor_cmd[i].tau = 0
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
self.robot.lowcmd_publisher.Write(self.robot.msg)
time.sleep(CONTROL_DT)
logger.info("At dance start pose!")
# =============================================================================
# MAIN
# =============================================================================
def load_dance_policy(onnx_path: str):
"""Load dance policy and extract metadata."""
logger.info(f"Loading dance policy: {onnx_path}")
policy = ort.InferenceSession(onnx_path)
model = onnx.load(onnx_path)
metadata = {p.key: json.loads(p.value) for p in model.metadata_props}
motor_kp = np.array(metadata.get("kp", STIFF_KP), dtype=np.float32)
motor_kd = np.array(metadata.get("kd", STIFF_KD), dtype=np.float32)
action_scale = float(metadata.get("action_scale", 1.0))
urdf_text = metadata.get("robot_urdf", None)
logger.info(f" Obs dim: {policy.get_inputs()[0].shape[1]}")
logger.info(f" Action scale: {action_scale}")
logger.info(f" KP range: [{motor_kp.min():.1f}, {motor_kp.max():.1f}]")
# Build Pinocchio FK if URDF available
pinocchio_fk = None
if urdf_text:
logger.info(" Building Pinocchio FK from URDF...")
pinocchio_fk = PinocchioFK(urdf_text)
else:
logger.warning(" No URDF in metadata - FK will not work!")
return policy, pinocchio_fk, motor_kp, motor_kd, action_scale
def main():
parser = argparse.ArgumentParser(description="WBT Dance Policy for Unitree G1")
parser.add_argument("--onnx", type=str, default=DANCE_ONNX_PATH, help="Path to dance ONNX model")
parser.add_argument("--sim", action="store_true", help="Run in simulation mode")
args = parser.parse_args()
print("=" * 70)
print("💃 WBT DANCE POLICY")
print("=" * 70)
# Load policy
policy, pinocchio_fk, motor_kp, motor_kd, action_scale = load_dance_policy(args.onnx)
# Initialize robot
logger.info("Initializing robot...")
config = UnitreeG1Config()
robot = UnitreeG1(config)
logger.info("Robot connected!")
# Create controller
controller = DanceController(policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale)
try:
# Move to start pose
controller.reset_to_motion_pose(duration=3.0)
# Start dancing
controller.start_dance_thread()
logger.info("Dancing! Press Ctrl+C to stop.")
print("-" * 70)
# Log status periodically
while True:
time.sleep(2.0)
logger.info(f"timestep={controller.timestep}")
except KeyboardInterrupt:
print("\n\nStopping...")
finally:
controller.stop_dance_thread()
robot.disconnect()
print("\nDone!")
if __name__ == "__main__":
main()
Binary file not shown.
+347
View File
@@ -0,0 +1,347 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example: GR00T Locomotion with Pre-loaded Policies
This example demonstrates the NEW pattern for loading GR00T policies externally
and passing them to the robot class.
"""
import argparse
import logging
import threading
import time
from collections import deque
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
logger = logging.getLogger(__name__)
GROOT_DEFAULT_ANGLES = np.zeros(29, dtype=np.float32)
GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # hip pitch
GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # knee
GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # ankle pitch
MISSING_JOINTS = []
G1_MODEL = "g1_23" # or "g1_29"
if G1_MODEL == "g1_23":
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # waist yaw/pitch, wrist pitch/yaw
LOCOMOTION_ACTION_SCALE = 0.25
LOCOMOTION_CONTROL_DT = 0.02
ANG_VEL_SCALE: float = 0.25
DOF_POS_SCALE: float = 1.0
DOF_VEL_SCALE: float = 0.05
CMD_SCALE: list = [2.0, 2.0, 0.25]
DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
def load_groot_policies(
repo_id: str = DEFAULT_GROOT_REPO_ID,
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
"""Load GR00T dual-policy system (Balance + Walk) from Hugging Face Hub.
Args:
repo_id: Hugging Face Hub repository ID containing the ONNX policies.
"""
logger.info(f"Loading GR00T dual-policy system from Hugging Face Hub ({repo_id})...")
# Download ONNX policies from Hugging Face Hub
balance_path = hf_hub_download(
repo_id=repo_id,
filename="GR00T-WholeBodyControl-Balance.onnx",
)
walk_path = hf_hub_download(
repo_id=repo_id,
filename="GR00T-WholeBodyControl-Walk.onnx",
)
# Load ONNX policies
policy_balance = ort.InferenceSession(balance_path)
policy_walk = ort.InferenceSession(walk_path)
logger.info("GR00T policies loaded successfully")
return policy_balance, policy_walk
class GrootLocomotionController:
"""
Handles GR00T-style locomotion control for the Unitree G1 robot.
This controller manages:
- Dual-policy system (Balance + Walk)
- 29-joint observation processing
- 15D action output (legs + waist)
- Policy inference and motor command generation
"""
def __init__(self, policy_balance, policy_walk, robot, config):
self.policy_balance = policy_balance
self.policy_walk = policy_walk
self.robot = robot
self.config = config
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot
# GR00T-specific state
self.groot_qj_all = np.zeros(29, dtype=np.float32)
self.groot_dqj_all = np.zeros(29, dtype=np.float32)
self.groot_action = np.zeros(15, dtype=np.float32)
self.groot_obs_single = np.zeros(86, dtype=np.float32)
self.groot_obs_history = deque(maxlen=6)
self.groot_obs_stacked = np.zeros(516, dtype=np.float32)
self.groot_height_cmd = 0.74 # Default base height
self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
# input to gr00t is 6 frames (6*86D=516)
for _ in range(6):
self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
# Thread management
self.locomotion_running = False
self.locomotion_thread = None
logger.info("GrootLocomotionController initialized")
def groot_locomotion_run(self):
# get current observation
robot_state = self.robot.get_observation()
if robot_state is None:
return
# get command from remote controller
if robot_state.wireless_remote is not None:
self.robot.remote_controller.set(robot_state.wireless_remote)
if self.robot.remote_controller.button[0]: # R1 - raise waist
self.groot_height_cmd += 0.001
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
if self.robot.remote_controller.button[4]: # R2 - lower waist
self.groot_height_cmd -= 0.001
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
else:
self.robot.remote_controller.lx = 0.0
self.robot.remote_controller.ly = 0.0
self.robot.remote_controller.rx = 0.0
self.robot.remote_controller.ry = 0.0
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # rotation rate
for i in range(29):
self.groot_qj_all[i] = robot_state.motor_state[i].q
self.groot_dqj_all[i] = robot_state.motor_state[i].dq
# adapt observation for g1_23dof
for idx in MISSING_JOINTS:
self.groot_qj_all[idx] = 0.0
self.groot_dqj_all[idx] = 0.0
# Scale joint positions and velocities
qj_obs = self.groot_qj_all.copy()
dqj_obs = self.groot_dqj_all.copy()
# express imu data in gravity frame of reference
quat = robot_state.imu_state.quaternion
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
gravity_orientation = self.robot.get_gravity_orientation(quat)
# scale joint positions and velocities before policy inference
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
dqj_obs = dqj_obs * DOF_VEL_SCALE
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
# build single frame observation
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE)
self.groot_obs_single[3] = self.groot_height_cmd
self.groot_obs_single[4:7] = self.groot_orientation_cmd
self.groot_obs_single[7:10] = ang_vel_scaled
self.groot_obs_single[10:13] = gravity_orientation
self.groot_obs_single[13:42] = qj_obs
self.groot_obs_single[42:71] = dqj_obs
self.groot_obs_single[71:86] = self.groot_action # 15D previous actions
# Add to history and stack observations (6 frames × 86D = 516D)
self.groot_obs_history.append(self.groot_obs_single.copy())
# Stack all 6 frames into 516D vector
for i, obs_frame in enumerate(self.groot_obs_history):
start_idx = i * 86
end_idx = start_idx + 86
self.groot_obs_stacked[start_idx:end_idx] = obs_frame
# Run policy inference (ONNX) with 516D stacked observation
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
selected_policy = (
self.policy_balance if cmd_magnitude < 0.05 else self.policy_walk
) # balance/standing policy for small commands, walking policy for movement commands
# run policy inference
ort_inputs = {selected_policy.get_inputs()[0].name: np.expand_dims(self.groot_obs_stacked, axis=0)}
ort_outs = selected_policy.run(None, ort_inputs)
self.groot_action = ort_outs[0].squeeze()
# transform action back to target joint positions
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * LOCOMOTION_ACTION_SCALE
# command motors
for i in range(15):
motor_idx = i
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# adapt action for g1_23dof
for joint_idx in MISSING_JOINTS:
self.robot.msg.motor_cmd[joint_idx].q = 0.0
self.robot.msg.motor_cmd[joint_idx].qd = 0
self.robot.msg.motor_cmd[joint_idx].kp = self.robot.kp[joint_idx]
self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd[joint_idx]
self.robot.msg.motor_cmd[joint_idx].tau = 0
# send action to robot
self.robot.send_action(self.robot.msg)
def _locomotion_thread_loop(self):
"""Background thread that runs the locomotion policy at specified rate."""
logger.info("Locomotion thread started")
while self.locomotion_running:
start_time = time.time()
try:
self.groot_locomotion_run()
except Exception as e:
logger.error(f"Error in locomotion loop: {e}")
# Sleep to maintain control rate
elapsed = time.time() - start_time
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
time.sleep(sleep_time)
logger.info("Locomotion thread stopped")
def start_locomotion_thread(self):
if self.locomotion_running:
logger.warning("Locomotion thread already running")
return
logger.info("Starting locomotion control thread...")
self.locomotion_running = True
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
self.locomotion_thread.start()
logger.info("Locomotion control thread started!")
def stop_locomotion_thread(self):
if not self.locomotion_running:
return
logger.info("Stopping locomotion control thread...")
self.locomotion_running = False
if self.locomotion_thread:
self.locomotion_thread.join(timeout=2.0)
logger.info("Locomotion control thread stopped")
def reset_robot(self):
"""Move robot legs to default standing position over 2 seconds (arms are not moved)."""
total_time = 3.0
num_step = int(total_time / self.robot.control_dt)
# Only control legs, not arms (first 12 joints)
default_pos = GROOT_DEFAULT_ANGLES # First 12 values are leg angles
dof_size = len(default_pos)
# Get current lowstate
robot_state = self.robot.get_observation()
# Record the current leg positions
init_dof_pos = np.zeros(dof_size, dtype=np.float32)
for i in range(dof_size):
init_dof_pos[i] = robot_state.motor_state[i].q
# Move legs to default pos
for i in range(num_step):
alpha = i / num_step
for motor_idx in range(dof_size):
target_pos = default_pos[motor_idx]
self.robot.msg.motor_cmd[motor_idx].q = (
init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
)
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
self.robot.msg.motor_cmd[motor_idx].tau = 0
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
self.robot.lowcmd_publisher.Write(self.robot.msg)
time.sleep(self.robot.control_dt)
logger.info("Reached default position (legs only)")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1")
parser.add_argument(
"--repo-id",
type=str,
default=DEFAULT_GROOT_REPO_ID,
help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})",
)
args = parser.parse_args()
# load policies
policy_balance, policy_walk = load_groot_policies(repo_id=args.repo_id)
# initialize robot
config = UnitreeG1Config()
robot = UnitreeG1(config)
# initialize gr00t locomotion controller
groot_controller = GrootLocomotionController(
policy_balance=policy_balance,
policy_walk=policy_walk,
robot=robot,
config=config,
)
# reset legs and start locomotion thread
try:
groot_controller.reset_robot()
groot_controller.start_locomotion_thread()
# log status
logger.info("Robot initialized with GR00T locomotion policies")
logger.info("Locomotion controller running in background thread")
logger.info("Press Ctrl+C to stop")
# keep robot alive
while True:
time.sleep(1.0)
except KeyboardInterrupt:
print("\nStopping locomotion...")
groot_controller.stop_locomotion_thread()
print("Done!")
+479
View File
@@ -0,0 +1,479 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example: Holosoma Whole-Body Locomotion (23-DOF and 29-DOF)
This example demonstrates loading Holosoma whole-body locomotion policies
and running them on the Unitree G1 robot.
Supports both:
- 23-DOF native policies (82D observations, 23D actions)
- 29-DOF policies (100D observations, 29D actions)
"""
import argparse
import logging
import threading
import time
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# =============================================================================
# 29-DOF Configuration
# =============================================================================
# fmt: off
HOLOSOMA_29DOF_DEFAULT_ANGLES = np.array([
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg
0.0, 0.0, 0.0, # waist (yaw, roll, pitch)
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # left arm
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # right arm
], dtype=np.float32)
HOLOSOMA_29DOF_KP = np.array([
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # left leg
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # right leg
40.179238471, 28.501246196, 28.501246196, # waist
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, 16.778327481, 16.778327481, # left arm
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, 16.778327481, 16.778327481, # right arm
], dtype=np.float32)
HOLOSOMA_29DOF_KD = np.array([
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # left leg
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # right leg
2.557889765, 1.814445687, 1.814445687, # waist
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, 1.068141502, 1.068141502, # left arm
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, 1.068141502, 1.068141502, # right arm
], dtype=np.float32)
# =============================================================================
# 23-DOF Configuration (native G1-23: no waist_roll/pitch, no wrist_pitch/yaw)
# Derived from 29-DOF Holosoma values
# =============================================================================
# Joint order: 6 left leg, 6 right leg, 1 waist_yaw, 5 left arm, 5 right arm
HOLOSOMA_23DOF_DEFAULT_ANGLES = np.array([
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg (from 29-DOF)
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg (from 29-DOF)
0.0, # waist_yaw only (from 29-DOF)
0.2, 0.2, 0.0, 0.6, 0.0, # left arm first 5 joints (from 29-DOF)
0.2, -0.2, 0.0, 0.6, 0.0, # right arm first 5 joints (from 29-DOF)
], dtype=np.float32)
HOLOSOMA_23DOF_KP = np.array([
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # left leg
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # right leg
40.179238471, # waist_yaw
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, # left arm
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, # right arm
], dtype=np.float32)
HOLOSOMA_23DOF_KD = np.array([
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # left leg
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # right leg
2.557889765, # waist_yaw
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, # left arm
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, # right arm
], dtype=np.float32)
# Maps 23-DOF policy index → 29-DOF motor index
# 23-DOF: legs(0-11), waist_yaw(12), L_arm(13-17), R_arm(18-22)
# 29-DOF: legs(0-11), waist(12-14), L_arm(15-21), R_arm(22-28)
DOF_23_TO_MOTOR_MAP = [
0, 1, 2, 3, 4, 5, # left leg → motor 0-5
6, 7, 8, 9, 10, 11, # right leg → motor 6-11
12, # waist_yaw → motor 12
15, 16, 17, 18, 19, # left arm (skip wrist_pitch/yaw) → motor 15-19
22, 23, 24, 25, 26, # right arm (skip wrist_pitch/yaw) → motor 22-26
]
# fmt: on
# Control parameters
LOCOMOTION_CONTROL_DT = 0.02 # 50Hz
LOCOMOTION_ACTION_SCALE = 0.25
ANG_VEL_SCALE = 0.25
DOF_POS_SCALE = 1.0
DOF_VEL_SCALE = 0.05
GAIT_PERIOD = 1.0
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
def load_holosoma_policy(
repo_id: str = DEFAULT_HOLOSOMA_REPO_ID,
policy_name: str = "fastsac",
local_path: str | None = None,
) -> tuple[ort.InferenceSession, int]:
"""Load Holosoma policy and detect observation dimension.
Returns:
(policy, obs_dim) tuple where obs_dim is 82 (23-DOF) or 100 (29-DOF)
"""
if local_path is not None:
logger.info(f"Loading policy from local path: {local_path}")
policy_path = local_path
else:
logger.info(f"Loading policy from Hugging Face Hub: {repo_id}")
policy_path = hf_hub_download(repo_id=repo_id, filename=f"{policy_name}_g1_29dof.onnx")
policy = ort.InferenceSession(policy_path)
# Detect observation dimension from model input shape
input_shape = policy.get_inputs()[0].shape
obs_dim = input_shape[1] if len(input_shape) > 1 else input_shape[0]
logger.info(f"Policy loaded successfully")
logger.info(f" Input: {policy.get_inputs()[0].name}, shape: {input_shape} → obs_dim={obs_dim}")
logger.info(f" Output: {policy.get_outputs()[0].name}, shape: {policy.get_outputs()[0].shape}")
return policy, obs_dim
class HolosomaLocomotionController:
"""
Handles Holosoma whole-body locomotion for Unitree G1.
Supports both 23-DOF (82D obs) and 29-DOF (100D obs) policies.
"""
def __init__(self, policy, robot, config, obs_dim: int = 100):
self.policy = policy
self.robot = robot
self.config = config
self.obs_dim = obs_dim
# Detect policy type from observation dimension
self.is_23dof = (obs_dim == 82)
self.num_dof = 23 if self.is_23dof else 29
# Velocity commands
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
# State variables sized for policy type
self.qj = np.zeros(self.num_dof, dtype=np.float32)
self.dqj = np.zeros(self.num_dof, dtype=np.float32)
self.locomotion_action = np.zeros(self.num_dof, dtype=np.float32)
self.locomotion_obs = np.zeros(obs_dim, dtype=np.float32)
self.last_unscaled_action = np.zeros(self.num_dof, dtype=np.float32)
# Select config based on DOF
if self.is_23dof:
self.default_angles = HOLOSOMA_23DOF_DEFAULT_ANGLES
self.kp = HOLOSOMA_23DOF_KP
self.kd = HOLOSOMA_23DOF_KD
self.motor_map = DOF_23_TO_MOTOR_MAP
else:
self.default_angles = HOLOSOMA_29DOF_DEFAULT_ANGLES
self.kp = HOLOSOMA_29DOF_KP
self.kd = HOLOSOMA_29DOF_KD
self.motor_map = list(range(29)) # Identity map for 29-DOF
# Phase state for gait
self.phase = np.zeros((1, 2), dtype=np.float32)
self.phase[0, 0] = 0.0
self.phase[0, 1] = np.pi
self.phase_dt = 2 * np.pi / (50.0 * GAIT_PERIOD)
self.is_standing = False
self.counter = 0
self.locomotion_running = False
self.locomotion_thread = None
logger.info(f"HolosomaLocomotionController initialized")
logger.info(f" Mode: {'23-DOF (82D obs)' if self.is_23dof else '29-DOF (100D obs)'}")
logger.info(f" Action dim: {self.num_dof}")
def holosoma_locomotion_run(self):
"""Main locomotion loop - handles both 23-DOF and 29-DOF."""
self.counter += 1
if self.counter == 1:
print("\n" + "=" * 60)
print(f"🚀 RUNNING HOLOSOMA {self.num_dof}-DOF LOCOMOTION POLICY")
print(f" {self.obs_dim}D observations → {self.num_dof}D actions")
print("=" * 60 + "\n")
robot_state = self.robot.get_observation()
if robot_state is None:
return
# Remote controller
if robot_state.wireless_remote is not None:
self.robot.remote_controller.set(robot_state.wireless_remote)
else:
self.robot.remote_controller.lx = 0.0
self.robot.remote_controller.ly = 0.0
self.robot.remote_controller.rx = 0.0
self.robot.remote_controller.ry = 0.0
# Deadzone
ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0
lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0
rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0
self.locomotion_cmd[0] = ly
self.locomotion_cmd[1] = -lx
self.locomotion_cmd[2] = -rx
# Read joint states using motor map
for i in range(self.num_dof):
motor_idx = self.motor_map[i]
self.qj[i] = robot_state.motor_state[motor_idx].q
self.dqj[i] = robot_state.motor_state[motor_idx].dq
# IMU
quat = robot_state.imu_state.quaternion
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
gravity_orientation = self.robot.get_gravity_orientation(quat)
# Scale observations
qj_obs = (self.qj - self.default_angles) * DOF_POS_SCALE
dqj_obs = self.dqj * DOF_VEL_SCALE
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
# Phase update
cmd_norm = np.linalg.norm(self.locomotion_cmd[:2])
ang_cmd_norm = np.abs(self.locomotion_cmd[2])
if cmd_norm < 0.01 and ang_cmd_norm < 0.01:
self.phase[0, :] = np.pi * np.ones(2)
self.is_standing = True
elif self.is_standing:
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
self.is_standing = False
else:
phase_tp1 = self.phase + self.phase_dt
self.phase = np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi
sin_phase = np.sin(self.phase[0, :])
cos_phase = np.cos(self.phase[0, :])
# Build observation (format depends on DOF)
if self.is_23dof:
# 82D: [23 actions, 3 ang_vel, 1 cmd_yaw, 2 cmd_lin, 2 cos, 23 pos, 23 vel, 3 grav, 2 sin]
self.locomotion_obs[0:23] = self.last_unscaled_action
self.locomotion_obs[23:26] = ang_vel_scaled
self.locomotion_obs[26] = self.locomotion_cmd[2]
self.locomotion_obs[27:29] = self.locomotion_cmd[:2]
self.locomotion_obs[29:31] = cos_phase
self.locomotion_obs[31:54] = qj_obs
self.locomotion_obs[54:77] = dqj_obs
self.locomotion_obs[77:80] = gravity_orientation
self.locomotion_obs[80:82] = sin_phase
else:
# 100D: [29 actions, 3 ang_vel, 1 cmd_yaw, 2 cmd_lin, 2 cos, 29 pos, 29 vel, 3 grav, 2 sin]
self.locomotion_obs[0:29] = self.last_unscaled_action
self.locomotion_obs[29:32] = ang_vel_scaled
self.locomotion_obs[32] = self.locomotion_cmd[2]
self.locomotion_obs[33:35] = self.locomotion_cmd[:2]
self.locomotion_obs[35:37] = cos_phase
self.locomotion_obs[37:66] = qj_obs
self.locomotion_obs[66:95] = dqj_obs
self.locomotion_obs[95:98] = gravity_orientation
self.locomotion_obs[98:100] = sin_phase
# Policy inference
obs_input = self.locomotion_obs.reshape(1, -1).astype(np.float32)
ort_inputs = {self.policy.get_inputs()[0].name: obs_input}
ort_outs = self.policy.run(None, ort_inputs)
raw_action = ort_outs[0].squeeze()
clipped_action = np.clip(raw_action, -100.0, 100.0)
self.last_unscaled_action = clipped_action.copy()
self.locomotion_action = clipped_action * LOCOMOTION_ACTION_SCALE
# Debug
if self.counter <= 3:
print(f"\n[Holosoma Debug #{self.counter}]")
print(f" Phase: ({self.phase[0, 0]:.3f}, {self.phase[0, 1]:.3f})")
print(f" Cmd: ({self.locomotion_cmd[0]:.2f}, {self.locomotion_cmd[1]:.2f}, {self.locomotion_cmd[2]:.2f})")
print(f" Action range: [{raw_action.min():.3f}, {raw_action.max():.3f}]")
# Compute target positions
target_dof_pos = self.default_angles + self.locomotion_action
# Send commands to motors via motor map
for i in range(self.num_dof):
motor_idx = self.motor_map[i]
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# For 23-DOF: zero out missing joints (waist_roll/pitch, wrist_pitch/yaw)
if self.is_23dof:
missing_motors = [13, 14, 20, 21, 27, 28] # waist_roll, waist_pitch, wrist_pitch/yaw
for motor_idx in missing_motors:
self.robot.msg.motor_cmd[motor_idx].q = 0.0
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
self.robot.msg.motor_cmd[motor_idx].tau = 0
self.robot.send_action(self.robot.msg)
def _locomotion_thread_loop(self):
logger.info("Locomotion thread started")
while self.locomotion_running:
start_time = time.time()
try:
self.holosoma_locomotion_run()
except Exception as e:
logger.error(f"Error in locomotion loop: {e}")
import traceback
traceback.print_exc()
elapsed = time.time() - start_time
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
time.sleep(sleep_time)
logger.info("Locomotion thread stopped")
def start_locomotion_thread(self):
if self.locomotion_running:
logger.warning("Locomotion thread already running")
return
logger.info("Starting locomotion control thread...")
self.locomotion_running = True
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
self.locomotion_thread.start()
logger.info("Locomotion control thread started!")
def stop_locomotion_thread(self):
if not self.locomotion_running:
return
logger.info("Stopping locomotion control thread...")
self.locomotion_running = False
if self.locomotion_thread:
self.locomotion_thread.join(timeout=2.0)
logger.info("Locomotion control thread stopped")
def reset_robot(self):
"""Move joints to default position."""
logger.info(f"Moving {self.num_dof} joints to default position...")
total_time = 3.0
num_step = int(total_time / self.robot.control_dt)
robot_state = self.robot.get_observation()
# Record current positions
init_dof_pos = np.zeros(self.num_dof, dtype=np.float32)
for i in range(self.num_dof):
motor_idx = self.motor_map[i]
init_dof_pos[i] = robot_state.motor_state[motor_idx].q
# Interpolate to target
for step in range(num_step):
alpha = step / num_step
for i in range(self.num_dof):
motor_idx = self.motor_map[i]
target = self.default_angles[i]
self.robot.msg.motor_cmd[motor_idx].q = init_dof_pos[i] * (1 - alpha) + target * alpha
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Zero missing joints for 23-DOF
if self.is_23dof:
for motor_idx in [13, 14, 20, 21, 27, 28]:
self.robot.msg.motor_cmd[motor_idx].q = 0.0
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
self.robot.msg.motor_cmd[motor_idx].tau = 0
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
self.robot.lowcmd_publisher.Write(self.robot.msg)
time.sleep(self.robot.control_dt)
logger.info(f"Reached default position ({self.num_dof} joints)")
# Hold for 2 seconds
logger.info("Holding default position for 2 seconds...")
hold_steps = int(2.0 / self.robot.control_dt)
for _ in range(hold_steps):
for i in range(self.num_dof):
motor_idx = self.motor_map[i]
self.robot.msg.motor_cmd[motor_idx].q = self.default_angles[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
if self.is_23dof:
for motor_idx in [13, 14, 20, 21, 27, 28]:
self.robot.msg.motor_cmd[motor_idx].q = 0.0
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
self.robot.msg.motor_cmd[motor_idx].tau = 0
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
self.robot.lowcmd_publisher.Write(self.robot.msg)
time.sleep(self.robot.control_dt)
logger.info("Ready to start locomotion!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Holosoma Locomotion Controller for Unitree G1")
parser.add_argument("--repo-id", type=str, default=DEFAULT_HOLOSOMA_REPO_ID)
parser.add_argument("--policy", type=str, default="fastsac", choices=["fastsac", "ppo"])
parser.add_argument("--local-path", type=str, default=None, help="Path to local ONNX file")
args = parser.parse_args()
# Load policy and detect dimensions
policy, obs_dim = load_holosoma_policy(
repo_id=args.repo_id,
policy_name=args.policy,
local_path=args.local_path,
)
# Initialize robot
config = UnitreeG1Config()
robot = UnitreeG1(config)
# Initialize controller with detected obs_dim
controller = HolosomaLocomotionController(
policy=policy,
robot=robot,
config=config,
obs_dim=obs_dim,
)
try:
#controller.reset_robot()
controller.start_locomotion_thread()
logger.info(f"Robot initialized with Holosoma {'23-DOF' if obs_dim == 82 else '29-DOF'} policy")
logger.info("Use remote controller: LY=fwd/back, LX=left/right, RX=rotate")
logger.info("Press Ctrl+C to stop")
while True:
time.sleep(1.0)
except KeyboardInterrupt:
print("\nStopping locomotion...")
controller.stop_locomotion_thread()
print("Done!")
+607
View File
@@ -0,0 +1,607 @@
#!/usr/bin/env python3
"""
Locomotion ↔ Dance Toggle for Unitree G1
Press Enter to instantly switch between locomotion and dance modes.
- Starts in LOCOMOTION mode (joystick control)
- Press Enter → DANCE mode (resets to frame 0)
- Press Enter → LOCOMOTION mode
- Repeat...
Auto-recovery feature:
- If robot tilts beyond threshold during dance, auto-switches to locomotion
- When robot recovers (tilt below recovery threshold), resumes dance from where it left off
Usage:
python examples/unitree_g1/locomotion_to_dance.py
python examples/unitree_g1/locomotion_to_dance.py --tilt-threshold 25 --recovery-threshold 10
"""
import argparse
import json
import logging
import select
import sys
import threading
import time
from xml.etree import ElementTree
import numpy as np
import onnx
import onnxruntime as ort
import pinocchio as pin
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# =============================================================================
# CONFIGURATION
# =============================================================================
NUM_DOFS = 29
CONTROL_DT = 0.02 # 50Hz
# Locomotion config
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
LOCOMOTION_ACTION_SCALE = 0.25
ANG_VEL_SCALE = 0.25
DOF_POS_SCALE = 1.0
DOF_VEL_SCALE = 0.05
GAIT_PERIOD = 1.0
# Dance config
DANCE_ONNX_PATH = "examples/unitree_g1/fastsac_g1_29dof_dancing.onnx"
FROZEN_JOINTS = [13, 14, 20, 21, 27, 28]
FROZEN_KP = 500.0
FROZEN_KD = 5.0
# fmt: off
# 29-DOF defaults (holosoma training)
DEFAULT_29DOF_ANGLES = np.array([
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg
0.0, 0.0, 0.0, # waist
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # left arm
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # right arm
], dtype=np.float32)
DEFAULT_29DOF_KP = np.array([
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
40.179, 28.501, 28.501,
14.251, 14.251, 14.251, 14.251, 14.251, 16.778, 16.778,
14.251, 14.251, 14.251, 14.251, 14.251, 16.778, 16.778,
], dtype=np.float32)
DEFAULT_29DOF_KD = np.array([
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
2.558, 1.814, 1.814,
0.907, 0.907, 0.907, 0.907, 0.907, 1.068, 1.068,
0.907, 0.907, 0.907, 0.907, 0.907, 1.068, 1.068,
], dtype=np.float32)
# 23-DOF config (no waist_roll/pitch, no wrist_pitch/yaw)
DEFAULT_23DOF_ANGLES = np.array([
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg
0.0, # waist_yaw only
0.2, 0.2, 0.0, 0.6, 0.0, # left arm (5 joints)
0.2, -0.2, 0.0, 0.6, 0.0, # right arm (5 joints)
], dtype=np.float32)
DEFAULT_23DOF_KP = np.array([
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
40.179,
14.251, 14.251, 14.251, 14.251, 14.251,
14.251, 14.251, 14.251, 14.251, 14.251,
], dtype=np.float32)
DEFAULT_23DOF_KD = np.array([
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
2.558,
0.907, 0.907, 0.907, 0.907, 0.907,
0.907, 0.907, 0.907, 0.907, 0.907,
], dtype=np.float32)
# 23-DOF policy index → 29-DOF motor index
DOF_23_TO_MOTOR = [
0, 1, 2, 3, 4, 5, # left leg
6, 7, 8, 9, 10, 11, # right leg
12, # waist_yaw
15, 16, 17, 18, 19, # left arm (skip wrist_pitch/yaw)
22, 23, 24, 25, 26, # right arm (skip wrist_pitch/yaw)
]
MISSING_23DOF_MOTORS = [13, 14, 20, 21, 27, 28]
# fmt: on
# =============================================================================
# QUATERNION UTILITIES
# =============================================================================
def quat_inverse(q):
return np.concatenate((q[:, 0:1], -q[:, 1:]), axis=1)
def quat_mul(a, b):
a, b = a.reshape(-1, 4), b.reshape(-1, 4)
w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
xx = ww + yy + zz
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
w = qq - ww + (z1 - y1) * (y2 - z2)
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
return np.stack([w, x, y, z]).T.reshape(a.shape)
def subtract_frame_transforms(q01, q02):
return quat_mul(quat_inverse(q01), q02)
def matrix_from_quat(q):
r, i, j, k = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
two_s = 2.0 / (q * q).sum(-1)
o = np.stack((
1 - two_s * (j*j + k*k), two_s * (i*j - k*r), two_s * (i*k + j*r),
two_s * (i*j + k*r), 1 - two_s * (i*i + k*k), two_s * (j*k - i*r),
two_s * (i*k - j*r), two_s * (j*k + i*r), 1 - two_s * (i*i + j*j),
), -1)
return o.reshape(q.shape[:-1] + (3, 3))
def xyzw_to_wxyz(xyzw):
return np.concatenate([xyzw[:, -1:], xyzw[:, :3]], axis=1)
def quat_to_rpy(q):
w, x, y, z = q
roll = np.arctan2(2*(w*x + y*z), 1 - 2*(x**2 + y**2))
pitch = np.arcsin(np.clip(2*(w*y - z*x), -1, 1))
yaw = np.arctan2(2*(w*z + x*y), 1 - 2*(y**2 + z**2))
return roll, pitch, yaw
def rpy_to_quat(rpy):
roll, pitch, yaw = rpy
cy, sy = np.cos(yaw*0.5), np.sin(yaw*0.5)
cp, sp = np.cos(pitch*0.5), np.sin(pitch*0.5)
cr, sr = np.cos(roll*0.5), np.sin(roll*0.5)
return np.array([cr*cp*cy + sr*sp*sy, sr*cp*cy - cr*sp*sy,
cr*sp*cy + sr*cp*sy, cr*cp*sy - sr*sp*cy])
# =============================================================================
# PINOCCHIO FK
# =============================================================================
DOF_NAMES = (
"left_hip_pitch_joint", "left_hip_roll_joint", "left_hip_yaw_joint",
"left_knee_joint", "left_ankle_pitch_joint", "left_ankle_roll_joint",
"right_hip_pitch_joint", "right_hip_roll_joint", "right_hip_yaw_joint",
"right_knee_joint", "right_ankle_pitch_joint", "right_ankle_roll_joint",
"waist_yaw_joint", "waist_roll_joint", "waist_pitch_joint",
"left_shoulder_pitch_joint", "left_shoulder_roll_joint", "left_shoulder_yaw_joint", "left_elbow_joint",
"left_wrist_roll_joint", "left_wrist_pitch_joint", "left_wrist_yaw_joint",
"right_shoulder_pitch_joint", "right_shoulder_roll_joint", "right_shoulder_yaw_joint", "right_elbow_joint",
"right_wrist_roll_joint", "right_wrist_pitch_joint", "right_wrist_yaw_joint",
)
class PinocchioFK:
def __init__(self, urdf_text: str):
root = ElementTree.fromstring(urdf_text)
for parent in root.iter():
for child in list(parent):
if child.tag.split("}")[-1] in {"visual", "collision"}:
parent.remove(child)
xml_text = '<?xml version="1.0"?>\n' + ElementTree.tostring(root, encoding="unicode")
self.model = pin.buildModelFromXML(xml_text, pin.JointModelFreeFlyer())
self.data = self.model.createData()
pin_names = [n for n in self.model.names if n not in ["universe", "root_joint"]]
self.idx_map = np.array([DOF_NAMES.index(n) for n in pin_names])
self.ref_frame_id = self.model.getFrameId("torso_link")
def get_torso_quat(self, pos, quat_wxyz, dof_pos):
quat_xyzw = np.array([quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]])
config = np.concatenate([pos, quat_xyzw, dof_pos[self.idx_map]])
pin.framesForwardKinematics(self.model, self.data, config)
coeffs = pin.Quaternion(self.data.oMf[self.ref_frame_id].rotation).coeffs()
return np.array([coeffs[3], coeffs[0], coeffs[1], coeffs[2]]).reshape(1, 4)
def get_torso_tilt(self, pos, quat_wxyz, dof_pos):
"""Get torso tilt angle from upright (degrees). Uses roll and pitch."""
torso_q = self.get_torso_quat(pos, quat_wxyz, dof_pos)
roll, pitch, _ = quat_to_rpy(torso_q.flatten())
# Tilt is the angle from vertical - combine roll and pitch
tilt_rad = np.sqrt(roll**2 + pitch**2)
return np.degrees(tilt_rad), np.degrees(roll), np.degrees(pitch)
# =============================================================================
# LOCOMOTION CONTROLLER
# =============================================================================
class LocomotionController:
"""Holosoma whole-body locomotion (23-DOF or 29-DOF)."""
def __init__(self, policy, robot, obs_dim: int):
self.policy = policy
self.robot = robot
self.obs_dim = obs_dim
# Detect DOF mode
self.is_23dof = (obs_dim == 82)
self.num_dof = 23 if self.is_23dof else 29
if self.is_23dof:
self.default_angles = DEFAULT_23DOF_ANGLES
self.kp = DEFAULT_23DOF_KP
self.kd = DEFAULT_23DOF_KD
self.motor_map = DOF_23_TO_MOTOR
logger.info("Locomotion: 23-DOF (82D obs)")
else:
self.default_angles = DEFAULT_29DOF_ANGLES
self.kp = DEFAULT_29DOF_KP
self.kd = DEFAULT_29DOF_KD
self.motor_map = list(range(29))
logger.info("Locomotion: 29-DOF (100D obs)")
self.cmd = np.zeros(3, dtype=np.float32)
self.qj = np.zeros(self.num_dof, dtype=np.float32)
self.dqj = np.zeros(self.num_dof, dtype=np.float32)
self.obs = np.zeros(obs_dim, dtype=np.float32)
self.last_action = np.zeros(self.num_dof, dtype=np.float32)
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
self.phase_dt = 2 * np.pi / (50.0 * GAIT_PERIOD)
self.is_standing = True
def run_step(self):
"""Single locomotion step."""
state = self.robot.lowstate_buffer.get_data()
if state is None:
return
# Joystick
if state.wireless_remote is not None:
self.robot.remote_controller.set(state.wireless_remote)
ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0
lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0
rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0
self.cmd[0], self.cmd[1], self.cmd[2] = ly, -lx, -rx
# Read joints via motor map
for i in range(self.num_dof):
self.qj[i] = state.motor_state[self.motor_map[i]].q
self.dqj[i] = state.motor_state[self.motor_map[i]].dq
# IMU
quat = state.imu_state.quaternion
ang_vel = np.array(state.imu_state.gyroscope, dtype=np.float32)
gravity = self.robot.get_gravity_orientation(quat)
# Scale
qj_obs = (self.qj - self.default_angles) * DOF_POS_SCALE
dqj_obs = self.dqj * DOF_VEL_SCALE
ang_vel_s = ang_vel * ANG_VEL_SCALE
# Phase
cmd_mag = np.linalg.norm(self.cmd[:2])
ang_mag = abs(self.cmd[2])
if cmd_mag < 0.01 and ang_mag < 0.01:
self.phase[0, :] = np.pi
self.is_standing = True
elif self.is_standing:
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
self.is_standing = False
else:
self.phase = np.fmod(self.phase + self.phase_dt + np.pi, 2*np.pi) - np.pi
sin_ph, cos_ph = np.sin(self.phase[0]), np.cos(self.phase[0])
# Build obs
if self.is_23dof:
self.obs[0:23] = self.last_action
self.obs[23:26] = ang_vel_s
self.obs[26] = self.cmd[2]
self.obs[27:29] = self.cmd[:2]
self.obs[29:31] = cos_ph
self.obs[31:54] = qj_obs
self.obs[54:77] = dqj_obs
self.obs[77:80] = gravity
self.obs[80:82] = sin_ph
else:
self.obs[0:29] = self.last_action
self.obs[29:32] = ang_vel_s
self.obs[32] = self.cmd[2]
self.obs[33:35] = self.cmd[:2]
self.obs[35:37] = cos_ph
self.obs[37:66] = qj_obs
self.obs[66:95] = dqj_obs
self.obs[95:98] = gravity
self.obs[98:100] = sin_ph
# Inference
obs_in = self.obs.reshape(1, -1).astype(np.float32)
ort_in = {self.policy.get_inputs()[0].name: obs_in}
raw_action = self.policy.run(None, ort_in)[0].squeeze()
clipped = np.clip(raw_action, -100.0, 100.0)
self.last_action = clipped.copy()
scaled = clipped * LOCOMOTION_ACTION_SCALE
target = self.default_angles + scaled
# Send commands
for i in range(self.num_dof):
motor_idx = self.motor_map[i]
self.robot.msg.motor_cmd[motor_idx].q = float(target[i])
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Zero missing joints for 23-DOF
if self.is_23dof:
for idx in MISSING_23DOF_MOTORS:
self.robot.msg.motor_cmd[idx].q = 0.0
self.robot.msg.motor_cmd[idx].qd = 0
self.robot.msg.motor_cmd[idx].kp = 40.0
self.robot.msg.motor_cmd[idx].kd = 2.0
self.robot.msg.motor_cmd[idx].tau = 0
self.robot.send_action(self.robot.msg)
def reset(self):
"""Reset state for fresh start."""
self.last_action.fill(0)
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
self.is_standing = True
# =============================================================================
# DANCE CONTROLLER
# =============================================================================
class DanceController:
"""WBT dance policy with FK for torso tracking."""
def __init__(self, policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale):
self.policy = policy
self.robot = robot
self.pinocchio_fk = pinocchio_fk
self.motor_kp = motor_kp
self.motor_kd = motor_kd
self.action_scale = action_scale
self.obs_dim = policy.get_inputs()[0].shape[1]
self.last_action = np.zeros((1, NUM_DOFS), dtype=np.float32)
self.motion_command = None
self.ref_quat_xyzw = None
self.timestep = 0
self.yaw_offset = 0.0
logger.info(f"Dance: obs_dim={self.obs_dim}, action_scale={action_scale}")
def initialize(self, reset_to_frame_0: bool = True):
"""Initialize dance. If reset_to_frame_0=True, starts from frame 0. Otherwise resumes."""
if reset_to_frame_0:
self.timestep = 0
self.last_action.fill(0)
# Get initial motion data at frame 0
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
self.motion_command = np.concatenate(outs[0:2], axis=1)
self.ref_quat_xyzw = outs[2]
logger.info("Dance: reset to frame 0")
else:
# Resume from current timestep - just update motion command for current frame
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
{"obs": dummy, "time_step": np.array([[self.timestep]], dtype=np.float32)})
self.motion_command = np.concatenate(outs[0:2], axis=1)
self.ref_quat_xyzw = outs[2]
logger.info(f"Dance: resuming from frame {self.timestep}")
# Capture yaw offset
state = self.robot.lowstate_buffer.get_data()
if state and self.pinocchio_fk:
quat = np.array(state.imu_state.quaternion, dtype=np.float32)
dof = np.array([state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof)
_, _, self.yaw_offset = quat_to_rpy(torso_q.flatten())
logger.info(f"Dance yaw offset: {np.degrees(self.yaw_offset):.1f}°")
def _remove_yaw_offset(self, quat_wxyz):
if abs(self.yaw_offset) < 1e-6:
return quat_wxyz
yaw_q = rpy_to_quat((0, 0, -self.yaw_offset)).reshape(1, 4)
return quat_mul(yaw_q, quat_wxyz)
def run_step(self):
"""Single dance step."""
state = self.robot.lowstate_buffer.get_data()
if state is None:
return
quat = np.array(state.imu_state.quaternion, dtype=np.float32)
ang_vel = np.array(state.imu_state.gyroscope, dtype=np.float32)
dof_pos = np.array([state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
dof_vel = np.array([state.motor_state[i].dq for i in range(NUM_DOFS)], dtype=np.float32)
# FK for torso orientation
if self.pinocchio_fk:
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof_pos)
torso_q = self._remove_yaw_offset(torso_q)
motion_ori = xyzw_to_wxyz(self.ref_quat_xyzw)
rel_quat = subtract_frame_transforms(torso_q, motion_ori)
ori_b = matrix_from_quat(rel_quat)[..., :2].reshape(1, -1)
else:
ori_b = np.zeros((1, 6), dtype=np.float32)
dof_rel = (dof_pos - DEFAULT_29DOF_ANGLES).reshape(1, -1)
# Build obs (alphabetical)
obs_dict = {
"actions": self.last_action,
"base_ang_vel": ang_vel.reshape(1, 3),
"dof_pos": dof_rel,
"dof_vel": dof_vel.reshape(1, -1),
"motion_command": self.motion_command,
"motion_ref_ori_b": ori_b,
}
obs = np.concatenate([obs_dict[k].astype(np.float32) for k in sorted(obs_dict.keys())], axis=1)
obs = np.clip(obs, -100, 100)
# Inference
outs = self.policy.run(["actions", "joint_pos", "joint_vel", "ref_quat_xyzw"],
{"obs": obs, "time_step": np.array([[self.timestep]], dtype=np.float32)})
action = np.clip(outs[0], -100, 100)
self.motion_command = np.concatenate(outs[1:3], axis=1)
self.ref_quat_xyzw = outs[3]
self.last_action = action.copy()
target = DEFAULT_29DOF_ANGLES + action.flatten() * self.action_scale
# Send commands
for i in range(NUM_DOFS):
if i in FROZEN_JOINTS:
self.robot.msg.motor_cmd[i].q = 0.0
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
else:
self.robot.msg.motor_cmd[i].q = float(target[i])
self.robot.msg.motor_cmd[i].kp = self.motor_kp[i]
self.robot.msg.motor_cmd[i].kd = self.motor_kd[i]
self.robot.msg.motor_cmd[i].qd = 0
self.robot.msg.motor_cmd[i].tau = 0
self.robot.send_action(self.robot.msg)
self.timestep += 1
# =============================================================================
# MAIN
# =============================================================================
def main():
parser = argparse.ArgumentParser(description="Locomotion ↔ Dance Toggle")
parser.add_argument("--loco-repo", type=str, default=DEFAULT_HOLOSOMA_REPO_ID)
parser.add_argument("--dance-onnx", type=str, default=DANCE_ONNX_PATH)
args = parser.parse_args()
print("=" * 70)
print("🚶 LOCOMOTION ↔ 💃 DANCE")
print("=" * 70)
print("Press ENTER to toggle between modes")
print("=" * 70)
# Load locomotion policy
logger.info("Loading locomotion policy...")
loco_path = hf_hub_download(repo_id=args.loco_repo, filename="fastsac_g1_29dof.onnx")
loco_policy = ort.InferenceSession(loco_path)
loco_obs_dim = loco_policy.get_inputs()[0].shape[1]
logger.info(f"Locomotion: {loco_obs_dim}D obs")
# Load dance policy
logger.info("Loading dance policy...")
dance_policy = ort.InferenceSession(args.dance_onnx)
dance_model = onnx.load(args.dance_onnx)
dance_meta = {p.key: json.loads(p.value) for p in dance_model.metadata_props}
dance_kp = np.array(dance_meta.get("kp", DEFAULT_29DOF_KP), dtype=np.float32)
dance_kd = np.array(dance_meta.get("kd", DEFAULT_29DOF_KD), dtype=np.float32)
dance_action_scale = float(dance_meta.get("action_scale", 1.0))
logger.info(f"Dance: {dance_policy.get_inputs()[0].shape[1]}D obs, scale={dance_action_scale}")
# Build Pinocchio FK
pinocchio_fk = None
if "robot_urdf" in dance_meta:
logger.info("Building Pinocchio FK...")
pinocchio_fk = PinocchioFK(dance_meta["robot_urdf"])
# Initialize robot
logger.info("Initializing robot...")
config = UnitreeG1Config()
robot = UnitreeG1(config)
logger.info("Robot connected!")
# Create controllers
loco_ctrl = LocomotionController(loco_policy, robot, loco_obs_dim)
dance_ctrl = DanceController(dance_policy, robot, pinocchio_fk, dance_kp, dance_kd, dance_action_scale)
# State
mode = "locomotion"
toggle_event = threading.Event()
shutdown = threading.Event()
# Input thread
def input_loop():
while not shutdown.is_set():
if select.select([sys.stdin], [], [], 0.1)[0]:
sys.stdin.readline()
toggle_event.set()
input_thread = threading.Thread(target=input_loop, daemon=True)
input_thread.start()
print("\n🚶 LOCOMOTION MODE - Use joystick to walk")
print(" Press ENTER to switch to DANCE")
print("-" * 70)
step = 0
try:
while not shutdown.is_set():
t0 = time.time()
# Check toggle
if toggle_event.is_set():
toggle_event.clear()
if mode == "locomotion":
mode = "dance"
dance_ctrl.initialize()
print("\n" + "=" * 70)
print("💃 DANCE MODE (frame 0)")
print(" Press ENTER to switch to LOCOMOTION")
print("=" * 70)
else:
mode = "locomotion"
loco_ctrl.reset()
print("\n" + "=" * 70)
print("🚶 LOCOMOTION MODE")
print(" Press ENTER to switch to DANCE")
print("=" * 70)
# Run controller
if mode == "locomotion":
loco_ctrl.run_step()
else:
dance_ctrl.run_step()
# Log
if step % 100 == 0:
if mode == "locomotion":
print(f"[LOCO ] step={step:5d} cmd=[{loco_ctrl.cmd[0]:.2f},{loco_ctrl.cmd[1]:.2f},{loco_ctrl.cmd[2]:.2f}]")
else:
print(f"[DANCE] step={step:5d} timestep={dance_ctrl.timestep}")
step += 1
elapsed = time.time() - t0
if elapsed < CONTROL_DT:
time.sleep(CONTROL_DT - elapsed)
except KeyboardInterrupt:
print("\n\nStopping...")
finally:
shutdown.set()
robot.disconnect()
print("Done!")
if __name__ == "__main__":
main()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,447 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example: Unitree RL 12-DOF Legs-Only Locomotion (TorchScript)
This example demonstrates loading a 12-DOF legs-only locomotion policy
(TorchScript .pt format) and running it on the Unitree G1 robot.
Key characteristics:
- Single TorchScript policy (.pt)
- 47D observations, 12D actions (legs only)
- Phase-based gait timing
- Arms and waist held at fixed positions
"""
import argparse
import logging
import threading
import time
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from scipy.spatial.transform import Rotation as R
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 12-DOF leg joint configuration
# Joint order: [L_hip_pitch, L_hip_roll, L_hip_yaw, L_knee, L_ankle_pitch, L_ankle_roll,
# R_hip_pitch, R_hip_roll, R_hip_yaw, R_knee, R_ankle_pitch, R_ankle_roll]
LEG_JOINT_INDICES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
# Default leg angles for standing
DEFAULT_LEG_ANGLES = np.array([
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # left leg
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # right leg
], dtype=np.float32)
# KP/KD for leg joints
LEG_KPS = np.array([150, 150, 150, 300, 40, 40, 150, 150, 150, 300, 40, 40], dtype=np.float32)
LEG_KDS = np.array([6, 6, 6, 4, 2, 2, 6, 6, 6, 4, 2, 2], dtype=np.float32)
# Waist configuration (held at zero)
WAIST_JOINT_INDICES = [12, 13, 14] # yaw, roll, pitch
WAIST_KPS = np.array([250, 250, 250], dtype=np.float32)
WAIST_KDS = np.array([5, 5, 5], dtype=np.float32)
# Arm configuration (indices 15-28, held at initial position)
ARM_JOINT_INDICES = list(range(15, 29))
ARM_KPS = np.array([80, 80, 80, 80, 40, 40, 40, # left arm (shoulder + wrist)
80, 80, 80, 80, 40, 40, 40], dtype=np.float32) # right arm
ARM_KDS = np.array([3, 3, 3, 3, 1.5, 1.5, 1.5,
3, 3, 3, 3, 1.5, 1.5, 1.5], dtype=np.float32)
# Control parameters
LOCOMOTION_CONTROL_DT = 0.02 # 50Hz control rate
LOCOMOTION_ACTION_SCALE = 0.25
ANG_VEL_SCALE = 0.25
DOF_POS_SCALE = 1.0
DOF_VEL_SCALE = 0.05
CMD_SCALE = np.array([2.0, 2.0, 0.25], dtype=np.float32)
MAX_CMD = np.array([0.8, 0.5, 1.57], dtype=np.float32) # max vx, vy, yaw_rate
# Gait parameters
GAIT_PERIOD = 0.8 # seconds
DEFAULT_REPO_ID = "nepyope/unitree_rl_locomotion"
def load_torchscript_policy(
repo_id: str = DEFAULT_REPO_ID,
filename: str = "motion.pt",
) -> torch.jit.ScriptModule:
"""Load TorchScript locomotion policy from Hugging Face Hub.
Args:
repo_id: Hugging Face Hub repository ID containing the policy.
filename: Policy filename (default: motion.pt).
"""
logger.info(f"Loading TorchScript policy from Hugging Face Hub ({repo_id}/{filename})...")
policy_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
)
policy = torch.jit.load(policy_path)
policy.eval()
logger.info("TorchScript policy loaded successfully")
return policy
class UnitreeRLLocomotionController:
"""
Handles 12-DOF legs-only locomotion control for the Unitree G1 robot.
This controller manages:
- Single TorchScript policy
- 47D observations (single frame)
- 12D action output (legs only)
- Arms and waist held at fixed positions
- Phase-based gait timing
"""
def __init__(self, policy, robot, config):
self.policy = policy
self.robot = robot
self.config = config
# Velocity commands (vx, vy, yaw_rate)
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
# State variables (12 DOF legs)
self.qj = np.zeros(12, dtype=np.float32)
self.dqj = np.zeros(12, dtype=np.float32)
self.locomotion_action = np.zeros(12, dtype=np.float32)
self.locomotion_obs = np.zeros(47, dtype=np.float32)
# Initial arm positions (captured on reset)
self.initial_arm_positions = np.zeros(14, dtype=np.float32)
# Counter for phase calculation
self.counter = 0
# Thread management
self.locomotion_running = False
self.locomotion_thread = None
logger.info("UnitreeRLLocomotionController initialized")
logger.info(" Observation dim: 47, Action dim: 12 (legs only)")
def locomotion_run(self):
"""12-DOF legs-only locomotion policy loop."""
self.counter += 1
if self.counter == 1:
print("\n" + "=" * 60)
print("🚀 RUNNING UNITREE RL 12-DOF LOCOMOTION POLICY")
print(" 47D observations → 12D actions (legs only)")
print(" Arms and waist held at fixed positions")
print("=" * 60 + "\n")
# Get current observation
robot_state = self.robot.get_observation()
if robot_state is None:
return
# Get command from remote controller
if robot_state.wireless_remote is not None:
self.robot.remote_controller.set(robot_state.wireless_remote)
else:
self.robot.remote_controller.lx = 0.0
self.robot.remote_controller.ly = 0.0
self.robot.remote_controller.rx = 0.0
self.robot.remote_controller.ry = 0.0
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right (inverted)
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # yaw (inverted)
# Get leg joint positions and velocities (12 DOF)
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
self.qj[i] = robot_state.motor_state[motor_idx].q
self.dqj[i] = robot_state.motor_state[motor_idx].dq
# Get IMU data
quat = robot_state.imu_state.quaternion
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
# Scale observations
gravity_orientation = self.robot.get_gravity_orientation(quat)
qj_obs = (self.qj - DEFAULT_LEG_ANGLES) * DOF_POS_SCALE
dqj_obs = self.dqj * DOF_VEL_SCALE
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
# Calculate phase
count = self.counter * LOCOMOTION_CONTROL_DT
phase = (count % GAIT_PERIOD) / GAIT_PERIOD
sin_phase = np.sin(2 * np.pi * phase)
cos_phase = np.cos(2 * np.pi * phase)
# Build 47D observation vector
# [0:3] - angular velocity (scaled)
# [3:6] - gravity orientation
# [6:9] - velocity command (scaled)
# [9:21] - joint positions (12D, relative to default)
# [21:33] - joint velocities (12D, scaled)
# [33:45] - previous actions (12D)
# [45] - sin_phase
# [46] - cos_phase
self.locomotion_obs[0:3] = ang_vel_scaled
self.locomotion_obs[3:6] = gravity_orientation
self.locomotion_obs[6:9] = self.locomotion_cmd * CMD_SCALE * MAX_CMD
self.locomotion_obs[9:21] = qj_obs
self.locomotion_obs[21:33] = dqj_obs
self.locomotion_obs[33:45] = self.locomotion_action
self.locomotion_obs[45] = sin_phase
self.locomotion_obs[46] = cos_phase
# Run policy inference (TorchScript)
obs_tensor = torch.from_numpy(self.locomotion_obs).unsqueeze(0).float()
with torch.no_grad():
action_tensor = self.policy(obs_tensor)
self.locomotion_action = action_tensor.squeeze().numpy()
# Transform action to target joint positions
target_leg_pos = DEFAULT_LEG_ANGLES + self.locomotion_action * LOCOMOTION_ACTION_SCALE
# Debug logging (first 3 iterations)
if self.counter <= 3:
print(f"\n[Unitree RL Debug #{self.counter}]")
print(f" Phase: {phase:.3f} (sin={sin_phase:.3f}, cos={cos_phase:.3f})")
print(f" Cmd (vx, vy, yaw): ({self.locomotion_cmd[0]:.2f}, {self.locomotion_cmd[1]:.2f}, {self.locomotion_cmd[2]:.2f})")
print(f" Action range: [{self.locomotion_action.min():.3f}, {self.locomotion_action.max():.3f}]")
# Send commands to LEG motors (0-11)
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
self.robot.msg.motor_cmd[motor_idx].q = target_leg_pos[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Hold WAIST motors at zero (12, 13, 14)
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
self.robot.msg.motor_cmd[motor_idx].q = 0.0
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Hold ARM motors at initial position (15-28)
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Send command
self.robot.send_action(self.robot.msg)
def _locomotion_thread_loop(self):
"""Background thread that runs the locomotion policy at specified rate."""
logger.info("Locomotion thread started")
while self.locomotion_running:
start_time = time.time()
try:
self.locomotion_run()
except Exception as e:
logger.error(f"Error in locomotion loop: {e}")
import traceback
traceback.print_exc()
# Sleep to maintain control rate
elapsed = time.time() - start_time
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
time.sleep(sleep_time)
logger.info("Locomotion thread stopped")
def start_locomotion_thread(self):
if self.locomotion_running:
logger.warning("Locomotion thread already running")
return
logger.info("Starting locomotion control thread...")
self.locomotion_running = True
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
self.locomotion_thread.start()
logger.info("Locomotion control thread started!")
def stop_locomotion_thread(self):
if not self.locomotion_running:
return
logger.info("Stopping locomotion control thread...")
self.locomotion_running = False
if self.locomotion_thread:
self.locomotion_thread.join(timeout=2.0)
logger.info("Locomotion control thread stopped")
def reset_robot(self):
"""Move legs to default standing position over 2 seconds (arms are captured and held)."""
logger.info("Moving legs to default position...")
total_time = 2.0
num_step = int(total_time / self.robot.control_dt)
# Get current state
robot_state = self.robot.get_observation()
# Capture initial arm positions (to hold during locomotion)
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
self.initial_arm_positions[i] = robot_state.motor_state[motor_idx].q
logger.info(f"Captured initial arm positions: {self.initial_arm_positions[:4]}...")
# Record current leg positions
init_leg_pos = np.zeros(12, dtype=np.float32)
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
init_leg_pos[i] = robot_state.motor_state[motor_idx].q
# Interpolate legs to default position
for step in range(num_step):
alpha = step / num_step
# Interpolate leg positions
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
target_pos = DEFAULT_LEG_ANGLES[i]
self.robot.msg.motor_cmd[motor_idx].q = (
init_leg_pos[i] * (1 - alpha) + target_pos * alpha
)
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Hold waist at zero
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
self.robot.msg.motor_cmd[motor_idx].q = 0.0
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Hold arms at initial position
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
self.robot.lowcmd_publisher.Write(self.robot.msg)
time.sleep(self.robot.control_dt)
logger.info("Reached default leg position")
# Hold position for 2 seconds
logger.info("Holding default position for 2 seconds...")
hold_time = 2.0
num_hold_steps = int(hold_time / self.robot.control_dt)
for _ in range(num_hold_steps):
# Hold legs at default
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
self.robot.msg.motor_cmd[motor_idx].q = DEFAULT_LEG_ANGLES[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Hold waist at zero
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
self.robot.msg.motor_cmd[motor_idx].q = 0.0
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Hold arms at initial position
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
self.robot.lowcmd_publisher.Write(self.robot.msg)
time.sleep(self.robot.control_dt)
logger.info("Ready to start locomotion!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Unitree RL 12-DOF Locomotion Controller for Unitree G1")
parser.add_argument(
"--repo-id",
type=str,
default=DEFAULT_REPO_ID,
help=f"Hugging Face Hub repo ID for policy (default: {DEFAULT_REPO_ID})",
)
parser.add_argument(
"--filename",
type=str,
default="motion.pt",
help="Policy filename (default: motion.pt)",
)
args = parser.parse_args()
# Load policy
policy = load_torchscript_policy(repo_id=args.repo_id, filename=args.filename)
# Initialize robot
config = UnitreeG1Config()
robot = UnitreeG1(config)
# Initialize locomotion controller
locomotion_controller = UnitreeRLLocomotionController(
policy=policy,
robot=robot,
config=config,
)
# Reset robot and start locomotion thread
try:
locomotion_controller.reset_robot()
locomotion_controller.start_locomotion_thread()
# Log status
logger.info("Robot initialized with Unitree RL locomotion policy")
logger.info("Locomotion controller running in background thread")
logger.info("Use remote controller to command velocity:")
logger.info(" Left stick Y: forward/backward")
logger.info(" Left stick X: left/right")
logger.info(" Right stick X: rotate")
logger.info("Press Ctrl+C to stop")
# Keep robot alive
while True:
time.sleep(1.0)
except KeyboardInterrupt:
print("\nStopping locomotion...")
locomotion_controller.stop_locomotion_thread()
print("Done!")
+11 -4
View File
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.4.2"
version = "0.4.3"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
readme = "README.md"
license = { text = "Apache-2.0" }
@@ -107,6 +107,10 @@ dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
unitree_g1 = [
"pyzmq>=26.2.1,<28.0.0",
"onnxruntime>=1.16.0"
]
reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"]
kinematics = ["lerobot[placo-dep]"]
intelrealsense = [
@@ -129,6 +133,7 @@ groot = [
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
xvla = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
@@ -157,6 +162,7 @@ all = [
"lerobot[pi]",
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[async]",
"lerobot[dev]",
@@ -257,6 +263,7 @@ default.extend-ignore-identifiers-re = [
"ein",
"thw",
"inpt",
"ROBOTIS",
]
# TODO: Uncomment when ready to use
@@ -356,9 +363,9 @@ ignore_errors = false
# module = "lerobot.async_inference.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.transport.*"
# ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.transport.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.scripts.*"
+1 -1
View File
@@ -26,4 +26,4 @@ DEFAULT_OBS_QUEUE_TIMEOUT = 2
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
# TODO: Add all other robots
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower"]
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower", "omx_follower"]
@@ -54,6 +54,7 @@ from lerobot.robots import ( # noqa: F401
bi_so100_follower,
koch_follower,
make_robot_from_config,
omx_follower,
so100_follower,
so101_follower,
)
+57 -20
View File
@@ -136,21 +136,40 @@ def update_meta_data(
df["_orig_chunk"] = df[orig_chunk_col].copy()
df["_orig_file"] = df[orig_file_col].copy()
# Update chunk and file indices to point to destination
df[orig_chunk_col] = video_idx["chunk"]
df[orig_file_col] = video_idx["file"]
# Apply per-source-file timestamp offsets
# Get mappings for this video key
src_to_offset = video_idx.get("src_to_offset", {})
if src_to_offset:
# Apply offset based on original source file
src_to_dst = video_idx.get("src_to_dst", {})
# Apply per-source-file mappings
if src_to_dst:
# Map each episode to its correct destination file and apply offset
for idx in df.index:
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
# Convert to Python int to avoid numpy type mismatch in dict lookup
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
# Get destination chunk/file for this source file
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
df.at[idx, orig_chunk_col] = dst_chunk
df.at[idx, orig_file_col] = dst_file
# Apply timestamp offset
offset = src_to_offset.get(src_key, 0)
df.at[idx, f"videos/{key}/from_timestamp"] += offset
df.at[idx, f"videos/{key}/to_timestamp"] += offset
elif src_to_offset:
# Fallback: use same destination for all, but apply per-file offsets
df[orig_chunk_col] = video_idx["chunk"]
df[orig_file_col] = video_idx["file"]
for idx in df.index:
# Convert to Python int to avoid numpy type mismatch in dict lookup
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
offset = src_to_offset.get(src_key, 0)
df.at[idx, f"videos/{key}/from_timestamp"] += offset
df.at[idx, f"videos/{key}/to_timestamp"] += offset
else:
# Fallback to simple offset (for backward compatibility)
df[orig_chunk_col] = video_idx["chunk"]
df[orig_file_col] = video_idx["file"]
df[f"videos/{key}/from_timestamp"] = (
df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
)
@@ -268,6 +287,12 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
videos_idx[key]["episode_duration"] = 0
# Track offset for each source (chunk, file) pair
videos_idx[key]["src_to_offset"] = {}
# Track destination (chunk, file) for each source (chunk, file) pair
videos_idx[key]["src_to_dst"] = {}
# Initialize dst_file_durations if not present
# dst_file_durations tracks duration of each destination file
if "dst_file_durations" not in videos_idx[key]:
videos_idx[key]["dst_file_durations"] = {}
for key, video_idx in videos_idx.items():
unique_chunk_file_pairs = {
@@ -282,9 +307,13 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
chunk_idx = video_idx["chunk"]
file_idx = video_idx["file"]
current_offset = video_idx["latest_duration"]
dst_file_durations = video_idx["dst_file_durations"]
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
# Convert to Python int to ensure consistent dict keys
src_chunk_idx = int(src_chunk_idx)
src_file_idx = int(src_file_idx)
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key,
chunk_index=src_chunk_idx,
@@ -298,14 +327,17 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
)
src_duration = get_video_duration_in_s(src_path)
dst_key = (chunk_idx, file_idx)
if not dst_path.exists():
# Store offset before incrementing
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
# New destination file: offset is 0
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(src_path), str(dst_path))
# Track duration of this destination file
dst_file_durations[dst_key] = src_duration
videos_idx[key]["episode_duration"] += src_duration
current_offset += src_duration
continue
# Check file sizes before appending
@@ -313,10 +345,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
dst_size = get_file_size_in_mb(dst_path)
if dst_size + src_size >= video_files_size_in_mb:
# Rotate to a new file, this source becomes start of new destination
# So its offset should be 0
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
# Rotate to a new file - offset is 0
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
dst_key = (chunk_idx, file_idx)
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key,
chunk_index=chunk_idx,
@@ -324,16 +357,20 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
)
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(src_path), str(dst_path))
# Reset offset for next file
current_offset = src_duration
# Track duration of this new destination file
dst_file_durations[dst_key] = src_duration
else:
# Append to existing video file - use current accumulated offset
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
# Append to existing destination file
# Offset is the current duration of this destination file
current_dst_duration = dst_file_durations.get(dst_key, 0)
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
concatenate_video_files(
[dst_path, src_path],
dst_path,
)
current_offset += src_duration
# Update duration of this destination file
dst_file_durations[dst_key] = current_dst_duration + src_duration
videos_idx[key]["episode_duration"] += src_duration
+2 -1
View File
@@ -245,7 +245,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
class LiberoEnv(EnvConfig):
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
fps: int = 30
episode_length: int = 520
episode_length: int | None = None
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
@@ -272,6 +272,7 @@ class LiberoEnv(EnvConfig):
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
}
)
control_mode: str = "relative" # or "absolute"
def __post_init__(self):
if self.obs_type == "pixels":
+9
View File
@@ -19,8 +19,10 @@ from typing import Any
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import ProcessorStep
from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
@@ -39,6 +41,7 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
def make_env_pre_post_processors(
env_cfg: EnvConfig,
policy_cfg: PreTrainedConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
@@ -61,6 +64,10 @@ def make_env_pre_post_processors(
# Preprocessor and Postprocessor steps are Identity for most environments
preprocessor_steps: list[ProcessorStep] = []
postprocessor_steps: list[ProcessorStep] = []
if isinstance(policy_cfg, XVLAConfig):
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
return make_xvla_libero_pre_post_processors()
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
@@ -136,6 +143,8 @@ def make_env(
init_states=cfg.init_states,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
control_mode=cfg.control_mode,
episode_length=cfg.episode_length,
)
elif "metaworld" in cfg.type:
from lerobot.envs.metaworld import create_metaworld_envs
+26 -5
View File
@@ -80,10 +80,7 @@ def get_libero_dummy_action():
return [0, 0, 0, 0, 0, 0, -1]
OBS_STATE_DIM = 8
ACTION_DIM = 7
AGENT_POS_LOW = -1000.0
AGENT_POS_HIGH = 1000.0
ACTION_LOW = -1.0
ACTION_HIGH = 1.0
TASK_SUITE_MAX_STEPS: dict[str, int] = {
@@ -103,6 +100,7 @@ class LiberoEnv(gym.Env):
task_suite: Any,
task_id: int,
task_suite_name: str,
episode_length: int | None = None,
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
obs_type: str = "pixels",
render_mode: str = "rgb_array",
@@ -114,6 +112,7 @@ class LiberoEnv(gym.Env):
episode_index: int = 0,
camera_name_mapping: dict[str, str] | None = None,
num_steps_wait: int = 10,
control_mode: str = "relative",
):
super().__init__()
self.task_id = task_id
@@ -141,14 +140,19 @@ class LiberoEnv(gym.Env):
self.camera_name_mapping = camera_name_mapping
self.num_steps_wait = num_steps_wait
self.episode_index = episode_index
self.episode_length = episode_length
# Load once and keep
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._env = self._make_envs_task(task_suite, self.task_id)
default_steps = 500
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
self._max_episode_steps = (
TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
if self.episode_length is None
else self.episode_length
)
self.control_mode = control_mode
images = {}
for cam in self.camera_name:
images[self.camera_name_mapping[cam]] = spaces.Box(
@@ -296,6 +300,15 @@ class LiberoEnv(gym.Env):
# Increasing this value can improve determinism and reproducibility across resets.
for _ in range(self.num_steps_wait):
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
if self.control_mode == "absolute":
for robot in self._env.robots:
robot.controller.use_delta = False
elif self.control_mode == "relative":
for robot in self._env.robots:
robot.controller.use_delta = True
else:
raise ValueError(f"Invalid control mode: {self.control_mode}")
observation = self._format_raw_obs(raw_obs)
info = {"is_success": False}
return observation, info
@@ -341,8 +354,10 @@ def _make_env_fns(
task_id: int,
n_envs: int,
camera_names: list[str],
episode_length: int | None,
init_states: bool,
gym_kwargs: Mapping[str, Any],
control_mode: str,
) -> list[Callable[[], LiberoEnv]]:
"""Build n_envs factory callables for a single (suite, task_id)."""
@@ -354,7 +369,9 @@ def _make_env_fns(
task_suite_name=suite_name,
camera_name=camera_names,
init_states=init_states,
episode_length=episode_length,
episode_index=episode_index,
control_mode=control_mode,
**local_kwargs,
)
@@ -374,6 +391,8 @@ def create_libero_envs(
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
init_states: bool = True,
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
control_mode: str = "relative",
episode_length: int | None = None,
) -> dict[str, dict[int, Any]]:
"""
Create vectorized LIBERO environments with a consistent return shape.
@@ -415,12 +434,14 @@ def create_libero_envs(
for tid in selected:
fns = _make_env_fns(
suite=suite,
episode_length=episode_length,
suite_name=suite_name,
task_id=tid,
n_envs=n_envs,
camera_names=camera_names,
init_states=init_states,
gym_kwargs=gym_kwargs,
control_mode=control_mode,
)
out[suite_name][tid] = env_cls(fns)
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
+101
View File
@@ -104,6 +104,107 @@ class SGDConfig(OptimizerConfig):
return torch.optim.SGD(params, **kwargs)
@OptimizerConfig.register_subclass("xvla-adamw")
@dataclass
class XVLAAdamWConfig(OptimizerConfig):
"""Custom AdamW optimizer for XVLA with differential learning rates.
The Vision-Language Model (VLM) is trained with 1/10 of the base learning rate
for stable optimization, while all other components use the full LR.
This LR ratio is crucial for achieving strong and stable finetuning performance.
Soft-prompts can optionally use a separate learning rate with warm-up support.
Set `soft_prompt_lr_scale` to a value < 1.0 (e.g., 0.1) to start soft-prompts
at a lower LR. Combine with a warmup scheduler for optimal results.
Note:
Completely matching official reported performance may require an additional
warm-up LR schedule for soft-prompts, which can bring minor improvements.
When `soft_prompt_warmup_lr_scale` is set, soft-prompts start at
`lr * soft_prompt_warmup_lr_scale` and should be warmed up via the scheduler.
Parameter Groups:
- Group 0 (vlm): VLM parameters at lr * 0.1, weight_decay * 0.1
- Group 1 (soft_prompts): Soft-prompt parameters at lr * soft_prompt_lr_scale
- Group 2 (other): All other parameters at full lr
"""
lr: float = 1e-4
betas: tuple[float, float] = (0.9, 0.99)
eps: float = 1e-8
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
# Soft-prompt specific settings
soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR (1.0 = same as base LR)
soft_prompt_warmup_lr_scale: float | None = None # If set, start soft-prompts at this scale (e.g., 0.01)
def build(self, params: dict) -> torch.optim.Optimizer:
"""
Build AdamW optimizer with differential learning rates.
Expects `named_parameters()` as input (dict of name -> param).
Applies:
- lr * 0.1 for all VLM-related parameters
- lr * soft_prompt_lr_scale for soft-prompt parameters (with optional warmup)
- full lr for all other parameters
Args:
params: Dictionary of parameter names to parameters (from named_parameters())
Returns:
AdamW optimizer with parameter groups for VLM, soft-prompts, and other components
"""
assert isinstance(params, dict), "Custom LR optimizer requires `named_parameters()` as inputs."
vlm_group, soft_prompt_group, other_group = [], [], []
for name, p in params.items():
if not p.requires_grad:
continue
if "vlm" in name.lower():
vlm_group.append(p)
elif "soft_prompt" in name.lower():
soft_prompt_group.append(p)
else:
other_group.append(p)
# Determine soft-prompt LR
soft_prompt_lr = self.lr * self.soft_prompt_lr_scale
if self.soft_prompt_warmup_lr_scale is not None:
# Start at warmup scale, scheduler will warm up to soft_prompt_lr
soft_prompt_lr = self.lr * self.soft_prompt_warmup_lr_scale
param_groups = [
{
"params": vlm_group,
"lr": self.lr * 0.1,
"weight_decay": self.weight_decay * 0.1,
"name": "vlm",
},
{
"params": soft_prompt_group,
"lr": soft_prompt_lr,
"weight_decay": self.weight_decay,
"name": "soft_prompts",
},
{
"params": other_group,
"lr": self.lr,
"weight_decay": self.weight_decay,
"name": "other",
},
]
# Filter out empty groups
param_groups = [g for g in param_groups if len(g["params"]) > 0]
return torch.optim.AdamW(
param_groups,
betas=self.betas,
eps=self.eps,
)
@OptimizerConfig.register_subclass("multi_adam")
@dataclass
class MultiAdamConfig(OptimizerConfig):
+2
View File
@@ -21,6 +21,7 @@ from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
__all__ = [
"ACTConfig",
@@ -31,4 +32,5 @@ __all__ = [
"TDMPCConfig",
"VQBeTConfig",
"GrootConfig",
"XVLAConfig",
]
+96 -5
View File
@@ -16,6 +16,7 @@
from __future__ import annotations
import importlib
import logging
from typing import Any, TypedDict
@@ -40,6 +41,7 @@ from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
batch_to_transition,
@@ -107,8 +109,15 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.groot.modeling_groot import GrootPolicy
return GrootPolicy
elif name == "xvla":
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
return XVLAPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
try:
return _get_policy_cls_from_policy_name(name=name)
except Exception as e:
raise ValueError(f"Policy type '{name}' is not available.") from e
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
@@ -150,8 +159,14 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return RewardClassifierConfig(**kwargs)
elif policy_type == "groot":
return GrootConfig(**kwargs)
elif policy_type == "xvla":
return XVLAConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
return config_cls(**kwargs)
except Exception as e:
raise ValueError(f"Policy type '{policy_type}' is not available.") from e
class ProcessorConfigKwargs(TypedDict, total=False):
@@ -329,9 +344,24 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, XVLAConfig):
from lerobot.policies.xvla.processor_xvla import (
make_xvla_pre_post_processors,
)
processors = make_xvla_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
try:
processors = _make_processors_from_policy_config(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
except Exception as e:
raise ValueError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") from e
return processors
@@ -400,8 +430,7 @@ def make_policy(
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
features = env_to_policy_features(env_cfg)
if not cfg.output_features:
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
@@ -425,3 +454,65 @@ def make_policy(
# TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
return policy
def _get_policy_cls_from_policy_name(name: str) -> type[PreTrainedConfig]:
"""Get policy class from its registered name using dynamic imports.
This is used as a helper function to import policies from 3rd party lerobot plugins.
Args:
name: The name of the policy.
Returns:
The policy class corresponding to the given name.
"""
if name not in PreTrainedConfig.get_known_choices():
raise ValueError(
f"Unknown policy name '{name}'. Available policies: {PreTrainedConfig.get_known_choices()}"
)
config_cls = PreTrainedConfig.get_choice_class(name)
config_cls_name = config_cls.__name__
model_name = config_cls_name.removesuffix("Config") # e.g., DiffusionConfig -> Diffusion
if model_name == config_cls_name:
raise ValueError(
f"The config class name '{config_cls_name}' does not follow the expected naming convention."
f"Make sure it ends with 'Config'!"
)
cls_name = model_name + "Policy" # e.g., DiffusionConfig -> DiffusionPolicy
module_path = config_cls.__module__.replace(
"configuration_", "modeling_"
) # e.g., configuration_diffusion -> modeling_diffusion
module = importlib.import_module(module_path)
policy_cls = getattr(module, cls_name)
return policy_cls
def _make_processors_from_policy_config(
config: PreTrainedConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[Any, Any]:
"""Create pre- and post-processors from a policy configuration using dynamic imports.
This is used as a helper function to import processor factories from 3rd party lerobot plugins.
Args:
config: The policy configuration object.
dataset_stats: Dataset statistics for normalization.
Returns:
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
"""
policy_type = config.type
function_name = f"make_{policy_type}_pre_post_processors"
module_path = config.__class__.__module__.replace(
"configuration_", "processor_"
) # e.g., configuration_diffusion -> processor_diffusion
logging.debug(
f"Instantiating pre/post processors using function '{function_name}' from module '{module_path}'"
)
module = importlib.import_module(module_path)
function = getattr(module, function_name)
return function(config, dataset_stats=dataset_stats)
@@ -23,6 +23,8 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import OBS_IMAGES
DEFAULT_IMAGE_SIZE = 224
@PreTrainedConfig.register_subclass("pi0")
@dataclass
@@ -51,7 +53,10 @@ class PI0Config(PreTrainedConfig):
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE,
DEFAULT_IMAGE_SIZE,
) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
empty_cameras: int = 0
+14 -13
View File
@@ -41,7 +41,7 @@ else:
PaliGemmaForConditionalGeneration = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import (
@@ -337,6 +337,7 @@ class PaliGemmaWithExpertModel(
action_expert_config,
use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16",
image_size: int = DEFAULT_IMAGE_SIZE,
):
if use_adarms is None:
use_adarms = [False, False]
@@ -356,6 +357,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.image_size = image_size
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
@@ -519,11 +521,17 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
if config.image_resolution[0] != config.image_resolution[1]:
raise ValueError(
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(
paligemma_config,
action_expert_config,
use_adarms=[False, False],
precision=config.dtype,
image_size=config.image_resolution[0],
)
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
@@ -812,16 +820,13 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
)
dt = -1.0 / num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
@@ -846,15 +851,11 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
x_t = x_t + dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
def denoise_step(
@@ -22,6 +22,8 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
DEFAULT_IMAGE_SIZE = 224
@PreTrainedConfig.register_subclass("pi05")
@dataclass
@@ -50,7 +52,10 @@ class PI05Config(PreTrainedConfig):
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE,
DEFAULT_IMAGE_SIZE,
) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
empty_cameras: int = 0
+16 -13
View File
@@ -41,7 +41,7 @@ else:
PaliGemmaForConditionalGeneration = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import (
@@ -336,6 +336,7 @@ class PaliGemmaWithExpertModel(
action_expert_config,
use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16",
image_size: int = DEFAULT_IMAGE_SIZE,
):
if use_adarms is None:
use_adarms = [False, False]
@@ -355,6 +356,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.image_size = image_size
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
@@ -518,11 +520,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
if config.image_resolution[0] != config.image_resolution[1]:
raise ValueError(
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(
paligemma_config,
action_expert_config,
use_adarms=[False, True],
precision=config.dtype,
image_size=config.image_resolution[0],
)
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
@@ -538,6 +546,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
if config.compile_model:
torch.set_float32_matmul_precision("high")
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
# Also compile the main forward pass used during training
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
@@ -785,16 +795,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
)
dt = -1.0 / num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
@@ -818,15 +825,11 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
x_t = x_t + dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
def denoise_step(
@@ -783,18 +783,15 @@ class VLAFlowMatching(nn.Module):
use_cache=self.config.use_cache,
fill_kv_cache=True,
)
dt = -1.0 / self.config.num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
num_steps = self.config.num_steps
dt = -1.0 / num_steps
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
x_t=input_x_t,
prefix_pad_masks=prefix_pad_masks,
@@ -818,15 +815,11 @@ class VLAFlowMatching(nn.Module):
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
x_t = x_t + dt * v_t
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
def denoise_step(
+6
View File
@@ -0,0 +1,6 @@
# register the processor steps
from lerobot.policies.xvla.processor_xvla import (
XVLAAddDomainIdProcessorStep,
XVLAImageNetNormalizeProcessorStep,
XVLAImageToFloatProcessorStep,
)
+588
View File
@@ -0,0 +1,588 @@
# ------------------------------------------------------------------------------
# Copyright 2025 2toINF and HuggingFace Inc. (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
from collections.abc import Iterable
import torch
import torch.nn as nn
# =============================================================================
# Registry
# =============================================================================
ACTION_REGISTRY: dict[str, type[BaseActionSpace]] = {}
def register_action(name: str):
"""Decorator for registering a new action space."""
def _wrap(cls):
key = name.lower()
if key in ACTION_REGISTRY:
raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}")
ACTION_REGISTRY[key] = cls
cls.name = key
return cls
return _wrap
def build_action_space(name: str, **kwargs) -> BaseActionSpace:
"""Instantiate a registered action space by name."""
key = name.lower()
if key not in ACTION_REGISTRY:
raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}")
return ACTION_REGISTRY[key](**kwargs)
# =============================================================================
# Base class
# =============================================================================
class BaseActionSpace(nn.Module):
"""
Abstract base class for all action-space definitions.
Each subclass defines:
- `dim_action`: dimension of the action vector.
- `gripper_idx`: indices of gripper channels.
- `compute_loss(pred, target)`: supervised loss for this space.
- `preprocess(proprio, action, mode)`: pre-step modifications.
- `postprocess(action)`: post-step corrections (e.g. apply sigmoid).
"""
name: str = "base"
dim_action: int = 0
gripper_idx: tuple[int, ...] = ()
def __init__(self):
super().__init__()
# ---------------------------------------------------------------------
# Core supervised loss
# ---------------------------------------------------------------------
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
raise NotImplementedError
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
"""Alias for compute_loss."""
return self.compute_loss(pred, target)
# ---------------------------------------------------------------------
# Space-level hooks
# ---------------------------------------------------------------------
def preprocess(
self,
proprio: torch.Tensor,
action: torch.Tensor,
mode: str = "train",
) -> tuple[torch.Tensor, torch.Tensor]:
"""Default: return unchanged."""
return proprio, action
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""Default: return unchanged."""
return action
# =============================================================================
# Utilities
# =============================================================================
def _ensure_indices_valid(dim_action: int, idx: Iterable[int], name: str) -> None:
bad = [i for i in idx if i < 0 or i >= dim_action]
if bad:
raise IndexError(f"{name} contains out-of-range indices {bad} for action dim dim_action={dim_action}")
# =============================================================================
# Implementations
# =============================================================================
@register_action("ee6d")
class EE6DActionSpace(BaseActionSpace):
"""End-effector layout with xyz, 6D rotation, and gripper channels."""
dim_action = 20
gripper_idx = (9, 19)
GRIPPER_SCALE = 1.0
XYZ_SCALE = 500.0
ROT_SCALE = 10.0
POS_IDX_1 = (0, 1, 2)
POS_IDX_2 = (10, 11, 12)
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape, "pred/target shapes must match"
batch_size, seq_len, action_dim = pred.shape
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
# Gripper BCE
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
# XYZ position
pos_loss = (
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
) * self.XYZ_SCALE
# Rotation 6D
rot_loss = (
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
) * self.ROT_SCALE
return {
"position_loss": pos_loss,
"rotate6D_loss": rot_loss,
"gripper_loss": gripper_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""Zero-out gripper channels in proprio/action."""
proprio_m = proprio.clone()
action_m = action.clone()
proprio_m[..., self.gripper_idx] = 0.0
action_m[..., self.gripper_idx] = 0.0
return proprio_m, action_m
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""Apply sigmoid to gripper logits."""
if action.size(-1) > max(self.gripper_idx):
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
return action
@register_action("joint")
class JointActionSpace(BaseActionSpace):
"""Joint-space layout with joints + gripper only."""
dim_action = 14
gripper_idx = (6, 13)
GRIPPER_SCALE = 0.1
JOINTS_SCALE = 1.0
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape
batch_size, seq_len, action_dim = pred.shape
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
joints_idx = tuple(i for i in range(action_dim) if i not in set(self.gripper_idx))
joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
return {
"joints_loss": joints_loss,
"gripper_loss": gripper_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""Zero-out gripper channels in proprio/action."""
proprio_m = proprio.clone()
action_m = action.clone()
proprio_m[..., self.gripper_idx] = 0.0
action_m[..., self.gripper_idx] = 0.0
return proprio_m, action_m
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""Apply sigmoid to gripper logits."""
if action.size(-1) > max(self.gripper_idx):
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
return action
@register_action("agibot_ee6d")
class AGIBOTEE6DActionSpace(BaseActionSpace):
"""AGI-bot variant of EE6DActionSpace using MSE for all components."""
dim_action = 20
gripper_idx = (9, 19)
GRIPPER_SCALE = 10.0
XYZ_SCALE = 500.0
ROT_SCALE = 10.0
POS_IDX_1 = (0, 1, 2)
POS_IDX_2 = (10, 11, 12)
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape
batch_size, seq_len, action_dim = pred.shape
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
gripper_loss = (
self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
)
pos_loss = (
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
) * self.XYZ_SCALE
rot_loss = (
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
) * self.ROT_SCALE
return {
"position_loss": pos_loss,
"rotate6D_loss": rot_loss,
"gripper_loss": gripper_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""No preprocessing applied in AGIBOT variant."""
return proprio, action
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""AGIBOT does not postprocess."""
return action
@register_action("franka_joint7")
class FrankaJoint7ActionSpace(BaseActionSpace):
"""
Franka Panda joint-space: 7 joints, with gripper.
- Real robot action dim: 7
- Model-facing dim: 20 (padded with zeros)
compatible with pretrained VLA models expecting 20D.
"""
dim_action = 20 # model dimension
REAL_DIM = 7 # actual Franka joints
JOINTS_SCALE = 1.0
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
"""Pad 7 → 20 dims (zeros for the dummy channels)."""
if x is None:
return None
if x.size(-1) == self.dim_action:
return x
if x.size(-1) != self.REAL_DIM:
raise ValueError(
f"Expected last dim to be {self.REAL_DIM} or {self.dim_action}, got {x.size(-1)}"
)
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.REAL_DIM] # 13 zeros
pad = x.new_zeros(pad_shape)
return torch.cat([x, pad], dim=-1)
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
"""Trim model output 20 → 7 dims."""
return x[..., : self.REAL_DIM]
def compute_loss(self, pred, target):
"""
pred : [B, T, 20]
target : [B, T, 7] or [B, T, 20]
Only compute MSE on the first 7 dims.
"""
pred = self._pad_to_model_dim(pred)
target = self._pad_to_model_dim(target)
assert pred.shape == target.shape
joints_loss = (
self.mse(
pred[:, :, : self.REAL_DIM], # use only the first 7 joints
target[:, :, : self.REAL_DIM],
)
* self.JOINTS_SCALE
)
return {"joints_loss": joints_loss}
def preprocess(self, proprio, action, mode="train"):
"""
During training:
- Pad [7] [20]
"""
return proprio, self._pad_to_model_dim(action)
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""
After model prediction:
- Trim [20] [7] for real robot control.
"""
return self._trim_to_real_dim(action)
@register_action("auto")
class AutoActionSpace(BaseActionSpace):
"""
Auto-detecting action space that adapts to any action dimension.
- Auto-detects the real action dimension from the policy feature
- Model outputs max_dim for compatibility with pretrained models
- Loss is computed only on the first real_dim dimensions
- Postprocess trims output back to real_dim
Args:
real_dim: The actual action dimension from the dataset/policy feature
max_dim: The model's output dimension for pretrained VLA compatibility
"""
JOINTS_SCALE = 1.0
def __init__(self, real_dim: int, max_dim: int):
super().__init__()
self.real_dim = real_dim
self.dim_action = max_dim # Model-facing dimension
self.mse = nn.MSELoss()
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
"""Pad real_dim → max_dim (zeros for the dummy channels)."""
if x is None:
return None
if x.size(-1) == self.dim_action:
return x
if x.size(-1) != self.real_dim:
# If dimension doesn't match either, pad/trim to real_dim first
if x.size(-1) < self.real_dim:
pad_shape = list(x.shape[:-1]) + [self.real_dim - x.size(-1)]
pad = x.new_zeros(pad_shape)
x = torch.cat([x, pad], dim=-1)
else:
x = x[..., : self.real_dim]
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.real_dim]
pad = x.new_zeros(pad_shape)
return torch.cat([x, pad], dim=-1)
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
"""Trim model output max_dim → real_dim."""
return x[..., : self.real_dim]
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Compute loss only on the first real_dim dimensions.
pred: [B, T, max_dim] from the model
target: [B, T, real_dim] or [B, T, max_dim]
Loss = MSE(pred[:,:,:real_dim], target[:,:,:real_dim])
"""
pred = self._pad_to_model_dim(pred)
target = self._pad_to_model_dim(target)
assert pred.shape == target.shape, f"Shape mismatch: pred {pred.shape} vs target {target.shape}"
# only compute loss on the real dimensions
joints_loss = (
self.mse(
pred[:, :, : self.real_dim],
target[:, :, : self.real_dim],
)
* self.JOINTS_SCALE
)
return {"joints_loss": joints_loss}
def preprocess(self, proprio: torch.Tensor, action: torch.Tensor, mode: str = "train"):
"""
Pad action from real_dim to max_dim for the model.
"""
return proprio, self._pad_to_model_dim(action)
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""
Trim model output from max_dim to real_dim for real robot control.
"""
return self._trim_to_real_dim(action)
@register_action("so101_bimanual")
class BimanualSO101ActionSpace(BaseActionSpace):
"""
Bimanual SO101 robot: 2 arms with 5 joints each + gripper.
Layout (real robot):
[left_arm (5 joints + gripper), right_arm (5 joints + gripper)]
- Left arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
- Right arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
Real action dim: 12
Model-facing dim: 20 (extra 8 dummy dims at the end)
"""
# Model output / training dimension (to match pretrained policy)
dim_action = 20
# Real robot action dimension
REAL_DIM = 12
# Indices of real vs dummy channels
REAL_IDXS = tuple(range(REAL_DIM)) # 0..11
DUMMY_IDXS = tuple(range(REAL_DIM, dim_action)) # 12..19
# Grippers live in the real part
gripper_idx = (5, 11) # left_gripper at idx 5, right_gripper at idx 11
GRIPPER_SCALE = 1.0
JOINTS_SCALE = 1.0
# Indices for left and right arm joints (excluding grippers)
LEFT_ARM_JOINTS = (0, 1, 2, 3, 4)
RIGHT_ARM_JOINTS = (6, 7, 8, 9, 10)
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
# ---------- helpers ----------
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
"""If last dim is REAL_DIM (12), pad zeros to reach dim_action (20)."""
if x is None:
return None
if x.size(-1) == self.dim_action:
return x
if x.size(-1) != self.REAL_DIM:
raise ValueError(
f"Expected last dim to be {self.REAL_DIM} or {self.dim_action}, got {x.size(-1)}"
)
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.REAL_DIM]
pad = x.new_zeros(pad_shape)
return torch.cat([x, pad], dim=-1)
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
"""Keep only the first REAL_DIM (12) dims for the real robot."""
return x[..., : self.REAL_DIM]
# ---------- loss ----------
def compute_loss(self, pred, target):
"""
pred: [B, T, 20] from the model
target: [B, T, 12] or [B, T, 20]
We pad target 20 and compute loss only on the real dims.
"""
# Ensure both are [B, T, 20]
pred = self._pad_to_model_dim(pred)
target = self._pad_to_model_dim(target)
assert pred.shape == target.shape
# ---- MSE for all real dims (011) ----
real_dims = 12
joints_loss = (
self.mse(
pred[:, :, :real_dims],
target[:, :, :real_dims],
)
* self.JOINTS_SCALE
)
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
gripper_loss = (
self.mse(
pred[:, :, [5, 11]],
target[:, :, [5, 11]],
)
* self.GRIPPER_SCALE
)
return {
"joints_loss": joints_loss,
"gripper_loss": gripper_loss,
"left_arm_loss": left_arm_loss,
"right_arm_loss": right_arm_loss,
}
# ---------- preprocess / postprocess ----------
def preprocess(self, proprio, action, mode="train"):
"""
- If proprio/action are 12-dim, pad them to 20 for the model.
- Zero-out gripper channels in proprio/action to focus learning on joints.
"""
proprio_m = self._pad_to_model_dim(proprio.clone())
action_m = self._pad_to_model_dim(action.clone()) if action is not None else None
proprio_m[..., self.gripper_idx] = 0.0
if action_m is not None:
action_m[..., self.gripper_idx] = 0.0
return proprio_m, action_m
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""
- Model outputs [*, 20]
- Apply sigmoid to gripper logits
- Return only the first 12 dims for the real robot:
["left_shoulder_pan.pos",
"left_shoulder_lift.pos",
"left_elbow_flex.pos",
"left_wrist_flex.pos",
"left_wrist_roll.pos",
"left_gripper.pos",
"right_shoulder_pan.pos",
"right_shoulder_lift.pos",
"right_elbow_flex.pos",
"right_wrist_flex.pos",
"right_wrist_roll.pos",
"right_gripper.pos"]
"""
# Ensure we at least have the real dims + grippers
if action.size(-1) < self.REAL_DIM:
raise ValueError(f"Expected at least {self.REAL_DIM} dims in action, got {action.size(-1)}")
# Apply sigmoid on gripper channels in model space (indices 5 and 11)
if action.size(-1) > max(self.gripper_idx):
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
# Return only the real 12-dim control vector for the env
return self._trim_to_real_dim(action)
# =============================================================================
# Exports
# =============================================================================
__all__ = [
"BaseActionSpace",
"build_action_space",
"register_action",
"EE6DActionSpace",
"JointActionSpace",
"AGIBOTEE6DActionSpace",
"FrankaJoint7ActionSpace",
"AutoActionSpace",
"BimanualSO101ActionSpace",
"ACTION_REGISTRY",
]
@@ -0,0 +1,353 @@
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
""" Florence-2 configuration"""
logger = logging.get_logger(__name__)
class Florence2VisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
drop_path_rate (`float`, *optional*, defaults to 0.1):
The dropout rate of the drop path layer.
patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
The patch size of the image.
patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
The patch stride of the image.
patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
The patch padding of the image.
patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
Whether to apply layer normalization before the patch embedding layer.
enable_checkpoint (`bool`, *optional*, defaults to False):
Whether to enable checkpointing.
dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
The dimension of the embedding layer.
num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
The number of attention heads.
num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
The number of groups.
depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
The depth of the model.
window_size (`int`, *optional*, defaults to 12):
The window size of the model.
projection_dim (`int`, *optional*, defaults to 1024):
The dimension of the projection layer.
visual_temporal_embedding (`dict`, *optional*):
The configuration of the visual temporal embedding.
image_pos_embed (`dict`, *optional*):
The configuration of the image position embedding.
image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
The source of the image feature.
Example:
```python
>>> from transformers import Florence2VisionConfig, Florence2VisionModel
>>> # Initializing a Florence2 Vision style configuration
>>> configuration = Florence2VisionConfig()
>>> # Initializing a model (with random weights)
>>> model = Florence2VisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "davit"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
drop_path_rate=0.1,
patch_size=None,
patch_stride=None,
patch_padding=None,
patch_prenorm=None,
enable_checkpoint=False,
dim_embed=None,
num_heads=None,
num_groups=None,
depths=None,
window_size=12,
projection_dim=1024,
visual_temporal_embedding=None,
image_pos_embed=None,
image_feature_source=None,
**kwargs,
):
self.drop_path_rate = drop_path_rate
self.patch_size = patch_size if patch_size is not None else [7, 3, 3, 3]
self.patch_stride = patch_stride if patch_stride is not None else [4, 2, 2, 2]
self.patch_padding = patch_padding if patch_padding is not None else [3, 1, 1, 1]
self.patch_prenorm = patch_prenorm if patch_prenorm is not None else [False, True, True, True]
self.enable_checkpoint = enable_checkpoint
self.dim_embed = dim_embed if dim_embed is not None else [256, 512, 1024, 2048]
self.num_heads = num_heads if num_heads is not None else [8, 16, 32, 64]
self.num_groups = num_groups if num_groups is not None else [8, 16, 32, 64]
self.depths = depths if depths is not None else [1, 1, 9, 1]
self.window_size = window_size
self.projection_dim = projection_dim
if visual_temporal_embedding is None:
visual_temporal_embedding = {
"type": "COSINE",
"max_temporal_embeddings": 100,
}
self.visual_temporal_embedding = visual_temporal_embedding
if image_pos_embed is None:
image_pos_embed = {
"type": "learned_abs_2d",
"max_pos_embeddings": 1000,
}
self.image_pos_embed = image_pos_embed
self.image_feature_source = (
image_feature_source
if image_feature_source is not None
else ["spatial_avg_pool", "temporal_avg_pool"]
)
super().__init__(**kwargs)
class Florence2LanguageConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the BART
[facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 51289):
Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Florence2LanguageModel`].
d_model (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
encoder_layers (`int`, *optional*, defaults to 12):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 12):
Number of decoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
scale_embedding (`bool`, *optional*, defaults to `False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
num_labels (`int`, *optional*, defaults to 3):
The number of labels to use in [`Florence2LanguageForSequenceClassification`].
forced_eos_token_id (`int`, *optional*, defaults to 2):
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
`eos_token_id`.
Example:
```python
>>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
>>> # Initializing a Florence2 Language style configuration
>>> configuration = Florence2LanguageConfig()
>>> # Initializing a model (with random weights)
>>> model = Florence2LanguageModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "florence2_language"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
def __init__(
self,
vocab_size=51289,
max_position_embeddings=1024,
encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
activation_function="gelu",
d_model=1024,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
classifier_dropout=0.0,
scale_embedding=False,
use_cache=True,
num_labels=3,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
is_encoder_decoder=True,
decoder_start_token_id=2,
forced_eos_token_id=2,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__(
num_labels=num_labels,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id,
forced_eos_token_id=forced_eos_token_id,
**kwargs,
)
# ensure backward compatibility for BART CNN models
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
self.forced_bos_token_id = self.bos_token_id
warnings.warn(
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
"The config can simply be saved and uploaded again to be fixed.",
stacklevel=2,
)
class Florence2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
Florence-2 model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`Florence2VisionConfig`, *optional*):
Custom vision config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
vocab_size (`int`, *optional*, defaults to 51289):
Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
projection_dim (`int`, *optional*, defaults to 1024):
Dimension of the multimodal projection space.
Example:
```python
>>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
>>> # Initializing a clip-like vision config
>>> vision_config = CLIPVisionConfig()
>>> # Initializing a Bart config
>>> text_config = BartConfig()
>>> # Initializing a Florence-2 configuration
>>> configuration = Florence2Config(vision_config, text_config)
>>> # Initializing a model from the florence-2 configuration
>>> model = Florence2ForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "florence2"
is_composition = False
def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
vocab_size=51289,
projection_dim=1024,
**kwargs,
):
self.ignore_index = ignore_index
self.vocab_size = vocab_size
self.projection_dim = projection_dim
if vision_config is not None:
vision_config = Florence2VisionConfig(**vision_config)
self.vision_config = vision_config
self.text_config = text_config
if text_config is not None:
self.text_config = Florence2LanguageConfig(**text_config)
super().__init__(**kwargs)
@@ -0,0 +1,203 @@
#!/usr/bin/env python
# ------------------------------------------------------------------------------
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import XVLAAdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import OBS_IMAGES
# Conditional import for type checking and lazy loading
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from .configuration_florence2 import Florence2Config
else:
Florence2Config = None
@PreTrainedConfig.register_subclass("xvla")
@dataclass
class XVLAConfig(PreTrainedConfig):
"""
Configuration class for the XVLA (Extended Vision-Language-Action) policy so it can
plug into the LeRobot training stack.
The config mirrors the knobs exposed in the original XVLA repository but also
declares the input/output feature contract required by LeRobot.
"""
# Input / output structure
n_obs_steps: int = 1
chunk_size: int = 32
n_action_steps: int = 32
dtype: str = "float32" # Options: "bfloat16", "float32"
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
)
# Florence2 backbone and tokenizer configuration
florence_config: dict[str, Any] = field(default_factory=dict)
tokenizer_name: str = "facebook/bart-large"
tokenizer_max_length: int = 64
tokenizer_padding_side: str = "right"
pad_language_to: str = "max_length"
# Transformer head
hidden_size: int = 1024
depth: int = 24
num_heads: int = 16
mlp_ratio: float = 4.0
num_domains: int = 30
len_soft_prompts: int = 32
dim_time: int = 32
max_len_seq: int = 512
use_hetero_proj: bool = False
# Action & proprioception
action_mode: str = "ee6d"
num_denoising_steps: int = 10
use_proprio: bool = True
max_state_dim: int = 32
max_action_dim: int = 20 # Maximum action dimension for padding (used by "auto" action mode)
domain_feature_key: str | None = None
# Vision preprocessing
resize_imgs_with_padding: tuple[int, int] | None = None
num_image_views: int | None = None
empty_cameras: int = 0
# Freezing options for VLM components
# By default, VLM encoders are frozen and only policy transformer + soft prompts train
freeze_vision_encoder: bool = False # Freeze VLM vision encoder weights
freeze_language_encoder: bool = False # Freeze VLM language encoder weights
train_policy_transformer: bool = True # Allow policy transformer to train
train_soft_prompts: bool = True # Allow soft prompts to train
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.99)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.0
optimizer_grad_clip_norm: float = 10.0
# Soft-prompt LR settings (for optional warm-up)
optimizer_soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR
optimizer_soft_prompt_warmup_lr_scale: float | None = None # Start scale for warmup (e.g., 0.01)
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
def __post_init__(self) -> None:
super().__post_init__()
if self.chunk_size <= 0:
raise ValueError("`chunk_size` must be strictly positive.")
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"`n_action_steps` ({self.n_action_steps}) must be <= `chunk_size` ({self.chunk_size})."
)
if self.num_image_views is not None and self.num_image_views <= 0:
raise ValueError("`num_image_views` must be > 0 when specified.")
if self.dtype not in ["bfloat16", "float32"]:
raise ValueError(f"Invalid dtype: {self.dtype}")
self._florence_config_obj: Florence2Config | None = None
def get_florence_config(self) -> Florence2Config:
"""
Build (and cache) the Florence2 transformer config that should back the VLM.
"""
if self._florence_config_obj is None:
config_dict = dict(self.florence_config)
if "vision_config" not in config_dict or config_dict["vision_config"] is None:
raise ValueError("vision_config is required")
if "text_config" not in config_dict or config_dict["text_config"] is None:
raise ValueError("text_config is required")
self._florence_config_obj = Florence2Config(**config_dict)
return self._florence_config_obj
def validate_features(self) -> None:
if not self.image_features:
raise ValueError("XVLA requires at least one visual feature in the inputs.")
if self.use_proprio and self.robot_state_feature is None:
raise ValueError("`use_proprio=True` requires a proprioceptive state feature.")
if self.num_image_views is None:
self.num_image_views = len(self.image_features) + self.empty_cameras
else:
self.num_image_views = max(self.num_image_views, len(self.image_features) + self.empty_cameras)
if self.empty_cameras > 0:
height, width = (480, 640)
if self.resize_imgs_with_padding is not None:
height, width = self.resize_imgs_with_padding
for idx in range(self.empty_cameras):
key = f"{OBS_IMAGES}.empty_camera_{idx}"
if key not in self.input_features:
self.input_features[key] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, height, width),
)
def get_optimizer_preset(self) -> XVLAAdamWConfig:
"""Return the XVLA-specific optimizer with differential learning rates.
This optimizer applies:
- 1/10 LR for VLM parameters (stable optimization)
- Full LR for transformer/action head
- Configurable LR for soft-prompts (with optional warm-up)
"""
return XVLAAdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
soft_prompt_lr_scale=self.optimizer_soft_prompt_lr_scale,
soft_prompt_warmup_lr_scale=self.optimizer_soft_prompt_warmup_lr_scale,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list[int] | None:
return None
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> list[int] | None:
return None
File diff suppressed because it is too large Load Diff
+548
View File
@@ -0,0 +1,548 @@
#!/usr/bin/env python
# ------------------------------------------------------------------------------
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
import builtins
import logging
import os
from collections import deque
from pathlib import Path
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE
from .action_hub import build_action_space
from .configuration_florence2 import Florence2Config
from .configuration_xvla import XVLAConfig
from .modeling_florence2 import Florence2ForConditionalGeneration
from .soft_transformer import SoftPromptedTransformer
class XVLAModel(nn.Module):
"""
XVLA backbone that stitches Florence-2 embeddings with the temporal/action transformer head.
"""
def __init__(
self,
config: XVLAConfig,
florence_config: Florence2Config,
proprio_dim: int,
) -> None:
super().__init__()
self.config = config
self.chunk_size: int = config.chunk_size
self.use_proprio: bool = config.use_proprio
# Build action space with auto-detection for "auto" mode
if config.action_mode.lower() == "auto":
# Auto-detect real action dim from config.action_feature
real_dim = (
config.action_feature.shape[-1]
if config.action_feature is not None
else config.max_action_dim
)
self.action_space = build_action_space(
config.action_mode.lower(),
real_dim=real_dim,
max_dim=config.max_action_dim,
)
else:
self.action_space = build_action_space(config.action_mode.lower())
self.dim_action = self.action_space.dim_action
self.dim_proprio = proprio_dim
self.vlm = Florence2ForConditionalGeneration(florence_config)
if hasattr(self.vlm, "language_model"):
lm = self.vlm.language_model
if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
del lm.model.decoder
if hasattr(lm, "lm_head"):
del lm.lm_head
projection_dim = getattr(self.vlm.config, "projection_dim", None)
if projection_dim is None:
raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
self.transformer = SoftPromptedTransformer(
hidden_size=config.hidden_size,
multi_modal_input_size=projection_dim,
depth=config.depth,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio,
num_domains=config.num_domains,
dim_action=self.dim_action,
dim_propio=self.dim_proprio,
len_soft_prompts=config.len_soft_prompts,
dim_time=config.dim_time,
max_len_seq=config.max_len_seq,
use_hetero_proj=config.use_hetero_proj,
)
# Apply freezing based on config
self._apply_freezing()
# Apply dtype casting based on config
self._apply_dtype()
def _get_target_dtype(self) -> torch.dtype:
"""Get the target dtype based on config."""
if self.config.dtype == "bfloat16":
return torch.bfloat16
return torch.float32
def _apply_dtype(self) -> None:
"""
Apply dtype casting to model components based on config.
"""
target_dtype = self._get_target_dtype()
self.to(dtype=target_dtype)
def _apply_freezing(self) -> None:
"""
Freeze VLM vision and language encoders based on config options.
Keep only policy transformer and soft prompts trainable.
"""
# Freeze vision encoder
if self.config.freeze_vision_encoder and hasattr(self.vlm, "vision_tower"):
for param in self.vlm.vision_tower.parameters():
param.requires_grad = False
# Freeze language encoder
if self.config.freeze_language_encoder and hasattr(self.vlm, "language_model"):
lm = self.vlm.language_model
# Freeze encoder
if hasattr(lm, "model") and hasattr(lm.model, "encoder"):
for param in lm.model.encoder.parameters():
param.requires_grad = False
# Freeze shared embeddings
if hasattr(lm, "model") and hasattr(lm.model, "shared"):
for param in lm.model.shared.parameters():
param.requires_grad = False
# Freeze or unfreeze policy transformer
if not self.config.train_policy_transformer:
for name, param in self.transformer.named_parameters():
if "soft_prompts" not in name:
param.requires_grad = False
# Freeze or unfreeze soft prompts
if not self.config.train_soft_prompts and hasattr(self.transformer, "soft_prompt_hub"):
for param in self.transformer.soft_prompt_hub.parameters():
param.requires_grad = False
def forward_vlm(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
image_mask: torch.Tensor,
) -> dict[str, torch.Tensor]:
"""
Encode text and multi-view images via Florence2 encoder.
"""
batch_size, num_views = pixel_values.shape[:2]
flat_mask = image_mask.view(-1).to(dtype=torch.bool)
flat_images = pixel_values.flatten(0, 1)
num_valid = int(flat_mask.sum().item())
if num_valid == 0:
raise ValueError("At least one image view must be valid per batch.")
valid_images = flat_images[flat_mask]
valid_feats = self.vlm._encode_image(valid_images)
tokens_per_view, hidden_dim = valid_feats.shape[1:]
image_features = valid_feats.new_zeros((batch_size * num_views, tokens_per_view, hidden_dim))
image_features[flat_mask] = valid_feats
image_features = image_features.view(batch_size, num_views, tokens_per_view, hidden_dim)
inputs_embeds = self.vlm.get_input_embeddings()(input_ids)
merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features(
image_features[:, 0],
inputs_embeds,
)
enc_out = self.vlm.language_model.model.encoder(
attention_mask=attention_mask,
inputs_embeds=merged_embeds,
)[0]
aux_visual_inputs = image_features[:, 1:].reshape(batch_size, -1, hidden_dim)
return {"vlm_features": enc_out, "aux_visual_inputs": aux_visual_inputs}
def forward(
self,
input_ids: torch.LongTensor,
image_input: torch.FloatTensor,
image_mask: torch.Tensor,
domain_id: torch.LongTensor,
proprio: torch.Tensor,
action: torch.Tensor,
) -> dict[str, torch.Tensor]:
"""
Forward pass for the XVLA model.
"""
target_dtype = self._get_target_dtype()
image_input = image_input.to(dtype=target_dtype)
proprio = proprio.to(dtype=target_dtype)
action = action.to(dtype=target_dtype)
enc = self.forward_vlm(input_ids, image_input, image_mask)
batch_size = input_ids.shape[0]
t = (
torch.rand(1, device=input_ids.device, dtype=target_dtype)
+ torch.arange(batch_size, device=input_ids.device, dtype=target_dtype) / batch_size
) % (1 - 1e-5)
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
pred_action = self.transformer(
domain_id=domain_id,
action_with_noise=action_noisy_m,
t=t,
proprio=proprio_m,
**enc,
)
return self.action_space.compute_loss(pred_action, action)
@torch.no_grad()
def generate_actions(
self,
input_ids: torch.LongTensor,
image_input: torch.FloatTensor,
image_mask: torch.Tensor,
domain_id: torch.LongTensor,
proprio: torch.Tensor,
steps: int,
) -> torch.Tensor:
self.eval()
target_dtype = self._get_target_dtype()
image_input = image_input.to(dtype=target_dtype)
proprio = proprio.to(dtype=target_dtype)
enc = self.forward_vlm(input_ids, image_input, image_mask)
batch_size = input_ids.shape[0]
action_dim = self.dim_action
x1 = torch.randn(batch_size, self.chunk_size, action_dim, device=proprio.device, dtype=target_dtype)
action = torch.zeros_like(x1)
steps = max(1, int(steps))
for i in range(steps, 0, -1):
t = torch.full((batch_size,), i / steps, device=proprio.device, dtype=target_dtype)
x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
action = self.transformer(
domain_id=domain_id,
action_with_noise=x_t_m,
proprio=proprio_m,
t=t,
**enc,
)
return self.action_space.postprocess(action)
class XVLAPolicy(PreTrainedPolicy):
"""LeRobot-compliant wrapper built around the XVLA model."""
config_class = XVLAConfig
name = "xvla"
def __init__(self, config: XVLAConfig):
super().__init__(config)
config.validate_features()
florence_config = config.get_florence_config()
proprio_dim = config.max_state_dim if config.use_proprio else 0
self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim)
self.reset()
def reset(self) -> None:
self._queues = {
ACTION: deque(maxlen=self.config.n_action_steps),
}
def get_optim_params(self) -> dict:
"""Return trainable named parameters for optimization.
Returns a dict of name -> param for all trainable parameters.
This enables the xvla-adamw optimizer to apply differential learning rates
based on parameter names (e.g., 1/10 LR for VLM components).
"""
return dict(filter(lambda kv: kv[1].requires_grad, self.named_parameters()))
def _prepare_state(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
if not self.config.use_proprio or OBS_STATE not in batch:
return torch.zeros(batch_size, 0, device=device)
state = batch[OBS_STATE]
if state.ndim > 2:
state = state[:, -1, :]
return pad_vector(state, self.model.dim_proprio)
def _prepare_images(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
present_img_keys = [key for key in self.config.image_features if key in batch]
if len(present_img_keys) == 0:
raise ValueError(
"All image features are missing from the batch. "
f"Batch keys: {list(batch.keys())}, expected at least one of {list(self.config.image_features)}."
)
images = []
masks = []
for key in present_img_keys:
img = batch[key][:, -1] if batch[key].ndim == 5 else batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(img, *self.config.resize_imgs_with_padding)
images.append(img)
masks.append(torch.ones(img.size(0), dtype=torch.bool, device=img.device))
stacked_imgs = torch.stack(images, dim=1)
stacked_masks = torch.stack(masks, dim=1)
total_views = self.config.num_image_views or stacked_imgs.size(1)
total_views = max(total_views, stacked_imgs.size(1))
num_pad = total_views - stacked_imgs.size(1)
if num_pad > 0:
pad_shape = (stacked_imgs.size(0), num_pad, *stacked_imgs.shape[2:])
pad_imgs = stacked_imgs.new_zeros(pad_shape)
pad_masks = stacked_masks.new_zeros((stacked_masks.size(0), num_pad))
stacked_imgs = torch.cat([stacked_imgs, pad_imgs], dim=1)
stacked_masks = torch.cat([stacked_masks, pad_masks], dim=1)
return stacked_imgs, stacked_masks
def _get_domain_id(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
candidate = None
if self.config.domain_feature_key and self.config.domain_feature_key in batch:
candidate = batch[self.config.domain_feature_key]
elif "domain_id" in batch:
candidate = batch["domain_id"]
if candidate is None:
return torch.zeros(batch_size, dtype=torch.long, device=device)
if not isinstance(candidate, torch.Tensor):
candidate = torch.as_tensor(candidate, device=device)
else:
candidate = candidate.to(device=device)
if candidate.ndim == 0:
candidate = candidate.expand(batch_size)
if candidate.ndim > 1:
candidate = candidate.view(candidate.shape[0], -1)[:, 0]
if candidate.shape[0] != batch_size:
candidate = candidate.expand(batch_size)
return candidate.to(dtype=torch.long)
def _prepare_action_targets(self, batch: dict[str, Tensor]) -> Tensor:
if ACTION not in batch:
raise ValueError("Batch is missing action targets required for training.")
actions = batch[ACTION]
if actions.ndim == 2:
actions = actions.unsqueeze(1)
actions = pad_tensor_along_dim(actions, self.config.chunk_size, dim=1)
if actions.shape[-1] != self.model.dim_action:
actions = pad_vector(actions, self.model.dim_action)
return actions
def _build_model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
input_ids = batch[OBS_LANGUAGE_TOKENS]
batch_size = input_ids.shape[0]
images, image_mask = self._prepare_images(batch)
domain_id = self._get_domain_id(batch, batch_size, images.device)
proprio = self._prepare_state(batch, batch_size, images.device)
return {
"input_ids": input_ids,
"image_input": images,
"image_mask": image_mask,
"domain_id": domain_id,
"proprio": proprio,
}
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
inputs = self._build_model_inputs(batch)
targets = self._prepare_action_targets(batch)
losses = self.model(action=targets, **inputs)
total_loss = sum(losses.values())
log_dict = {k: v.detach().item() for k, v in losses.items()}
log_dict["loss"] = total_loss.detach().item()
return total_loss, log_dict
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
inputs = self._build_model_inputs(batch)
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
return actions
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
return self._get_action_chunk(batch)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
if len(self._queues[ACTION]) == 0:
actions = self._get_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = False,
**kwargs,
):
"""
Loads XVLA model weights with:
- automatic prefix 'model.' added to all keys
- skip list for layers that should remain randomly initialized
"""
import safetensors.torch
# step 1: load config
# TODO: jadechoghari, fix this
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
model_id = str(pretrained_name_or_path)
instance = cls(config, **kwargs)
# step 2: locate model.safetensors
if os.path.isdir(model_id):
logging.info("Loading weights from local directory")
model_file = os.path.join(model_id, "model.safetensors")
else:
try:
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
model_file = hf_hub_download(
repo_id=model_id,
filename="model.safetensors",
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
logging.info(f"Loading checkpoint from {model_file}")
# step 3: load state dict
state_dict = safetensors.torch.load_file(model_file)
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
shared_key = "model.vlm.language_model.model.shared.weight"
if encoder_key in state_dict:
state_dict[shared_key] = state_dict[encoder_key]
# or deepcopy
# step 4: load into instance
instance.load_state_dict(state_dict, strict=True)
logging.info("Loaded XVLA checkpoint")
# step 5: finalize
# Reapply dtype after loading state dict
instance.model._apply_dtype()
instance.to(config.device)
instance.eval()
return instance
def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float = 0.0) -> torch.Tensor:
if img.ndim != 4:
raise ValueError(f"(b,c,h,w) expected, but got {img.shape}")
current_height, current_width = img.shape[2:]
if current_height == height and current_width == width:
return img
ratio = max(current_width / width, current_height / height)
resized_height = int(current_height / ratio)
resized_width = int(current_width / ratio)
resized_img = F.interpolate(
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
pad_height = max(0, height - resized_height)
pad_width = max(0, width - resized_width)
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img
def pad_vector(vector: Tensor, new_dim: int) -> Tensor:
if vector.shape[-1] == new_dim:
return vector
if new_dim == 0:
shape = list(vector.shape)
shape[-1] = 0
return vector.new_zeros(*shape)
shape = list(vector.shape)
current_dim = shape[-1]
shape[-1] = new_dim
new_vector = vector.new_zeros(*shape)
length = min(current_dim, new_dim)
new_vector[..., :length] = vector[..., :length]
return new_vector
def pad_tensor_along_dim(tensor: Tensor, target_len: int, dim: int = 1) -> Tensor:
current_len = tensor.size(dim)
if current_len == target_len:
return tensor
if current_len > target_len:
slices = [slice(None)] * tensor.dim()
slices[dim] = slice(0, target_len)
return tensor[tuple(slices)]
pad_shape = list(tensor.shape)
pad_shape[dim] = target_len - current_len
pad_tensor = tensor.new_zeros(pad_shape)
return torch.cat([tensor, pad_tensor], dim=dim)
+554
View File
@@ -0,0 +1,554 @@
# ------------------------------------------------------------------------------
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.datasets.factory import IMAGENET_STATS
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
ObservationProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_IMAGES,
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
def make_xvla_pre_post_processors(
config: XVLAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Build the LeRobot processor pipelines for XVLA.
"""
features = {**config.input_features, **config.output_features}
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
TokenizerProcessorStep(
tokenizer_name=config.tokenizer_name,
max_length=config.tokenizer_max_length,
padding=config.pad_language_to,
padding_side=config.tokenizer_padding_side,
),
XVLAImageToFloatProcessorStep(),
XVLAImageNetNormalizeProcessorStep(),
XVLAAddDomainIdProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features=features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
# Custom XVLA processor steps
@dataclass
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
This step handles the specific observation structure from LIBERO environments,
which includes nested robot_state dictionaries and image observations.
**State Processing:**
- Processes the `robot_state` dictionary which contains nested end-effector,
gripper, and joint information.
- Extracts and concatenates:
- End-effector position (3D)
- End-effector quaternion converted to axis-angle (3D)
- Gripper joint positions (2D)
- Maps the concatenated state to `"observation.state"`.
**Image Processing:**
- Rotates images by 180 degrees by flipping both height and width dimensions.
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
"""
def _process_observation(self, observation):
"""
Processes both image and robot_state observations from LIBERO.
"""
processed_obs = observation.copy()
for key in list(processed_obs.keys()):
if key.startswith(f"{OBS_IMAGES}."):
img = processed_obs[key]
if key == f"{OBS_IMAGES}.image":
# Flip both H and W
img = torch.flip(img, dims=[2, 3])
processed_obs[key] = img
# Process robot_state into a flat state vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
# Extract components
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
eef_mat = robot_state["eef"]["mat"] # (B, 3, 3)
eef_rot6d = self._mat_to_rotate6d(eef_mat) # (B, 6)
extra = torch.zeros((eef_pos.shape[0], 1), dtype=torch.float32, device=eef_pos.device)
proprio_state = torch.cat((eef_pos, eef_rot6d, extra), dim=-1) # (B, 10)
state = torch.cat((proprio_state, torch.zeros_like(proprio_state)), dim=-1) # (B, 20)
# ensure float32
state = state.float()
if state.dim() == 1:
state = state.unsqueeze(0)
processed_obs[OBS_STATE] = state
return processed_obs
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Transforms feature keys from the LIBERO format to the LeRobot standard.
"""
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
# copy over non-STATE features
for ft, feats in features.items():
if ft != PipelineFeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
state_feats = {}
# add our new flattened state
state_feats["observation.state"] = PolicyFeature(
key="observation.state",
shape=(20,),
dtype="float32",
)
new_features[PipelineFeatureType.STATE] = state_feats
return new_features
def _mat_to_rotate6d(self, rot_mats: torch.Tensor) -> torch.Tensor:
"""
Convert batched rotation matrices (B, 3, 3) into 6D rotation representation (B, 6).
Args:
rot_mats (Tensor): Rotation matrices of shape (B, 3, 3)
Returns:
Tensor: 6D rotation representation, shape (B, 6)
Raises:
TypeError: if input is not a torch tensor
ValueError: if shape is not (B, 3, 3)
"""
if not isinstance(rot_mats, torch.Tensor):
raise TypeError(f"mat_to_rot6d expects a torch.Tensor, got {type(rot_mats)}")
if rot_mats.ndim != 3 or rot_mats.shape[1:] != (3, 3):
raise ValueError(f"mat_to_rot6d expects shape (B, 3, 3), got {tuple(rot_mats.shape)}")
rot_mats = rot_mats.to(torch.float32)
col1 = rot_mats[:, :3, 0] # (B, 3)
col2 = rot_mats[:, :3, 1] # (B, 3)
rot6d = torch.cat([col1, col2], dim=-1) # (B, 6)
return rot6d
def observation(self, observation):
return self._process_observation(observation)
@dataclass
@ProcessorStepRegistry.register(name="xvla_image_scale")
class XVLAImageScaleProcessorStep(ProcessorStep):
"""Scale image observations by 255 to convert from [0, 1] to [0, 255] range.
This processor step multiplies all image observations by 255, which is required
for XVLA models that expect images in uint8-like range.
Args:
image_keys: List of observation keys that contain images to scale.
If None, will automatically detect keys starting with "observation.images."
"""
image_keys: list[str] | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Scale image observations by 255."""
new_transition = transition.copy()
obs = new_transition.get(TransitionKey.OBSERVATION, {})
if obs is None:
return new_transition
# Make a copy of observations to avoid modifying the original
obs = obs.copy()
# Determine which keys to scale
keys_to_scale = self.image_keys
if keys_to_scale is None:
# Auto-detect image keys
keys_to_scale = [k for k in obs if k.startswith("observation.images.")]
# Scale each image
for key in keys_to_scale:
if key in obs and isinstance(obs[key], torch.Tensor):
obs[key] = obs[key] * 255
new_transition[TransitionKey.OBSERVATION] = obs
return new_transition
def transform_features(self, features):
"""Image scaling doesn't change feature structure."""
return features
def get_config(self) -> dict[str, Any]:
"""Return serializable configuration."""
return {
"image_keys": self.image_keys,
}
@dataclass
@ProcessorStepRegistry.register(name="xvla_image_to_float")
class XVLAImageToFloatProcessorStep(ProcessorStep):
"""Convert image observations from [0, 255] to [0, 1] range.
This processor step divides image observations by 255 to convert from uint8-like
range [0, 255] to float range [0, 1]. This is typically used when loading images
that are stored as uint8 values.
Args:
image_keys: List of observation keys that contain images to convert.
If None, will automatically detect keys starting with "observation.images."
validate_range: If True, validates that input values are in [0, 255] range (default: True)
Raises:
ValueError: If validate_range is True and image values are not in [0, 255] range.
"""
image_keys: list[str] | None = None
validate_range: bool = True
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Convert image observations from [0, 255] to [0, 1]."""
new_transition = transition.copy()
obs = new_transition.get(TransitionKey.OBSERVATION, {})
if obs is None:
return new_transition
# Make a copy of observations to avoid modifying the original
obs = obs.copy()
# Determine which keys to convert
keys_to_convert = self.image_keys
if keys_to_convert is None:
# Auto-detect image keys
keys_to_convert = [k for k in obs if k.startswith("observation.images.")]
# Convert each image
for key in keys_to_convert:
if key in obs and isinstance(obs[key], torch.Tensor):
tensor = obs[key]
min_val = tensor.min().item()
max_val = tensor.max().item()
if max_val <= 1.0:
obs[key] = tensor.float() # ensure float dtype, but no division
continue
# Validate that values are in [0, 255] range if requested
if self.validate_range and (min_val < 0.0 or max_val > 255.0):
raise ValueError(
f"Image '{key}' has values outside [0, 255] range: "
f"min={min_val:.4f}, max={max_val:.4f}. "
f"Cannot convert to [0, 1] range."
)
# Convert to float and divide by 255
obs[key] = tensor.float() / 255.0
new_transition[TransitionKey.OBSERVATION] = obs
return new_transition
def transform_features(self, features):
"""Image conversion doesn't change feature structure."""
return features
def get_config(self) -> dict[str, Any]:
"""Return serializable configuration."""
return {
"image_keys": self.image_keys,
"validate_range": self.validate_range,
}
@dataclass
@ProcessorStepRegistry.register(name="xvla_imagenet_normalize")
class XVLAImageNetNormalizeProcessorStep(ProcessorStep):
"""Normalize image observations using ImageNet statistics.
This processor step applies ImageNet normalization (mean and std) to image observations.
It validates that input values are in the [0, 1] range before normalizing.
The normalization formula is: (image - mean) / std
Args:
image_keys: List of observation keys that contain images to normalize.
If None, will automatically detect keys starting with "observation.images."
Raises:
ValueError: If image values are not in the [0, 1] range.
"""
image_keys: list[str] | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Normalize image observations using ImageNet statistics."""
new_transition = transition.copy()
obs = new_transition.get(TransitionKey.OBSERVATION, {})
if obs is None:
return new_transition
# Make a copy of observations to avoid modifying the original
obs = obs.copy()
# Determine which keys to normalize
keys_to_normalize = self.image_keys
if keys_to_normalize is None:
# Auto-detect image keys
keys_to_normalize = [k for k in obs if k.startswith("observation.images.")]
# Normalize each image
for key in keys_to_normalize:
if key in obs and isinstance(obs[key], torch.Tensor):
tensor = obs[key]
# Validate that values are in [0, 1] range
min_val = tensor.min().item()
max_val = tensor.max().item()
if min_val < 0.0 or max_val > 1.0:
raise ValueError(
f"Image '{key}' has values outside [0, 1] range: "
f"min={min_val:.4f}, max={max_val:.4f}. "
f"ImageNet normalization requires input values in [0, 1]."
)
# Apply ImageNet normalization
mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
# Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
while mean.dim() < tensor.dim():
mean = mean.unsqueeze(0)
std = std.unsqueeze(0)
# Normalize: (image - mean) / std
obs[key] = (tensor - mean) / std
new_transition[TransitionKey.OBSERVATION] = obs
return new_transition
def transform_features(self, features):
"""ImageNet normalization doesn't change feature structure."""
return features
def get_config(self) -> dict[str, Any]:
"""Return serializable configuration."""
return {
"image_keys": self.image_keys,
}
@dataclass
@ProcessorStepRegistry.register(name="xvla_add_domain_id")
class XVLAAddDomainIdProcessorStep(ProcessorStep):
"""Add domain_id to complementary data.
This processor step adds a domain_id tensor to the complementary data,
which is used by XVLA to identify different robot embodiments or task domains.
Args:
domain_id: The domain ID to add (default: 3)
"""
domain_id: int = 0
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Add domain_id to complementary data."""
new_transition = transition.copy()
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp = {} if comp is None else comp.copy()
# Infer batch size from observation tensors
obs = new_transition.get(TransitionKey.OBSERVATION, {})
batch_size = 1
if obs:
for v in obs.values():
if isinstance(v, torch.Tensor):
batch_size = v.shape[0]
break
# Add domain_id tensor
comp["domain_id"] = torch.tensor([int(self.domain_id)] * batch_size, dtype=torch.long)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp
return new_transition
def transform_features(self, features):
"""Domain ID addition doesn't change feature structure."""
return features
def get_config(self) -> dict[str, Any]:
"""Return serializable configuration."""
return {
"domain_id": self.domain_id,
}
@dataclass
@ProcessorStepRegistry.register(name="xvla_rotation_6d_to_axis_angle")
class XVLARotation6DToAxisAngleProcessorStep(ProcessorStep):
"""Convert 6D rotation representation to axis-angle and reorganize action dimensions.
This processor step takes actions with 6D rotation representation and converts them to
axis-angle representation, reorganizing the action dimensions as:
- action[:, :3] -> target_eef (end-effector position)
- action[:, 3:9] -> 6D rotation (converted to axis-angle, 3D)
- action[:, 9:10] -> gripper action
Final output: [target_eef (3), axis_angle (3), gripper (1)] = 7D action
Args:
expected_action_dim: Expected input action dimension (default: 10, supports 6D rotation + extras)
"""
expected_action_dim: int = 10
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Convert 6D rotation to axis-angle in action."""
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None or not isinstance(action, torch.Tensor):
return new_transition
# Convert to numpy for processing
device = action.device
dtype = action.dtype
action_np = action.cpu().numpy()
# Extract components
# action shape: (B, D) where D >= 10
target_eef = action_np[:, :3] # (B, 3)
rotation_6d = action_np[:, 3:9] # (B, 6)
target_act = action_np[:, 9:10] # (B, 1)
# Convert 6D rotation to axis-angle
target_axis = rotate6d_to_axis_angle(rotation_6d) # (B, 3)
# Concatenate: [eef (3), axis_angle (3), gripper (1)] = 7D
action_np = np.concatenate([target_eef, target_axis, target_act], axis=-1)
# Convert gripper action to -1 or 1
action_np[:, -1] = np.where(action_np[:, -1] > 0.5, 1.0, -1.0)
# Convert back to tensor
action = torch.from_numpy(action_np).to(device=device, dtype=dtype)
new_transition[TransitionKey.ACTION] = action
return new_transition
def transform_features(self, features):
"""Rotation conversion changes action dimension from 10 to 7."""
# Note: This is a simplified version. In practice, you might want to
# update the action feature shape in the features dict.
return features
def get_config(self) -> dict[str, Any]:
"""Return serializable configuration."""
return {
"expected_action_dim": self.expected_action_dim,
}
def make_xvla_libero_pre_post_processors() -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Build the LeRobot processor pipelines for XVLA with LIBERO environment.
"""
pre_processor_steps: list[ProcessorStep] = []
post_processor_steps: list[ProcessorStep] = []
pre_processor_steps.extend(
[LiberoProcessorStep(), XVLAImageNetNormalizeProcessorStep(), XVLAAddDomainIdProcessorStep()]
)
post_processor_steps.extend([XVLARotation6DToAxisAngleProcessorStep()])
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=pre_processor_steps,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=post_processor_steps,
),
)
@@ -0,0 +1,415 @@
# ------------------------------------------------------------------------------
# Copyright 2025 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
import math
from collections.abc import Iterable
from functools import partial
from typing import Final
import torch
import torch.nn as nn
import torch.nn.functional as functional
# ------------------------------- Small utils ----------------------------------
def _to_2tuple(x) -> tuple:
"""Minimal replacement for timm.layers.to_2tuple."""
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
t = tuple(x)
return (t[0], t[1]) if len(t) >= 2 else (t[0], t[0])
return (x, x)
def _has_sdp_attention() -> bool:
"""Check if we can use PyTorch fused scaled_dot_product_attention."""
return hasattr(functional, "scaled_dot_product_attention")
# ---------------------------------- MLP --------------------------------------
class Mlp(nn.Module):
"""
MLP used in ViT-style blocks.
Supports Linear or 1x1 Conv 'linear_layer' for token/channel mixing.
"""
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
norm_layer: type[nn.Module] | None = None,
bias: bool | tuple[bool, bool] = True,
drop: float | tuple[float, float] = 0.0,
use_conv: bool = False,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = _to_2tuple(bias)
drop_probs = _to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = nn.GELU(approximate="tanh")
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Expect [B, T, C] for Linear variant; caller is responsible for shapes.
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
# -------------------------------- Attention ----------------------------------
class Attention(nn.Module):
"""
Multi-Head Self-Attention with optional fused SDPA fallback.
If PyTorch provides `scaled_dot_product_attention`, it will be used
(usually faster and more stable); otherwise we use a manual implementation.
"""
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: type[nn.Module] = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.fused_attn = _has_sdp_attention()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Parameters
----------
x : Tensor, shape [batch_size, seq_len, channels]
Input sequence.
Returns
-------
Tensor, shape [batch_size, seq_len, channels]
Output sequence after MHSA + projection.
"""
batch_size, seq_len, channels = x.shape
qkv = (
self.qkv(x)
.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4) # 3 x [batch_size, num_heads, seq_len, head_dim]
)
q, k, v = qkv.unbind(0) # each: [batch_size, num_heads, seq_len, head_dim]
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = functional.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
) # [batch_size, num_heads, seq_len, head_dim]
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1) # [batch_size, num_heads, seq_len, seq_len]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v # [batch_size, num_heads, seq_len, head_dim]
x = x.transpose(1, 2).reshape(batch_size, seq_len, channels) # [batch_size, seq_len, channels]
x = self.proj(x)
x = self.proj_drop(x)
return x
# ------------------------------- Utilities -----------------------------------
def basic_init(module: nn.Module) -> None:
"""
Apply a basic initialization scheme to Linear layers.
- Weight: Xavier uniform initialization.
- Bias: Set to zero.
"""
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torch.Tensor:
"""
Create sinusoidal timestep embeddings.
Parameters
----------
t : torch.Tensor
Shape [B]. Each element is a timestep index, may be fractional.
dim : int
Dimensionality of the output embedding.
max_period : int, default=100
Controls the minimum frequency of the sinusoids.
Returns
-------
torch.Tensor
Shape [B, dim]. Sinusoidal embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=t.dtype, device=t.device) / half
)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2 == 1:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
# ------------------------------- Core Layers ----------------------------------
class DomainAwareLinear(nn.Module):
"""
Linear layer with domain-conditioned parameters (per-sample).
Each domain has its own weight and bias vectors, stored in embeddings.
"""
def __init__(self, input_size: int, output_size: int, num_domains: int = 20) -> None:
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.fc = nn.Embedding(num_domains, output_size * input_size)
self.bias = nn.Embedding(num_domains, output_size)
nn.init.xavier_uniform_(self.fc.weight)
nn.init.zeros_(self.bias.weight)
def forward(self, x: torch.Tensor, domain_id: torch.LongTensor) -> torch.Tensor:
"""
Parameters
----------
x : Tensor
[B, I] or [B, T, I]
domain_id : LongTensor
[B], domain indices.
Returns
-------
Tensor
[batch_size, output_size] or [batch_size, seq_len, output_size]
"""
batch_size = domain_id.shape[0]
squeeze_seq = False
if x.dim() == 2:
x = x.unsqueeze(1)
squeeze_seq = True
weight = self.fc(domain_id).view(batch_size, self.input_size, self.output_size)
bias = self.bias(domain_id).view(batch_size, self.output_size)
y = torch.matmul(x, weight) + bias.view(batch_size, 1, self.output_size)
if squeeze_seq:
y = y.squeeze(1)
return y
class TransformerBlock(nn.Module):
"""
Standard Transformer block (pre-LN): LN MHSA residual, LN MLP residual.
"""
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, attn_drop=0.1)
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=int(hidden_size * mlp_ratio),
drop=0.1,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Parameters
----------
x : Tensor, [B, T, H]
Returns
-------
Tensor, [B, T, H]
"""
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
# --------------------------- Main Model ---------------------------------------
class SoftPromptedTransformer(nn.Module):
"""
Multi-modal, domain-aware Transformer with optional soft prompts.
See parameter and forward I/O descriptions inside the docstrings.
"""
def __init__(
self,
hidden_size: int = 768,
multi_modal_input_size: int = 768,
depth: int = 24,
num_heads: int = 16,
mlp_ratio: float = 4.0,
num_domains: int = 20,
dim_action: int = 20,
dim_propio: int = 20,
dim_time: int = 32,
len_soft_prompts: int = 32,
max_len_seq: int = 512,
use_hetero_proj: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.dim_action = dim_action
self.dim_time = dim_time
self.len_soft_prompts = len_soft_prompts
self.use_hetero_proj = use_hetero_proj
self.blocks = nn.ModuleList(
[TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)]
)
if use_hetero_proj:
self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
self.aux_visual_proj = DomainAwareLinear(
multi_modal_input_size, hidden_size, num_domains=num_domains
)
else:
self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
self.pos_emb = nn.Parameter(torch.zeros(1, max_len_seq, hidden_size), requires_grad=True)
nn.init.normal_(self.pos_emb, std=0.02)
self.norm = nn.LayerNorm(hidden_size)
self.action_encoder = DomainAwareLinear(
dim_action + dim_time + dim_propio, hidden_size, num_domains=num_domains
)
self.action_decoder = DomainAwareLinear(hidden_size, dim_action, num_domains=num_domains)
if len_soft_prompts > 0:
self.soft_prompt_hub = nn.Embedding(num_domains, len_soft_prompts * hidden_size)
nn.init.normal_(self.soft_prompt_hub.weight, std=0.02)
self.apply(basic_init)
def forward(
self,
domain_id: torch.LongTensor,
vlm_features: torch.Tensor,
aux_visual_inputs: torch.Tensor,
action_with_noise: torch.Tensor,
proprio: torch.Tensor,
t: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass.
Inputs
------
domain_id : [B]
vlm_features : [B, T_vlm, D]
aux_visual_inputs : [B, T_aux, D]
action_with_noise : [B, T_action, dim_action]
proprio : [B, dim_propio]
t : [B]
Returns
-------
Tensor
Predicted actions, [batch_size, num_actions, dim_action]
"""
batch_size, num_actions = action_with_noise.shape[:2]
# Encode (action + proprio + time) → tokens
time_emb = timestep_embedding(t, self.dim_time) # [batch_size, dim_time]
time_tokens = time_emb.unsqueeze(1).expand(batch_size, num_actions, self.dim_time)
proprio_tokens = proprio.unsqueeze(1).expand(batch_size, num_actions, proprio.shape[-1])
action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
x = self.action_encoder(action_tokens, domain_id) # [batch_size, num_actions, hidden_size]
# Project visual streams and concatenate
if self.use_hetero_proj:
x = torch.cat(
[
x,
self.vlm_proj(vlm_features, domain_id),
self.aux_visual_proj(aux_visual_inputs, domain_id),
],
dim=1,
)
else:
x = torch.cat([x, self.vlm_proj(vlm_features), self.aux_visual_proj(aux_visual_inputs)], dim=1)
# Add positional embeddings (truncate if needed)
seq_len = x.shape[1]
if seq_len > self.pos_emb.shape[1]:
raise ValueError(f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}.")
x = x + self.pos_emb[:, :seq_len, :]
# Append soft prompts
if self.len_soft_prompts > 0:
soft_prompts = self.soft_prompt_hub(domain_id).view(
batch_size, self.len_soft_prompts, self.hidden_size
)
x = torch.cat([x, soft_prompts], dim=1)
# Transformer backbone
for block in self.blocks:
x = block(x)
# Decode only the action segment
return self.action_decoder(self.norm(x[:, :num_actions]), domain_id)
+138
View File
@@ -0,0 +1,138 @@
import math
import numpy as np
def mat2quat(rmat):
"""
Converts given rotation matrix to quaternion.
Args:
rmat (np.array): 3x3 rotation matrix
Returns:
np.array: (x,y,z,w) float quaternion angles
"""
mat = np.asarray(rmat).astype(np.float32)[:3, :3]
m00 = mat[0, 0]
m01 = mat[0, 1]
m02 = mat[0, 2]
m10 = mat[1, 0]
m11 = mat[1, 1]
m12 = mat[1, 2]
m20 = mat[2, 0]
m21 = mat[2, 1]
m22 = mat[2, 2]
# symmetric matrix k
k = np.array(
[
[m00 - m11 - m22, np.float32(0.0), np.float32(0.0), np.float32(0.0)],
[m01 + m10, m11 - m00 - m22, np.float32(0.0), np.float32(0.0)],
[m02 + m20, m12 + m21, m22 - m00 - m11, np.float32(0.0)],
[m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22],
]
)
k /= 3.0
# quaternion is Eigen vector of k that corresponds to largest eigenvalue
w, v = np.linalg.eigh(k)
inds = np.array([3, 0, 1, 2])
q1 = v[inds, np.argmax(w)]
if q1[0] < 0.0:
np.negative(q1, q1)
inds = np.array([1, 2, 3, 0])
return q1[inds]
def quat2axisangle(quat):
"""
Converts quaternion to axis-angle format.
Returns a unit vector direction scaled by its angle in radians.
Args:
quat (np.array): (x,y,z,w) vec4 float angles
Returns:
np.array: (ax,ay,az) axis-angle exponential coordinates
"""
# clip quaternion
if quat[3] > 1.0:
quat[3] = 1.0
elif quat[3] < -1.0:
quat[3] = -1.0
den = np.sqrt(1.0 - quat[3] * quat[3])
if math.isclose(den, 0.0):
# This is (close to) a zero degree rotation, immediately return
return np.zeros(3)
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
def rotate6d_to_axis_angle(r6d):
"""
r6d: np.ndarray, shape (N, 6)
return: np.ndarray, shape (N, 3), axis-angle vectors
"""
flag = 0
if len(r6d.shape) == 1:
r6d = r6d[None, ...]
flag = 1
a1 = r6d[:, 0:3]
a2 = r6d[:, 3:6]
# b1
b1 = a1 / (np.linalg.norm(a1, axis=-1, keepdims=True) + 1e-6)
# b2
dot_prod = np.sum(b1 * a2, axis=-1, keepdims=True)
b2_orth = a2 - dot_prod * b1
b2 = b2_orth / (np.linalg.norm(b2_orth, axis=-1, keepdims=True) + 1e-6)
# b3
b3 = np.cross(b1, b2, axis=-1)
rotation_matrix = np.stack([b1, b2, b3], axis=-1) # shape: (N, 3, 3)
axis_angle_list = []
for i in range(rotation_matrix.shape[0]):
quat = mat2quat(rotation_matrix[i])
axis_angle = quat2axisangle(quat)
axis_angle_list.append(axis_angle)
axis_angle_array = np.stack(axis_angle_list, axis=0) # shape: (N, 3)
if flag == 1:
axis_angle_array = axis_angle_array[0]
return axis_angle_array
def mat_to_rotate6d(abs_action):
if len(abs_action.shape) == 2:
return np.concatenate([abs_action[:3, 0], abs_action[:3, 1]], axis=-1)
elif len(abs_action.shape) == 3:
return np.concatenate([abs_action[:, :3, 0], abs_action[:, :3, 1]], axis=-1)
else:
raise NotImplementedError
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
@@ -0,0 +1,20 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
from .robot_earthrover_mini_plus import EarthRoverMiniPlus
__all__ = ["EarthRoverMiniPlus", "EarthRoverMiniPlusConfig"]
@@ -0,0 +1,35 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration for EarthRover Mini Plus robot."""
from dataclasses import dataclass
from ..config import RobotConfig
@RobotConfig.register_subclass("earthrover_mini_plus")
@dataclass
class EarthRoverMiniPlusConfig(RobotConfig):
"""Configuration for EarthRover Mini Plus robot using Frodobots SDK.
This robot uses cloud-based control via the Frodobots SDK HTTP API.
Camera frames are accessed directly through SDK HTTP endpoints.
Attributes:
sdk_url: URL of the Frodobots SDK server (default: http://localhost:8000)
"""
sdk_url: str = "http://localhost:8000"
@@ -0,0 +1 @@
../../../../docs/source/earthrover_mini_plus.mdx
@@ -0,0 +1,473 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""EarthRover Mini Plus robot using Frodobots SDK."""
import base64
import logging
from functools import cached_property
from typing import Any
import cv2
import numpy as np
import requests
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
logger = logging.getLogger(__name__)
# Action feature keys
ACTION_LINEAR_VEL = "linear.vel"
ACTION_ANGULAR_VEL = "angular.vel"
# Observation feature keys
OBS_FRONT = "front"
OBS_REAR = "rear"
OBS_LINEAR_VEL = "linear.vel"
OBS_BATTERY_LEVEL = "battery.level"
OBS_ORIENTATION_DEG = "orientation.deg"
OBS_GPS_LATITUDE = "gps.latitude"
OBS_GPS_LONGITUDE = "gps.longitude"
OBS_GPS_SIGNAL = "gps.signal"
OBS_SIGNAL_LEVEL = "signal.level"
OBS_VIBRATION = "vibration"
OBS_LAMP_STATE = "lamp.state"
class EarthRoverMiniPlus(Robot):
"""
EarthRover Mini Plus robot controlled via Frodobots SDK HTTP API.
This robot uses cloud-based control through the Frodobots SDK instead of direct
hardware connection. Cameras stream via WebRTC through Agora cloud, and control
commands are sent via HTTP POST requests.
The robot supports:
- Dual cameras (front and rear) accessed via SDK HTTP endpoints
- Linear and angular velocity control
- Battery and orientation telemetry
Attributes:
config: Robot configuration
sdk_base_url: URL of the Frodobots SDK server (default: http://localhost:8000)
"""
config_class = EarthRoverMiniPlusConfig
name = "earthrover_mini_plus"
def __init__(self, config: EarthRoverMiniPlusConfig):
"""Initialize EarthRover Mini Plus robot.
Args:
config: Robot configuration including SDK URL
"""
super().__init__(config)
self.config = config
self.sdk_base_url = "http://localhost:8000"
# Empty cameras dict for compatibility with recording script
# Cameras are accessed directly via SDK, not through Camera objects
self.cameras = {}
self._is_connected = False
# Cache for camera frames (fallback when requests fail)
self._last_front_frame = None
self._last_rear_frame = None
# Cache for robot telemetry data (fallback when requests fail)
self._last_robot_data = None
logger.info(f"Initialized {self.name} with SDK at {self.sdk_base_url}")
@property
def is_connected(self) -> bool:
"""Check if robot is connected to SDK."""
return self._is_connected
def connect(self, calibrate: bool = True) -> None:
"""Connect to robot via Frodobots SDK.
Args:
calibrate: Not used for SDK-based robot (kept for API compatibility)
Raises:
DeviceAlreadyConnectedError: If robot is already connected
DeviceNotConnectedError: If cannot connect to SDK server
"""
if self._is_connected:
raise DeviceAlreadyConnectedError(f"{self.name} is already connected")
# Verify SDK is running and accessible
try:
response = requests.get(f"{self.sdk_base_url}/data", timeout=10.0)
if response.status_code != 200:
raise DeviceNotConnectedError(
f"Cannot connect to SDK at {self.sdk_base_url}. "
"Make sure it's running: hypercorn main:app --reload"
)
except requests.RequestException as e:
raise DeviceNotConnectedError(f"Cannot connect to SDK at {self.sdk_base_url}: {e}") from e
self._is_connected = True
logger.info(f"{self.name} connected to SDK")
if calibrate:
self.calibrate()
def calibrate(self) -> None:
"""Calibration not needed for SDK-based robot."""
logger.info("Calibration not required for SDK-based robot")
@property
def is_calibrated(self) -> bool:
"""SDK robot doesn't require calibration.
Returns:
bool: Always True for SDK-based robots
"""
return True
def configure(self) -> None:
"""Configure robot (no-op for SDK-based robot)."""
pass
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
"""Define the observation space for dataset recording.
Returns:
dict: Observation features with types/shapes:
- front: (480, 640, 3) - Front camera RGB image
- rear: (480, 640, 3) - Rear camera RGB image
- linear.vel: float - Current speed (0-1, SDK reports only positive speeds)
- battery.level: float - Battery level (0-1, normalized from 0-100)
- orientation.deg: float - Robot orientation (0-1, normalized from raw value)
- gps.latitude: float - GPS latitude coordinate
- gps.longitude: float - GPS longitude coordinate
- gps.signal: float - GPS signal strength (0-1, normalized from percentage)
- signal.level: float - Network signal level (0-1, normalized from 0-5)
- vibration: float - Vibration sensor reading
- lamp.state: float - Lamp state (0=off, 1=on)
"""
return {
# Cameras (height, width, channels)
OBS_FRONT: (480, 640, 3),
OBS_REAR: (480, 640, 3),
# Motion state
OBS_LINEAR_VEL: float,
# Robot state
OBS_BATTERY_LEVEL: float,
OBS_ORIENTATION_DEG: float,
# GPS
OBS_GPS_LATITUDE: float,
OBS_GPS_LONGITUDE: float,
OBS_GPS_SIGNAL: float,
# Sensors
OBS_SIGNAL_LEVEL: float,
OBS_VIBRATION: float,
OBS_LAMP_STATE: float,
}
@cached_property
def action_features(self) -> dict[str, type]:
"""Define the action space.
Returns:
dict: Action features with types:
- linear.vel: float - Target linear velocity
- angular.vel: float - Target angular velocity
"""
return {
ACTION_LINEAR_VEL: float,
ACTION_ANGULAR_VEL: float,
}
def get_observation(self) -> dict[str, Any]:
"""Get current robot observation from SDK.
Returns:
dict: Observation containing:
- front: Front camera image (480, 640, 3) in RGB format
- rear: Rear camera image (480, 640, 3) in RGB format
- linear.vel: Current speed (0-1, SDK reports only positive speeds)
- battery.level: Battery level (0-1, normalized from 0-100)
- orientation.deg: Robot orientation (0-1, normalized from raw value)
- gps.latitude: GPS latitude coordinate
- gps.longitude: GPS longitude coordinate
- gps.signal: GPS signal strength (0-1, normalized from percentage)
- signal.level: Network signal level (0-1, normalized from 0-5)
- vibration: Vibration sensor reading
- lamp.state: Lamp state (0=off, 1=on)
Raises:
DeviceNotConnectedError: If robot is not connected
Note:
Camera frames are retrieved from SDK endpoints /v2/front and /v2/rear.
Frames are decoded from base64 and converted from BGR to RGB format.
Robot telemetry is retrieved from /data endpoint.
All SDK values are normalized to appropriate ranges for dataset recording.
"""
if not self._is_connected:
raise DeviceNotConnectedError(f"{self.name} is not connected")
observation = {}
# Get camera images from SDK
frames = self._get_camera_frames()
observation[OBS_FRONT] = frames["front"]
observation[OBS_REAR] = frames["rear"]
# Get robot state from SDK
robot_data = self._get_robot_data()
# Motion state
observation[OBS_LINEAR_VEL] = robot_data["speed"] / 100.0 # Normalize 0-100 to 0-1
# Robot state
observation[OBS_BATTERY_LEVEL] = robot_data["battery"] / 100.0 # Normalize 0-100 to 0-1
observation[OBS_ORIENTATION_DEG] = robot_data["orientation"] / 360.0 # Normalize to 0-1
# GPS data
observation[OBS_GPS_LATITUDE] = robot_data["latitude"]
observation[OBS_GPS_LONGITUDE] = robot_data["longitude"]
observation[OBS_GPS_SIGNAL] = robot_data["gps_signal"] / 100.0 # Normalize percentage to 0-1
# Sensors
observation[OBS_SIGNAL_LEVEL] = robot_data["signal_level"] / 5.0 # Normalize 0-5 to 0-1
observation[OBS_VIBRATION] = robot_data["vibration"]
observation[OBS_LAMP_STATE] = float(robot_data["lamp"]) # 0 or 1
return observation
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
"""Send action to robot via SDK.
Args:
action: Action dict with keys:
- linear.vel: Target linear velocity (-1 to 1)
- angular.vel: Target angular velocity (-1 to 1)
Returns:
dict: The action that was sent (matches action_features keys)
Raises:
DeviceNotConnectedError: If robot is not connected
Note:
Actions are sent to SDK via POST /control endpoint.
SDK expects commands in range [-1, 1].
"""
if not self._is_connected:
raise DeviceNotConnectedError(f"{self.name} is not connected")
# Extract action values and convert to float
linear = float(action.get(ACTION_LINEAR_VEL, 0.0))
angular = float(action.get(ACTION_ANGULAR_VEL, 0.0))
# Send command to SDK
try:
self._send_command_to_sdk(linear, angular)
except Exception as e:
logger.error(f"Error sending action: {e}")
# Return action in format matching action_features
return {
ACTION_LINEAR_VEL: linear,
ACTION_ANGULAR_VEL: angular,
}
def disconnect(self) -> None:
"""Disconnect from robot.
Stops the robot and closes connection to SDK.
Raises:
DeviceNotConnectedError: If robot is not connected
"""
if not self._is_connected:
raise DeviceNotConnectedError(f"{self.name} is not connected")
# Stop the robot before disconnecting
try:
self._send_command_to_sdk(0.0, 0.0)
except Exception as e:
logger.warning(f"Failed to stop robot during disconnect: {e}")
self._is_connected = False
logger.info(f"{self.name} disconnected")
# Private helper methods for SDK communication
def _get_camera_frames(self) -> dict[str, np.ndarray]:
"""Get camera frames from SDK using v2 endpoints with caching fallback.
Returns:
dict: Dictionary with 'front' and 'rear' keys containing:
- Current frame (if request succeeds)
- Cached frame (if request fails but cache exists)
- Zero array (if request fails and no cache exists yet)
Note:
Uses /v2/front and /v2/rear endpoints which are 15x faster than /screenshot.
Images are base64 encoded, resized to 640x480, and converted from BGR to RGB.
If request fails, returns the last successfully retrieved frame (cached).
"""
frames = {}
# Get front camera
try:
response = requests.get(f"{self.sdk_base_url}/v2/front", timeout=2.0)
if response.status_code == 200:
data = response.json()
if "front_frame" in data and data["front_frame"]:
front_img = self._decode_base64_image(data["front_frame"])
if front_img is not None:
# Resize and convert BGR to RGB
front_img = cv2.resize(front_img, (640, 480))
front_rgb = cv2.cvtColor(front_img, cv2.COLOR_BGR2RGB)
frames["front"] = front_rgb
# Cache the successful frame
self._last_front_frame = front_rgb
except Exception as e:
logger.warning(f"Error fetching front camera: {e}")
# Fallback: use cache or zero array
if "front" not in frames:
if self._last_front_frame is not None:
frames["front"] = self._last_front_frame
else:
frames["front"] = np.zeros((480, 640, 3), dtype=np.uint8)
# Get rear camera
try:
response = requests.get(f"{self.sdk_base_url}/v2/rear", timeout=2.0)
if response.status_code == 200:
data = response.json()
if "rear_frame" in data and data["rear_frame"]:
rear_img = self._decode_base64_image(data["rear_frame"])
if rear_img is not None:
# Resize and convert BGR to RGB
rear_img = cv2.resize(rear_img, (640, 480))
rear_rgb = cv2.cvtColor(rear_img, cv2.COLOR_BGR2RGB)
frames["rear"] = rear_rgb
# Cache the successful frame
self._last_rear_frame = rear_rgb
except Exception as e:
logger.warning(f"Error fetching rear camera: {e}")
# Fallback: use cache or zero array
if "rear" not in frames:
if self._last_rear_frame is not None:
frames["rear"] = self._last_rear_frame
else:
frames["rear"] = np.zeros((480, 640, 3), dtype=np.uint8)
return frames
def _decode_base64_image(self, base64_string: str) -> np.ndarray | None:
"""Decode base64 string to image.
Args:
base64_string: Base64 encoded image string
Returns:
np.ndarray: Decoded image in BGR format (OpenCV default), or None if decoding fails
"""
try:
img_bytes = base64.b64decode(base64_string)
nparr = np.frombuffer(img_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
return img # Return in BGR format (OpenCV default)
except Exception as e:
logger.error(f"Error decoding image: {e}")
return None
def _get_robot_data(self) -> dict:
"""Get robot telemetry data from SDK.
Returns:
dict: Robot telemetry data including battery, speed, orientation, GPS, etc:
- Current data (if request succeeds)
- Cached data (if request fails but cache exists)
- Default values (if request fails and no cache exists yet)
Note:
Uses /data endpoint which provides comprehensive robot state.
If request fails, returns the last successfully retrieved data (cached).
"""
try:
response = requests.get(f"{self.sdk_base_url}/data", timeout=2.0)
if response.status_code == 200:
data = response.json()
# Cache the successful data
self._last_robot_data = data
return data
except Exception as e:
logger.warning(f"Error fetching robot data: {e}")
# Fallback: use cache or default values
if self._last_robot_data is not None:
return self._last_robot_data
else:
# Return dict with default values (used only on first failure before any cache exists)
return {
"speed": 0,
"battery": 0,
"orientation": 0,
"latitude": 0.0,
"longitude": 0.0,
"gps_signal": 0,
"signal_level": 0,
"vibration": 0.0,
"lamp": 0,
}
def _send_command_to_sdk(self, linear: float, angular: float, lamp: int = 0) -> bool:
"""Send control command to SDK.
Args:
linear: Linear velocity command (-1 to 1)
angular: Angular velocity command (-1 to 1)
lamp: Lamp control (0=off, 1=on)
Returns:
bool: True if command sent successfully, False otherwise
Note:
Uses POST /control endpoint. Commands are sent as JSON payload.
"""
try:
payload = {
"command": {
"linear": linear,
"angular": angular,
"lamp": lamp,
}
}
response = requests.post(
f"{self.sdk_base_url}/control",
json=payload,
timeout=1.0,
)
return response.status_code == 200
except Exception as e:
logger.error(f"Error sending command: {e}")
return False
@@ -0,0 +1,21 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# OMX is a fully open-source robot from ROBOTIS.
# More information at: https://ai.robotis.com/omx/introduction_omx.html
from .config_omx_follower import OmxFollowerConfig
from .omx_follower import OmxFollower
@@ -0,0 +1,39 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("omx_follower")
@dataclass
class OmxFollowerConfig(RobotConfig):
# Port to connect to the arm
port: str
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
# Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = False
@@ -0,0 +1,225 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from functools import cached_property
from typing import Any
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.dynamixel import (
DriveMode,
DynamixelMotorsBus,
OperatingMode,
)
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
from .config_omx_follower import OmxFollowerConfig
logger = logging.getLogger(__name__)
class OmxFollower(Robot):
"""
- [OMX](https://github.com/ROBOTIS-GIT/open_manipulator),
expansion, developed by Woojin Wie and Junha Cha from [ROBOTIS](https://ai.robotis.com/)
"""
config_class = OmxFollowerConfig
name = "omx_follower"
def __init__(self, config: OmxFollowerConfig):
super().__init__(config)
self.config = config
norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100
self.bus = DynamixelMotorsBus(
port=self.config.port,
motors={
"shoulder_pan": Motor(11, "xl430-w250", norm_mode_body),
"shoulder_lift": Motor(12, "xl430-w250", norm_mode_body),
"elbow_flex": Motor(13, "xl430-w250", norm_mode_body),
"wrist_flex": Motor(14, "xl330-m288", norm_mode_body),
"wrist_roll": Motor(15, "xl330-m288", norm_mode_body),
"gripper": Motor(16, "xl330-m288", MotorNormMode.RANGE_0_100),
},
calibration=self.calibration,
)
self.cameras = make_cameras_from_configs(config.cameras)
@property
def _motors_ft(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.bus.motors}
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
def connect(self, calibrate: bool = True) -> None:
"""
For OMX robots that come pre-calibrated:
- If default calibration from package doesn't match motors, read from motors and save
- This allows using pre-calibrated robots without manual calibration
- If no calibration file exists, use factory default values (homing_offset=0, range_min=0, range_max=4095)
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect()
if not self.is_calibrated and calibrate:
logger.info(
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
)
self.calibrate()
for cam in self.cameras.values():
cam.connect()
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
def calibrate(self) -> None:
self.bus.disable_torque()
logger.info(f"\nUsing factory default calibration values for {self}")
logger.info(f"\nWriting default configuration of {self} to the motors")
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
for motor in self.bus.motors:
self.bus.write("Drive_Mode", motor, DriveMode.NON_INVERTED.value)
self.calibration = {}
for motor, m in self.bus.motors.items():
self.calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=0,
homing_offset=0,
range_min=0,
range_max=4095,
)
self.bus.write_calibration(self.calibration)
self._save_calibration()
logger.info(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
with self.bus.torque_disabled():
self.bus.configure_motors()
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling
# the arm, you could end up with a servo with a position 0 or 4095 at a crucial point
for motor in self.bus.motors:
if motor != "gripper":
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
# Use 'position control current based' for gripper to be limited by the limit of the current. For
# the follower gripper, it means it can grasp an object without forcing too much even tho, its
# goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
# For the leader gripper, it means we can use it as a physical trigger, since we can force with
# our finger to make it move, and it will move back to its original target position when we
# release the force.
self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
# Set better PID values to close the gap between recorded states and actions
# TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor
self.bus.write("Position_P_Gain", "elbow_flex", 1500)
self.bus.write("Position_I_Gain", "elbow_flex", 0)
self.bus.write("Position_D_Gain", "elbow_flex", 600)
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Read arm position
start = time.perf_counter()
obs_dict = self.bus.sync_read("Present_Position")
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
return obs_dict
def send_action(self, action: dict[str, float]) -> dict[str, float]:
"""Command arm to move to a target joint configuration.
The relative action magnitude may be clipped depending on the configuration parameter
`max_relative_target`. In this case, the action sent differs from original action.
Thus, this function always returns the action actually sent.
Args:
action (dict[str, float]): The goal positions for the motors.
Returns:
dict[str, float]: The action sent to the motors, potentially clipped.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
# Cap goal position when too far away from present position.
# /!\ Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
present_pos = self.bus.sync_read("Present_Position")
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
# Send goal position to the arm
self.bus.sync_write("Goal_Position", goal_pos)
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
logger.info(f"{self} disconnected.")
+18
View File
@@ -0,0 +1,18 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_unitree_g1 import UnitreeG1Config
from .unitree_g1 import UnitreeG1
@@ -0,0 +1,63 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig
_GAINS: dict[str, dict[str, list[float]]] = {
"left_leg": {
"kp": [150, 150, 150, 300, 40, 40],
"kd": [2, 2, 2, 4, 2, 2],
}, # pitch, roll, yaw, knee, ankle_pitch, ankle_roll
"right_leg": {"kp": [150, 150, 150, 300, 40, 40], "kd": [2, 2, 2, 4, 2, 2]},
"waist": {"kp": [250, 250, 250], "kd": [5, 5, 5]}, # yaw, roll, pitch
"left_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]}, # shoulder_pitch/roll/yaw, elbow
"left_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]}, # roll, pitch, yaw
"right_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]},
"right_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]},
"other": {"kp": [80, 80, 80, 80, 80, 80], "kd": [3, 3, 3, 3, 3, 3]},
}
def _build_gains() -> tuple[list[float], list[float]]:
"""Build kp and kd lists from body-part groupings."""
kp = [v for g in _GAINS.values() for v in g["kp"]]
kd = [v for g in _GAINS.values() for v in g["kd"]]
return kp, kd
_DEFAULT_KP, _DEFAULT_KD = _build_gains()
@RobotConfig.register_subclass("unitree_g1")
@dataclass
class UnitreeG1Config(RobotConfig):
kp: list[float] = field(default_factory=lambda: _DEFAULT_KP.copy())
kd: list[float] = field(default_factory=lambda: _DEFAULT_KD.copy())
control_dt: float = 1.0 / 250.0 # 250Hz
# launch mujoco simulation
is_simulation: bool = False
# socket config for ZMQ bridge
robot_ip: str = "172.18.129.215"
# cameras (optional)
cameras: dict[str, CameraConfig] = field(default_factory=dict)
+89
View File
@@ -0,0 +1,89 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import IntEnum
# ruff: noqa: N801, N815
NUM_MOTORS = 35
class G1_29_JointArmIndex(IntEnum):
# Left arm
kLeftShoulderPitch = 15
kLeftShoulderRoll = 16
kLeftShoulderYaw = 17
kLeftElbow = 18
kLeftWristRoll = 19
kLeftWristPitch = 20
kLeftWristyaw = 21
# Right arm
kRightShoulderPitch = 22
kRightShoulderRoll = 23
kRightShoulderYaw = 24
kRightElbow = 25
kRightWristRoll = 26
kRightWristPitch = 27
kRightWristYaw = 28
class G1_29_JointIndex(IntEnum):
# Left leg
kLeftHipPitch = 0
kLeftHipRoll = 1
kLeftHipYaw = 2
kLeftKnee = 3
kLeftAnklePitch = 4
kLeftAnkleRoll = 5
# Right leg
kRightHipPitch = 6
kRightHipRoll = 7
kRightHipYaw = 8
kRightKnee = 9
kRightAnklePitch = 10
kRightAnkleRoll = 11
kWaistYaw = 12
kWaistRoll = 13
kWaistPitch = 14
# Left arm
kLeftShoulderPitch = 15
kLeftShoulderRoll = 16
kLeftShoulderYaw = 17
kLeftElbow = 18
kLeftWristRoll = 19
kLeftWristPitch = 20
kLeftWristyaw = 21
# Right arm
kRightShoulderPitch = 22
kRightShoulderRoll = 23
kRightShoulderYaw = 24
kRightElbow = 25
kRightWristRoll = 26
kRightWristPitch = 27
kRightWristYaw = 28
# not used
kNotUsedJoint0 = 29
kNotUsedJoint1 = 30
kNotUsedJoint2 = 31
kNotUsedJoint3 = 32
kNotUsedJoint4 = 33
kNotUsedJoint5 = 34
@@ -0,0 +1,302 @@
#!/usr/bin/env python
"""
Standalone keyboard control script for Unitree G1 robot.
This script provides keyboard-based velocity control for the G1 robot's
locomotion system. It can be run alongside the main robot control to
provide manual movement commands.
Usage:
python keyboard_control.py [--robot-ip IP] [--simulation]
Controls:
W/S: Forward/Backward
A/D: Strafe Left/Right
Q/E: Rotate Left/Right
R/F: Raise/Lower Height (GR00T policies only)
Z: Stop (zero all velocity commands)
ESC/Ctrl+C: Exit
"""
import argparse
import sys
import select
import time
import numpy as np
# Terminal handling for non-blocking keyboard input
try:
import termios
import tty
HAS_TERMIOS = True
except ImportError:
HAS_TERMIOS = False
print("Warning: termios not available. Keyboard controls require Linux/macOS.")
class KeyboardController:
"""Handles keyboard input and converts to locomotion commands."""
def __init__(self, callback=None):
"""
Initialize keyboard controller.
Args:
callback: Optional function called when commands change.
Signature: callback(vx, vy, yaw, height)
"""
self.callback = callback
self.running = False
# Locomotion commands
self.vx = 0.0 # Forward/backward velocity
self.vy = 0.0 # Left/right velocity (strafe)
self.yaw = 0.0 # Rotation rate
self.height = 0.74 # Base height (for GR00T policies)
# Command limits
self.vx_limit = (-0.8, 0.8)
self.vy_limit = (-0.5, 0.5)
self.yaw_limit = (-1.0, 1.0)
self.height_limit = (0.50, 1.00)
# Increments per keypress
self.vx_increment = 0.4
self.vy_increment = 0.25
self.yaw_increment = 0.5
self.height_increment = 0.05
self._old_terminal_settings = None
def get_commands(self) -> tuple[float, float, float, float]:
"""Get current command values as tuple (vx, vy, yaw, height)."""
return (self.vx, self.vy, self.yaw, self.height)
def get_commands_array(self) -> np.ndarray:
"""Get velocity commands as numpy array [vx, vy, yaw]."""
return np.array([self.vx, self.vy, self.yaw], dtype=np.float32)
def reset_commands(self):
"""Reset all commands to zero (stop)."""
self.vx = 0.0
self.vy = 0.0
self.yaw = 0.0
self._notify_callback()
def _clamp(self, value: float, limits: tuple[float, float]) -> float:
"""Clamp value to limits."""
return max(limits[0], min(limits[1], value))
def _notify_callback(self):
"""Call callback with current commands if set."""
if self.callback:
self.callback(self.vx, self.vy, self.yaw, self.height)
def process_key(self, key: str) -> bool:
"""
Process a single key press and update commands.
Args:
key: Single character key that was pressed.
Returns:
True if key was handled, False otherwise.
"""
key = key.lower()
handled = True
if key == 'w':
self.vx = self._clamp(self.vx + self.vx_increment, self.vx_limit)
elif key == 's':
self.vx = self._clamp(self.vx - self.vx_increment, self.vx_limit)
elif key == 'a':
self.vy = self._clamp(self.vy + self.vy_increment, self.vy_limit)
elif key == 'd':
self.vy = self._clamp(self.vy - self.vy_increment, self.vy_limit)
elif key == 'q':
self.yaw = self._clamp(self.yaw + self.yaw_increment, self.yaw_limit)
elif key == 'e':
self.yaw = self._clamp(self.yaw - self.yaw_increment, self.yaw_limit)
elif key == 'r':
self.height = self._clamp(self.height + self.height_increment, self.height_limit)
elif key == 'f':
self.height = self._clamp(self.height - self.height_increment, self.height_limit)
elif key == 'z':
self.reset_commands()
return True # Already notified in reset_commands
else:
handled = False
if handled:
self._notify_callback()
return handled
def _setup_terminal(self):
"""Set terminal to raw mode for single character input."""
if HAS_TERMIOS:
self._old_terminal_settings = termios.tcgetattr(sys.stdin)
tty.setcbreak(sys.stdin.fileno())
def _restore_terminal(self):
"""Restore terminal to original settings."""
if HAS_TERMIOS and self._old_terminal_settings is not None:
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, self._old_terminal_settings)
self._old_terminal_settings = None
def run(self):
"""Run the keyboard listener loop (blocking)."""
if not HAS_TERMIOS:
print("Error: Keyboard controls require termios (Linux/macOS)")
return
self.running = True
self._print_controls()
try:
self._setup_terminal()
while self.running:
# Check for keyboard input with timeout
if select.select([sys.stdin], [], [], 0.1)[0]:
key = sys.stdin.read(1)
# Handle escape sequences (arrow keys, etc.)
if key == '\x1b': # ESC
self.running = False
break
if self.process_key(key):
self._print_status()
except KeyboardInterrupt:
print("\nInterrupted by user")
finally:
self._restore_terminal()
print("\nKeyboard controls stopped")
def stop(self):
"""Stop the keyboard listener."""
self.running = False
def _print_controls(self):
"""Print control instructions."""
print("\n" + "=" * 60)
print("KEYBOARD CONTROLS ACTIVE")
print("=" * 60)
print(" W/S: Forward/Backward")
print(" A/D: Strafe Left/Right")
print(" Q/E: Rotate Left/Right")
print(" R/F: Raise/Lower Height (±5cm)")
print(" Z: Stop (zero all commands)")
print(" ESC: Exit")
print("=" * 60 + "\n")
def _print_status(self):
"""Print current command status."""
print(f"[CMD] vx={self.vx:+.2f}, vy={self.vy:+.2f}, yaw={self.yaw:+.2f} | height={self.height:.3f}m")
class RobotKeyboardController(KeyboardController):
"""Keyboard controller that directly updates a robot's locomotion commands."""
def __init__(self, robot):
"""
Initialize with a UnitreeG1 robot instance.
Args:
robot: UnitreeG1 robot instance with locomotion_cmd attribute.
"""
super().__init__()
self.robot = robot
# Initialize from robot's current state if available
if hasattr(robot, 'locomotion_cmd'):
self.vx = robot.locomotion_cmd[0]
self.vy = robot.locomotion_cmd[1]
self.yaw = robot.locomotion_cmd[2]
if hasattr(robot, 'groot_height_cmd'):
self.height = robot.groot_height_cmd
def _notify_callback(self):
"""Update robot's locomotion commands directly."""
if hasattr(self.robot, 'locomotion_cmd'):
self.robot.locomotion_cmd[0] = self.vx
self.robot.locomotion_cmd[1] = self.vy
self.robot.locomotion_cmd[2] = self.yaw
if hasattr(self.robot, 'groot_height_cmd'):
self.robot.groot_height_cmd = self.height
def start_keyboard_control_thread(robot) -> tuple:
"""
Start keyboard controls for a robot in a background thread.
Args:
robot: UnitreeG1 robot instance.
Returns:
Tuple of (controller, thread) for later stopping.
"""
import threading
controller = RobotKeyboardController(robot)
thread = threading.Thread(target=controller.run, daemon=True)
thread.start()
return controller, thread
def stop_keyboard_control_thread(controller, thread, timeout: float = 2.0):
"""
Stop the keyboard control thread.
Args:
controller: KeyboardController instance.
thread: Thread running the controller.
timeout: Max time to wait for thread to stop.
"""
controller.stop()
thread.join(timeout=timeout)
def main():
"""Standalone keyboard control with optional robot connection."""
parser = argparse.ArgumentParser(description="Keyboard control for Unitree G1")
parser.add_argument("--standalone", action="store_true",
help="Run in standalone mode (just print commands, no robot)")
args = parser.parse_args()
if args.standalone:
# Standalone mode - just demonstrate keyboard input
def print_callback(vx, vy, yaw, height):
print(f" → Would send: vx={vx:+.2f}, vy={vy:+.2f}, yaw={yaw:+.2f}, height={height:.3f}")
controller = KeyboardController(callback=print_callback)
print("Running in STANDALONE mode (no robot connection)")
controller.run()
else:
print("To use with a robot, import and use RobotKeyboardController:")
print("")
print(" from lerobot.robots.unitree_g1.keyboard_control import (")
print(" RobotKeyboardController,")
print(" start_keyboard_control_thread,")
print(" stop_keyboard_control_thread")
print(" )")
print("")
print(" # Start keyboard controls")
print(" controller, thread = start_keyboard_control_thread(robot)")
print("")
print(" # ... robot runs ...")
print("")
print(" # Stop keyboard controls")
print(" stop_keyboard_control_thread(controller, thread)")
print("")
print("Or run with --standalone to test keyboard input without a robot.")
if __name__ == "__main__":
main()
@@ -0,0 +1,212 @@
#!/usr/bin/env python3
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DDS-to-ZMQ bridge server for Unitree G1 robot.
This server runs on the robot and forwards:
- Robot state (LowState) from DDS to ZMQ (for remote clients)
- Robot commands (LowCmd) from ZMQ to DDS (from remote clients)
Uses JSON for secure serialization instead of pickle.
"""
import base64
import contextlib
import json
import threading
import time
from typing import Any
import zmq
from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import MotionSwitcherClient
from unitree_sdk2py.core.channel import ChannelFactoryInitialize, ChannelPublisher, ChannelSubscriber
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState
from unitree_sdk2py.utils.crc import CRC
# DDS topic names follow Unitree SDK naming conventions
# ruff: noqa: N816
kTopicLowCommand_Debug = "rt/lowcmd" # action to robot
kTopicLowState = "rt/lowstate" # observation from robot
LOWCMD_PORT = 6000
LOWSTATE_PORT = 6001
NUM_MOTORS = 35
def lowstate_to_dict(msg: hg_LowState) -> dict[str, Any]:
"""Convert LowState SDK message to a JSON-serializable dictionary."""
motor_states = []
for i in range(NUM_MOTORS):
temp = msg.motor_state[i].temperature
avg_temp = float(sum(temp) / len(temp)) if isinstance(temp, list) else float(temp)
motor_states.append(
{
"q": float(msg.motor_state[i].q),
"dq": float(msg.motor_state[i].dq),
"tau_est": float(msg.motor_state[i].tau_est),
"temperature": avg_temp,
}
)
return {
"motor_state": motor_states,
"imu_state": {
"quaternion": [float(x) for x in msg.imu_state.quaternion],
"gyroscope": [float(x) for x in msg.imu_state.gyroscope],
"accelerometer": [float(x) for x in msg.imu_state.accelerometer],
"rpy": [float(x) for x in msg.imu_state.rpy],
"temperature": float(msg.imu_state.temperature),
},
# Encode bytes as base64 for JSON compatibility
"wireless_remote": base64.b64encode(bytes(msg.wireless_remote)).decode("ascii"),
"mode_machine": int(msg.mode_machine),
}
def dict_to_lowcmd(data: dict[str, Any]) -> hg_LowCmd:
"""Convert dictionary back to LowCmd SDK message."""
cmd = unitree_hg_msg_dds__LowCmd_()
cmd.mode_pr = data.get("mode_pr", 0)
cmd.mode_machine = data.get("mode_machine", 0)
for i, motor_data in enumerate(data.get("motor_cmd", [])):
cmd.motor_cmd[i].mode = motor_data.get("mode", 0)
cmd.motor_cmd[i].q = motor_data.get("q", 0.0)
cmd.motor_cmd[i].dq = motor_data.get("dq", 0.0)
cmd.motor_cmd[i].kp = motor_data.get("kp", 0.0)
cmd.motor_cmd[i].kd = motor_data.get("kd", 0.0)
cmd.motor_cmd[i].tau = motor_data.get("tau", 0.0)
return cmd
def state_forward_loop(
lowstate_sub: ChannelSubscriber,
lowstate_sock: zmq.Socket,
state_period: float,
shutdown_event: threading.Event,
) -> None:
"""Read observation from DDS and forward to ZMQ clients."""
last_state_time = 0.0
while not shutdown_event.is_set():
# read from DDS
msg = lowstate_sub.Read()
if msg is None:
continue
now = time.time()
# optional downsampling (if robot dds rate > state_period)
if now - last_state_time >= state_period:
# Convert to dict and serialize with JSON
state_dict = lowstate_to_dict(msg)
payload = json.dumps({"topic": kTopicLowState, "data": state_dict}).encode("utf-8")
# if no subscribers / tx buffer full, just drop
with contextlib.suppress(zmq.Again):
lowstate_sock.send(payload, zmq.NOBLOCK)
last_state_time = now
def cmd_forward_loop(
lowcmd_sock: zmq.Socket,
lowcmd_pub_debug: ChannelPublisher,
crc: CRC,
) -> None:
"""Receive commands from ZMQ and forward to DDS."""
while True:
try:
payload = lowcmd_sock.recv()
except zmq.ContextTerminated:
break
msg_dict = json.loads(payload.decode("utf-8"))
topic = msg_dict.get("topic", "")
cmd_data = msg_dict.get("data", {})
# Reconstruct LowCmd object from dict
cmd = dict_to_lowcmd(cmd_data)
# recompute crc
cmd.crc = crc.Crc(cmd)
if topic == kTopicLowCommand_Debug:
lowcmd_pub_debug.Write(cmd)
def main() -> None:
"""Main entry point for the robot server bridge."""
# initialize DDS
ChannelFactoryInitialize(0)
# stop all active publishers on the robot
msc = MotionSwitcherClient()
msc.SetTimeout(5.0)
msc.Init()
status, result = msc.CheckMode()
while result is not None and "name" in result and result["name"]:
msc.ReleaseMode()
status, result = msc.CheckMode()
time.sleep(1.0)
crc = CRC()
# initialize DDS publisher
lowcmd_pub_debug = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
lowcmd_pub_debug.Init()
# initialize DDS subscriber
lowstate_sub = ChannelSubscriber(kTopicLowState, hg_LowState)
lowstate_sub.Init()
# initialize ZMQ
ctx = zmq.Context.instance()
# receive commands from remote client
lowcmd_sock = ctx.socket(zmq.PULL)
lowcmd_sock.bind(f"tcp://0.0.0.0:{LOWCMD_PORT}")
# publish state to remote clients
lowstate_sock = ctx.socket(zmq.PUB)
lowstate_sock.bind(f"tcp://0.0.0.0:{LOWSTATE_PORT}")
state_period = 0.002 # ~500 hz
shutdown_event = threading.Event()
# start observation forwarding in background thread
t_state = threading.Thread(
target=state_forward_loop,
args=(lowstate_sub, lowstate_sock, state_period, shutdown_event),
)
t_state.start()
print("bridge running (lowstate -> zmq, lowcmd -> dds)")
# run command forwarding in main thread
try:
cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, crc)
except KeyboardInterrupt:
print("shutting down bridge...")
finally:
shutdown_event.set()
ctx.term() # terminates blocking zmq.recv() calls
t_state.join(timeout=2.0)
if __name__ == "__main__":
main()
+284
View File
@@ -0,0 +1,284 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import struct
import threading
import time
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any
import numpy as np
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
LowCmd_ as hg_LowCmd,
LowState_ as hg_LowState,
)
from unitree_sdk2py.utils.crc import CRC
from lerobot.envs.factory import make_env
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
from ..robot import Robot
from .config_unitree_g1 import UnitreeG1Config
logger = logging.getLogger(__name__)
# DDS topic names follow Unitree SDK naming conventions
# ruff: noqa: N816
kTopicLowCommand_Debug = "rt/lowcmd"
kTopicLowState = "rt/lowstate"
G1_29_Num_Motors = 35
G1_23_Num_Motors = 35
H1_2_Num_Motors = 35
H1_Num_Motors = 20
@dataclass
class MotorState:
q: float | None = None # position
dq: float | None = None # velocity
tau_est: float | None = None # estimated torque
temperature: float | None = None # motor temperature
@dataclass
class IMUState:
quaternion: np.ndarray | None = None # [w, x, y, z]
gyroscope: np.ndarray | None = None # [x, y, z] angular velocity (rad/s)
accelerometer: np.ndarray | None = None # [x, y, z] linear acceleration (m/s²)
rpy: np.ndarray | None = None # [roll, pitch, yaw] (rad)
temperature: float | None = None # IMU temperature
# g1 observation class
@dataclass
class G1_29_LowState: # noqa: N801
motor_state: list[MotorState] = field(
default_factory=lambda: [MotorState() for _ in range(G1_29_Num_Motors)]
)
imu_state: IMUState = field(default_factory=IMUState)
wireless_remote: Any = None # Raw wireless remote data
mode_machine: int = 0 # Robot mode
class DataBuffer:
def __init__(self):
self.data = None
self.lock = threading.Lock()
def get_data(self):
with self.lock:
return self.data
def set_data(self, data):
with self.lock:
self.data = data
class UnitreeG1(Robot):
config_class = UnitreeG1Config
name = "unitree_g1"
# unitree remote controller
class RemoteController:
def __init__(self):
self.lx = 0
self.ly = 0
self.rx = 0
self.ry = 0
self.button = [0] * 16
def set(self, data):
# wireless_remote
keys = struct.unpack("H", data[2:4])[0]
for i in range(16):
self.button[i] = (keys & (1 << i)) >> i
self.lx = struct.unpack("f", data[4:8])[0]
self.rx = struct.unpack("f", data[8:12])[0]
self.ry = struct.unpack("f", data[12:16])[0]
self.ly = struct.unpack("f", data[20:24])[0]
def __init__(self, config: UnitreeG1Config):
super().__init__(config)
logger.info("Initialize UnitreeG1...")
self.config = config
self.control_dt = config.control_dt
if config.is_simulation:
from unitree_sdk2py.core.channel import (
ChannelFactoryInitialize,
ChannelPublisher,
ChannelSubscriber,
)
else:
from lerobot.robots.unitree_g1.unitree_sdk2_socket import (
ChannelFactoryInitialize,
ChannelPublisher,
ChannelSubscriber,
)
# connect robot
self.ChannelFactoryInitialize = ChannelFactoryInitialize
self.connect()
# initialize direct motor control interface
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
self.lowcmd_publisher.Init()
self.lowstate_subscriber = ChannelSubscriber(kTopicLowState, hg_LowState)
self.lowstate_subscriber.Init()
self.lowstate_buffer = DataBuffer()
# initialize subscribe thread to read robot state
self._shutdown_event = threading.Event()
self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state)
self.subscribe_thread.start()
while not self.is_connected:
time.sleep(0.1)
# initialize hg's lowcmd msg
self.crc = CRC()
self.msg = unitree_hg_msg_dds__LowCmd_()
self.msg.mode_pr = 0
# Wait for first state message to arrive
lowstate = None
while lowstate is None:
lowstate = self.lowstate_buffer.get_data()
if lowstate is None:
time.sleep(0.01)
logger.warning("[UnitreeG1] Waiting for robot state...")
logger.warning("[UnitreeG1] Connected to robot.")
self.msg.mode_machine = lowstate.mode_machine
# initialize all motors with unified kp/kd from config
self.kp = np.array(config.kp, dtype=np.float32)
self.kd = np.array(config.kd, dtype=np.float32)
for id in G1_29_JointIndex:
self.msg.motor_cmd[id].mode = 1
self.msg.motor_cmd[id].kp = self.kp[id.value]
self.msg.motor_cmd[id].kd = self.kd[id.value]
self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q
# Initialize remote controller
self.remote_controller = self.RemoteController()
def _subscribe_motor_state(self): # polls robot state @ 250Hz
while not self._shutdown_event.is_set():
start_time = time.time()
msg = self.lowstate_subscriber.Read()
if msg is not None:
lowstate = G1_29_LowState()
# Capture motor states
for id in range(G1_29_Num_Motors):
lowstate.motor_state[id].q = msg.motor_state[id].q
lowstate.motor_state[id].dq = msg.motor_state[id].dq
lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est
lowstate.motor_state[id].temperature = msg.motor_state[id].temperature
# Capture IMU state
lowstate.imu_state.quaternion = list(msg.imu_state.quaternion)
lowstate.imu_state.gyroscope = list(msg.imu_state.gyroscope)
lowstate.imu_state.accelerometer = list(msg.imu_state.accelerometer)
lowstate.imu_state.rpy = list(msg.imu_state.rpy)
lowstate.imu_state.temperature = msg.imu_state.temperature
# Capture wireless remote data
lowstate.wireless_remote = msg.wireless_remote
# Capture mode_machine
lowstate.mode_machine = msg.mode_machine
self.lowstate_buffer.set_data(lowstate)
current_time = time.time()
all_t_elapsed = current_time - start_time
sleep_time = max(0, (self.control_dt - all_t_elapsed)) # maintain constant control dt
time.sleep(sleep_time)
@cached_property
def action_features(self) -> dict[str, type]:
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
def calibrate(self) -> None: # robot is already calibrated
pass
def configure(self) -> None:
pass
def connect(self, calibrate: bool = True) -> None: # connect to DDS
if self.config.is_simulation:
self.ChannelFactoryInitialize(0, "lo")
self.mujoco_env = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True)
else:
self.ChannelFactoryInitialize(0)
def disconnect(self):
self._shutdown_event.set()
self.subscribe_thread.join(timeout=2.0)
if self.config.is_simulation:
self.mujoco_env["hub_env"][0].envs[0].kill_sim()
def get_observation(self) -> dict[str, Any]:
return self.lowstate_buffer.get_data()
@property
def is_calibrated(self) -> bool:
return True
@property
def is_connected(self) -> bool:
return self.lowstate_buffer.get_data() is not None
@property
def _motors_ft(self) -> dict[str, type]:
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
self.msg.crc = self.crc.Crc(action)
self.lowcmd_publisher.Write(action)
return action
def get_gravity_orientation(self, quaternion): # get gravity orientation from quaternion
"""Get gravity orientation from quaternion."""
qw = quaternion[0]
qx = quaternion[1]
qy = quaternion[2]
qz = quaternion[3]
gravity_orientation = np.zeros(3)
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
return gravity_orientation
@@ -0,0 +1,168 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import json
from typing import Any
import zmq
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
_ctx: zmq.Context | None = None
_lowcmd_sock: zmq.Socket | None = None
_lowstate_sock: zmq.Socket | None = None
LOWCMD_PORT = 6000
LOWSTATE_PORT = 6001
# DDS topic names follow Unitree SDK naming conventions
# ruff: noqa: N816
kTopicLowCommand_Debug = "rt/lowcmd"
class LowStateMsg:
"""
Wrapper class that mimics the Unitree SDK LowState_ message structure.
Reconstructs the message from deserialized JSON data to maintain
compatibility with existing code that expects SDK message objects.
"""
class MotorState:
"""Motor state data for a single joint."""
def __init__(self, data: dict[str, Any]) -> None:
self.q: float = data.get("q", 0.0)
self.dq: float = data.get("dq", 0.0)
self.tau_est: float = data.get("tau_est", 0.0)
self.temperature: float = data.get("temperature", 0.0)
class IMUState:
"""IMU sensor data."""
def __init__(self, data: dict[str, Any]) -> None:
self.quaternion: list[float] = data.get("quaternion", [1.0, 0.0, 0.0, 0.0])
self.gyroscope: list[float] = data.get("gyroscope", [0.0, 0.0, 0.0])
self.accelerometer: list[float] = data.get("accelerometer", [0.0, 0.0, 0.0])
self.rpy: list[float] = data.get("rpy", [0.0, 0.0, 0.0])
self.temperature: float = data.get("temperature", 0.0)
def __init__(self, data: dict[str, Any]) -> None:
"""Initialize from deserialized JSON data."""
self.motor_state = [self.MotorState(m) for m in data.get("motor_state", [])]
self.imu_state = self.IMUState(data.get("imu_state", {}))
# Decode base64-encoded wireless_remote bytes
wireless_b64 = data.get("wireless_remote", "")
self.wireless_remote: bytes = base64.b64decode(wireless_b64) if wireless_b64 else b""
self.mode_machine: int = data.get("mode_machine", 0)
def lowcmd_to_dict(topic: str, msg: Any) -> dict[str, Any]:
"""Convert LowCmd message to a JSON-serializable dictionary."""
motor_cmds = []
# Iterate over all motor commands in the message
for i in range(len(msg.motor_cmd)):
motor_cmds.append(
{
"mode": int(msg.motor_cmd[i].mode),
"q": float(msg.motor_cmd[i].q),
"dq": float(msg.motor_cmd[i].dq),
"kp": float(msg.motor_cmd[i].kp),
"kd": float(msg.motor_cmd[i].kd),
"tau": float(msg.motor_cmd[i].tau),
}
)
return {
"topic": topic,
"data": {
"mode_pr": int(msg.mode_pr),
"mode_machine": int(msg.mode_machine),
"motor_cmd": motor_cmds,
},
}
def ChannelFactoryInitialize(*args: Any, **kwargs: Any) -> None: # noqa: N802
"""
Initialize ZMQ sockets for robot communication.
This function mimics the Unitree SDK's ChannelFactoryInitialize but uses
ZMQ sockets to connect to the robot server bridge instead of DDS.
"""
global _ctx, _lowcmd_sock, _lowstate_sock
# read socket config
config = UnitreeG1Config()
robot_ip = config.robot_ip
ctx = zmq.Context.instance()
_ctx = ctx
# lowcmd: send robot commands
lowcmd_sock = ctx.socket(zmq.PUSH)
lowcmd_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
lowcmd_sock.connect(f"tcp://{robot_ip}:{LOWCMD_PORT}")
_lowcmd_sock = lowcmd_sock
# lowstate: receive robot observations
lowstate_sock = ctx.socket(zmq.SUB)
lowstate_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
lowstate_sock.connect(f"tcp://{robot_ip}:{LOWSTATE_PORT}")
lowstate_sock.setsockopt_string(zmq.SUBSCRIBE, "")
_lowstate_sock = lowstate_sock
class ChannelPublisher:
"""ZMQ-based publisher that sends commands to the robot server."""
def __init__(self, topic: str, msg_type: type) -> None:
self.topic = topic
self.msg_type = msg_type
def Init(self) -> None: # noqa: N802
"""Initialize the publisher (no-op for ZMQ)."""
pass
def Write(self, msg: Any) -> None: # noqa: N802
"""Serialize and send a command message to the robot."""
if _lowcmd_sock is None:
raise RuntimeError("ChannelFactoryInitialize must be called first")
payload = json.dumps(lowcmd_to_dict(self.topic, msg)).encode("utf-8")
_lowcmd_sock.send(payload)
class ChannelSubscriber:
"""ZMQ-based subscriber that receives state from the robot server."""
def __init__(self, topic: str, msg_type: type) -> None:
self.topic = topic
self.msg_type = msg_type
def Init(self) -> None: # noqa: N802
"""Initialize the subscriber (no-op for ZMQ)."""
pass
def Read(self) -> LowStateMsg: # noqa: N802
"""Receive and deserialize a state message from the robot."""
if _lowstate_sock is None:
raise RuntimeError("ChannelFactoryInitialize must be called first")
payload = _lowstate_sock.recv()
msg_dict = json.loads(payload.decode("utf-8"))
return LowStateMsg(msg_dict.get("data", {}))
+4
View File
@@ -28,6 +28,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
from .koch_follower import KochFollower
return KochFollower(config)
elif config.type == "omx_follower":
from .omx_follower import OmxFollower
return OmxFollower(config)
elif config.type == "so100_follower":
from .so100_follower import SO100Follower
+4 -2
View File
@@ -40,6 +40,7 @@ from lerobot.robots import ( # noqa: F401
koch_follower,
lekiwi,
make_robot_from_config,
omx_follower,
so100_follower,
so101_follower,
)
@@ -49,10 +50,11 @@ from lerobot.teleoperators import ( # noqa: F401
homunculus,
koch_leader,
make_teleoperator_from_config,
omx_leader,
so100_leader,
so101_leader,
)
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.utils import init_logging
@@ -84,7 +86,7 @@ def calibrate(cfg: CalibrateConfig):
def main():
register_third_party_devices()
register_third_party_plugins()
calibrate()
@@ -65,7 +65,6 @@ import argparse
import gc
import logging
import time
from collections.abc import Iterator
from pathlib import Path
import numpy as np
@@ -78,19 +77,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset: LeRobotDataset, episode_index: int):
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
self.frame_ids = range(from_idx, to_idx)
def __iter__(self) -> Iterator:
return iter(self.frame_ids)
def __len__(self) -> int:
return len(self.frame_ids)
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
@@ -119,12 +105,10 @@ def visualize_dataset(
repo_id = dataset.repo_id
logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
sampler=episode_sampler,
)
logging.info("Starting Rerun")
+455 -5
View File
@@ -18,7 +18,8 @@
Edit LeRobot datasets using various transformation tools.
This script allows you to delete episodes, split datasets, merge datasets,
and remove features. When new_repo_id is specified, creates a new dataset.
remove features, and convert image datasets to video format.
When new_repo_id is specified, creates a new dataset.
Usage Examples:
@@ -65,6 +66,25 @@ Remove camera feature:
--operation.type remove_feature \
--operation.feature_names "['observation.images.top']"
Convert image dataset to video format (saves locally):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir /path/to/output/pusht_video
Convert image dataset and save with new repo_id:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_to_video
Convert and push to hub:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_to_video \
--push_to_hub true
Using JSON config file:
python -m lerobot.scripts.lerobot_edit_dataset \
--config_path path/to/edit_config.json
@@ -72,9 +92,13 @@ Using JSON config file:
import logging
import shutil
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
import pandas as pd
from tqdm import tqdm
from lerobot.configs import parser
from lerobot.datasets.dataset_tools import (
delete_episodes,
@@ -82,8 +106,10 @@ from lerobot.datasets.dataset_tools import (
remove_feature,
split_dataset,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import write_stats, write_tasks
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
from lerobot.utils.utils import init_logging
@@ -111,10 +137,23 @@ class RemoveFeatureConfig:
feature_names: list[str] | None = None
@dataclass
class ConvertToVideoConfig:
type: str = "convert_to_video"
output_dir: str | None = None
vcodec: str = "libsvtav1"
pix_fmt: str = "yuv420p"
g: int = 2
crf: int = 30
fast_decode: int = 0
episode_indices: list[int] | None = None
num_workers: int = 4
@dataclass
class EditDatasetConfig:
repo_id: str
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertToVideoConfig
root: str | None = None
new_repo_id: str | None = None
push_to_hub: bool = False
@@ -258,6 +297,415 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
def save_episode_images_for_video(
dataset: LeRobotDataset,
imgs_dir: Path,
img_key: str,
episode_index: int,
num_workers: int = 4,
) -> None:
"""Save images from a specific episode and camera to disk for video encoding.
Args:
dataset: The LeRobot dataset to extract images from
imgs_dir: Directory to save images to
img_key: The image key (camera) to extract
episode_index: Index of the episode to save
num_workers: Number of threads for parallel image saving
"""
# Create directory
imgs_dir.mkdir(parents=True, exist_ok=True)
# Get dataset without torch format for PIL image access
hf_dataset = dataset.hf_dataset.with_format(None)
# Select only this camera's images
imgs_dataset = hf_dataset.select_columns(img_key)
# Get episode start and end indices
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
# Get all items for this episode
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
# Define function to save a single image
def save_single_image(i_item_tuple):
i, item = i_item_tuple
img = item[img_key]
# Use frame-XXXXXX.png format to match encode_video_frames expectations
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
return i
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
items = list(enumerate(episode_dataset))
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(save_single_image, item) for item in items]
for future in as_completed(futures):
future.result() # This will raise any exceptions that occurred
def encode_episode_videos(
dataset: LeRobotDataset,
new_meta: LeRobotDatasetMetadata,
episode_index: int,
vcodec: str,
pix_fmt: str,
g: int,
crf: int,
fast_decode: int,
temp_dir: Path,
num_image_workers: int = 4,
) -> dict[str, dict]:
"""Encode videos for a single episode and return video metadata.
Args:
dataset: Source dataset with images
new_meta: Metadata object for the new video dataset
episode_index: Episode index to process
vcodec: Video codec
pix_fmt: Pixel format
g: Group of pictures size
crf: Constant rate factor
fast_decode: Fast decode tuning
temp_dir: Temporary directory for images
num_image_workers: Number of workers for saving images
Returns:
Dictionary mapping video keys to their metadata (chunk_index, file_index, timestamps)
"""
hf_dataset = dataset.hf_dataset.with_format(None)
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
video_metadata = {}
fps = int(dataset.fps) # Convert to int for PyAV compatibility
episode_length = dataset.meta.episodes["length"][episode_index]
episode_duration = episode_length / dataset.fps # Use original fps for duration calculation
for img_key in img_keys:
# Save images temporarily
imgs_dir = temp_dir / f"episode_{episode_index:06d}" / img_key
save_episode_images_for_video(dataset, imgs_dir, img_key, episode_index, num_image_workers)
# Determine chunk and file indices
# For simplicity, we'll put each episode in its own file
chunk_idx = episode_index // new_meta.chunks_size
file_idx = episode_index % new_meta.chunks_size
# Create video path in the new dataset structure
video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
)
video_path.parent.mkdir(parents=True, exist_ok=True)
# Encode video
encode_video_frames(
imgs_dir=imgs_dir,
video_path=video_path,
fps=fps,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
overwrite=True,
)
# Clean up temporary images
shutil.rmtree(imgs_dir)
# Store video metadata
video_metadata[img_key] = {
f"videos/{img_key}/chunk_index": chunk_idx,
f"videos/{img_key}/file_index": file_idx,
f"videos/{img_key}/from_timestamp": 0.0,
f"videos/{img_key}/to_timestamp": episode_duration,
}
return video_metadata
def convert_dataset_to_videos(
dataset: LeRobotDataset,
output_dir: Path,
repo_id: str | None = None,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int = 2,
crf: int = 30,
fast_decode: int = 0,
episode_indices: list[int] | None = None,
num_workers: int = 4,
) -> LeRobotDataset:
"""Convert image-based dataset to video-based dataset.
Creates a new LeRobotDataset with videos instead of images, following the proper
LeRobot dataset structure with videos stored in chunked MP4 files.
Args:
dataset: The source LeRobot dataset with images
output_dir: Directory to save the new video dataset
repo_id: Repository ID for the new dataset (default: original_id + "_video")
vcodec: Video codec (default: libsvtav1)
pix_fmt: Pixel format (default: yuv420p)
g: Group of pictures size (default: 2)
crf: Constant rate factor (default: 30)
fast_decode: Fast decode tuning (default: 0)
episode_indices: List of episode indices to convert (None = all episodes)
num_workers: Number of threads for parallel processing (default: 4)
Returns:
New LeRobotDataset with videos
"""
# Check that it's an image dataset
if len(dataset.meta.video_keys) > 0:
raise ValueError(
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
)
# Get all image keys
hf_dataset = dataset.hf_dataset.with_format(None)
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
if len(img_keys) == 0:
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
# Determine which episodes to process
if episode_indices is None:
episode_indices = list(range(dataset.meta.total_episodes))
if repo_id is None:
repo_id = f"{dataset.repo_id}_video"
logging.info(
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
)
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
# Create new features dict, converting image features to video features
new_features = {}
for key, value in dataset.meta.features.items():
if key not in img_keys:
new_features[key] = value
else:
# Convert image key to video format
new_features[key] = value.copy()
new_features[key]["dtype"] = "video" # Change dtype from "image" to "video"
# Video info will be updated after episodes are encoded
# Create new metadata for video dataset
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=new_features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=True,
chunks_size=dataset.meta.chunks_size,
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
)
# Create temporary directory for image extraction
temp_dir = output_dir / "temp_images"
temp_dir.mkdir(parents=True, exist_ok=True)
# Process each episode
all_episode_metadata = []
try:
for ep_idx in tqdm(episode_indices, desc="Converting episodes to videos"):
# Get episode metadata from source
src_episode = dataset.meta.episodes[ep_idx]
# Encode videos for this episode
video_metadata = encode_episode_videos(
dataset=dataset,
new_meta=new_meta,
episode_index=ep_idx,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
temp_dir=temp_dir,
num_image_workers=num_workers,
)
# Build episode metadata
episode_meta = {
"episode_index": ep_idx,
"length": src_episode["length"],
"dataset_from_index": ep_idx * src_episode["length"],
"dataset_to_index": (ep_idx + 1) * src_episode["length"],
}
# Add video metadata
for img_key in img_keys:
episode_meta.update(video_metadata[img_key])
# Add data chunk/file info (using same structure as source)
if "data/chunk_index" in src_episode:
episode_meta["data/chunk_index"] = src_episode["data/chunk_index"]
episode_meta["data/file_index"] = src_episode["data/file_index"]
all_episode_metadata.append(episode_meta)
# Copy and transform data files (removing image columns)
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
# Save episode metadata
episodes_df = pd.DataFrame(all_episode_metadata)
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
episodes_path.parent.mkdir(parents=True, exist_ok=True)
episodes_df.to_parquet(episodes_path, index=False)
# Update metadata info
new_meta.info["total_episodes"] = len(episode_indices)
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata)
new_meta.info["total_tasks"] = dataset.meta.total_tasks
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
# Update video info for all image keys (now videos)
# We need to manually set video info since update_video_info() checks video_keys first
for img_key in img_keys:
if not new_meta.features[img_key].get("info", None):
video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=0, file_index=0
)
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
from lerobot.datasets.utils import write_info
write_info(new_meta.info, new_meta.root)
# Copy stats and tasks
if dataset.meta.stats is not None:
# Remove image stats
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
write_stats(new_stats, new_meta.root)
if dataset.meta.tasks is not None:
write_tasks(dataset.meta.tasks, new_meta.root)
finally:
# Clean up temporary directory
if temp_dir.exists():
shutil.rmtree(temp_dir)
logging.info(f"✓ Completed converting {dataset.repo_id} to video format")
logging.info(f"New dataset saved to: {output_dir}")
# Return new dataset
return LeRobotDataset(repo_id=repo_id, root=output_dir)
def _copy_data_without_images(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_indices: list[int],
img_keys: list[str],
) -> None:
"""Copy data files without image columns.
Args:
src_dataset: Source dataset
dst_meta: Destination metadata
episode_indices: Episodes to include
img_keys: Image keys to remove
"""
from lerobot.datasets.utils import DATA_DIR
data_dir = src_dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
episode_set = set(episode_indices)
for src_path in tqdm(parquet_files, desc="Processing data files"):
df = pd.read_parquet(src_path).reset_index(drop=True)
# Filter to only include selected episodes
df = df[df["episode_index"].isin(episode_set)].copy()
if len(df) == 0:
continue
# Remove image columns
columns_to_drop = [col for col in img_keys if col in df.columns]
if columns_to_drop:
df = df.drop(columns=columns_to_drop)
# Get chunk and file indices from path
relative_path = src_path.relative_to(src_dataset.root)
chunk_dir = relative_path.parts[1]
file_name = relative_path.parts[2]
chunk_idx = int(chunk_dir.split("-")[1])
file_idx = int(file_name.split("-")[1].split(".")[0])
# Write to destination without pandas index
dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet"
dst_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(dst_path, index=False)
def handle_convert_to_video(cfg: EditDatasetConfig) -> None:
# Note: Parser may create any config type with the right fields, so we access fields directly
# instead of checking isinstance()
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
# Determine output directory and repo_id
# Priority: 1) new_repo_id, 2) operation.output_dir, 3) auto-generated name
output_dir_config = getattr(cfg.operation, "output_dir", None)
if cfg.new_repo_id:
# Use new_repo_id for both local storage and hub push
output_repo_id = cfg.new_repo_id
output_dir = Path(cfg.root) / cfg.new_repo_id if cfg.root else HF_LEROBOT_HOME / cfg.new_repo_id
logging.info(f"Saving to new dataset: {cfg.new_repo_id}")
elif output_dir_config:
# Use custom output directory for local-only storage
output_dir = Path(output_dir_config)
# Extract repo name from output_dir for the dataset
output_repo_id = output_dir.name
logging.info(f"Saving to local directory: {output_dir}")
else:
# Auto-generate name: append "_video" to original repo_id
output_repo_id = f"{cfg.repo_id}_video"
output_dir = Path(cfg.root) / output_repo_id if cfg.root else HF_LEROBOT_HOME / output_repo_id
logging.info(f"Saving to auto-generated location: {output_dir}")
logging.info(f"Converting dataset {cfg.repo_id} to video format")
new_dataset = convert_dataset_to_videos(
dataset=dataset,
output_dir=output_dir,
repo_id=output_repo_id,
vcodec=getattr(cfg.operation, "vcodec", "libsvtav1"),
pix_fmt=getattr(cfg.operation, "pix_fmt", "yuv420p"),
g=getattr(cfg.operation, "g", 2),
crf=getattr(cfg.operation, "crf", 30),
fast_decode=getattr(cfg.operation, "fast_decode", 0),
episode_indices=getattr(cfg.operation, "episode_indices", None),
num_workers=getattr(cfg.operation, "num_workers", 4),
)
logging.info("Video dataset created successfully!")
logging.info(f"Location: {output_dir}")
logging.info(f"Episodes: {new_dataset.meta.total_episodes}")
logging.info(f"Frames: {new_dataset.meta.total_frames}")
if cfg.push_to_hub:
logging.info(f"Pushing to hub as {output_repo_id}...")
new_dataset.push_to_hub()
logging.info("✓ Successfully pushed to hub!")
else:
logging.info("Dataset saved locally (not pushed to hub)")
@parser.wrap()
def edit_dataset(cfg: EditDatasetConfig) -> None:
operation_type = cfg.operation.type
@@ -270,10 +718,12 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
handle_merge(cfg)
elif operation_type == "remove_feature":
handle_remove_feature(cfg)
elif operation_type == "convert_to_video":
handle_convert_to_video(cfg)
else:
raise ValueError(
f"Unknown operation type: {operation_type}\n"
f"Available operations: delete_episodes, split, merge, remove_feature"
f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video"
)
+3 -1
View File
@@ -82,6 +82,7 @@ from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
@@ -533,7 +534,7 @@ def eval_main(cfg: EvalPipelineConfig):
)
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all(
@@ -792,6 +793,7 @@ def eval_policy_all(
def main():
init_logging()
register_third_party_plugins()
eval_main()
+135 -43
View File
@@ -15,18 +15,23 @@
# limitations under the License.
"""
Simple script to control a robot from teleoperation.
Script to find joint limits and end-effector bounds via teleoperation.
Example:
```shell
lerobot-find-joint-limits \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.port=/dev/tty.usbmodem58760432981 \
--robot.id=black \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=blue
--teleop.port=/dev/tty.usbmodem58760434471 \
--teleop.id=blue \
--urdf_path=<user>/SO-ARM100-main/Simulation/SO101/so101_new_calib.urdf \
--target_frame_name=gripper \
--teleop_time_s=30 \
--warmup_time_s=5 \
--control_loop_fps=30
```
"""
@@ -41,14 +46,18 @@ from lerobot.robots import ( # noqa: F401
RobotConfig,
koch_follower,
make_robot_from_config,
omx_follower,
so100_follower,
so101_follower,
)
from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
gamepad,
koch_leader,
make_teleoperator_from_config,
omx_leader,
so100_leader,
so101_leader,
)
from lerobot.utils.robot_utils import precise_sleep
@@ -57,10 +66,19 @@ from lerobot.utils.robot_utils import precise_sleep
class FindJointLimitsConfig:
teleop: TeleoperatorConfig
robot: RobotConfig
# Limit the maximum frames per second. By default, no limit.
# Path to URDF file for kinematics
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
urdf_path: str
target_frame_name: str = "gripper"
# Duration of the recording phase in seconds
teleop_time_s: float = 30
# Display all cameras on screen
display_data: bool = False
# Duration of the warmup phase in seconds
warmup_time_s: float = 5
# Control loop frequency
control_loop_fps: int = 30
@draccus.wrap()
@@ -68,53 +86,127 @@ def find_joint_and_ee_bounds(cfg: FindJointLimitsConfig):
teleop = make_teleoperator_from_config(cfg.teleop)
robot = make_robot_from_config(cfg.robot)
print(f"Connecting to robot: {cfg.robot.type}...")
teleop.connect()
robot.connect()
print("Devices connected.")
start_episode_t = time.perf_counter()
robot_type = getattr(robot.config, "robot_type", "so101")
if "so100" in robot_type or "so101" in robot_type:
# Note to be compatible with the rest of the codebase,
# we are using the new calibration method for so101 and so100
robot_type = "so_new_calibration"
kinematics = RobotKinematics(cfg.robot.urdf_path, cfg.robot.target_frame_name)
# Initialize Kinematics
try:
kinematics = RobotKinematics(cfg.urdf_path, cfg.target_frame_name)
except Exception as e:
print(f"Error initializing kinematics: {e}")
print("Ensure URDF path and target frame name are correct.")
robot.disconnect()
teleop.disconnect()
return
# Initialize min/max values
observation = robot.get_observation()
joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors])
ee_pos = kinematics.forward_kinematics(joint_positions)[:3, 3]
# Initialize variables
max_pos = None
min_pos = None
max_ee = None
min_ee = None
max_pos = joint_positions.copy()
min_pos = joint_positions.copy()
max_ee = ee_pos.copy()
min_ee = ee_pos.copy()
start_t = time.perf_counter()
warmup_done = False
while True:
action = teleop.get_action()
robot.send_action(action)
print("\n" + "=" * 40)
print(f" WARMUP PHASE ({cfg.warmup_time_s}s)")
print(" Move the robot freely to ensure control works.")
print(" Data is NOT being recorded yet.")
print("=" * 40 + "\n")
observation = robot.get_observation()
joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors])
ee_pos = kinematics.forward_kinematics(joint_positions)[:3, 3]
try:
while True:
t0 = time.perf_counter()
# Skip initial warmup period
if (time.perf_counter() - start_episode_t) < 5:
continue
# 1. Teleoperation Control Loop
action = teleop.get_action()
robot.send_action(action)
# Update min/max values
max_ee = np.maximum(max_ee, ee_pos)
min_ee = np.minimum(min_ee, ee_pos)
max_pos = np.maximum(max_pos, joint_positions)
min_pos = np.minimum(min_pos, joint_positions)
# 2. Read Observations
observation = robot.get_observation()
joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors])
if time.perf_counter() - start_episode_t > cfg.teleop_time_s:
print(f"Max ee position {np.round(max_ee, 4).tolist()}")
print(f"Min ee position {np.round(min_ee, 4).tolist()}")
print(f"Max joint pos position {np.round(max_pos, 4).tolist()}")
print(f"Min joint pos position {np.round(min_pos, 4).tolist()}")
break
# 3. Calculate Kinematics
# Forward kinematics to get (x, y, z) translation
ee_pos = kinematics.forward_kinematics(joint_positions)[:3, 3]
precise_sleep(0.01)
current_time = time.perf_counter()
elapsed = current_time - start_t
# 4. Handle Phases
if elapsed < cfg.warmup_time_s:
# Still in warmup
pass
else:
# Phase Transition: Warmup -> Recording
if not warmup_done:
print("\n" + "=" * 40)
print(" RECORDING STARTED")
print(" Move robot to ALL joint limits.")
print(" Press Ctrl+C to stop early and save results.")
print("=" * 40 + "\n")
# Initialize limits with current position at start of recording
max_pos = joint_positions.copy()
min_pos = joint_positions.copy()
max_ee = ee_pos.copy()
min_ee = ee_pos.copy()
warmup_done = True
# Update Limits
max_ee = np.maximum(max_ee, ee_pos)
min_ee = np.minimum(min_ee, ee_pos)
max_pos = np.maximum(max_pos, joint_positions)
min_pos = np.minimum(min_pos, joint_positions)
# Time check
recording_time = elapsed - cfg.warmup_time_s
remaining = cfg.teleop_time_s - recording_time
# Simple throttle for print statements (every ~1 sec)
if int(recording_time * 100) % 100 == 0:
print(f"Time remaining: {remaining:.1f}s", end="\r")
if recording_time > cfg.teleop_time_s:
print("\nTime limit reached.")
break
precise_sleep(max(1.0 / cfg.control_loop_fps - (time.perf_counter() - t0), 0.0))
except KeyboardInterrupt:
print("\n\nInterrupted by user. Stopping safely...")
finally:
# Safety: Disconnect devices
print("\nDisconnecting devices...")
robot.disconnect()
teleop.disconnect()
# Results Output
if max_pos is not None:
print("\n" + "=" * 40)
print("FINAL RESULTS")
print("=" * 40)
# Rounding for readability
r_max_ee = np.round(max_ee, 4).tolist()
r_min_ee = np.round(min_ee, 4).tolist()
r_max_pos = np.round(max_pos, 4).tolist()
r_min_pos = np.round(min_pos, 4).tolist()
print("\n# End Effector Bounds (x, y, z):")
print(f"max_ee = {r_max_ee}")
print(f"min_ee = {r_min_ee}")
print("\n# Joint Position Limits (radians):")
print(f"max_pos = {r_max_pos}")
print(f"min_pos = {r_min_pos}")
else:
print("No data recorded (exited during warmup).")
def main():
+29 -12
View File
@@ -93,12 +93,15 @@ from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_so100_follower,
earthrover_mini_plus,
hope_jr,
koch_follower,
make_robot_from_config,
omx_follower,
so100_follower,
so101_follower,
)
from lerobot.robots.unitree_g1 import config_unitree_g1 # noqa: F401
from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
@@ -106,6 +109,7 @@ from lerobot.teleoperators import ( # noqa: F401
homunculus,
koch_leader,
make_teleoperator_from_config,
omx_leader,
so100_leader,
so101_leader,
)
@@ -118,7 +122,7 @@ from lerobot.utils.control_utils import (
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
)
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import (
get_safe_torch_device,
@@ -194,9 +198,8 @@ class RecordConfig:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.teleop is None and self.policy is None:
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
# Note: teleop and policy can both be None for robots with built-in control (e.g. unitree_g1)
# This is validated in record() after the robot is instantiated
@classmethod
def __get_path_fields__(cls) -> list[str]:
@@ -269,7 +272,12 @@ def record_loop(
for t in teleop
if isinstance(
t,
(so100_leader.SO100Leader | so101_leader.SO101Leader | koch_leader.KochLeader),
(
so100_leader.SO100Leader
| so101_leader.SO101Leader
| koch_leader.KochLeader
| omx_leader.OmxLeader
),
)
),
None,
@@ -332,6 +340,13 @@ def record_loop(
base_action = robot._from_keyboard_to_base_action(keyboard_action)
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
act_processed_teleop = teleop_action_processor((act, obs))
elif policy is None and teleop is None and dataset is not None:
# Observation-only recording (robot controls itself, e.g. unitree_g1)
# Record observations, extract action-relevant values (positions) from obs
# Filter obs_processed to only include keys that match action_features
action_keys = set(robot.action_features.keys())
action_values = {k: v for k, v in obs_processed.items() if k in action_keys}
robot_action_to_send = None
else:
logging.info(
"No policy or teleoperator provided, skipping action generation."
@@ -344,15 +359,17 @@ def record_loop(
if policy is not None and act_processed_policy is not None:
action_values = act_processed_policy
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
else:
elif teleop is not None:
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
# else: observation-only mode, action_values already set above
# Send action to robot
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
_sent_action = robot.send_action(robot_action_to_send)
# Send action to robot (skip if observation-only mode)
if robot_action_to_send is not None:
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
_sent_action = robot.send_action(robot_action_to_send)
# Write to dataset
if dataset is not None:
@@ -512,7 +529,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
def main():
register_third_party_devices()
register_third_party_plugins()
record()
+4 -2
View File
@@ -54,14 +54,16 @@ from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_so100_follower,
earthrover_mini_plus,
hope_jr,
koch_follower,
make_robot_from_config,
omx_follower,
so100_follower,
so101_follower,
)
from lerobot.utils.constants import ACTION
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import (
init_logging,
@@ -127,7 +129,7 @@ def replay(cfg: ReplayConfig):
def main():
register_third_party_devices()
register_third_party_plugins()
replay()
@@ -33,6 +33,7 @@ from lerobot.robots import ( # noqa: F401
koch_follower,
lekiwi,
make_robot_from_config,
omx_follower,
so100_follower,
so101_follower,
)
@@ -40,6 +41,7 @@ from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
koch_leader,
make_teleoperator_from_config,
omx_leader,
so100_leader,
so101_leader,
)
@@ -47,6 +49,8 @@ from lerobot.teleoperators import ( # noqa: F401
COMPATIBLE_DEVICES = [
"koch_follower",
"koch_leader",
"omx_follower",
"omx_leader",
"so100_follower",
"so100_leader",
"so101_follower",
+6 -2
View File
@@ -71,9 +71,11 @@ from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_so100_follower,
earthrover_mini_plus,
hope_jr,
koch_follower,
make_robot_from_config,
omx_follower,
so100_follower,
so101_follower,
)
@@ -83,12 +85,14 @@ from lerobot.teleoperators import ( # noqa: F401
bi_so100_leader,
gamepad,
homunculus,
keyboard,
koch_leader,
make_teleoperator_from_config,
omx_leader,
so100_leader,
so101_leader,
)
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import init_logging, move_cursor_up
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
@@ -217,7 +221,7 @@ def teleoperate(cfg: TeleoperateConfig):
def main():
register_third_party_devices()
register_third_party_plugins()
teleoperate()
+5 -1
View File
@@ -36,6 +36,7 @@ from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.rl.wandb_utils import WandBLogger
from lerobot.scripts.lerobot_eval import eval_policy_all
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import (
@@ -260,7 +261,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info("Creating environment processors")
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(
env_cfg=cfg.env, policy_cfg=cfg.policy
)
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
@@ -446,6 +449,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
def main():
register_third_party_plugins()
train()
@@ -14,12 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_keyboard import KeyboardEndEffectorTeleopConfig, KeyboardTeleopConfig
from .teleop_keyboard import KeyboardEndEffectorTeleop, KeyboardTeleop
from .configuration_keyboard import (
KeyboardEndEffectorTeleopConfig,
KeyboardRoverTeleopConfig,
KeyboardTeleopConfig,
)
from .teleop_keyboard import KeyboardEndEffectorTeleop, KeyboardRoverTeleop, KeyboardTeleop
__all__ = [
"KeyboardTeleopConfig",
"KeyboardTeleop",
"KeyboardEndEffectorTeleopConfig",
"KeyboardEndEffectorTeleop",
"KeyboardRoverTeleopConfig",
"KeyboardRoverTeleop",
]
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration for keyboard teleoperators."""
from dataclasses import dataclass
@@ -30,4 +31,38 @@ class KeyboardTeleopConfig(TeleoperatorConfig):
@TeleoperatorConfig.register_subclass("keyboard_ee")
@dataclass
class KeyboardEndEffectorTeleopConfig(KeyboardTeleopConfig):
"""Configuration for keyboard end-effector teleoperator.
Used for controlling robot end-effectors with keyboard inputs.
Attributes:
use_gripper: Whether to include gripper control in actions
"""
use_gripper: bool = True
@TeleoperatorConfig.register_subclass("keyboard_rover")
@dataclass
class KeyboardRoverTeleopConfig(TeleoperatorConfig):
"""Configuration for keyboard rover teleoperator.
Used for controlling mobile robots like EarthRover Mini Plus with WASD controls.
Attributes:
linear_speed: Default linear velocity magnitude (-1 to 1 range for SDK robots)
angular_speed: Default angular velocity magnitude (-1 to 1 range for SDK robots)
speed_increment: Amount to increase/decrease speed with +/- keys
turn_assist_ratio: Forward motion multiplier when turning with A/D keys (0.0-1.0)
angular_speed_ratio: Ratio of angular to linear speed for synchronized adjustments
min_linear_speed: Minimum linear speed when decreasing (prevents zero speed)
min_angular_speed: Minimum angular speed when decreasing (prevents zero speed)
"""
linear_speed: float = 1.0
angular_speed: float = 1.0
speed_increment: float = 0.1
turn_assist_ratio: float = 0.3
angular_speed_ratio: float = 0.6
min_linear_speed: float = 0.1
min_angular_speed: float = 0.05
@@ -25,7 +25,11 @@ from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnected
from ..teleoperator import Teleoperator
from ..utils import TeleopEvents
from .configuration_keyboard import KeyboardEndEffectorTeleopConfig, KeyboardTeleopConfig
from .configuration_keyboard import (
KeyboardEndEffectorTeleopConfig,
KeyboardRoverTeleopConfig,
KeyboardTeleopConfig,
)
PYNPUT_AVAILABLE = True
try:
@@ -289,3 +293,158 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
TeleopEvents.SUCCESS: success,
TeleopEvents.RERECORD_EPISODE: rerecord_episode,
}
class KeyboardRoverTeleop(KeyboardTeleop):
"""
Keyboard teleoperator for mobile robots like EarthRover Mini Plus.
Provides intuitive WASD-style controls for driving a mobile robot:
- Linear movement (forward/backward)
- Angular movement (turning/rotation)
- Speed adjustment
- Emergency stop
Keyboard Controls:
Movement:
- W: Move forward
- S: Move backward
- A: Turn left (with forward motion)
- D: Turn right (with forward motion)
- Q: Rotate left in place
- E: Rotate right in place
- X: Emergency stop
Speed Control:
- +/=: Increase speed
- -: Decrease speed
System:
- ESC: Disconnect teleoperator
Attributes:
config: Teleoperator configuration
current_linear_speed: Current linear velocity magnitude
current_angular_speed: Current angular velocity magnitude
Example:
```python
from lerobot.teleoperators.keyboard import KeyboardRoverTeleop, KeyboardRoverTeleopConfig
teleop = KeyboardRoverTeleop(
KeyboardRoverTeleopConfig(linear_speed=1.0, angular_speed=1.0, speed_increment=0.1)
)
teleop.connect()
while teleop.is_connected:
action = teleop.get_action()
robot.send_action(action)
```
"""
config_class = KeyboardRoverTeleopConfig
name = "keyboard_rover"
def __init__(self, config: KeyboardRoverTeleopConfig):
super().__init__(config)
# Add rover-specific speed settings
self.current_linear_speed = config.linear_speed
self.current_angular_speed = config.angular_speed
@property
def action_features(self) -> dict:
"""Return action format for rover (linear and angular velocities)."""
return {
"linear.vel": float,
"angular.vel": float,
}
@property
def is_calibrated(self) -> bool:
"""Rover teleop doesn't require calibration."""
return True
def _drain_pressed_keys(self):
"""Update current_pressed state from event queue without clearing held keys"""
while not self.event_queue.empty():
key_char, is_pressed = self.event_queue.get_nowait()
if is_pressed:
self.current_pressed[key_char] = True
else:
# Only remove key if it's being released
self.current_pressed.pop(key_char, None)
def get_action(self) -> dict[str, Any]:
"""
Get the current action based on pressed keys.
Returns:
dict with 'linear.vel' and 'angular.vel' keys
"""
before_read_t = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(
"KeyboardRoverTeleop is not connected. You need to run `connect()` before `get_action()`."
)
self._drain_pressed_keys()
linear_velocity = 0.0
angular_velocity = 0.0
# Check which keys are currently pressed (not released)
active_keys = {key for key, is_pressed in self.current_pressed.items() if is_pressed}
# Linear movement (W/S) - these take priority
if "w" in active_keys:
linear_velocity = self.current_linear_speed
elif "s" in active_keys:
linear_velocity = -self.current_linear_speed
# Turning (A/D/Q/E)
if "d" in active_keys:
angular_velocity = -self.current_angular_speed
if linear_velocity == 0: # If not moving forward/back, add slight forward motion
linear_velocity = self.current_linear_speed * self.config.turn_assist_ratio
elif "a" in active_keys:
angular_velocity = self.current_angular_speed
if linear_velocity == 0: # If not moving forward/back, add slight forward motion
linear_velocity = self.current_linear_speed * self.config.turn_assist_ratio
elif "q" in active_keys:
angular_velocity = self.current_angular_speed
linear_velocity = 0 # Rotate in place
elif "e" in active_keys:
angular_velocity = -self.current_angular_speed
linear_velocity = 0 # Rotate in place
# Stop (X) - overrides everything
if "x" in active_keys:
linear_velocity = 0
angular_velocity = 0
# Speed adjustment
if "+" in active_keys or "=" in active_keys:
self.current_linear_speed += self.config.speed_increment
self.current_angular_speed += self.config.speed_increment * self.config.angular_speed_ratio
logging.info(
f"Speed increased: linear={self.current_linear_speed:.2f}, angular={self.current_angular_speed:.2f}"
)
if "-" in active_keys:
self.current_linear_speed = max(
self.config.min_linear_speed, self.current_linear_speed - self.config.speed_increment
)
self.current_angular_speed = max(
self.config.min_angular_speed,
self.current_angular_speed - self.config.speed_increment * self.config.angular_speed_ratio,
)
logging.info(
f"Speed decreased: linear={self.current_linear_speed:.2f}, angular={self.current_angular_speed:.2f}"
)
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
return {
"linear.vel": linear_velocity,
"angular.vel": angular_velocity,
}
@@ -0,0 +1,18 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_omx_leader import OmxLeaderConfig
from .omx_leader import OmxLeader
@@ -0,0 +1,30 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("omx_leader")
@dataclass
class OmxLeaderConfig(TeleoperatorConfig):
# Port to connect to the arm
port: str
# Sets the arm in torque mode with the gripper motor set to this value. This makes it possible to squeeze
# the gripper and have it spring back to an open position on its own.
gripper_open_pos: float = 37.0
@@ -0,0 +1,165 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.dynamixel import (
DriveMode,
DynamixelMotorsBus,
OperatingMode,
)
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from .config_omx_leader import OmxLeaderConfig
logger = logging.getLogger(__name__)
class OmxLeader(Teleoperator):
"""
- [OMX](https://github.com/ROBOTIS-GIT/open_manipulator),
expansion, developed by Woojin Wie and Junha Cha from [ROBOTIS](https://ai.robotis.com/)
"""
config_class = OmxLeaderConfig
name = "omx_leader"
def __init__(self, config: OmxLeaderConfig):
super().__init__(config)
self.config = config
self.bus = DynamixelMotorsBus(
port=self.config.port,
motors={
"shoulder_pan": Motor(1, "xl330-m288", MotorNormMode.RANGE_M100_100),
"shoulder_lift": Motor(2, "xl330-m288", MotorNormMode.RANGE_M100_100),
"elbow_flex": Motor(3, "xl330-m288", MotorNormMode.RANGE_M100_100),
"wrist_flex": Motor(4, "xl330-m288", MotorNormMode.RANGE_M100_100),
"wrist_roll": Motor(5, "xl330-m288", MotorNormMode.RANGE_M100_100),
"gripper": Motor(6, "xl330-m077", MotorNormMode.RANGE_0_100),
},
calibration=self.calibration,
)
@property
def action_features(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.bus.motors}
@property
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.bus.is_connected
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect()
if not self.is_calibrated and calibrate:
logger.info(
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
)
self.calibrate()
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
def calibrate(self) -> None:
self.bus.disable_torque()
logger.info(f"\nUsing factory default calibration values for {self}")
logger.info(f"\nWriting default configuration of {self} to the motors")
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
for motor in self.bus.motors:
if motor == "gripper":
self.bus.write("Drive_Mode", motor, DriveMode.INVERTED.value)
else:
self.bus.write("Drive_Mode", motor, DriveMode.NON_INVERTED.value)
drive_modes = {motor: 1 if motor == "gripper" else 0 for motor in self.bus.motors}
self.calibration = {}
for motor, m in self.bus.motors.items():
self.calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=drive_modes[motor],
homing_offset=0,
range_min=0,
range_max=4095,
)
self.bus.write_calibration(self.calibration)
self._save_calibration()
logger.info(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
self.bus.disable_torque()
self.bus.configure_motors()
for motor in self.bus.motors:
if motor != "gripper":
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while
# assembling the arm, you could end up with a servo with a position 0 or 4095 at a crucial
# point
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
# Use 'position control current based' for gripper to be limited by the limit of the current.
# For the follower gripper, it means it can grasp an object without forcing too much even tho,
# its goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
# For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger
# to make it move, and it will move back to its original target position when we release the force.
self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
# Set gripper's goal pos in current position mode so that we can use it as a trigger.
self.bus.enable_torque("gripper")
if self.is_calibrated:
self.bus.write("Goal_Position", "gripper", self.config.gripper_open_pos)
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
def get_action(self) -> dict[str, float]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start = time.perf_counter()
action = self.bus.sync_read("Present_Position")
action = {f"{motor}.pos": val for motor, val in action.items()}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return action
def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO(rcadene, aliberts): Implement force feedback
raise NotImplementedError
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect()
logger.info(f"{self} disconnected.")
+4
View File
@@ -41,6 +41,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
from .koch_leader import KochLeader
return KochLeader(config)
elif config.type == "omx_leader":
from .omx_leader import OmxLeader
return OmxLeader(config)
elif config.type == "so100_leader":
from .so100_leader import SO100Leader
+29 -26
View File
@@ -19,7 +19,7 @@ import io
import json
import logging
import pickle # nosec B403: Safe usage for internal serialization only
from multiprocessing import Event
from multiprocessing.synchronize import Event as MpEvent
from queue import Queue
from typing import Any
@@ -28,6 +28,9 @@ import torch
from lerobot.transport import services_pb2
from lerobot.utils.transition import Transition
# FIX for protobuf: Assign the enum to a variable and ignore the type error once
TransferState = services_pb2.TransferState # type: ignore[attr-defined]
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
@@ -40,8 +43,8 @@ def bytes_buffer_size(buffer: io.BytesIO) -> int:
def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True):
buffer = io.BytesIO(buffer)
size_in_bytes = bytes_buffer_size(buffer)
bytes_buffer: io.BytesIO = io.BytesIO(buffer)
size_in_bytes = bytes_buffer_size(bytes_buffer)
sent_bytes = 0
@@ -50,15 +53,15 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
while sent_bytes < size_in_bytes:
transfer_state = services_pb2.TransferState.TRANSFER_MIDDLE
transfer_state = TransferState.TRANSFER_MIDDLE
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
transfer_state = services_pb2.TransferState.TRANSFER_END
transfer_state = TransferState.TRANSFER_END
elif sent_bytes == 0:
transfer_state = services_pb2.TransferState.TRANSFER_BEGIN
transfer_state = TransferState.TRANSFER_BEGIN
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
chunk = buffer.read(size_to_read)
chunk = bytes_buffer.read(size_to_read)
yield message_class(transfer_state=transfer_state, data=chunk)
sent_bytes += size_to_read
@@ -67,7 +70,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event, log_prefix: str = ""):
def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: MpEvent, log_prefix: str = ""):
bytes_buffer = io.BytesIO()
step = 0
@@ -78,17 +81,17 @@ def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event
logging.info(f"{log_prefix} Shutting down receiver")
return
if item.transfer_state == services_pb2.TransferState.TRANSFER_BEGIN:
if item.transfer_state == TransferState.TRANSFER_BEGIN:
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
bytes_buffer.write(item.data)
logging.debug(f"{log_prefix} Received data at step 0")
step = 0
elif item.transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE:
elif item.transfer_state == TransferState.TRANSFER_MIDDLE:
bytes_buffer.write(item.data)
step += 1
logging.debug(f"{log_prefix} Received data at step {step}")
elif item.transfer_state == services_pb2.TransferState.TRANSFER_END:
elif item.transfer_state == TransferState.TRANSFER_END:
bytes_buffer.write(item.data)
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
@@ -109,17 +112,17 @@ def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
"""Convert model state dict to flat array for transmission"""
buffer = io.BytesIO()
bytes_buffer = io.BytesIO()
torch.save(state_dict, buffer)
torch.save(state_dict, bytes_buffer)
return buffer.getvalue()
return bytes_buffer.getvalue()
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return torch.load(buffer, weights_only=True)
bytes_buffer = io.BytesIO(buffer)
bytes_buffer.seek(0)
return torch.load(bytes_buffer, weights_only=True)
def python_object_to_bytes(python_object: Any) -> bytes:
@@ -127,24 +130,24 @@ def python_object_to_bytes(python_object: Any) -> bytes:
def bytes_to_python_object(buffer: bytes) -> Any:
buffer = io.BytesIO(buffer)
buffer.seek(0)
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
bytes_buffer = io.BytesIO(buffer)
bytes_buffer.seek(0)
obj = pickle.load(bytes_buffer) # nosec B301: Safe usage of pickle.load
# Add validation checks here
return obj
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
transitions = torch.load(buffer, weights_only=True)
bytes_buffer = io.BytesIO(buffer)
bytes_buffer.seek(0)
transitions = torch.load(bytes_buffer, weights_only=True)
return transitions
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
buffer = io.BytesIO()
torch.save(transitions, buffer)
return buffer.getvalue()
bytes_buffer = io.BytesIO()
torch.save(transitions, bytes_buffer)
return bytes_buffer.getvalue()
def grpc_channel_options(
+3 -3
View File
@@ -130,14 +130,14 @@ def make_device_from_device_class(config: ChoiceRegistry) -> Any:
)
def register_third_party_devices() -> None:
def register_third_party_plugins() -> None:
"""
Discover and import third-party lerobot_* plugins so they can register themselves.
Scans top-level modules on sys.path for packages starting with
'lerobot_robot_', 'lerobot_camera_' or 'lerobot_teleoperator_' and imports them.
'lerobot_robot_', 'lerobot_camera_', 'lerobot_teleoperator_' or 'lerobot_policy_' and imports them.
"""
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_")
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_")
imported: list[str] = []
failed: list[str] = []
+105
View File
@@ -29,6 +29,7 @@ from lerobot.datasets.dataset_tools import (
remove_feature,
split_dataset,
)
from lerobot.scripts.lerobot_edit_dataset import convert_dataset_to_videos
@pytest.fixture
@@ -1047,3 +1048,107 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
assert new_chunk_indices == original_chunk_indices, "Chunk indices should be preserved"
assert new_file_indices == original_file_indices, "File indices should be preserved"
assert "reward" in modified_dataset.meta.features
def test_convert_dataset_to_videos(tmp_path):
"""Test converting lerobot/pusht_image dataset to video format."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Load the actual lerobot/pusht_image dataset (only first 2 episodes for speed)
source_dataset = LeRobotDataset("lerobot/pusht_image", episodes=[0, 1])
output_dir = tmp_path / "pusht_video"
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
# Verify source dataset has images, not videos
assert len(source_dataset.meta.video_keys) == 0
assert "observation.image" in source_dataset.meta.features
# Convert to video dataset (only first 2 episodes for speed)
video_dataset = convert_dataset_to_videos(
dataset=source_dataset,
output_dir=output_dir,
repo_id="lerobot/pusht_video",
vcodec="libsvtav1",
pix_fmt="yuv420p",
g=2,
crf=30,
episode_indices=[0, 1],
num_workers=2,
)
# Verify new dataset has videos
assert len(video_dataset.meta.video_keys) > 0
assert "observation.image" in video_dataset.meta.video_keys
# Verify correct number of episodes and frames (2 episodes)
assert video_dataset.meta.total_episodes == 2
# Compare against the actual number of frames in the loaded episodes, not metadata total
assert len(video_dataset) == len(source_dataset)
# Verify video files exist
for ep_idx in range(video_dataset.meta.total_episodes):
for video_key in video_dataset.meta.video_keys:
video_path = video_dataset.root / video_dataset.meta.get_video_file_path(ep_idx, video_key)
assert video_path.exists(), f"Video file should exist: {video_path}"
# Verify we can load the dataset and access it
assert len(video_dataset) == video_dataset.meta.total_frames
# Test that we can actually get an item from the video dataset
item = video_dataset[0]
assert "observation.image" in item
assert "action" in item
# Cleanup
import shutil
if output_dir.exists():
shutil.rmtree(output_dir)
def test_convert_dataset_to_videos_subset_episodes(tmp_path):
"""Test converting only specific episodes from lerobot/pusht_image to video format."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Load the actual lerobot/pusht_image dataset (only first 3 episodes)
source_dataset = LeRobotDataset("lerobot/pusht_image", episodes=[0, 1, 2])
output_dir = tmp_path / "pusht_video_subset"
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
# Convert only episode 0 to video (subset of loaded episodes)
episode_indices = [0]
video_dataset = convert_dataset_to_videos(
dataset=source_dataset,
output_dir=output_dir,
repo_id="lerobot/pusht_video_subset",
episode_indices=episode_indices,
num_workers=2,
)
# Verify correct number of episodes
assert video_dataset.meta.total_episodes == len(episode_indices)
# Verify video files exist for selected episodes
assert len(video_dataset.meta.video_keys) > 0
assert "observation.image" in video_dataset.meta.video_keys
# Cleanup
import shutil
if output_dir.exists():
shutil.rmtree(output_dir)
@@ -0,0 +1,318 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script to verify XVLA policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
# ruff: noqa: E402
import random
from copy import deepcopy
from typing import Any
import numpy as np
import pytest
import torch
pytest.importorskip("transformers")
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
from tests.utils import require_cuda # noqa: E402
# Constants
DUMMY_ACTION_DIM = 7 # Standard robot arm action dimension
DUMMY_STATE_DIM = 20 # Proprioceptive state dimension
IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224
NUM_VIEWS = 2 # Number of camera views
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH_LEROBOT = "lerobot/xvla-widowx"
LIBERO_DOMAIN_ID = 0 # Domain ID for examples purposes
# Expected values from original XVLA implementation (reference values)
EXPECTED_ACTIONS_SHAPE = (30, 20)
EXPECTED_ACTIONS_MEAN = 0.117606
EXPECTED_ACTIONS_STD = 0.245411
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.2742, 0.4977, 0.0500, 0.7040, -0.2653])
def set_seed_all(seed: int):
"""Set random seed for all RNG sources to ensure reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Set deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True, warn_only=True)
def instantiate_lerobot_xvla(
from_pretrained: bool = False,
model_path: str = MODEL_PATH_LEROBOT,
) -> tuple[
Any, # Policy
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Instantiate LeRobot XVLA policy with preprocessor and postprocessor."""
if from_pretrained:
policy = XVLAPolicy.from_pretrained(
pretrained_name_or_path=model_path,
strict=False,
)
else:
config = XVLAConfig(
base_model_path=model_path,
n_action_steps=DUMMY_ACTION_DIM,
chunk_size=DUMMY_ACTION_DIM,
device=DEVICE,
num_image_views=NUM_VIEWS,
) # add resize_imgs_with_padding=IMAGE_SIZE, IMAGE_SIZE?
policy = XVLAPolicy(config)
policy.to(DEVICE)
policy.config.device = DEVICE
preprocessor, postprocessor = make_xvla_pre_post_processors(
config=policy.config,
dataset_stats=None, # Pass None for dataset_stats to disable normalization (original XVLA doesn't normalize)
)
return policy, preprocessor, postprocessor
def create_dummy_data(device=DEVICE):
"""Create dummy data for testing both implementations."""
batch_size = 1
prompt = "Pick up the red block and place it in the bin"
# Create random RGB images in [0, 255] uint8 range (as PIL images would be)
# Then convert to [0, 1] float32 range for LeRobot
def fake_rgb(h, w):
arr = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
t = torch.from_numpy(arr).permute(2, 0, 1) # CHW
return t
batch = {
f"{OBS_IMAGES}.image": torch.stack(
[fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]
).to(device),
f"{OBS_IMAGES}.image2": torch.stack(
[fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]
).to(device),
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
"task": [prompt for _ in range(batch_size)],
}
return batch
# Pytest fixtures
@pytest.fixture(scope="module")
def xvla_components():
"""Fixture to instantiate and provide all XVLA components for tests."""
print(f"\nTesting with DEVICE='{DEVICE}'")
print("\n[Setup] Instantiating LeRobot XVLA policy...")
policy_obj, preprocessor_obj, postprocessor_obj = instantiate_lerobot_xvla(from_pretrained=True)
print("✔️ Model loaded successfully")
yield policy_obj, preprocessor_obj, postprocessor_obj
@pytest.fixture(scope="module")
def policy(xvla_components):
"""Fixture to provide the XVLA policy for tests."""
return xvla_components[0]
@pytest.fixture(scope="module")
def preprocessor(xvla_components):
"""Fixture to provide the XVLA preprocessor for tests."""
return xvla_components[1]
@require_cuda
def test_xvla_preprocessor_alignment(policy, preprocessor):
"""Test that LeRobot XVLA preprocessor produces expected outputs."""
print("\n" + "=" * 80)
print("Test: XVLA Preprocessor Outputs")
print("=" * 80)
set_seed_all(42)
print("\nCreating dummy data...")
batch = create_dummy_data()
print("\n[LeRobot] Preprocessing...")
lerobot_observation = preprocessor(deepcopy(batch))
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
print("\nVerifying preprocessor outputs:")
print("-" * 80)
# Expected shapes from tester.txt
expected_shapes = {
"domain_id": (1,),
"input_ids": (1, 50),
"proprio": (1, 20),
"image_mask": (1, 2),
"image_input": (1, 2, 3, 224, 224),
}
for key, expected_shape in expected_shapes.items():
if key in lerobot_inputs:
actual_shape = tuple(lerobot_inputs[key].shape)
print(f"\nKey: {key}")
print(f"Expected shape: {expected_shape}")
print(f"Actual shape: {actual_shape}")
if actual_shape == expected_shape:
print("Shape matches!")
else:
print("Shape mismatch!")
assert actual_shape == expected_shape, f"Shape mismatch for {key}"
else:
print(f"\nKey '{key}' not found in inputs!")
print("\nAll preprocessor outputs have correct shapes!")
@require_cuda
def test_xvla_action_generation(policy, preprocessor):
"""Test XVLA LeRobot implementation generates expected actions."""
print("\n" + "=" * 80)
print("Test: XVLA Action Generation Against Expected Values")
print("=" * 80)
set_seed_all(42)
print("\nCreating dummy data...")
batch = create_dummy_data()
print("\n[LeRobot] Running inference...")
lerobot_observation = preprocessor(deepcopy(batch))
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
# Reset seed for inference
torch.manual_seed(42)
with torch.no_grad():
lerobot_actions = policy.model.generate_actions(**lerobot_inputs, steps=10)
lerobot_actions = lerobot_actions.squeeze(0).float().cpu()
print(f"LeRobot actions shape: {lerobot_actions.shape}")
print(f"LeRobot actions mean: {lerobot_actions.mean().item():.6f}")
print(f"LeRobot actions std: {lerobot_actions.std().item():.6f}")
print(f"LeRobot actions first 5: {lerobot_actions[0, :5]}")
print("\nExpected values (from original XVLA):")
print(f"Expected actions shape: {EXPECTED_ACTIONS_SHAPE}")
print(f"Expected actions mean: {EXPECTED_ACTIONS_MEAN:.6f}")
print(f"Expected actions std: {EXPECTED_ACTIONS_STD:.6f}")
print(f"Expected actions first 5: {EXPECTED_ACTIONS_FIRST_5}")
print("\nAction Comparison:")
print("-" * 80)
# Compare shapes
actual_shape = tuple(lerobot_actions.shape)
assert actual_shape == EXPECTED_ACTIONS_SHAPE, (
f"Shape mismatch: {actual_shape} vs {EXPECTED_ACTIONS_SHAPE}"
)
print(f"✔️ Shape matches: {actual_shape}")
# Compare statistics
actual_mean = lerobot_actions.mean().item()
actual_std = lerobot_actions.std().item()
mean_diff = abs(actual_mean - EXPECTED_ACTIONS_MEAN)
std_diff = abs(actual_std - EXPECTED_ACTIONS_STD)
print(f"\nMean: {actual_mean:.6f} (expected: {EXPECTED_ACTIONS_MEAN:.6f}, diff: {mean_diff:.6e})")
print(f"Std: {actual_std:.6f} (expected: {EXPECTED_ACTIONS_STD:.6f}, diff: {std_diff:.6e})")
# Compare first 5 actions
actual_first_5 = lerobot_actions[0, :5]
first_5_diff = torch.abs(actual_first_5 - EXPECTED_ACTIONS_FIRST_5)
print("\nFirst 5 actions comparison:")
print(f" Actual: {actual_first_5}")
print(f" Expected: {EXPECTED_ACTIONS_FIRST_5}")
print(f" Max diff: {first_5_diff.max().item():.6e}")
print(f" Mean diff: {first_5_diff.mean().item():.6e}")
# Check with different tolerances
tolerances = [1e-5, 1e-4, 1e-3, 1e-2]
for tol in tolerances:
is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol)
status = "Success" if is_close else "Failure"
print(f"{status}: First 5 actions close (atol={tol}): {is_close}")
# Assert with reasonable tolerance
tolerance = 1e-3
assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), (
f"First 5 actions differ by more than tolerance ({tolerance})"
)
print(f"\nSuccess: Actions match expected values within tolerance ({tolerance})!")
@require_cuda
def test_xvla_inference_reproducibility(policy, preprocessor):
"""Test that XVLA inference is reproducible with the same seed."""
print("\n" + "=" * 80)
print("Test: XVLA Inference Reproducibility")
print("=" * 80)
print("\nCreating dummy data...")
batch = create_dummy_data()
# First inference
print("\n[Run 1] Running inference...")
set_seed_all(42)
lerobot_observation = preprocessor(deepcopy(batch))
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
with torch.no_grad():
actions_1 = policy.model.generate_actions(**lerobot_inputs, steps=10)
actions_1 = actions_1.squeeze(0).float().cpu()
# Second inference with same seed
print("\n[Run 2] Running inference with same seed...")
set_seed_all(42)
lerobot_observation = preprocessor(deepcopy(batch))
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
with torch.no_grad():
actions_2 = policy.model.generate_actions(**lerobot_inputs, steps=10)
actions_2 = actions_2.squeeze(0).float().cpu()
print("\nComparing two runs:")
print("-" * 80)
if torch.allclose(actions_1, actions_2, atol=1e-8):
print("Inference is perfectly reproducible!")
else:
diff = torch.abs(actions_1 - actions_2)
print("Small differences detected:")
print(f" Max diff: {diff.max().item():.6e}")
print(f" Mean diff: {diff.mean().item():.6e}")
assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!"
print("\nInference is reproducible!")