Compare commits

...

12 Commits

Author SHA1 Message Date
Pepijn c75df5c3b9 clean up load, add inject stats and extend convert script for libero 2025-09-10 14:00:22 +02:00
Pepijn e2740fe555 add load script 2025-09-09 20:52:21 +02:00
Steven Palma d602e8169c fix(scripts): revert deletion of rs cam config import introduced by #1767 (#1876) 2025-09-08 18:29:39 +02:00
Steven Gong 49baccdccb Disable torque before applying calibration logic (#1889) 2025-09-08 11:38:13 +02:00
Gaëlle Lannuzel 6a3d57031a 2 add reachy 2 to updated lerobot (#1767)
* Start adding Reachy 2 (no camera)

* Fix joint shape

* Remove print

* Modify observation_features

* Fix observation state

* Try adding a fake Reachy teleoperator

* Saving test scripts

* Add reachy2camera to cameras

* Add teleop_left camera to observation

* Create test_reachy2_camera.py

* Update utils.py

* Add all rgb cameras

* Future depth work

* Try adding mobile_base velocity

* Update tests

* Update data_acquisition_server.py

* Update with use_external_commands

* Replay

* Usable with or without mobile base

* No need for new isntance

* Use same ip for cameras

* Remove useless imports

* Add resume

* Divide joints in multiple dicts

* Divide joinits into several dicts in teleoperator

* Fix forgotten method call

* Create test_robot_client.py

* Open gripper on start

* Add arguments for cameras

* Modify get_frame() requested size

* Call generate_joints_dict on _init_

* black + isort

* Add reachy2 in imports

* Add reachy2 dependencies

* Add documentation

* Update reachy2.mdx

* Update reachy2.mdx

* Clean files and add types

* Fix type in send_action

* Remove print

* Delete test files

* Clean code

* Update cameras

* Disconnect from camera

* Run pre-commit hooks

* Update pyproject.toml

* Create test_reachy2.py

* Fix generate_joints

* Update test_reachy2.py

* Update send_action test

* Update reachy2_cameras depth + CameraManager

* Update reachy2_camera tests

* Remove useless import and args

* Rename reachy2_teleoperator

* Create test_reachy2_teleoperator.py

* Fix remainging fake_teleoperator

* Remove useless elements

* Mock cameras in test_reachy2

* Delete commented lines

* Add use_present_position to teleoperator

* Add cameras tests

* Add check no part + test

* Use disable_torque_on_disconnect

* Use odometry for vel with present_position

* Update documentation

* Fix vel value type

* Use ensure_safe_goal_position

* Import joints dict from classes

* Update reachy2.mdx

* Update reachy2.mdx

* Update minimal version

* Update minimal version

* fix(tests) fixes for reachy2 tests; removing reachy2 references from the script

* Add reachy2_sdk fake as plugins

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-09-05 11:03:14 +02:00
Justin Huang d74494d92b Allow max_relative_target to be a float (#1837)
* Remove unused max_relative_target for stretch3

* Fix type annotation and allow integer max_relative_target values

* Configure max_relative_target to be floats instead of ints

* Update docs and types to reflect that max_relative_target can be a dict

* Remove unnecessary isinstance check for ints

* Fix typo in name

---------

Co-authored-by: Justin Huang <justin.huang@jpl.nasa.gov>
2025-09-05 09:58:47 +02:00
Pepijn 882c80d446 Lower limits by 50% for current and torque for gripper motor (#1809)
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2025-08-29 16:06:55 +02:00
Pepijn 61b0eeae4b Add feetech firmware update docs (#1793)
* Add feetech firmware update docs

* add bonus

* formatting

* adapt text

* feedback pr
2025-08-28 11:18:54 +02:00
mgiac-hexagon 577cd10974 Removed dupicate lines of code (#1709) 2025-08-25 12:39:32 +02:00
lxk b0923ab74b fix(dataset): Use provided episode_data in save_episode (#1740)
The 'episode_data' parameter was previously ignored, causing an error if provided. This change ensures it is correctly used, which allows for asynchronous episode saving by passing a copy of the episode buffer, preventing conflicts with the main data collection loop.
2025-08-22 15:24:02 +02:00
Jack Vial 7f70b78f32 Add missing encoding table entries for Koch arm (#1534) 2025-08-20 17:24:05 +02:00
Steven Palma 55198de096 fix(ci): rename libegl1-mesa in deb13 trixie (#1735) 2025-08-14 11:12:06 +02:00
43 changed files with 3306 additions and 65 deletions
+26
View File
@@ -0,0 +1,26 @@
#!/usr/bin/env python
"""Simple script to check buffer naming in the transformed model."""
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
# Load the model with strict=False to see what buffers we have
print("Loading model...")
policy = PI0Policy.from_pretrained("pepijn223/pi0_libero_lerobot", strict=False)
# Check what buffer keys exist
state_dict = policy.state_dict()
buffer_keys = [k for k in state_dict.keys() if "buffer" in k]
normalize_keys = [k for k in state_dict.keys() if "normalize" in k]
print("\nAll buffer keys:")
for key in buffer_keys:
print(f" {key}")
print("\nAll normalize keys:")
for key in normalize_keys:
print(f" {key}")
print("\nAll keys (first 20):")
for i, key in enumerate(state_dict.keys()):
if i < 20:
print(f" {key}")
+1 -1
View File
@@ -29,7 +29,7 @@ ENV DEBIAN_FRONTEND=noninteractive \
# Install system dependencies and uv (as root)
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential git curl libglib2.0-0 libegl1-mesa ffmpeg \
build-essential git curl libglib2.0-0 libegl1-mesa-dev ffmpeg \
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
&& mv /root/.local/bin/uv /usr/local/bin/uv \
+4
View File
@@ -35,10 +35,14 @@
title: Koch v1.1
- local: lekiwi
title: LeKiwi
- local: reachy2
title: Reachy 2
title: "Robots"
- sections:
- local: notebooks
title: Notebooks
- local: feetech
title: Updating Feetech Firmware
title: "Resources"
- sections:
- local: contributing
+71
View File
@@ -0,0 +1,71 @@
# Feetech Motor Firmware Update
This tutorial guides you through updating the firmware of Feetech motors using the official Feetech software.
## Prerequisites
- Windows computer (Feetech software is only available for Windows)
- Feetech motor control board
- USB cable to connect the control board to your computer
- Feetech motors connected to the control board
## Step 1: Download Feetech Software
1. Visit the official Feetech software download page: [https://www.feetechrc.com/software.html](https://www.feetechrc.com/software.html)
2. Download the latest version of the Feetech debugging software (FD)
3. Install the software on your Windows computer
## Step 2: Hardware Setup
1. Connect your Feetech motors to the motor control board
2. Connect the motor control board to your Windows computer via USB cable
3. Ensure power is supplied to the motors
## Step 3: Configure Connection
1. Launch the Feetech debugging software
2. Select the correct COM port from the port dropdown menu
- If unsure which port to use, check Windows Device Manager under "Ports (COM & LPT)"
3. Set the appropriate baud rate (typically 1000000 for most Feetech motors)
4. Click "Open" to establish communication with the control board
## Step 4: Scan for Motors
1. Once connected, click the "Search" button to detect all connected motors
2. The software will automatically discover and list all motors on the bus
3. Each motor will appear with its ID number
## Step 5: Update Firmware
For each motor you want to update:
1. **Select the motor** from the list by clicking on it
2. **Click on Upgrade tab**:
3. **Click on Online button**:
- If an potential firmware update is found, it will be displayed in the box
4. **Click on Upgrade button**:
- The update progress will be displayed
## Step 6: Verify Update
1. After the update completes, the software should automatically refresh the motor information
2. Verify that the firmware version has been updated to the expected version
## Important Notes
⚠️ **Warning**: Do not disconnect power or USB during firmware updates, it will potentially brick the motor.
## Bonus: Motor Debugging on Linux/macOS
For debugging purposes only, you can use the open-source Feetech Debug Tool:
- **Repository**: [FT_SCServo_Debug_Qt](https://github.com/CarolinePascal/FT_SCServo_Debug_Qt/tree/fix/port-search-timer)
### Installation Instructions
Follow the instructions in the repository to install the tool, for Ubuntu you can directly install it, for MacOS you need to build it from source.
**Limitations:**
- This tool is for debugging and parameter adjustment only
- Firmware updates must still be done on Windows with official Feetech software
+288
View File
@@ -0,0 +1,288 @@
# Reachy 2
Reachy 2 is an open-source humanoid robot made by Pollen Robotics, specifically designed for the development of embodied AI and real-world applications.
Check out [Pollen Robotics website](https://www.pollen-robotics.com/reachy/), or access [Reachy 2 documentation](https://docs.pollen-robotics.com/) for more information on the platform!
## Teleoperate Reachy 2
Currently, there are two ways to teleoperate Reachy 2:
- Pollen Robotics VR teleoperation (not included in LeRobot).
- Robot-to-robot teleoperation (use one Reachy 2 to control another).
## Reachy 2 Simulation
**(Linux only)** You can run Reachy 2 in simulation (Gazebo or MuJoCo) using the provided [Docker image](https://hub.docker.com/r/pollenrobotics/reachy2_core).
1. Install [Docker Engine](https://docs.docker.com/engine/).
2. Run (for MuJoCo):
```
docker run --rm -it \
--name reachy \
--privileged \
--network host \
--ipc host \
--device-cgroup-rule='c 189:* rwm' \
--group-add audio \
-e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
-e DISPLAY="$DISPLAY" \
-e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
-e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
-v /dev:/dev \
-v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
-v "$HOME/.reachy.log":/home/reachy/.ros/log \
-v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
--entrypoint /package/launch.sh \
pollenrobotics/reachy2_core:1.7.5.9_deploy \
start_rviz:=true start_sdk_server:=true mujoco:=true
```
> If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance:
>
> ```
> docker run --rm -it \
> --name reachy \
> --privileged \
> --network host \
> --ipc host \
> --device-cgroup-rule='c 189:* rwm' \
> --group-add audio \
> -e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
> -e DISPLAY="$DISPLAY" \
> -e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
> -e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
> -e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \
> -v /dev:/dev \
> -v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
> -v "$HOME/.reachy.log":/home/reachy/.ros/log \
> -v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
> --entrypoint /package/launch.sh \
> pollenrobotics/reachy2_core:1.7.5.9_deploy \
> start_rviz:=true start_sdk_server:=true mujoco:=true
> ```
## Setup
### Prerequisites
- On your robot, check the **service images** meet the minimum versions:
- **reachy2-core >= 1.7.5.2**
- **webrtc >= 2.0.1.1**
Then, if you want to use VR teleoperation:
- Install the [Reachy 2 teleoperation application](https://docs.pollen-robotics.com/teleoperation/teleoperation-introduction/discover-teleoperation/).
Use version **>=v1.2.0**
We recommend using two computers: one for teleoperation (Windows required) and another for recording with LeRobot.
### Install LeRobot
Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot.
Install LeRobot with Reachy 2 dependencies:
```bash
pip install -e ".[reachy2]"
```
### (Optional but recommended) Install pollen_data_acquisition_server
How you manage Reachy 2 recording sessions is up to you, but the **easiest** way is to use this server so you can control sessions directly from the VR teleoperation app.
> **Note:** Currently, only the VR teleoperation application works as a client for this server, so this step primarily targets teleoperation. Youre free to develop custom clients to manage sessions to your needs.
In your LeRobot environment, install the server from source:
```bash
git clone https://github.com/pollen-robotics/pollen_data_acquisition_server.git
cd pollen_data_acquisition_server
pip install -e .
```
Find the [pollen_data_acquisition_server documentation here](https://github.com/pollen-robotics/pollen_data_acquisition_server).
## Step 1: Recording
### Get Reachy 2 IP address
Before starting teleoperation and data recording, find the [robot's IP address](https://docs.pollen-robotics.com/getting-started/setup-reachy2/connect-reachy2/).
We strongly recommend connecting all devices (PC and robot) via **Ethernet**.
### Launch recording
There are two ways to manage recording sessions when using the Reachy 2 VR teleoperation application:
- **Using the data acquisition server (recommended for VR teleop)**: The VR app orchestrates sessions (via the server it tells LeRobot when to create datasets, start/stop episodes) while also controlling the robots motions.
- **Using LeRobots record script**: LeRobot owns session control and decides when to start/stop episodes. If you also use the VR teleop app, its only for motion control.
### Option 1: Using Pollen data acquisition server (recommended for VR teleop)
Make sure you have installed pollen_data_acquisition_server, as explained in the Setup section.
Launch the data acquisition server to be able to manage your session directly from the teleoperation application:
```bash
python -m pollen_data_acquisition_server.server
```
Then get into the teleoperation application and choose "Data acquisition session".
You can finally setup your session by following the screens displayed.
> Even without the VR app, you can use the `pollen_data_acquisition_server` with your own client implementation.
### Option 2: Using lerobot.record
Reachy 2 is fully supported by LeRobots recording features.
If you choose this option but still want to use the VR teleoperation application, select "Standard session" in the app.
**Example: start a recording without the mobile base:**
First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command:
```bash
python -m lerobot.record \
--robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \
--robot.id=r2-0000 \
--robot.use_external_commands=true \
--robot.with_mobile_base=false \
--teleop.type=reachy2_teleoperator \
--teleop.ip_address=192.168.0.200 \
--teleop.with_mobile_base=false \
--dataset.repo_id=pollen_robotics/record_test \
--dataset.single_task="Reachy 2 recording test" \
--dataset.num_episodes=1 \
--dataset.episode_time_s=5 \
--dataset.fps=15 \
--dataset.push_to_hub=true \
--dataset.private=true \
--display_data=true
```
#### Specific Options
**Extended setup overview (all options included):**
```bash
python -m lerobot.record \
--robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \
--robot.use_external_commands=true \
--robot.with_mobile_base=true \
--robot.with_l_arm=true \
--robot.with_r_arm=true \
--robot.with_neck=true \
--robot.with_antennas=true \
--robot.with_left_teleop_camera=true \
--robot.with_right_teleop_camera=true \
--robot.with_torso_camera=false \
--robot.disable_torque_on_disconnect=false \
--robot.max_relative_target=5.0 \
--teleop.type=reachy2_teleoperator \
--teleop.ip_address=192.168.0.200 \
--teleop.use_present_position=false \
--teleop.with_mobile_base=false \
--teleop.with_l_arm=true \
--teleop.with_r_arm=true \
--teleop.with_neck=true \
--teleop.with_antennas=true \
--dataset.repo_id=pollen_robotics/record_test \
--dataset.single_task="Reachy 2 recording test" \
--dataset.num_episodes=1 \
--dataset.episode_time_s=5 \
--dataset.fps=15 \
--dataset.push_to_hub=true \
--dataset.private=true \
--display_data=true
```
##### `--robot.use_external_commands`
Determine whether LeRobot robot.send_action() sends commands to the robot.
**Must** be set to false while using the VR teleoperation application, as the app already sends commands.
##### `--teleop.use_present_position`
Determine whether the teleoperator reads the goal or present position of the robot.
Must be set to true if a compliant Reachy 2 is used to control another one.
##### Use the relevant parts
From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies.
To avoid this, you can exclude specific parts from recording and replay using:
````
--robot.with_<part>=false
```,
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
It determine whether the corresponding part is recorded in the observations. True if not set.
By default, **all parts are recorded**.
The same per-part mechanism is available in `reachy2_teleoperator` as well.
````
--teleop.with\_<part>
```
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
Determine whether the corresponding part is recorded in the actions. True if not set.
> **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator.
For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`.
##### Use the relevant cameras
You can do the same for **cameras**. By default, only the **teleoperation cameras** are recorded (both `left_teleop_camera` and `right_teleop_camera`). Enable or disable each camera with:
```
--robot.with_left_teleop_camera=<true|false>
--robot.with_right_teleop_camera=<true|false>
--robot.with_torso_camera=<true|false>
````
## Step 2: Replay
Make sure the robot is configured with the same parts as the dataset:
```bash
python -m lerobot.replay \
--robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \
--robot.use_external_commands=false \
--robot.with_mobile_base=false \
--dataset.repo_id=pollen_robotics/record_test \
--dataset.episode=0
--display_data=true
````
## Step 3: Train
```bash
python -m lerobot.scripts.train \
--dataset.repo_id=pollen_robotics/record_test \
--policy.type=act \
--output_dir=outputs/train/reachy2_test \
--job_name=reachy2 \
--policy.device=mps \
--wandb.enable=true \
--policy.repo_id=pollen_robotics/record_test_policy
```
## Step 4: Evaluate
```bash
python -m lerobot.record \
--robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \
--display_data=false \
--dataset.repo_id=pollen_robotics/eval_record_test \
--dataset.single_task="Evaluate reachy2 policy" \
--dataset.num_episodes=10 \
--policy.path=outputs/train/reachy2_test/checkpoints/last/pretrained_model
```
+347
View File
@@ -0,0 +1,347 @@
#!/usr/bin/env python
"""Script for Pi0 pretrained policy inference and Hub upload."""
import argparse
from datetime import datetime
import numpy as np
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
# Set seed
torch.manual_seed(42)
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Pi0 policy inference and Hub upload")
parser.add_argument(
"--source-model-id",
type=str,
default="pepijn223/pi0_libero_lerobot",
help="Source model repository ID on Hugging Face Hub",
)
parser.add_argument(
"--dataset-id", type=str, default="pepijn223/libero", help="Dataset repository ID on Hugging Face Hub"
)
parser.add_argument(
"--output-model-id",
type=str,
required=True,
help="Output model repository ID to upload to (e.g., 'your-username/pi0-libero-fixed')",
)
parser.add_argument(
"--device", type=str, default="cpu", choices=["cpu", "cuda", "mps"], help="Device to run inference on"
)
parser.add_argument("--episode", type=int, default=0, help="Episode index to load from dataset")
parser.add_argument(
"--sample-idx", type=int, default=10, help="Sample index within episode to use for inference"
)
parser.add_argument("--private", action="store_true", help="Make the uploaded model private")
parser.add_argument(
"--commit-message", type=str, default=None, help="Custom commit message for the upload"
)
return parser.parse_args()
def _inject_normalization_stats(policy: PI0Policy, dataset_meta: LeRobotDatasetMetadata, key_mapping: dict):
"""Recreate normalization layers with proper stats from the dataset."""
from lerobot.policies.normalize import Normalize, Unnormalize
# Convert numpy stats to the format expected by normalization layers and remap keys
stats = {}
for dataset_key, stat_dict in dataset_meta.stats.items():
# Use mapped key if available, otherwise use original key
policy_key = key_mapping.get(dataset_key, dataset_key)
stats[policy_key] = {
stat_type: torch.from_numpy(stat_array) if isinstance(stat_array, np.ndarray) else stat_array
for stat_type, stat_array in stat_dict.items()
}
print(f"Available stats keys: {list(stats.keys())}")
print(
f"Policy expects keys: input={list(policy.config.input_features.keys())}, output={list(policy.config.output_features.keys())}"
)
# Recreate normalization layers with proper stats
normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
unnormalize_outputs = Unnormalize(
policy.config.output_features, policy.config.normalization_mapping, stats
)
# Replace the normalization layers on the policy
policy.normalize_inputs = normalize_inputs
policy.normalize_targets = normalize_targets
policy.unnormalize_outputs = unnormalize_outputs
print("Normalization layers recreated with dataset stats.")
def configure_policy_features(policy: PI0Policy, dataset: LeRobotDataset):
"""Configure policy input and output features based on dataset metadata."""
print(f"Dataset features: {list(dataset.meta.features.keys())}")
# Create a proper mapping from dataset keys to policy keys
dataset_to_policy_mapping = {}
# Handle images
if "image" in dataset.meta.features:
dataset_to_policy_mapping["image"] = "observation.images.image"
if "wrist_image" in dataset.meta.features:
dataset_to_policy_mapping["wrist_image"] = "observation.images.image2"
# Handle state
if "state" in dataset.meta.features:
dataset_to_policy_mapping["state"] = "observation.state"
# Handle actions
if "actions" in dataset.meta.features:
dataset_to_policy_mapping["actions"] = "action"
print(f"Key mapping: {dataset_to_policy_mapping}")
# Clear existing input features and reconfigure with proper mapping
policy.config.input_features = {}
policy.config.output_features = {}
# Map visual features
for dataset_key, policy_key in dataset_to_policy_mapping.items():
if dataset_key in ["image", "wrist_image"]:
feature_info = dataset.meta.features[dataset_key]
# Convert HWC to CHW format and resize
shape = (3, 224, 224) # Pi0 expects CHW format
policy.config.input_features[policy_key] = PolicyFeature(type=FeatureType.VISUAL, shape=shape)
# Map state features
for dataset_key, policy_key in dataset_to_policy_mapping.items():
if dataset_key == "state":
feature_info = dataset.meta.features[dataset_key]
shape = tuple(feature_info["shape"])
policy.config.input_features[policy_key] = PolicyFeature(type=FeatureType.STATE, shape=shape)
# Map action features
for dataset_key, policy_key in dataset_to_policy_mapping.items():
if dataset_key == "actions":
feature_info = dataset.meta.features[dataset_key]
shape = tuple(feature_info["shape"])
policy.config.output_features[policy_key] = PolicyFeature(type=FeatureType.ACTION, shape=shape)
print(f"Policy input_features: {list(policy.config.input_features.keys())}")
print(f"Policy output_features: {list(policy.config.output_features.keys())}")
print(f"Policy image_features: {list(policy.config.image_features.keys())}")
print(f"Policy action_feature: {policy.config.action_feature}")
return dataset_to_policy_mapping
def fix_buffer_naming(policy: PI0Policy):
"""Fix buffer naming issues in the loaded policy state dict."""
print("Fixing normalization buffer naming issues...")
state_dict = policy.state_dict()
corrected_state_dict = {}
fixes_applied = 0
for key, value in state_dict.items():
new_key = key
# Fix buffer naming: buffer_observation_state_mean -> buffer_observation_state.mean
if "buffer_observation_state_mean" in key:
new_key = key.replace("buffer_observation_state_mean", "buffer_observation_state.mean")
fixes_applied += 1
print(f" Fixed: {key} -> {new_key}")
elif "buffer_observation_state_std" in key:
new_key = key.replace("buffer_observation_state_std", "buffer_observation_state.std")
fixes_applied += 1
print(f" Fixed: {key} -> {new_key}")
# Remove image buffers that aren't expected (they cause conflicts)
elif "buffer_observation_image_mean" in key or "buffer_observation_image_std" in key:
print(f" Removed unexpected buffer: {key}")
continue # Skip this buffer
corrected_state_dict[new_key] = value
# Add missing action buffers with dummy values (will be replaced by dataset stats)
missing_buffers = [
"normalize_targets.buffer_action.mean",
"normalize_targets.buffer_action.std",
"unnormalize_outputs.buffer_action.mean",
"unnormalize_outputs.buffer_action.std",
]
for buffer_key in missing_buffers:
if buffer_key not in corrected_state_dict:
# Use dummy values - these will be overwritten by proper dataset stats later
if "mean" in buffer_key:
corrected_state_dict[buffer_key] = torch.zeros(8) # Assume 8-dim action
else: # std
corrected_state_dict[buffer_key] = torch.ones(8) # Assume 8-dim action
fixes_applied += 1
print(f" Added missing buffer: {buffer_key}")
print(f"Applied {fixes_applied} buffer fixes")
# Load the corrected state dict back into the policy
policy.load_state_dict(corrected_state_dict)
return policy
def main():
"""Main function to run the Pi0 inference and upload."""
args = parse_args()
# Load pretrained Pi0 model directly from Hugging Face Hub
print(f"Loading pretrained Pi0 model from {args.source_model_id}...")
# Load with strict=False to allow missing/unexpected keys, then fix them manually
policy = PI0Policy.from_pretrained(args.source_model_id, strict=False)
policy = fix_buffer_naming(policy)
policy.eval()
policy.to(args.device)
# Load dataset and get a sample
print(f"Loading dataset: {args.dataset_id}")
dataset = LeRobotDataset(args.dataset_id, episodes=[args.episode])
meta: LeRobotDatasetMetadata = dataset.meta
sample = dataset[args.sample_idx]
# Configure policy features
key_mapping = configure_policy_features(policy, dataset)
# Inject normalization stats with proper key mapping
_inject_normalization_stats(policy, meta, key_mapping)
# Prepare batch for PI0 (handle temporal dimensions)
batch = {}
# Map dataset sample keys to policy keys
reverse_mapping = {v: k for k, v in key_mapping.items()}
for policy_key in policy.config.input_features:
# Find the corresponding dataset key
dataset_key = reverse_mapping.get(policy_key, policy_key)
if dataset_key in sample:
data = sample[dataset_key]
# Handle image data: convert from HWC to CHW and normalize
if policy_key.startswith("observation.images."):
if data.dim() == 3 and data.shape[-1] == 3: # HWC format
data = data.permute(2, 0, 1) # Convert to CHW
# Normalize to [0, 1] range if needed
if data.dtype == torch.uint8:
data = data.float() / 255.0
# Resize to expected size if needed
if data.shape[-2:] != (224, 224):
import torch.nn.functional as F # noqa: N812
data = F.interpolate(
data.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False
)[0]
# Remove temporal dimension if present
if data.dim() > len(policy.config.input_features[policy_key].shape):
data = data[0]
batch[policy_key] = data.unsqueeze(0) # Add batch dimension
# Debug: print what's in the sample
print(f"Sample keys: {list(sample.keys())}")
print(f"Batch keys prepared: {list(batch.keys())}")
# Pi0 requires task description - add a default if not available
if "task" in sample:
batch["task"] = [sample["task"]] # Keep as list of strings
else:
print("No task in sample, using default task description")
batch["task"] = ["Complete the manipulation task"]
print(f"Task: {batch['task'][0]}")
print(f"Final batch keys: {list(batch.keys())}")
# Run inference
with torch.no_grad():
action = policy.select_action(batch)
print(f"Predicted action shape: {action.shape}")
print(f"Predicted action: {action.tolist()}")
print("✅ Pi0 pretrained inference completed successfully!")
# Upload to Hugging Face Hub
print(f"\n📤 Uploading model to Hugging Face Hub: {args.output_model_id}")
# Create commit message
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
commit_message = (
args.commit_message
or f"Pi0 model with injected normalization stats from {args.dataset_id} - {timestamp}"
)
# Update model configuration with dataset info
policy.config.push_to_hub = True
policy.config.repo_id = args.output_model_id
policy.config.private = args.private
# Add metadata about the adaptation
adaptation_info = {
"source_model": args.source_model_id,
"dataset_used": args.dataset_id,
"adaptation_date": timestamp,
"stats_injected": True,
"key_mapping": key_mapping,
"inference_test_passed": True,
"sample_action_shape": list(action.shape),
}
try:
# Push to hub
policy.push_to_hub(
repo_id=args.output_model_id,
private=args.private,
commit_message=commit_message,
create_pr=False,
)
# Also save the adaptation info as a separate file
import json
import os
import tempfile
from huggingface_hub import HfApi
api = HfApi()
# Create a temporary file with adaptation info
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(adaptation_info, f, indent=2)
temp_path = f.name
try:
api.upload_file(
path_or_fileobj=temp_path,
path_in_repo="adaptation_info.json",
repo_id=args.output_model_id,
commit_message=f"Add adaptation metadata - {timestamp}",
)
finally:
os.unlink(temp_path)
print(f"✅ Model successfully uploaded to: https://huggingface.co/{args.output_model_id}")
print("📋 Adaptation info:")
for key, value in adaptation_info.items():
print(f" {key}: {value}")
except Exception as e:
print(f"❌ Error uploading to Hub: {e}")
raise
if __name__ == "__main__":
main()
+704
View File
@@ -0,0 +1,704 @@
import json
import os
import random
from datetime import datetime
import numpy as np
import torch
from huggingface_hub import hf_hub_download # noqa: E402
from safetensors.torch import load_file # noqa: E402
from transformers.model_debugging_utils import model_addition_debugger_context
from lerobot.configs.policies import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
RANDOM_SEED = 42 # Set to fixed value for reproducible results
def set_all_seeds(seed=42):
"""Set all random seeds for reproducible results."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
print(f"All random seeds set to {seed} for reproducible results (deterministic mode enabled)")
# Set seeds at the start
set_all_seeds(RANDOM_SEED)
config_model_path = "lerobot/pi0" # Use config from official model
official_model_path = "lerobot/pi0" # Official model
custom_model_path = "pepijn223/pi0_base_fp32" # Custom model to compare # pepijn223/pi0_base_fp32
device = "mps"
USE_FULL_TENSORS = True
SAVE_TENSORS_TO_DISK = False
# Model transformation and upload settings
SAVE_TRANSFORMED_MODEL = True # Set to True to save the transformed model
UPLOAD_TO_HUB = True # Set to True to upload to HuggingFace Hub
TRANSFORMED_MODEL_NAME = "pepijn223/pi0_base_fp32_lerobot_format" # Target repo name
COMMIT_MESSAGE = "Add transformed PI0 model with correct key format for lerobot"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
debug_path = os.path.join("debug_outputs", f"pi0_debug_direct_{timestamp}")
os.makedirs(debug_path, exist_ok=True)
print(f"Model debugging enabled - outputs will be saved to: {debug_path}")
# Download and load the config manually to avoid draccus parsing issues
config_file = hf_hub_download(repo_id=config_model_path, filename="config.json")
with open(config_file) as f:
config_dict = json.load(f)
# Remove the 'type' field that causes draccus issues
if "type" in config_dict:
config_dict.pop("type")
print("Removed 'type' field from config")
# Create shared PI0Config
print("Creating shared PI0Config...")
shared_config = PI0Config(**config_dict)
def load_policy_with_weights(
model_path: str, config: PI0Config, model_name: str, apply_transformations: bool = False
):
"""Load a policy with specified weights but shared config."""
print(f"\n=== Loading {model_name} from {model_path} ===")
# Set deterministic seed before creating the policy to ensure identical initialization
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
policy = PI0Policy(config)
# Download and load weights
model_file = hf_hub_download(repo_id=model_path, filename="model.safetensors")
print(f"Downloaded {model_name} weights to: {model_file}")
# Load state dict and apply transformations
print(f"Investigating safetensors file: {model_file}")
# First, check what's in the metadata
try:
from safetensors import safe_open
with safe_open(model_file, framework="pt", device="cpu") as f:
metadata = f.metadata()
all_keys_in_file = f.keys()
print(f" Total keys in safetensors file: {len(list(all_keys_in_file))}")
# Check for embed_tokens in the file keys
embed_keys_in_file = [k for k in f.keys() if "embed_tokens" in k]
print(f" embed_tokens keys in safetensors: {embed_keys_in_file}")
if metadata:
print(f" Metadata exists: {list(metadata.keys()) if metadata else 'None'}")
except Exception as e:
print(f" Could not inspect safetensors file directly: {e}")
# Now load normally and see what we get
state_dict = load_file(model_file)
print(f" Keys loaded by load_file(): {len(state_dict)} keys")
# Check for embed_tokens in loaded state_dict
loaded_embed_keys = [k for k in state_dict.keys() if "embed_tokens" in k]
print(f" embed_tokens keys in loaded state_dict: {loaded_embed_keys}")
# Check if we need to add "model." prefix (for custom models that don't have it)
sample_key = next(iter(state_dict.keys()))
if not sample_key.startswith("model."):
print(f"Adding 'model.' prefix to all keys (detected format: {sample_key})")
state_dict = {f"model.{k}": v for k, v in state_dict.items()}
# IMPORTANT: Call PI0Policy._transform_state_dict_keys AFTER adding model. prefix
# This ensures tied weights logic can find the correct key pattern
transformed_state_dict = PI0Policy._transform_state_dict_keys(state_dict)
# Apply specific PaliGemma key transformations only for custom models
if apply_transformations:
print("Applying custom model key transformations...")
# First, let's debug what keys we actually have
all_keys = list(transformed_state_dict.keys())
sample_keys = all_keys[:10]
print(f"Sample keys to transform: {sample_keys}")
# Look for specific keys we need to transform and missing keys
embed_tokens_keys = [k for k in all_keys if "embed_tokens" in k]
embedding_keys = [k for k in all_keys if "embed" in k]
lm_head_keys = [k for k in all_keys if "lm_head" in k]
paligemma_keys = [
k for k in all_keys if "paligemma_with_expert.paligemma" in k and "gemma_expert" not in k
]
language_model_keys = [k for k in all_keys if "language_model" in k]
print(f"Found embed_tokens keys: {embed_tokens_keys}")
print(f"Found any embedding keys: {embedding_keys}")
print(f"Found lm_head keys: {lm_head_keys}")
print(
f"Found paligemma keys (non-expert): {paligemma_keys[:5]}{'...' if len(paligemma_keys) > 5 else ''}"
)
print(
f"Found language_model keys: {language_model_keys[:5]}{'...' if len(language_model_keys) > 5 else ''}"
)
print(f"Total keys in model: {len(all_keys)}")
# Check if the embed_tokens is in gemma_expert instead
gemma_expert_embed = [k for k in all_keys if "gemma_expert" in k and "embed_tokens" in k]
print(f"Found gemma_expert embed_tokens keys: {gemma_expert_embed}")
# Check what we're missing and what we actually have
expected_embed_key = "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
if expected_embed_key not in all_keys:
print(f" Missing expected embed_tokens key: {expected_embed_key}")
# Let's see what keys we actually have for debugging
print("Debugging: Looking for any embedding-related keys...")
all_embed_related = [k for k in all_keys if "embed" in k.lower()]
print(f"Keys containing 'embed': {all_embed_related}")
# Look for any keys that might contain embeddings
potential_embed_keys = [
k for k in all_keys if any(word in k for word in ["embed", "embedding", "token"])
]
print(f" Potential embedding keys: {potential_embed_keys}")
# Try to find a suitable replacement
if gemma_expert_embed:
print(f" Will try to copy from: {gemma_expert_embed[0]}")
else:
print(" No gemma_expert embed_tokens found either!")
# Check if there's an embed_tokens in the gemma_expert that we missed
gemma_keys = [k for k in all_keys if "gemma_expert" in k]
print(f" First 10 gemma_expert keys: {gemma_keys[:10]}")
# Check if there are any token-related keys in gemma_expert
token_keys = [k for k in all_keys if "gemma_expert" in k and "token" in k.lower()]
print(f" Gemma expert token-related keys: {token_keys}")
# Check for any keys that look like they might be embeddings
possible_embeds = [
k
for k in all_keys
if any(
pattern in k.lower() for pattern in ["embed_token", "embedding", "wte", "word_embed"]
)
]
print(f" Possible embedding alternatives: {possible_embeds}")
final_state_dict = {}
transformation_count = 0
for key, value in transformed_state_dict.items():
new_key = key
original_key = key
# Transform vision tower keys: ADD .model between paligemma and vision_tower
if "paligemma_with_expert.paligemma.vision_tower.vision_model" in new_key:
new_key = new_key.replace(
"paligemma_with_expert.paligemma.vision_tower.vision_model",
"paligemma_with_expert.paligemma.model.vision_tower.vision_model",
)
print(f"Transformed vision key: {original_key} -> {new_key}")
transformation_count += 1
# Transform multi_modal_projector keys: ADD .model between paligemma and multi_modal_projector
elif "paligemma_with_expert.paligemma.multi_modal_projector" in new_key:
new_key = new_key.replace(
"paligemma_with_expert.paligemma.multi_modal_projector",
"paligemma_with_expert.paligemma.model.multi_modal_projector",
)
print(f"Transformed multi_modal_projector key: {original_key} -> {new_key}")
transformation_count += 1
# NO transformation needed for language_model keys - they're already correct!
# The custom model already has: paligemma.model.language_model.* which is what we need
# NO transformation needed for lm_head - it should stay as paligemma.lm_head
final_state_dict[new_key] = value
print(f"Applied {transformation_count} key transformations")
transformed_state_dict = final_state_dict
else:
print("No transformations applied (official model format)")
# Debug: show what keys the policy expects vs what we have
policy_keys = set(policy.state_dict().keys())
provided_keys = set(transformed_state_dict.keys())
missing_in_provided = policy_keys - provided_keys
extra_in_provided = provided_keys - policy_keys
print(f"Policy expects {len(policy_keys)} keys, we provide {len(provided_keys)} keys")
if missing_in_provided:
print(
f" Missing from provided: {list(missing_in_provided)[:5]}{'...' if len(missing_in_provided) > 5 else ''}"
)
if extra_in_provided:
print(
f" Extra in provided: {list(extra_in_provided)[:5]}{'...' if len(extra_in_provided) > 5 else ''}"
)
# Load the weights into the policy
msg = policy.load_state_dict(transformed_state_dict, strict=True)
print(
f"{model_name} - Missing keys: {len(msg.missing_keys)}, Unexpected keys: {len(msg.unexpected_keys)}"
)
if msg.missing_keys:
print(
f" Actually missing keys: {list(msg.missing_keys)[:3]}{'...' if len(msg.missing_keys) > 3 else ''}"
)
if msg.unexpected_keys:
print(
f" Actually unexpected keys: {list(msg.unexpected_keys)[:3]}{'...' if len(msg.unexpected_keys) > 3 else ''}"
)
# Set deterministic mode and move to device
policy = policy.to(device)
policy.eval()
# Reset the policy to ensure identical internal state
policy.reset()
return policy
# Load both models with shared config
print("Loading both models with shared config...")
official_policy = load_policy_with_weights(
official_model_path, shared_config, "Official Model", apply_transformations=False
)
custom_policy = load_policy_with_weights(
custom_model_path, shared_config, "Custom Model", apply_transformations=True
)
print("\nBoth models loaded successfully!")
print(f"Shared config: {shared_config}")
print(f"Device: {device}")
# Configure input features for both policies since they're not set by default in pretrained models
def configure_policy_features(policy: PI0Policy):
"""Configure input and output features for a policy."""
policy.config.input_features[OBS_IMAGE] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224), # Channel-first RGB image
)
policy.config.input_features[OBS_STATE] = PolicyFeature(
type=FeatureType.STATE,
shape=(8,), # 8-dimensional state vector
)
policy.config.output_features[ACTION] = PolicyFeature(
type=FeatureType.ACTION,
shape=(8,), # 8-dimensional action vector
)
# Add dummy normalization buffers to the policy (like openpi does with norm_stats)
if hasattr(policy, "normalize_inputs"):
# For observation.state (8-dim state vector)
policy.normalize_inputs.register_buffer(
f"buffer_{OBS_STATE.replace('.', '_')}_mean", torch.zeros(8, device=device)
)
policy.normalize_inputs.register_buffer(
f"buffer_{OBS_STATE.replace('.', '_')}_std", torch.ones(8, device=device)
)
# For observation.image (3x224x224 image)
policy.normalize_inputs.register_buffer(
f"buffer_{OBS_IMAGE.replace('.', '_')}_mean", torch.zeros(3, 224, 224, device=device)
)
policy.normalize_inputs.register_buffer(
f"buffer_{OBS_IMAGE.replace('.', '_')}_std", torch.ones(3, 224, 224, device=device)
)
print("Configuring features for both policies...")
configure_policy_features(official_policy)
configure_policy_features(custom_policy)
# Verify that the models have identical parameters
print("\n=== Model Parameter Comparison ===")
official_params = dict(official_policy.named_parameters())
custom_params = dict(custom_policy.named_parameters())
param_differences = []
for name in official_params.keys():
if name not in custom_params:
param_differences.append(f"Missing parameter in custom model: {name}")
else:
diff = torch.abs(official_params[name] - custom_params[name]).max().item()
if diff > 1e-8:
param_differences.append(f"Parameter {name}: max difference = {diff:.2e}")
for name in custom_params.keys():
if name not in official_params:
param_differences.append(f"Extra parameter in custom model: {name}")
if param_differences:
print("Parameter differences found:")
for diff in param_differences[:10]: # Show first 10 differences
print(f" {diff}")
if len(param_differences) > 10:
print(f" ... and {len(param_differences) - 10} more differences")
else:
print("All model parameters are identical!")
# Get the raw models for direct comparison
official_raw_model = official_policy.model
custom_raw_model = custom_policy.model
print("\n=== Model Details ===")
print(f"Official raw model type: {type(official_raw_model)}")
print(f"Custom raw model type: {type(custom_raw_model)}")
print(f"Official model device: {next(official_raw_model.parameters()).device}")
print(f"Custom model device: {next(custom_raw_model.parameters()).device}")
# Create lerobot-format input data (similar to DROID format from openpi example)
example = {
"joint_position": np.zeros(7, dtype=np.float32),
"gripper_position": np.array([0.0], dtype=np.float32),
"image": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
"task": "pick up the object",
}
print(f"\nProvided input keys: {list(example.keys())}")
print("\nPreparing inputs for direct model call...")
# Apply input transformation (similar to openpi's policy._input_transform)
transformed_example = {}
# Combine joint and gripper positions into state
transformed_example[OBS_STATE] = np.concatenate([example["joint_position"], example["gripper_position"]])
transformed_example[OBS_IMAGE] = example["image"]
transformed_example["task"] = example["task"]
# Convert to PyTorch tensors and add batch dimension (as openpi example does)
# Device is already defined above, use the official model device for consistency
pytorch_inputs = {}
for key, value in transformed_example.items():
if isinstance(value, np.ndarray):
tensor_value = torch.from_numpy(value).to(device)
# Add batch dimension
if tensor_value.dim() > 0:
tensor_value = tensor_value.unsqueeze(0)
pytorch_inputs[key] = tensor_value
elif isinstance(value, str):
pytorch_inputs[key] = [value] # Convert to list format expected by policy
else:
pytorch_inputs[key] = value
# Convert image from HWC to CHW format for lerobot
if OBS_IMAGE in pytorch_inputs:
img = pytorch_inputs[OBS_IMAGE]
if img.dim() == 4 and img.shape[-1] == 3: # BHWC -> BCHW
img = img.permute(0, 3, 1, 2)
# Convert to float and normalize to [0, 1] range
img = img.float() / 255.0
pytorch_inputs[OBS_IMAGE] = img
print(f"Transformed input keys: {list(pytorch_inputs.keys())}")
for key, value in pytorch_inputs.items():
if isinstance(value, torch.Tensor):
print(f" {key}: {value.shape} {value.dtype}")
else:
print(f" {key}: {type(value)} - {value}")
# Reset both policies (clears the action queue)
official_policy.reset()
custom_policy.reset()
# Prepare inputs using the official policy (both models should have same preprocessing)
print("Preparing inputs for both models...")
images, img_masks = official_policy.prepare_images(pytorch_inputs)
lang_tokens, lang_masks = official_policy.prepare_language(pytorch_inputs)
state = official_policy.prepare_state(pytorch_inputs)
print("Prepared inputs:")
print(f" Images: {len(images)} images")
print(f" Language tokens shape: {lang_tokens.shape}")
print(f" State shape: {state.shape}")
for i, img in enumerate(images):
print(f" Image {i} shape: {img.shape}")
for i, mask in enumerate(img_masks):
print(f" Image mask {i} shape: {mask.shape}")
# Compare both models with identical inputs
print("\n🚀 Running MODEL COMPARISON...")
# Force torch.no_grad for consistent comparison
with torch.no_grad():
# Ensure reproducible noise generation for both models
torch.manual_seed(RANDOM_SEED)
# Generate synthetic noise and time for the forward call
batch_size = 1
actions_shape = (
batch_size,
official_raw_model.config.n_action_steps,
official_raw_model.config.max_action_dim,
)
# Generate noise and time using direct PyTorch operations instead of model methods
# This avoids any potential model-specific randomness
torch.manual_seed(RANDOM_SEED)
noise = torch.normal(
mean=0.0,
std=1.0,
size=actions_shape,
dtype=torch.float32,
device=device,
)
# Generate time using the same distribution as PI0FlowMatching.sample_time
torch.manual_seed(RANDOM_SEED) # Reset for consistent time
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
time_beta = beta_dist.sample((batch_size,)).to(device=device, dtype=torch.float32)
time = time_beta * 0.999 + 0.001
print("\n=== Generated Inputs ===")
print(f" Action shape: {actions_shape}")
print(f" Noise shape: {noise.shape}")
print(f" Time value: {time.item():.6f}")
print(f" Noise sample (first 5 values): {noise.flatten()[:5].tolist()}")
# Create dummy actions for forward pass (required for training forward)
dummy_actions = torch.zeros(actions_shape, dtype=torch.float32, device=device)
print("\n=== Running Forward Passes ===")
print("Running with model_addition_debugger_context for detailed analysis...")
# Create separate debug paths for each model
official_debug_path = os.path.join(debug_path, "official_model")
custom_debug_path = os.path.join(debug_path, "custom_model")
os.makedirs(official_debug_path, exist_ok=True)
os.makedirs(custom_debug_path, exist_ok=True)
# Set deterministic mode for forward pass
torch.manual_seed(RANDOM_SEED)
# Run official model with debugger
print("Running Official Model forward pass with debugger...")
with model_addition_debugger_context(
official_raw_model,
debug_path=official_debug_path,
do_prune_layers=False, # Output ALL layers
use_repr=not SAVE_TENSORS_TO_DISK,
):
official_loss = official_raw_model.forward(
images=images,
img_masks=img_masks,
lang_tokens=lang_tokens,
lang_masks=lang_masks,
state=state,
actions=dummy_actions,
noise=noise,
time=time,
)
# Reset seed before second forward pass to ensure any internal randomness is identical
torch.manual_seed(RANDOM_SEED)
# Run custom model with debugger
print("Running Custom Model forward pass with debugger...")
with model_addition_debugger_context(
custom_raw_model,
debug_path=custom_debug_path,
do_prune_layers=False, # Output ALL layers
use_repr=not SAVE_TENSORS_TO_DISK,
):
custom_loss = custom_raw_model.forward(
images=images,
img_masks=img_masks,
lang_tokens=lang_tokens,
lang_masks=lang_masks,
state=state,
actions=dummy_actions,
noise=noise,
time=time,
)
print(f"Official model debug outputs saved to: {official_debug_path}")
print(f"Custom model debug outputs saved to: {custom_debug_path}")
print("\n=== Output Comparison ===")
print(f"Official model loss shape: {official_loss.shape}")
print(f"Custom model loss shape: {custom_loss.shape}")
# Compare outputs
loss_diff = torch.abs(official_loss - custom_loss)
print("\n=== Detailed Comparison ===")
print("Loss difference stats:")
print(f" Mean absolute difference: {loss_diff.mean().item():.8f}")
print(f" Max absolute difference: {loss_diff.max().item():.8f}")
print(f" Min absolute difference: {loss_diff.min().item():.8f}")
print(f" Standard deviation of difference: {loss_diff.std().item():.8f}")
# Show some actual values for comparison
print("\nSample output values:")
print(f" Official model (first 5): {official_loss.flatten()[:5].tolist()}")
print(f" Custom model (first 5): {custom_loss.flatten()[:5].tolist()}")
print(f" Difference (first 5): {loss_diff.flatten()[:5].tolist()}")
# Determine if models are equivalent
are_equivalent = loss_diff.max().item() < 1e-6
print(f"\nModels are {'EQUIVALENT' if are_equivalent else 'DIFFERENT'}")
print(f" (Max difference: {loss_diff.max().item():.8f}, Threshold: 1e-6)")
print(f"\nDetailed debugging outputs saved to: {debug_path}")
# Save comparison results
comparison_results = {
"official_loss_stats": {
"shape": list(official_loss.shape),
"mean": official_loss.mean().item(),
"std": official_loss.std().item(),
"min": official_loss.min().item(),
"max": official_loss.max().item(),
},
"custom_loss_stats": {
"shape": list(custom_loss.shape),
"mean": custom_loss.mean().item(),
"std": custom_loss.std().item(),
"min": custom_loss.min().item(),
"max": custom_loss.max().item(),
},
"difference_stats": {
"mean_abs_diff": loss_diff.mean().item(),
"max_abs_diff": loss_diff.max().item(),
"min_abs_diff": loss_diff.min().item(),
"std_diff": loss_diff.std().item(),
"are_equivalent": are_equivalent,
},
}
comparison_file = os.path.join(debug_path, "model_comparison_results.json")
with open(comparison_file, "w") as f:
json.dump(comparison_results, f, indent=2)
print(f" Comparison results saved to: {comparison_file}")
# Save and upload transformed model if requested
if SAVE_TRANSFORMED_MODEL:
print("\nSaving Transformed Model...")
if are_equivalent:
print("Models are equivalent - proceeding with transformation and upload")
else:
print("Models are NOT equivalent, but proceeding with upload anyway")
print(f" Max difference: {loss_diff.max().item():.2e}")
print(" This might be useful for debugging or partial transformations")
# Create timestamp for README
transformation_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
try:
# Use the already working custom policy as the base for transformation
print("Using already working custom policy as base for transformed model...")
# Deep copy the custom policy to create the transformed version
from copy import deepcopy
transformed_policy = deepcopy(custom_policy)
print("Custom policy copied successfully - no additional configuration needed")
# Save locally first
local_save_path = "./transformed_pi0_model"
print(f"Saving transformed model locally to: {local_save_path}")
transformed_policy.save_pretrained(local_save_path, safe_serialization=True)
# Save the tokenizer as well (required for complete model)
transformed_policy.language_tokenizer.save_pretrained(local_save_path)
# Create a README with transformation details
readme_content = f"""
# PI0 Model - LeRobot Compatible Format
This model is a transformed version of `{custom_model_path}` with key names corrected to match the official LeRobot PI0 format.
## Transformation Applied
The original model had a different key naming convention. This model applies the following transformations:
1. **Model prefix**: Added `model.` prefix to all parameter keys
2. **Tied weights**: Applied PI0Policy's built-in tied weights logic to create `embed_tokens.weight` from `lm_head.weight`
3. **Key structure**: Applied standard PI0 key transformations for compatibility
## Verification
{"This transformed model produces **identical outputs**" if are_equivalent else "This transformed model has **slightly different outputs**"} (max difference = {loss_diff.max().item():.2e}) compared to the official model `{official_model_path}` when tested with the same inputs.
{"**Models are EQUIVALENT** (difference < 1e-6)" if are_equivalent else "**Models are NOT equivalent** (difference >= 1e-6) - use with caution"}
## Usage
```python
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
# Load the model
policy = PI0Policy.from_pretrained("{TRANSFORMED_MODEL_NAME}")
# Use for inference
action = policy.select_action(observation_batch)
```
## Original Model
- **Source**: {custom_model_path}
- **Verified Against**: {official_model_path}
## Technical Details
- **Total Parameters**: {sum(p.numel() for p in transformed_policy.parameters()):,}
- **Model Type**: PI0FlowMatching with PaliGemma + Expert Gemma
- **Configuration**: Matches official PI0 configuration
"""
readme_path = os.path.join(local_save_path, "README.md")
with open(readme_path, "w") as f:
f.write(readme_content.strip())
print(f"Model saved locally to: {local_save_path}")
# Upload to HuggingFace Hub if requested
if UPLOAD_TO_HUB:
print(f"\nUploading to HuggingFace Hub: {TRANSFORMED_MODEL_NAME}")
try:
# Push to hub
transformed_policy.push_to_hub(
repo_id=TRANSFORMED_MODEL_NAME,
commit_message=COMMIT_MESSAGE,
private=False, # Make it public
safe_serialization=True,
)
print(f"Model successfully uploaded to: https://huggingface.co/{TRANSFORMED_MODEL_NAME}")
print("You can now use this model directly without any transformations!")
print("\n Usage:")
print(" from lerobot.policies.pi0.modeling_pi0 import PI0Policy")
print(f" policy = PI0Policy.from_pretrained('{TRANSFORMED_MODEL_NAME}')")
except Exception as upload_error:
print(f"Failed to upload to HuggingFace Hub: {upload_error}")
print(f"You can manually upload the model from: {local_save_path}")
print(" Or set UPLOAD_TO_HUB = False and upload later")
except Exception as e:
import traceback
print(f"Error saving transformed model: {str(e)}")
print("Full traceback:")
traceback.print_exc()
print("The model transformation logic works, but saving failed")
else:
print("\nModel transformation and upload disabled (SAVE_TRANSFORMED_MODEL = False)")
+2
View File
@@ -106,6 +106,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31"]
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1"]
reachy2 = ["reachy2_sdk>=1.0.14"]
kinematics = ["lerobot[placo-dep]"]
intelrealsense = [
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
@@ -141,6 +142,7 @@ all = [
"lerobot[gamepad]",
"lerobot[hopejr]",
"lerobot[lekiwi]",
"lerobot[reachy2]",
"lerobot[kinematics]",
"lerobot[intelrealsense]",
"lerobot[pi0]",
@@ -0,0 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_reachy2_camera import Reachy2CameraConfig
from .reachy2_camera import Reachy2Camera
@@ -0,0 +1,78 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..configs import CameraConfig, ColorMode
@CameraConfig.register_subclass("reachy2_camera")
@dataclass
class Reachy2CameraConfig(CameraConfig):
"""Configuration class for Reachy 2 camera devices.
This class provides configuration options for Reachy 2 cameras,
supporting both the teleop and depth cameras. It includes settings
for resolution, frame rate, color mode, and the selection of the cameras.
Example configurations:
```python
# Basic configurations
Reachy2CameraConfig(
name="teleop",
image_type="left",
ip_address="192.168.0.200", # IP address of the robot
fps=15,
width=640,
height=480,
color_mode=ColorMode.RGB,
) # Left teleop camera, 640x480 @ 15FPS
```
Attributes:
name: Name of the camera device. Can be "teleop" or "depth".
image_type: Type of image stream. For "teleop" camera, can be "left" or "right".
For "depth" camera, can be "rgb" or "depth". (depth is not supported yet)
fps: Requested frames per second for the color stream.
width: Requested frame width in pixels for the color stream.
height: Requested frame height in pixels for the color stream.
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
ip_address: IP address of the robot. Defaults to "localhost".
port: Port number for the camera server. Defaults to 50065.
Note:
- Only 3-channel color output (RGB/BGR) is currently supported.
"""
name: str
image_type: str
color_mode: ColorMode = ColorMode.RGB
ip_address: str | None = "localhost"
port: int = 50065
# use_depth: bool = False
def __post_init__(self):
if self.name not in ["teleop", "depth"]:
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (
self.name == "depth" and self.image_type not in ["rgb", "depth"]
):
raise ValueError(
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
)
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
@@ -0,0 +1,288 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager.
"""
import logging
import os
import platform
import time
from threading import Event, Lock, Thread
from typing import Any
# Fix MSMF hardware transform compatibility for Windows before importing cv2
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2
import numpy as np
from reachy2_sdk.media.camera import CameraView
from reachy2_sdk.media.camera_manager import CameraManager
from lerobot.errors import DeviceNotConnectedError
from ..camera import Camera
from .configuration_reachy2_camera import ColorMode, Reachy2CameraConfig
logger = logging.getLogger(__name__)
class Reachy2Camera(Camera):
"""
Manages Reachy 2 camera using Reachy 2 CameraManager.
This class provides a high-level interface to connect to, configure, and read
frames from Reachy 2 cameras. It supports both synchronous and asynchronous
frame reading.
An Reachy2Camera instance requires a camera name (e.g., "teleop") and an image
type (e.g., "left") to be specified in the configuration.
The camera's default settings (FPS, resolution, color mode) are used unless
overridden in the configuration.
"""
def __init__(self, config: Reachy2CameraConfig):
"""
Initializes the Reachy2Camera instance.
Args:
config: The configuration settings for the camera.
"""
super().__init__(config)
self.config = config
self.fps = config.fps
self.color_mode = config.color_mode
self.cam_manager: CameraManager | None = None
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: np.ndarray | None = None
self.new_frame_event: Event = Event()
def __str__(self) -> str:
return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})"
@property
def is_connected(self) -> bool:
"""Checks if the camera is currently connected and opened."""
if self.config.name == "teleop":
return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
elif self.config.name == "depth":
return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
else:
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
def connect(self, warmup: bool = True):
"""
Connects to the Reachy2 CameraManager as specified in the configuration.
"""
self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port)
self.cam_manager.initialize_cameras()
logger.info(f"{self} connected.")
@staticmethod
def find_cameras(ip_address: str = "localhost", port: int = 50065) -> list[dict[str, Any]]:
"""
Detects available Reachy 2 cameras.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains 'name', 'stereo',
and the default profile properties (width, height, fps).
"""
initialized_cameras = []
camera_manager = CameraManager(host=ip_address, port=port)
for camera in [camera_manager.teleop, camera_manager.depth]:
if camera is None:
continue
height, width, _, _, _, _, _ = camera.get_parameters()
camera_info = {
"name": camera._cam_info.name,
"stereo": camera._cam_info.stereo,
"default_profile": {
"width": width,
"height": height,
"fps": 30,
},
}
initialized_cameras.append(camera_info)
camera_manager.disconnect()
return initialized_cameras
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
"""
Reads a single frame synchronously from the camera.
This is a blocking call.
Args:
color_mode (Optional[ColorMode]): If specified, overrides the default
color mode (`self.color_mode`) for this read operation (e.g.,
request RGB even if default is BGR).
Returns:
np.ndarray: The captured frame as a NumPy array in the format
(height, width, channels), using the specified or default
color mode and applying any configured rotation.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start_time = time.perf_counter()
frame = None
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
else:
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
if self.config.image_type == "left":
frame = self.cam_manager.teleop.get_frame(CameraView.LEFT, size=(640, 480))[0]
elif self.config.image_type == "right":
frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT, size=(640, 480))[0]
elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"):
if self.config.image_type == "depth":
frame = self.cam_manager.depth.get_depth_frame()[0]
elif self.config.image_type == "rgb":
frame = self.cam_manager.depth.get_frame(size=(640, 480))[0]
if frame is None:
return np.empty((0, 0, 3), dtype=np.uint8)
if self.config.color_mode == "rgb":
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return frame
def _read_loop(self):
"""
Internal loop run by the background thread for asynchronous reading.
On each iteration:
1. Reads a color frame
2. Stores result in latest_frame (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
while not self.stop_event.is_set():
try:
color_image = self.read()
with self.frame_lock:
self.latest_frame = color_image
self.new_frame_event.set()
except DeviceNotConnectedError:
break
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=0.1)
if self.stop_event is not None:
self.stop_event.set()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
self.thread.daemon = True
self.thread.start()
def _stop_read_thread(self) -> None:
"""Signals the background read thread to stop and waits for it to join."""
if self.stop_event is not None:
self.stop_event.set()
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
"""
Reads the latest available frame asynchronously.
This method retrieves the most recent frame captured by the background
read thread. It does not block waiting for the camera hardware directly,
but may wait up to timeout_ms for the background thread to provide a frame.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
to become available. Defaults to 200ms (0.2 seconds).
Returns:
np.ndarray: The latest captured frame as a NumPy array in the format
(height, width, channels), processed according to configuration.
Raises:
DeviceNotConnectedError: If the camera is not connected.
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
self._start_read_thread()
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
thread_alive = self.thread is not None and self.thread.is_alive()
raise TimeoutError(
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
f"Read thread alive: {thread_alive}."
)
with self.frame_lock:
frame = self.latest_frame
self.new_frame_event.clear()
if frame is None:
raise RuntimeError(f"Internal error: Event set but no frame available for {self}.")
return frame
def disconnect(self):
"""
Stops the background read thread (if running).
Raises:
DeviceNotConnectedError: If the camera is already disconnected.
"""
if not self.is_connected and self.thread is None:
raise DeviceNotConnectedError(f"{self} not connected.")
if self.thread is not None:
self._stop_read_thread()
if self.cam_manager is not None:
self.cam_manager.disconnect()
logger.info(f"{self} disconnected.")
+7 -1
View File
@@ -37,8 +37,14 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
from .realsense.camera_realsense import RealSenseCamera
cameras[key] = RealSenseCamera(cfg)
elif cfg.type == "reachy2_camera":
from .reachy2_camera.reachy2_camera import Reachy2Camera
cameras[key] = Reachy2Camera(cfg)
else:
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
return cameras
+2
View File
@@ -825,6 +825,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
if not episode_data:
episode_buffer = self.episode_buffer
else:
episode_buffer = episode_data
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
@@ -13,20 +13,22 @@
# limitations under the License.
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
2.1. It will:
This script will help you download any LeRobot dataset from the hub, convert it to the latest format, and
upload it to your own repository. It will:
- Download the dataset from any source repository
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
- Check consistency between these new stats and the old ones.
- Remove the deprecated `stats.json`.
- Update codebase_version in `info.json`.
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
- Update codebase_version in `info.json` to the latest version
- Create proper version tags
- Push the converted dataset to your specified destination repository
Usage:
```bash
python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 \
--repo-id=aliberts/koch_tutorial
--source-repo-id=IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot \
--dest-repo-id=your-username/libero_spatial_converted \
--episodes=0,1,2,3,4
```
"""
@@ -37,8 +39,8 @@ import logging
from huggingface_hub import HfApi
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, write_info
from lerobot.datasets.v21.convert_stats import convert_stats
V20 = "v2.0"
V21 = "v2.1"
@@ -54,48 +56,133 @@ class SuppressWarnings:
def convert_dataset(
repo_id: str,
source_repo_id: str,
dest_repo_id: str | None = None,
episodes: str | None = None,
branch: str | None = None,
num_workers: int = 4,
force_cache_sync: bool = True,
):
with SuppressWarnings():
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
"""
Download a dataset from source_repo_id, convert it, and upload to dest_repo_id.
Args:
source_repo_id: Source repository to download from
dest_repo_id: Destination repository to upload to (defaults to source_repo_id)
episodes: Comma-separated list of episode indices to include (e.g. "0,1,2,3")
branch: Branch to upload to
num_workers: Number of workers for stats computation
force_cache_sync: Whether to force cache synchronization
"""
if dest_repo_id is None:
dest_repo_id = source_repo_id
# Parse episodes list if provided
episode_list = None
if episodes:
try:
episode_list = [int(ep.strip()) for ep in episodes.split(",")]
print(f"Loading episodes: {episode_list}")
except ValueError as e:
raise ValueError(
f"Invalid episodes format '{episodes}'. Use comma-separated integers like '0,1,2,3'"
) from e
print(f"Downloading dataset from: {source_repo_id}")
# Try to load the dataset with different approaches to handle versioning issues
dataset = None
load_attempts = [
{"revision": None}, # Try latest first
{"revision": V20}, # Try v2.0
{"revision": "main"}, # Try main branch
]
for attempt in load_attempts:
try:
print(f"Attempting to load with revision: {attempt['revision']}")
with SuppressWarnings():
dataset = LeRobotDataset(
source_repo_id, episodes=episode_list, force_cache_sync=force_cache_sync, **attempt
)
print("Successfully loaded dataset!")
break
except Exception as e:
print(f"Failed with revision {attempt['revision']}: {e}")
continue
if dataset is None:
raise RuntimeError(f"Could not load dataset {source_repo_id} with any revision")
# Clean up old stats if present
if (dataset.root / EPISODES_STATS_PATH).is_file():
(dataset.root / EPISODES_STATS_PATH).unlink()
print("Removed existing episodes_stats.jsonl")
print("Converting stats to new format...")
convert_stats(dataset, num_workers=num_workers)
ref_stats = load_stats(dataset.root)
check_aggregate_stats(dataset, ref_stats)
# Update dataset info
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
write_info(dataset.meta.info, dataset.root)
print(f"Updated codebase_version to {CODEBASE_VERSION}")
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
# Change repo_id for destination if different
if dest_repo_id != source_repo_id:
print(f"Changing repository from {source_repo_id} to {dest_repo_id}")
dataset.repo_id = dest_repo_id
# delete old stats.json file
if (dataset.root / STATS_PATH).is_file:
print(f"Pushing converted dataset to: {dest_repo_id}")
dataset.push_to_hub(branch=branch, tag_version=False)
# Clean up old stats.json file locally and on hub
if (dataset.root / STATS_PATH).is_file():
(dataset.root / STATS_PATH).unlink()
print("Removed local stats.json file")
hub_api = HfApi()
if hub_api.file_exists(
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
):
hub_api.delete_file(
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
)
try:
if hub_api.file_exists(
repo_id=dest_repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
):
hub_api.delete_file(
path_in_repo=STATS_PATH, repo_id=dest_repo_id, revision=branch, repo_type="dataset"
)
print("Removed stats.json from hub")
except Exception as e:
print(f"Warning: Could not remove stats.json from hub: {e}")
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
# Create version tag
try:
hub_api.create_tag(dest_repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
print(f"Created tag {CODEBASE_VERSION} for {dest_repo_id}")
except Exception as e:
print(f"Warning: Could not create tag: {e}")
print(f"✅ Successfully converted and uploaded dataset to {dest_repo_id}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
description="Download, convert, and re-upload LeRobot datasets with proper versioning"
)
parser.add_argument(
"--repo-id",
"--source-repo-id",
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
help="Source repository identifier to download from (e.g. 'IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot')",
)
parser.add_argument(
"--dest-repo-id",
type=str,
default=None,
help="Destination repository identifier to upload to. Defaults to source-repo-id if not specified.",
)
parser.add_argument(
"--episodes",
type=str,
default=None,
help="Comma-separated list of episode indices to include (e.g. '0,1,2,3,4'). If not specified, all episodes are included.",
)
parser.add_argument(
"--branch",
@@ -109,6 +196,22 @@ if __name__ == "__main__":
default=4,
help="Number of workers for parallelizing stats compute. Defaults to 4.",
)
parser.add_argument(
"--no-cache-sync",
action="store_true",
help="Skip forcing cache synchronization (faster but may use cached data)",
)
args = parser.parse_args()
convert_dataset(**vars(args))
# Convert args to match function signature
convert_args = {
"source_repo_id": args.source_repo_id,
"dest_repo_id": args.dest_repo_id,
"episodes": args.episodes,
"branch": args.branch,
"num_workers": args.num_workers,
"force_cache_sync": not args.no_cache_sync,
}
convert_dataset(**convert_args)
+2
View File
@@ -107,6 +107,8 @@ X_SERIES_ENCODINGS_TABLE = {
"Goal_PWM": X_SERIES_CONTROL_TABLE["Goal_PWM"][1],
"Goal_Current": X_SERIES_CONTROL_TABLE["Goal_Current"][1],
"Goal_Velocity": X_SERIES_CONTROL_TABLE["Goal_Velocity"][1],
"Goal_Position": X_SERIES_CONTROL_TABLE["Goal_Position"][1],
"Present_Position": X_SERIES_CONTROL_TABLE["Present_Position"][1],
"Present_PWM": X_SERIES_CONTROL_TABLE["Present_PWM"][1],
"Present_Current": X_SERIES_CONTROL_TABLE["Present_Current"][1],
"Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1],
+8 -1
View File
@@ -209,7 +209,14 @@ def record_loop(
(
t
for t in teleop
if isinstance(t, (so100_leader.SO100Leader, so101_leader.SO101Leader, koch_leader.KochLeader))
if isinstance(
t,
(
so100_leader.SO100Leader,
so101_leader.SO101Leader,
koch_leader.KochLeader,
),
)
),
None,
)
+1
View File
@@ -55,6 +55,7 @@ from lerobot.robots import ( # noqa: F401
hope_jr,
koch_follower,
make_robot_from_config,
reachy2,
so100_follower,
so101_follower,
)
@@ -29,10 +29,10 @@ class BiSO100FollowerConfig(RobotConfig):
# Optional
left_arm_disable_torque_on_disconnect: bool = True
left_arm_max_relative_target: int | None = None
left_arm_max_relative_target: float | dict[str, float] | None = None
left_arm_use_degrees: bool = False
right_arm_disable_torque_on_disconnect: bool = True
right_arm_max_relative_target: int | None = None
right_arm_max_relative_target: float | dict[str, float] | None = None
right_arm_use_degrees: bool = False
# cameras (shared between both arms)
+3 -3
View File
@@ -44,8 +44,8 @@ class HopeJrArmConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -28,9 +28,9 @@ class KochFollowerConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -110,6 +110,7 @@ class KochFollower(Robot):
return self.bus.is_calibrated
def calibrate(self) -> None:
self.bus.disable_torque()
if self.calibration:
# Calibration file exists, ask user whether to use it or run new calibration
user_input = input(
@@ -120,7 +121,6 @@ class KochFollower(Robot):
self.bus.write_calibration(self.calibration)
return
logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
+3 -3
View File
@@ -39,9 +39,9 @@ class LeKiwiConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
+25
View File
@@ -0,0 +1,25 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_reachy2 import Reachy2RobotConfig
from .robot_reachy2 import (
REACHY2_ANTENNAS_JOINTS,
REACHY2_L_ARM_JOINTS,
REACHY2_NECK_JOINTS,
REACHY2_R_ARM_JOINTS,
REACHY2_VEL,
Reachy2Robot,
)
@@ -0,0 +1,107 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from lerobot.cameras.configs import ColorMode
from lerobot.cameras.reachy2_camera import Reachy2CameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("reachy2")
@dataclass
class Reachy2RobotConfig(RobotConfig):
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors.
max_relative_target: float | None = None
# IP address of the Reachy 2 robot
ip_address: str | None = "localhost"
# If True, turn_off_smoothly() will be sent to the robot before disconnecting.
disable_torque_on_disconnect: bool = False
# Tag for external commands control
# Set to True if you use an external commands system to control the robot,
# such as the official teleoperation application: https://github.com/pollen-robotics/Reachy2Teleoperation
# If True, robot.send_action() will not send commands to the robot.
use_external_commands: bool = False
# Robot parts
# Set to False to not add the corresponding joints part to the robot list of joints.
# By default, all parts are set to True.
with_mobile_base: bool = True
with_l_arm: bool = True
with_r_arm: bool = True
with_neck: bool = True
with_antennas: bool = True
# Robot cameras
# Set to True if you want to use the corresponding cameras in the observations.
# By default, only the teleop cameras are used.
with_left_teleop_camera: bool = True
with_right_teleop_camera: bool = True
with_torso_camera: bool = False
cameras: dict[str, CameraConfig] = field(default_factory=dict)
def __post_init__(self) -> None:
# Add cameras with same ip_address as the robot
if self.with_left_teleop_camera:
self.cameras["teleop_left"] = Reachy2CameraConfig(
name="teleop",
image_type="left",
ip_address=self.ip_address,
fps=15,
width=640,
height=480,
color_mode=ColorMode.RGB,
)
if self.with_right_teleop_camera:
self.cameras["teleop_right"] = Reachy2CameraConfig(
name="teleop",
image_type="right",
ip_address=self.ip_address,
fps=15,
width=640,
height=480,
color_mode=ColorMode.RGB,
)
if self.with_torso_camera:
self.cameras["torso_rgb"] = Reachy2CameraConfig(
name="depth",
image_type="rgb",
ip_address=self.ip_address,
fps=15,
width=640,
height=480,
color_mode=ColorMode.RGB,
)
super().__post_init__()
if not (
self.with_mobile_base
or self.with_l_arm
or self.with_r_arm
or self.with_neck
or self.with_antennas
):
raise ValueError(
"No Reachy2Robot part used.\n"
"At least one part of the robot must be set to True "
"(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)"
)
+230
View File
@@ -0,0 +1,230 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import Any
import numpy as np
from reachy2_sdk import ReachySDK
from lerobot.cameras.utils import make_cameras_from_configs
from ..robot import Robot
from ..utils import ensure_safe_goal_position
from .configuration_reachy2 import Reachy2RobotConfig
# {lerobot_keys: reachy2_sdk_keys}
REACHY2_NECK_JOINTS = {
"neck_yaw.pos": "head.neck.yaw",
"neck_pitch.pos": "head.neck.pitch",
"neck_roll.pos": "head.neck.roll",
}
REACHY2_ANTENNAS_JOINTS = {
"l_antenna.pos": "head.l_antenna",
"r_antenna.pos": "head.r_antenna",
}
REACHY2_R_ARM_JOINTS = {
"r_shoulder_pitch.pos": "r_arm.shoulder.pitch",
"r_shoulder_roll.pos": "r_arm.shoulder.roll",
"r_elbow_yaw.pos": "r_arm.elbow.yaw",
"r_elbow_pitch.pos": "r_arm.elbow.pitch",
"r_wrist_roll.pos": "r_arm.wrist.roll",
"r_wrist_pitch.pos": "r_arm.wrist.pitch",
"r_wrist_yaw.pos": "r_arm.wrist.yaw",
"r_gripper.pos": "r_arm.gripper",
}
REACHY2_L_ARM_JOINTS = {
"l_shoulder_pitch.pos": "l_arm.shoulder.pitch",
"l_shoulder_roll.pos": "l_arm.shoulder.roll",
"l_elbow_yaw.pos": "l_arm.elbow.yaw",
"l_elbow_pitch.pos": "l_arm.elbow.pitch",
"l_wrist_roll.pos": "l_arm.wrist.roll",
"l_wrist_pitch.pos": "l_arm.wrist.pitch",
"l_wrist_yaw.pos": "l_arm.wrist.yaw",
"l_gripper.pos": "l_arm.gripper",
}
REACHY2_VEL = {
"mobile_base.vx": "vx",
"mobile_base.vy": "vy",
"mobile_base.vtheta": "vtheta",
}
class Reachy2Robot(Robot):
"""
[Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics.
"""
config_class = Reachy2RobotConfig
name = "reachy2"
def __init__(self, config: Reachy2RobotConfig):
super().__init__(config)
self.config = config
self.robot_type = self.config.type
self.use_external_commands = self.config.use_external_commands
self.reachy: None | ReachySDK = None
self.cameras = make_cameras_from_configs(config.cameras)
self.logs: dict[str, float] = {}
self.joints_dict: dict[str, str] = self._generate_joints_dict()
@property
def observation_features(self) -> dict[str, Any]:
return {**self.motors_features, **self.camera_features}
@property
def action_features(self) -> dict[str, type]:
return self.motors_features
@property
def camera_features(self) -> dict[str, tuple[int | None, int | None, int]]:
return {cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras}
@property
def motors_features(self) -> dict[str, type]:
if self.config.with_mobile_base:
return {
**dict.fromkeys(
self.joints_dict.keys(),
float,
),
**dict.fromkeys(
REACHY2_VEL.keys(),
float,
),
}
else:
return dict.fromkeys(self.joints_dict.keys(), float)
@property
def is_connected(self) -> bool:
return self.reachy.is_connected() if self.reachy is not None else False
def connect(self, calibrate: bool = False) -> None:
self.reachy = ReachySDK(self.config.ip_address)
if not self.is_connected:
raise ConnectionError()
for cam in self.cameras.values():
cam.connect()
self.configure()
def configure(self) -> None:
if self.reachy is not None:
self.reachy.turn_on()
self.reachy.reset_default_limits()
@property
def is_calibrated(self) -> bool:
return True
def calibrate(self) -> None:
pass
def _generate_joints_dict(self) -> dict[str, str]:
joints = {}
if self.config.with_neck:
joints.update(REACHY2_NECK_JOINTS)
if self.config.with_l_arm:
joints.update(REACHY2_L_ARM_JOINTS)
if self.config.with_r_arm:
joints.update(REACHY2_R_ARM_JOINTS)
if self.config.with_antennas:
joints.update(REACHY2_ANTENNAS_JOINTS)
return joints
def _get_state(self) -> dict[str, float]:
if self.reachy is not None:
pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()}
if not self.config.with_mobile_base:
return pos_dict
vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
return {**pos_dict, **vel_dict}
else:
return {}
def get_observation(self) -> dict[str, np.ndarray]:
obs_dict: dict[str, Any] = {}
# Read Reachy 2 state
before_read_t = time.perf_counter()
obs_dict.update(self._get_state())
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
# Capture images from cameras
for cam_key, cam in self.cameras.items():
obs_dict[cam_key] = cam.async_read()
return obs_dict
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
if self.reachy is not None:
if not self.is_connected:
raise ConnectionError()
before_write_t = time.perf_counter()
vel = {}
goal_pos = {}
for key, val in action.items():
if key not in self.joints_dict:
if key not in REACHY2_VEL:
raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.")
else:
vel[REACHY2_VEL[key]] = float(val)
else:
if not self.use_external_commands and self.config.max_relative_target is not None:
goal_pos[key] = float(val)
goal_present_pos = {
key: (
goal_pos[key],
self.reachy.joints[self.joints_dict[key]].present_position,
)
}
safe_goal_pos = ensure_safe_goal_position(
goal_present_pos, float(self.config.max_relative_target)
)
val = safe_goal_pos[key]
self.reachy.joints[self.joints_dict[key]].goal_position = float(val)
if self.config.with_mobile_base:
self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"])
# We don't send the goal positions if we control Reachy 2 externally
if not self.use_external_commands:
self.reachy.send_goal_positions()
if self.config.with_mobile_base:
self.reachy.mobile_base.send_speed_command()
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
return action
def disconnect(self) -> None:
if self.reachy is not None:
for cam in self.cameras.values():
cam.disconnect()
if self.config.disable_torque_on_disconnect:
self.reachy.turn_off_smoothly()
self.reachy.disconnect()
@@ -30,9 +30,9 @@ class SO100FollowerConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -161,6 +161,11 @@ class SO100Follower(Robot):
self.bus.write("I_Coefficient", motor, 0)
self.bus.write("D_Coefficient", motor, 32)
if motor == "gripper":
self.bus.write("Max_Torque_Limit", motor, 500) # 50% of max torque to avoid burnout
self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout
self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
@@ -30,9 +30,9 @@ class SO101FollowerConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -157,6 +157,13 @@ class SO101Follower(Robot):
self.bus.write("I_Coefficient", motor, 0)
self.bus.write("D_Coefficient", motor, 32)
if motor == "gripper":
self.bus.write(
"Max_Torque_Limit", motor, 500
) # 50% of the max torque limit to avoid burnout
self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout
self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
@@ -24,11 +24,6 @@ from ..config import RobotConfig
@RobotConfig.register_subclass("stretch3")
@dataclass
class Stretch3RobotConfig(RobotConfig):
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# cameras
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
+5 -1
View File
@@ -61,6 +61,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
from .bi_so100_follower import BiSO100Follower
return BiSO100Follower(config)
elif config.type == "reachy2":
from .reachy2 import Reachy2Robot
return Reachy2Robot(config)
elif config.type == "mock_robot":
from tests.mocks.mock_robot import MockRobot
@@ -70,7 +74,7 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
def ensure_safe_goal_position(
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float]
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[str, float]
) -> dict[str, float]:
"""Caps relative action target magnitude for safety."""
+3 -3
View File
@@ -28,15 +28,15 @@ class ViperXConfig(RobotConfig):
# /!\ FOR SAFETY, READ THIS /!\
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
# When you feel more confident with teleoperation or running the policy, you can extend
# this safety limit and even removing it by setting it to `null`.
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
max_relative_target: int | None = 5
max_relative_target: float | dict[str, float] = 5.0
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -302,11 +302,6 @@ class RobotClient:
self.logger.debug(f"Current latest action: {latest_action}")
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:
old_timesteps = [latest_action] # queue was empty
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:
@@ -88,6 +88,7 @@ class KochLeader(Teleoperator):
return self.bus.is_calibrated
def calibrate(self) -> None:
self.bus.disable_torque()
if self.calibration:
# Calibration file exists, ask user whether to use it or run new calibration
user_input = input(
@@ -98,7 +99,6 @@ class KochLeader(Teleoperator):
self.bus.write_calibration(self.calibration)
return
logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
@@ -0,0 +1,25 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
from .reachy2_teleoperator import (
REACHY2_ANTENNAS_JOINTS,
REACHY2_L_ARM_JOINTS,
REACHY2_NECK_JOINTS,
REACHY2_R_ARM_JOINTS,
REACHY2_VEL,
Reachy2Teleoperator,
)
@@ -0,0 +1,51 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("reachy2_teleoperator")
@dataclass
class Reachy2TeleoperatorConfig(TeleoperatorConfig):
# IP address of the Reachy 2 robot used as teleoperator
ip_address: str | None = "localhost"
# Whether to use the present position of the joints as actions
# if False, the goal position of the joints will be used
use_present_position: bool = False
# Which parts of the robot to use
with_mobile_base: bool = True
with_l_arm: bool = True
with_r_arm: bool = True
with_neck: bool = True
with_antennas: bool = True
def __post_init__(self):
if not (
self.with_mobile_base
or self.with_l_arm
or self.with_r_arm
or self.with_neck
or self.with_antennas
):
raise ValueError(
"No Reachy2Teleoperator part used.\n"
"At least one part of the robot must be set to True "
"(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)"
)
@@ -0,0 +1,164 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from reachy2_sdk import ReachySDK
from ..teleoperator import Teleoperator
from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
logger = logging.getLogger(__name__)
# {lerobot_keys: reachy2_sdk_keys}
REACHY2_NECK_JOINTS = {
"neck_yaw.pos": "head.neck.yaw",
"neck_pitch.pos": "head.neck.pitch",
"neck_roll.pos": "head.neck.roll",
}
REACHY2_ANTENNAS_JOINTS = {
"l_antenna.pos": "head.l_antenna",
"r_antenna.pos": "head.r_antenna",
}
REACHY2_R_ARM_JOINTS = {
"r_shoulder_pitch.pos": "r_arm.shoulder.pitch",
"r_shoulder_roll.pos": "r_arm.shoulder.roll",
"r_elbow_yaw.pos": "r_arm.elbow.yaw",
"r_elbow_pitch.pos": "r_arm.elbow.pitch",
"r_wrist_roll.pos": "r_arm.wrist.roll",
"r_wrist_pitch.pos": "r_arm.wrist.pitch",
"r_wrist_yaw.pos": "r_arm.wrist.yaw",
"r_gripper.pos": "r_arm.gripper",
}
REACHY2_L_ARM_JOINTS = {
"l_shoulder_pitch.pos": "l_arm.shoulder.pitch",
"l_shoulder_roll.pos": "l_arm.shoulder.roll",
"l_elbow_yaw.pos": "l_arm.elbow.yaw",
"l_elbow_pitch.pos": "l_arm.elbow.pitch",
"l_wrist_roll.pos": "l_arm.wrist.roll",
"l_wrist_pitch.pos": "l_arm.wrist.pitch",
"l_wrist_yaw.pos": "l_arm.wrist.yaw",
"l_gripper.pos": "l_arm.gripper",
}
REACHY2_VEL = {
"mobile_base.vx": "vx",
"mobile_base.vy": "vy",
"mobile_base.vtheta": "vtheta",
}
class Reachy2Teleoperator(Teleoperator):
"""
[Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics.
"""
config_class = Reachy2TeleoperatorConfig
name = "reachy2_specific"
def __init__(self, config: Reachy2TeleoperatorConfig):
super().__init__(config)
self.config = config
self.reachy: None | ReachySDK = None
self.joints_dict: dict[str, str] = self._generate_joints_dict()
def _generate_joints_dict(self) -> dict[str, str]:
joints = {}
if self.config.with_neck:
joints.update(REACHY2_NECK_JOINTS)
if self.config.with_l_arm:
joints.update(REACHY2_L_ARM_JOINTS)
if self.config.with_r_arm:
joints.update(REACHY2_R_ARM_JOINTS)
if self.config.with_antennas:
joints.update(REACHY2_ANTENNAS_JOINTS)
return joints
@property
def action_features(self) -> dict[str, type]:
if self.config.with_mobile_base:
return {
**dict.fromkeys(
self.joints_dict.keys(),
float,
),
**dict.fromkeys(
REACHY2_VEL.keys(),
float,
),
}
else:
return dict.fromkeys(self.joints_dict.keys(), float)
@property
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.reachy.is_connected() if self.reachy is not None else False
def connect(self, calibrate: bool = True) -> None:
self.reachy = ReachySDK(self.config.ip_address)
if not self.is_connected:
raise ConnectionError()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return True
def calibrate(self) -> None:
pass
def configure(self) -> None:
pass
def get_action(self) -> dict[str, float]:
start = time.perf_counter()
if self.reachy and self.is_connected:
if self.config.use_present_position:
joint_action = {
k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()
}
else:
joint_action = {k: self.reachy.joints[v].goal_position for k, v in self.joints_dict.items()}
if not self.config.with_mobile_base:
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return joint_action
if self.config.use_present_position:
vel_action = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
else:
vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return {**joint_action, **vel_action}
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError
def disconnect(self) -> None:
if self.reachy and self.is_connected:
self.reachy.disconnect()
+4
View File
@@ -65,5 +65,9 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
from .bi_so100_leader import BiSO100Leader
return BiSO100Leader(config)
elif config.type == "reachy2_teleoperator":
from .reachy2_teleoperator import Reachy2Teleoperator
return Reachy2Teleoperator(config)
else:
raise ValueError(config.type)
+177
View File
@@ -0,0 +1,177 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from lerobot.cameras.reachy2_camera import Reachy2Camera, Reachy2CameraConfig
from lerobot.errors import DeviceNotConnectedError
PARAMS = [
("teleop", "left"),
("teleop", "right"),
("depth", "rgb"),
# ("depth", "depth"), # Depth camera is not available yet
]
def _make_cam_manager_mock():
c = MagicMock(name="CameraManagerMock")
teleop = MagicMock(name="TeleopCam")
teleop.width = 640
teleop.height = 480
teleop.get_frame = MagicMock(
side_effect=lambda *_, **__: (
np.zeros((480, 640, 3), dtype=np.uint8),
time.time(),
)
)
depth = MagicMock(name="DepthCam")
depth.width = 640
depth.height = 480
depth.get_frame = MagicMock(
side_effect=lambda *_, **__: (
np.zeros((480, 640, 3), dtype=np.uint8),
time.time(),
)
)
c.is_connected.return_value = True
c.teleop = teleop
c.depth = depth
def _connect():
c.teleop = teleop
c.depth = depth
c.is_connected.return_value = True
def _disconnect():
c.teleop = None
c.depth = None
c.is_connected.return_value = False
c.connect = MagicMock(side_effect=_connect)
c.disconnect = MagicMock(side_effect=_disconnect)
# Mock methods
c.initialize_cameras = MagicMock()
return c
@pytest.fixture(
params=PARAMS,
# ids=["teleop-left", "teleop-right", "torso-rgb", "torso-depth"],
ids=["teleop-left", "teleop-right", "torso-rgb"],
)
def camera(request):
name, image_type = request.param
with (
patch(
"lerobot.cameras.reachy2_camera.reachy2_camera.CameraManager",
side_effect=lambda *a, **k: _make_cam_manager_mock(),
),
):
config = Reachy2CameraConfig(name=name, image_type=image_type)
cam = Reachy2Camera(config)
yield cam
if cam.is_connected:
cam.disconnect()
def test_connect(camera):
camera.connect()
assert camera.is_connected
camera.cam_manager.initialize_cameras.assert_called_once()
def test_read(camera):
camera.connect()
img = camera.read()
if camera.config.name == "teleop":
camera.cam_manager.teleop.get_frame.assert_called_once()
elif camera.config.name == "depth":
camera.cam_manager.depth.get_frame.assert_called_once()
assert isinstance(img, np.ndarray)
assert img.shape == (480, 640, 3)
def test_disconnect(camera):
camera.connect()
camera.disconnect()
assert not camera.is_connected
def test_async_read(camera):
camera.connect()
try:
img = camera.async_read()
assert camera.thread is not None
assert camera.thread.is_alive()
assert isinstance(img, np.ndarray)
finally:
if camera.is_connected:
camera.disconnect()
def test_async_read_timeout(camera):
camera.connect()
try:
with pytest.raises(TimeoutError):
camera.async_read(timeout_ms=0)
finally:
if camera.is_connected:
camera.disconnect()
def test_read_before_connect(camera):
with pytest.raises(DeviceNotConnectedError):
_ = camera.read()
def test_disconnect_before_connect(camera):
with pytest.raises(DeviceNotConnectedError):
camera.disconnect()
def test_async_read_before_connect(camera):
with pytest.raises(DeviceNotConnectedError):
_ = camera.async_read()
def test_wrong_camera_name():
with pytest.raises(ValueError):
_ = Reachy2CameraConfig(name="wrong-name", image_type="left")
def test_wrong_image_type():
with pytest.raises(ValueError):
_ = Reachy2CameraConfig(name="teleop", image_type="rgb")
with pytest.raises(ValueError):
_ = Reachy2CameraConfig(name="depth", image_type="left")
def test_wrong_color_mode():
with pytest.raises(ValueError):
_ = Reachy2CameraConfig(name="teleop", image_type="left", color_mode="wrong-color")
+1
View File
@@ -28,6 +28,7 @@ pytest_plugins = [
"tests.fixtures.files",
"tests.fixtures.hub",
"tests.fixtures.optimizers",
"tests.plugins.reachy2_sdk",
]
+30
View File
@@ -0,0 +1,30 @@
import sys
import types
from unittest.mock import MagicMock
def _install_reachy2_sdk_stub():
sdk = types.ModuleType("reachy2_sdk")
sdk.__path__ = []
sdk.ReachySDK = MagicMock(name="ReachySDK")
media = types.ModuleType("reachy2_sdk.media")
media.__path__ = []
camera = types.ModuleType("reachy2_sdk.media.camera")
camera.CameraView = MagicMock(name="CameraView")
camera_manager = types.ModuleType("reachy2_sdk.media.camera_manager")
camera_manager.CameraManager = MagicMock(name="CameraManager")
sdk.media = media
media.camera = camera
media.camera_manager = camera_manager
# Register in sys.modules
sys.modules.setdefault("reachy2_sdk", sdk)
sys.modules.setdefault("reachy2_sdk.media", media)
sys.modules.setdefault("reachy2_sdk.media.camera", camera)
sys.modules.setdefault("reachy2_sdk.media.camera_manager", camera_manager)
def pytest_sessionstart(session):
_install_reachy2_sdk_stub()
+326
View File
@@ -0,0 +1,326 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from lerobot.robots.reachy2 import (
REACHY2_ANTENNAS_JOINTS,
REACHY2_L_ARM_JOINTS,
REACHY2_NECK_JOINTS,
REACHY2_R_ARM_JOINTS,
REACHY2_VEL,
Reachy2Robot,
Reachy2RobotConfig,
)
# {lerobot_keys: reachy2_sdk_keys}
REACHY2_JOINTS = {
**REACHY2_NECK_JOINTS,
**REACHY2_ANTENNAS_JOINTS,
**REACHY2_R_ARM_JOINTS,
**REACHY2_L_ARM_JOINTS,
}
PARAMS = [
{}, # default config
{"with_mobile_base": False},
{"with_mobile_base": False, "with_l_arm": False, "with_antennas": False},
{"with_r_arm": False, "with_neck": False, "with_antennas": False},
{"use_external_commands": True, "disable_torque_on_disconnect": True},
{"use_external_commands": True, "with_mobile_base": False, "with_neck": False},
{"disable_torque_on_disconnect": False},
{"max_relative_target": 5},
{"with_right_teleop_camera": False},
{"with_left_teleop_camera": False, "with_right_teleop_camera": False},
{"with_left_teleop_camera": False, "with_torso_camera": True},
]
def _make_reachy2_sdk_mock():
class JointSpy:
__slots__ = (
"present_position",
"_goal_position",
"_on_set",
)
def __init__(self, present_position=0.0, on_set=None):
self.present_position = present_position
self._goal_position = present_position
self._on_set = on_set
@property
def goal_position(self):
return self._goal_position
@goal_position.setter
def goal_position(self, v):
self._goal_position = v
if self._on_set:
self._on_set()
r = MagicMock(name="ReachySDKMock")
r.is_connected.return_value = True
def _connect():
r.is_connected.return_value = True
def _disconnect():
r.is_connected.return_value = False
# Global counter of goal_position sets
r._goal_position_set_total = 0
def _on_any_goal_set():
r._goal_position_set_total += 1
# Mock joints with some dummy positions
joints = {
k: JointSpy(
present_position=float(i),
on_set=_on_any_goal_set,
)
for i, k in enumerate(REACHY2_JOINTS.values())
}
r.joints = joints
# Mock mobile base with some dummy odometry
r.mobile_base = MagicMock()
r.mobile_base.odometry = {
"x": 0.1,
"y": -0.2,
"theta": 21.3,
"vx": 0.001,
"vy": 0.002,
"vtheta": 0.0,
}
r.connect = MagicMock(side_effect=_connect)
r.disconnect = MagicMock(side_effect=_disconnect)
# Mock methods
r.turn_on = MagicMock()
r.reset_default_limits = MagicMock()
r.send_goal_positions = MagicMock()
r.turn_off_smoothly = MagicMock()
r.mobile_base.set_goal_speed = MagicMock()
r.mobile_base.send_speed_command = MagicMock()
return r
def _make_reachy2_camera_mock(*args, **kwargs):
cfg = args[0] if args else kwargs.get("config")
name = getattr(cfg, "name", kwargs.get("name", "cam"))
image_type = getattr(cfg, "image_type", kwargs.get("image_type", "cam"))
width = getattr(cfg, "width", kwargs.get("width", 640))
height = getattr(cfg, "height", kwargs.get("height", 480))
cam = MagicMock(name=f"Reachy2CameraMock:{name}")
cam.name = name
cam.image_type = image_type
cam.width = width
cam.height = height
cam.connect = MagicMock()
cam.disconnect = MagicMock()
cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
return cam
@pytest.fixture(params=PARAMS, ids=lambda p: "default" if not p else ",".join(p.keys()))
def reachy2(request):
with (
patch(
"lerobot.robots.reachy2.robot_reachy2.ReachySDK",
side_effect=lambda *a, **k: _make_reachy2_sdk_mock(),
),
patch(
"lerobot.cameras.reachy2_camera.reachy2_camera.Reachy2Camera",
side_effect=_make_reachy2_camera_mock,
),
):
overrides = request.param
cfg = Reachy2RobotConfig(ip_address="192.168.0.200", **overrides)
robot = Reachy2Robot(cfg)
yield robot
if robot.is_connected:
robot.disconnect()
def test_connect_disconnect(reachy2):
assert not reachy2.is_connected
reachy2.connect()
assert reachy2.is_connected
reachy2.reachy.turn_on.assert_called_once()
reachy2.reachy.reset_default_limits.assert_called_once()
reachy2.disconnect()
assert not reachy2.is_connected
if reachy2.config.disable_torque_on_disconnect:
reachy2.reachy.turn_off_smoothly.assert_called_once()
else:
reachy2.reachy.turn_off_smoothly.assert_not_called()
reachy2.reachy.disconnect.assert_called_once()
def test_get_joints_dict(reachy2):
reachy2.connect()
if reachy2.config.with_neck:
assert "neck_yaw.pos" in reachy2.joints_dict
assert "neck_pitch.pos" in reachy2.joints_dict
assert "neck_roll.pos" in reachy2.joints_dict
else:
assert "neck_yaw.pos" not in reachy2.joints_dict
assert "neck_pitch.pos" not in reachy2.joints_dict
assert "neck_roll.pos" not in reachy2.joints_dict
if reachy2.config.with_antennas:
assert "l_antenna.pos" in reachy2.joints_dict
assert "r_antenna.pos" in reachy2.joints_dict
else:
assert "l_antenna.pos" not in reachy2.joints_dict
assert "r_antenna.pos" not in reachy2.joints_dict
if reachy2.config.with_r_arm:
assert "r_shoulder_pitch.pos" in reachy2.joints_dict
assert "r_shoulder_roll.pos" in reachy2.joints_dict
assert "r_elbow_yaw.pos" in reachy2.joints_dict
assert "r_elbow_pitch.pos" in reachy2.joints_dict
assert "r_wrist_roll.pos" in reachy2.joints_dict
assert "r_wrist_pitch.pos" in reachy2.joints_dict
assert "r_wrist_yaw.pos" in reachy2.joints_dict
assert "r_gripper.pos" in reachy2.joints_dict
else:
assert "r_shoulder_pitch.pos" not in reachy2.joints_dict
assert "r_shoulder_roll.pos" not in reachy2.joints_dict
assert "r_elbow_yaw.pos" not in reachy2.joints_dict
assert "r_elbow_pitch.pos" not in reachy2.joints_dict
assert "r_wrist_roll.pos" not in reachy2.joints_dict
assert "r_wrist_pitch.pos" not in reachy2.joints_dict
assert "r_wrist_yaw.pos" not in reachy2.joints_dict
assert "r_gripper.pos" not in reachy2.joints_dict
if reachy2.config.with_l_arm:
assert "l_shoulder_pitch.pos" in reachy2.joints_dict
assert "l_shoulder_roll.pos" in reachy2.joints_dict
assert "l_elbow_yaw.pos" in reachy2.joints_dict
assert "l_elbow_pitch.pos" in reachy2.joints_dict
assert "l_wrist_roll.pos" in reachy2.joints_dict
assert "l_wrist_pitch.pos" in reachy2.joints_dict
assert "l_wrist_yaw.pos" in reachy2.joints_dict
assert "l_gripper.pos" in reachy2.joints_dict
else:
assert "l_shoulder_pitch.pos" not in reachy2.joints_dict
assert "l_shoulder_roll.pos" not in reachy2.joints_dict
assert "l_elbow_yaw.pos" not in reachy2.joints_dict
assert "l_elbow_pitch.pos" not in reachy2.joints_dict
assert "l_wrist_roll.pos" not in reachy2.joints_dict
assert "l_wrist_pitch.pos" not in reachy2.joints_dict
assert "l_wrist_yaw.pos" not in reachy2.joints_dict
assert "l_gripper.pos" not in reachy2.joints_dict
def test_get_observation(reachy2):
reachy2.connect()
obs = reachy2.get_observation()
expected_keys = set(reachy2.joints_dict)
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
expected_keys.update(reachy2.cameras.keys())
assert set(obs.keys()) == expected_keys
for motor in reachy2.joints_dict.keys():
assert obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
if reachy2.config.with_mobile_base:
for vel in REACHY2_VEL.keys():
assert obs[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
if reachy2.config.with_left_teleop_camera:
assert obs["teleop_left"].shape == (
reachy2.config.cameras["teleop_left"].height,
reachy2.config.cameras["teleop_left"].width,
3,
)
if reachy2.config.with_right_teleop_camera:
assert obs["teleop_right"].shape == (
reachy2.config.cameras["teleop_right"].height,
reachy2.config.cameras["teleop_right"].width,
3,
)
if reachy2.config.with_torso_camera:
assert obs["torso_rgb"].shape == (
reachy2.config.cameras["torso_rgb"].height,
reachy2.config.cameras["torso_rgb"].width,
3,
)
def test_send_action(reachy2):
reachy2.connect()
action = {k: i * 10.0 for i, k in enumerate(reachy2.joints_dict.keys(), start=1)}
if reachy2.config.with_mobile_base:
action.update({k: i * 0.1 for i, k in enumerate(REACHY2_VEL.keys(), start=1)})
previous_present_position = {
k: reachy2.reachy.joints[REACHY2_JOINTS[k]].present_position for k in reachy2.joints_dict.keys()
}
returned = reachy2.send_action(action)
if reachy2.config.max_relative_target is None:
assert returned == action
assert reachy2.reachy._goal_position_set_total == len(reachy2.joints_dict)
for motor in reachy2.joints_dict.keys():
expected_pos = action[motor]
real_pos = reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
if reachy2.config.max_relative_target is None:
assert real_pos == expected_pos
else:
assert real_pos == previous_present_position[motor] + np.sign(expected_pos) * min(
abs(expected_pos - real_pos), reachy2.config.max_relative_target
)
if reachy2.config.with_mobile_base:
goal_speed = [i * 0.1 for i, _ in enumerate(REACHY2_VEL.keys(), start=1)]
reachy2.reachy.mobile_base.set_goal_speed.assert_called_once_with(*goal_speed)
if reachy2.config.use_external_commands:
reachy2.reachy.send_goal_positions.assert_not_called()
if reachy2.config.with_mobile_base:
reachy2.reachy.mobile_base.send_speed_command.assert_not_called()
else:
reachy2.reachy.send_goal_positions.assert_called_once()
if reachy2.config.with_mobile_base:
reachy2.reachy.mobile_base.send_speed_command.assert_called_once()
def test_no_part_declared():
with pytest.raises(ValueError):
_ = Reachy2RobotConfig(
ip_address="192.168.0.200",
with_mobile_base=False,
with_l_arm=False,
with_r_arm=False,
with_neck=False,
with_antennas=False,
)
@@ -0,0 +1,150 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, patch
import pytest
from lerobot.teleoperators.reachy2_teleoperator import (
REACHY2_ANTENNAS_JOINTS,
REACHY2_L_ARM_JOINTS,
REACHY2_NECK_JOINTS,
REACHY2_R_ARM_JOINTS,
REACHY2_VEL,
Reachy2Teleoperator,
Reachy2TeleoperatorConfig,
)
# {lerobot_keys: reachy2_sdk_keys}
REACHY2_JOINTS = {
**REACHY2_NECK_JOINTS,
**REACHY2_ANTENNAS_JOINTS,
**REACHY2_R_ARM_JOINTS,
**REACHY2_L_ARM_JOINTS,
}
PARAMS = [
{}, # default config
{"with_mobile_base": False},
{"with_mobile_base": False, "with_l_arm": False, "with_antennas": False},
{"with_r_arm": False, "with_neck": False, "with_antennas": False},
{"with_mobile_base": False, "with_neck": False},
{"use_present_position": True},
]
def _make_reachy2_sdk_mock():
r = MagicMock(name="ReachySDKMock")
r.is_connected.return_value = True
def _connect():
r.is_connected.return_value = True
def _disconnect():
r.is_connected.return_value = False
# Mock joints with some dummy positions
joints = {
k: MagicMock(
present_position=float(i),
goal_position=float(i) + 0.5,
)
for i, k in enumerate(REACHY2_JOINTS.values())
}
r.joints = joints
# Mock mobile base with some dummy odometry
r.mobile_base = MagicMock()
r.mobile_base.last_cmd_vel = {
"vx": -0.2,
"vy": 0.2,
"vtheta": 11.0,
}
r.mobile_base.odometry = {
"x": 1.0,
"y": 2.0,
"theta": 20.0,
"vx": 0.1,
"vy": -0.1,
"vtheta": 8.0,
}
r.connect = MagicMock(side_effect=_connect)
r.disconnect = MagicMock(side_effect=_disconnect)
return r
@pytest.fixture(params=PARAMS, ids=lambda p: "default" if not p else ",".join(p.keys()))
def reachy2(request):
with (
patch(
"lerobot.teleoperators.reachy2_teleoperator.reachy2_teleoperator.ReachySDK",
side_effect=lambda *a, **k: _make_reachy2_sdk_mock(),
),
):
overrides = request.param
cfg = Reachy2TeleoperatorConfig(ip_address="192.168.0.200", **overrides)
robot = Reachy2Teleoperator(cfg)
yield robot
if robot.is_connected:
robot.disconnect()
def test_connect_disconnect(reachy2):
assert not reachy2.is_connected
reachy2.connect()
assert reachy2.is_connected
reachy2.disconnect()
assert not reachy2.is_connected
reachy2.reachy.disconnect.assert_called_once()
def test_get_action(reachy2):
reachy2.connect()
action = reachy2.get_action()
expected_keys = set(reachy2.joints_dict)
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
assert set(action.keys()) == expected_keys
for motor in reachy2.joints_dict.keys():
if reachy2.config.use_present_position:
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
else:
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
if reachy2.config.with_mobile_base:
if reachy2.config.use_present_position:
for vel in REACHY2_VEL.keys():
assert action[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
else:
for vel in REACHY2_VEL.keys():
assert action[vel] == reachy2.reachy.mobile_base.last_cmd_vel[REACHY2_VEL[vel]]
def test_no_part_declared():
with pytest.raises(ValueError):
_ = Reachy2TeleoperatorConfig(
ip_address="192.168.0.200",
with_mobile_base=False,
with_l_arm=False,
with_r_arm=False,
with_neck=False,
with_antennas=False,
)