Compare commits

..

1 Commits

Author SHA1 Message Date
Michel Aractingi 74b7cd246e add check for cfg.policy in force_cpu line 2026-01-19 13:54:44 +01:00
88 changed files with 921 additions and 4257 deletions
+1 -12
View File
@@ -18,11 +18,6 @@ name: Documentation
on:
# Allows running this workflow manually from the Actions tab
workflow_dispatch:
inputs:
version:
description: 'Version tag (e.g. v0.1.2) - Leave empty for standard main build'
required: false
type: string
# Triggers the workflow on push events to main for the docs folder
push:
@@ -59,13 +54,7 @@ jobs:
with:
commit_sha: ${{ github.sha }}
package: lerobot
additional_args: >-
--not_python_module
${{
(github.event_name == 'release' && format('--version {0}', github.event.release.tag_name)) ||
(inputs.version != '' && format('--version {0}', inputs.version)) ||
''
}}
additional_args: --not_python_module ${{ github.event_name == 'release' && format('--version {0}', github.event.release.tag_name) || '' }}
secrets:
token: ${{ secrets.HUGGINGFACE_PUSH }}
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
+2 -2
View File
@@ -20,8 +20,8 @@ on:
workflow_dispatch:
# Run on the 1st and 15th of every month at 09:00 UTC
# schedule:
# - cron: '0 2 1,15 * *'
schedule:
- cron: '0 2 1,15 * *'
permissions:
contents: read
+1 -1
View File
@@ -14,7 +14,7 @@ You can contribute in many ways:
- **Documentation:** Improve examples, guides, and docstrings.
- **Feedback:** Submit tickets related to bugs or desired new features.
If you are unsure where to start, join our [Discord Channel](https://discord.gg/q8Dzzpym3f).
If you are unsure where to start, join our [Discord Channel](https://discord.gg/JkrYNdmw).
## Development Setup
-1
View File
@@ -128,7 +128,6 @@ Learn how to implement your own simulation environment or benchmark and distribu
## Resources
- **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API.
- **[Chinese Tutorials: LeRobot+SO-ARM101中文教程-同济子豪兄](https://zihao-ai.feishu.cn/wiki/space/7589642043471924447)** Detailed doc for assembling, teleoperate, dataset, train, deploy. Verified by Seed Studio and 5 global hackathon players.
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
-48
View File
@@ -1,48 +0,0 @@
# Security Policy
## Project Status & Philosophy
`lerobot` has so far been primarily a research and prototyping tool, which is why deployment security hasnt been a strong focus until now. As `lerobot` continues to be adopted and deployed in production, we are paying much closer attention to these kinds of issues.
Fortunately, being an open-source project, the community can also help by reporting and fixing vulnerabilities. We appreciate your efforts to responsibly disclose your findings and will make every effort to acknowledge your contributions.
## Reporting a Vulnerability
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/huggingface/lerobot/security/advisories/new) tab.
The `lerobot` team will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.
#### Hugging Face Security Team
Since this project is part of the Hugging Face ecosystem, feel free to submit vulnerability reports directly to: **[security@huggingface.co](mailto:security@huggingface.co)**. Someone from the HF security team will review the report and recommend next steps.
#### Open Source Disclosures
If reporting a vulnerability specific to the open-source codebase (and not the underlying Hub infrastructure), you may also use [Huntr](https://huntr.com), a vulnerability disclosure program for open source software.
## Supported Versions
Currently, we treat `lerobot` as a rolling release. We prioritize security updates for the latest available version (`main` branch).
| Version | Supported |
| -------- | --------- |
| Latest | ✅ |
| < Latest | ❌ |
## Secure Usage Guidelines
`lerobot` is tightly coupled to the Hugging Face Hub for sharing data and pretrained policies. When downloading artifacts uploaded by others, you expose yourself to risks. Please read below for recommendations to keep your runtime and robot environment safe.
### Remote Artefacts (Weights & Policies)
Models and policies uploaded to the Hugging Face Hub come in different formats. We heavily recommend uploading and downloading models in the [`safetensors`](https://github.com/huggingface/safetensors) format.
`safetensors` was developed specifically to prevent arbitrary code execution on your system, which is critical when running software on physical hardware/robots.
To avoid loading models from unsafe formats (e.g., `pickle`), you should ensure you are prioritizing `safetensors` files.
### Remote Code
Some models or environments on the Hub may require `trust_remote_code=True` to run custom architecture code.
Please **always** verify the content of the modeling files when using this argument. We recommend setting a specific `revision` (commit hash) when loading remote code to ensure you protect yourself from unverified updates to the repository.
-6
View File
@@ -57,8 +57,6 @@
title: Use Async Inference
- local: rtc
title: Real-Time Chunking (RTC)
- local: training_time_rtc
title: Training-Time RTC
title: "Inference"
- sections:
- local: envhub
@@ -101,8 +99,6 @@
title: Unitree G1
- local: earthrover_mini_plus
title: Earth Rover Mini
- local: omx
title: OMX
title: "Robots"
- sections:
- local: phone_teleop
@@ -117,8 +113,6 @@
title: Notebooks
- local: feetech
title: Updating Feetech Firmware
- local: damiao
title: Damiao Motors and CAN Bus
title: "Resources"
- sections:
- local: contributing
-1
View File
@@ -195,7 +195,6 @@ client_cfg = RobotClientConfig(
robot=robot_cfg,
server_address="localhost:8080",
policy_device="mps",
client_device="cpu",
policy_type="smolvla",
pretrained_name_or_path="<user>/smolvla_async",
chunk_size_threshold=0.5,
-165
View File
@@ -1,165 +0,0 @@
# Damiao Motors and CAN Bus
This guide covers setup and usage of Damiao motors with LeRobot via CAN bus communication.
Currently, only Linux is supported, as the OpenArms CAN adapter only has drivers for Linux.
## Linux CAN Setup
Before using Damiao motors, you need to set up the CAN interface on your Linux system.
### Install CAN Utilities
```bash
sudo apt-get install can-utils
```
### Configure CAN Interface (Manual)
For standard CAN FD (recommended for OpenArms):
```bash
sudo ip link set can0 down
sudo ip link set can0 type can bitrate 1000000 dbitrate 5000000 fd on
sudo ip link set can0 up
```
For standard CAN (without FD):
```bash
sudo ip link set can0 down
sudo ip link set can0 type can bitrate 1000000
sudo ip link set can0 up
```
### Configure CAN Interface (Using LeRobot)
LeRobot provides a utility script to setup and test CAN interfaces:
```bash
# Setup multiple interfaces (e.g., OpenArms Followers with 2 CAN buses)
lerobot-setup-can --mode=setup --interfaces=can0,can1
```
## Debugging CAN Communication
Use the built-in debug tools to test motor communication:
```bash
# Test motors on all interfaces
lerobot-setup-can --mode=test --interfaces=can0,can1
# Run speed/latency test
lerobot-setup-can --mode=speed --interfaces=can0
```
The test mode will scan for motors (IDs 0x01-0x08) and report which ones respond. Example output:
```
can0: UP (CAN FD)
Motor 0x01 (joint_1): ✓ FOUND
→ Response 0x11 [FD]: 00112233...
Motor 0x02 (joint_2): ✓ FOUND
Motor 0x03 (joint_3): ✗ No response
...
Summary: 2/8 motors found
```
## Usage
### Basic Setup
```python
from lerobot.motors import Motor
from lerobot.motors.damiao import DamiaoMotorsBus
# Define your motors with send/receive CAN IDs
motors = {
"joint_1": Motor(id=0x01, motor_type_str="dm8009", recv_id=0x11),
"joint_2": Motor(id=0x02, motor_type_str="dm4340", recv_id=0x12),
"joint_3": Motor(id=0x03, motor_type_str="dm4310", recv_id=0x13),
}
# Create the bus
bus = DamiaoMotorsBus(
port="can0", # Linux socketcan interface
motors=motors,
)
# Connect
bus.connect()
```
### Reading Motor States
```python
# Read single motor position (degrees)
position = bus.read("Present_Position", "joint_1")
# Read from multiple motors
positions = bus.sync_read("Present_Position") # All motors
positions = bus.sync_read("Present_Position", ["joint_1", "joint_2"])
# Read all states at once (position, velocity, torque)
states = bus.sync_read_all_states()
# Returns: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
```
### Writing Motor Commands
```python
# Enable torque
bus.enable_torque()
# Set goal position (degrees)
bus.write("Goal_Position", "joint_1", 45.0)
# Set positions for multiple motors
bus.sync_write("Goal_Position", {
"joint_1": 45.0,
"joint_2": -30.0,
"joint_3": 90.0,
})
# Disable torque
bus.disable_torque()
```
## Configuration Options
| Parameter | Default | Description |
| -------------- | --------- | ----------------------------------------------------------- |
| `port` | - | CAN interface (`can0`) or serial port (`/dev/cu.usbmodem*`) |
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
| `bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
| `data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
## Motor Configuration
Each motor requires:
- `id`: CAN ID for sending commands
- `motor_type`: One of the supported motor types (e.g., `"dm8009"`, `"dm4340"`)
- `recv_id`: CAN ID for receiving responses
OpenArms default IDs follow the pattern: send ID `0x0N`, receive ID `0x1N` where N is the joint number.
## Troubleshooting
### No Response from Motors
1. **Check power**
2. **Verify CAN wiring**: Check CAN-H, CAN-L, and GND connections
3. **Check motor IDs**: Use Damiao Debugging Tools to verify/configure IDs
4. **Test CAN interface**: Run `candump can0` to see if messages are being received
5. **Run diagnostics**: `lerobot-setup-can --mode=test --interfaces=can0`
### Motor Timeout Parameter
If motors were configured with timeout=0, they won't respond to commands. Use Damiao Debugging Tools to set a non-zero timeout value.
### Verify CAN FD Status
```bash
ip -d link show can0 | grep fd
```
-6
View File
@@ -1,11 +1,5 @@
# EarthRover Mini Plus
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Earth_Rover_Mini_5_240c9adc-4f9e-44b7-982f-5d1dc24af1d8.png.webp"
alt="EarthRover Mini Plus"
width="70%"
/>
The EarthRover Mini Plus is a fully open source mobile robot that connects through the cloud using the Frodobots SDK. This lets you control the robot and record datasets for training AI models.
## What You Need
-6
View File
@@ -1,11 +1,5 @@
# LeKiwi
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/1740517739083.jpeg"
alt="LeKiwi"
width="70%"
/>
In the steps below, we explain how to assemble the LeKiwi mobile robot.
## Source the parts
-1
View File
@@ -42,7 +42,6 @@ lerobot-eval \
```
- `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.).
- `--env.task_ids` picks task ids to run (`[0]`, `[1,2,3]`, etc.). Omit this flag (or set it to `null`) to run all tasks in the suite.
- `--eval.batch_size` controls how many environments run in parallel.
- `--eval.n_episodes` sets how many episodes to run in total.
-197
View File
@@ -1,197 +0,0 @@
## Order and Assemble the parts
First, assemble the OMX hardware following the official assembly guide.
OMX Assembly Guide: https://ai.robotis.com/omx/assembly_guide_omx.html
OMX robots are shipped preconfigured from the factory. Motor IDs, communication parameters, and joint offsets are already set, so no additional motor setup or calibration is required before using LeRobot.
## Install LeRobot 🤗
To install LeRobot, follow our [Installation Guide](./installation)
In addition to these instructions, you need to install the Dynamixel SDK:
```bash
pip install -e ".[dynamixel]"
```
## Connect the robot
To find the port for each bus servo adapter, run this script:
```bash
lerobot-find-port
```
This command runs and when prompted, disconnect the USB cable from either the leader or follower arm and press Enter. The output will show 'The port of this MotorsBus is [port]'. This identifies the port for the disconnected arm. Repeat for the other arm to identify both ports.
<hfoptions id="find_port">
<hfoption id="Mac">
Example output on macOS:
```
Finding all available ports for the MotorBus.
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
Remove the USB cable from your MotorsBus and press Enter when done.
[...Disconnect corresponding leader or follower arm and press Enter...]
The port of this MotorsBus is /dev/tty.usbmodem575E0032081
Reconnect the USB cable.
```
Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm.
</hfoption>
<hfoption id="Linux">
On Linux, we strongly recommend using udev rules to assign persistent and human-readable device names to the OMX leader and follower arms. This avoids issues where device names such as ttyACM0 and ttyACM1 change when the robot is unplugged, replugged, or when the system is rebooted.
#### 1. Find your device serial numbers
You should have obtained the port numbers like ../../ttyACM? for the leader and follower using `lerobot-find-port`. You can match those results with the serial numbers using the `ls -l /dev/serial/by-id/` command.
To create udev rules, you need the unique serial number for each OMX device. The easiest way is to list devices under:
```bash
ls -l /dev/serial/by-id/
```
You will see output similar to:
```bash
usb-ROBOTIS_OpenRB-150_228BDD7B503059384C2E3120FF0A2B19-if00 -> ../../ttyACM0
usb-ROBOTIS_OpenRB-150_67E1ED68503059384C2E3120FF092234-if00 -> ../../ttyACM1
```
In each line, the serial number is the long string after `usb-ROBOTIS_OpenRB-150_` and before `-if00`.
Follower serial: `228BDD7B503059384C2E3120FF0A2B19`
Leader serial: `67E1ED68503059384C2E3120FF092234`
#### 2. Create the udev rule
Create a new udev rule file:
```bash
sudo nano /etc/udev/rules.d/99-omx.rules
```
Paste the following lines, replacing the serial numbers with the values you found above:
```bash
SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="228BDD7B503059384C2E3120FF0A2B19", SYMLINK+="omx_follower"
SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="67E1ED68503059384C2E3120FF092234", SYMLINK+="omx_leader"
```
Save the file and reload udev rules:
```bash
sudo udevadm control --reload-rules
sudo udevadm trigger
```
Now unplug and replug both devices once.
#### 3. Verify the symlinks
Check that the persistent device names exist:
```bash
ls -l /dev/omx_follower /dev/omx_leader
```
You should see them pointing to ttyACM\* devices:
```bash
/dev/omx_follower -> ttyACM*
/dev/omx_leader -> ttyACM*
```
These names remain stable across reboots and reconnections.
</hfoption>
</hfoptions>
## Teleoperate
After identifying the correct ports, you can directly teleoperate the follower arm using the leader arm.
<hfoptions id="teleoperate">
<hfoption id="Mac">
### Teleoperate without camera
```bash
lerobot-teleoperate \
--robot.type=omx_follower \
--robot.port=<your_follower_port> \
--robot.id=omx_follower_arm \
--teleop.type=omx_leader \
--teleop.port=<your_leader_port> \
--teleop.id=omx_leader_arm
```
During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps.
### Teleoperate with camera
You can also enable camera input during teleoperation by providing a camera configuration for the follower arm.
```bash
lerobot-teleoperate \
--robot.type=omx_follower \
--robot.port=<your_follower_port> \
--robot.id=omx_follower_arm \
--robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \
--teleop.type=omx_leader \
--teleop.port=<your_leader_port> \
--teleop.id=omx_leader_arm \
--display_data=true
```
When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning.
</hfoption>
<hfoption id="Linux">
### Teleoperate without camera
```bash
lerobot-teleoperate \
--robot.type=omx_follower \
--robot.port=/dev/omx_follower \
--robot.id=omx_follower_arm \
--teleop.type=omx_leader \
--teleop.port=/dev/omx_leader \
--teleop.id=omx_leader_arm
```
During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps.
### Teleoperate with camera
You can also enable camera input during teleoperation by providing a camera configuration for the follower arm.
```bash
lerobot-teleoperate \
--robot.type=omx_follower \
--robot.port=/dev/omx_follower \
--robot.id=omx_follower_arm \
--robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \
--teleop.type=omx_leader \
--teleop.port=/dev/omx_leader \
--teleop.id=omx_leader_arm \
--display_data=true
```
When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning.
</hfoption>
</hfoptions>
Congrats 🎉, your robot is all set to learn a task on its own.
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/robotis).
-13
View File
@@ -1,18 +1,5 @@
# SO-101
<div style="display: flex; align-items: center; gap: 10px;">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/SO101_Follower.webp"
alt="SO-101"
width="60%"
/>
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/SO101_Leader.webp"
alt="SO-101"
width="60%"
/>
</div>
In the steps below, we explain how to assemble our flagship robot, the SO-101.
## Source the parts
-86
View File
@@ -1,86 +0,0 @@
# Training-Time RTC
Training-Time RTC teaches the model to handle inference delay during training.
It feeds the **ground-truth action prefix** to the model and trains only on the remaining postfix actions.
This keeps chunk transitions smooth without doing any inference-time inpainting.
Based on: [Training-Time Action Conditioning for Efficient Real-Time Chunking](https://arxiv.org/abs/2512.05964).
LeRobot supports this for `pi0`, `pi05` and `smolvla` without changing model parameters.
---
## How It Works
### At Training Time
- Sample a delay `d` per batch element.
- Keep the first `d` action steps as **ground truth** (no noise).
- Add noise only to the postfix actions.
- Set the flow-matching timestep to **1.0** for prefix tokens and normal timesteps for postfix tokens.
- Mask the loss to only train on the postfix.
### At Inference Time
When `rtc_training_config.enabled=true`, the model uses training-time RTC inference:
- Replace prefix positions in `x_t` with previous chunk's leftover actions.
- Set timestep to **1.0** for prefix positions.
---
## Quick Start (CLI)
```bash
lerobot-train \
--policy.type=pi0 \
--dataset.repo_id=your/dataset \
--policy.rtc_training_config.enabled=true \
--policy.rtc_training_config.min_delay=0 \
--policy.rtc_training_config.max_delay=6 \
--policy.rtc_training_config.delay_distribution=UNIFORM
```
---
## Inference with Training-Time RTC
After training with `rtc_training_config`, use the same config at inference. The model will automatically use training-time RTC inference:
```python
policy = PI0Policy.from_pretrained("path/to/trained/model")
# rtc_training_config is loaded from the saved config
actions = policy.predict_action_chunk(
batch,
inference_delay=5, # estimated delay in timesteps
prev_chunk_left_over=previous_actions, # from previous chunk
)
```
---
## Key Parameters
`RTCTrainingConfig` is available on the policy config (`pi0`, `pi05`, `smolvla`, `xvla`):
- **`enabled`**: Toggle training-time RTC (both training and inference).
- **`min_delay` / `max_delay`**: Delay range (inclusive).
- **`delay_distribution`**:
- `UNIFORM`: uniform in `[min_delay, max_delay]`
- `EXP`: exponentially decayed distribution over delays
- **`exp_decay`**: Exponential decay factor for `EXP` sampling.
---
## Notes and Recommendations
- Start with `min_delay=0` and `max_delay` around your expected worst-case inference delay.
- Use `EXP` if you want more supervision on smaller delays.
---
## Related Docs
- [Real-Time Chunking (Inference-Time RTC)](./rtc)
- [Pi0](./pi0), [Pi0.5](./pi05), [SmolVLA](./smolvla)
+6 -13
View File
@@ -95,26 +95,26 @@ Convert an image-based dataset to video format, creating a new LeRobotDataset wh
# Local-only: Save to a custom output directory (no hub push)
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \
--operation.type convert_to_video \
--operation.output_dir /path/to/output/pusht_video
# Save with new repo_id (local storage)
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_image_to_video
--operation.type convert_to_video
# Convert and push to Hugging Face Hub
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_image_to_video \
--operation.type convert_to_video \
--push_to_hub true
# Convert with custom video codec and quality settings
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \
--operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \
--operation.vcodec libsvtav1 \
--operation.pix_fmt yuv420p \
@@ -124,23 +124,16 @@ lerobot-edit-dataset \
# Convert only specific episodes
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \
--operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \
--operation.episode_indices "[0, 1, 2, 5, 10]"
# Convert with multiple workers for parallel processing
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \
--operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \
--operation.num_workers 8
# For memory-constrained systems, users can now specify limits:
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.max_episodes_per_batch 50 \
--operation.max_frames_per_batch 10000
```
**Parameters:**
@@ -30,7 +30,6 @@ def main():
robot=robot_cfg,
server_address=server_address,
policy_device="mps",
client_device="cpu",
policy_type="act",
pretrained_name_or_path="<user>/robot_learning_tutorial_act",
chunk_size_threshold=0.5, # g
+1 -4
View File
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.4.4"
version = "0.4.3"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
dynamic = ["readme"]
license = { text = "Apache-2.0" }
@@ -102,7 +102,6 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
damiao = ["python-can>=4.2.0,<5.0.0"]
# Robots
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
@@ -204,7 +203,6 @@ lerobot-info="lerobot.scripts.lerobot_info:main"
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.packages.find]
@@ -280,7 +278,6 @@ default.extend-ignore-identifiers-re = [
"thw",
"inpt",
"ROBOTIS",
"OT_VALUE"
]
# TODO: Uncomment when ready to use
-10
View File
@@ -126,12 +126,6 @@ class RobotClientConfig:
# Device configuration
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
client_device: str = field(
default="cpu",
metadata={
"help": "Device to move actions to after receiving from server (e.g., for downstream planners)"
},
)
# Control behavior configuration
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
@@ -167,9 +161,6 @@ class RobotClientConfig:
if not self.policy_device:
raise ValueError("policy_device cannot be empty")
if not self.client_device:
raise ValueError("client_device cannot be empty")
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
@@ -193,7 +184,6 @@ class RobotClientConfig:
"policy_type": self.policy_type,
"pretrained_name_or_path": self.pretrained_name_or_path,
"policy_device": self.policy_device,
"client_device": self.client_device,
"chunk_size_threshold": self.chunk_size_threshold,
"fps": self.fps,
"actions_per_chunk": self.actions_per_chunk,
+1 -1
View File
@@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
DEFAULT_OBS_QUEUE_TIMEOUT = 2
# All action chunking policies
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05", "groot"]
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
# TODO: Add all other robots
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"]
+2 -3
View File
@@ -18,7 +18,6 @@ import os
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import torch
@@ -40,8 +39,8 @@ from lerobot.utils.utils import init_logging
Action = torch.Tensor
# observation as received from the robot (can be numpy arrays, floats, etc.)
RawObservation = dict[str, Any]
# observation as received from the robot
RawObservation = dict[str, torch.Tensor]
# observation as those recorded in LeRobot dataset (keys are different)
LeRobotObservation = dict[str, torch.Tensor]
@@ -381,8 +381,6 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
action_tensor = action_tensor.detach().cpu()
"""5. Convert to TimedAction list"""
action_chunk = self._time_action_chunk(
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
+2 -18
View File
@@ -25,7 +25,6 @@ python src/lerobot/async_inference/robot_client.py \
--policy_type=act \
--pretrained_name_or_path=user/model \
--policy_device=mps \
--client_device=cpu \
--actions_per_chunk=50 \
--chunk_size_threshold=0.5 \
--aggregate_fn_name=weighted_average \
@@ -41,7 +40,6 @@ from collections.abc import Callable
from dataclasses import asdict
from pprint import pformat
from queue import Queue
from typing import Any
import draccus
import grpc
@@ -49,6 +47,7 @@ import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.processor import RobotAction
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -286,21 +285,6 @@ class RobotClient:
timed_actions = pickle.loads(actions_chunk.data) # nosec
deserialize_time = time.perf_counter() - deserialize_start
# Log device type of received actions
if len(timed_actions) > 0:
received_device = timed_actions[0].get_action().device.type
self.logger.debug(f"Received actions on device: {received_device}")
# Move actions to client_device (e.g., for downstream planners that need GPU)
client_device = self.config.client_device
if client_device != "cpu":
for timed_action in timed_actions:
if timed_action.get_action().device.type != client_device:
timed_action.action = timed_action.get_action().to(client_device)
self.logger.debug(f"Converted actions to device: {client_device}")
else:
self.logger.debug(f"Actions kept on device: {client_device}")
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
# Calculate network latency if we have matching observations
@@ -367,7 +351,7 @@ class RobotClient:
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
return action
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
def control_loop_action(self, verbose: bool = False) -> RobotAction:
"""Reading and performing actions in local queue"""
# Lock only for queue operations
-5
View File
@@ -50,8 +50,3 @@ class RTCAttentionSchedule(str, Enum):
ONES = "ONES"
LINEAR = "LINEAR"
EXP = "EXP"
class RTCTrainingDelayDistribution(str, Enum):
UNIFORM = "UNIFORM"
EXP = "EXP"
+5 -24
View File
@@ -19,7 +19,6 @@ import logging
import shutil
from pathlib import Path
import datasets
import pandas as pd
import tqdm
@@ -33,7 +32,6 @@ from lerobot.datasets.utils import (
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
get_file_size_in_mb,
get_hf_features_from_features,
get_parquet_file_size_in_mb,
to_parquet_with_hf_images,
update_chunk_file_indices,
@@ -404,21 +402,12 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
}
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
contains_images = len(dst_meta.image_keys) > 0
# retrieve features schema for proper image typing in parquet
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
chunk_index=src_chunk_idx, file_index=src_file_idx
)
if contains_images:
# Use HuggingFace datasets to read source data to preserve image format
src_ds = datasets.Dataset.from_parquet(str(src_path))
df = src_ds.to_pandas()
else:
df = pd.read_parquet(src_path)
df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta)
data_idx = append_or_create_parquet_file(
@@ -428,9 +417,8 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
data_files_size_in_mb,
chunk_size,
DEFAULT_DATA_PATH,
contains_images=contains_images,
contains_images=len(dst_meta.image_keys) > 0,
aggr_root=dst_meta.root,
hf_features=hf_features,
)
return data_idx
@@ -500,7 +488,6 @@ def append_or_create_parquet_file(
default_path: str,
contains_images: bool = False,
aggr_root: Path = None,
hf_features: datasets.Features | None = None,
):
"""Appends data to an existing parquet file or creates a new one based on size constraints.
@@ -516,7 +503,6 @@ def append_or_create_parquet_file(
default_path: Format string for generating file paths.
contains_images: Whether the data contains images requiring special handling.
aggr_root: Root path for the aggregated dataset.
hf_features: Optional HuggingFace Features schema for proper image typing.
Returns:
dict: Updated index dictionary with current chunk and file indices.
@@ -526,7 +512,7 @@ def append_or_create_parquet_file(
if not dst_path.exists():
dst_path.parent.mkdir(parents=True, exist_ok=True)
if contains_images:
to_parquet_with_hf_images(df, dst_path, features=hf_features)
to_parquet_with_hf_images(df, dst_path)
else:
df.to_parquet(dst_path)
return idx
@@ -541,17 +527,12 @@ def append_or_create_parquet_file(
final_df = df
target_path = new_path
else:
if contains_images:
# Use HuggingFace datasets to read existing data to preserve image format
existing_ds = datasets.Dataset.from_parquet(str(dst_path))
existing_df = existing_ds.to_pandas()
else:
existing_df = pd.read_parquet(dst_path)
existing_df = pd.read_parquet(dst_path)
final_df = pd.concat([existing_df, df], ignore_index=True)
target_path = dst_path
if contains_images:
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
to_parquet_with_hf_images(final_df, target_path)
else:
final_df.to_parquet(target_path)
+1 -561
View File
@@ -26,7 +26,6 @@ This module provides utilities for:
import logging
import shutil
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import datasets
@@ -52,8 +51,7 @@ from lerobot.datasets.utils import (
write_stats,
write_tasks,
)
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
from lerobot.utils.constants import HF_LEROBOT_HOME
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
@@ -1085,561 +1083,3 @@ def _copy_episodes_metadata_and_stats(
else:
if src_dataset.meta.stats:
write_stats(src_dataset.meta.stats, dst_meta.root)
def _save_episode_images_for_video(
dataset: LeRobotDataset,
imgs_dir: Path,
img_key: str,
episode_index: int,
num_workers: int = 4,
) -> None:
"""Save images from a specific episode and camera to disk for video encoding.
Args:
dataset: The LeRobot dataset to extract images from
imgs_dir: Directory to save images to
img_key: The image key (camera) to extract
episode_index: Index of the episode to save
num_workers: Number of threads for parallel image saving
"""
# Create directory
imgs_dir.mkdir(parents=True, exist_ok=True)
# Get dataset without torch format for PIL image access
hf_dataset = dataset.hf_dataset.with_format(None)
# Select only this camera's images
imgs_dataset = hf_dataset.select_columns(img_key)
# Get episode start and end indices
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
# Get all items for this episode
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
# Define function to save a single image
def save_single_image(i_item_tuple):
i, item = i_item_tuple
img = item[img_key]
# Use frame-XXXXXX.png format to match encode_video_frames expectations
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
return i
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
items = list(enumerate(episode_dataset))
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(save_single_image, item) for item in items]
for future in as_completed(futures):
future.result() # This will raise any exceptions that occurred
def _save_batch_episodes_images(
dataset: LeRobotDataset,
imgs_dir: Path,
img_key: str,
episode_indices: list[int],
num_workers: int = 4,
) -> list[float]:
"""Save images from multiple episodes to disk for batch video encoding.
Args:
dataset: The LeRobot dataset to extract images from
imgs_dir: Directory to save images to
img_key: The image key (camera) to extract
episode_indices: List of episode indices to save
num_workers: Number of threads for parallel image saving
Returns:
List of episode durations in seconds
"""
imgs_dir.mkdir(parents=True, exist_ok=True)
hf_dataset = dataset.hf_dataset.with_format(None)
imgs_dataset = hf_dataset.select_columns(img_key)
# Define function to save a single image with global frame index
# Defined once outside the loop to avoid repeated closure creation
def save_single_image(i_item_tuple, base_frame_idx, img_key_param):
i, item = i_item_tuple
img = item[img_key_param]
# Use global frame index for naming
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100)
return i
episode_durations = []
frame_idx = 0
for ep_idx in episode_indices:
# Get episode range
from_idx = dataset.meta.episodes["dataset_from_index"][ep_idx]
to_idx = dataset.meta.episodes["dataset_to_index"][ep_idx]
episode_length = to_idx - from_idx
episode_durations.append(episode_length / dataset.fps)
# Get episode images
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
# Save images
items = list(enumerate(episode_dataset))
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(save_single_image, item, frame_idx, img_key) for item in items]
for future in as_completed(futures):
future.result()
frame_idx += episode_length
return episode_durations
def _iter_episode_batches(
episode_indices: list[int],
episode_lengths: dict[int, int],
size_per_frame_mb: float,
video_file_size_limit: float,
max_episodes: int | None,
max_frames: int | None,
):
"""Generator that yields batches of episode indices for video encoding.
Groups episodes into batches that respect size and memory constraints:
- Stays under video file size limit
- Respects maximum episodes per batch (if specified)
- Respects maximum frames per batch (if specified)
Args:
episode_indices: List of episode indices to batch
episode_lengths: Dictionary mapping episode index to episode length
size_per_frame_mb: Estimated size per frame in MB
video_file_size_limit: Maximum video file size in MB
max_episodes: Maximum number of episodes per batch (None = no limit)
max_frames: Maximum number of frames per batch (None = no limit)
Yields:
List of episode indices for each batch
"""
batch_episodes = []
estimated_size = 0.0
total_frames = 0
for ep_idx in episode_indices:
ep_length = episode_lengths[ep_idx]
ep_estimated_size = ep_length * size_per_frame_mb
# we check if adding this episode would exceed any constraint
would_exceed_size = estimated_size > 0 and estimated_size + ep_estimated_size >= video_file_size_limit
would_exceed_episodes = max_episodes is not None and len(batch_episodes) >= max_episodes
would_exceed_frames = max_frames is not None and total_frames + ep_length > max_frames
if batch_episodes and (would_exceed_size or would_exceed_episodes or would_exceed_frames):
# yield current batch before adding this episode
yield batch_episodes
# start a new batch with current episode
batch_episodes = [ep_idx]
estimated_size = ep_estimated_size
total_frames = ep_length
else:
# add to current batch
batch_episodes.append(ep_idx)
estimated_size += ep_estimated_size
total_frames += ep_length
# yield final batch if not empty
if batch_episodes:
yield batch_episodes
def _estimate_frame_size_via_calibration(
dataset: LeRobotDataset,
img_key: str,
episode_indices: list[int],
temp_dir: Path,
fps: int,
vcodec: str,
pix_fmt: str,
g: int,
crf: int,
fast_decode: int,
num_calibration_frames: int = 30,
) -> float:
"""Estimate MB per frame by encoding a small calibration sample.
Encodes a representative sample of frames using the exact codec parameters
to measure actual compression ratio, which is more accurate than heuristics.
Args:
dataset: Source dataset with images.
img_key: Image key to calibrate (e.g., "observation.images.top").
episode_indices: List of episode indices being processed.
temp_dir: Temporary directory for calibration files.
fps: Frames per second for video encoding.
vcodec: Video codec (libsvtav1, h264, hevc).
pix_fmt: Pixel format (yuv420p, etc.).
g: GOP size (group of pictures).
crf: Constant Rate Factor (quality).
fast_decode: Fast decode tuning parameter.
num_calibration_frames: Number of frames to use for calibration (default: 30).
Returns:
Estimated size in MB per frame based on actual encoding.
"""
calibration_dir = temp_dir / "calibration" / img_key
calibration_dir.mkdir(parents=True, exist_ok=True)
try:
# Select a representative episode (prefer middle episode if available)
calibration_ep_idx = episode_indices[len(episode_indices) // 2]
# Get episode range
from_idx = dataset.meta.episodes["dataset_from_index"][calibration_ep_idx]
to_idx = dataset.meta.episodes["dataset_to_index"][calibration_ep_idx]
episode_length = to_idx - from_idx
# Use up to num_calibration_frames from this episode
num_frames = min(num_calibration_frames, episode_length)
# Get frames from dataset
hf_dataset = dataset.hf_dataset.with_format(None)
sample_indices = range(from_idx, from_idx + num_frames)
# Save calibration frames
for i, idx in enumerate(sample_indices):
img = hf_dataset[idx][img_key]
img.save(str(calibration_dir / f"frame-{i:06d}.png"), quality=100)
# Encode calibration video
calibration_video_path = calibration_dir / "calibration.mp4"
encode_video_frames(
imgs_dir=calibration_dir,
video_path=calibration_video_path,
fps=fps,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
overwrite=True,
)
# Measure actual compressed size
video_size_bytes = calibration_video_path.stat().st_size
video_size_mb = video_size_bytes / BYTES_PER_MIB
size_per_frame_mb = video_size_mb / num_frames
logging.info(
f" Calibration: {num_frames} frames -> {video_size_mb:.2f} MB "
f"= {size_per_frame_mb:.4f} MB/frame for {img_key}"
)
return size_per_frame_mb
finally:
# Clean up calibration files
if calibration_dir.exists():
shutil.rmtree(calibration_dir)
def _copy_data_without_images(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_indices: list[int],
img_keys: list[str],
) -> None:
"""Copy data files without image columns.
Args:
src_dataset: Source dataset
dst_meta: Destination metadata
episode_indices: Episodes to include
img_keys: Image keys to remove
"""
from lerobot.datasets.utils import DATA_DIR
data_dir = src_dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
episode_set = set(episode_indices)
for src_path in tqdm(parquet_files, desc="Processing data files"):
df = pd.read_parquet(src_path).reset_index(drop=True)
# Filter to only include selected episodes
df = df[df["episode_index"].isin(episode_set)].copy()
if len(df) == 0:
continue
# Remove image columns
columns_to_drop = [col for col in img_keys if col in df.columns]
if columns_to_drop:
df = df.drop(columns=columns_to_drop)
# Get chunk and file indices from path
relative_path = src_path.relative_to(src_dataset.root)
chunk_dir = relative_path.parts[1]
file_name = relative_path.parts[2]
chunk_idx = int(chunk_dir.split("-")[1])
file_idx = int(file_name.split("-")[1].split(".")[0])
# Write to destination without pandas index
dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet"
dst_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(dst_path, index=False)
# Video conversion constants
BYTES_PER_KIB = 1024
BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB
def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path,
repo_id: str | None = None,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int = 2,
crf: int = 30,
fast_decode: int = 0,
episode_indices: list[int] | None = None,
num_workers: int = 4,
max_episodes_per_batch: int | None = None,
max_frames_per_batch: int | None = None,
) -> LeRobotDataset:
"""Convert image-to-video dataset.
Creates a new LeRobotDataset with images encoded as videos, following the proper
LeRobot dataset structure with videos stored in chunked MP4 files.
Args:
dataset: The source LeRobot dataset with images
output_dir: Directory to save the new video dataset
repo_id: Repository ID for the new dataset (default: original_id + "_video")
vcodec: Video codec (default: libsvtav1)
pix_fmt: Pixel format (default: yuv420p)
g: Group of pictures size (default: 2)
crf: Constant rate factor (default: 30)
fast_decode: Fast decode tuning (default: 0)
episode_indices: List of episode indices to convert (None = all episodes)
num_workers: Number of threads for parallel processing (default: 4)
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
max_frames_per_batch: Maximum frames per video batch to avoid memory issues (None = no limit)
Returns:
New LeRobotDataset with images encoded as videos
"""
# Check that it's an image dataset
if len(dataset.meta.video_keys) > 0:
raise ValueError(
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
)
# Get all image keys
hf_dataset = dataset.hf_dataset.with_format(None)
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
if len(img_keys) == 0:
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
# Determine which episodes to process
if episode_indices is None:
episode_indices = list(range(dataset.meta.total_episodes))
if repo_id is None:
repo_id = f"{dataset.repo_id}_video"
logging.info(
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
)
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
# Create new features dict, converting image features to video features
new_features = {}
for key, value in dataset.meta.features.items():
if key not in img_keys:
new_features[key] = value
else:
# Convert image key to video format
new_features[key] = value.copy()
new_features[key]["dtype"] = "video" # Change dtype from "image" to "video"
# Video info will be updated after episodes are encoded
# Create new metadata for video dataset
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=new_features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=True,
chunks_size=dataset.meta.chunks_size,
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
)
# Create temporary directory for image extraction
temp_dir = output_dir / "temp_images"
temp_dir.mkdir(parents=True, exist_ok=True)
# Process all episodes and batch encode videos
# Use dictionary for O(1) episode metadata lookups instead of O(n) linear search
all_episode_metadata = {}
fps = int(dataset.fps)
try:
# Build episode metadata entries first
logging.info("Building episode metadata...")
cumulative_frame_idx = 0
for ep_idx in episode_indices:
src_episode = dataset.meta.episodes[ep_idx]
ep_length = src_episode["length"]
ep_meta = {
"episode_index": ep_idx,
"length": ep_length,
"dataset_from_index": cumulative_frame_idx,
"dataset_to_index": cumulative_frame_idx + ep_length,
}
if "data/chunk_index" in src_episode:
ep_meta["data/chunk_index"] = src_episode["data/chunk_index"]
ep_meta["data/file_index"] = src_episode["data/file_index"]
all_episode_metadata[ep_idx] = ep_meta
cumulative_frame_idx += ep_length
# Process each camera and batch encode multiple episodes together
video_file_size_limit = new_meta.video_files_size_in_mb
# Pre-compute episode lengths for batching
episode_lengths = {ep_idx: dataset.meta.episodes["length"][ep_idx] for ep_idx in episode_indices}
for img_key in tqdm(img_keys, desc="Processing cameras"):
# Estimate size per frame by encoding a small calibration sample
# This provides accurate compression ratio for the specific codec parameters
size_per_frame_mb = _estimate_frame_size_via_calibration(
dataset=dataset,
img_key=img_key,
episode_indices=episode_indices,
temp_dir=temp_dir,
fps=fps,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
)
logging.info(f"Processing camera: {img_key}")
chunk_idx, file_idx = 0, 0
cumulative_timestamp = 0.0
# Process episodes in batches to stay under size limit
for batch_episodes in _iter_episode_batches(
episode_indices=episode_indices,
episode_lengths=episode_lengths,
size_per_frame_mb=size_per_frame_mb,
video_file_size_limit=video_file_size_limit,
max_episodes=max_episodes_per_batch,
max_frames=max_frames_per_batch,
):
total_frames_in_batch = sum(episode_lengths[idx] for idx in batch_episodes)
logging.info(
f" Encoding batch of {len(batch_episodes)} episodes "
f"({batch_episodes[0]}-{batch_episodes[-1]}) = {total_frames_in_batch} frames"
)
# Save images for all episodes in this batch
imgs_dir = temp_dir / f"batch_{chunk_idx}_{file_idx}" / img_key
episode_durations = _save_batch_episodes_images(
dataset=dataset,
imgs_dir=imgs_dir,
img_key=img_key,
episode_indices=batch_episodes,
num_workers=num_workers,
)
# Encode all batched episodes into single video
video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
)
video_path.parent.mkdir(parents=True, exist_ok=True)
encode_video_frames(
imgs_dir=imgs_dir,
video_path=video_path,
fps=fps,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
overwrite=True,
)
# Clean up temporary images
shutil.rmtree(imgs_dir)
# Update metadata for each episode in the batch
for ep_idx, duration in zip(batch_episodes, episode_durations, strict=True):
from_timestamp = cumulative_timestamp
to_timestamp = cumulative_timestamp + duration
cumulative_timestamp = to_timestamp
# Find episode metadata entry and add video metadata (O(1) dictionary lookup)
ep_meta = all_episode_metadata[ep_idx]
ep_meta[f"videos/{img_key}/chunk_index"] = chunk_idx
ep_meta[f"videos/{img_key}/file_index"] = file_idx
ep_meta[f"videos/{img_key}/from_timestamp"] = from_timestamp
ep_meta[f"videos/{img_key}/to_timestamp"] = to_timestamp
# Move to next video file for next batch
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, new_meta.chunks_size)
cumulative_timestamp = 0.0
# Copy and transform data files (removing image columns)
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
# Save episode metadata
episodes_df = pd.DataFrame(list(all_episode_metadata.values()))
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
episodes_path.parent.mkdir(parents=True, exist_ok=True)
episodes_df.to_parquet(episodes_path, index=False)
# Update metadata info
new_meta.info["total_episodes"] = len(episode_indices)
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values())
new_meta.info["total_tasks"] = dataset.meta.total_tasks
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
# Update video info for all image keys (now videos)
# We need to manually set video info since update_video_info() checks video_keys first
for img_key in img_keys:
if not new_meta.features[img_key].get("info", None):
video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=0, file_index=0
)
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
write_info(new_meta.info, new_meta.root)
# Copy stats and tasks
if dataset.meta.stats is not None:
# Remove image stats
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
write_stats(new_stats, new_meta.root)
if dataset.meta.tasks is not None:
write_tasks(dataset.meta.tasks, new_meta.root)
finally:
# Clean up temporary directory
if temp_dir.exists():
shutil.rmtree(temp_dir)
logging.info(f"Completed converting {dataset.repo_id} to video format")
logging.info(f"New dataset saved to: {output_dir}")
# Return new dataset
return LeRobotDataset(repo_id=repo_id, root=output_dir)
+5 -20
View File
@@ -935,30 +935,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
return get_hf_features_from_features(self.features)
def _get_query_indices(
self, abs_idx: int, ep_idx: int
) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]:
"""Compute query indices for delta timestamps.
Args:
abs_idx: The absolute index in the full dataset (not the relative index in filtered episodes).
ep_idx: The episode index.
Returns:
A tuple of (query_indices, padding) where:
- query_indices: Dict mapping keys to lists of absolute indices to query
- padding: Dict mapping "{key}_is_pad" to boolean tensors indicating padded positions
"""
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
ep = self.meta.episodes[ep_idx]
ep_start = ep["dataset_from_index"]
ep_end = ep["dataset_to_index"]
query_indices = {
key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx]
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items()
}
padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor(
[(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx]
[(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx]
)
for key, delta_idx in self.delta_indices.items()
}
@@ -1050,12 +1037,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._ensure_hf_dataset_loaded()
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
# Use the absolute index from the dataset for delta timestamp calculations
abs_idx = item["index"].item()
query_indices = None
if self.delta_indices is not None:
query_indices, padding = self._get_query_indices(abs_idx, ep_idx)
query_indices, padding = self._get_query_indices(idx, ep_idx)
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
for key, val in query_result.items():
@@ -1513,7 +1498,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_index = self.episode_buffer["episode_index"]
if isinstance(episode_index, np.ndarray):
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
for cam_key in self.meta.image_keys:
for cam_key in self.meta.camera_keys:
img_dir = self._get_image_file_dir(episode_index, cam_key)
if img_dir.is_dir():
shutil.rmtree(img_dir)
+2 -11
View File
@@ -1172,21 +1172,12 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
)
def to_parquet_with_hf_images(
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
) -> None:
def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None:
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
This way, it can be loaded by HF dataset and correctly formatted images are returned.
Args:
df: DataFrame to write to parquet.
path: Path to write the parquet file.
features: Optional HuggingFace Features schema. If provided, ensures image columns
are properly typed as Image() in the parquet schema.
"""
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
ds.to_parquet(path)
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
def item_to_torch(item: dict) -> dict:
+4 -5
View File
@@ -260,7 +260,6 @@ class HILSerlRobotEnvConfig(EnvConfig):
@dataclass
class LiberoEnv(EnvConfig):
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
task_ids: list[int] | None = None
fps: int = 30
episode_length: int | None = None
obs_type: str = "pixels_agent_pos"
@@ -339,10 +338,10 @@ class LiberoEnv(EnvConfig):
@property
def gym_kwargs(self) -> dict:
kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
if self.task_ids is not None:
kwargs["task_ids"] = self.task_ids
return kwargs
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
}
@EnvConfig.register_subclass("metaworld")
+2 -2
View File
@@ -293,9 +293,9 @@ class LiberoEnv(gym.Env):
def reset(self, seed=None, **kwargs):
super().reset(seed=seed)
self._env.seed(seed)
raw_obs = self._env.reset()
if self.init_states and self._init_states is not None:
raw_obs = self._env.set_init_state(self._init_states[self._init_state_id])
self._env.set_init_state(self._init_states[self._init_state_id])
raw_obs = self._env.reset()
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
# Step the simulator with a no-op action for a few frames so everything settles.
+1 -5
View File
@@ -14,8 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .motors_bus import (
Motor,
MotorCalibration,
MotorNormMode,
)
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus
+1 -1
View File
@@ -18,7 +18,7 @@ from dataclasses import dataclass
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"
from .motors_bus import MotorCalibration, MotorsBus
from lerobot.motors import MotorCalibration, MotorsBus
BAR_LEN, BAR_THICKNESS = 450, 8
HANDLE_R = 10
-18
View File
@@ -1,18 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .damiao import DamiaoMotorsBus
from .tables import *
-808
View File
@@ -1,808 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Portions of this file are derived from DM_Control_Python by cmjang.
# Licensed under the MIT License; see `LICENSE` for the full text:
# https://github.com/cmjang/DM_Control_Python
import logging
import time
from contextlib import contextmanager
from copy import deepcopy
from functools import cached_property
from typing import TYPE_CHECKING, Any, TypedDict
from lerobot.utils.import_utils import _can_available
if TYPE_CHECKING or _can_available:
import can
else:
can.Message = object
can.interface = None
import numpy as np
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import enter_pressed, move_cursor_up
from ..motors_bus import Motor, MotorCalibration, MotorsBusBase, NameOrID, Value
from .tables import (
AVAILABLE_BAUDRATES,
CAN_CMD_DISABLE,
CAN_CMD_ENABLE,
CAN_CMD_REFRESH,
CAN_CMD_SET_ZERO,
CAN_PARAM_ID,
DEFAULT_BAUDRATE,
DEFAULT_TIMEOUT_MS,
MIT_KD_RANGE,
MIT_KP_RANGE,
MOTOR_LIMIT_PARAMS,
MotorType,
)
logger = logging.getLogger(__name__)
LONG_TIMEOUT_SEC = 0.1
MEDIUM_TIMEOUT_SEC = 0.01
SHORT_TIMEOUT_SEC = 0.001
PRECISE_TIMEOUT_SEC = 0.0001
class MotorState(TypedDict):
position: float
velocity: float
torque: float
temp_mos: float
temp_rotor: float
class DamiaoMotorsBus(MotorsBusBase):
"""
The Damiao implementation for a MotorsBus using CAN bus communication.
This class uses python-can for CAN bus communication with Damiao motors.
For more info, see:
- python-can documentation: https://python-can.readthedocs.io/en/stable/
- Seedstudio documentation: https://wiki.seeedstudio.com/damiao_series/
- DM_Control_Python repo: https://github.com/cmjang/DM_Control_Python
"""
# CAN-specific settings
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
default_baudrate = DEFAULT_BAUDRATE
default_timeout = DEFAULT_TIMEOUT_MS
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
can_interface: str = "auto",
use_can_fd: bool = True,
bitrate: int = 1000000,
data_bitrate: int | None = 5000000,
):
"""
Initialize the Damiao motors bus.
Args:
port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS)
motors: Dictionary mapping motor names to Motor objects
calibration: Optional calibration data
can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial)
use_can_fd: Whether to use CAN FD mode (default: True for OpenArms)
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
"""
super().__init__(port, motors, calibration)
self.port = port
self.can_interface = can_interface
self.use_can_fd = use_can_fd
self.bitrate = bitrate
self.data_bitrate = data_bitrate
self.canbus: can.interface.Bus | None = None
self._is_connected = False
# Map motor names to CAN IDs
self._motor_can_ids: dict[str, int] = {}
self._recv_id_to_motor: dict[int, str] = {}
self._motor_types: dict[str, MotorType] = {}
for name, motor in self.motors.items():
if motor.motor_type_str is None:
raise ValueError(f"Motor '{name}' is missing required 'motor_type'")
self._motor_types[name] = getattr(MotorType, motor.motor_type_str.upper().replace("-", "_"))
# Map recv_id to motor name for filtering responses
if motor.recv_id is not None:
self._recv_id_to_motor[motor.recv_id] = name
# State cache for handling packet drops safely
self._last_known_states: dict[str, MotorState] = {
name: {
"position": 0.0,
"velocity": 0.0,
"torque": 0.0,
"temp_mos": 0.0,
"temp_rotor": 0.0,
}
for name in self.motors
}
# Dynamic gains storage
# Defaults: Kp=10.0 (Stiffness), Kd=0.5 (Damping)
self._gains: dict[str, dict[str, float]] = {name: {"kp": 10.0, "kd": 0.5} for name in self.motors}
@property
def is_connected(self) -> bool:
"""Check if the CAN bus is connected."""
return self._is_connected and self.canbus is not None
def connect(self, handshake: bool = True) -> None:
"""
Open the CAN bus and initialize communication.
Args:
handshake: If True, ping all motors to verify they're present
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(
f"{self.__class__.__name__}('{self.port}') is already connected."
)
try:
# Auto-detect interface type based on port name
if self.can_interface == "auto":
if self.port.startswith("/dev/"):
self.can_interface = "slcan"
logger.info(f"Auto-detected slcan interface for port {self.port}")
else:
self.can_interface = "socketcan"
logger.info(f"Auto-detected socketcan interface for port {self.port}")
# Connect to CAN bus
kwargs = {
"channel": self.port,
"bitrate": self.bitrate,
"interface": self.can_interface,
}
if self.can_interface == "socketcan" and self.use_can_fd and self.data_bitrate is not None:
kwargs.update({"data_bitrate": self.data_bitrate, "fd": True})
logger.info(
f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})"
)
else:
logger.info(f"Connected to {self.port} with {self.can_interface} (bitrate={self.bitrate})")
self.canbus = can.interface.Bus(**kwargs)
self._is_connected = True
if handshake:
self._handshake()
logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.")
except Exception as e:
self._is_connected = False
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
def _handshake(self) -> None:
"""
Verify all motors are present and populate initial state cache.
Raises ConnectionError if any motor fails to respond.
"""
logger.info("Starting handshake with motors...")
missing_motors = []
for motor_name in self.motors:
msg = self._refresh_motor(motor_name)
if msg is None:
missing_motors.append(motor_name)
else:
self._process_response(motor_name, msg)
time.sleep(MEDIUM_TIMEOUT_SEC)
if missing_motors:
raise ConnectionError(
f"Handshake failed. The following motors did not respond: {missing_motors}. "
"Check power (24V) and CAN wiring."
)
logger.info("Handshake successful. All motors ready.")
def disconnect(self, disable_torque: bool = True) -> None:
"""
Close the CAN bus connection.
Args:
disable_torque: If True, disable torque on all motors before disconnecting
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.")
if disable_torque:
try:
self.disable_torque()
except Exception as e:
logger.warning(f"Failed to disable torque during disconnect: {e}")
if self.canbus:
self.canbus.shutdown()
self.canbus = None
self._is_connected = False
logger.debug(f"{self.__class__.__name__} disconnected.")
def configure_motors(self) -> None:
"""Configure all motors with default settings."""
# Damiao motors don't require much configuration in MIT mode
# Just ensure they're enabled
for motor in self.motors:
self._send_simple_command(motor, CAN_CMD_ENABLE)
time.sleep(MEDIUM_TIMEOUT_SEC)
def _send_simple_command(self, motor: NameOrID, command_byte: int) -> None:
"""Helper to send simple 8-byte commands (Enable, Disable, Zero)."""
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [command_byte]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
if msg := self._recv_motor_response(expected_recv_id=recv_id):
self._process_response(motor_name, msg)
else:
logger.debug(f"No response from {motor_name} after command 0x{command_byte:02X}")
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors."""
target_motors = self._get_motors_list(motors)
for motor in target_motors:
for _ in range(num_retry + 1):
try:
self._send_simple_command(motor, CAN_CMD_ENABLE)
break
except Exception as e:
if _ == num_retry:
raise e
time.sleep(MEDIUM_TIMEOUT_SEC)
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Disable torque on selected motors."""
target_motors = self._get_motors_list(motors)
for motor in target_motors:
for _ in range(num_retry + 1):
try:
self._send_simple_command(motor, CAN_CMD_DISABLE)
break
except Exception as e:
if _ == num_retry:
raise e
time.sleep(MEDIUM_TIMEOUT_SEC)
@contextmanager
def torque_disabled(self, motors: str | list[str] | None = None):
"""
Context manager that guarantees torque is re-enabled.
This helper is useful to temporarily disable torque when configuring motors.
"""
self.disable_torque(motors)
try:
yield
finally:
self.enable_torque(motors)
def set_zero_position(self, motors: str | list[str] | None = None) -> None:
"""Set current position as zero for selected motors."""
target_motors = self._get_motors_list(motors)
for motor in target_motors:
self._send_simple_command(motor, CAN_CMD_SET_ZERO)
time.sleep(MEDIUM_TIMEOUT_SEC)
def _refresh_motor(self, motor: NameOrID) -> can.Message | None:
"""Refresh motor status and return the response."""
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
self.canbus.send(msg)
return self._recv_motor_response(expected_recv_id=recv_id)
def _recv_motor_response(
self, expected_recv_id: int | None = None, timeout: float = 0.001
) -> can.Message | None:
"""
Receive a response from a motor.
Args:
expected_recv_id: If provided, only return messages from this CAN ID
timeout: Timeout in seconds (default: 1ms for high-speed operation)
Returns:
CAN message if received, None otherwise
"""
try:
start_time = time.time()
messages_seen = []
while time.time() - start_time < timeout:
msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC)
if msg:
messages_seen.append(f"0x{msg.arbitration_id:02X}")
if expected_recv_id is None or msg.arbitration_id == expected_recv_id:
return msg
logger.debug(
f"Ignoring message from 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}"
)
if logger.isEnabledFor(logging.DEBUG):
if messages_seen:
logger.debug(
f"Received {len(messages_seen)} msgs from {set(messages_seen)}, expected 0x{expected_recv_id:02X}"
)
else:
logger.debug(f"No CAN messages received (expected 0x{expected_recv_id:02X})")
except Exception as e:
logger.debug(f"Failed to receive CAN message: {e}")
return None
def _recv_all_responses(
self, expected_recv_ids: list[int], timeout: float = 0.002
) -> dict[int, can.Message]:
"""
Efficiently receive responses from multiple motors at once.
Uses the OpenArms pattern: collect all available messages within timeout.
Args:
expected_recv_ids: List of CAN IDs we expect responses from
timeout: Total timeout in seconds (default: 2ms)
Returns:
Dictionary mapping recv_id to CAN message
"""
responses = {}
expected_set = set(expected_recv_ids)
start_time = time.time()
try:
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
# 100us poll timeout
msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC)
if msg and msg.arbitration_id in expected_set:
responses[msg.arbitration_id] = msg
if len(responses) == len(expected_recv_ids):
break
except Exception as e:
logger.debug(f"Error receiving responses: {e}")
return responses
def _encode_mit_packet(
self,
motor_type: MotorType,
kp: float,
kd: float,
position_degrees: float,
velocity_deg_per_sec: float,
torque: float,
) -> list[int]:
"""Helper to encode control parameters into 8 bytes for MIT mode."""
# Convert degrees to radians
position_rad = np.radians(position_degrees)
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
# Get motor limits
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
# Encode parameters
kp_uint = self._float_to_uint(kp, *MIT_KP_RANGE, 12)
kd_uint = self._float_to_uint(kd, *MIT_KD_RANGE, 12)
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
# Pack data
data = [0] * 8
data[0] = (q_uint >> 8) & 0xFF
data[1] = q_uint & 0xFF
data[2] = dq_uint >> 4
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
data[4] = kp_uint & 0xFF
data[5] = kd_uint >> 4
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
data[7] = tau_uint & 0xFF
return data
def _mit_control(
self,
motor: NameOrID,
kp: float,
kd: float,
position_degrees: float,
velocity_deg_per_sec: float,
torque: float,
) -> None:
"""Send MIT control command to a motor."""
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
recv_id = self._get_motor_recv_id(motor)
if msg := self._recv_motor_response(expected_recv_id=recv_id):
self._process_response(motor_name, msg)
else:
logger.debug(f"No response from {motor_name} after MIT control command")
def _mit_control_batch(
self,
commands: dict[NameOrID, tuple[float, float, float, float, float]],
) -> None:
"""
Send MIT control commands to multiple motors in batch.
Sends all commands first, then collects responses.
Args:
commands: Dict mapping motor name/ID to (kp, kd, position_deg, velocity_deg/s, torque)
Example: {'joint_1': (10.0, 0.5, 45.0, 0.0, 0.0), ...}
"""
if not commands:
return
recv_id_to_motor: dict[int, str] = {}
# Step 1: Send all MIT control commands
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
# Step 2: Collect responses and update state cache
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=SHORT_TIMEOUT_SEC)
for recv_id, motor_name in recv_id_to_motor.items():
if msg := responses.get(recv_id):
self._process_response(motor_name, msg)
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
"""Convert float to unsigned integer for CAN transmission."""
x = max(x_min, min(x_max, x)) # Clamp to range
span = x_max - x_min
data_norm = (x - x_min) / span
return int(data_norm * ((1 << bits) - 1))
def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float:
"""Convert unsigned integer from CAN to float."""
span = x_max - x_min
data_norm = float(x) / ((1 << bits) - 1)
return data_norm * span + x_min
def _decode_motor_state(
self, data: bytearray | bytes, motor_type: MotorType
) -> tuple[float, float, float, int, int]:
"""
Decode motor state from CAN data.
Returns: (position_deg, velocity_deg_s, torque, temp_mos, temp_rotor)
"""
if len(data) < 8:
raise ValueError("Invalid motor state data")
# Extract encoded values
q_uint = (data[1] << 8) | data[2]
dq_uint = (data[3] << 4) | (data[4] >> 4)
tau_uint = ((data[4] & 0x0F) << 8) | data[5]
t_mos = data[6]
t_rotor = data[7]
# Get motor limits
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
# Decode to physical values
position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16)
velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12)
torque = self._uint_to_float(tau_uint, -tmax, tmax, 12)
return np.degrees(position_rad), np.degrees(velocity_rad_per_sec), torque, t_mos, t_rotor
def _process_response(self, motor: str, msg: can.Message) -> None:
"""Decode a message and update the motor state cache."""
try:
motor_type = self._motor_types[motor]
pos, vel, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
self._last_known_states[motor] = {
"position": pos,
"velocity": vel,
"torque": torque,
"temp_mos": float(t_mos),
"temp_rotor": float(t_rotor),
}
except Exception as e:
logger.warning(f"Failed to decode response from {motor}: {e}")
def read(self, data_name: str, motor: str) -> Value:
"""Read a value from a single motor. Positions are always in degrees."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Refresh motor to get latest state
msg = self._refresh_motor(motor)
if msg is None:
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
raise ConnectionError(
f"No response from motor '{motor}' (send ID: 0x{motor_id:02X}, recv ID: 0x{recv_id:02X}). "
f"Check that: 1) Motor is powered (24V), 2) CAN wiring is correct, "
f"3) Motor IDs are configured correctly using Damiao Debugging Tools"
)
self._process_response(motor, msg)
return self._get_cached_value(motor, data_name)
def _get_cached_value(self, motor: str, data_name: str) -> Value:
"""Retrieve a specific value from the cache."""
state = self._last_known_states[motor]
mapping: dict[str, Any] = {
"Present_Position": state["position"],
"Present_Velocity": state["velocity"],
"Present_Torque": state["torque"],
"Temperature_MOS": state["temp_mos"],
"Temperature_Rotor": state["temp_rotor"],
}
if data_name not in mapping:
raise ValueError(f"Unknown data_name: {data_name}")
return mapping[data_name]
def write(
self,
data_name: str,
motor: str,
value: Value,
) -> None:
"""
Write a value to a single motor. Positions are always in degrees.
Can write 'Goal_Position', 'Kp', or 'Kd'.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if data_name in ("Kp", "Kd"):
self._gains[motor][data_name.lower()] = float(value)
elif data_name == "Goal_Position":
kp = self._gains[motor]["kp"]
kd = self._gains[motor]["kd"]
self._mit_control(motor, kp, kd, float(value), 0.0, 0.0)
else:
raise ValueError(f"Writing {data_name} not supported in MIT mode")
def sync_read(
self,
data_name: str,
motors: str | list[str] | None = None,
) -> dict[str, Value]:
"""
Read the same value from multiple motors simultaneously.
"""
target_motors = self._get_motors_list(motors)
self._batch_refresh(target_motors)
result = {}
for motor in target_motors:
result[motor] = self._get_cached_value(motor, data_name)
return result
def sync_read_all_states(
self,
motors: str | list[str] | None = None,
*,
num_retry: int = 0,
) -> dict[str, MotorState]:
"""
Read ALL motor states (position, velocity, torque) from multiple motors in ONE refresh cycle.
Returns:
Dictionary mapping motor names to state dicts with keys: 'position', 'velocity', 'torque'
Example: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
"""
target_motors = self._get_motors_list(motors)
self._batch_refresh(target_motors)
result = {}
for motor in target_motors:
result[motor] = self._last_known_states[motor].copy()
return result
def _batch_refresh(self, motors: list[str]) -> None:
"""Internal helper to refresh a list of motors and update cache."""
# Send refresh commands
for motor in motors:
motor_id = self._get_motor_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
self.canbus.send(msg)
# Small delay to reduce bus congestion if necessary, though removed in sync_read previously
# precise_sleep(PRECISE_SLEEP_SEC)
# Collect responses
expected_recv_ids = [self._get_motor_recv_id(m) for m in motors]
responses = self._recv_all_responses(expected_recv_ids, timeout=MEDIUM_TIMEOUT_SEC)
# Update cache
for motor in motors:
recv_id = self._get_motor_recv_id(motor)
msg = responses.get(recv_id)
if msg:
self._process_response(motor, msg)
else:
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
"""
Write values to multiple motors simultaneously. Positions are always in degrees.
"""
if data_name in ("Kp", "Kd"):
key = data_name.lower()
for motor, val in values.items():
self._gains[motor][key] = float(val)
elif data_name == "Goal_Position":
# Step 1: Send all MIT control commands
recv_id_to_motor: dict[int, str] = {}
for motor, value_degrees in values.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
kp = self._gains[motor]["kp"]
kd = self._gains[motor]["kd"]
data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
precise_sleep(PRECISE_TIMEOUT_SEC)
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
# Step 2: Collect responses and update state cache
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=MEDIUM_TIMEOUT_SEC)
for recv_id, motor_name in recv_id_to_motor.items():
if msg := responses.get(recv_id):
self._process_response(motor_name, msg)
else:
# Fall back to individual writes
for motor, value in values.items():
self.write(data_name, motor, value)
def read_calibration(self) -> dict[str, MotorCalibration]:
"""Read calibration data from motors."""
# Damiao motors don't store calibration internally
# Return existing calibration or empty dict
return self.calibration if self.calibration else {}
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
"""Write calibration data to motors."""
# Damiao motors don't store calibration internally
# Just cache it in memory
if cache:
self.calibration = calibration_dict
def record_ranges_of_motion(
self,
motors: NameOrID | list[NameOrID] | None = None,
display_values: bool = True,
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
"""
Interactively record the min/max values of each motor in degrees.
Move the joints by hand (with torque disabled) while the method streams live positions.
Press Enter to finish.
"""
target_motors = self._get_motors_list(motors)
self.disable_torque(target_motors)
time.sleep(LONG_TIMEOUT_SEC)
start_positions = self.sync_read("Present_Position", target_motors)
mins = start_positions.copy()
maxes = start_positions.copy()
print("\nMove joints through their full range of motion. Press ENTER when done.")
user_pressed_enter = False
while not user_pressed_enter:
positions = self.sync_read("Present_Position", target_motors)
for motor in target_motors:
if motor in positions:
mins[motor] = min(positions[motor], mins.get(motor, positions[motor]))
maxes[motor] = max(positions[motor], maxes.get(motor, positions[motor]))
if display_values:
print("\n" + "=" * 50)
print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}")
print("-" * 50)
for motor in target_motors:
if motor in positions:
print(
f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}"
)
if enter_pressed():
user_pressed_enter = True
if display_values and not user_pressed_enter:
move_cursor_up(len(target_motors) + 4)
time.sleep(LONG_TIMEOUT_SEC)
self.enable_torque(target_motors)
for motor in target_motors:
if (motor in mins) and (motor in maxes) and (int(abs(maxes[motor] - mins[motor])) < 5):
raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)")
return mins, maxes
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
"""Convert motor specification to list of motor names."""
if motors is None:
return list(self.motors.keys())
elif isinstance(motors, str):
return [motors]
elif isinstance(motors, list):
return motors
else:
raise TypeError(f"Invalid motors type: {type(motors)}")
def _get_motor_id(self, motor: NameOrID) -> int:
"""Get CAN ID for a motor."""
if isinstance(motor, str):
if motor in self.motors:
return self.motors[motor].id
else:
raise ValueError(f"Unknown motor: {motor}")
else:
return motor
def _get_motor_name(self, motor: NameOrID) -> str:
"""Get motor name from name or ID."""
if isinstance(motor, str):
return motor
else:
for name, m in self.motors.items():
if m.id == motor:
return name
raise ValueError(f"Unknown motor ID: {motor}")
def _get_motor_recv_id(self, motor: NameOrID) -> int:
"""Get motor recv_id from name or ID."""
motor_name = self._get_motor_name(motor)
motor_obj = self.motors.get(motor_name)
if motor_obj and motor_obj.recv_id is not None:
return motor_obj.recv_id
else:
raise ValueError(f"Motor {motor_obj} doesn't have a valid recv_id (None).")
@cached_property
def is_calibrated(self) -> bool:
"""Check if motors are calibrated."""
return bool(self.calibration)
-209
View File
@@ -1,209 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration tables for Damiao motors."""
from enum import IntEnum
# Motor type definitions
class MotorType(IntEnum):
DM3507 = 0
DM4310 = 1
DM4310_48V = 2
DM4340 = 3
DM4340_48V = 4
DM6006 = 5
DM8006 = 6
DM8009 = 7
DM10010L = 8
DM10010 = 9
DMH3510 = 10
DMH6215 = 11
DMG6220 = 12
# Control modes
class ControlMode(IntEnum):
MIT = 1
POS_VEL = 2
VEL = 3
TORQUE_POS = 4
# Motor variable IDs (RID)
class MotorVariable(IntEnum):
UV_VALUE = 0
KT_VALUE = 1
OT_VALUE = 2
OC_VALUE = 3
ACC = 4
DEC = 5
MAX_SPD = 6
MST_ID = 7
ESC_ID = 8
TIMEOUT = 9
CTRL_MODE = 10
DAMP = 11
INERTIA = 12
HW_VER = 13
SW_VER = 14
SN = 15
NPP = 16
RS = 17
LS = 18
FLUX = 19
GR = 20
PMAX = 21
VMAX = 22
TMAX = 23
I_BW = 24
KP_ASR = 25
KI_ASR = 26
KP_APR = 27
KI_APR = 28
OV_VALUE = 29
GREF = 30
DETA = 31
V_BW = 32
IQ_C1 = 33
VL_C1 = 34
CAN_BR = 35
SUB_VER = 36
U_OFF = 50
V_OFF = 51
K1 = 52
K2 = 53
M_OFF = 54
DIR = 55
P_M = 80
XOUT = 81
# Motor limit parameters [PMAX, VMAX, TMAX]
# PMAX: Maximum position (rad)
# VMAX: Maximum velocity (rad/s)
# TMAX: Maximum torque (N·m)
MOTOR_LIMIT_PARAMS = {
MotorType.DM3507: (12.5, 30, 10),
MotorType.DM4310: (12.5, 30, 10),
MotorType.DM4310_48V: (12.5, 50, 10),
MotorType.DM4340: (12.5, 8, 28),
MotorType.DM4340_48V: (12.5, 10, 28),
MotorType.DM6006: (12.5, 45, 20),
MotorType.DM8006: (12.5, 45, 40),
MotorType.DM8009: (12.5, 45, 54),
MotorType.DM10010L: (12.5, 25, 200),
MotorType.DM10010: (12.5, 20, 200),
MotorType.DMH3510: (12.5, 280, 1),
MotorType.DMH6215: (12.5, 45, 10),
MotorType.DMG6220: (12.5, 45, 10),
}
# Motor model names
MODEL_NAMES = {
MotorType.DM3507: "dm3507",
MotorType.DM4310: "dm4310",
MotorType.DM4310_48V: "dm4310_48v",
MotorType.DM4340: "dm4340",
MotorType.DM4340_48V: "dm4340_48v",
MotorType.DM6006: "dm6006",
MotorType.DM8006: "dm8006",
MotorType.DM8009: "dm8009",
MotorType.DM10010L: "dm10010l",
MotorType.DM10010: "dm10010",
MotorType.DMH3510: "dmh3510",
MotorType.DMH6215: "dmh6215",
MotorType.DMG6220: "dmg6220",
}
# Motor resolution table (encoder counts per revolution)
MODEL_RESOLUTION = {
"dm3507": 65536,
"dm4310": 65536,
"dm4310_48v": 65536,
"dm4340": 65536,
"dm4340_48v": 65536,
"dm6006": 65536,
"dm8006": 65536,
"dm8009": 65536,
"dm10010l": 65536,
"dm10010": 65536,
"dmh3510": 65536,
"dmh6215": 65536,
"dmg6220": 65536,
}
# CAN baudrates supported by Damiao motors
AVAILABLE_BAUDRATES = [
125000, # 0: 125 kbps
200000, # 1: 200 kbps
250000, # 2: 250 kbps
500000, # 3: 500 kbps
1000000, # 4: 1 mbps (default for OpenArms)
2000000, # 5: 2 mbps
2500000, # 6: 2.5 mbps
3200000, # 7: 3.2 mbps
4000000, # 8: 4 mbps
5000000, # 9: 5 mbps
]
DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms
# Default timeout in milliseconds
DEFAULT_TIMEOUT_MS = 1000
# OpenArms specific configurations
# Based on: https://docs.openarm.dev/software/setup/configure-test
# OpenArms has 7 DOF per arm (14 total for dual arm)
OPENARMS_ARM_MOTOR_IDS = {
"joint_1": {"send": 0x01, "recv": 0x11}, # J1 - Shoulder pan
"joint_2": {"send": 0x02, "recv": 0x12}, # J2 - Shoulder lift
"joint_3": {"send": 0x03, "recv": 0x13}, # J3 - Elbow flex
"joint_4": {"send": 0x04, "recv": 0x14}, # J4 - Wrist flex
"joint_5": {"send": 0x05, "recv": 0x15}, # J5 - Wrist roll
"joint_6": {"send": 0x06, "recv": 0x16}, # J6 - Wrist pitch
"joint_7": {"send": 0x07, "recv": 0x17}, # J7 - Wrist rotation
}
OPENARMS_GRIPPER_MOTOR_IDS = {
"gripper": {"send": 0x08, "recv": 0x18}, # J8 - Gripper
}
# Default motor types for OpenArms
OPENARMS_DEFAULT_MOTOR_TYPES = {
"joint_1": MotorType.DM8009, # Shoulder pan - high torque
"joint_2": MotorType.DM8009, # Shoulder lift - high torque
"joint_3": MotorType.DM4340, # Shoulder rotation
"joint_4": MotorType.DM4340, # Elbow flex
"joint_5": MotorType.DM4310, # Wrist roll
"joint_6": MotorType.DM4310, # Wrist pitch
"joint_7": MotorType.DM4310, # Wrist rotation
"gripper": MotorType.DM4310, # Gripper
}
# MIT control parameter ranges
MIT_KP_RANGE = (0.0, 500.0)
MIT_KD_RANGE = (0.0, 5.0)
# CAN frame command IDs
CAN_CMD_ENABLE = 0xFC
CAN_CMD_DISABLE = 0xFD
CAN_CMD_SET_ZERO = 0xFE
CAN_CMD_REFRESH = 0xCC
CAN_CMD_QUERY_PARAM = 0x33
CAN_CMD_WRITE_PARAM = 0x55
CAN_CMD_SAVE_PARAM = 0xAA
# CAN ID for parameter operations
CAN_PARAM_ID = 0x7FF
+6 -5
View File
@@ -22,8 +22,9 @@ import logging
from copy import deepcopy
from enum import Enum
from ..encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from .tables import (
AVAILABLE_BAUDRATES,
MODEL_BAUDRATE_TABLE,
@@ -99,7 +100,7 @@ def _split_into_byte_chunks(value: int, length: int) -> list[int]:
return data
class DynamixelMotorsBus(SerialMotorsBus):
class DynamixelMotorsBus(MotorsBus):
"""
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
the motors. For more info, see the Dynamixel SDK Documentation:
@@ -202,9 +203,9 @@ class DynamixelMotorsBus(SerialMotorsBus):
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
+7 -6
View File
@@ -17,8 +17,9 @@ from copy import deepcopy
from enum import Enum
from pprint import pformat
from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from .tables import (
FIRMWARE_MAJOR_VERSION,
FIRMWARE_MINOR_VERSION,
@@ -95,7 +96,7 @@ def patch_setPacketTimeout(self, packet_length): # noqa: N802
self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50
class FeetechMotorsBus(SerialMotorsBus):
class FeetechMotorsBus(MotorsBus):
"""
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
@@ -297,11 +298,11 @@ class FeetechMotorsBus(SerialMotorsBus):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self.write("Lock", motor, 0, num_retry=num_retry)
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
addr, length = get_address(self.model_ctrl_table, model, "Lock")
self._write(addr, length, motor, 0, num_retry=num_retry)
self._write(addr, length, motor_id, 0, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
-1
View File
@@ -205,7 +205,6 @@ MODEL_BAUDRATE_TABLE = {
# Sign-Magnitude encoding bits
STS_SMS_SERIES_ENCODINGS_TABLE = {
"Present_Load": 10,
"Homing_Offset": 11,
"Goal_Position": 15,
"Goal_Velocity": 15,
+35 -98
View File
@@ -19,8 +19,6 @@
# TODO(aliberts): Add block noqa when feature below is available
# https://github.com/astral-sh/ruff/issues/3711
from __future__ import annotations
import abc
import logging
from contextlib import contextmanager
@@ -34,7 +32,7 @@ import serial
from deepdiff import DeepDiff
from tqdm import tqdm
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.utils import enter_pressed, move_cursor_up
NameOrID: TypeAlias = str | int
@@ -43,81 +41,6 @@ Value: TypeAlias = int | float
logger = logging.getLogger(__name__)
class MotorsBusBase(abc.ABC):
"""
Base class for all motor bus implementations.
This is a minimal interface that all motor buses must implement, regardless of their
communication protocol (serial, CAN, etc.).
"""
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
self.port = port
self.motors = motors
self.calibration = calibration if calibration else {}
@abc.abstractmethod
def connect(self, handshake: bool = True) -> None:
"""Establish connection to the motors."""
pass
@abc.abstractmethod
def disconnect(self, disable_torque: bool = True) -> None:
"""Disconnect from the motors."""
pass
@property
@abc.abstractmethod
def is_connected(self) -> bool:
"""Check if connected to the motors."""
pass
@abc.abstractmethod
def read(self, data_name: str, motor: str) -> Value:
"""Read a value from a single motor."""
pass
@abc.abstractmethod
def write(self, data_name: str, motor: str, value: Value) -> None:
"""Write a value to a single motor."""
pass
@abc.abstractmethod
def sync_read(self, data_name: str, motors: str | list[str] | None = None) -> dict[str, Value]:
"""Read a value from multiple motors."""
pass
@abc.abstractmethod
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
"""Write values to multiple motors."""
pass
@abc.abstractmethod
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors."""
pass
@abc.abstractmethod
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Disable torque on selected motors."""
pass
@abc.abstractmethod
def read_calibration(self) -> dict[str, MotorCalibration]:
"""Read calibration parameters from the motors."""
pass
@abc.abstractmethod
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
"""Write calibration parameters to the motors."""
pass
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
ctrl_table = model_ctrl_table.get(model)
if ctrl_table is None:
@@ -174,8 +97,6 @@ class Motor:
id: int
model: str
norm_mode: MotorNormMode
motor_type_str: str | None = None
recv_id: int | None = None
class PortHandler(Protocol):
@@ -282,15 +203,15 @@ class GroupSyncWrite(Protocol):
def txPacket(self): ...
class SerialMotorsBus(MotorsBusBase):
class MotorsBus(abc.ABC):
"""
A SerialMotorsBus allows to efficiently read and write to motors connected via serial communication.
A MotorsBus allows to efficiently read and write to the attached motors.
It represents several motors daisy-chained together and connected through a serial port.
There are currently two implementations of this class:
There are currently two implementations of this abstract class:
- DynamixelMotorsBus
- FeetechMotorsBus
This class is specifically for serial-based motor protocols (Dynamixel, Feetech, etc.).
Note: This class may evolve in the future should we add support for other types of bus.
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
To find the port, you can run our utility script:
@@ -339,7 +260,9 @@ class SerialMotorsBus(MotorsBusBase):
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
super().__init__(port, motors, calibration)
self.port = port
self.motors = motors
self.calibration = calibration if calibration else {}
self.port_handler: PortHandler
self.packet_handler: PacketHandler
@@ -488,7 +411,6 @@ class SerialMotorsBus(MotorsBusBase):
"""bool: `True` if the underlying serial port is open."""
return self.port_handler.is_open
@check_if_already_connected
def connect(self, handshake: bool = True) -> None:
"""Open the serial port and initialise communication.
@@ -500,6 +422,10 @@ class SerialMotorsBus(MotorsBusBase):
DeviceAlreadyConnectedError: The port is already open.
ConnectionError: The underlying SDK failed to open the port or the handshake did not succeed.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(
f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice."
)
self._connect(handshake)
self.set_timeout()
@@ -521,7 +447,6 @@ class SerialMotorsBus(MotorsBusBase):
def _handshake(self) -> None:
pass
@check_if_not_connected
def disconnect(self, disable_torque: bool = True) -> None:
"""Close the serial port (optionally disabling torque first).
@@ -530,6 +455,10 @@ class SerialMotorsBus(MotorsBusBase):
closing the port. This can prevent damaging motors if they are left applying resisting torque
after disconnect.
"""
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. Try running `{self.__class__.__name__}.connect()` first."
)
if disable_torque:
self.port_handler.clearPort()
@@ -609,7 +538,7 @@ class SerialMotorsBus(MotorsBusBase):
self.set_baudrate(self.default_baudrate)
@abc.abstractmethod
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]:
pass
@abc.abstractmethod
@@ -622,13 +551,13 @@ class SerialMotorsBus(MotorsBusBase):
pass
@abc.abstractmethod
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
"""Disable torque on selected motors.
Disabling Torque allows to write to the motors' permanent memory area (EPROM/EEPROM).
Args:
motors ( str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a
motors (int | str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a
list of names or `None` to affect every registered motor. Defaults to `None`.
num_retry (int, optional): Number of additional retry attempts on communication failure.
Defaults to 0.
@@ -978,7 +907,6 @@ class SerialMotorsBus(MotorsBusBase):
"""
pass
@check_if_not_connected
def read(
self,
data_name: str,
@@ -999,6 +927,10 @@ class SerialMotorsBus(MotorsBusBase):
Returns:
Value: Raw or normalised value depending on *normalize*.
"""
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
id_ = self.motors[motor].id
model = self.motors[motor].model
@@ -1049,7 +981,6 @@ class SerialMotorsBus(MotorsBusBase):
return value, comm, error
@check_if_not_connected
def write(
self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0
) -> None:
@@ -1068,6 +999,10 @@ class SerialMotorsBus(MotorsBusBase):
normalize (bool, optional): Enable or disable normalisation. Defaults to `True`.
num_retry (int, optional): Retry attempts. Defaults to `0`.
"""
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
id_ = self.motors[motor].id
model = self.motors[motor].model
@@ -1109,7 +1044,6 @@ class SerialMotorsBus(MotorsBusBase):
return comm, error
@check_if_not_connected
def sync_read(
self,
data_name: str,
@@ -1129,6 +1063,10 @@ class SerialMotorsBus(MotorsBusBase):
Returns:
dict[str, Value]: Mapping *motor name value*.
"""
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
self._assert_protocol_is_compatible("sync_read")
@@ -1201,7 +1139,6 @@ class SerialMotorsBus(MotorsBusBase):
# for id_ in motor_ids:
# value = self.sync_reader.getData(id_, address, length)
@check_if_not_connected
def sync_write(
self,
data_name: str,
@@ -1223,6 +1160,10 @@ class SerialMotorsBus(MotorsBusBase):
normalize (bool, optional): If `True` (default) convert values from the user range to raw units.
num_retry (int, optional): Retry attempts. Defaults to `0`.
"""
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in ids_values]
@@ -1271,7 +1212,3 @@ class SerialMotorsBus(MotorsBusBase):
for id_, value in ids_values.items():
data = self._serialize_data(value, length)
self.sync_writer.addParam(id_, data)
# Backward compatibility alias
MotorsBus: TypeAlias = SerialMotorsBus
+1 -130
View File
@@ -32,22 +32,16 @@ Notes:
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
"""
import builtins
import os
from collections import deque
from pathlib import Path
from typing import TypeVar
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.groot.groot_n1 import GR00TN15
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES
T = TypeVar("T", bound="GrootPolicy")
from lerobot.utils.constants import ACTION
class GrootPolicy(PreTrainedPolicy):
@@ -96,129 +90,6 @@ class GrootPolicy(PreTrainedPolicy):
"""Reset policy state when environment resets."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: GrootConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = True,
**kwargs,
) -> T:
"""Load Groot policy from pretrained model.
Handles two cases:
1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model
2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors
Args:
pretrained_name_or_path: Path to the GR00T model or fine-tuned checkpoint
config: Optional GrootConfig. If None, loads from checkpoint or creates default
force_download: Force download even if cached
resume_download: Resume interrupted download
proxies: Proxy settings
token: HuggingFace authentication token
cache_dir: Cache directory path
local_files_only: Only use local files
revision: Specific model revision
strict: Strict state dict loading
**kwargs: Additional arguments (passed to config)
Returns:
Initialized GrootPolicy instance with loaded model
"""
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
print(
"The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n"
f"Loading pretrained model from: {pretrained_name_or_path}"
)
model_id = str(pretrained_name_or_path)
is_finetuned_checkpoint = False
# Check if this is a fine-tuned LeRobot checkpoint (has model.safetensors)
try:
if os.path.isdir(model_id):
is_finetuned_checkpoint = os.path.exists(os.path.join(model_id, SAFETENSORS_SINGLE_FILE))
else:
# Try to download the safetensors file to check if it exists
try:
hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=False, # Just check, don't force download
proxies=proxies,
token=token,
local_files_only=local_files_only,
)
is_finetuned_checkpoint = True
except HfHubHTTPError:
is_finetuned_checkpoint = False
except Exception:
is_finetuned_checkpoint = False
if is_finetuned_checkpoint:
# This is a fine-tuned LeRobot checkpoint - use parent class loading
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
return super().from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
config=config,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
strict=strict,
**kwargs,
)
# This is a base GR00T model - load it fresh
print("Detected base GR00T model, loading from HuggingFace...")
if config is None:
# Create default config with the pretrained path
config = GrootConfig(base_model_path=str(pretrained_name_or_path))
# Add minimal visual feature required for validation
# validate_features() will automatically add state and action features
# These are placeholders - actual robot features come from the preprocessor
if not config.input_features:
config.input_features = {
f"{OBS_IMAGES}.camera": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224), # Default image size from config
),
}
else:
# Override the base_model_path with the provided path
config.base_model_path = str(pretrained_name_or_path)
# Pass through any additional config overrides from kwargs
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
# Create a fresh policy instance - this will automatically load the GR00T model
# in __init__ via _create_groot_model()
policy = cls(config)
policy.eval()
return policy
def get_optim_params(self) -> dict:
return self.parameters()
@@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224
@@ -50,9 +50,8 @@ class PI0Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configurations
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
rtc_training_config: RTCTrainingConfig | None = None
image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE,
+19 -85
View File
@@ -44,12 +44,6 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.rtc.training_time import (
apply_rtc_training_time,
apply_training_time_rtc_inference,
masked_mean,
sample_rtc_delay,
)
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -85,8 +79,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim not in (1, 2):
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
@@ -94,14 +88,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
if time.ndim == 1:
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
time_flat = time.reshape(-1)
sin_input = scaling_factor[None, :] * time_flat[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb.reshape(*time.shape, dimension)
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
@@ -617,9 +605,6 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _training_time_rtc_inference_enabled(self):
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -729,10 +714,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
if time_emb.dim() == 2:
time_emb = time_emb[:, None, :].expand_as(action_emb)
elif time_emb.shape[:2] != action_emb.shape[:2]:
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
time_emb = time_emb[:, None, :].expand_as(action_emb)
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
def mlp_func(action_time_emb):
@@ -768,12 +750,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
if time.ndim == 1:
time_expanded = time[:, None, None]
elif time.ndim == 2:
time_expanded = time[:, :, None]
else:
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
@@ -869,37 +846,24 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
dt = -1.0 / num_steps
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
use_training_time_rtc = self._training_time_rtc_inference_enabled()
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
if use_training_time_rtc:
x_t_cond, time_tensor = apply_training_time_rtc_inference(
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
)
v_t = self.denoise_step(
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=x_t_cond,
timestep=time_tensor,
x_t=input_x_t,
timestep=current_timestep,
)
elif self._rtc_enabled():
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
@@ -910,14 +874,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
execution_horizon=execution_horizon,
)
else:
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
v_t = self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=x_t,
timestep=time_tensor,
)
v_t = denoise_step_partial_call(x_t)
x_t = x_t + dt * v_t
@@ -1320,19 +1277,7 @@ class PI0Policy(PreTrainedPolicy):
actions = self.prepare_action(batch)
# Compute loss
postfix_mask = None
rtc_cfg = self.config.rtc_training_config
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
batch_size = actions.shape[0]
time = self.model.sample_time(batch_size, actions.device)
noise = self.model.sample_noise(actions.shape, actions.device)
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
losses = self.model.forward(
images, img_masks, lang_tokens, lang_masks, state, actions, noise=noise, time=time
)
else:
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
# Truncate losses to actual action dimensions
original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -1344,22 +1289,11 @@ class PI0Policy(PreTrainedPolicy):
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
per_sample_loss = losses.mean(dim=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict
else:
# Default: return scalar mean loss
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
loss = losses.mean()
loss_dict["loss"] = loss.item()
return loss, loss_dict
def _get_default_peft_targets(self) -> dict[str, any]:
"""Return default PEFT target modules for PI0 fine-tuning."""
common_projections = (
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
)
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
return {
"target_modules": target_modules,
"modules_to_save": [],
}
@@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224
@@ -52,7 +52,6 @@ class PI05Config(PreTrainedConfig):
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
rtc_training_config: RTCTrainingConfig | None = None
image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE,
+18 -77
View File
@@ -44,12 +44,6 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.rtc.training_time import (
apply_rtc_training_time,
apply_training_time_rtc_inference,
masked_mean,
sample_rtc_delay,
)
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -84,8 +78,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim not in (1, 2):
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
@@ -93,14 +87,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
if time.ndim == 1:
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
time_flat = time.reshape(-1)
sin_input = scaling_factor[None, :] * time_flat[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb.reshape(*time.shape, dimension)
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
@@ -614,9 +602,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _training_time_rtc_inference_enabled(self):
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -744,12 +729,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
if time.ndim == 1:
time_expanded = time[:, None, None]
elif time.ndim == 2:
time_expanded = time[:, :, None]
else:
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
@@ -840,35 +820,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
dt = -1.0 / num_steps
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
use_training_time_rtc = self._training_time_rtc_inference_enabled()
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
if use_training_time_rtc:
x_t_cond, time_tensor = apply_training_time_rtc_inference(
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
)
v_t = self.denoise_step(
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=x_t_cond,
timestep=time_tensor,
x_t=input_x_t,
timestep=current_timestep,
)
elif self._rtc_enabled():
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
@@ -879,13 +847,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
execution_horizon=execution_horizon,
)
else:
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=x_t,
timestep=time_tensor,
)
v_t = denoise_step_partial_call(x_t)
x_t = x_t + dt * v_t
@@ -1288,17 +1250,7 @@ class PI05Policy(PreTrainedPolicy):
actions = self.prepare_action(batch)
# Compute loss (no separate state needed for PI05)
postfix_mask = None
rtc_cfg = self.config.rtc_training_config
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
batch_size = actions.shape[0]
time = self.model.sample_time(batch_size, actions.device)
noise = self.model.sample_noise(actions.shape, actions.device)
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise=noise, time=time)
else:
losses = self.model.forward(images, img_masks, tokens, masks, actions)
losses = self.model.forward(images, img_masks, tokens, masks, actions)
# Truncate losses to actual action dimensions
original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -1310,22 +1262,11 @@ class PI05Policy(PreTrainedPolicy):
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
per_sample_loss = losses.mean(dim=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict
else:
# Default: return scalar mean loss
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
loss = losses.mean()
loss_dict["loss"] = loss.item()
return loss, loss_dict
def _get_default_peft_targets(self) -> dict[str, any]:
"""Return default PEFT target modules for PI0.5 fine-tuning."""
common_projections = (
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
)
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
return {
"target_modules": target_modules,
"modules_to_save": [],
}
-164
View File
@@ -13,7 +13,6 @@
# limitations under the License.
import abc
import builtins
import dataclasses
import logging
import os
from importlib.resources import files
@@ -266,166 +265,3 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
card = ModelCard.from_template(card_data, template_str=template_card)
card.validate()
return card
def wrap_with_peft(
self,
peft_config=None,
peft_cli_overrides: dict | None = None,
) -> "PreTrainedPolicy":
"""
Wrap this policy with PEFT adapters for parameter-efficient fine-tuning.
This method is the single entry point for PEFT integration. Subclasses should
override `_get_default_peft_targets()` to provide default target modules, and
`_validate_peft_config()` for policy-specific validation.
Args:
peft_config: Optional PEFT adapter configuration (e.g., LoraConfig).
If provided, used directly (with CLI overrides applied).
peft_cli_overrides: Optional dict of CLI overrides (method_type, target_modules, r, etc.)
These are merged with policy defaults to build the final config.
"""
from peft import get_peft_model
# If user provided a complete config, use it directly (with overrides)
if peft_config is not None:
final_config = peft_config
if peft_cli_overrides:
final_config = self._apply_peft_cli_overrides(final_config, peft_cli_overrides)
else:
# Build config from defaults + CLI overrides
final_config = self._build_peft_config(peft_cli_overrides or {})
# Validate the configuration
self._validate_peft_config(final_config)
# Freeze base parameters, only adapter params will be trained
for p in self.parameters():
p.requires_grad_(False)
# Store pretrained path for PEFT's base_model_name_or_path
if self.config.pretrained_path:
self.name_or_path = str(self.config.pretrained_path)
# Wrap with PEFT
peft_model = get_peft_model(self, final_config)
# Mark config as using PEFT for proper loading later
peft_model.config.use_peft = True
logging.info(f"Wrapped {self.name} with PEFT ({type(final_config).__name__})")
return peft_model
def _get_default_peft_targets(self) -> dict[str, any] | None:
"""
Return default PEFT target modules for this policy.
Override this in subclasses to provide policy-specific defaults. These defaults
are PEFT-method agnostic - they only specify which modules to target.
"""
return None
def _validate_peft_config(self, peft_config) -> None:
"""
Validate the PEFT configuration for this policy.
Override this in subclasses to add policy-specific validation or warnings.
The default implementation checks that a pretrained_path exists.
Args:
peft_config: The PEFT configuration to validate.
Raises:
ValueError: If the configuration is invalid.
"""
if not self.config.pretrained_path:
raise ValueError(
"Training from scratch using PEFT is unlikely to yield good results. "
"Supply a `policy.pretrained_path` to fine-tune an existing model."
)
def _preprocess_peft_cli_overrides(self, cli_overrides: dict, peft_method_type) -> dict:
"""
Preprocess CLI overrides: rename keys and handle method-specific init_type.
Args:
cli_overrides: Dict of CLI options (will be copied, not mutated).
peft_method_type: The PeftType enum value for the PEFT method.
Returns:
Preprocessed dict with renamed keys and init_type mapped to method-specific key.
"""
from peft import PeftType
cli_overrides = cli_overrides.copy()
# Handle the full_training_modules -> modules_to_save rename
if "full_training_modules" in cli_overrides:
cli_overrides["modules_to_save"] = cli_overrides.pop("full_training_modules")
# Remove method_type as it's handled separately
cli_overrides.pop("method_type", None)
# Handle init_type specially based on PEFT method
init_type = cli_overrides.pop("init_type", None)
if init_type is not None:
if peft_method_type == PeftType.LORA:
cli_overrides["init_lora_weights"] = init_type
elif peft_method_type == PeftType.MISS:
cli_overrides["init_weights"] = init_type
else:
raise ValueError(f"Init type '{init_type}' unknown for PEFT method {peft_method_type}.")
return cli_overrides
def _build_peft_config(self, cli_overrides: dict):
"""Build a PEFT config from policy defaults and CLI overrides."""
from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType
# Determine PEFT method type (default to LORA)
method_type_str = cli_overrides.get("method_type") or "lora"
peft_method_type = PeftType[method_type_str.upper()]
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]
# Preprocess CLI overrides
cli_overrides = self._preprocess_peft_cli_overrides(cli_overrides, peft_method_type)
# Start with policy defaults, apply CLI overrides
config_dict = dict(self._get_default_peft_targets() or {})
for key, value in cli_overrides.items():
if value is not None:
config_dict[key] = value
# Ensure we have target_modules
if not config_dict.get("target_modules"):
raise ValueError(
f"Policy '{self.name}' does not define default target_modules. "
"Please pass --peft.target_modules explicitly."
)
return peft_config_cls(**config_dict)
def _apply_peft_cli_overrides(self, peft_config, cli_overrides: dict):
"""Apply CLI overrides to an existing PEFT config."""
from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType
# Get method type from existing config or CLI override
method_type_str = cli_overrides.get("method_type")
if method_type_str:
peft_method_type = PeftType[method_type_str.upper()]
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]
else:
peft_method_type = PeftType(peft_config.peft_type)
peft_config_cls = type(peft_config)
# Preprocess CLI overrides
cli_overrides = self._preprocess_peft_cli_overrides(cli_overrides, peft_method_type)
# Start with existing config, apply CLI overrides
config_dict = {k: v for k, v in dataclasses.asdict(peft_config).items() if not k.startswith("_")}
for key, value in cli_overrides.items():
if value is not None:
config_dict[key] = value
return peft_config_cls(**config_dict)
+1 -20
View File
@@ -23,7 +23,7 @@ Based on:
from dataclasses import dataclass
from lerobot.configs.types import RTCAttentionSchedule, RTCTrainingDelayDistribution
from lerobot.configs.types import RTCAttentionSchedule
@dataclass
@@ -53,22 +53,3 @@ class RTCConfig:
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
if self.debug_maxlen <= 0:
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
@dataclass
class RTCTrainingConfig:
"""Configuration for training-time RTC action prefix conditioning."""
enabled: bool = False
min_delay: int = 0
max_delay: int = 0
delay_distribution: RTCTrainingDelayDistribution = RTCTrainingDelayDistribution.UNIFORM
exp_decay: float = 1.0
def __post_init__(self):
if self.min_delay < 0:
raise ValueError(f"min_delay must be >= 0, got {self.min_delay}")
if self.max_delay < self.min_delay:
raise ValueError(f"max_delay ({self.max_delay}) must be >= min_delay ({self.min_delay})")
if self.exp_decay <= 0:
raise ValueError(f"exp_decay must be positive, got {self.exp_decay}")
-110
View File
@@ -1,110 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from lerobot.configs.types import RTCTrainingDelayDistribution
from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig
def sample_rtc_delay(cfg: RTCTrainingConfig, batch_size: int, device: torch.device) -> torch.Tensor:
if cfg.max_delay == cfg.min_delay:
return torch.full((batch_size,), cfg.min_delay, device=device, dtype=torch.long)
if cfg.delay_distribution == RTCTrainingDelayDistribution.UNIFORM:
return torch.randint(cfg.min_delay, cfg.max_delay + 1, (batch_size,), device=device, dtype=torch.long)
delay_values = torch.arange(cfg.min_delay, cfg.max_delay + 1, device=device, dtype=torch.long)
weights = torch.exp(-cfg.exp_decay * delay_values.to(dtype=torch.float32))
probs = weights / weights.sum()
samples = torch.multinomial(probs, batch_size, replacement=True)
return delay_values[samples]
def apply_rtc_training_time(
time: torch.Tensor, delay: torch.Tensor, seq_len: int
) -> tuple[torch.Tensor, torch.Tensor]:
device = time.device
delay = torch.clamp(delay, max=seq_len)
prefix_mask = torch.arange(seq_len, device=device)[None, :] < delay[:, None]
time_tokens = time[:, None].expand(-1, seq_len)
time_tokens = time_tokens.masked_fill(prefix_mask, 0.0)
postfix_mask = ~prefix_mask
return time_tokens, postfix_mask
def masked_mean(
losses: torch.Tensor, mask: torch.Tensor | None, reduce_dims: tuple[int, ...], eps: float = 1e-8
) -> torch.Tensor:
if mask is None:
return losses.mean(dim=reduce_dims)
mask = mask.to(dtype=losses.dtype)
while mask.dim() < losses.dim():
mask = mask.unsqueeze(-1)
masked = losses * mask
denom = mask.sum(dim=reduce_dims).clamp_min(eps)
return masked.sum(dim=reduce_dims) / denom
def apply_training_time_rtc_inference(
x_t: torch.Tensor,
time: float,
inference_delay: int | None,
prev_chunk_left_over: torch.Tensor | None,
chunk_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply training-time RTC conditioning during inference.
Based on Algorithm 1 from "Training-Time Action Conditioning for Efficient Real-Time Chunking".
At each denoising step:
1. Replace prefix positions in x_t with ground truth from previous chunk
2. Create per-token timesteps with 1.0 for prefix positions
Args:
x_t: Current noisy actions (B, T, D)
time: Current flow matching timestep (scalar)
inference_delay: Number of prefix actions to condition on
prev_chunk_left_over: Previous chunk's leftover actions (B, T, D)
chunk_size: Total chunk size T
Returns:
x_t_conditioned: x_t with prefix replaced by previous actions
time_per_token: Per-token timesteps (B, T) with 1.0 for prefix
"""
batch_size = x_t.shape[0]
device = x_t.device
if inference_delay is None or inference_delay <= 0 or prev_chunk_left_over is None:
time_scalar = torch.full((batch_size,), time, device=device, dtype=torch.float32)
return x_t, time_scalar
delay = min(inference_delay, chunk_size)
prefix_mask = torch.arange(chunk_size, device=device)[None, :] < delay
x_t_conditioned = torch.where(
prefix_mask[:, :, None].expand_as(x_t),
prev_chunk_left_over[:, :chunk_size, :],
x_t,
)
time_per_token = torch.full((batch_size, chunk_size), time, device=device, dtype=torch.float32)
time_per_token = time_per_token.masked_fill(prefix_mask, 1.0)
return x_t_conditioned, time_per_token
@@ -20,7 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import OBS_IMAGES
@@ -103,9 +103,8 @@ class SmolVLAConfig(PreTrainedConfig):
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
max_period: float = 4.0
# Real-Time Chunking (RTC) configurations
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
rtc_training_config: RTCTrainingConfig | None = None
def __post_init__(self):
super().__post_init__()
@@ -63,12 +63,6 @@ from typing_extensions import Unpack
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.rtc.training_time import (
apply_rtc_training_time,
apply_training_time_rtc_inference,
masked_mean,
sample_rtc_delay,
)
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.policies.utils import (
@@ -91,8 +85,8 @@ def create_sinusoidal_pos_embedding(
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim not in (1, 2):
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
@@ -100,14 +94,9 @@ def create_sinusoidal_pos_embedding(
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
if time.ndim == 1:
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
time_flat = time.reshape(-1)
sin_input = scaling_factor[None, :] * time_flat[:, None]
sin_input = scaling_factor[None, :] * time[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb.reshape(*time.shape, dimension)
return pos_emb
def make_att_2d_masks(pad_masks, att_masks):
@@ -386,16 +375,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
actions = self.prepare_action(batch)
postfix_mask = None
rtc_cfg = self.config.rtc_training_config
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
batch_size = actions.shape[0]
if time is None:
time = self.model.sample_time(batch_size, actions.device)
if noise is None:
noise = self.model.sample_noise(actions.shape, actions.device)
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
@@ -405,7 +384,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone()
postfix_mask = in_episode_bound if postfix_mask is None else (postfix_mask & in_episode_bound)
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
@@ -413,12 +391,12 @@ class SmolVLAPolicy(PreTrainedPolicy):
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
per_sample_loss = losses.mean(dim=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict
else:
# Default: return scalar mean loss
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
loss = losses.mean()
loss_dict["loss"] = loss.item()
return loss, loss_dict
@@ -502,28 +480,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
return actions
def _get_default_peft_targets(self) -> dict[str, any]:
"""Return default PEFT target modules for SmolVLA fine-tuning."""
common_projections = (
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
)
target_modules = rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))"
return {
"target_modules": target_modules,
"modules_to_save": [],
}
def _validate_peft_config(self, peft_config) -> None:
"""Validate PEFT configuration for SmolVLA."""
super()._validate_peft_config(peft_config)
if not self.config.load_vlm_weights:
import logging
logging.warning(
"Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. "
"Set `load_vlm_weights=True` to fine-tune the existing policy."
)
def pad_tensor(tensor, max_len, pad_value=0):
"""
@@ -618,9 +574,6 @@ class VLAFlowMatching(nn.Module):
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _training_time_rtc_inference_enabled(self):
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
def set_requires_grad(self):
for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj
@@ -756,10 +709,7 @@ class VLAFlowMatching(nn.Module):
)
time_emb = time_emb.type(dtype=dtype)
if time_emb.dim() == 2:
time_emb = time_emb[:, None, :].expand_as(action_emb)
elif time_emb.shape[:2] != action_emb.shape[:2]:
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
time_emb = time_emb[:, None, :].expand_as(action_emb)
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
action_time_emb = self.action_time_mlp_in(action_time_emb)
@@ -791,12 +741,7 @@ class VLAFlowMatching(nn.Module):
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
if time.ndim == 1:
time_expanded = time[:, None, None]
elif time.ndim == 2:
time_expanded = time[:, :, None]
else:
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
@@ -859,35 +804,23 @@ class VLAFlowMatching(nn.Module):
num_steps = self.config.num_steps
dt = -1.0 / num_steps
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
use_training_time_rtc = self._training_time_rtc_inference_enabled()
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
if use_training_time_rtc:
x_t_cond, time_tensor = apply_training_time_rtc_inference(
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
)
v_t = self.denoise_step(
x_t=x_t_cond,
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
x_t=input_x_t,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=time_tensor,
timestep=current_timestep,
)
elif self._rtc_enabled():
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
x_t=input_x_t,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
@@ -898,13 +831,7 @@ class VLAFlowMatching(nn.Module):
execution_horizon=execution_horizon,
)
else:
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
v_t = self.denoise_step(
x_t=x_t,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=time_tensor,
)
v_t = denoise_step_partial_call(x_t)
x_t = x_t + dt * v_t
@@ -24,8 +24,7 @@ import numpy as np
import requests
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
@@ -100,7 +99,6 @@ class EarthRoverMiniPlus(Robot):
"""Check if robot is connected to SDK."""
return self._is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
"""Connect to robot via Frodobots SDK.
@@ -111,6 +109,8 @@ class EarthRoverMiniPlus(Robot):
DeviceAlreadyConnectedError: If robot is already connected
DeviceNotConnectedError: If cannot connect to SDK server
"""
if self._is_connected:
raise DeviceAlreadyConnectedError(f"{self.name} is already connected")
# Verify SDK is running and accessible
try:
@@ -197,7 +197,6 @@ class EarthRoverMiniPlus(Robot):
ACTION_ANGULAR_VEL: float,
}
@check_if_not_connected
def get_observation(self) -> RobotObservation:
"""Get current robot observation from SDK.
@@ -224,6 +223,8 @@ class EarthRoverMiniPlus(Robot):
Robot telemetry is retrieved from /data endpoint.
All SDK values are normalized to appropriate ranges for dataset recording.
"""
if not self._is_connected:
raise DeviceNotConnectedError(f"{self.name} is not connected")
observation = {}
@@ -254,7 +255,6 @@ class EarthRoverMiniPlus(Robot):
return observation
@check_if_not_connected
def send_action(self, action: RobotAction) -> RobotAction:
"""Send action to robot via SDK.
@@ -272,6 +272,8 @@ class EarthRoverMiniPlus(Robot):
Actions are sent to SDK via POST /control endpoint.
SDK expects commands in range [-1, 1].
"""
if not self._is_connected:
raise DeviceNotConnectedError(f"{self.name} is not connected")
# Extract action values and convert to float
linear = float(action.get(ACTION_LINEAR_VEL, 0.0))
@@ -289,7 +291,6 @@ class EarthRoverMiniPlus(Robot):
ACTION_ANGULAR_VEL: angular,
}
@check_if_not_connected
def disconnect(self) -> None:
"""Disconnect from robot.
@@ -298,6 +299,8 @@ class EarthRoverMiniPlus(Robot):
Raises:
DeviceNotConnectedError: If robot is not connected
"""
if not self._is_connected:
raise DeviceNotConnectedError(f"{self.name} is not connected")
# Stop the robot before disconnecting
try:
+12 -5
View File
@@ -25,7 +25,7 @@ from lerobot.motors.feetech import (
FeetechMotorsBus,
)
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
@@ -82,12 +82,13 @@ class HopeJrArm(Robot):
def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
"""
We assume that at connection time, arm is in a rest position,
and torque can be safely disabled to run calibration.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect(handshake=False)
if not self.is_calibrated and calibrate:
@@ -127,8 +128,10 @@ class HopeJrArm(Robot):
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
@check_if_not_connected
def get_observation(self) -> RobotObservation:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Read arm position
start = time.perf_counter()
obs_dict = self.bus.sync_read("Present_Position", self.other_motors)
@@ -146,8 +149,10 @@ class HopeJrArm(Robot):
return obs_dict
@check_if_not_connected
def send_action(self, action: RobotAction) -> RobotAction:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
# Cap goal position when too far away from present position.
@@ -160,8 +165,10 @@ class HopeJrArm(Robot):
self.bus.sync_write("Goal_Position", goal_pos)
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
@check_if_not_connected
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
+13 -5
View File
@@ -25,7 +25,7 @@ from lerobot.motors.feetech import (
FeetechMotorsBus,
)
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from .config_hope_jr import HopeJrHandConfig
@@ -118,8 +118,10 @@ class HopeJrHand(Robot):
def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect()
if not self.is_calibrated and calibrate:
self.calibrate()
@@ -157,8 +159,10 @@ class HopeJrHand(Robot):
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
@check_if_not_connected
def get_observation(self) -> RobotObservation:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
obs_dict = {}
# Read hand position
@@ -177,14 +181,18 @@ class HopeJrHand(Robot):
return obs_dict
@check_if_not_connected
def send_action(self, action: RobotAction) -> RobotAction:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
self.bus.sync_write("Goal_Position", goal_pos)
return action
@check_if_not_connected
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
@@ -25,7 +25,7 @@ from lerobot.motors.dynamixel import (
OperatingMode,
)
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
@@ -84,12 +84,13 @@ class KochFollower(Robot):
def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
"""
We assume that at connection time, arm is in a rest position,
and torque can be safely disabled to run calibration.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect()
if not self.is_calibrated and calibrate:
@@ -181,8 +182,10 @@ class KochFollower(Robot):
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
@check_if_not_connected
def get_observation(self) -> RobotObservation:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Read arm position
start = time.perf_counter()
obs_dict = self.bus.sync_read("Present_Position")
@@ -199,7 +202,6 @@ class KochFollower(Robot):
return obs_dict
@check_if_not_connected
def send_action(self, action: RobotAction) -> RobotAction:
"""Command arm to move to a target joint configuration.
@@ -213,6 +215,8 @@ class KochFollower(Robot):
Returns:
RobotAction: The action sent to the motors, potentially clipped.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
@@ -227,8 +231,10 @@ class KochFollower(Robot):
self.bus.sync_write("Goal_Position", goal_pos)
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
@check_if_not_connected
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
+12 -5
View File
@@ -29,7 +29,7 @@ from lerobot.motors.feetech import (
OperatingMode,
)
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
@@ -109,8 +109,10 @@ class LeKiwi(Robot):
def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect()
if not self.is_calibrated and calibrate:
logger.info(
@@ -337,8 +339,10 @@ class LeKiwi(Robot):
"theta.vel": theta,
} # m/s and deg/s
@check_if_not_connected
def get_observation(self) -> RobotObservation:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Read actuators position for arm and vel for base
start = time.perf_counter()
arm_pos = self.bus.sync_read("Present_Position", self.arm_motors)
@@ -366,7 +370,6 @@ class LeKiwi(Robot):
return obs_dict
@check_if_not_connected
def send_action(self, action: RobotAction) -> RobotAction:
"""Command lekiwi to move to a target joint configuration.
@@ -380,6 +383,8 @@ class LeKiwi(Robot):
Returns:
RobotAction: the action sent to the motors, potentially clipped.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
arm_goal_pos = {k: v for k, v in action.items() if k.endswith(".pos")}
base_goal_vel = {k: v for k, v in action.items() if k.endswith(".vel")}
@@ -407,8 +412,10 @@ class LeKiwi(Robot):
self.bus.sync_write("Goal_Velocity", dict.fromkeys(self.base_motors, 0), num_retry=5)
logger.info("Base motors stopped")
@check_if_not_connected
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.stop_base()
self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
+16 -6
View File
@@ -24,8 +24,7 @@ import numpy as np
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from .config_lekiwi import LeKiwiClientConfig
@@ -113,10 +112,14 @@ class LeKiwiClient(Robot):
def is_calibrated(self) -> bool:
pass
@check_if_already_connected
def connect(self) -> None:
"""Establishes ZMQ sockets with the remote mobile robot"""
if self._is_connected:
raise DeviceAlreadyConnectedError(
"LeKiwi Daemon is already connected. Do not run `robot.connect()` twice."
)
zmq = self._zmq
self.zmq_context = zmq.Context()
self.zmq_cmd_socket = self.zmq_context.socket(zmq.PUSH)
@@ -249,13 +252,14 @@ class LeKiwiClient(Robot):
return new_frames, new_state
@check_if_not_connected
def get_observation(self) -> RobotObservation:
"""
Capture observations from the remote robot: current follower arm positions,
present wheel speeds (converted to body-frame velocities: x, y, theta),
and a camera frame. Receives over ZMQ, translate to body-frame vel
"""
if not self._is_connected:
raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.")
frames, obs_dict = self._get_data()
@@ -303,7 +307,6 @@ class LeKiwiClient(Robot):
def configure(self):
pass
@check_if_not_connected
def send_action(self, action: RobotAction) -> RobotAction:
"""Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ
@@ -315,6 +318,10 @@ class LeKiwiClient(Robot):
Returns:
np.ndarray: the action sent to the motors, potentially clipped.
"""
if not self._is_connected:
raise DeviceNotConnectedError(
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
)
self.zmq_cmd_socket.send_string(json.dumps(action)) # action is in motor space
@@ -325,10 +332,13 @@ class LeKiwiClient(Robot):
action_sent[ACTION] = actions
return action_sent
@check_if_not_connected
def disconnect(self):
"""Cleans ZMQ comms"""
if not self._is_connected:
raise DeviceNotConnectedError(
"LeKiwi is not connected. You need to run `robot.connect()` before disconnecting."
)
self.zmq_observation_socket.close()
self.zmq_cmd_socket.close()
self.zmq_context.term()
@@ -26,7 +26,7 @@ from lerobot.motors.dynamixel import (
OperatingMode,
)
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
@@ -84,7 +84,6 @@ class OmxFollower(Robot):
def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
"""
For OMX robots that come pre-calibrated:
@@ -92,6 +91,8 @@ class OmxFollower(Robot):
- This allows using pre-calibrated robots without manual calibration
- If no calibration file exists, use factory default values (homing_offset=0, range_min=0, range_max=4095)
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect()
if not self.is_calibrated and calibrate:
@@ -164,8 +165,10 @@ class OmxFollower(Robot):
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
@check_if_not_connected
def get_observation(self) -> RobotObservation:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Read arm position
start = time.perf_counter()
obs_dict = self.bus.sync_read("Present_Position")
@@ -182,7 +185,6 @@ class OmxFollower(Robot):
return obs_dict
@check_if_not_connected
def send_action(self, action: RobotAction) -> RobotAction:
"""Command arm to move to a target joint configuration.
@@ -196,6 +198,8 @@ class OmxFollower(Robot):
Returns:
RobotAction: The action sent to the motors, potentially clipped.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
@@ -210,8 +214,10 @@ class OmxFollower(Robot):
self.bus.sync_write("Goal_Position", goal_pos)
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
@check_if_not_connected
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
-26
View File
@@ -58,32 +58,6 @@ class Robot(abc.ABC):
def __str__(self) -> str:
return f"{self.id} {self.__class__.__name__}"
def __enter__(self):
"""
Context manager entry.
Automatically connects to the camera.
"""
self.connect()
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
Context manager exit.
Automatically disconnects, ensuring resources are released even on error.
"""
self.disconnect()
def __del__(self) -> None:
"""
Destructor safety net.
Attempts to disconnect if the object is garbage collected without cleanup.
"""
try:
if self.is_connected:
self.disconnect()
except Exception: # nosec B110
pass
# TODO(aliberts): create a proper Feature class for this that links with datasets
@property
@abc.abstractmethod
+11 -5
View File
@@ -26,7 +26,7 @@ from lerobot.motors.feetech import (
OperatingMode,
)
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
@@ -85,12 +85,13 @@ class SOFollower(Robot):
def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
"""
We assume that at connection time, arm is in a rest position,
and torque can be safely disabled to run calibration.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect()
if not self.is_calibrated and calibrate:
@@ -175,8 +176,10 @@ class SOFollower(Robot):
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
@check_if_not_connected
def get_observation(self) -> RobotObservation:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Read arm position
start = time.perf_counter()
obs_dict = self.bus.sync_read("Present_Position")
@@ -193,7 +196,6 @@ class SOFollower(Robot):
return obs_dict
@check_if_not_connected
def send_action(self, action: RobotAction) -> RobotAction:
"""Command arm to move to a target joint configuration.
@@ -207,6 +209,8 @@ class SOFollower(Robot):
Returns:
RobotAction: the action sent to the motors, potentially clipped.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
@@ -221,8 +225,10 @@ class SOFollower(Robot):
self.bus.sync_write("Goal_Position", goal_pos)
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
@check_if_not_connected
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
+379 -32
View File
@@ -66,23 +66,23 @@ Remove camera feature:
--operation.type remove_feature \
--operation.feature_names "['observation.images.top']"
Convert image dataset to video format and save locally:
Convert image dataset to video format (saves locally):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \
--operation.type convert_to_video \
--operation.output_dir /path/to/output/pusht_video
Convert image dataset to video format and save with new repo_id:
Convert image dataset and save with new repo_id:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_image_to_video
--operation.type convert_to_video
Convert image dataset to video format and push to hub:
Convert and push to hub:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_image_to_video \
--operation.type convert_to_video \
--push_to_hub true
Using JSON config file:
@@ -92,19 +92,24 @@ Using JSON config file:
import logging
import shutil
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
import pandas as pd
from tqdm import tqdm
from lerobot.configs import parser
from lerobot.datasets.dataset_tools import (
convert_image_to_video_dataset,
delete_episodes,
merge_datasets,
remove_feature,
split_dataset,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import write_stats, write_tasks
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
from lerobot.utils.utils import init_logging
@@ -133,8 +138,8 @@ class RemoveFeatureConfig:
@dataclass
class ConvertImageToVideoConfig:
type: str = "convert_image_to_video"
class ConvertToVideoConfig:
type: str = "convert_to_video"
output_dir: str | None = None
vcodec: str = "libsvtav1"
pix_fmt: str = "yuv420p"
@@ -143,16 +148,12 @@ class ConvertImageToVideoConfig:
fast_decode: int = 0
episode_indices: list[int] | None = None
num_workers: int = 4
max_episodes_per_batch: int | None = None
max_frames_per_batch: int | None = None
@dataclass
class EditDatasetConfig:
repo_id: str
operation: (
DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig
)
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertToVideoConfig
root: str | None = None
new_repo_id: str | None = None
push_to_hub: bool = False
@@ -296,7 +297,362 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
def save_episode_images_for_video(
dataset: LeRobotDataset,
imgs_dir: Path,
img_key: str,
episode_index: int,
num_workers: int = 4,
) -> None:
"""Save images from a specific episode and camera to disk for video encoding.
Args:
dataset: The LeRobot dataset to extract images from
imgs_dir: Directory to save images to
img_key: The image key (camera) to extract
episode_index: Index of the episode to save
num_workers: Number of threads for parallel image saving
"""
# Create directory
imgs_dir.mkdir(parents=True, exist_ok=True)
# Get dataset without torch format for PIL image access
hf_dataset = dataset.hf_dataset.with_format(None)
# Select only this camera's images
imgs_dataset = hf_dataset.select_columns(img_key)
# Get episode start and end indices
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
# Get all items for this episode
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
# Define function to save a single image
def save_single_image(i_item_tuple):
i, item = i_item_tuple
img = item[img_key]
# Use frame-XXXXXX.png format to match encode_video_frames expectations
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
return i
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
items = list(enumerate(episode_dataset))
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(save_single_image, item) for item in items]
for future in as_completed(futures):
future.result() # This will raise any exceptions that occurred
def encode_episode_videos(
dataset: LeRobotDataset,
new_meta: LeRobotDatasetMetadata,
episode_index: int,
vcodec: str,
pix_fmt: str,
g: int,
crf: int,
fast_decode: int,
temp_dir: Path,
num_image_workers: int = 4,
) -> dict[str, dict]:
"""Encode videos for a single episode and return video metadata.
Args:
dataset: Source dataset with images
new_meta: Metadata object for the new video dataset
episode_index: Episode index to process
vcodec: Video codec
pix_fmt: Pixel format
g: Group of pictures size
crf: Constant rate factor
fast_decode: Fast decode tuning
temp_dir: Temporary directory for images
num_image_workers: Number of workers for saving images
Returns:
Dictionary mapping video keys to their metadata (chunk_index, file_index, timestamps)
"""
hf_dataset = dataset.hf_dataset.with_format(None)
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
video_metadata = {}
fps = int(dataset.fps) # Convert to int for PyAV compatibility
episode_length = dataset.meta.episodes["length"][episode_index]
episode_duration = episode_length / dataset.fps # Use original fps for duration calculation
for img_key in img_keys:
# Save images temporarily
imgs_dir = temp_dir / f"episode_{episode_index:06d}" / img_key
save_episode_images_for_video(dataset, imgs_dir, img_key, episode_index, num_image_workers)
# Determine chunk and file indices
# For simplicity, we'll put each episode in its own file
chunk_idx = episode_index // new_meta.chunks_size
file_idx = episode_index % new_meta.chunks_size
# Create video path in the new dataset structure
video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
)
video_path.parent.mkdir(parents=True, exist_ok=True)
# Encode video
encode_video_frames(
imgs_dir=imgs_dir,
video_path=video_path,
fps=fps,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
overwrite=True,
)
# Clean up temporary images
shutil.rmtree(imgs_dir)
# Store video metadata
video_metadata[img_key] = {
f"videos/{img_key}/chunk_index": chunk_idx,
f"videos/{img_key}/file_index": file_idx,
f"videos/{img_key}/from_timestamp": 0.0,
f"videos/{img_key}/to_timestamp": episode_duration,
}
return video_metadata
def convert_dataset_to_videos(
dataset: LeRobotDataset,
output_dir: Path,
repo_id: str | None = None,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int = 2,
crf: int = 30,
fast_decode: int = 0,
episode_indices: list[int] | None = None,
num_workers: int = 4,
) -> LeRobotDataset:
"""Convert image-based dataset to video-based dataset.
Creates a new LeRobotDataset with videos instead of images, following the proper
LeRobot dataset structure with videos stored in chunked MP4 files.
Args:
dataset: The source LeRobot dataset with images
output_dir: Directory to save the new video dataset
repo_id: Repository ID for the new dataset (default: original_id + "_video")
vcodec: Video codec (default: libsvtav1)
pix_fmt: Pixel format (default: yuv420p)
g: Group of pictures size (default: 2)
crf: Constant rate factor (default: 30)
fast_decode: Fast decode tuning (default: 0)
episode_indices: List of episode indices to convert (None = all episodes)
num_workers: Number of threads for parallel processing (default: 4)
Returns:
New LeRobotDataset with videos
"""
# Check that it's an image dataset
if len(dataset.meta.video_keys) > 0:
raise ValueError(
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
)
# Get all image keys
hf_dataset = dataset.hf_dataset.with_format(None)
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
if len(img_keys) == 0:
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
# Determine which episodes to process
if episode_indices is None:
episode_indices = list(range(dataset.meta.total_episodes))
if repo_id is None:
repo_id = f"{dataset.repo_id}_video"
logging.info(
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
)
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
# Create new features dict, converting image features to video features
new_features = {}
for key, value in dataset.meta.features.items():
if key not in img_keys:
new_features[key] = value
else:
# Convert image key to video format
new_features[key] = value.copy()
new_features[key]["dtype"] = "video" # Change dtype from "image" to "video"
# Video info will be updated after episodes are encoded
# Create new metadata for video dataset
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=new_features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=True,
chunks_size=dataset.meta.chunks_size,
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
)
# Create temporary directory for image extraction
temp_dir = output_dir / "temp_images"
temp_dir.mkdir(parents=True, exist_ok=True)
# Process each episode
all_episode_metadata = []
try:
for ep_idx in tqdm(episode_indices, desc="Converting episodes to videos"):
# Get episode metadata from source
src_episode = dataset.meta.episodes[ep_idx]
# Encode videos for this episode
video_metadata = encode_episode_videos(
dataset=dataset,
new_meta=new_meta,
episode_index=ep_idx,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
temp_dir=temp_dir,
num_image_workers=num_workers,
)
# Build episode metadata
episode_meta = {
"episode_index": ep_idx,
"length": src_episode["length"],
"dataset_from_index": ep_idx * src_episode["length"],
"dataset_to_index": (ep_idx + 1) * src_episode["length"],
}
# Add video metadata
for img_key in img_keys:
episode_meta.update(video_metadata[img_key])
# Add data chunk/file info (using same structure as source)
if "data/chunk_index" in src_episode:
episode_meta["data/chunk_index"] = src_episode["data/chunk_index"]
episode_meta["data/file_index"] = src_episode["data/file_index"]
all_episode_metadata.append(episode_meta)
# Copy and transform data files (removing image columns)
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
# Save episode metadata
episodes_df = pd.DataFrame(all_episode_metadata)
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
episodes_path.parent.mkdir(parents=True, exist_ok=True)
episodes_df.to_parquet(episodes_path, index=False)
# Update metadata info
new_meta.info["total_episodes"] = len(episode_indices)
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata)
new_meta.info["total_tasks"] = dataset.meta.total_tasks
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
# Update video info for all image keys (now videos)
# We need to manually set video info since update_video_info() checks video_keys first
for img_key in img_keys:
if not new_meta.features[img_key].get("info", None):
video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=0, file_index=0
)
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
from lerobot.datasets.utils import write_info
write_info(new_meta.info, new_meta.root)
# Copy stats and tasks
if dataset.meta.stats is not None:
# Remove image stats
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
write_stats(new_stats, new_meta.root)
if dataset.meta.tasks is not None:
write_tasks(dataset.meta.tasks, new_meta.root)
finally:
# Clean up temporary directory
if temp_dir.exists():
shutil.rmtree(temp_dir)
logging.info(f"✓ Completed converting {dataset.repo_id} to video format")
logging.info(f"New dataset saved to: {output_dir}")
# Return new dataset
return LeRobotDataset(repo_id=repo_id, root=output_dir)
def _copy_data_without_images(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_indices: list[int],
img_keys: list[str],
) -> None:
"""Copy data files without image columns.
Args:
src_dataset: Source dataset
dst_meta: Destination metadata
episode_indices: Episodes to include
img_keys: Image keys to remove
"""
from lerobot.datasets.utils import DATA_DIR
data_dir = src_dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
episode_set = set(episode_indices)
for src_path in tqdm(parquet_files, desc="Processing data files"):
df = pd.read_parquet(src_path).reset_index(drop=True)
# Filter to only include selected episodes
df = df[df["episode_index"].isin(episode_set)].copy()
if len(df) == 0:
continue
# Remove image columns
columns_to_drop = [col for col in img_keys if col in df.columns]
if columns_to_drop:
df = df.drop(columns=columns_to_drop)
# Get chunk and file indices from path
relative_path = src_path.relative_to(src_dataset.root)
chunk_dir = relative_path.parts[1]
file_name = relative_path.parts[2]
chunk_idx = int(chunk_dir.split("-")[1])
file_idx = int(file_name.split("-")[1].split(".")[0])
# Write to destination without pandas index
dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet"
dst_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(dst_path, index=False)
def handle_convert_to_video(cfg: EditDatasetConfig) -> None:
# Note: Parser may create any config type with the right fields, so we access fields directly
# instead of checking isinstance()
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
@@ -308,12 +664,8 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
if cfg.new_repo_id:
# Use new_repo_id for both local storage and hub push
output_repo_id = cfg.new_repo_id
# Place new dataset as a sibling to the original dataset
# Get the parent of the actual dataset root (not cfg.root which might be the lerobot cache dir)
# Extract just the dataset name (after last slash) for the local directory
local_dir_name = cfg.new_repo_id.split("/")[-1]
output_dir = dataset.root.parent / local_dir_name
logging.info(f"Saving to new dataset: {cfg.new_repo_id} at {output_dir}")
output_dir = Path(cfg.root) / cfg.new_repo_id if cfg.root else HF_LEROBOT_HOME / cfg.new_repo_id
logging.info(f"Saving to new dataset: {cfg.new_repo_id}")
elif output_dir_config:
# Use custom output directory for local-only storage
output_dir = Path(output_dir_config)
@@ -323,15 +675,12 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
else:
# Auto-generate name: append "_video" to original repo_id
output_repo_id = f"{cfg.repo_id}_video"
# Place new dataset as a sibling to the original dataset
# Extract just the dataset name (after last slash) for the local directory
local_dir_name = output_repo_id.split("/")[-1]
output_dir = dataset.root.parent / local_dir_name
output_dir = Path(cfg.root) / output_repo_id if cfg.root else HF_LEROBOT_HOME / output_repo_id
logging.info(f"Saving to auto-generated location: {output_dir}")
logging.info(f"Converting dataset {cfg.repo_id} to video format")
new_dataset = convert_image_to_video_dataset(
new_dataset = convert_dataset_to_videos(
dataset=dataset,
output_dir=output_dir,
repo_id=output_repo_id,
@@ -342,8 +691,6 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
fast_decode=getattr(cfg.operation, "fast_decode", 0),
episode_indices=getattr(cfg.operation, "episode_indices", None),
num_workers=getattr(cfg.operation, "num_workers", 4),
max_episodes_per_batch=getattr(cfg.operation, "max_episodes_per_batch", None),
max_frames_per_batch=getattr(cfg.operation, "max_frames_per_batch", None),
)
logging.info("Video dataset created successfully!")
@@ -371,8 +718,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
handle_merge(cfg)
elif operation_type == "remove_feature":
handle_remove_feature(cfg)
elif operation_type == "convert_image_to_video":
handle_convert_image_to_video(cfg)
elif operation_type == "convert_to_video":
handle_convert_to_video(cfg)
else:
raise ValueError(
f"Unknown operation type: {operation_type}\n"
-360
View File
@@ -1,360 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Setup and debug CAN interfaces for Damiao motors (e.g., OpenArms).
Examples:
Setup CAN interfaces with CAN FD:
```shell
lerobot-setup-can --mode=setup --interfaces=can0,can1,can2,can3
```
Test motors on a single interface:
```shell
lerobot-setup-can --mode=test --interfaces=can0
```
Test motors on all interfaces:
```shell
lerobot-setup-can --mode=test --interfaces=can0,can1,can2,can3
```
Speed test:
```shell
lerobot-setup-can --mode=speed --interfaces=can0
```
"""
import subprocess
import sys
import time
from dataclasses import dataclass, field
import draccus
from lerobot.utils.import_utils import is_package_available
MOTOR_NAMES = {
0x01: "joint_1",
0x02: "joint_2",
0x03: "joint_3",
0x04: "joint_4",
0x05: "joint_5",
0x06: "joint_6",
0x07: "joint_7",
0x08: "gripper",
}
@dataclass
class CANSetupConfig:
mode: str = "test"
interfaces: str = "can0" # Comma-separated, e.g. "can0,can1,can2,can3"
bitrate: int = 1000000
data_bitrate: int = 5000000
use_fd: bool = True
motor_ids: list[int] = field(default_factory=lambda: list(range(0x01, 0x09)))
timeout: float = 1.0
speed_iterations: int = 100
def get_interfaces(self) -> list[str]:
return [i.strip() for i in self.interfaces.split(",") if i.strip()]
def check_interface_status(interface: str) -> tuple[bool, str, bool]:
"""Check if CAN interface is UP and configured."""
try:
result = subprocess.run(["ip", "link", "show", interface], capture_output=True, text=True) # nosec B607
if result.returncode != 0:
return False, "Interface not found", False
output = result.stdout
is_up = "UP" in output
is_fd = "fd on" in output.lower() or "canfd" in output.lower()
status = "UP" if is_up else "DOWN"
if is_fd:
status += " (CAN FD)"
return is_up, status, is_fd
except FileNotFoundError:
return False, "ip command not found", False
def setup_interface(interface: str, bitrate: int, data_bitrate: int, use_fd: bool) -> bool:
"""Configure a CAN interface."""
try:
subprocess.run(["sudo", "ip", "link", "set", interface, "down"], check=False, capture_output=True) # nosec B607
cmd = ["sudo", "ip", "link", "set", interface, "type", "can", "bitrate", str(bitrate)]
if use_fd:
cmd.extend(["dbitrate", str(data_bitrate), "fd", "on"])
result = subprocess.run(cmd, capture_output=True, text=True) # nosec B607
if result.returncode != 0:
print(f" ✗ Failed to configure: {result.stderr}")
return False
result = subprocess.run( # nosec B607
["sudo", "ip", "link", "set", interface, "up"], capture_output=True, text=True
)
if result.returncode != 0:
print(f" ✗ Failed to bring up: {result.stderr}")
return False
return True
except Exception as e:
print(f" ✗ Error: {e}")
return False
def test_motor(bus, motor_id: int, timeout: float, use_fd: bool):
"""Test a single motor and return responses."""
import can
enable_msg = can.Message(
arbitration_id=motor_id,
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
is_extended_id=False,
is_fd=use_fd,
)
try:
bus.send(enable_msg)
except Exception as e:
return None, f"Send error: {e}"
responses = []
start_time = time.time()
while time.time() - start_time < timeout:
msg = bus.recv(timeout=0.1)
if msg:
responses.append((msg.arbitration_id, msg.data.hex(), getattr(msg, "is_fd", False)))
disable_msg = can.Message(
arbitration_id=motor_id,
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD],
is_extended_id=False,
is_fd=use_fd,
)
try:
bus.send(disable_msg)
except Exception:
print(f"Error sending message to motor 0x{motor_id:02X}")
return responses, None
def test_interface(cfg: CANSetupConfig, interface: str):
"""Test all motors on a CAN interface."""
import can
is_up, status, _ = check_interface_status(interface)
print(f"\n{interface}: {status}")
if not is_up:
print(f" ⚠ Interface is not UP. Run: lerobot-setup-can --mode=setup --interfaces {interface}")
return {}
try:
kwargs = {"channel": interface, "interface": "socketcan", "bitrate": cfg.bitrate}
if cfg.use_fd:
kwargs.update({"data_bitrate": cfg.data_bitrate, "fd": True})
bus = can.interface.Bus(**kwargs)
except Exception as e:
print(f" ✗ Connection failed: {e}")
return {}
results = {}
try:
while bus.recv(timeout=0.01):
pass
for motor_id in cfg.motor_ids:
motor_name = MOTOR_NAMES.get(motor_id, f"motor_0x{motor_id:02X}")
responses, error = test_motor(bus, motor_id, cfg.timeout, cfg.use_fd)
if error:
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ {error}")
results[motor_id] = {"found": False, "error": error}
elif responses:
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✓ FOUND")
for resp_id, data, is_fd in responses:
fd_flag = " [FD]" if is_fd else ""
print(f" → Response 0x{resp_id:02X}{fd_flag}: {data}")
results[motor_id] = {"found": True, "responses": responses}
else:
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ No response")
results[motor_id] = {"found": False}
time.sleep(0.05)
finally:
bus.shutdown()
found = sum(1 for r in results.values() if r.get("found"))
print(f"\n Summary: {found}/{len(cfg.motor_ids)} motors found")
return results
def speed_test(cfg: CANSetupConfig, interface: str):
"""Test communication speed with motors."""
import can
is_up, status, _ = check_interface_status(interface)
if not is_up:
print(f"{interface}: {status} - skipping")
return
print(f"\n{interface}: Running speed test ({cfg.speed_iterations} iterations)...")
try:
kwargs = {"channel": interface, "interface": "socketcan", "bitrate": cfg.bitrate}
if cfg.use_fd:
kwargs.update({"data_bitrate": cfg.data_bitrate, "fd": True})
bus = can.interface.Bus(**kwargs)
except Exception as e:
print(f" ✗ Connection failed: {e}")
return
responding_motor = None
for motor_id in cfg.motor_ids:
responses, _ = test_motor(bus, motor_id, 0.5, cfg.use_fd)
if responses:
responding_motor = motor_id
break
if not responding_motor:
print(" ✗ No responding motors found")
bus.shutdown()
return
print(f" Testing with motor 0x{responding_motor:02X}...")
latencies = []
for _ in range(cfg.speed_iterations):
start = time.perf_counter()
msg = can.Message(
arbitration_id=responding_motor,
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
is_extended_id=False,
is_fd=cfg.use_fd,
)
bus.send(msg)
resp = bus.recv(timeout=0.1)
if resp:
latencies.append((time.perf_counter() - start) * 1000)
bus.shutdown()
if latencies:
avg_latency = sum(latencies) / len(latencies)
hz = 1000.0 / avg_latency if avg_latency > 0 else 0
print(f" ✓ Success rate: {len(latencies)}/{cfg.speed_iterations}")
print(f" ✓ Avg latency: {avg_latency:.2f} ms")
print(f" ✓ Max frequency: {hz:.1f} Hz")
else:
print(" ✗ No successful responses")
def run_setup(cfg: CANSetupConfig):
"""Setup CAN interfaces."""
print("=" * 50)
print("CAN Interface Setup")
print("=" * 50)
print(f"Mode: {'CAN FD' if cfg.use_fd else 'CAN 2.0'}")
print(f"Bitrate: {cfg.bitrate / 1_000_000:.1f} Mbps")
if cfg.use_fd:
print(f"Data bitrate: {cfg.data_bitrate / 1_000_000:.1f} Mbps")
print()
interfaces = cfg.get_interfaces()
for interface in interfaces:
print(f"Configuring {interface}...")
if setup_interface(interface, cfg.bitrate, cfg.data_bitrate, cfg.use_fd):
is_up, status, _ = check_interface_status(interface)
print(f"{interface}: {status}")
else:
print(f"{interface}: Failed")
print("\nSetup complete!")
print("\nNext: Test motors with:")
print(f" lerobot-setup-can --mode=test --interfaces {','.join(interfaces)}")
def run_test(cfg: CANSetupConfig):
"""Test motors on CAN interfaces."""
print("=" * 50)
print("CAN Motor Test")
print("=" * 50)
print(f"Testing motors 0x{min(cfg.motor_ids):02X}-0x{max(cfg.motor_ids):02X}")
print(f"Mode: {'CAN FD' if cfg.use_fd else 'CAN 2.0'}")
print()
interfaces = cfg.get_interfaces()
all_results = {}
for interface in interfaces:
all_results[interface] = test_interface(cfg, interface)
total_found = sum(sum(1 for r in res.values() if r.get("found")) for res in all_results.values())
print("\n" + "=" * 50)
print("Summary")
print("=" * 50)
print(f"Total motors found: {total_found}")
if total_found == 0:
print("\n⚠ No motors found! Check:")
print(" 1. Motors are powered (24V)")
print(" 2. CAN wiring (CANH, CANL, GND)")
print(" 3. Motor timeout parameter > 0 (use Damiao tools)")
print(" 4. 120Ω termination at both cable ends")
print(f" 5. Interface configured: lerobot-setup-can --mode=setup --interfaces {interfaces[0]}")
def run_speed(cfg: CANSetupConfig):
"""Run speed tests on CAN interfaces."""
print("=" * 50)
print("CAN Speed Test")
print("=" * 50)
for interface in cfg.get_interfaces():
speed_test(cfg, interface)
@draccus.wrap()
def setup_can(cfg: CANSetupConfig):
if not is_package_available("can"):
print("Error: python-can not installed. Install with: pip install python-can")
sys.exit(1)
if cfg.mode == "setup":
run_setup(cfg)
elif cfg.mode == "test":
run_test(cfg)
elif cfg.mode == "speed":
run_speed(cfg)
else:
print(f"Unknown mode: {cfg.mode}")
print("Available modes: setup, test, speed")
sys.exit(1)
def main():
setup_can()
if __name__ == "__main__":
main()
+92 -6
View File
@@ -148,6 +148,92 @@ def update_policy(
return train_metrics, output_dict
def get_default_peft_configuration(policy_type):
"""Build a basic PEFT configuration for the given policy type assuming that we train a policy from a checkpoint."""
common_projections = "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
if policy_type == "smolvla":
return {
"target_modules": rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))",
"modules_to_save": [],
}
elif policy_type in ("pi0", "pi05"):
return {
"target_modules": rf"(.*\.gemma_expert\..*\.self_attn.(q|v)_proj|model\.({common_projections}))",
"modules_to_save": [],
}
return {"modules_to_save": None}
def wrap_policy_in_peft_model(cfg, policy):
from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType, get_peft_model
# Disable all gradients because we'll only train the parameters selected by the PEFT method.
# Layers that should receive gradients anyway need to be listed in `modules_to_save`.
for p in policy.parameters():
p.requires_grad_(False)
if not cfg.policy.pretrained_path:
raise ValueError(
"Training from scratch using PEFT. This is unlikely to yield good results. "
"Supply a `policy.path` to fine-tune an existing model."
)
if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights:
logging.warning(
"Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. Set "
"`load_vlm_weights=True` to fine-tune the existing policy."
)
peft_config_policy = get_default_peft_configuration(cfg.policy.type)
peft_config_cli = dataclasses.asdict(cfg.peft) if cfg.peft else {}
peft_config_cli["modules_to_save"] = peft_config_cli["full_training_modules"] # compatibility with PEFT
peft_method_type = PeftType[peft_config_cli["method_type"].upper()]
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]
# Handle specific CLI overrides
for key in ["target_modules", "modules_to_save", "r"]:
if peft_config_cli[key] is not None:
peft_config_policy[key] = peft_config_cli[key]
if "target_modules" not in peft_config_policy:
raise ValueError(
f"There is no default `target_modules` value for policy {cfg.policy.type}. Please pass it manually."
)
# Init method depends on the used PEFT method, your specific PEFT method
# might not be considered here, in that case an error is raised.
if peft_config_cli["init_type"] is not None:
if peft_method_type == "LORA":
peft_config_policy["init_lora_weights"] = peft_config_cli["init_type"]
elif peft_method_type == "MISS":
peft_config_policy["init_weights"] = peft_config_cli["init_type"]
else:
raise ValueError(
f"Init type {peft_config_cli['init_type']} unknown for PEFT method {peft_method_type}."
)
# PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the
# correct base model in `make_policy` since in a PEFT loading setting we only get the path to the
# adapter, not the base model.
if policy.config.pretrained_path:
policy.name_or_path = str(policy.config.pretrained_path)
# Finally wrap the policy in a PEFT model
policy = get_peft_model(
policy,
peft_config_cls(**peft_config_policy),
)
# Make sure that the config is tagged as using PEFT so that the loading code can take the
# appropriate steps to use the adapter weights and the PEFT config instead of the full model weights.
policy.config.use_peft = True
return policy
@parser.wrap()
def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
"""
@@ -177,7 +263,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
# Force the device to be CPU when policy.device is set to CPU.
force_cpu = cfg.policy.device == "cpu"
# Note (maractin): cfg.policy may be None before validate() fully loads from pretrained_path
force_cpu = cfg.policy is not None and cfg.policy.device == "cpu"
accelerator = Accelerator(
step_scheduler_with_optimizer=False,
kwargs_handlers=[ddp_kwargs],
@@ -225,8 +312,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None
if cfg.eval_freq > 0 and cfg.env is not None and is_main_process:
logging.info("Creating env")
if cfg.eval_freq > 0 and cfg.env is not None:
if is_main_process:
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
if is_main_process:
@@ -239,9 +327,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
if cfg.peft is not None:
logging.info("Using PEFT! Wrapping model.")
# Convert CLI peft config to dict for overrides
peft_cli_overrides = dataclasses.asdict(cfg.peft)
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
policy = wrap_policy_in_peft_model(cfg, policy)
# Wait for all processes to finish policy creation before continuing
accelerator.wait_for_everyone()
@@ -18,7 +18,7 @@ import logging
from functools import cached_property
from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from ..so_leader import SOLeader
from ..teleoperator import Teleoperator
@@ -92,8 +92,10 @@ class BiSOLeader(Teleoperator):
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_action(self) -> dict[str, float]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
action_dict = {}
# Add "left_" prefix
@@ -21,7 +21,7 @@ from typing import Any
import numpy as np
from lerobot.processor import RobotAction
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from ..teleoperator import Teleoperator
from ..utils import TeleopEvents
@@ -86,8 +86,10 @@ class GamepadTeleop(Teleoperator):
self.gamepad = Gamepad()
self.gamepad.start()
@check_if_not_connected
def get_action(self) -> RobotAction:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Update the controller to get fresh inputs
self.gamepad.update()
@@ -22,7 +22,7 @@ from pprint import pformat
import serial
from lerobot.motors.motors_bus import MotorCalibration, MotorNormMode
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.utils import enter_pressed, move_cursor_up
from ..teleoperator import Teleoperator
@@ -93,8 +93,10 @@ class HomunculusArm(Teleoperator):
with self.serial_lock:
return self.serial.is_open and self.thread.is_alive()
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
if not self.serial.is_open:
self.serial.open()
self.thread.start()
@@ -297,16 +299,20 @@ class HomunculusArm(Teleoperator):
except Exception as e:
logger.debug(f"Error reading frame in background thread for {self}: {e}")
@check_if_not_connected
def get_action(self) -> dict[str, float]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
joint_positions = self._read()
return {f"{joint}.pos": pos for joint, pos in joint_positions.items()}
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
if not self.is_connected:
DeviceNotConnectedError(f"{self} is not connected.")
self.stop_event.set()
self.thread.join(timeout=1)
self.serial.close()
@@ -24,7 +24,7 @@ import serial
from lerobot.motors import MotorCalibration
from lerobot.motors.motors_bus import MotorNormMode
from lerobot.teleoperators.homunculus.joints_translation import homunculus_glove_to_hope_jr_hand
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.utils import enter_pressed, move_cursor_up
from ..teleoperator import Teleoperator
@@ -119,8 +119,10 @@ class HomunculusGlove(Teleoperator):
with self.serial_lock:
return self.serial.is_open and self.thread.is_alive()
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
if not self.serial.is_open:
self.serial.open()
self.thread.start()
@@ -323,8 +325,10 @@ class HomunculusGlove(Teleoperator):
except Exception as e:
logger.debug(f"Error reading frame in background thread for {self}: {e}")
@check_if_not_connected
def get_action(self) -> dict[str, float]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
joint_positions = self._read()
return homunculus_glove_to_hope_jr_hand(
{f"{joint}.pos": pos for joint, pos in joint_positions.items()}
@@ -333,8 +337,10 @@ class HomunculusGlove(Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
if not self.is_connected:
DeviceNotConnectedError(f"{self} is not connected.")
self.stop_event.set()
self.thread.join(timeout=1)
self.serial.close()
@@ -22,7 +22,7 @@ from queue import Queue
from typing import Any
from lerobot.processor import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from ..utils import TeleopEvents
@@ -86,8 +86,12 @@ class KeyboardTeleop(Teleoperator):
def is_calibrated(self) -> bool:
pass
@check_if_already_connected
def connect(self) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(
"Keyboard is already connected. Do not run `robot.connect()` twice."
)
if PYNPUT_AVAILABLE:
logging.info("pynput is available - enabling local keyboard listener.")
self.listener = keyboard.Listener(
@@ -121,10 +125,14 @@ class KeyboardTeleop(Teleoperator):
def configure(self):
pass
@check_if_not_connected
def get_action(self) -> RobotAction:
before_read_t = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(
"KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`."
)
self._drain_pressed_keys()
# Generate action based on current key states
@@ -136,8 +144,11 @@ class KeyboardTeleop(Teleoperator):
def send_feedback(self, feedback: dict[str, Any]) -> None:
pass
@check_if_not_connected
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(
"KeyboardTeleop is not connected. You need to run `robot.connect()` before `disconnect()`."
)
if self.listener is not None:
self.listener.stop()
@@ -171,8 +182,12 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2},
}
@check_if_not_connected
def get_action(self) -> RobotAction:
if not self.is_connected:
raise DeviceNotConnectedError(
"KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`."
)
self._drain_pressed_keys()
delta_x = 0.0
delta_y = 0.0
@@ -360,7 +375,6 @@ class KeyboardRoverTeleop(KeyboardTeleop):
# Only remove key if it's being released
self.current_pressed.pop(key_char, None)
@check_if_not_connected
def get_action(self) -> RobotAction:
"""
Get the current action based on pressed keys.
@@ -370,6 +384,11 @@ class KeyboardRoverTeleop(KeyboardTeleop):
"""
before_read_t = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(
"KeyboardRoverTeleop is not connected. You need to run `connect()` before `get_action()`."
)
self._drain_pressed_keys()
linear_velocity = 0.0
@@ -23,7 +23,7 @@ from lerobot.motors.dynamixel import (
DynamixelMotorsBus,
OperatingMode,
)
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from .config_koch_leader import KochLeaderConfig
@@ -69,8 +69,10 @@ class KochLeader(Teleoperator):
def is_connected(self) -> bool:
return self.bus.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect()
if not self.is_calibrated and calibrate:
logger.info(
@@ -159,8 +161,10 @@ class KochLeader(Teleoperator):
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
@check_if_not_connected
def get_action(self) -> dict[str, float]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start = time.perf_counter()
action = self.bus.sync_read("Present_Position")
action = {f"{motor}.pos": val for motor, val in action.items()}
@@ -172,7 +176,9 @@ class KochLeader(Teleoperator):
# TODO(rcadene, aliberts): Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect()
logger.info(f"{self} disconnected.")
@@ -27,4 +27,4 @@ class OmxLeaderConfig(TeleoperatorConfig):
# Sets the arm in torque mode with the gripper motor set to this value. This makes it possible to squeeze
# the gripper and have it spring back to an open position on its own.
gripper_open_pos: float = 60.0
gripper_open_pos: float = 37.0
@@ -23,7 +23,7 @@ from lerobot.motors.dynamixel import (
DynamixelMotorsBus,
OperatingMode,
)
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from .config_omx_leader import OmxLeaderConfig
@@ -68,8 +68,10 @@ class OmxLeader(Teleoperator):
def is_connected(self) -> bool:
return self.bus.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect()
if not self.is_calibrated and calibrate:
logger.info(
@@ -103,7 +105,7 @@ class OmxLeader(Teleoperator):
self.calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=drive_modes[motor],
homing_offset=0 if motor != "gripper" else 100,
homing_offset=0,
range_min=0,
range_max=4095,
)
@@ -123,20 +125,12 @@ class OmxLeader(Teleoperator):
# point
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
if motor == "gripper":
self.bus.write("Drive_Mode", motor, DriveMode.INVERTED.value)
else:
self.bus.write("Drive_Mode", motor, DriveMode.NON_INVERTED.value)
# Use 'position control current based' for gripper to be limited by the limit of the current.
# For the follower gripper, it means it can grasp an object without forcing too much even tho,
# its goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
# For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger
# to make it move, and it will move back to its original target position when we release the force.
self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
self.bus.write("Current_Limit", "gripper", 100)
self.bus.write("Goal_Current", "gripper", 100)
self.bus.write("Homing_Offset", "gripper", 100)
# Set gripper's goal pos in current position mode so that we can use it as a trigger.
self.bus.enable_torque("gripper")
if self.is_calibrated:
@@ -148,8 +142,10 @@ class OmxLeader(Teleoperator):
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
@check_if_not_connected
def get_action(self) -> dict[str, float]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start = time.perf_counter()
action = self.bus.sync_read("Present_Position")
action = {f"{motor}.pos": val for motor, val in action.items()}
@@ -161,7 +157,9 @@ class OmxLeader(Teleoperator):
# TODO(rcadene, aliberts): Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect()
logger.info(f"{self} disconnected.")
@@ -28,7 +28,7 @@ from teleop import Teleop
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.rotation import Rotation
logger = logging.getLogger(__name__)
@@ -81,8 +81,10 @@ class IOSPhone(BasePhone, Teleoperator):
def is_connected(self) -> bool:
return self._group is not None
@check_if_already_connected
def connect(self) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.")
lookup = hebi.Lookup()
time.sleep(2.0)
@@ -162,8 +164,10 @@ class IOSPhone(BasePhone, Teleoperator):
pos = ar_pos - rot.apply(self.config.camera_offset)
return True, pos, rot, pose
@check_if_not_connected
def get_action(self) -> dict:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
has_pose, raw_position, raw_rotation, fb_pose = self._read_current_pose()
if not has_pose or not self.is_calibrated:
return {}
@@ -203,8 +207,10 @@ class IOSPhone(BasePhone, Teleoperator):
"phone.enabled": self._enabled,
}
@check_if_not_connected
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self._group = None
@@ -224,8 +230,10 @@ class AndroidPhone(BasePhone, Teleoperator):
def is_connected(self) -> bool:
return self._teleop is not None
@check_if_already_connected
def connect(self) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
logger.info("Starting teleop stream for Android...")
self._teleop = Teleop()
self._teleop.subscribe(self._android_callback)
@@ -313,8 +321,10 @@ class AndroidPhone(BasePhone, Teleoperator):
self._latest_pose = pose
self._latest_message = message
@check_if_not_connected
def get_action(self) -> dict:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
ok, raw_pos, raw_rot, pose = self._read_current_pose()
if not ok or not self.is_calibrated:
return {}
@@ -346,8 +356,10 @@ class AndroidPhone(BasePhone, Teleoperator):
"phone.enabled": self._enabled,
}
@check_if_not_connected
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self._teleop = None
if self._teleop_thread and self._teleop_thread.is_alive():
self._teleop_thread.join(timeout=1.0)
@@ -26,8 +26,7 @@ if TYPE_CHECKING or _reachy2_sdk_available:
else:
ReachySDK = None
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
@@ -127,8 +126,10 @@ class Reachy2Teleoperator(Teleoperator):
def is_connected(self) -> bool:
return self.reachy.is_connected() if self.reachy is not None else False
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.reachy = ReachySDK(self.config.ip_address)
if not self.is_connected:
@@ -145,10 +146,12 @@ class Reachy2Teleoperator(Teleoperator):
def configure(self) -> None:
pass
@check_if_not_connected
def get_action(self) -> dict[str, float]:
start = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
joint_action: dict[str, float] = {}
vel_action: dict[str, float] = {}
@@ -23,7 +23,7 @@ from lerobot.motors.feetech import (
FeetechMotorsBus,
OperatingMode,
)
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from .config_so_leader import SOLeaderTeleopConfig
@@ -66,8 +66,10 @@ class SOLeader(Teleoperator):
def is_connected(self) -> bool:
return self.bus.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect()
if not self.is_calibrated and calibrate:
logger.info(
@@ -137,8 +139,10 @@ class SOLeader(Teleoperator):
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
@check_if_not_connected
def get_action(self) -> dict[str, float]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start = time.perf_counter()
action = self.bus.sync_read("Present_Position")
action = {f"{motor}.pos": val for motor, val in action.items()}
@@ -150,8 +154,10 @@ class SOLeader(Teleoperator):
# TODO: Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
if not self.is_connected:
DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect()
logger.info(f"{self} disconnected.")
-26
View File
@@ -58,32 +58,6 @@ class Teleoperator(abc.ABC):
def __str__(self) -> str:
return f"{self.id} {self.__class__.__name__}"
def __enter__(self):
"""
Context manager entry.
Automatically connects to the camera.
"""
self.connect()
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
Context manager exit.
Automatically disconnects, ensuring resources are released even on error.
"""
self.disconnect()
def __del__(self) -> None:
"""
Destructor safety net.
Attempts to disconnect if the object is garbage collected without cleanup.
"""
try:
if self.is_connected:
self.disconnect()
except Exception: # nosec B110
pass
@property
@abc.abstractmethod
def action_features(self) -> dict:
-41
View File
@@ -1,41 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
def check_if_not_connected(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__} is not connected. Run `.connect()` first."
)
return func(self, *args, **kwargs)
return wrapper
def check_if_already_connected(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self.__class__.__name__} is already connected.")
return func(self, *args, **kwargs)
return wrapper
+8 -17
View File
@@ -21,23 +21,12 @@ from typing import Any
from draccus.choice_types import ChoiceRegistry
def is_package_available(
pkg_name: str, import_name: str | None = None, return_version: bool = False
) -> tuple[bool, str] | bool:
"""
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
Check if the package spec exists and grab its version to avoid importing a local directory.
Args:
pkg_name: The name of the package as installed via pip (e.g. "python-can").
import_name: The actual name used to import the package (e.g. "can").
Defaults to pkg_name if not provided.
return_version: Whether to return the version string.
**Note:** this doesn't work for all packages.
"""
if import_name is None:
import_name = pkg_name
# Check if the module spec exists using the import name
package_exists = importlib.util.find_spec(import_name) is not None
package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
try:
@@ -48,7 +37,7 @@ def is_package_available(
# Fallback method: Only for "torch" and versions containing "dev"
if pkg_name == "torch":
try:
package = importlib.import_module(import_name)
package = importlib.import_module(pkg_name)
temp_version = getattr(package, "__version__", "N/A")
# Check if the version contains "dev"
if "dev" in temp_version:
@@ -59,6 +48,9 @@ def is_package_available(
except ImportError:
# If the package can't be imported, it's not available
package_exists = False
elif pkg_name == "grpc":
package = importlib.import_module(pkg_name)
package_version = getattr(package, "__version__", "N/A")
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False
@@ -73,7 +65,6 @@ _transformers_available = is_package_available("transformers")
_peft_available = is_package_available("peft")
_scipy_available = is_package_available("scipy")
_reachy2_sdk_available = is_package_available("reachy2_sdk")
_can_available = is_package_available("python-can", "can")
def make_device_from_device_class(config: ChoiceRegistry) -> Any:
+2 -8
View File
@@ -144,18 +144,12 @@ def test_async_inference_e2e(monkeypatch):
client = RobotClient(client_config)
assert client.start(), "Client failed initial handshake with the server"
# Track action chunks received and verify device type
action_chunks_received = {"count": 0, "actions_on_cpu": True}
# Track action chunks received without modifying RobotClient
action_chunks_received = {"count": 0}
original_aggregate = client._aggregate_action_queues
def counting_aggregate(*args, **kwargs):
action_chunks_received["count"] += 1
# Check that all received actions are on CPU
if args:
for timed_action in args[0]: # args[0] is the list of TimedAction
action_tensor = timed_action.get_action()
if action_tensor.device.type != "cpu":
action_chunks_received["actions_on_cpu"] = False
return original_aggregate(*args, **kwargs)
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)
+1 -1
View File
@@ -62,7 +62,7 @@ class MockPolicy:
@pytest.fixture
@require_package("grpcio", "grpc")
@require_package("grpc")
def policy_server():
"""Fresh `PolicyServer` instance with a stubbed-out policy model."""
# Import only when the test actually runs (after decorator check)
-145
View File
@@ -16,7 +16,6 @@
from unittest.mock import patch
import datasets
import torch
from lerobot.datasets.aggregate import aggregate_datasets
@@ -381,147 +380,3 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
for key in aggr_ds.meta.video_keys:
assert key in item, f"Video key {key} missing from item {i}"
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
def assert_image_schema_preserved(aggr_ds):
"""Test that HuggingFace Image feature schema is preserved in aggregated parquet files.
This verifies the fix for a bug where image columns were written with a generic
struct schema {'bytes': Value('binary'), 'path': Value('string')} instead of
the proper Image() feature type, causing HuggingFace Hub viewer to display
raw dict objects instead of image thumbnails.
"""
image_keys = aggr_ds.meta.image_keys
if not image_keys:
return
# Check that parquet files have proper Image schema
data_dir = aggr_ds.root / "data"
parquet_files = list(data_dir.rglob("*.parquet"))
assert len(parquet_files) > 0, "No parquet files found in aggregated dataset"
for parquet_file in parquet_files:
# Load with HuggingFace datasets to check schema
ds = datasets.Dataset.from_parquet(str(parquet_file))
for image_key in image_keys:
feature = ds.features.get(image_key)
assert feature is not None, f"Image key '{image_key}' not found in parquet schema"
assert isinstance(feature, datasets.Image), (
f"Image key '{image_key}' should have Image() feature type, "
f"but got {type(feature).__name__}: {feature}. "
"This indicates image schema was not preserved during aggregation."
)
def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
"""Test that image frames are correctly preserved after aggregation."""
image_keys = aggr_ds.meta.image_keys
if not image_keys:
return
def images_equal(img1, img2):
return torch.allclose(img1, img2)
# Test the section corresponding to the first dataset (ds_0)
for i in range(len(ds_0)):
assert aggr_ds[i]["index"] == i, (
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
)
for key in image_keys:
assert images_equal(aggr_ds[i][key], ds_0[i][key]), (
f"Image frames at position {i} should be equal between aggregated and ds_0"
)
# Test the section corresponding to the second dataset (ds_1)
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
assert aggr_ds[i]["index"] == i, (
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
)
for key in image_keys:
assert images_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), (
f"Image frames at position {i} should be equal between aggregated and ds_1"
)
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
This test specifically verifies that:
1. Image-based datasets can be aggregated correctly
2. The HuggingFace Image() feature type is preserved in parquet files
3. Image data integrity is maintained across aggregation
4. Images can be properly decoded after aggregation
This catches the bug where to_parquet_with_hf_images() was not passing
the features schema, causing image columns to be written as generic
struct types instead of Image() types.
"""
ds_0_num_frames = 50
ds_1_num_frames = 75
ds_0_num_episodes = 2
ds_1_num_episodes = 3
# Create two image-based datasets (use_videos=False)
ds_0 = lerobot_dataset_factory(
root=tmp_path / "image_0",
repo_id=f"{DUMMY_REPO_ID}_image_0",
total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames,
use_videos=False, # Image-based dataset
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "image_1",
repo_id=f"{DUMMY_REPO_ID}_image_1",
total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames,
use_videos=False, # Image-based dataset
)
# Verify source datasets have image keys
assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys"
assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys"
# Aggregate the datasets
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_image_aggr",
aggr_root=tmp_path / "image_aggr",
)
# Load the aggregated dataset
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "image_aggr")
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_image_aggr", root=tmp_path / "image_aggr")
# Verify aggregated dataset has image keys
assert len(aggr_ds.meta.image_keys) > 0, "Aggregated dataset should have image keys"
assert aggr_ds.meta.image_keys == ds_0.meta.image_keys, "Image keys should match source datasets"
# Run standard aggregation assertions
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
expected_total_frames = ds_0_num_frames + ds_1_num_frames
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
# Image-specific assertions
assert_image_schema_preserved(aggr_ds)
assert_image_frames_integrity(aggr_ds, ds_0, ds_1)
# Verify images can be accessed and have correct shape
sample_item = aggr_ds[0]
for image_key in aggr_ds.meta.image_keys:
img = sample_item[image_key]
assert isinstance(img, torch.Tensor), f"Image {image_key} should be a tensor"
assert img.dim() == 3, f"Image {image_key} should have 3 dimensions (C, H, W)"
assert img.shape[0] == 3, f"Image {image_key} should have 3 channels"
assert_dataset_iteration_works(aggr_ds)
+5 -5
View File
@@ -29,7 +29,7 @@ from lerobot.datasets.dataset_tools import (
remove_feature,
split_dataset,
)
from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset
from lerobot.scripts.lerobot_edit_dataset import convert_dataset_to_videos
@pytest.fixture
@@ -1050,7 +1050,7 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
assert "reward" in modified_dataset.meta.features
def test_convert_image_to_video_dataset(tmp_path):
def test_convert_dataset_to_videos(tmp_path):
"""Test converting lerobot/pusht_image dataset to video format."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
@@ -1071,7 +1071,7 @@ def test_convert_image_to_video_dataset(tmp_path):
assert "observation.image" in source_dataset.meta.features
# Convert to video dataset (only first 2 episodes for speed)
video_dataset = convert_image_to_video_dataset(
video_dataset = convert_dataset_to_videos(
dataset=source_dataset,
output_dir=output_dir,
repo_id="lerobot/pusht_video",
@@ -1113,7 +1113,7 @@ def test_convert_image_to_video_dataset(tmp_path):
shutil.rmtree(output_dir)
def test_convert_image_to_video_dataset_subset_episodes(tmp_path):
def test_convert_dataset_to_videos_subset_episodes(tmp_path):
"""Test converting only specific episodes from lerobot/pusht_image to video format."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
@@ -1132,7 +1132,7 @@ def test_convert_image_to_video_dataset_subset_episodes(tmp_path):
# Convert only episode 0 to video (subset of loaded episodes)
episode_indices = [0]
video_dataset = convert_image_to_video_dataset(
video_dataset = convert_dataset_to_videos(
dataset=source_dataset,
output_dir=output_dir,
repo_id="lerobot/pusht_video_subset",
-258
View File
@@ -352,65 +352,6 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
image_array_to_pil_image(image)
def test_tmp_image_deletion(tmp_path, empty_lerobot_dataset_factory):
"""Verify temporary image directories are removed for image features after saving episode."""
# Image feature: images should be deleted after saving episode
image_key = "image"
features_image = {
image_key: {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]}
}
ds_img = empty_lerobot_dataset_factory(root=tmp_path / "img", features=features_image)
ds_img.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
ds_img.save_episode()
img_dir = ds_img._get_image_file_dir(0, image_key)
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory):
"""Verify temporary image directories are removed for video encoding when `batch_encoding_size == 1`."""
# Video feature: when batch_encoding_size == 1 temporary images should be deleted
vid_key = "video"
features_video = {
vid_key: {"dtype": "video", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]}
}
ds_vid = empty_lerobot_dataset_factory(root=tmp_path / "vid", features=features_video)
ds_vid.batch_encoding_size = 1
ds_vid.add_frame({vid_key: np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
ds_vid.save_episode()
vid_img_dir = ds_vid._get_image_file_dir(0, vid_key)
assert not vid_img_dir.exists(), (
"Temporary image directory should be removed when batch_encoding_size == 1"
)
def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
"""Verify temporary image directories are removed appropriately when both image and video features are present."""
image_key = "image"
vid_key = "video"
features_mixed = {
image_key: {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]},
vid_key: {"dtype": "video", "shape": DUMMY_HWC, "names": ["height", "width", "channels"]},
}
ds_mixed = empty_lerobot_dataset_factory(
root=tmp_path / "mixed", features=features_mixed, batch_encoding_size=2
)
ds_mixed.add_frame(
{
"image": np.random.rand(*DUMMY_CHW),
"video": np.random.rand(*DUMMY_HWC),
"task": "Dummy task",
}
)
ds_mixed.save_episode()
img_dir = ds_mixed._get_image_file_dir(0, image_key)
vid_img_dir = ds_mixed._get_image_file_dir(0, vid_key)
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
assert vid_img_dir.exists(), (
"Temporary image directory should not be removed for video features when batch_encoding_size == 2"
)
# TODO(aliberts):
# - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames
@@ -1451,202 +1392,3 @@ def test_valid_video_codecs_constant():
assert "hevc" in VALID_VIDEO_CODECS
assert "libsvtav1" in VALID_VIDEO_CODECS
assert len(VALID_VIDEO_CODECS) == 3
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
"""Regression test for bug where delta_timestamps incorrectly marked all frames as padded when using episodes filter.
The bug occurred because _get_query_indices was using the relative index (idx) in the filtered dataset
instead of the absolute index when comparing against episode boundaries (ep_start, ep_end).
"""
features = {
"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
"action": {"dtype": "float32", "shape": (2,), "names": ["vx", "vy"]},
}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
# Create 3 episodes with 10 frames each
frames_per_episode = 10
for ep_idx in range(3):
for frame_idx in range(frames_per_episode):
dataset.add_frame(
{
"observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32),
"action": torch.randn(2),
"task": f"task_{ep_idx}",
}
)
dataset.save_episode()
dataset.finalize()
# Load only episode 1 (middle episode) with delta_timestamps
delta_ts = {"observation.state": [0.0]} # Just the current frame
filtered_dataset = LeRobotDataset(
dataset.repo_id,
root=dataset.root,
episodes=[1],
delta_timestamps=delta_ts,
)
# Verify the filtered dataset has the correct length
assert len(filtered_dataset) == frames_per_episode
# Check that no frames are marked as padded (since delta=0 should always be valid)
for idx in range(len(filtered_dataset)):
frame = filtered_dataset[idx]
assert frame["observation.state_is_pad"].item() is False, f"Frame {idx} incorrectly marked as padded"
# Verify we're getting data from episode 1
assert frame["episode_index"].item() == 1
def test_delta_timestamps_padding_at_episode_boundaries(tmp_path, empty_lerobot_dataset_factory):
"""Test that delta_timestamps correctly marks padding at episode boundaries when using episodes filter."""
features = {
"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
"action": {"dtype": "float32", "shape": (2,), "names": ["vx", "vy"]},
}
dataset = empty_lerobot_dataset_factory(
root=tmp_path / "test", features=features, use_videos=False, fps=10
)
# Create 3 episodes with 5 frames each
frames_per_episode = 5
for ep_idx in range(3):
for frame_idx in range(frames_per_episode):
dataset.add_frame(
{
"observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32),
"action": torch.randn(2),
"task": f"task_{ep_idx}",
}
)
dataset.save_episode()
dataset.finalize()
# Load only episode 1 with delta_timestamps that go beyond episode boundaries
# fps=10, so 0.1s = 1 frame offset
delta_ts = {"observation.state": [-0.2, -0.1, 0.0, 0.1, 0.2]} # -2, -1, 0, +1, +2 frames
filtered_dataset = LeRobotDataset(
dataset.repo_id,
root=dataset.root,
episodes=[1],
delta_timestamps=delta_ts,
tolerance_s=0.04, # Slightly less than half a frame at 10fps
)
assert len(filtered_dataset) == frames_per_episode
# Check padding at the start of the episode (first frame)
first_frame = filtered_dataset[0]
is_pad = first_frame["observation.state_is_pad"].tolist()
# At frame 0 of episode 1: delta -2 and -1 should be padded, 0, +1, +2 should not
assert is_pad == [True, True, False, False, False], f"First frame padding incorrect: {is_pad}"
# Check middle frame (no padding expected)
mid_frame = filtered_dataset[2]
is_pad = mid_frame["observation.state_is_pad"].tolist()
assert is_pad == [False, False, False, False, False], f"Middle frame padding incorrect: {is_pad}"
# Check padding at the end of the episode (last frame)
last_frame = filtered_dataset[4]
is_pad = last_frame["observation.state_is_pad"].tolist()
# At frame 4 of episode 1: delta -2, -1, 0 should not be padded, +1, +2 should be
assert is_pad == [False, False, False, True, True], f"Last frame padding incorrect: {is_pad}"
def test_delta_timestamps_multiple_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
"""Test delta_timestamps with multiple non-consecutive episodes selected."""
features = {
"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
}
dataset = empty_lerobot_dataset_factory(
root=tmp_path / "test", features=features, use_videos=False, fps=10
)
# Create 5 episodes with 5 frames each
frames_per_episode = 5
for ep_idx in range(5):
for frame_idx in range(frames_per_episode):
dataset.add_frame(
{
"observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32),
"task": f"task_{ep_idx}",
}
)
dataset.save_episode()
dataset.finalize()
# Load episodes 1 and 3 (non-consecutive)
delta_ts = {"observation.state": [0.0]}
filtered_dataset = LeRobotDataset(
dataset.repo_id,
root=dataset.root,
episodes=[1, 3],
delta_timestamps=delta_ts,
)
assert len(filtered_dataset) == 2 * frames_per_episode
# All frames should have valid (non-padded) data for delta=0
for idx in range(len(filtered_dataset)):
frame = filtered_dataset[idx]
assert frame["observation.state_is_pad"].item() is False
# Verify we're getting the correct episodes
episode_indices = [filtered_dataset[i]["episode_index"].item() for i in range(len(filtered_dataset))]
expected_episodes = [1] * frames_per_episode + [3] * frames_per_episode
assert episode_indices == expected_episodes
def test_delta_timestamps_query_returns_correct_values(tmp_path, empty_lerobot_dataset_factory):
"""Test that delta_timestamps returns the correct observation values, not just correct padding."""
features = {
"observation.state": {"dtype": "float32", "shape": (1,), "names": ["x"]},
}
dataset = empty_lerobot_dataset_factory(
root=tmp_path / "test", features=features, use_videos=False, fps=10
)
# Create 2 episodes with known values
# Episode 0: frames with values 0, 1, 2, 3, 4
# Episode 1: frames with values 10, 11, 12, 13, 14
frames_per_episode = 5
for ep_idx in range(2):
for frame_idx in range(frames_per_episode):
value = ep_idx * 10 + frame_idx
dataset.add_frame(
{
"observation.state": torch.tensor([value], dtype=torch.float32),
"task": f"task_{ep_idx}",
}
)
dataset.save_episode()
dataset.finalize()
# Load episode 1 with delta that looks at previous frame
delta_ts = {"observation.state": [-0.1, 0.0]} # Previous frame and current frame
filtered_dataset = LeRobotDataset(
dataset.repo_id,
root=dataset.root,
episodes=[1],
delta_timestamps=delta_ts,
tolerance_s=0.04,
)
# Check frame 2 of episode 1 (which has absolute index 7, value 12)
frame = filtered_dataset[2]
state_values = frame["observation.state"].tolist()
# Should get [11, 12] - the previous and current values within episode 1
assert state_values == [11.0, 12.0], f"Expected [11.0, 12.0], got {state_values}"
# Check first frame - previous frame should be clamped to episode start (padded)
first_frame = filtered_dataset[0]
state_values = first_frame["observation.state"].tolist()
is_pad = first_frame["observation.state_is_pad"].tolist()
# Previous frame is outside episode, so it's clamped to first frame and marked as padded
assert state_values == [10.0, 10.0], f"Expected [10.0, 10.0], got {state_values}"
assert is_pad == [True, False], f"Expected [True, False], got {is_pad}"
+16 -6
View File
@@ -22,7 +22,7 @@ from lerobot.cameras import CameraConfig, make_cameras_from_configs
from lerobot.motors.motors_bus import Motor, MotorNormMode
from lerobot.processor import RobotAction, RobotObservation
from lerobot.robots import Robot, RobotConfig
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from tests.mocks.mock_motors_bus import MockMotorsBus
@@ -98,8 +98,10 @@ class MockRobot(Robot):
def is_connected(self) -> bool:
return self._is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self._is_connected = True
if calibrate:
self.calibrate()
@@ -108,15 +110,19 @@ class MockRobot(Robot):
def is_calibrated(self) -> bool:
return self._is_calibrated
@check_if_not_connected
def calibrate(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self._is_calibrated = True
def configure(self) -> None:
pass
@check_if_not_connected
def get_observation(self) -> RobotObservation:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.config.random_values:
return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors}
else:
@@ -124,10 +130,14 @@ class MockRobot(Robot):
f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True)
}
@check_if_not_connected
def send_action(self, action: RobotAction) -> RobotAction:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
return action
@check_if_not_connected
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self._is_connected = False
+16 -7
View File
@@ -21,7 +21,7 @@ from typing import Any
from lerobot.processor import RobotAction
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
@TeleoperatorConfig.register_subclass("mock_teleop")
@@ -68,8 +68,10 @@ class MockTeleop(Teleoperator):
def is_connected(self) -> bool:
return self._is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self._is_connected = True
if calibrate:
self.calibrate()
@@ -78,15 +80,19 @@ class MockTeleop(Teleoperator):
def is_calibrated(self) -> bool:
return self._is_calibrated
@check_if_not_connected
def calibrate(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self._is_calibrated = True
def configure(self) -> None:
pass
@check_if_not_connected
def get_action(self) -> RobotAction:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.config.random_values:
return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors}
else:
@@ -94,9 +100,12 @@ class MockTeleop(Teleoperator):
f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True)
}
@check_if_not_connected
def send_feedback(self, feedback: dict[str, Any]) -> None: ...
def send_feedback(self, feedback: dict[str, Any]) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
@check_if_not_connected
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self._is_connected = False
-66
View File
@@ -1,66 +0,0 @@
"""Minimal test script for Damiao motor with ID 3."""
import pytest
from lerobot.utils.import_utils import _can_available
if not _can_available:
pytest.skip("python-can not available", allow_module_level=True)
from lerobot.motors import Motor
from lerobot.motors.damiao import DamiaoMotorsBus
@pytest.mark.skip(reason="Requires physical Damiao motor and CAN interface")
def test_damiao_motor():
motors = {
"joint_3": Motor(
id=0x03,
model="damiao",
norm_mode="degrees",
motor_type_str="dm4310",
recv_id=0x13,
),
}
bus = DamiaoMotorsBus(port="can0", motors=motors)
try:
print("Connecting...")
bus.connect()
print("✓ Connected")
print("Enabling torque...")
bus.enable_torque()
print("✓ Torque enabled")
print("Reading all states...")
states = bus.sync_read_all_states()
print(f"✓ States: {states}")
print("Reading position...")
positions = bus.sync_read("Present_Position")
print(f"✓ Position: {positions}")
print("Testing MIT control batch...")
current_pos = states["joint_3"]["position"]
commands = {"joint_3": (10.0, 0.5, current_pos, 0.0, 0.0)}
bus._mit_control_batch(commands)
print("✓ MIT control batch sent")
print("Disabling torque...")
bus.disable_torque()
print("✓ Torque disabled")
print("Setting zero position...")
bus.set_zero_position()
print("✓ Zero position set")
finally:
print("Disconnecting...")
bus.disconnect(disable_torque=True)
print("✓ Disconnected")
if __name__ == "__main__":
test_damiao_motor()
@@ -1,50 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for training-time RTC helpers."""
import torch
from lerobot.configs.types import RTCTrainingDelayDistribution
from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig
from lerobot.policies.rtc.training_time import apply_rtc_training_time, sample_rtc_delay
def test_rtc_training_config_defaults():
config = RTCTrainingConfig()
assert config.enabled is False
assert config.min_delay == 0
assert config.max_delay == 0
assert config.delay_distribution == RTCTrainingDelayDistribution.UNIFORM
assert config.exp_decay == 1.0
def test_sample_rtc_delay_uniform_range():
cfg = RTCTrainingConfig(enabled=True, min_delay=1, max_delay=4)
delays = sample_rtc_delay(cfg, batch_size=100, device=torch.device("cpu"))
assert delays.min().item() >= 1
assert delays.max().item() <= 4
def test_apply_rtc_training_time_prefix_mask():
time = torch.tensor([0.5])
delays = torch.tensor([2])
time_tokens, postfix_mask = apply_rtc_training_time(time, delays, seq_len=4)
assert time_tokens.shape == (1, 4)
assert postfix_mask.shape == (1, 4)
# Delay=2 means the first two steps are prefix (time forced to 0.0) and only the last two are postfix.
assert torch.allclose(time_tokens[0], torch.tensor([0.0, 0.0, 0.5, 0.5]))
assert torch.equal(postfix_mask[0], torch.tensor([False, False, True, True]))
+5 -5
View File
@@ -64,7 +64,7 @@ def close_service_stub(channel, server):
server.stop(None)
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_establish_learner_connection_success():
from lerobot.rl.actor import establish_learner_connection
@@ -81,7 +81,7 @@ def test_establish_learner_connection_success():
close_service_stub(channel, server)
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_establish_learner_connection_failure():
from lerobot.rl.actor import establish_learner_connection
@@ -100,7 +100,7 @@ def test_establish_learner_connection_failure():
close_service_stub(channel, server)
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_push_transitions_to_transport_queue():
from lerobot.rl.actor import push_transitions_to_transport_queue
from lerobot.transport.utils import bytes_to_transitions
@@ -135,7 +135,7 @@ def test_push_transitions_to_transport_queue():
assert_transitions_equal(deserialized_transition, transitions[i])
@require_package("grpcio", "grpc")
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_transitions_stream():
from lerobot.rl.actor import transitions_stream
@@ -167,7 +167,7 @@ def test_transitions_stream():
assert streamed_data[2].data == b"transition_data_3"
@require_package("grpcio", "grpc")
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_interactions_stream():
from lerobot.rl.actor import interactions_stream
+3 -3
View File
@@ -88,7 +88,7 @@ def cfg():
return cfg
@require_package("grpcio", "grpc")
@require_package("grpc")
@pytest.mark.timeout(10) # force cross-platform watchdog
def test_end_to_end_transitions_flow(cfg):
from lerobot.rl.actor import (
@@ -150,7 +150,7 @@ def test_end_to_end_transitions_flow(cfg):
assert_transitions_equal(transition, input_transitions[i])
@require_package("grpcio", "grpc")
@require_package("grpc")
@pytest.mark.timeout(10)
def test_end_to_end_interactions_flow(cfg):
from lerobot.rl.actor import (
@@ -223,7 +223,7 @@ def test_end_to_end_interactions_flow(cfg):
assert received == expected
@require_package("grpcio", "grpc")
@require_package("grpc")
@pytest.mark.parametrize("data_size", ["small", "large"])
@pytest.mark.timeout(10)
def test_end_to_end_parameters_flow(cfg, data_size):
+8 -8
View File
@@ -39,7 +39,7 @@ def learner_service_stub():
close_learner_service_stub(channel, server)
@require_package("grpcio", "grpc")
@require_package("grpc")
def create_learner_service_stub(
shutdown_event: Event,
parameters_queue: Queue,
@@ -75,7 +75,7 @@ def create_learner_service_stub(
return services_pb2_grpc.LearnerServiceStub(channel), channel, server
@require_package("grpcio", "grpc")
@require_package("grpc")
def close_learner_service_stub(channel, server):
channel.close()
server.stop(None)
@@ -91,7 +91,7 @@ def test_ready_method(learner_service_stub):
assert response == services_pb2.Empty()
@require_package("grpcio", "grpc")
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_interactions():
from lerobot.transport import services_pb2
@@ -135,7 +135,7 @@ def test_send_interactions():
assert interactions == [b"123", b"4", b"5", b"678"]
@require_package("grpcio", "grpc")
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_transitions():
from lerobot.transport import services_pb2
@@ -181,7 +181,7 @@ def test_send_transitions():
assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"]
@require_package("grpcio", "grpc")
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_transitions_empty_stream():
from lerobot.transport import services_pb2
@@ -209,7 +209,7 @@ def test_send_transitions_empty_stream():
assert transitions_queue.empty()
@require_package("grpcio", "grpc")
@require_package("grpc")
@pytest.mark.timeout(10) # force cross-platform watchdog
def test_stream_parameters():
import time
@@ -267,7 +267,7 @@ def test_stream_parameters():
assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1)
@require_package("grpcio", "grpc")
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_stream_parameters_with_shutdown():
from lerobot.transport import services_pb2
@@ -319,7 +319,7 @@ def test_stream_parameters_with_shutdown():
assert received_params == [b"param_batch_1", b"stop"]
@require_package("grpcio", "grpc")
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_stream_parameters_waits_and_retries_on_empty_queue():
import threading
+31 -31
View File
@@ -26,7 +26,7 @@ from lerobot.utils.transition import Transition
from tests.utils import require_cuda, require_package
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_bytes_buffer_size_empty_buffer():
from lerobot.transport.utils import bytes_buffer_size
@@ -37,7 +37,7 @@ def test_bytes_buffer_size_empty_buffer():
assert buffer.tell() == 0
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_bytes_buffer_size_small_buffer():
from lerobot.transport.utils import bytes_buffer_size
@@ -47,7 +47,7 @@ def test_bytes_buffer_size_small_buffer():
assert buffer.tell() == 0
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_bytes_buffer_size_large_buffer():
from lerobot.transport.utils import CHUNK_SIZE, bytes_buffer_size
@@ -58,7 +58,7 @@ def test_bytes_buffer_size_large_buffer():
assert buffer.tell() == 0
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_send_bytes_in_chunks_empty_data():
from lerobot.transport.utils import send_bytes_in_chunks, services_pb2
@@ -68,7 +68,7 @@ def test_send_bytes_in_chunks_empty_data():
assert len(chunks) == 0
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_single_chunk_small_data():
from lerobot.transport.utils import send_bytes_in_chunks, services_pb2
@@ -82,7 +82,7 @@ def test_single_chunk_small_data():
assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_not_silent_mode():
from lerobot.transport.utils import send_bytes_in_chunks, services_pb2
@@ -94,7 +94,7 @@ def test_not_silent_mode():
assert chunks[0].data == b"Some data"
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_send_bytes_in_chunks_large_data():
from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2
@@ -111,7 +111,7 @@ def test_send_bytes_in_chunks_large_data():
assert chunks[2].transfer_state == services_pb2.TransferState.TRANSFER_END
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_send_bytes_in_chunks_large_data_with_exact_chunk_size():
from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2
@@ -124,7 +124,7 @@ def test_send_bytes_in_chunks_large_data_with_exact_chunk_size():
assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_receive_bytes_in_chunks_empty_data():
from lerobot.transport.utils import receive_bytes_in_chunks
@@ -138,7 +138,7 @@ def test_receive_bytes_in_chunks_empty_data():
assert queue.empty()
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_receive_bytes_in_chunks_single_chunk():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -157,7 +157,7 @@ def test_receive_bytes_in_chunks_single_chunk():
assert queue.empty()
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_receive_bytes_in_chunks_single_not_end_chunk():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -175,7 +175,7 @@ def test_receive_bytes_in_chunks_single_not_end_chunk():
assert queue.empty()
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_receive_bytes_in_chunks_multiple_chunks():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -199,7 +199,7 @@ def test_receive_bytes_in_chunks_multiple_chunks():
assert queue.empty()
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_receive_bytes_in_chunks_multiple_messages():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -235,7 +235,7 @@ def test_receive_bytes_in_chunks_multiple_messages():
assert queue.empty()
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_receive_bytes_in_chunks_shutdown_during_receive():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -259,7 +259,7 @@ def test_receive_bytes_in_chunks_shutdown_during_receive():
assert queue.empty()
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_receive_bytes_in_chunks_only_begin_chunk():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -279,7 +279,7 @@ def test_receive_bytes_in_chunks_only_begin_chunk():
assert queue.empty()
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_receive_bytes_in_chunks_missing_begin():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -303,7 +303,7 @@ def test_receive_bytes_in_chunks_missing_begin():
# Tests for state_to_bytes and bytes_to_state_dict
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_state_to_bytes_empty_dict():
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
@@ -314,7 +314,7 @@ def test_state_to_bytes_empty_dict():
assert reconstructed == state_dict
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_bytes_to_state_dict_empty_data():
from lerobot.transport.utils import bytes_to_state_dict
@@ -323,7 +323,7 @@ def test_bytes_to_state_dict_empty_data():
bytes_to_state_dict(b"")
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_state_to_bytes_simple_dict():
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
@@ -347,7 +347,7 @@ def test_state_to_bytes_simple_dict():
assert torch.allclose(state_dict[key], reconstructed[key])
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_state_to_bytes_various_dtypes():
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
@@ -372,7 +372,7 @@ def test_state_to_bytes_various_dtypes():
assert torch.allclose(state_dict[key], reconstructed[key])
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_bytes_to_state_dict_invalid_data():
from lerobot.transport.utils import bytes_to_state_dict
@@ -382,7 +382,7 @@ def test_bytes_to_state_dict_invalid_data():
@require_cuda
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_state_to_bytes_various_dtypes_cuda():
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
@@ -407,7 +407,7 @@ def test_state_to_bytes_various_dtypes_cuda():
assert torch.allclose(state_dict[key], reconstructed[key])
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_python_object_to_bytes_none():
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
@@ -439,7 +439,7 @@ def test_python_object_to_bytes_none():
(1, 2, 3),
],
)
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_python_object_to_bytes_simple_types(obj):
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
@@ -450,7 +450,7 @@ def test_python_object_to_bytes_simple_types(obj):
assert type(reconstructed) is type(obj)
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_python_object_to_bytes_with_tensors():
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
@@ -475,7 +475,7 @@ def test_python_object_to_bytes_with_tensors():
assert torch.equal(obj["nested"]["tensor2"], reconstructed["nested"]["tensor2"])
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_transitions_to_bytes_empty_list():
from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes
@@ -487,7 +487,7 @@ def test_transitions_to_bytes_empty_list():
assert isinstance(reconstructed, list)
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_transitions_to_bytes_single_transition():
from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes
@@ -509,7 +509,7 @@ def test_transitions_to_bytes_single_transition():
assert_transitions_equal(transitions[0], reconstructed[0])
@require_package("grpcio", "grpc")
@require_package("grpc")
def assert_transitions_equal(t1: Transition, t2: Transition):
"""Helper to assert two transitions are equal."""
assert_observation_equal(t1["state"], t2["state"])
@@ -519,7 +519,7 @@ def assert_transitions_equal(t1: Transition, t2: Transition):
assert_observation_equal(t1["next_state"], t2["next_state"])
@require_package("grpcio", "grpc")
@require_package("grpc")
def assert_observation_equal(o1: dict, o2: dict):
"""Helper to assert two observations are equal."""
assert set(o1.keys()) == set(o2.keys())
@@ -527,7 +527,7 @@ def assert_observation_equal(o1: dict, o2: dict):
assert torch.allclose(o1[key], o2[key])
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_transitions_to_bytes_multiple_transitions():
from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes
@@ -551,7 +551,7 @@ def test_transitions_to_bytes_multiple_transitions():
assert_transitions_equal(original, reconstructed_item)
@require_package("grpcio", "grpc")
@require_package("grpc")
def test_receive_bytes_in_chunks_unknown_state():
from lerobot.transport.utils import receive_bytes_in_chunks
+2 -2
View File
@@ -167,7 +167,7 @@ def require_package_arg(func):
return wrapper
def require_package(package_name, import_name=None):
def require_package(package_name):
"""
Decorator that skips the test if the specified package is not installed.
"""
@@ -175,7 +175,7 @@ def require_package(package_name, import_name=None):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not is_package_available(pkg_name=package_name, import_name=import_name):
if not is_package_available(package_name):
pytest.skip(f"{package_name} not installed")
return func(*args, **kwargs)