mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 22:59:50 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c75df5c3b9 | |||
| e2740fe555 | |||
| d602e8169c | |||
| 49baccdccb | |||
| 6a3d57031a | |||
| d74494d92b | |||
| 882c80d446 | |||
| 61b0eeae4b | |||
| 577cd10974 | |||
| b0923ab74b | |||
| 7f70b78f32 | |||
| 55198de096 | |||
| 0878c6880f | |||
| 11e6bd762a | |||
| ce3b9f627e | |||
| c66cd40176 |
@@ -30,7 +30,7 @@ pytest -sx tests/test_stuff.py::test_something
|
||||
```
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --some.option=true
|
||||
lerobot-train --some.option=true
|
||||
```
|
||||
|
||||
## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR
|
||||
|
||||
@@ -29,8 +29,8 @@ on:
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-gpu:latest
|
||||
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-cpu:latest
|
||||
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest
|
||||
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest
|
||||
|
||||
# Ensures that only the latest commit is built, canceling older runs.
|
||||
concurrency:
|
||||
|
||||
@@ -44,7 +44,7 @@ test-end-to-end:
|
||||
${MAKE} DEVICE=$(DEVICE) test-smolvla-ete-eval
|
||||
|
||||
test-act-ete-train:
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=act \
|
||||
--policy.dim_model=64 \
|
||||
--policy.n_action_steps=20 \
|
||||
@@ -68,12 +68,12 @@ test-act-ete-train:
|
||||
--output_dir=tests/outputs/act/
|
||||
|
||||
test-act-ete-train-resume:
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=tests/outputs/act/checkpoints/000002/pretrained_model/train_config.json \
|
||||
--resume=true
|
||||
|
||||
test-act-ete-eval:
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=aloha \
|
||||
@@ -82,7 +82,7 @@ test-act-ete-eval:
|
||||
--eval.batch_size=1
|
||||
|
||||
test-diffusion-ete-train:
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=diffusion \
|
||||
--policy.down_dims='[64,128,256]' \
|
||||
--policy.diffusion_step_embed_dim=32 \
|
||||
@@ -106,7 +106,7 @@ test-diffusion-ete-train:
|
||||
--output_dir=tests/outputs/diffusion/
|
||||
|
||||
test-diffusion-ete-eval:
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=pusht \
|
||||
@@ -115,7 +115,7 @@ test-diffusion-ete-eval:
|
||||
--eval.batch_size=1
|
||||
|
||||
test-tdmpc-ete-train:
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=tdmpc \
|
||||
--policy.device=$(DEVICE) \
|
||||
--policy.push_to_hub=false \
|
||||
@@ -137,7 +137,7 @@ test-tdmpc-ete-train:
|
||||
--output_dir=tests/outputs/tdmpc/
|
||||
|
||||
test-tdmpc-ete-eval:
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=xarm \
|
||||
@@ -148,7 +148,7 @@ test-tdmpc-ete-eval:
|
||||
|
||||
|
||||
test-smolvla-ete-train:
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--policy.n_action_steps=20 \
|
||||
--policy.chunk_size=20 \
|
||||
@@ -171,7 +171,7 @@ test-smolvla-ete-train:
|
||||
--output_dir=tests/outputs/smolvla/
|
||||
|
||||
test-smolvla-ete-eval:
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=tests/outputs/smolvla/checkpoints/000004/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=aloha \
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://github.com/huggingface/lerobot/actions/workflows/nighty.yml?query=branch%3Amain)
|
||||
[](https://github.com/huggingface/lerobot/actions/workflows/nightly.yml?query=branch%3Amain)
|
||||
[](https://www.python.org/downloads/)
|
||||
[](https://github.com/huggingface/lerobot/blob/main/LICENSE)
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
@@ -276,7 +276,7 @@ Check out [example 2](https://github.com/huggingface/lerobot/blob/main/examples/
|
||||
We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht):
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/diffusion_pusht \
|
||||
--env.type=pusht \
|
||||
--eval.batch_size=10 \
|
||||
@@ -288,10 +288,10 @@ python -m lerobot.scripts.eval \
|
||||
Note: After training your own policy, you can re-evaluate the checkpoints with:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.eval --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model
|
||||
lerobot-eval --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
See `python -m lerobot.scripts.eval --help` for more instructions.
|
||||
See `lerobot-eval --help` for more instructions.
|
||||
|
||||
### Train your own policy
|
||||
|
||||
@@ -303,7 +303,7 @@ A link to the wandb logs for the run will also show up in yellow in your termina
|
||||
|
||||
\<img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/wandb.png" alt="WandB logs example"\>
|
||||
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python -m lerobot.scripts.eval --help` for more instructions.
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `lerobot-eval --help` for more instructions.
|
||||
|
||||
#### Reproduce state-of-the-art (SOTA)
|
||||
|
||||
@@ -311,7 +311,7 @@ We provide some pretrained policies on our [hub page](https://huggingface.co/ler
|
||||
You can reproduce their training by loading the config from their run. Simply running:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht
|
||||
lerobot-train --config_path=lerobot/diffusion_pusht
|
||||
```
|
||||
|
||||
reproduces SOTA results for Diffusion Policy on the PushT task.
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env python
|
||||
"""Simple script to check buffer naming in the transformed model."""
|
||||
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
# Load the model with strict=False to see what buffers we have
|
||||
print("Loading model...")
|
||||
policy = PI0Policy.from_pretrained("pepijn223/pi0_libero_lerobot", strict=False)
|
||||
|
||||
# Check what buffer keys exist
|
||||
state_dict = policy.state_dict()
|
||||
buffer_keys = [k for k in state_dict.keys() if "buffer" in k]
|
||||
normalize_keys = [k for k in state_dict.keys() if "normalize" in k]
|
||||
|
||||
print("\nAll buffer keys:")
|
||||
for key in buffer_keys:
|
||||
print(f" {key}")
|
||||
|
||||
print("\nAll normalize keys:")
|
||||
for key in normalize_keys:
|
||||
print(f" {key}")
|
||||
|
||||
print("\nAll keys (first 20):")
|
||||
for i, key in enumerate(state_dict.keys()):
|
||||
if i < 20:
|
||||
print(f" {key}")
|
||||
@@ -29,7 +29,7 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
||||
|
||||
# Install system dependencies and uv (as root)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential git curl libglib2.0-0 libegl1-mesa ffmpeg \
|
||||
build-essential git curl libglib2.0-0 libegl1-mesa-dev ffmpeg \
|
||||
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||
|
||||
@@ -35,10 +35,14 @@
|
||||
title: Koch v1.1
|
||||
- local: lekiwi
|
||||
title: LeKiwi
|
||||
- local: reachy2
|
||||
title: Reachy 2
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: notebooks
|
||||
title: Notebooks
|
||||
- local: feetech
|
||||
title: Updating Feetech Firmware
|
||||
title: "Resources"
|
||||
- sections:
|
||||
- local: contributing
|
||||
|
||||
@@ -9,7 +9,7 @@ To instantiate a camera, you need a camera identifier. This identifier might cha
|
||||
To find the camera indices of the cameras plugged into your system, run the following script:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_cameras opencv # or realsense for Intel Realsense cameras
|
||||
lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
|
||||
```
|
||||
|
||||
The output will look something like this if you have two cameras connected:
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
# Feetech Motor Firmware Update
|
||||
|
||||
This tutorial guides you through updating the firmware of Feetech motors using the official Feetech software.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Windows computer (Feetech software is only available for Windows)
|
||||
- Feetech motor control board
|
||||
- USB cable to connect the control board to your computer
|
||||
- Feetech motors connected to the control board
|
||||
|
||||
## Step 1: Download Feetech Software
|
||||
|
||||
1. Visit the official Feetech software download page: [https://www.feetechrc.com/software.html](https://www.feetechrc.com/software.html)
|
||||
2. Download the latest version of the Feetech debugging software (FD)
|
||||
3. Install the software on your Windows computer
|
||||
|
||||
## Step 2: Hardware Setup
|
||||
|
||||
1. Connect your Feetech motors to the motor control board
|
||||
2. Connect the motor control board to your Windows computer via USB cable
|
||||
3. Ensure power is supplied to the motors
|
||||
|
||||
## Step 3: Configure Connection
|
||||
|
||||
1. Launch the Feetech debugging software
|
||||
2. Select the correct COM port from the port dropdown menu
|
||||
- If unsure which port to use, check Windows Device Manager under "Ports (COM & LPT)"
|
||||
3. Set the appropriate baud rate (typically 1000000 for most Feetech motors)
|
||||
4. Click "Open" to establish communication with the control board
|
||||
|
||||
## Step 4: Scan for Motors
|
||||
|
||||
1. Once connected, click the "Search" button to detect all connected motors
|
||||
2. The software will automatically discover and list all motors on the bus
|
||||
3. Each motor will appear with its ID number
|
||||
|
||||
## Step 5: Update Firmware
|
||||
|
||||
For each motor you want to update:
|
||||
|
||||
1. **Select the motor** from the list by clicking on it
|
||||
2. **Click on Upgrade tab**:
|
||||
3. **Click on Online button**:
|
||||
- If an potential firmware update is found, it will be displayed in the box
|
||||
4. **Click on Upgrade button**:
|
||||
- The update progress will be displayed
|
||||
|
||||
## Step 6: Verify Update
|
||||
|
||||
1. After the update completes, the software should automatically refresh the motor information
|
||||
2. Verify that the firmware version has been updated to the expected version
|
||||
|
||||
## Important Notes
|
||||
|
||||
⚠️ **Warning**: Do not disconnect power or USB during firmware updates, it will potentially brick the motor.
|
||||
|
||||
## Bonus: Motor Debugging on Linux/macOS
|
||||
|
||||
For debugging purposes only, you can use the open-source Feetech Debug Tool:
|
||||
|
||||
- **Repository**: [FT_SCServo_Debug_Qt](https://github.com/CarolinePascal/FT_SCServo_Debug_Qt/tree/fix/port-search-timer)
|
||||
|
||||
### Installation Instructions
|
||||
|
||||
Follow the instructions in the repository to install the tool, for Ubuntu you can directly install it, for MacOS you need to build it from source.
|
||||
|
||||
**Limitations:**
|
||||
|
||||
- This tool is for debugging and parameter adjustment only
|
||||
- Firmware updates must still be done on Windows with official Feetech software
|
||||
@@ -412,7 +412,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
||||
To train the classifier, use the `train.py` script with your configuration:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --config_path path/to/reward_classifier_train_config.json
|
||||
lerobot-train --config_path path/to/reward_classifier_train_config.json
|
||||
```
|
||||
|
||||
**Deploying and Testing the Model**
|
||||
@@ -458,7 +458,7 @@ The reward classifier will automatically provide rewards based on the visual inp
|
||||
3. **Train the classifier**:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --config_path src/lerobot/configs/reward_classifier_train_config.json
|
||||
lerobot-train --config_path src/lerobot/configs/reward_classifier_train_config.json
|
||||
```
|
||||
|
||||
4. **Test the classifier**:
|
||||
|
||||
+11
-11
@@ -19,7 +19,7 @@ pip install -e ".[hopejr]"
|
||||
Before starting calibration and operation, you need to identify the USB ports for each HopeJR component. Run this script to find the USB ports for the arm, hand, glove, and exoskeleton:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
This will display the available USB ports and their associated devices. Make note of the port paths (e.g., `/dev/tty.usbmodem58760433331`, `/dev/tty.usbmodem11301`) as you'll need to specify them in the `--robot.port` and `--teleop.port` parameters when recording data, replaying episodes, or running teleoperation scripts.
|
||||
@@ -31,7 +31,7 @@ Before performing teleoperation, HopeJR's limbs need to be calibrated. Calibrati
|
||||
### 1.1 Calibrate Robot Hand
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=blue \
|
||||
@@ -81,7 +81,7 @@ Once you have set the appropriate boundaries for all joints, click "Save" to sav
|
||||
### 1.2 Calibrate Teleoperator Glove
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=homunculus_glove \
|
||||
--teleop.port=/dev/tty.usbmodem11201 \
|
||||
--teleop.id=red \
|
||||
@@ -120,7 +120,7 @@ Once calibration is complete, the system will save the calibration to `/Users/yo
|
||||
### 1.3 Calibrate Robot Arm
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=hope_jr_arm \
|
||||
--robot.port=/dev/tty.usbserial-1110 \
|
||||
--robot.id=white
|
||||
@@ -146,7 +146,7 @@ Use the calibration interface to set the range boundaries for each joint. Move e
|
||||
### 1.4 Calibrate Teleoperator Exoskeleton
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=homunculus_arm \
|
||||
--teleop.port=/dev/tty.usbmodem11201 \
|
||||
--teleop.id=black
|
||||
@@ -178,7 +178,7 @@ Due to global variable conflicts in the Feetech middleware, teleoperation for ar
|
||||
### Hand
|
||||
|
||||
```bash
|
||||
python -m lerobot.teleoperate \
|
||||
lerobot-teleoperate \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=blue \
|
||||
@@ -194,7 +194,7 @@ python -m lerobot.teleoperate \
|
||||
### Arm
|
||||
|
||||
```bash
|
||||
python -m lerobot.teleoperate \
|
||||
lerobot-teleoperate \
|
||||
--robot.type=hope_jr_arm \
|
||||
--robot.port=/dev/tty.usbserial-1110 \
|
||||
--robot.id=white \
|
||||
@@ -214,7 +214,7 @@ Record, Replay and Train with Hope-JR is still experimental.
|
||||
This step records the dataset, which can be seen as an example [here](https://huggingface.co/datasets/nepyope/hand_record_test_with_video_data/settings).
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
@@ -236,7 +236,7 @@ python -m lerobot.record \
|
||||
### Replay
|
||||
|
||||
```bash
|
||||
python -m lerobot.replay \
|
||||
lerobot-replay \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
@@ -248,7 +248,7 @@ python -m lerobot.replay \
|
||||
### Train
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/hopejr_hand \
|
||||
@@ -263,7 +263,7 @@ python -m lerobot.scripts.train \
|
||||
This training run can be viewed as an example [here](https://wandb.ai/tino/lerobot/runs/rp0k8zvw?nw=nwusertino).
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
|
||||
@@ -45,7 +45,7 @@ Note that the `id` associated with a robot is used to store the calibration file
|
||||
<hfoptions id="teleoperate_so101">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.teleoperate \
|
||||
lerobot-teleoperate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
@@ -101,7 +101,7 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
|
||||
<hfoptions id="teleoperate_koch_camera">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.teleoperate \
|
||||
lerobot-teleoperate \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
@@ -174,7 +174,7 @@ Now you can record a dataset. To record 5 episodes and upload your dataset to th
|
||||
<hfoptions id="record">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
@@ -376,7 +376,7 @@ You can replay the first episode on your robot with either the command below or
|
||||
<hfoptions id="replay">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.replay \
|
||||
lerobot-replay \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
@@ -428,10 +428,10 @@ Your robot should replicate movements similar to those you recorded. For example
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/so101_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_so101_test \
|
||||
@@ -453,7 +453,7 @@ Training should take several hours. You will find checkpoints in `outputs/train/
|
||||
To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so101_test` policy:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=outputs/train/act_so101_test/checkpoints/last/pretrained_model/train_config.json \
|
||||
--resume=true
|
||||
```
|
||||
@@ -490,7 +490,7 @@ You can use the `record` script from [`lerobot/record.py`](https://github.com/hu
|
||||
<hfoptions id="eval">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM1 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
||||
|
||||
@@ -96,10 +96,10 @@ If you uploaded your dataset to the hub you can [visualize your dataset online](
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/il_gym \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/il_sim_test \
|
||||
|
||||
@@ -31,7 +31,7 @@ pip install -e ".[dynamixel]"
|
||||
To find the port for each bus servo adapter, run this script:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<hfoptions id="example">
|
||||
@@ -98,7 +98,7 @@ For a visual reference on how to set the motor ids please refer to [this video](
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -174,7 +174,7 @@ Do the same steps for the leader arm but modify the command or script accordingl
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--teleop.type=koch_leader \
|
||||
--teleop.port=/dev/tty.usbmodem575E0031751 \ # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -211,7 +211,7 @@ Run the following command or API example to calibrate the follower arm:
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--robot.id=my_awesome_follower_arm # <- Give the robot a unique name
|
||||
@@ -249,7 +249,7 @@ Do the same steps to calibrate the leader arm, run the following command or API
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=koch_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
|
||||
|
||||
@@ -60,7 +60,7 @@ First, we will assemble the two SO100/SO101 arms. One to attach to the mobile ba
|
||||
To find the port for each bus servo adapter, run this script:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<hfoptions id="example">
|
||||
@@ -116,7 +116,7 @@ The instructions for configuring the motors can be found in the SO101 [docs](./s
|
||||
You can run this command to setup motors for LeKiwi. It will first setup the motors for arm (id 6..1) and then setup motors for wheels (9,8,7)
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--robot.type=lekiwi \
|
||||
--robot.port=/dev/tty.usbmodem58760431551 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -174,7 +174,7 @@ The calibration process is very important because it allows a neural network tra
|
||||
Make sure the arm is connected to the Raspberry Pi and run this script or API example (on the Raspberry Pi via SSH) to launch calibration of the follower arm:
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=lekiwi \
|
||||
--robot.id=my_awesome_kiwi # <- Give the robot a unique name
|
||||
```
|
||||
@@ -193,7 +193,7 @@ Then, to calibrate the leader arm (which is attached to the laptop/pc). Run the
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
|
||||
|
||||
@@ -0,0 +1,288 @@
|
||||
# Reachy 2
|
||||
|
||||
Reachy 2 is an open-source humanoid robot made by Pollen Robotics, specifically designed for the development of embodied AI and real-world applications.
|
||||
Check out [Pollen Robotics website](https://www.pollen-robotics.com/reachy/), or access [Reachy 2 documentation](https://docs.pollen-robotics.com/) for more information on the platform!
|
||||
|
||||
## Teleoperate Reachy 2
|
||||
|
||||
Currently, there are two ways to teleoperate Reachy 2:
|
||||
|
||||
- Pollen Robotics’ VR teleoperation (not included in LeRobot).
|
||||
- Robot-to-robot teleoperation (use one Reachy 2 to control another).
|
||||
|
||||
## Reachy 2 Simulation
|
||||
|
||||
**(Linux only)** You can run Reachy 2 in simulation (Gazebo or MuJoCo) using the provided [Docker image](https://hub.docker.com/r/pollenrobotics/reachy2_core).
|
||||
|
||||
1. Install [Docker Engine](https://docs.docker.com/engine/).
|
||||
2. Run (for MuJoCo):
|
||||
|
||||
```
|
||||
docker run --rm -it \
|
||||
--name reachy \
|
||||
--privileged \
|
||||
--network host \
|
||||
--ipc host \
|
||||
--device-cgroup-rule='c 189:* rwm' \
|
||||
--group-add audio \
|
||||
-e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
|
||||
-e DISPLAY="$DISPLAY" \
|
||||
-e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
|
||||
-e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
|
||||
-v /dev:/dev \
|
||||
-v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
|
||||
-v "$HOME/.reachy.log":/home/reachy/.ros/log \
|
||||
-v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
|
||||
--entrypoint /package/launch.sh \
|
||||
pollenrobotics/reachy2_core:1.7.5.9_deploy \
|
||||
start_rviz:=true start_sdk_server:=true mujoco:=true
|
||||
```
|
||||
|
||||
> If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance:
|
||||
>
|
||||
> ```
|
||||
> docker run --rm -it \
|
||||
> --name reachy \
|
||||
> --privileged \
|
||||
> --network host \
|
||||
> --ipc host \
|
||||
> --device-cgroup-rule='c 189:* rwm' \
|
||||
> --group-add audio \
|
||||
> -e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
|
||||
> -e DISPLAY="$DISPLAY" \
|
||||
> -e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
|
||||
> -e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
|
||||
> -e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \
|
||||
> -v /dev:/dev \
|
||||
> -v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
|
||||
> -v "$HOME/.reachy.log":/home/reachy/.ros/log \
|
||||
> -v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
|
||||
> --entrypoint /package/launch.sh \
|
||||
> pollenrobotics/reachy2_core:1.7.5.9_deploy \
|
||||
> start_rviz:=true start_sdk_server:=true mujoco:=true
|
||||
> ```
|
||||
|
||||
## Setup
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- On your robot, check the **service images** meet the minimum versions:
|
||||
- **reachy2-core >= 1.7.5.2**
|
||||
- **webrtc >= 2.0.1.1**
|
||||
|
||||
Then, if you want to use VR teleoperation:
|
||||
|
||||
- Install the [Reachy 2 teleoperation application](https://docs.pollen-robotics.com/teleoperation/teleoperation-introduction/discover-teleoperation/).
|
||||
Use version **>=v1.2.0**
|
||||
|
||||
We recommend using two computers: one for teleoperation (Windows required) and another for recording with LeRobot.
|
||||
|
||||
### Install LeRobot
|
||||
|
||||
Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot.
|
||||
|
||||
Install LeRobot with Reachy 2 dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[reachy2]"
|
||||
```
|
||||
|
||||
### (Optional but recommended) Install pollen_data_acquisition_server
|
||||
|
||||
How you manage Reachy 2 recording sessions is up to you, but the **easiest** way is to use this server so you can control sessions directly from the VR teleoperation app.
|
||||
|
||||
> **Note:** Currently, only the VR teleoperation application works as a client for this server, so this step primarily targets teleoperation. You’re free to develop custom clients to manage sessions to your needs.
|
||||
|
||||
In your LeRobot environment, install the server from source:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/pollen-robotics/pollen_data_acquisition_server.git
|
||||
cd pollen_data_acquisition_server
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Find the [pollen_data_acquisition_server documentation here](https://github.com/pollen-robotics/pollen_data_acquisition_server).
|
||||
|
||||
## Step 1: Recording
|
||||
|
||||
### Get Reachy 2 IP address
|
||||
|
||||
Before starting teleoperation and data recording, find the [robot's IP address](https://docs.pollen-robotics.com/getting-started/setup-reachy2/connect-reachy2/).
|
||||
We strongly recommend connecting all devices (PC and robot) via **Ethernet**.
|
||||
|
||||
### Launch recording
|
||||
|
||||
There are two ways to manage recording sessions when using the Reachy 2 VR teleoperation application:
|
||||
|
||||
- **Using the data acquisition server (recommended for VR teleop)**: The VR app orchestrates sessions (via the server it tells LeRobot when to create datasets, start/stop episodes) while also controlling the robot’s motions.
|
||||
- **Using LeRobot’s record script**: LeRobot owns session control and decides when to start/stop episodes. If you also use the VR teleop app, it’s only for motion control.
|
||||
|
||||
### Option 1: Using Pollen data acquisition server (recommended for VR teleop)
|
||||
|
||||
Make sure you have installed pollen_data_acquisition_server, as explained in the Setup section.
|
||||
|
||||
Launch the data acquisition server to be able to manage your session directly from the teleoperation application:
|
||||
|
||||
```bash
|
||||
python -m pollen_data_acquisition_server.server
|
||||
```
|
||||
|
||||
Then get into the teleoperation application and choose "Data acquisition session".
|
||||
You can finally setup your session by following the screens displayed.
|
||||
|
||||
> Even without the VR app, you can use the `pollen_data_acquisition_server` with your own client implementation.
|
||||
|
||||
### Option 2: Using lerobot.record
|
||||
|
||||
Reachy 2 is fully supported by LeRobot’s recording features.
|
||||
If you choose this option but still want to use the VR teleoperation application, select "Standard session" in the app.
|
||||
|
||||
**Example: start a recording without the mobile base:**
|
||||
First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command:
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.id=r2-0000 \
|
||||
--robot.use_external_commands=true \
|
||||
--robot.with_mobile_base=false \
|
||||
--teleop.type=reachy2_teleoperator \
|
||||
--teleop.ip_address=192.168.0.200 \
|
||||
--teleop.with_mobile_base=false \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.single_task="Reachy 2 recording test" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.fps=15 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
#### Specific Options
|
||||
|
||||
**Extended setup overview (all options included):**
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.use_external_commands=true \
|
||||
--robot.with_mobile_base=true \
|
||||
--robot.with_l_arm=true \
|
||||
--robot.with_r_arm=true \
|
||||
--robot.with_neck=true \
|
||||
--robot.with_antennas=true \
|
||||
--robot.with_left_teleop_camera=true \
|
||||
--robot.with_right_teleop_camera=true \
|
||||
--robot.with_torso_camera=false \
|
||||
--robot.disable_torque_on_disconnect=false \
|
||||
--robot.max_relative_target=5.0 \
|
||||
--teleop.type=reachy2_teleoperator \
|
||||
--teleop.ip_address=192.168.0.200 \
|
||||
--teleop.use_present_position=false \
|
||||
--teleop.with_mobile_base=false \
|
||||
--teleop.with_l_arm=true \
|
||||
--teleop.with_r_arm=true \
|
||||
--teleop.with_neck=true \
|
||||
--teleop.with_antennas=true \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.single_task="Reachy 2 recording test" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.fps=15 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
##### `--robot.use_external_commands`
|
||||
|
||||
Determine whether LeRobot robot.send_action() sends commands to the robot.
|
||||
**Must** be set to false while using the VR teleoperation application, as the app already sends commands.
|
||||
|
||||
##### `--teleop.use_present_position`
|
||||
|
||||
Determine whether the teleoperator reads the goal or present position of the robot.
|
||||
Must be set to true if a compliant Reachy 2 is used to control another one.
|
||||
|
||||
##### Use the relevant parts
|
||||
|
||||
From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies.
|
||||
To avoid this, you can exclude specific parts from recording and replay using:
|
||||
|
||||
````
|
||||
--robot.with_<part>=false
|
||||
```,
|
||||
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
|
||||
It determine whether the corresponding part is recorded in the observations. True if not set.
|
||||
|
||||
By default, **all parts are recorded**.
|
||||
|
||||
The same per-part mechanism is available in `reachy2_teleoperator` as well.
|
||||
|
||||
````
|
||||
|
||||
--teleop.with\_<part>
|
||||
|
||||
```
|
||||
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
|
||||
Determine whether the corresponding part is recorded in the actions. True if not set.
|
||||
|
||||
> **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator.
|
||||
For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`.
|
||||
|
||||
##### Use the relevant cameras
|
||||
|
||||
You can do the same for **cameras**. By default, only the **teleoperation cameras** are recorded (both `left_teleop_camera` and `right_teleop_camera`). Enable or disable each camera with:
|
||||
|
||||
```
|
||||
|
||||
--robot.with_left_teleop_camera=<true|false>
|
||||
--robot.with_right_teleop_camera=<true|false>
|
||||
--robot.with_torso_camera=<true|false>
|
||||
|
||||
````
|
||||
|
||||
|
||||
## Step 2: Replay
|
||||
|
||||
Make sure the robot is configured with the same parts as the dataset:
|
||||
|
||||
```bash
|
||||
python -m lerobot.replay \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.use_external_commands=false \
|
||||
--robot.with_mobile_base=false \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.episode=0
|
||||
--display_data=true
|
||||
````
|
||||
|
||||
## Step 3: Train
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/reachy2_test \
|
||||
--job_name=reachy2 \
|
||||
--policy.device=mps \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=pollen_robotics/record_test_policy
|
||||
```
|
||||
|
||||
## Step 4: Evaluate
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=pollen_robotics/eval_record_test \
|
||||
--dataset.single_task="Evaluate reachy2 policy" \
|
||||
--dataset.num_episodes=10 \
|
||||
--policy.path=outputs/train/reachy2_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
@@ -54,7 +54,7 @@ If you don't have a gpu device, you can train using our notebook on [.
|
||||
|
||||
```bash
|
||||
cd lerobot && python -m lerobot.scripts.train \
|
||||
cd lerobot && lerobot-train \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=${HF_USER}/mydataset \
|
||||
--batch_size=64 \
|
||||
@@ -73,7 +73,7 @@ cd lerobot && python -m lerobot.scripts.train \
|
||||
Fine-tuning is an art. For a complete overview of the options for finetuning, run
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --help
|
||||
lerobot-train --help
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
@@ -97,7 +97,7 @@ Similarly for when recording an episode, it is recommended that you are logged i
|
||||
Once you are logged in, you can run inference in your setup by doing:
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \ # <- Use your port
|
||||
--robot.id=my_blue_follower_arm \ # <- Use your robot id
|
||||
|
||||
@@ -26,7 +26,7 @@ Unlike the SO-101, the motor connectors are not easily accessible once the arm i
|
||||
To find the port for each bus servo adapter, run this script:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<hfoptions id="example">
|
||||
@@ -93,7 +93,7 @@ For a visual reference on how to set the motor ids please refer to [this video](
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -168,7 +168,7 @@ Do the same steps for the leader arm.
|
||||
<hfoptions id="setup_motors">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -568,7 +568,7 @@ Run the following command or API example to calibrate the follower arm:
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--robot.id=my_awesome_follower_arm # <- Give the robot a unique name
|
||||
@@ -606,7 +606,7 @@ Do the same steps to calibrate the leader arm, run the following command or API
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
|
||||
|
||||
@@ -162,7 +162,7 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
|
||||
To find the port for each bus servo adapter, connect MotorBus to your computer via USB and power. Run the following script and disconnect the MotorBus when prompted:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<hfoptions id="example">
|
||||
@@ -240,7 +240,7 @@ Connect the usb cable from your computer and the power supply to the follower ar
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -316,7 +316,7 @@ Do the same steps for the leader arm.
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -353,7 +353,7 @@ Run the following command or API example to calibrate the follower arm:
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--robot.id=my_awesome_follower_arm # <- Give the robot a unique name
|
||||
@@ -402,7 +402,7 @@ Do the same steps to calibrate the leader arm, run the following command or API
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
|
||||
|
||||
@@ -62,7 +62,7 @@ By default, every field takes its default value specified in the dataclass. If a
|
||||
Let's say that we want to train [Diffusion Policy](../src/lerobot/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--policy.type=diffusion \
|
||||
--env.type=pusht
|
||||
@@ -77,7 +77,7 @@ Let's break this down:
|
||||
Let's see another example. Let's say you've been training [ACT](../src/lerobot/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
@@ -90,7 +90,7 @@ We now want to train a different policy for aloha on another task. We'll change
|
||||
Looking at the [`AlohaEnv`](../src/lerobot/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
--env.type=aloha \
|
||||
@@ -127,7 +127,7 @@ Now, let's assume that we want to reproduce the run just above. That run has pro
|
||||
We can then simply load the config values from this file using:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
|
||||
--output_dir=outputs/train/act_aloha_transfer_2
|
||||
```
|
||||
@@ -137,7 +137,7 @@ python -m lerobot.scripts.train \
|
||||
Similarly to Hydra, we can still override some parameters in the CLI if we want to, e.g.:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
|
||||
--output_dir=outputs/train/act_aloha_transfer_2
|
||||
--policy.n_action_steps=80
|
||||
@@ -148,7 +148,7 @@ python -m lerobot.scripts.train \
|
||||
`--config_path` can also accept the repo_id of a repo on the hub that contains a `train_config.json` file, e.g. running:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht
|
||||
lerobot-train --config_path=lerobot/diffusion_pusht
|
||||
```
|
||||
|
||||
will start a training run with the same configuration used for training [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)
|
||||
@@ -160,7 +160,7 @@ Being able to resume a training run is important in case it crashed or aborted f
|
||||
Let's reuse the command from the previous run and add a few more options:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
--env.type=aloha \
|
||||
@@ -179,7 +179,7 @@ INFO 2025-01-24 16:10:56 ts/train.py:263 Checkpoint policy after step 100
|
||||
Now let's simulate a crash by killing the process (hit `ctrl`+`c`). We can then simply resume this run from the last checkpoint available with:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
|
||||
--resume=true
|
||||
```
|
||||
@@ -190,7 +190,7 @@ Another reason for which you might want to resume a run is simply to extend trai
|
||||
You could double the number of steps of the previous run with:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
|
||||
--resume=true \
|
||||
--steps=200000
|
||||
@@ -224,7 +224,7 @@ In addition to the features currently in Draccus, we've added a special `.path`
|
||||
For example, we could fine-tune a [policy pre-trained on the aloha transfer task](https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human) on the aloha insertion task. We can achieve this with:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/act_aloha_sim_transfer_cube_human \
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
@@ -270,7 +270,7 @@ We'll summarize here the main use cases to remember from this tutorial.
|
||||
#### Train a policy from scratch – CLI
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=act \ # <- select 'act' policy
|
||||
--env.type=pusht \ # <- select 'pusht' environment
|
||||
--dataset.repo_id=lerobot/pusht # <- train on this dataset
|
||||
@@ -279,7 +279,7 @@ python -m lerobot.scripts.train \
|
||||
#### Train a policy from scratch - config file + CLI
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=path/to/pretrained_model \ # <- can also be a repo_id
|
||||
--policy.n_action_steps=80 # <- you may still override values
|
||||
```
|
||||
@@ -287,7 +287,7 @@ python -m lerobot.scripts.train \
|
||||
#### Resume/continue a training run
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=checkpoint/pretrained_model/ \
|
||||
--resume=true \
|
||||
--steps=200000 # <- you can change some training parameters
|
||||
@@ -296,7 +296,7 @@ python -m lerobot.scripts.train \
|
||||
#### Fine-tuning
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/act_aloha_sim_transfer_cube_human \ # <- can also be a local path to a checkpoint
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
|
||||
@@ -18,7 +18,7 @@ Replays the actions of an episode from a dataset on a robot.
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.replay \
|
||||
lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
|
||||
+347
@@ -0,0 +1,347 @@
|
||||
#!/usr/bin/env python
|
||||
"""Script for Pi0 pretrained policy inference and Hub upload."""
|
||||
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
# Set seed
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Pi0 policy inference and Hub upload")
|
||||
parser.add_argument(
|
||||
"--source-model-id",
|
||||
type=str,
|
||||
default="pepijn223/pi0_libero_lerobot",
|
||||
help="Source model repository ID on Hugging Face Hub",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-id", type=str, default="pepijn223/libero", help="Dataset repository ID on Hugging Face Hub"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-model-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output model repository ID to upload to (e.g., 'your-username/pi0-libero-fixed')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cpu", choices=["cpu", "cuda", "mps"], help="Device to run inference on"
|
||||
)
|
||||
parser.add_argument("--episode", type=int, default=0, help="Episode index to load from dataset")
|
||||
parser.add_argument(
|
||||
"--sample-idx", type=int, default=10, help="Sample index within episode to use for inference"
|
||||
)
|
||||
parser.add_argument("--private", action="store_true", help="Make the uploaded model private")
|
||||
parser.add_argument(
|
||||
"--commit-message", type=str, default=None, help="Custom commit message for the upload"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _inject_normalization_stats(policy: PI0Policy, dataset_meta: LeRobotDatasetMetadata, key_mapping: dict):
|
||||
"""Recreate normalization layers with proper stats from the dataset."""
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
# Convert numpy stats to the format expected by normalization layers and remap keys
|
||||
stats = {}
|
||||
for dataset_key, stat_dict in dataset_meta.stats.items():
|
||||
# Use mapped key if available, otherwise use original key
|
||||
policy_key = key_mapping.get(dataset_key, dataset_key)
|
||||
|
||||
stats[policy_key] = {
|
||||
stat_type: torch.from_numpy(stat_array) if isinstance(stat_array, np.ndarray) else stat_array
|
||||
for stat_type, stat_array in stat_dict.items()
|
||||
}
|
||||
|
||||
print(f"Available stats keys: {list(stats.keys())}")
|
||||
print(
|
||||
f"Policy expects keys: input={list(policy.config.input_features.keys())}, output={list(policy.config.output_features.keys())}"
|
||||
)
|
||||
|
||||
# Recreate normalization layers with proper stats
|
||||
normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
|
||||
|
||||
normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
|
||||
|
||||
unnormalize_outputs = Unnormalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, stats
|
||||
)
|
||||
|
||||
# Replace the normalization layers on the policy
|
||||
policy.normalize_inputs = normalize_inputs
|
||||
policy.normalize_targets = normalize_targets
|
||||
policy.unnormalize_outputs = unnormalize_outputs
|
||||
|
||||
print("Normalization layers recreated with dataset stats.")
|
||||
|
||||
|
||||
def configure_policy_features(policy: PI0Policy, dataset: LeRobotDataset):
|
||||
"""Configure policy input and output features based on dataset metadata."""
|
||||
print(f"Dataset features: {list(dataset.meta.features.keys())}")
|
||||
|
||||
# Create a proper mapping from dataset keys to policy keys
|
||||
dataset_to_policy_mapping = {}
|
||||
|
||||
# Handle images
|
||||
if "image" in dataset.meta.features:
|
||||
dataset_to_policy_mapping["image"] = "observation.images.image"
|
||||
if "wrist_image" in dataset.meta.features:
|
||||
dataset_to_policy_mapping["wrist_image"] = "observation.images.image2"
|
||||
|
||||
# Handle state
|
||||
if "state" in dataset.meta.features:
|
||||
dataset_to_policy_mapping["state"] = "observation.state"
|
||||
|
||||
# Handle actions
|
||||
if "actions" in dataset.meta.features:
|
||||
dataset_to_policy_mapping["actions"] = "action"
|
||||
|
||||
print(f"Key mapping: {dataset_to_policy_mapping}")
|
||||
|
||||
# Clear existing input features and reconfigure with proper mapping
|
||||
policy.config.input_features = {}
|
||||
policy.config.output_features = {}
|
||||
|
||||
# Map visual features
|
||||
for dataset_key, policy_key in dataset_to_policy_mapping.items():
|
||||
if dataset_key in ["image", "wrist_image"]:
|
||||
feature_info = dataset.meta.features[dataset_key]
|
||||
# Convert HWC to CHW format and resize
|
||||
shape = (3, 224, 224) # Pi0 expects CHW format
|
||||
policy.config.input_features[policy_key] = PolicyFeature(type=FeatureType.VISUAL, shape=shape)
|
||||
|
||||
# Map state features
|
||||
for dataset_key, policy_key in dataset_to_policy_mapping.items():
|
||||
if dataset_key == "state":
|
||||
feature_info = dataset.meta.features[dataset_key]
|
||||
shape = tuple(feature_info["shape"])
|
||||
policy.config.input_features[policy_key] = PolicyFeature(type=FeatureType.STATE, shape=shape)
|
||||
|
||||
# Map action features
|
||||
for dataset_key, policy_key in dataset_to_policy_mapping.items():
|
||||
if dataset_key == "actions":
|
||||
feature_info = dataset.meta.features[dataset_key]
|
||||
shape = tuple(feature_info["shape"])
|
||||
policy.config.output_features[policy_key] = PolicyFeature(type=FeatureType.ACTION, shape=shape)
|
||||
|
||||
print(f"Policy input_features: {list(policy.config.input_features.keys())}")
|
||||
print(f"Policy output_features: {list(policy.config.output_features.keys())}")
|
||||
print(f"Policy image_features: {list(policy.config.image_features.keys())}")
|
||||
print(f"Policy action_feature: {policy.config.action_feature}")
|
||||
|
||||
return dataset_to_policy_mapping
|
||||
|
||||
|
||||
def fix_buffer_naming(policy: PI0Policy):
|
||||
"""Fix buffer naming issues in the loaded policy state dict."""
|
||||
print("Fixing normalization buffer naming issues...")
|
||||
|
||||
state_dict = policy.state_dict()
|
||||
corrected_state_dict = {}
|
||||
fixes_applied = 0
|
||||
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
|
||||
# Fix buffer naming: buffer_observation_state_mean -> buffer_observation_state.mean
|
||||
if "buffer_observation_state_mean" in key:
|
||||
new_key = key.replace("buffer_observation_state_mean", "buffer_observation_state.mean")
|
||||
fixes_applied += 1
|
||||
print(f" Fixed: {key} -> {new_key}")
|
||||
elif "buffer_observation_state_std" in key:
|
||||
new_key = key.replace("buffer_observation_state_std", "buffer_observation_state.std")
|
||||
fixes_applied += 1
|
||||
print(f" Fixed: {key} -> {new_key}")
|
||||
# Remove image buffers that aren't expected (they cause conflicts)
|
||||
elif "buffer_observation_image_mean" in key or "buffer_observation_image_std" in key:
|
||||
print(f" Removed unexpected buffer: {key}")
|
||||
continue # Skip this buffer
|
||||
|
||||
corrected_state_dict[new_key] = value
|
||||
|
||||
# Add missing action buffers with dummy values (will be replaced by dataset stats)
|
||||
missing_buffers = [
|
||||
"normalize_targets.buffer_action.mean",
|
||||
"normalize_targets.buffer_action.std",
|
||||
"unnormalize_outputs.buffer_action.mean",
|
||||
"unnormalize_outputs.buffer_action.std",
|
||||
]
|
||||
|
||||
for buffer_key in missing_buffers:
|
||||
if buffer_key not in corrected_state_dict:
|
||||
# Use dummy values - these will be overwritten by proper dataset stats later
|
||||
if "mean" in buffer_key:
|
||||
corrected_state_dict[buffer_key] = torch.zeros(8) # Assume 8-dim action
|
||||
else: # std
|
||||
corrected_state_dict[buffer_key] = torch.ones(8) # Assume 8-dim action
|
||||
fixes_applied += 1
|
||||
print(f" Added missing buffer: {buffer_key}")
|
||||
|
||||
print(f"Applied {fixes_applied} buffer fixes")
|
||||
|
||||
# Load the corrected state dict back into the policy
|
||||
policy.load_state_dict(corrected_state_dict)
|
||||
return policy
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run the Pi0 inference and upload."""
|
||||
args = parse_args()
|
||||
|
||||
# Load pretrained Pi0 model directly from Hugging Face Hub
|
||||
print(f"Loading pretrained Pi0 model from {args.source_model_id}...")
|
||||
|
||||
# Load with strict=False to allow missing/unexpected keys, then fix them manually
|
||||
policy = PI0Policy.from_pretrained(args.source_model_id, strict=False)
|
||||
policy = fix_buffer_naming(policy)
|
||||
policy.eval()
|
||||
policy.to(args.device)
|
||||
|
||||
# Load dataset and get a sample
|
||||
print(f"Loading dataset: {args.dataset_id}")
|
||||
dataset = LeRobotDataset(args.dataset_id, episodes=[args.episode])
|
||||
meta: LeRobotDatasetMetadata = dataset.meta
|
||||
sample = dataset[args.sample_idx]
|
||||
|
||||
# Configure policy features
|
||||
key_mapping = configure_policy_features(policy, dataset)
|
||||
|
||||
# Inject normalization stats with proper key mapping
|
||||
_inject_normalization_stats(policy, meta, key_mapping)
|
||||
|
||||
# Prepare batch for PI0 (handle temporal dimensions)
|
||||
batch = {}
|
||||
|
||||
# Map dataset sample keys to policy keys
|
||||
reverse_mapping = {v: k for k, v in key_mapping.items()}
|
||||
|
||||
for policy_key in policy.config.input_features:
|
||||
# Find the corresponding dataset key
|
||||
dataset_key = reverse_mapping.get(policy_key, policy_key)
|
||||
|
||||
if dataset_key in sample:
|
||||
data = sample[dataset_key]
|
||||
|
||||
# Handle image data: convert from HWC to CHW and normalize
|
||||
if policy_key.startswith("observation.images."):
|
||||
if data.dim() == 3 and data.shape[-1] == 3: # HWC format
|
||||
data = data.permute(2, 0, 1) # Convert to CHW
|
||||
# Normalize to [0, 1] range if needed
|
||||
if data.dtype == torch.uint8:
|
||||
data = data.float() / 255.0
|
||||
# Resize to expected size if needed
|
||||
if data.shape[-2:] != (224, 224):
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
|
||||
data = F.interpolate(
|
||||
data.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False
|
||||
)[0]
|
||||
|
||||
# Remove temporal dimension if present
|
||||
if data.dim() > len(policy.config.input_features[policy_key].shape):
|
||||
data = data[0]
|
||||
|
||||
batch[policy_key] = data.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# Debug: print what's in the sample
|
||||
print(f"Sample keys: {list(sample.keys())}")
|
||||
print(f"Batch keys prepared: {list(batch.keys())}")
|
||||
|
||||
# Pi0 requires task description - add a default if not available
|
||||
if "task" in sample:
|
||||
batch["task"] = [sample["task"]] # Keep as list of strings
|
||||
else:
|
||||
print("No task in sample, using default task description")
|
||||
batch["task"] = ["Complete the manipulation task"]
|
||||
|
||||
print(f"Task: {batch['task'][0]}")
|
||||
print(f"Final batch keys: {list(batch.keys())}")
|
||||
|
||||
# Run inference
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
print(f"Predicted action shape: {action.shape}")
|
||||
print(f"Predicted action: {action.tolist()}")
|
||||
|
||||
print("✅ Pi0 pretrained inference completed successfully!")
|
||||
|
||||
# Upload to Hugging Face Hub
|
||||
print(f"\n📤 Uploading model to Hugging Face Hub: {args.output_model_id}")
|
||||
|
||||
# Create commit message
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
commit_message = (
|
||||
args.commit_message
|
||||
or f"Pi0 model with injected normalization stats from {args.dataset_id} - {timestamp}"
|
||||
)
|
||||
|
||||
# Update model configuration with dataset info
|
||||
policy.config.push_to_hub = True
|
||||
policy.config.repo_id = args.output_model_id
|
||||
policy.config.private = args.private
|
||||
|
||||
# Add metadata about the adaptation
|
||||
adaptation_info = {
|
||||
"source_model": args.source_model_id,
|
||||
"dataset_used": args.dataset_id,
|
||||
"adaptation_date": timestamp,
|
||||
"stats_injected": True,
|
||||
"key_mapping": key_mapping,
|
||||
"inference_test_passed": True,
|
||||
"sample_action_shape": list(action.shape),
|
||||
}
|
||||
|
||||
try:
|
||||
# Push to hub
|
||||
policy.push_to_hub(
|
||||
repo_id=args.output_model_id,
|
||||
private=args.private,
|
||||
commit_message=commit_message,
|
||||
create_pr=False,
|
||||
)
|
||||
|
||||
# Also save the adaptation info as a separate file
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
|
||||
# Create a temporary file with adaptation info
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(adaptation_info, f, indent=2)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
api.upload_file(
|
||||
path_or_fileobj=temp_path,
|
||||
path_in_repo="adaptation_info.json",
|
||||
repo_id=args.output_model_id,
|
||||
commit_message=f"Add adaptation metadata - {timestamp}",
|
||||
)
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
print(f"✅ Model successfully uploaded to: https://huggingface.co/{args.output_model_id}")
|
||||
print("📋 Adaptation info:")
|
||||
for key, value in adaptation_info.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error uploading to Hub: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+704
@@ -0,0 +1,704 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download # noqa: E402
|
||||
from safetensors.torch import load_file # noqa: E402
|
||||
from transformers.model_debugging_utils import model_addition_debugger_context
|
||||
|
||||
from lerobot.configs.policies import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
RANDOM_SEED = 42 # Set to fixed value for reproducible results
|
||||
|
||||
|
||||
def set_all_seeds(seed=42):
|
||||
"""Set all random seeds for reproducible results."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
torch.use_deterministic_algorithms(True)
|
||||
print(f"All random seeds set to {seed} for reproducible results (deterministic mode enabled)")
|
||||
|
||||
|
||||
# Set seeds at the start
|
||||
set_all_seeds(RANDOM_SEED)
|
||||
|
||||
config_model_path = "lerobot/pi0" # Use config from official model
|
||||
official_model_path = "lerobot/pi0" # Official model
|
||||
custom_model_path = "pepijn223/pi0_base_fp32" # Custom model to compare # pepijn223/pi0_base_fp32
|
||||
device = "mps"
|
||||
|
||||
USE_FULL_TENSORS = True
|
||||
SAVE_TENSORS_TO_DISK = False
|
||||
|
||||
# Model transformation and upload settings
|
||||
SAVE_TRANSFORMED_MODEL = True # Set to True to save the transformed model
|
||||
UPLOAD_TO_HUB = True # Set to True to upload to HuggingFace Hub
|
||||
TRANSFORMED_MODEL_NAME = "pepijn223/pi0_base_fp32_lerobot_format" # Target repo name
|
||||
COMMIT_MESSAGE = "Add transformed PI0 model with correct key format for lerobot"
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
debug_path = os.path.join("debug_outputs", f"pi0_debug_direct_{timestamp}")
|
||||
os.makedirs(debug_path, exist_ok=True)
|
||||
print(f"Model debugging enabled - outputs will be saved to: {debug_path}")
|
||||
|
||||
# Download and load the config manually to avoid draccus parsing issues
|
||||
config_file = hf_hub_download(repo_id=config_model_path, filename="config.json")
|
||||
with open(config_file) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Remove the 'type' field that causes draccus issues
|
||||
if "type" in config_dict:
|
||||
config_dict.pop("type")
|
||||
print("Removed 'type' field from config")
|
||||
|
||||
# Create shared PI0Config
|
||||
print("Creating shared PI0Config...")
|
||||
shared_config = PI0Config(**config_dict)
|
||||
|
||||
|
||||
def load_policy_with_weights(
|
||||
model_path: str, config: PI0Config, model_name: str, apply_transformations: bool = False
|
||||
):
|
||||
"""Load a policy with specified weights but shared config."""
|
||||
print(f"\n=== Loading {model_name} from {model_path} ===")
|
||||
|
||||
# Set deterministic seed before creating the policy to ensure identical initialization
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
np.random.seed(RANDOM_SEED)
|
||||
random.seed(RANDOM_SEED)
|
||||
|
||||
policy = PI0Policy(config)
|
||||
|
||||
# Download and load weights
|
||||
model_file = hf_hub_download(repo_id=model_path, filename="model.safetensors")
|
||||
print(f"Downloaded {model_name} weights to: {model_file}")
|
||||
|
||||
# Load state dict and apply transformations
|
||||
print(f"Investigating safetensors file: {model_file}")
|
||||
|
||||
# First, check what's in the metadata
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
all_keys_in_file = f.keys()
|
||||
print(f" Total keys in safetensors file: {len(list(all_keys_in_file))}")
|
||||
|
||||
# Check for embed_tokens in the file keys
|
||||
embed_keys_in_file = [k for k in f.keys() if "embed_tokens" in k]
|
||||
print(f" embed_tokens keys in safetensors: {embed_keys_in_file}")
|
||||
|
||||
if metadata:
|
||||
print(f" Metadata exists: {list(metadata.keys()) if metadata else 'None'}")
|
||||
except Exception as e:
|
||||
print(f" Could not inspect safetensors file directly: {e}")
|
||||
|
||||
# Now load normally and see what we get
|
||||
state_dict = load_file(model_file)
|
||||
print(f" Keys loaded by load_file(): {len(state_dict)} keys")
|
||||
|
||||
# Check for embed_tokens in loaded state_dict
|
||||
loaded_embed_keys = [k for k in state_dict.keys() if "embed_tokens" in k]
|
||||
print(f" embed_tokens keys in loaded state_dict: {loaded_embed_keys}")
|
||||
|
||||
# Check if we need to add "model." prefix (for custom models that don't have it)
|
||||
sample_key = next(iter(state_dict.keys()))
|
||||
if not sample_key.startswith("model."):
|
||||
print(f"Adding 'model.' prefix to all keys (detected format: {sample_key})")
|
||||
state_dict = {f"model.{k}": v for k, v in state_dict.items()}
|
||||
|
||||
# IMPORTANT: Call PI0Policy._transform_state_dict_keys AFTER adding model. prefix
|
||||
# This ensures tied weights logic can find the correct key pattern
|
||||
transformed_state_dict = PI0Policy._transform_state_dict_keys(state_dict)
|
||||
|
||||
# Apply specific PaliGemma key transformations only for custom models
|
||||
if apply_transformations:
|
||||
print("Applying custom model key transformations...")
|
||||
|
||||
# First, let's debug what keys we actually have
|
||||
all_keys = list(transformed_state_dict.keys())
|
||||
sample_keys = all_keys[:10]
|
||||
print(f"Sample keys to transform: {sample_keys}")
|
||||
|
||||
# Look for specific keys we need to transform and missing keys
|
||||
embed_tokens_keys = [k for k in all_keys if "embed_tokens" in k]
|
||||
embedding_keys = [k for k in all_keys if "embed" in k]
|
||||
lm_head_keys = [k for k in all_keys if "lm_head" in k]
|
||||
paligemma_keys = [
|
||||
k for k in all_keys if "paligemma_with_expert.paligemma" in k and "gemma_expert" not in k
|
||||
]
|
||||
language_model_keys = [k for k in all_keys if "language_model" in k]
|
||||
|
||||
print(f"Found embed_tokens keys: {embed_tokens_keys}")
|
||||
print(f"Found any embedding keys: {embedding_keys}")
|
||||
print(f"Found lm_head keys: {lm_head_keys}")
|
||||
print(
|
||||
f"Found paligemma keys (non-expert): {paligemma_keys[:5]}{'...' if len(paligemma_keys) > 5 else ''}"
|
||||
)
|
||||
print(
|
||||
f"Found language_model keys: {language_model_keys[:5]}{'...' if len(language_model_keys) > 5 else ''}"
|
||||
)
|
||||
print(f"Total keys in model: {len(all_keys)}")
|
||||
|
||||
# Check if the embed_tokens is in gemma_expert instead
|
||||
gemma_expert_embed = [k for k in all_keys if "gemma_expert" in k and "embed_tokens" in k]
|
||||
print(f"Found gemma_expert embed_tokens keys: {gemma_expert_embed}")
|
||||
|
||||
# Check what we're missing and what we actually have
|
||||
expected_embed_key = "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
if expected_embed_key not in all_keys:
|
||||
print(f" Missing expected embed_tokens key: {expected_embed_key}")
|
||||
|
||||
# Let's see what keys we actually have for debugging
|
||||
print("Debugging: Looking for any embedding-related keys...")
|
||||
all_embed_related = [k for k in all_keys if "embed" in k.lower()]
|
||||
print(f"Keys containing 'embed': {all_embed_related}")
|
||||
|
||||
# Look for any keys that might contain embeddings
|
||||
potential_embed_keys = [
|
||||
k for k in all_keys if any(word in k for word in ["embed", "embedding", "token"])
|
||||
]
|
||||
print(f" Potential embedding keys: {potential_embed_keys}")
|
||||
|
||||
# Try to find a suitable replacement
|
||||
if gemma_expert_embed:
|
||||
print(f" Will try to copy from: {gemma_expert_embed[0]}")
|
||||
else:
|
||||
print(" No gemma_expert embed_tokens found either!")
|
||||
|
||||
# Check if there's an embed_tokens in the gemma_expert that we missed
|
||||
gemma_keys = [k for k in all_keys if "gemma_expert" in k]
|
||||
print(f" First 10 gemma_expert keys: {gemma_keys[:10]}")
|
||||
|
||||
# Check if there are any token-related keys in gemma_expert
|
||||
token_keys = [k for k in all_keys if "gemma_expert" in k and "token" in k.lower()]
|
||||
print(f" Gemma expert token-related keys: {token_keys}")
|
||||
|
||||
# Check for any keys that look like they might be embeddings
|
||||
possible_embeds = [
|
||||
k
|
||||
for k in all_keys
|
||||
if any(
|
||||
pattern in k.lower() for pattern in ["embed_token", "embedding", "wte", "word_embed"]
|
||||
)
|
||||
]
|
||||
print(f" Possible embedding alternatives: {possible_embeds}")
|
||||
|
||||
final_state_dict = {}
|
||||
transformation_count = 0
|
||||
|
||||
for key, value in transformed_state_dict.items():
|
||||
new_key = key
|
||||
original_key = key
|
||||
|
||||
# Transform vision tower keys: ADD .model between paligemma and vision_tower
|
||||
if "paligemma_with_expert.paligemma.vision_tower.vision_model" in new_key:
|
||||
new_key = new_key.replace(
|
||||
"paligemma_with_expert.paligemma.vision_tower.vision_model",
|
||||
"paligemma_with_expert.paligemma.model.vision_tower.vision_model",
|
||||
)
|
||||
print(f"Transformed vision key: {original_key} -> {new_key}")
|
||||
transformation_count += 1
|
||||
|
||||
# Transform multi_modal_projector keys: ADD .model between paligemma and multi_modal_projector
|
||||
elif "paligemma_with_expert.paligemma.multi_modal_projector" in new_key:
|
||||
new_key = new_key.replace(
|
||||
"paligemma_with_expert.paligemma.multi_modal_projector",
|
||||
"paligemma_with_expert.paligemma.model.multi_modal_projector",
|
||||
)
|
||||
print(f"Transformed multi_modal_projector key: {original_key} -> {new_key}")
|
||||
transformation_count += 1
|
||||
|
||||
# NO transformation needed for language_model keys - they're already correct!
|
||||
# The custom model already has: paligemma.model.language_model.* which is what we need
|
||||
|
||||
# NO transformation needed for lm_head - it should stay as paligemma.lm_head
|
||||
|
||||
final_state_dict[new_key] = value
|
||||
|
||||
print(f"Applied {transformation_count} key transformations")
|
||||
transformed_state_dict = final_state_dict
|
||||
else:
|
||||
print("No transformations applied (official model format)")
|
||||
|
||||
# Debug: show what keys the policy expects vs what we have
|
||||
policy_keys = set(policy.state_dict().keys())
|
||||
provided_keys = set(transformed_state_dict.keys())
|
||||
|
||||
missing_in_provided = policy_keys - provided_keys
|
||||
extra_in_provided = provided_keys - policy_keys
|
||||
|
||||
print(f"Policy expects {len(policy_keys)} keys, we provide {len(provided_keys)} keys")
|
||||
if missing_in_provided:
|
||||
print(
|
||||
f" Missing from provided: {list(missing_in_provided)[:5]}{'...' if len(missing_in_provided) > 5 else ''}"
|
||||
)
|
||||
if extra_in_provided:
|
||||
print(
|
||||
f" Extra in provided: {list(extra_in_provided)[:5]}{'...' if len(extra_in_provided) > 5 else ''}"
|
||||
)
|
||||
|
||||
# Load the weights into the policy
|
||||
msg = policy.load_state_dict(transformed_state_dict, strict=True)
|
||||
print(
|
||||
f"{model_name} - Missing keys: {len(msg.missing_keys)}, Unexpected keys: {len(msg.unexpected_keys)}"
|
||||
)
|
||||
|
||||
if msg.missing_keys:
|
||||
print(
|
||||
f" Actually missing keys: {list(msg.missing_keys)[:3]}{'...' if len(msg.missing_keys) > 3 else ''}"
|
||||
)
|
||||
if msg.unexpected_keys:
|
||||
print(
|
||||
f" Actually unexpected keys: {list(msg.unexpected_keys)[:3]}{'...' if len(msg.unexpected_keys) > 3 else ''}"
|
||||
)
|
||||
|
||||
# Set deterministic mode and move to device
|
||||
policy = policy.to(device)
|
||||
policy.eval()
|
||||
|
||||
# Reset the policy to ensure identical internal state
|
||||
policy.reset()
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
# Load both models with shared config
|
||||
print("Loading both models with shared config...")
|
||||
official_policy = load_policy_with_weights(
|
||||
official_model_path, shared_config, "Official Model", apply_transformations=False
|
||||
)
|
||||
custom_policy = load_policy_with_weights(
|
||||
custom_model_path, shared_config, "Custom Model", apply_transformations=True
|
||||
)
|
||||
|
||||
print("\nBoth models loaded successfully!")
|
||||
print(f"Shared config: {shared_config}")
|
||||
print(f"Device: {device}")
|
||||
|
||||
|
||||
# Configure input features for both policies since they're not set by default in pretrained models
|
||||
def configure_policy_features(policy: PI0Policy):
|
||||
"""Configure input and output features for a policy."""
|
||||
policy.config.input_features[OBS_IMAGE] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224), # Channel-first RGB image
|
||||
)
|
||||
|
||||
policy.config.input_features[OBS_STATE] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(8,), # 8-dimensional state vector
|
||||
)
|
||||
|
||||
policy.config.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(8,), # 8-dimensional action vector
|
||||
)
|
||||
|
||||
# Add dummy normalization buffers to the policy (like openpi does with norm_stats)
|
||||
if hasattr(policy, "normalize_inputs"):
|
||||
# For observation.state (8-dim state vector)
|
||||
policy.normalize_inputs.register_buffer(
|
||||
f"buffer_{OBS_STATE.replace('.', '_')}_mean", torch.zeros(8, device=device)
|
||||
)
|
||||
policy.normalize_inputs.register_buffer(
|
||||
f"buffer_{OBS_STATE.replace('.', '_')}_std", torch.ones(8, device=device)
|
||||
)
|
||||
|
||||
# For observation.image (3x224x224 image)
|
||||
policy.normalize_inputs.register_buffer(
|
||||
f"buffer_{OBS_IMAGE.replace('.', '_')}_mean", torch.zeros(3, 224, 224, device=device)
|
||||
)
|
||||
policy.normalize_inputs.register_buffer(
|
||||
f"buffer_{OBS_IMAGE.replace('.', '_')}_std", torch.ones(3, 224, 224, device=device)
|
||||
)
|
||||
|
||||
|
||||
print("Configuring features for both policies...")
|
||||
configure_policy_features(official_policy)
|
||||
configure_policy_features(custom_policy)
|
||||
|
||||
# Verify that the models have identical parameters
|
||||
print("\n=== Model Parameter Comparison ===")
|
||||
official_params = dict(official_policy.named_parameters())
|
||||
custom_params = dict(custom_policy.named_parameters())
|
||||
|
||||
param_differences = []
|
||||
for name in official_params.keys():
|
||||
if name not in custom_params:
|
||||
param_differences.append(f"Missing parameter in custom model: {name}")
|
||||
else:
|
||||
diff = torch.abs(official_params[name] - custom_params[name]).max().item()
|
||||
if diff > 1e-8:
|
||||
param_differences.append(f"Parameter {name}: max difference = {diff:.2e}")
|
||||
|
||||
for name in custom_params.keys():
|
||||
if name not in official_params:
|
||||
param_differences.append(f"Extra parameter in custom model: {name}")
|
||||
|
||||
if param_differences:
|
||||
print("Parameter differences found:")
|
||||
for diff in param_differences[:10]: # Show first 10 differences
|
||||
print(f" {diff}")
|
||||
if len(param_differences) > 10:
|
||||
print(f" ... and {len(param_differences) - 10} more differences")
|
||||
else:
|
||||
print("All model parameters are identical!")
|
||||
|
||||
|
||||
# Get the raw models for direct comparison
|
||||
official_raw_model = official_policy.model
|
||||
custom_raw_model = custom_policy.model
|
||||
print("\n=== Model Details ===")
|
||||
print(f"Official raw model type: {type(official_raw_model)}")
|
||||
print(f"Custom raw model type: {type(custom_raw_model)}")
|
||||
print(f"Official model device: {next(official_raw_model.parameters()).device}")
|
||||
print(f"Custom model device: {next(custom_raw_model.parameters()).device}")
|
||||
|
||||
# Create lerobot-format input data (similar to DROID format from openpi example)
|
||||
example = {
|
||||
"joint_position": np.zeros(7, dtype=np.float32),
|
||||
"gripper_position": np.array([0.0], dtype=np.float32),
|
||||
"image": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
||||
"task": "pick up the object",
|
||||
}
|
||||
|
||||
print(f"\nProvided input keys: {list(example.keys())}")
|
||||
|
||||
print("\nPreparing inputs for direct model call...")
|
||||
|
||||
# Apply input transformation (similar to openpi's policy._input_transform)
|
||||
transformed_example = {}
|
||||
# Combine joint and gripper positions into state
|
||||
transformed_example[OBS_STATE] = np.concatenate([example["joint_position"], example["gripper_position"]])
|
||||
transformed_example[OBS_IMAGE] = example["image"]
|
||||
transformed_example["task"] = example["task"]
|
||||
|
||||
# Convert to PyTorch tensors and add batch dimension (as openpi example does)
|
||||
# Device is already defined above, use the official model device for consistency
|
||||
pytorch_inputs = {}
|
||||
for key, value in transformed_example.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
tensor_value = torch.from_numpy(value).to(device)
|
||||
# Add batch dimension
|
||||
if tensor_value.dim() > 0:
|
||||
tensor_value = tensor_value.unsqueeze(0)
|
||||
pytorch_inputs[key] = tensor_value
|
||||
elif isinstance(value, str):
|
||||
pytorch_inputs[key] = [value] # Convert to list format expected by policy
|
||||
else:
|
||||
pytorch_inputs[key] = value
|
||||
|
||||
# Convert image from HWC to CHW format for lerobot
|
||||
if OBS_IMAGE in pytorch_inputs:
|
||||
img = pytorch_inputs[OBS_IMAGE]
|
||||
if img.dim() == 4 and img.shape[-1] == 3: # BHWC -> BCHW
|
||||
img = img.permute(0, 3, 1, 2)
|
||||
# Convert to float and normalize to [0, 1] range
|
||||
img = img.float() / 255.0
|
||||
pytorch_inputs[OBS_IMAGE] = img
|
||||
|
||||
print(f"Transformed input keys: {list(pytorch_inputs.keys())}")
|
||||
for key, value in pytorch_inputs.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
print(f" {key}: {value.shape} {value.dtype}")
|
||||
else:
|
||||
print(f" {key}: {type(value)} - {value}")
|
||||
|
||||
# Reset both policies (clears the action queue)
|
||||
official_policy.reset()
|
||||
custom_policy.reset()
|
||||
|
||||
# Prepare inputs using the official policy (both models should have same preprocessing)
|
||||
print("Preparing inputs for both models...")
|
||||
images, img_masks = official_policy.prepare_images(pytorch_inputs)
|
||||
lang_tokens, lang_masks = official_policy.prepare_language(pytorch_inputs)
|
||||
state = official_policy.prepare_state(pytorch_inputs)
|
||||
|
||||
print("Prepared inputs:")
|
||||
print(f" Images: {len(images)} images")
|
||||
print(f" Language tokens shape: {lang_tokens.shape}")
|
||||
print(f" State shape: {state.shape}")
|
||||
for i, img in enumerate(images):
|
||||
print(f" Image {i} shape: {img.shape}")
|
||||
for i, mask in enumerate(img_masks):
|
||||
print(f" Image mask {i} shape: {mask.shape}")
|
||||
|
||||
# Compare both models with identical inputs
|
||||
print("\n🚀 Running MODEL COMPARISON...")
|
||||
|
||||
# Force torch.no_grad for consistent comparison
|
||||
with torch.no_grad():
|
||||
# Ensure reproducible noise generation for both models
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
|
||||
# Generate synthetic noise and time for the forward call
|
||||
batch_size = 1
|
||||
actions_shape = (
|
||||
batch_size,
|
||||
official_raw_model.config.n_action_steps,
|
||||
official_raw_model.config.max_action_dim,
|
||||
)
|
||||
|
||||
# Generate noise and time using direct PyTorch operations instead of model methods
|
||||
# This avoids any potential model-specific randomness
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=actions_shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Generate time using the same distribution as PI0FlowMatching.sample_time
|
||||
torch.manual_seed(RANDOM_SEED) # Reset for consistent time
|
||||
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
|
||||
time_beta = beta_dist.sample((batch_size,)).to(device=device, dtype=torch.float32)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
|
||||
print("\n=== Generated Inputs ===")
|
||||
print(f" Action shape: {actions_shape}")
|
||||
print(f" Noise shape: {noise.shape}")
|
||||
print(f" Time value: {time.item():.6f}")
|
||||
print(f" Noise sample (first 5 values): {noise.flatten()[:5].tolist()}")
|
||||
|
||||
# Create dummy actions for forward pass (required for training forward)
|
||||
dummy_actions = torch.zeros(actions_shape, dtype=torch.float32, device=device)
|
||||
|
||||
print("\n=== Running Forward Passes ===")
|
||||
|
||||
print("Running with model_addition_debugger_context for detailed analysis...")
|
||||
# Create separate debug paths for each model
|
||||
official_debug_path = os.path.join(debug_path, "official_model")
|
||||
custom_debug_path = os.path.join(debug_path, "custom_model")
|
||||
os.makedirs(official_debug_path, exist_ok=True)
|
||||
os.makedirs(custom_debug_path, exist_ok=True)
|
||||
# Set deterministic mode for forward pass
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
# Run official model with debugger
|
||||
print("Running Official Model forward pass with debugger...")
|
||||
with model_addition_debugger_context(
|
||||
official_raw_model,
|
||||
debug_path=official_debug_path,
|
||||
do_prune_layers=False, # Output ALL layers
|
||||
use_repr=not SAVE_TENSORS_TO_DISK,
|
||||
):
|
||||
official_loss = official_raw_model.forward(
|
||||
images=images,
|
||||
img_masks=img_masks,
|
||||
lang_tokens=lang_tokens,
|
||||
lang_masks=lang_masks,
|
||||
state=state,
|
||||
actions=dummy_actions,
|
||||
noise=noise,
|
||||
time=time,
|
||||
)
|
||||
# Reset seed before second forward pass to ensure any internal randomness is identical
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
# Run custom model with debugger
|
||||
print("Running Custom Model forward pass with debugger...")
|
||||
with model_addition_debugger_context(
|
||||
custom_raw_model,
|
||||
debug_path=custom_debug_path,
|
||||
do_prune_layers=False, # Output ALL layers
|
||||
use_repr=not SAVE_TENSORS_TO_DISK,
|
||||
):
|
||||
custom_loss = custom_raw_model.forward(
|
||||
images=images,
|
||||
img_masks=img_masks,
|
||||
lang_tokens=lang_tokens,
|
||||
lang_masks=lang_masks,
|
||||
state=state,
|
||||
actions=dummy_actions,
|
||||
noise=noise,
|
||||
time=time,
|
||||
)
|
||||
|
||||
print(f"Official model debug outputs saved to: {official_debug_path}")
|
||||
print(f"Custom model debug outputs saved to: {custom_debug_path}")
|
||||
|
||||
print("\n=== Output Comparison ===")
|
||||
print(f"Official model loss shape: {official_loss.shape}")
|
||||
print(f"Custom model loss shape: {custom_loss.shape}")
|
||||
|
||||
# Compare outputs
|
||||
loss_diff = torch.abs(official_loss - custom_loss)
|
||||
|
||||
print("\n=== Detailed Comparison ===")
|
||||
print("Loss difference stats:")
|
||||
print(f" Mean absolute difference: {loss_diff.mean().item():.8f}")
|
||||
print(f" Max absolute difference: {loss_diff.max().item():.8f}")
|
||||
print(f" Min absolute difference: {loss_diff.min().item():.8f}")
|
||||
print(f" Standard deviation of difference: {loss_diff.std().item():.8f}")
|
||||
|
||||
# Show some actual values for comparison
|
||||
print("\nSample output values:")
|
||||
print(f" Official model (first 5): {official_loss.flatten()[:5].tolist()}")
|
||||
print(f" Custom model (first 5): {custom_loss.flatten()[:5].tolist()}")
|
||||
print(f" Difference (first 5): {loss_diff.flatten()[:5].tolist()}")
|
||||
|
||||
# Determine if models are equivalent
|
||||
are_equivalent = loss_diff.max().item() < 1e-6
|
||||
print(f"\nModels are {'EQUIVALENT' if are_equivalent else 'DIFFERENT'}")
|
||||
print(f" (Max difference: {loss_diff.max().item():.8f}, Threshold: 1e-6)")
|
||||
|
||||
print(f"\nDetailed debugging outputs saved to: {debug_path}")
|
||||
# Save comparison results
|
||||
comparison_results = {
|
||||
"official_loss_stats": {
|
||||
"shape": list(official_loss.shape),
|
||||
"mean": official_loss.mean().item(),
|
||||
"std": official_loss.std().item(),
|
||||
"min": official_loss.min().item(),
|
||||
"max": official_loss.max().item(),
|
||||
},
|
||||
"custom_loss_stats": {
|
||||
"shape": list(custom_loss.shape),
|
||||
"mean": custom_loss.mean().item(),
|
||||
"std": custom_loss.std().item(),
|
||||
"min": custom_loss.min().item(),
|
||||
"max": custom_loss.max().item(),
|
||||
},
|
||||
"difference_stats": {
|
||||
"mean_abs_diff": loss_diff.mean().item(),
|
||||
"max_abs_diff": loss_diff.max().item(),
|
||||
"min_abs_diff": loss_diff.min().item(),
|
||||
"std_diff": loss_diff.std().item(),
|
||||
"are_equivalent": are_equivalent,
|
||||
},
|
||||
}
|
||||
|
||||
comparison_file = os.path.join(debug_path, "model_comparison_results.json")
|
||||
with open(comparison_file, "w") as f:
|
||||
json.dump(comparison_results, f, indent=2)
|
||||
print(f" Comparison results saved to: {comparison_file}")
|
||||
|
||||
# Save and upload transformed model if requested
|
||||
if SAVE_TRANSFORMED_MODEL:
|
||||
print("\nSaving Transformed Model...")
|
||||
if are_equivalent:
|
||||
print("Models are equivalent - proceeding with transformation and upload")
|
||||
else:
|
||||
print("Models are NOT equivalent, but proceeding with upload anyway")
|
||||
print(f" Max difference: {loss_diff.max().item():.2e}")
|
||||
print(" This might be useful for debugging or partial transformations")
|
||||
|
||||
# Create timestamp for README
|
||||
transformation_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
try:
|
||||
# Use the already working custom policy as the base for transformation
|
||||
print("Using already working custom policy as base for transformed model...")
|
||||
|
||||
# Deep copy the custom policy to create the transformed version
|
||||
from copy import deepcopy
|
||||
|
||||
transformed_policy = deepcopy(custom_policy)
|
||||
|
||||
print("Custom policy copied successfully - no additional configuration needed")
|
||||
|
||||
# Save locally first
|
||||
local_save_path = "./transformed_pi0_model"
|
||||
print(f"Saving transformed model locally to: {local_save_path}")
|
||||
transformed_policy.save_pretrained(local_save_path, safe_serialization=True)
|
||||
|
||||
# Save the tokenizer as well (required for complete model)
|
||||
transformed_policy.language_tokenizer.save_pretrained(local_save_path)
|
||||
|
||||
# Create a README with transformation details
|
||||
readme_content = f"""
|
||||
# PI0 Model - LeRobot Compatible Format
|
||||
|
||||
This model is a transformed version of `{custom_model_path}` with key names corrected to match the official LeRobot PI0 format.
|
||||
|
||||
## Transformation Applied
|
||||
|
||||
The original model had a different key naming convention. This model applies the following transformations:
|
||||
|
||||
1. **Model prefix**: Added `model.` prefix to all parameter keys
|
||||
2. **Tied weights**: Applied PI0Policy's built-in tied weights logic to create `embed_tokens.weight` from `lm_head.weight`
|
||||
3. **Key structure**: Applied standard PI0 key transformations for compatibility
|
||||
|
||||
## Verification
|
||||
|
||||
{"This transformed model produces **identical outputs**" if are_equivalent else "This transformed model has **slightly different outputs**"} (max difference = {loss_diff.max().item():.2e}) compared to the official model `{official_model_path}` when tested with the same inputs.
|
||||
{"**Models are EQUIVALENT** (difference < 1e-6)" if are_equivalent else "**Models are NOT equivalent** (difference >= 1e-6) - use with caution"}
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
# Load the model
|
||||
policy = PI0Policy.from_pretrained("{TRANSFORMED_MODEL_NAME}")
|
||||
|
||||
# Use for inference
|
||||
action = policy.select_action(observation_batch)
|
||||
```
|
||||
|
||||
## Original Model
|
||||
|
||||
- **Source**: {custom_model_path}
|
||||
- **Verified Against**: {official_model_path}
|
||||
|
||||
## Technical Details
|
||||
|
||||
- **Total Parameters**: {sum(p.numel() for p in transformed_policy.parameters()):,}
|
||||
- **Model Type**: PI0FlowMatching with PaliGemma + Expert Gemma
|
||||
- **Configuration**: Matches official PI0 configuration
|
||||
"""
|
||||
|
||||
readme_path = os.path.join(local_save_path, "README.md")
|
||||
with open(readme_path, "w") as f:
|
||||
f.write(readme_content.strip())
|
||||
|
||||
print(f"Model saved locally to: {local_save_path}")
|
||||
|
||||
# Upload to HuggingFace Hub if requested
|
||||
if UPLOAD_TO_HUB:
|
||||
print(f"\nUploading to HuggingFace Hub: {TRANSFORMED_MODEL_NAME}")
|
||||
|
||||
try:
|
||||
# Push to hub
|
||||
transformed_policy.push_to_hub(
|
||||
repo_id=TRANSFORMED_MODEL_NAME,
|
||||
commit_message=COMMIT_MESSAGE,
|
||||
private=False, # Make it public
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
print(f"Model successfully uploaded to: https://huggingface.co/{TRANSFORMED_MODEL_NAME}")
|
||||
print("You can now use this model directly without any transformations!")
|
||||
print("\n Usage:")
|
||||
print(" from lerobot.policies.pi0.modeling_pi0 import PI0Policy")
|
||||
print(f" policy = PI0Policy.from_pretrained('{TRANSFORMED_MODEL_NAME}')")
|
||||
|
||||
except Exception as upload_error:
|
||||
print(f"Failed to upload to HuggingFace Hub: {upload_error}")
|
||||
print(f"You can manually upload the model from: {local_save_path}")
|
||||
print(" Or set UPLOAD_TO_HUB = False and upload later")
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
print(f"Error saving transformed model: {str(e)}")
|
||||
print("Full traceback:")
|
||||
traceback.print_exc()
|
||||
print("The model transformation logic works, but saving failed")
|
||||
|
||||
else:
|
||||
print("\nModel transformation and upload disabled (SAVE_TRANSFORMED_MODEL = False)")
|
||||
+3
-1
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.3.3"
|
||||
version = "0.3.4"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
readme = "README.md"
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -106,6 +106,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1"]
|
||||
reachy2 = ["reachy2_sdk>=1.0.14"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||
@@ -141,6 +142,7 @@ all = [
|
||||
"lerobot[gamepad]",
|
||||
"lerobot[hopejr]",
|
||||
"lerobot[lekiwi]",
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[pi0]",
|
||||
|
||||
@@ -18,7 +18,7 @@ Helper to recalibrate your device (robot or teleoperator).
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=blue
|
||||
|
||||
@@ -60,7 +60,7 @@ class OpenCVCamera(Camera):
|
||||
or port changes, especially on Linux. Use the provided utility script to find
|
||||
available camera indices or paths:
|
||||
```bash
|
||||
python -m lerobot.find_cameras opencv
|
||||
lerobot-find-cameras opencv
|
||||
```
|
||||
|
||||
The camera's default settings (FPS, resolution, color mode) are used unless
|
||||
@@ -165,8 +165,7 @@ class OpenCVCamera(Camera):
|
||||
self.videocapture.release()
|
||||
self.videocapture = None
|
||||
raise ConnectionError(
|
||||
f"Failed to open {self}."
|
||||
f"Run `python -m lerobot.find_cameras opencv` to find available cameras."
|
||||
f"Failed to open {self}.Run `lerobot-find-cameras opencv` to find available cameras."
|
||||
)
|
||||
|
||||
self._configure_capture_settings()
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_reachy2_camera import Reachy2CameraConfig
|
||||
from .reachy2_camera import Reachy2Camera
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import CameraConfig, ColorMode
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("reachy2_camera")
|
||||
@dataclass
|
||||
class Reachy2CameraConfig(CameraConfig):
|
||||
"""Configuration class for Reachy 2 camera devices.
|
||||
|
||||
This class provides configuration options for Reachy 2 cameras,
|
||||
supporting both the teleop and depth cameras. It includes settings
|
||||
for resolution, frame rate, color mode, and the selection of the cameras.
|
||||
|
||||
Example configurations:
|
||||
```python
|
||||
# Basic configurations
|
||||
Reachy2CameraConfig(
|
||||
name="teleop",
|
||||
image_type="left",
|
||||
ip_address="192.168.0.200", # IP address of the robot
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
) # Left teleop camera, 640x480 @ 15FPS
|
||||
```
|
||||
|
||||
Attributes:
|
||||
name: Name of the camera device. Can be "teleop" or "depth".
|
||||
image_type: Type of image stream. For "teleop" camera, can be "left" or "right".
|
||||
For "depth" camera, can be "rgb" or "depth". (depth is not supported yet)
|
||||
fps: Requested frames per second for the color stream.
|
||||
width: Requested frame width in pixels for the color stream.
|
||||
height: Requested frame height in pixels for the color stream.
|
||||
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
|
||||
ip_address: IP address of the robot. Defaults to "localhost".
|
||||
port: Port number for the camera server. Defaults to 50065.
|
||||
|
||||
Note:
|
||||
- Only 3-channel color output (RGB/BGR) is currently supported.
|
||||
"""
|
||||
|
||||
name: str
|
||||
image_type: str
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
ip_address: str | None = "localhost"
|
||||
port: int = 50065
|
||||
# use_depth: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.name not in ["teleop", "depth"]:
|
||||
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
|
||||
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (
|
||||
self.name == "depth" and self.image_type not in ["rgb", "depth"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
|
||||
)
|
||||
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
@@ -0,0 +1,288 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
# Fix MSMF hardware transform compatibility for Windows before importing cv2
|
||||
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2
|
||||
import numpy as np
|
||||
from reachy2_sdk.media.camera import CameraView
|
||||
from reachy2_sdk.media.camera_manager import CameraManager
|
||||
|
||||
from lerobot.errors import DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from .configuration_reachy2_camera import ColorMode, Reachy2CameraConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Reachy2Camera(Camera):
|
||||
"""
|
||||
Manages Reachy 2 camera using Reachy 2 CameraManager.
|
||||
|
||||
This class provides a high-level interface to connect to, configure, and read
|
||||
frames from Reachy 2 cameras. It supports both synchronous and asynchronous
|
||||
frame reading.
|
||||
|
||||
An Reachy2Camera instance requires a camera name (e.g., "teleop") and an image
|
||||
type (e.g., "left") to be specified in the configuration.
|
||||
|
||||
The camera's default settings (FPS, resolution, color mode) are used unless
|
||||
overridden in the configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Reachy2CameraConfig):
|
||||
"""
|
||||
Initializes the Reachy2Camera instance.
|
||||
|
||||
Args:
|
||||
config: The configuration settings for the camera.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
|
||||
self.fps = config.fps
|
||||
self.color_mode = config.color_mode
|
||||
|
||||
self.cam_manager: CameraManager | None = None
|
||||
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Checks if the camera is currently connected and opened."""
|
||||
if self.config.name == "teleop":
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
|
||||
elif self.config.name == "depth":
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
|
||||
else:
|
||||
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
|
||||
|
||||
def connect(self, warmup: bool = True):
|
||||
"""
|
||||
Connects to the Reachy2 CameraManager as specified in the configuration.
|
||||
"""
|
||||
self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port)
|
||||
self.cam_manager.initialize_cameras()
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@staticmethod
|
||||
def find_cameras(ip_address: str = "localhost", port: int = 50065) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detects available Reachy 2 cameras.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains 'name', 'stereo',
|
||||
and the default profile properties (width, height, fps).
|
||||
"""
|
||||
initialized_cameras = []
|
||||
camera_manager = CameraManager(host=ip_address, port=port)
|
||||
|
||||
for camera in [camera_manager.teleop, camera_manager.depth]:
|
||||
if camera is None:
|
||||
continue
|
||||
|
||||
height, width, _, _, _, _, _ = camera.get_parameters()
|
||||
|
||||
camera_info = {
|
||||
"name": camera._cam_info.name,
|
||||
"stereo": camera._cam_info.stereo,
|
||||
"default_profile": {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"fps": 30,
|
||||
},
|
||||
}
|
||||
initialized_cameras.append(camera_info)
|
||||
|
||||
camera_manager.disconnect()
|
||||
return initialized_cameras
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
This is a blocking call.
|
||||
|
||||
Args:
|
||||
color_mode (Optional[ColorMode]): If specified, overrides the default
|
||||
color mode (`self.color_mode`) for this read operation (e.g.,
|
||||
request RGB even if default is BGR).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured frame as a NumPy array in the format
|
||||
(height, width, channels), using the specified or default
|
||||
color mode and applying any configured rotation.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
frame = None
|
||||
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
else:
|
||||
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
|
||||
if self.config.image_type == "left":
|
||||
frame = self.cam_manager.teleop.get_frame(CameraView.LEFT, size=(640, 480))[0]
|
||||
elif self.config.image_type == "right":
|
||||
frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT, size=(640, 480))[0]
|
||||
elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"):
|
||||
if self.config.image_type == "depth":
|
||||
frame = self.cam_manager.depth.get_depth_frame()[0]
|
||||
elif self.config.image_type == "rgb":
|
||||
frame = self.cam_manager.depth.get_frame(size=(640, 480))[0]
|
||||
|
||||
if frame is None:
|
||||
return np.empty((0, 0, 3), dtype=np.uint8)
|
||||
|
||||
if self.config.color_mode == "rgb":
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self):
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame
|
||||
2. Stores result in latest_frame (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = color_image
|
||||
self.new_frame_event.set()
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
"""Starts or restarts the background read thread if it's not running."""
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=0.1)
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
"""Signals the background read thread to stop and waits for it to join."""
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
This method retrieves the most recent frame captured by the background
|
||||
read thread. It does not block waiting for the camera hardware directly,
|
||||
but may wait up to timeout_ms for the background thread to provide a frame.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
to become available. Defaults to 200ms (0.2 seconds).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The latest captured frame as a NumPy array in the format
|
||||
(height, width, channels), processed according to configuration.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
TimeoutError: If no frame becomes available within the specified timeout.
|
||||
RuntimeError: If an unexpected error occurs.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
thread_alive = self.thread is not None and self.thread.is_alive()
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
|
||||
f"Read thread alive: {thread_alive}."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if frame is None:
|
||||
raise RuntimeError(f"Internal error: Event set but no frame available for {self}.")
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self):
|
||||
"""
|
||||
Stops the background read thread (if running).
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is already disconnected.
|
||||
"""
|
||||
if not self.is_connected and self.thread is None:
|
||||
raise DeviceNotConnectedError(f"{self} not connected.")
|
||||
|
||||
if self.thread is not None:
|
||||
self._stop_read_thread()
|
||||
|
||||
if self.cam_manager is not None:
|
||||
self.cam_manager.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -51,7 +51,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
Use the provided utility script to find available camera indices and default profiles:
|
||||
```bash
|
||||
python -m lerobot.find_cameras realsense
|
||||
lerobot-find-cameras realsense
|
||||
```
|
||||
|
||||
A `RealSenseCamera` instance requires a configuration object specifying the
|
||||
@@ -176,8 +176,7 @@ class RealSenseCamera(Camera):
|
||||
self.rs_profile = None
|
||||
self.rs_pipeline = None
|
||||
raise ConnectionError(
|
||||
f"Failed to open {self}."
|
||||
"Run `python -m lerobot.find_cameras realsense` to find available cameras."
|
||||
f"Failed to open {self}.Run `lerobot-find-cameras realsense` to find available cameras."
|
||||
) from e
|
||||
|
||||
self._configure_capture_settings()
|
||||
|
||||
@@ -37,8 +37,14 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
|
||||
from .realsense.camera_realsense import RealSenseCamera
|
||||
|
||||
cameras[key] = RealSenseCamera(cfg)
|
||||
|
||||
elif cfg.type == "reachy2_camera":
|
||||
from .reachy2_camera.reachy2_camera import Reachy2Camera
|
||||
|
||||
cameras[key] = Reachy2Camera(cfg)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
|
||||
|
||||
return cameras
|
||||
|
||||
|
||||
@@ -825,6 +825,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
if not episode_data:
|
||||
episode_buffer = self.episode_buffer
|
||||
else:
|
||||
episode_buffer = episode_data
|
||||
|
||||
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
|
||||
|
||||
|
||||
@@ -13,20 +13,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||
2.1. It will:
|
||||
This script will help you download any LeRobot dataset from the hub, convert it to the latest format, and
|
||||
upload it to your own repository. It will:
|
||||
|
||||
- Download the dataset from any source repository
|
||||
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Check consistency between these new stats and the old ones.
|
||||
- Remove the deprecated `stats.json`.
|
||||
- Update codebase_version in `info.json`.
|
||||
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||
- Update codebase_version in `info.json` to the latest version
|
||||
- Create proper version tags
|
||||
- Push the converted dataset to your specified destination repository
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 \
|
||||
--repo-id=aliberts/koch_tutorial
|
||||
--source-repo-id=IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot \
|
||||
--dest-repo-id=your-username/libero_spatial_converted \
|
||||
--episodes=0,1,2,3,4
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -37,8 +39,8 @@ import logging
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, write_info
|
||||
from lerobot.datasets.v21.convert_stats import convert_stats
|
||||
|
||||
V20 = "v2.0"
|
||||
V21 = "v2.1"
|
||||
@@ -54,48 +56,133 @@ class SuppressWarnings:
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
source_repo_id: str,
|
||||
dest_repo_id: str | None = None,
|
||||
episodes: str | None = None,
|
||||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
force_cache_sync: bool = True,
|
||||
):
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
"""
|
||||
Download a dataset from source_repo_id, convert it, and upload to dest_repo_id.
|
||||
|
||||
Args:
|
||||
source_repo_id: Source repository to download from
|
||||
dest_repo_id: Destination repository to upload to (defaults to source_repo_id)
|
||||
episodes: Comma-separated list of episode indices to include (e.g. "0,1,2,3")
|
||||
branch: Branch to upload to
|
||||
num_workers: Number of workers for stats computation
|
||||
force_cache_sync: Whether to force cache synchronization
|
||||
"""
|
||||
if dest_repo_id is None:
|
||||
dest_repo_id = source_repo_id
|
||||
|
||||
# Parse episodes list if provided
|
||||
episode_list = None
|
||||
if episodes:
|
||||
try:
|
||||
episode_list = [int(ep.strip()) for ep in episodes.split(",")]
|
||||
print(f"Loading episodes: {episode_list}")
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Invalid episodes format '{episodes}'. Use comma-separated integers like '0,1,2,3'"
|
||||
) from e
|
||||
|
||||
print(f"Downloading dataset from: {source_repo_id}")
|
||||
|
||||
# Try to load the dataset with different approaches to handle versioning issues
|
||||
dataset = None
|
||||
load_attempts = [
|
||||
{"revision": None}, # Try latest first
|
||||
{"revision": V20}, # Try v2.0
|
||||
{"revision": "main"}, # Try main branch
|
||||
]
|
||||
|
||||
for attempt in load_attempts:
|
||||
try:
|
||||
print(f"Attempting to load with revision: {attempt['revision']}")
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(
|
||||
source_repo_id, episodes=episode_list, force_cache_sync=force_cache_sync, **attempt
|
||||
)
|
||||
print("Successfully loaded dataset!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Failed with revision {attempt['revision']}: {e}")
|
||||
continue
|
||||
|
||||
if dataset is None:
|
||||
raise RuntimeError(f"Could not load dataset {source_repo_id} with any revision")
|
||||
|
||||
# Clean up old stats if present
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||
print("Removed existing episodes_stats.jsonl")
|
||||
|
||||
print("Converting stats to new format...")
|
||||
convert_stats(dataset, num_workers=num_workers)
|
||||
ref_stats = load_stats(dataset.root)
|
||||
check_aggregate_stats(dataset, ref_stats)
|
||||
|
||||
# Update dataset info
|
||||
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||
write_info(dataset.meta.info, dataset.root)
|
||||
print(f"Updated codebase_version to {CODEBASE_VERSION}")
|
||||
|
||||
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||
# Change repo_id for destination if different
|
||||
if dest_repo_id != source_repo_id:
|
||||
print(f"Changing repository from {source_repo_id} to {dest_repo_id}")
|
||||
dataset.repo_id = dest_repo_id
|
||||
|
||||
# delete old stats.json file
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
print(f"Pushing converted dataset to: {dest_repo_id}")
|
||||
dataset.push_to_hub(branch=branch, tag_version=False)
|
||||
|
||||
# Clean up old stats.json file locally and on hub
|
||||
if (dataset.root / STATS_PATH).is_file():
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
print("Removed local stats.json file")
|
||||
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
try:
|
||||
if hub_api.file_exists(
|
||||
repo_id=dest_repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dest_repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
print("Removed stats.json from hub")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not remove stats.json from hub: {e}")
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
# Create version tag
|
||||
try:
|
||||
hub_api.create_tag(dest_repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
print(f"Created tag {CODEBASE_VERSION} for {dest_repo_id}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not create tag: {e}")
|
||||
|
||||
print(f"✅ Successfully converted and uploaded dataset to {dest_repo_id}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Download, convert, and re-upload LeRobot datasets with proper versioning"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
"--source-repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
help="Source repository identifier to download from (e.g. 'IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dest-repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Destination repository identifier to upload to. Defaults to source-repo-id if not specified.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma-separated list of episode indices to include (e.g. '0,1,2,3,4'). If not specified, all episodes are included.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
@@ -109,6 +196,22 @@ if __name__ == "__main__":
|
||||
default=4,
|
||||
help="Number of workers for parallelizing stats compute. Defaults to 4.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-cache-sync",
|
||||
action="store_true",
|
||||
help="Skip forcing cache synchronization (faster but may use cached data)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_dataset(**vars(args))
|
||||
|
||||
# Convert args to match function signature
|
||||
convert_args = {
|
||||
"source_repo_id": args.source_repo_id,
|
||||
"dest_repo_id": args.dest_repo_id,
|
||||
"episodes": args.episodes,
|
||||
"branch": args.branch,
|
||||
"num_workers": args.num_workers,
|
||||
"force_cache_sync": not args.no_cache_sync,
|
||||
}
|
||||
|
||||
convert_dataset(**convert_args)
|
||||
|
||||
@@ -20,7 +20,7 @@ Helper to find the camera devices available in your system.
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.find_cameras
|
||||
lerobot-find-cameras
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ Helper to find the USB port associated with your MotorsBus.
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@@ -107,6 +107,8 @@ X_SERIES_ENCODINGS_TABLE = {
|
||||
"Goal_PWM": X_SERIES_CONTROL_TABLE["Goal_PWM"][1],
|
||||
"Goal_Current": X_SERIES_CONTROL_TABLE["Goal_Current"][1],
|
||||
"Goal_Velocity": X_SERIES_CONTROL_TABLE["Goal_Velocity"][1],
|
||||
"Goal_Position": X_SERIES_CONTROL_TABLE["Goal_Position"][1],
|
||||
"Present_Position": X_SERIES_CONTROL_TABLE["Present_Position"][1],
|
||||
"Present_PWM": X_SERIES_CONTROL_TABLE["Present_PWM"][1],
|
||||
"Present_Current": X_SERIES_CONTROL_TABLE["Present_Current"][1],
|
||||
"Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1],
|
||||
|
||||
@@ -222,7 +222,7 @@ class MotorsBus(abc.ABC):
|
||||
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
||||
To find the port, you can run our utility script:
|
||||
```bash
|
||||
python -m lerobot.find_port.py
|
||||
lerobot-find-port.py
|
||||
>>> Finding all available ports for the MotorsBus.
|
||||
>>> ["/dev/tty.usbmodem575E0032081", "/dev/tty.usbmodem575E0031751"]
|
||||
>>> Remove the usb cable from your MotorsBus and press Enter when done.
|
||||
@@ -446,7 +446,7 @@ class MotorsBus(abc.ABC):
|
||||
except (FileNotFoundError, OSError, serial.SerialException) as e:
|
||||
raise ConnectionError(
|
||||
f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port."
|
||||
"\nTry running `python -m lerobot.find_port`\n"
|
||||
"\nTry running `lerobot-find-port`\n"
|
||||
) from e
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -30,7 +30,7 @@ pip install -e ".[pi0]"
|
||||
|
||||
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/pi0 \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
@@ -38,7 +38,7 @@ python -m lerobot.scripts.train \
|
||||
Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
|
||||
pretrained with VLM default parameters before pi0 finetuning:
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=pi0 \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
@@ -25,14 +25,14 @@ Disclaimer: It is not expected to perform as well as the original implementation
|
||||
|
||||
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/pi0fast_base \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of training the pi0+FAST neural network with from scratch:
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=pi0fast \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
@@ -28,7 +28,7 @@ pip install -e ".[smolvla]"
|
||||
|
||||
Example of finetuning the smolvla pretrained model (`smolvla_base`):
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
@@ -38,7 +38,7 @@ python -m lerobot.scripts.train \
|
||||
Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM,
|
||||
and an action expert.
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
|
||||
+10
-3
@@ -18,7 +18,7 @@ Records a dataset. Actions for the robot can be either generated by teleoperatio
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{laptop: {type: opencv, camera_index: 0, width: 640, height: 480}}" \
|
||||
@@ -36,7 +36,7 @@ python -m lerobot.record \
|
||||
|
||||
Example recording with bimanual so100:
|
||||
```shell
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=bi_so100_follower \
|
||||
--robot.left_arm_port=/dev/tty.usbmodem5A460851411 \
|
||||
--robot.right_arm_port=/dev/tty.usbmodem5A460812391 \
|
||||
@@ -209,7 +209,14 @@ def record_loop(
|
||||
(
|
||||
t
|
||||
for t in teleop
|
||||
if isinstance(t, (so100_leader.SO100Leader, so101_leader.SO101Leader, koch_leader.KochLeader))
|
||||
if isinstance(
|
||||
t,
|
||||
(
|
||||
so100_leader.SO100Leader,
|
||||
so101_leader.SO101Leader,
|
||||
koch_leader.KochLeader,
|
||||
),
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ Replays the actions of an episode from a dataset on a robot.
|
||||
Examples:
|
||||
|
||||
```shell
|
||||
python -m lerobot.replay \
|
||||
lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
@@ -28,7 +28,7 @@ python -m lerobot.replay \
|
||||
|
||||
Example replay with bimanual so100:
|
||||
```shell
|
||||
python -m lerobot.replay \
|
||||
lerobot-replay \
|
||||
--robot.type=bi_so100_follower \
|
||||
--robot.left_arm_port=/dev/tty.usbmodem5A460851411 \
|
||||
--robot.right_arm_port=/dev/tty.usbmodem5A460812391 \
|
||||
@@ -55,6 +55,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
reachy2,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
|
||||
@@ -29,10 +29,10 @@ class BiSO100FollowerConfig(RobotConfig):
|
||||
|
||||
# Optional
|
||||
left_arm_disable_torque_on_disconnect: bool = True
|
||||
left_arm_max_relative_target: int | None = None
|
||||
left_arm_max_relative_target: float | dict[str, float] | None = None
|
||||
left_arm_use_degrees: bool = False
|
||||
right_arm_disable_torque_on_disconnect: bool = True
|
||||
right_arm_max_relative_target: int | None = None
|
||||
right_arm_max_relative_target: float | dict[str, float] | None = None
|
||||
right_arm_use_degrees: bool = False
|
||||
|
||||
# cameras (shared between both arms)
|
||||
|
||||
@@ -44,8 +44,8 @@ class HopeJrArmConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -28,9 +28,9 @@ class KochFollowerConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -110,6 +110,7 @@ class KochFollower(Robot):
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.bus.disable_torque()
|
||||
if self.calibration:
|
||||
# Calibration file exists, ask user whether to use it or run new calibration
|
||||
user_input = input(
|
||||
@@ -120,7 +121,6 @@ class KochFollower(Robot):
|
||||
self.bus.write_calibration(self.calibration)
|
||||
return
|
||||
logger.info(f"\nRunning calibration of {self}")
|
||||
self.bus.disable_torque()
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
|
||||
@@ -39,9 +39,9 @@ class LeKiwiConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
||||
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_reachy2 import Reachy2RobotConfig
|
||||
from .robot_reachy2 import (
|
||||
REACHY2_ANTENNAS_JOINTS,
|
||||
REACHY2_L_ARM_JOINTS,
|
||||
REACHY2_NECK_JOINTS,
|
||||
REACHY2_R_ARM_JOINTS,
|
||||
REACHY2_VEL,
|
||||
Reachy2Robot,
|
||||
)
|
||||
@@ -0,0 +1,107 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
from lerobot.cameras.configs import ColorMode
|
||||
from lerobot.cameras.reachy2_camera import Reachy2CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("reachy2")
|
||||
@dataclass
|
||||
class Reachy2RobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors.
|
||||
max_relative_target: float | None = None
|
||||
|
||||
# IP address of the Reachy 2 robot
|
||||
ip_address: str | None = "localhost"
|
||||
|
||||
# If True, turn_off_smoothly() will be sent to the robot before disconnecting.
|
||||
disable_torque_on_disconnect: bool = False
|
||||
|
||||
# Tag for external commands control
|
||||
# Set to True if you use an external commands system to control the robot,
|
||||
# such as the official teleoperation application: https://github.com/pollen-robotics/Reachy2Teleoperation
|
||||
# If True, robot.send_action() will not send commands to the robot.
|
||||
use_external_commands: bool = False
|
||||
|
||||
# Robot parts
|
||||
# Set to False to not add the corresponding joints part to the robot list of joints.
|
||||
# By default, all parts are set to True.
|
||||
with_mobile_base: bool = True
|
||||
with_l_arm: bool = True
|
||||
with_r_arm: bool = True
|
||||
with_neck: bool = True
|
||||
with_antennas: bool = True
|
||||
|
||||
# Robot cameras
|
||||
# Set to True if you want to use the corresponding cameras in the observations.
|
||||
# By default, only the teleop cameras are used.
|
||||
with_left_teleop_camera: bool = True
|
||||
with_right_teleop_camera: bool = True
|
||||
with_torso_camera: bool = False
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Add cameras with same ip_address as the robot
|
||||
if self.with_left_teleop_camera:
|
||||
self.cameras["teleop_left"] = Reachy2CameraConfig(
|
||||
name="teleop",
|
||||
image_type="left",
|
||||
ip_address=self.ip_address,
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
)
|
||||
if self.with_right_teleop_camera:
|
||||
self.cameras["teleop_right"] = Reachy2CameraConfig(
|
||||
name="teleop",
|
||||
image_type="right",
|
||||
ip_address=self.ip_address,
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
)
|
||||
if self.with_torso_camera:
|
||||
self.cameras["torso_rgb"] = Reachy2CameraConfig(
|
||||
name="depth",
|
||||
image_type="rgb",
|
||||
ip_address=self.ip_address,
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
)
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
if not (
|
||||
self.with_mobile_base
|
||||
or self.with_l_arm
|
||||
or self.with_r_arm
|
||||
or self.with_neck
|
||||
or self.with_antennas
|
||||
):
|
||||
raise ValueError(
|
||||
"No Reachy2Robot part used.\n"
|
||||
"At least one part of the robot must be set to True "
|
||||
"(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)"
|
||||
)
|
||||
@@ -0,0 +1,230 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from reachy2_sdk import ReachySDK
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
from .configuration_reachy2 import Reachy2RobotConfig
|
||||
|
||||
# {lerobot_keys: reachy2_sdk_keys}
|
||||
REACHY2_NECK_JOINTS = {
|
||||
"neck_yaw.pos": "head.neck.yaw",
|
||||
"neck_pitch.pos": "head.neck.pitch",
|
||||
"neck_roll.pos": "head.neck.roll",
|
||||
}
|
||||
|
||||
REACHY2_ANTENNAS_JOINTS = {
|
||||
"l_antenna.pos": "head.l_antenna",
|
||||
"r_antenna.pos": "head.r_antenna",
|
||||
}
|
||||
|
||||
REACHY2_R_ARM_JOINTS = {
|
||||
"r_shoulder_pitch.pos": "r_arm.shoulder.pitch",
|
||||
"r_shoulder_roll.pos": "r_arm.shoulder.roll",
|
||||
"r_elbow_yaw.pos": "r_arm.elbow.yaw",
|
||||
"r_elbow_pitch.pos": "r_arm.elbow.pitch",
|
||||
"r_wrist_roll.pos": "r_arm.wrist.roll",
|
||||
"r_wrist_pitch.pos": "r_arm.wrist.pitch",
|
||||
"r_wrist_yaw.pos": "r_arm.wrist.yaw",
|
||||
"r_gripper.pos": "r_arm.gripper",
|
||||
}
|
||||
|
||||
REACHY2_L_ARM_JOINTS = {
|
||||
"l_shoulder_pitch.pos": "l_arm.shoulder.pitch",
|
||||
"l_shoulder_roll.pos": "l_arm.shoulder.roll",
|
||||
"l_elbow_yaw.pos": "l_arm.elbow.yaw",
|
||||
"l_elbow_pitch.pos": "l_arm.elbow.pitch",
|
||||
"l_wrist_roll.pos": "l_arm.wrist.roll",
|
||||
"l_wrist_pitch.pos": "l_arm.wrist.pitch",
|
||||
"l_wrist_yaw.pos": "l_arm.wrist.yaw",
|
||||
"l_gripper.pos": "l_arm.gripper",
|
||||
}
|
||||
|
||||
REACHY2_VEL = {
|
||||
"mobile_base.vx": "vx",
|
||||
"mobile_base.vy": "vy",
|
||||
"mobile_base.vtheta": "vtheta",
|
||||
}
|
||||
|
||||
|
||||
class Reachy2Robot(Robot):
|
||||
"""
|
||||
[Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics.
|
||||
"""
|
||||
|
||||
config_class = Reachy2RobotConfig
|
||||
name = "reachy2"
|
||||
|
||||
def __init__(self, config: Reachy2RobotConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
self.robot_type = self.config.type
|
||||
self.use_external_commands = self.config.use_external_commands
|
||||
|
||||
self.reachy: None | ReachySDK = None
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
self.logs: dict[str, float] = {}
|
||||
|
||||
self.joints_dict: dict[str, str] = self._generate_joints_dict()
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict[str, Any]:
|
||||
return {**self.motors_features, **self.camera_features}
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self.motors_features
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict[str, tuple[int | None, int | None, int]]:
|
||||
return {cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras}
|
||||
|
||||
@property
|
||||
def motors_features(self) -> dict[str, type]:
|
||||
if self.config.with_mobile_base:
|
||||
return {
|
||||
**dict.fromkeys(
|
||||
self.joints_dict.keys(),
|
||||
float,
|
||||
),
|
||||
**dict.fromkeys(
|
||||
REACHY2_VEL.keys(),
|
||||
float,
|
||||
),
|
||||
}
|
||||
else:
|
||||
return dict.fromkeys(self.joints_dict.keys(), float)
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.reachy.is_connected() if self.reachy is not None else False
|
||||
|
||||
def connect(self, calibrate: bool = False) -> None:
|
||||
self.reachy = ReachySDK(self.config.ip_address)
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
self.configure()
|
||||
|
||||
def configure(self) -> None:
|
||||
if self.reachy is not None:
|
||||
self.reachy.turn_on()
|
||||
self.reachy.reset_default_limits()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
def calibrate(self) -> None:
|
||||
pass
|
||||
|
||||
def _generate_joints_dict(self) -> dict[str, str]:
|
||||
joints = {}
|
||||
if self.config.with_neck:
|
||||
joints.update(REACHY2_NECK_JOINTS)
|
||||
if self.config.with_l_arm:
|
||||
joints.update(REACHY2_L_ARM_JOINTS)
|
||||
if self.config.with_r_arm:
|
||||
joints.update(REACHY2_R_ARM_JOINTS)
|
||||
if self.config.with_antennas:
|
||||
joints.update(REACHY2_ANTENNAS_JOINTS)
|
||||
return joints
|
||||
|
||||
def _get_state(self) -> dict[str, float]:
|
||||
if self.reachy is not None:
|
||||
pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()}
|
||||
if not self.config.with_mobile_base:
|
||||
return pos_dict
|
||||
vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
|
||||
return {**pos_dict, **vel_dict}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def get_observation(self) -> dict[str, np.ndarray]:
|
||||
obs_dict: dict[str, Any] = {}
|
||||
|
||||
# Read Reachy 2 state
|
||||
before_read_t = time.perf_counter()
|
||||
obs_dict.update(self._get_state())
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
if self.reachy is not None:
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
before_write_t = time.perf_counter()
|
||||
|
||||
vel = {}
|
||||
goal_pos = {}
|
||||
for key, val in action.items():
|
||||
if key not in self.joints_dict:
|
||||
if key not in REACHY2_VEL:
|
||||
raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.")
|
||||
else:
|
||||
vel[REACHY2_VEL[key]] = float(val)
|
||||
else:
|
||||
if not self.use_external_commands and self.config.max_relative_target is not None:
|
||||
goal_pos[key] = float(val)
|
||||
goal_present_pos = {
|
||||
key: (
|
||||
goal_pos[key],
|
||||
self.reachy.joints[self.joints_dict[key]].present_position,
|
||||
)
|
||||
}
|
||||
safe_goal_pos = ensure_safe_goal_position(
|
||||
goal_present_pos, float(self.config.max_relative_target)
|
||||
)
|
||||
val = safe_goal_pos[key]
|
||||
self.reachy.joints[self.joints_dict[key]].goal_position = float(val)
|
||||
|
||||
if self.config.with_mobile_base:
|
||||
self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"])
|
||||
|
||||
# We don't send the goal positions if we control Reachy 2 externally
|
||||
if not self.use_external_commands:
|
||||
self.reachy.send_goal_positions()
|
||||
if self.config.with_mobile_base:
|
||||
self.reachy.mobile_base.send_speed_command()
|
||||
|
||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||
return action
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self.reachy is not None:
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
if self.config.disable_torque_on_disconnect:
|
||||
self.reachy.turn_off_smoothly()
|
||||
self.reachy.disconnect()
|
||||
@@ -30,9 +30,9 @@ class SO100FollowerConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -161,6 +161,11 @@ class SO100Follower(Robot):
|
||||
self.bus.write("I_Coefficient", motor, 0)
|
||||
self.bus.write("D_Coefficient", motor, 32)
|
||||
|
||||
if motor == "gripper":
|
||||
self.bus.write("Max_Torque_Limit", motor, 500) # 50% of max torque to avoid burnout
|
||||
self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout
|
||||
self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
for motor in reversed(self.bus.motors):
|
||||
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||
|
||||
@@ -30,9 +30,9 @@ class SO101FollowerConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -157,6 +157,13 @@ class SO101Follower(Robot):
|
||||
self.bus.write("I_Coefficient", motor, 0)
|
||||
self.bus.write("D_Coefficient", motor, 32)
|
||||
|
||||
if motor == "gripper":
|
||||
self.bus.write(
|
||||
"Max_Torque_Limit", motor, 500
|
||||
) # 50% of the max torque limit to avoid burnout
|
||||
self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout
|
||||
self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
for motor in reversed(self.bus.motors):
|
||||
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||
|
||||
@@ -24,11 +24,6 @@ from ..config import RobotConfig
|
||||
@RobotConfig.register_subclass("stretch3")
|
||||
@dataclass
|
||||
class Stretch3RobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
|
||||
@@ -61,6 +61,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
from .bi_so100_follower import BiSO100Follower
|
||||
|
||||
return BiSO100Follower(config)
|
||||
elif config.type == "reachy2":
|
||||
from .reachy2 import Reachy2Robot
|
||||
|
||||
return Reachy2Robot(config)
|
||||
elif config.type == "mock_robot":
|
||||
from tests.mocks.mock_robot import MockRobot
|
||||
|
||||
@@ -70,7 +74,7 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
|
||||
|
||||
def ensure_safe_goal_position(
|
||||
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float]
|
||||
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[str, float]
|
||||
) -> dict[str, float]:
|
||||
"""Caps relative action target magnitude for safety."""
|
||||
|
||||
|
||||
@@ -141,10 +141,10 @@ python lerobot/scripts/control_robot.py \
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
To train a policy to control your robot, use the [`lerobot-train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/aloha_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_aloha_test \
|
||||
|
||||
@@ -28,15 +28,15 @@ class ViperXConfig(RobotConfig):
|
||||
|
||||
# /!\ FOR SAFETY, READ THIS /!\
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
|
||||
# When you feel more confident with teleoperation or running the policy, you can extend
|
||||
# this safety limit and even removing it by setting it to `null`.
|
||||
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
|
||||
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
|
||||
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
|
||||
max_relative_target: int | None = 5
|
||||
max_relative_target: float | dict[str, float] = 5.0
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -21,7 +21,7 @@ You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/di
|
||||
for 10 episodes.
|
||||
|
||||
```
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/diffusion_pusht \
|
||||
--env.type=pusht \
|
||||
--eval.batch_size=10 \
|
||||
@@ -32,7 +32,7 @@ python -m lerobot.scripts.eval \
|
||||
|
||||
OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes.
|
||||
```
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \
|
||||
--env.type=pusht \
|
||||
--eval.batch_size=10 \
|
||||
|
||||
@@ -302,11 +302,6 @@ class RobotClient:
|
||||
|
||||
self.logger.debug(f"Current latest action: {latest_action}")
|
||||
|
||||
# Get queue state before changes
|
||||
old_size, old_timesteps = self._inspect_action_queue()
|
||||
if not old_timesteps:
|
||||
old_timesteps = [latest_action] # queue was empty
|
||||
|
||||
# Get queue state before changes
|
||||
old_size, old_timesteps = self._inspect_action_queue()
|
||||
if not old_timesteps:
|
||||
|
||||
@@ -18,7 +18,7 @@ Helper to set motor ids and baudrate.
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem575E0031751
|
||||
```
|
||||
|
||||
@@ -18,7 +18,7 @@ Simple script to control a robot from teleoperation.
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.teleoperate \
|
||||
lerobot-teleoperate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
@@ -32,7 +32,7 @@ python -m lerobot.teleoperate \
|
||||
Example teleoperation with bimanual so100:
|
||||
|
||||
```shell
|
||||
python -m lerobot.teleoperate \
|
||||
lerobot-teleoperate \
|
||||
--robot.type=bi_so100_follower \
|
||||
--robot.left_arm_port=/dev/tty.usbmodem5A460851411 \
|
||||
--robot.right_arm_port=/dev/tty.usbmodem5A460812391 \
|
||||
|
||||
@@ -88,6 +88,7 @@ class KochLeader(Teleoperator):
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.bus.disable_torque()
|
||||
if self.calibration:
|
||||
# Calibration file exists, ask user whether to use it or run new calibration
|
||||
user_input = input(
|
||||
@@ -98,7 +99,6 @@ class KochLeader(Teleoperator):
|
||||
self.bus.write_calibration(self.calibration)
|
||||
return
|
||||
logger.info(f"\nRunning calibration of {self}")
|
||||
self.bus.disable_torque()
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
|
||||
from .reachy2_teleoperator import (
|
||||
REACHY2_ANTENNAS_JOINTS,
|
||||
REACHY2_L_ARM_JOINTS,
|
||||
REACHY2_NECK_JOINTS,
|
||||
REACHY2_R_ARM_JOINTS,
|
||||
REACHY2_VEL,
|
||||
Reachy2Teleoperator,
|
||||
)
|
||||
@@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("reachy2_teleoperator")
|
||||
@dataclass
|
||||
class Reachy2TeleoperatorConfig(TeleoperatorConfig):
|
||||
# IP address of the Reachy 2 robot used as teleoperator
|
||||
ip_address: str | None = "localhost"
|
||||
|
||||
# Whether to use the present position of the joints as actions
|
||||
# if False, the goal position of the joints will be used
|
||||
use_present_position: bool = False
|
||||
|
||||
# Which parts of the robot to use
|
||||
with_mobile_base: bool = True
|
||||
with_l_arm: bool = True
|
||||
with_r_arm: bool = True
|
||||
with_neck: bool = True
|
||||
with_antennas: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if not (
|
||||
self.with_mobile_base
|
||||
or self.with_l_arm
|
||||
or self.with_r_arm
|
||||
or self.with_neck
|
||||
or self.with_antennas
|
||||
):
|
||||
raise ValueError(
|
||||
"No Reachy2Teleoperator part used.\n"
|
||||
"At least one part of the robot must be set to True "
|
||||
"(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)"
|
||||
)
|
||||
@@ -0,0 +1,164 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from reachy2_sdk import ReachySDK
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# {lerobot_keys: reachy2_sdk_keys}
|
||||
REACHY2_NECK_JOINTS = {
|
||||
"neck_yaw.pos": "head.neck.yaw",
|
||||
"neck_pitch.pos": "head.neck.pitch",
|
||||
"neck_roll.pos": "head.neck.roll",
|
||||
}
|
||||
|
||||
REACHY2_ANTENNAS_JOINTS = {
|
||||
"l_antenna.pos": "head.l_antenna",
|
||||
"r_antenna.pos": "head.r_antenna",
|
||||
}
|
||||
|
||||
REACHY2_R_ARM_JOINTS = {
|
||||
"r_shoulder_pitch.pos": "r_arm.shoulder.pitch",
|
||||
"r_shoulder_roll.pos": "r_arm.shoulder.roll",
|
||||
"r_elbow_yaw.pos": "r_arm.elbow.yaw",
|
||||
"r_elbow_pitch.pos": "r_arm.elbow.pitch",
|
||||
"r_wrist_roll.pos": "r_arm.wrist.roll",
|
||||
"r_wrist_pitch.pos": "r_arm.wrist.pitch",
|
||||
"r_wrist_yaw.pos": "r_arm.wrist.yaw",
|
||||
"r_gripper.pos": "r_arm.gripper",
|
||||
}
|
||||
|
||||
REACHY2_L_ARM_JOINTS = {
|
||||
"l_shoulder_pitch.pos": "l_arm.shoulder.pitch",
|
||||
"l_shoulder_roll.pos": "l_arm.shoulder.roll",
|
||||
"l_elbow_yaw.pos": "l_arm.elbow.yaw",
|
||||
"l_elbow_pitch.pos": "l_arm.elbow.pitch",
|
||||
"l_wrist_roll.pos": "l_arm.wrist.roll",
|
||||
"l_wrist_pitch.pos": "l_arm.wrist.pitch",
|
||||
"l_wrist_yaw.pos": "l_arm.wrist.yaw",
|
||||
"l_gripper.pos": "l_arm.gripper",
|
||||
}
|
||||
|
||||
REACHY2_VEL = {
|
||||
"mobile_base.vx": "vx",
|
||||
"mobile_base.vy": "vy",
|
||||
"mobile_base.vtheta": "vtheta",
|
||||
}
|
||||
|
||||
|
||||
class Reachy2Teleoperator(Teleoperator):
|
||||
"""
|
||||
[Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics.
|
||||
"""
|
||||
|
||||
config_class = Reachy2TeleoperatorConfig
|
||||
name = "reachy2_specific"
|
||||
|
||||
def __init__(self, config: Reachy2TeleoperatorConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.reachy: None | ReachySDK = None
|
||||
|
||||
self.joints_dict: dict[str, str] = self._generate_joints_dict()
|
||||
|
||||
def _generate_joints_dict(self) -> dict[str, str]:
|
||||
joints = {}
|
||||
if self.config.with_neck:
|
||||
joints.update(REACHY2_NECK_JOINTS)
|
||||
if self.config.with_l_arm:
|
||||
joints.update(REACHY2_L_ARM_JOINTS)
|
||||
if self.config.with_r_arm:
|
||||
joints.update(REACHY2_R_ARM_JOINTS)
|
||||
if self.config.with_antennas:
|
||||
joints.update(REACHY2_ANTENNAS_JOINTS)
|
||||
return joints
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
if self.config.with_mobile_base:
|
||||
return {
|
||||
**dict.fromkeys(
|
||||
self.joints_dict.keys(),
|
||||
float,
|
||||
),
|
||||
**dict.fromkeys(
|
||||
REACHY2_VEL.keys(),
|
||||
float,
|
||||
),
|
||||
}
|
||||
else:
|
||||
return dict.fromkeys(self.joints_dict.keys(), float)
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.reachy.is_connected() if self.reachy is not None else False
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.reachy = ReachySDK(self.config.ip_address)
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
def calibrate(self) -> None:
|
||||
pass
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def get_action(self) -> dict[str, float]:
|
||||
start = time.perf_counter()
|
||||
|
||||
if self.reachy and self.is_connected:
|
||||
if self.config.use_present_position:
|
||||
joint_action = {
|
||||
k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()
|
||||
}
|
||||
else:
|
||||
joint_action = {k: self.reachy.joints[v].goal_position for k, v in self.joints_dict.items()}
|
||||
|
||||
if not self.config.with_mobile_base:
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return joint_action
|
||||
|
||||
if self.config.use_present_position:
|
||||
vel_action = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
|
||||
else:
|
||||
vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()}
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return {**joint_action, **vel_action}
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self.reachy and self.is_connected:
|
||||
self.reachy.disconnect()
|
||||
@@ -65,5 +65,9 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
from .bi_so100_leader import BiSO100Leader
|
||||
|
||||
return BiSO100Leader(config)
|
||||
elif config.type == "reachy2_teleoperator":
|
||||
from .reachy2_teleoperator import Reachy2Teleoperator
|
||||
|
||||
return Reachy2Teleoperator(config)
|
||||
else:
|
||||
raise ValueError(config.type)
|
||||
|
||||
@@ -44,7 +44,7 @@ Below is the short version on how to train and run inference/eval:
|
||||
### Train from scratch
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/<dataset> \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/<desired_policy_repo_id> \
|
||||
@@ -59,7 +59,7 @@ _Writes checkpoints to `outputs/train/<desired_policy_repo_id>/checkpoints/`._
|
||||
### Evaluate the policy/run inference
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--dataset.repo_id=<hf_user>/eval_<dataset> \
|
||||
--policy.path=<hf_user>/<desired_policy_repo_id> \
|
||||
|
||||
@@ -17,10 +17,9 @@ import time
|
||||
|
||||
|
||||
def busy_wait(seconds):
|
||||
if platform.system() == "Darwin":
|
||||
# On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
|
||||
if platform.system() == "Darwin" or platform.system() == "Windows":
|
||||
# On Mac and Windows, `time.sleep` is not accurate and we need to use this while loop trick,
|
||||
# but it consumes CPU cycles.
|
||||
# TODO(rcadene): find an alternative: from python 11, time.sleep is precise
|
||||
end_time = time.perf_counter() + seconds
|
||||
while time.perf_counter() < end_time:
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.cameras.reachy2_camera import Reachy2Camera, Reachy2CameraConfig
|
||||
from lerobot.errors import DeviceNotConnectedError
|
||||
|
||||
PARAMS = [
|
||||
("teleop", "left"),
|
||||
("teleop", "right"),
|
||||
("depth", "rgb"),
|
||||
# ("depth", "depth"), # Depth camera is not available yet
|
||||
]
|
||||
|
||||
|
||||
def _make_cam_manager_mock():
|
||||
c = MagicMock(name="CameraManagerMock")
|
||||
|
||||
teleop = MagicMock(name="TeleopCam")
|
||||
teleop.width = 640
|
||||
teleop.height = 480
|
||||
teleop.get_frame = MagicMock(
|
||||
side_effect=lambda *_, **__: (
|
||||
np.zeros((480, 640, 3), dtype=np.uint8),
|
||||
time.time(),
|
||||
)
|
||||
)
|
||||
|
||||
depth = MagicMock(name="DepthCam")
|
||||
depth.width = 640
|
||||
depth.height = 480
|
||||
depth.get_frame = MagicMock(
|
||||
side_effect=lambda *_, **__: (
|
||||
np.zeros((480, 640, 3), dtype=np.uint8),
|
||||
time.time(),
|
||||
)
|
||||
)
|
||||
|
||||
c.is_connected.return_value = True
|
||||
c.teleop = teleop
|
||||
c.depth = depth
|
||||
|
||||
def _connect():
|
||||
c.teleop = teleop
|
||||
c.depth = depth
|
||||
c.is_connected.return_value = True
|
||||
|
||||
def _disconnect():
|
||||
c.teleop = None
|
||||
c.depth = None
|
||||
c.is_connected.return_value = False
|
||||
|
||||
c.connect = MagicMock(side_effect=_connect)
|
||||
c.disconnect = MagicMock(side_effect=_disconnect)
|
||||
|
||||
# Mock methods
|
||||
c.initialize_cameras = MagicMock()
|
||||
|
||||
return c
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=PARAMS,
|
||||
# ids=["teleop-left", "teleop-right", "torso-rgb", "torso-depth"],
|
||||
ids=["teleop-left", "teleop-right", "torso-rgb"],
|
||||
)
|
||||
def camera(request):
|
||||
name, image_type = request.param
|
||||
with (
|
||||
patch(
|
||||
"lerobot.cameras.reachy2_camera.reachy2_camera.CameraManager",
|
||||
side_effect=lambda *a, **k: _make_cam_manager_mock(),
|
||||
),
|
||||
):
|
||||
config = Reachy2CameraConfig(name=name, image_type=image_type)
|
||||
cam = Reachy2Camera(config)
|
||||
yield cam
|
||||
if cam.is_connected:
|
||||
cam.disconnect()
|
||||
|
||||
|
||||
def test_connect(camera):
|
||||
camera.connect()
|
||||
assert camera.is_connected
|
||||
camera.cam_manager.initialize_cameras.assert_called_once()
|
||||
|
||||
|
||||
def test_read(camera):
|
||||
camera.connect()
|
||||
|
||||
img = camera.read()
|
||||
if camera.config.name == "teleop":
|
||||
camera.cam_manager.teleop.get_frame.assert_called_once()
|
||||
elif camera.config.name == "depth":
|
||||
camera.cam_manager.depth.get_frame.assert_called_once()
|
||||
assert isinstance(img, np.ndarray)
|
||||
assert img.shape == (480, 640, 3)
|
||||
|
||||
|
||||
def test_disconnect(camera):
|
||||
camera.connect()
|
||||
|
||||
camera.disconnect()
|
||||
assert not camera.is_connected
|
||||
|
||||
|
||||
def test_async_read(camera):
|
||||
camera.connect()
|
||||
try:
|
||||
img = camera.async_read()
|
||||
|
||||
assert camera.thread is not None
|
||||
assert camera.thread.is_alive()
|
||||
assert isinstance(img, np.ndarray)
|
||||
finally:
|
||||
if camera.is_connected:
|
||||
camera.disconnect()
|
||||
|
||||
|
||||
def test_async_read_timeout(camera):
|
||||
camera.connect()
|
||||
try:
|
||||
with pytest.raises(TimeoutError):
|
||||
camera.async_read(timeout_ms=0)
|
||||
finally:
|
||||
if camera.is_connected:
|
||||
camera.disconnect()
|
||||
|
||||
|
||||
def test_read_before_connect(camera):
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
_ = camera.read()
|
||||
|
||||
|
||||
def test_disconnect_before_connect(camera):
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
camera.disconnect()
|
||||
|
||||
|
||||
def test_async_read_before_connect(camera):
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
_ = camera.async_read()
|
||||
|
||||
|
||||
def test_wrong_camera_name():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2CameraConfig(name="wrong-name", image_type="left")
|
||||
|
||||
|
||||
def test_wrong_image_type():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2CameraConfig(name="teleop", image_type="rgb")
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2CameraConfig(name="depth", image_type="left")
|
||||
|
||||
|
||||
def test_wrong_color_mode():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2CameraConfig(name="teleop", image_type="left", color_mode="wrong-color")
|
||||
@@ -28,6 +28,7 @@ pytest_plugins = [
|
||||
"tests.fixtures.files",
|
||||
"tests.fixtures.hub",
|
||||
"tests.fixtures.optimizers",
|
||||
"tests.plugins.reachy2_sdk",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def _install_reachy2_sdk_stub():
|
||||
sdk = types.ModuleType("reachy2_sdk")
|
||||
sdk.__path__ = []
|
||||
sdk.ReachySDK = MagicMock(name="ReachySDK")
|
||||
|
||||
media = types.ModuleType("reachy2_sdk.media")
|
||||
media.__path__ = []
|
||||
camera = types.ModuleType("reachy2_sdk.media.camera")
|
||||
camera.CameraView = MagicMock(name="CameraView")
|
||||
camera_manager = types.ModuleType("reachy2_sdk.media.camera_manager")
|
||||
camera_manager.CameraManager = MagicMock(name="CameraManager")
|
||||
|
||||
sdk.media = media
|
||||
media.camera = camera
|
||||
media.camera_manager = camera_manager
|
||||
|
||||
# Register in sys.modules
|
||||
sys.modules.setdefault("reachy2_sdk", sdk)
|
||||
sys.modules.setdefault("reachy2_sdk.media", media)
|
||||
sys.modules.setdefault("reachy2_sdk.media.camera", camera)
|
||||
sys.modules.setdefault("reachy2_sdk.media.camera_manager", camera_manager)
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
_install_reachy2_sdk_stub()
|
||||
@@ -0,0 +1,326 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.robots.reachy2 import (
|
||||
REACHY2_ANTENNAS_JOINTS,
|
||||
REACHY2_L_ARM_JOINTS,
|
||||
REACHY2_NECK_JOINTS,
|
||||
REACHY2_R_ARM_JOINTS,
|
||||
REACHY2_VEL,
|
||||
Reachy2Robot,
|
||||
Reachy2RobotConfig,
|
||||
)
|
||||
|
||||
# {lerobot_keys: reachy2_sdk_keys}
|
||||
REACHY2_JOINTS = {
|
||||
**REACHY2_NECK_JOINTS,
|
||||
**REACHY2_ANTENNAS_JOINTS,
|
||||
**REACHY2_R_ARM_JOINTS,
|
||||
**REACHY2_L_ARM_JOINTS,
|
||||
}
|
||||
|
||||
PARAMS = [
|
||||
{}, # default config
|
||||
{"with_mobile_base": False},
|
||||
{"with_mobile_base": False, "with_l_arm": False, "with_antennas": False},
|
||||
{"with_r_arm": False, "with_neck": False, "with_antennas": False},
|
||||
{"use_external_commands": True, "disable_torque_on_disconnect": True},
|
||||
{"use_external_commands": True, "with_mobile_base": False, "with_neck": False},
|
||||
{"disable_torque_on_disconnect": False},
|
||||
{"max_relative_target": 5},
|
||||
{"with_right_teleop_camera": False},
|
||||
{"with_left_teleop_camera": False, "with_right_teleop_camera": False},
|
||||
{"with_left_teleop_camera": False, "with_torso_camera": True},
|
||||
]
|
||||
|
||||
|
||||
def _make_reachy2_sdk_mock():
|
||||
class JointSpy:
|
||||
__slots__ = (
|
||||
"present_position",
|
||||
"_goal_position",
|
||||
"_on_set",
|
||||
)
|
||||
|
||||
def __init__(self, present_position=0.0, on_set=None):
|
||||
self.present_position = present_position
|
||||
self._goal_position = present_position
|
||||
self._on_set = on_set
|
||||
|
||||
@property
|
||||
def goal_position(self):
|
||||
return self._goal_position
|
||||
|
||||
@goal_position.setter
|
||||
def goal_position(self, v):
|
||||
self._goal_position = v
|
||||
if self._on_set:
|
||||
self._on_set()
|
||||
|
||||
r = MagicMock(name="ReachySDKMock")
|
||||
r.is_connected.return_value = True
|
||||
|
||||
def _connect():
|
||||
r.is_connected.return_value = True
|
||||
|
||||
def _disconnect():
|
||||
r.is_connected.return_value = False
|
||||
|
||||
# Global counter of goal_position sets
|
||||
r._goal_position_set_total = 0
|
||||
|
||||
def _on_any_goal_set():
|
||||
r._goal_position_set_total += 1
|
||||
|
||||
# Mock joints with some dummy positions
|
||||
joints = {
|
||||
k: JointSpy(
|
||||
present_position=float(i),
|
||||
on_set=_on_any_goal_set,
|
||||
)
|
||||
for i, k in enumerate(REACHY2_JOINTS.values())
|
||||
}
|
||||
r.joints = joints
|
||||
|
||||
# Mock mobile base with some dummy odometry
|
||||
r.mobile_base = MagicMock()
|
||||
r.mobile_base.odometry = {
|
||||
"x": 0.1,
|
||||
"y": -0.2,
|
||||
"theta": 21.3,
|
||||
"vx": 0.001,
|
||||
"vy": 0.002,
|
||||
"vtheta": 0.0,
|
||||
}
|
||||
|
||||
r.connect = MagicMock(side_effect=_connect)
|
||||
r.disconnect = MagicMock(side_effect=_disconnect)
|
||||
|
||||
# Mock methods
|
||||
r.turn_on = MagicMock()
|
||||
r.reset_default_limits = MagicMock()
|
||||
r.send_goal_positions = MagicMock()
|
||||
r.turn_off_smoothly = MagicMock()
|
||||
r.mobile_base.set_goal_speed = MagicMock()
|
||||
r.mobile_base.send_speed_command = MagicMock()
|
||||
|
||||
return r
|
||||
|
||||
|
||||
def _make_reachy2_camera_mock(*args, **kwargs):
|
||||
cfg = args[0] if args else kwargs.get("config")
|
||||
name = getattr(cfg, "name", kwargs.get("name", "cam"))
|
||||
image_type = getattr(cfg, "image_type", kwargs.get("image_type", "cam"))
|
||||
width = getattr(cfg, "width", kwargs.get("width", 640))
|
||||
height = getattr(cfg, "height", kwargs.get("height", 480))
|
||||
|
||||
cam = MagicMock(name=f"Reachy2CameraMock:{name}")
|
||||
cam.name = name
|
||||
cam.image_type = image_type
|
||||
cam.width = width
|
||||
cam.height = height
|
||||
cam.connect = MagicMock()
|
||||
cam.disconnect = MagicMock()
|
||||
cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
|
||||
return cam
|
||||
|
||||
|
||||
@pytest.fixture(params=PARAMS, ids=lambda p: "default" if not p else ",".join(p.keys()))
|
||||
def reachy2(request):
|
||||
with (
|
||||
patch(
|
||||
"lerobot.robots.reachy2.robot_reachy2.ReachySDK",
|
||||
side_effect=lambda *a, **k: _make_reachy2_sdk_mock(),
|
||||
),
|
||||
patch(
|
||||
"lerobot.cameras.reachy2_camera.reachy2_camera.Reachy2Camera",
|
||||
side_effect=_make_reachy2_camera_mock,
|
||||
),
|
||||
):
|
||||
overrides = request.param
|
||||
cfg = Reachy2RobotConfig(ip_address="192.168.0.200", **overrides)
|
||||
robot = Reachy2Robot(cfg)
|
||||
yield robot
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def test_connect_disconnect(reachy2):
|
||||
assert not reachy2.is_connected
|
||||
|
||||
reachy2.connect()
|
||||
assert reachy2.is_connected
|
||||
|
||||
reachy2.reachy.turn_on.assert_called_once()
|
||||
reachy2.reachy.reset_default_limits.assert_called_once()
|
||||
|
||||
reachy2.disconnect()
|
||||
assert not reachy2.is_connected
|
||||
|
||||
if reachy2.config.disable_torque_on_disconnect:
|
||||
reachy2.reachy.turn_off_smoothly.assert_called_once()
|
||||
else:
|
||||
reachy2.reachy.turn_off_smoothly.assert_not_called()
|
||||
reachy2.reachy.disconnect.assert_called_once()
|
||||
|
||||
|
||||
def test_get_joints_dict(reachy2):
|
||||
reachy2.connect()
|
||||
|
||||
if reachy2.config.with_neck:
|
||||
assert "neck_yaw.pos" in reachy2.joints_dict
|
||||
assert "neck_pitch.pos" in reachy2.joints_dict
|
||||
assert "neck_roll.pos" in reachy2.joints_dict
|
||||
else:
|
||||
assert "neck_yaw.pos" not in reachy2.joints_dict
|
||||
assert "neck_pitch.pos" not in reachy2.joints_dict
|
||||
assert "neck_roll.pos" not in reachy2.joints_dict
|
||||
|
||||
if reachy2.config.with_antennas:
|
||||
assert "l_antenna.pos" in reachy2.joints_dict
|
||||
assert "r_antenna.pos" in reachy2.joints_dict
|
||||
else:
|
||||
assert "l_antenna.pos" not in reachy2.joints_dict
|
||||
assert "r_antenna.pos" not in reachy2.joints_dict
|
||||
|
||||
if reachy2.config.with_r_arm:
|
||||
assert "r_shoulder_pitch.pos" in reachy2.joints_dict
|
||||
assert "r_shoulder_roll.pos" in reachy2.joints_dict
|
||||
assert "r_elbow_yaw.pos" in reachy2.joints_dict
|
||||
assert "r_elbow_pitch.pos" in reachy2.joints_dict
|
||||
assert "r_wrist_roll.pos" in reachy2.joints_dict
|
||||
assert "r_wrist_pitch.pos" in reachy2.joints_dict
|
||||
assert "r_wrist_yaw.pos" in reachy2.joints_dict
|
||||
assert "r_gripper.pos" in reachy2.joints_dict
|
||||
else:
|
||||
assert "r_shoulder_pitch.pos" not in reachy2.joints_dict
|
||||
assert "r_shoulder_roll.pos" not in reachy2.joints_dict
|
||||
assert "r_elbow_yaw.pos" not in reachy2.joints_dict
|
||||
assert "r_elbow_pitch.pos" not in reachy2.joints_dict
|
||||
assert "r_wrist_roll.pos" not in reachy2.joints_dict
|
||||
assert "r_wrist_pitch.pos" not in reachy2.joints_dict
|
||||
assert "r_wrist_yaw.pos" not in reachy2.joints_dict
|
||||
assert "r_gripper.pos" not in reachy2.joints_dict
|
||||
|
||||
if reachy2.config.with_l_arm:
|
||||
assert "l_shoulder_pitch.pos" in reachy2.joints_dict
|
||||
assert "l_shoulder_roll.pos" in reachy2.joints_dict
|
||||
assert "l_elbow_yaw.pos" in reachy2.joints_dict
|
||||
assert "l_elbow_pitch.pos" in reachy2.joints_dict
|
||||
assert "l_wrist_roll.pos" in reachy2.joints_dict
|
||||
assert "l_wrist_pitch.pos" in reachy2.joints_dict
|
||||
assert "l_wrist_yaw.pos" in reachy2.joints_dict
|
||||
assert "l_gripper.pos" in reachy2.joints_dict
|
||||
else:
|
||||
assert "l_shoulder_pitch.pos" not in reachy2.joints_dict
|
||||
assert "l_shoulder_roll.pos" not in reachy2.joints_dict
|
||||
assert "l_elbow_yaw.pos" not in reachy2.joints_dict
|
||||
assert "l_elbow_pitch.pos" not in reachy2.joints_dict
|
||||
assert "l_wrist_roll.pos" not in reachy2.joints_dict
|
||||
assert "l_wrist_pitch.pos" not in reachy2.joints_dict
|
||||
assert "l_wrist_yaw.pos" not in reachy2.joints_dict
|
||||
assert "l_gripper.pos" not in reachy2.joints_dict
|
||||
|
||||
|
||||
def test_get_observation(reachy2):
|
||||
reachy2.connect()
|
||||
obs = reachy2.get_observation()
|
||||
|
||||
expected_keys = set(reachy2.joints_dict)
|
||||
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
|
||||
expected_keys.update(reachy2.cameras.keys())
|
||||
assert set(obs.keys()) == expected_keys
|
||||
|
||||
for motor in reachy2.joints_dict.keys():
|
||||
assert obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
|
||||
if reachy2.config.with_mobile_base:
|
||||
for vel in REACHY2_VEL.keys():
|
||||
assert obs[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
|
||||
if reachy2.config.with_left_teleop_camera:
|
||||
assert obs["teleop_left"].shape == (
|
||||
reachy2.config.cameras["teleop_left"].height,
|
||||
reachy2.config.cameras["teleop_left"].width,
|
||||
3,
|
||||
)
|
||||
if reachy2.config.with_right_teleop_camera:
|
||||
assert obs["teleop_right"].shape == (
|
||||
reachy2.config.cameras["teleop_right"].height,
|
||||
reachy2.config.cameras["teleop_right"].width,
|
||||
3,
|
||||
)
|
||||
if reachy2.config.with_torso_camera:
|
||||
assert obs["torso_rgb"].shape == (
|
||||
reachy2.config.cameras["torso_rgb"].height,
|
||||
reachy2.config.cameras["torso_rgb"].width,
|
||||
3,
|
||||
)
|
||||
|
||||
|
||||
def test_send_action(reachy2):
|
||||
reachy2.connect()
|
||||
|
||||
action = {k: i * 10.0 for i, k in enumerate(reachy2.joints_dict.keys(), start=1)}
|
||||
if reachy2.config.with_mobile_base:
|
||||
action.update({k: i * 0.1 for i, k in enumerate(REACHY2_VEL.keys(), start=1)})
|
||||
|
||||
previous_present_position = {
|
||||
k: reachy2.reachy.joints[REACHY2_JOINTS[k]].present_position for k in reachy2.joints_dict.keys()
|
||||
}
|
||||
returned = reachy2.send_action(action)
|
||||
|
||||
if reachy2.config.max_relative_target is None:
|
||||
assert returned == action
|
||||
|
||||
assert reachy2.reachy._goal_position_set_total == len(reachy2.joints_dict)
|
||||
for motor in reachy2.joints_dict.keys():
|
||||
expected_pos = action[motor]
|
||||
real_pos = reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
|
||||
if reachy2.config.max_relative_target is None:
|
||||
assert real_pos == expected_pos
|
||||
else:
|
||||
assert real_pos == previous_present_position[motor] + np.sign(expected_pos) * min(
|
||||
abs(expected_pos - real_pos), reachy2.config.max_relative_target
|
||||
)
|
||||
|
||||
if reachy2.config.with_mobile_base:
|
||||
goal_speed = [i * 0.1 for i, _ in enumerate(REACHY2_VEL.keys(), start=1)]
|
||||
reachy2.reachy.mobile_base.set_goal_speed.assert_called_once_with(*goal_speed)
|
||||
|
||||
if reachy2.config.use_external_commands:
|
||||
reachy2.reachy.send_goal_positions.assert_not_called()
|
||||
if reachy2.config.with_mobile_base:
|
||||
reachy2.reachy.mobile_base.send_speed_command.assert_not_called()
|
||||
else:
|
||||
reachy2.reachy.send_goal_positions.assert_called_once()
|
||||
if reachy2.config.with_mobile_base:
|
||||
reachy2.reachy.mobile_base.send_speed_command.assert_called_once()
|
||||
|
||||
|
||||
def test_no_part_declared():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2RobotConfig(
|
||||
ip_address="192.168.0.200",
|
||||
with_mobile_base=False,
|
||||
with_l_arm=False,
|
||||
with_r_arm=False,
|
||||
with_neck=False,
|
||||
with_antennas=False,
|
||||
)
|
||||
@@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.teleoperators.reachy2_teleoperator import (
|
||||
REACHY2_ANTENNAS_JOINTS,
|
||||
REACHY2_L_ARM_JOINTS,
|
||||
REACHY2_NECK_JOINTS,
|
||||
REACHY2_R_ARM_JOINTS,
|
||||
REACHY2_VEL,
|
||||
Reachy2Teleoperator,
|
||||
Reachy2TeleoperatorConfig,
|
||||
)
|
||||
|
||||
# {lerobot_keys: reachy2_sdk_keys}
|
||||
REACHY2_JOINTS = {
|
||||
**REACHY2_NECK_JOINTS,
|
||||
**REACHY2_ANTENNAS_JOINTS,
|
||||
**REACHY2_R_ARM_JOINTS,
|
||||
**REACHY2_L_ARM_JOINTS,
|
||||
}
|
||||
|
||||
PARAMS = [
|
||||
{}, # default config
|
||||
{"with_mobile_base": False},
|
||||
{"with_mobile_base": False, "with_l_arm": False, "with_antennas": False},
|
||||
{"with_r_arm": False, "with_neck": False, "with_antennas": False},
|
||||
{"with_mobile_base": False, "with_neck": False},
|
||||
{"use_present_position": True},
|
||||
]
|
||||
|
||||
|
||||
def _make_reachy2_sdk_mock():
|
||||
r = MagicMock(name="ReachySDKMock")
|
||||
r.is_connected.return_value = True
|
||||
|
||||
def _connect():
|
||||
r.is_connected.return_value = True
|
||||
|
||||
def _disconnect():
|
||||
r.is_connected.return_value = False
|
||||
|
||||
# Mock joints with some dummy positions
|
||||
joints = {
|
||||
k: MagicMock(
|
||||
present_position=float(i),
|
||||
goal_position=float(i) + 0.5,
|
||||
)
|
||||
for i, k in enumerate(REACHY2_JOINTS.values())
|
||||
}
|
||||
r.joints = joints
|
||||
|
||||
# Mock mobile base with some dummy odometry
|
||||
r.mobile_base = MagicMock()
|
||||
r.mobile_base.last_cmd_vel = {
|
||||
"vx": -0.2,
|
||||
"vy": 0.2,
|
||||
"vtheta": 11.0,
|
||||
}
|
||||
r.mobile_base.odometry = {
|
||||
"x": 1.0,
|
||||
"y": 2.0,
|
||||
"theta": 20.0,
|
||||
"vx": 0.1,
|
||||
"vy": -0.1,
|
||||
"vtheta": 8.0,
|
||||
}
|
||||
|
||||
r.connect = MagicMock(side_effect=_connect)
|
||||
r.disconnect = MagicMock(side_effect=_disconnect)
|
||||
|
||||
return r
|
||||
|
||||
|
||||
@pytest.fixture(params=PARAMS, ids=lambda p: "default" if not p else ",".join(p.keys()))
|
||||
def reachy2(request):
|
||||
with (
|
||||
patch(
|
||||
"lerobot.teleoperators.reachy2_teleoperator.reachy2_teleoperator.ReachySDK",
|
||||
side_effect=lambda *a, **k: _make_reachy2_sdk_mock(),
|
||||
),
|
||||
):
|
||||
overrides = request.param
|
||||
cfg = Reachy2TeleoperatorConfig(ip_address="192.168.0.200", **overrides)
|
||||
robot = Reachy2Teleoperator(cfg)
|
||||
yield robot
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def test_connect_disconnect(reachy2):
|
||||
assert not reachy2.is_connected
|
||||
|
||||
reachy2.connect()
|
||||
assert reachy2.is_connected
|
||||
|
||||
reachy2.disconnect()
|
||||
assert not reachy2.is_connected
|
||||
|
||||
reachy2.reachy.disconnect.assert_called_once()
|
||||
|
||||
|
||||
def test_get_action(reachy2):
|
||||
reachy2.connect()
|
||||
action = reachy2.get_action()
|
||||
|
||||
expected_keys = set(reachy2.joints_dict)
|
||||
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
|
||||
assert set(action.keys()) == expected_keys
|
||||
|
||||
for motor in reachy2.joints_dict.keys():
|
||||
if reachy2.config.use_present_position:
|
||||
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
|
||||
else:
|
||||
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
|
||||
if reachy2.config.with_mobile_base:
|
||||
if reachy2.config.use_present_position:
|
||||
for vel in REACHY2_VEL.keys():
|
||||
assert action[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
|
||||
else:
|
||||
for vel in REACHY2_VEL.keys():
|
||||
assert action[vel] == reachy2.reachy.mobile_base.last_cmd_vel[REACHY2_VEL[vel]]
|
||||
|
||||
|
||||
def test_no_part_declared():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2TeleoperatorConfig(
|
||||
ip_address="192.168.0.200",
|
||||
with_mobile_base=False,
|
||||
with_l_arm=False,
|
||||
with_r_arm=False,
|
||||
with_neck=False,
|
||||
with_antennas=False,
|
||||
)
|
||||
Reference in New Issue
Block a user