Compare commits

..

4 Commits

Author SHA1 Message Date
Jade Choghari 18bba97cd6 change default to 256 latent video 2026-01-25 19:48:54 +01:00
Jade Choghari 9c14524470 add video backbone to pi05 2026-01-25 19:38:39 +01:00
Jade Choghari 5ab3dfd762 add videoprism example 2026-01-25 15:51:50 +01:00
Jade Choghari bbe9407ead add videoprism 2026-01-25 15:51:21 +01:00
181 changed files with 6130 additions and 11809 deletions
+3 -5
View File
@@ -101,11 +101,9 @@ jobs:
runs-on:
group: aws-general-8-plus
if: |
github.repository == 'huggingface/lerobot' && (
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
github.event_name == 'push' ||
github.event_name == 'workflow_dispatch'
)
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
github.event_name == 'push' ||
github.event_name == 'workflow_dispatch'
outputs:
image_tag: ${{ steps.set_tag.outputs.image_tag }}
env:
-1
View File
@@ -91,7 +91,6 @@ jobs:
name: Build and Push Docker
runs-on:
group: aws-general-8-plus
if: github.repository == 'huggingface/lerobot'
outputs:
image_tag: ${{ env.DOCKER_IMAGE_NAME }}
env:
+42 -42
View File
@@ -28,9 +28,9 @@ We don't expect the same optimal settings for a dataset of images from a simulat
For these reasons, we run this benchmark on four representative datasets:
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
- `aliberts/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
@@ -179,7 +179,7 @@ python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
aliberts/aloha_mobile_shrimp_image \
--vcodec libx264 libx265 \
--pix-fmt yuv444p yuv420p \
--g 2 20 None \
@@ -203,9 +203,9 @@ python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
lerobot/paris_street \
lerobot/kitchen \
aliberts/aloha_mobile_shrimp_image \
aliberts/paris_street \
aliberts/kitchen \
--vcodec libx264 libx265 \
--pix-fmt yuv444p yuv420p \
--g 1 2 3 4 5 6 10 15 20 40 None \
@@ -221,9 +221,9 @@ python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
lerobot/paris_street \
lerobot/kitchen \
aliberts/aloha_mobile_shrimp_image \
aliberts/paris_street \
aliberts/kitchen \
--vcodec libsvtav1 \
--pix-fmt yuv420p \
--g 1 2 3 4 5 6 10 15 20 40 None \
@@ -252,37 +252,37 @@ Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_read
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
| video_images_size_ratio | vcodec | pix_fmt | | | |
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
| video_images_size_ratio | vcodec | pix_fmt | | | |
| ---------------------------------- | ---------- | ------- | --------- | --------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
| aliberts/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
| aliberts/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
| ---------------------------------- | ------- | ------- | -------- | ------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
| aliberts/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
| aliberts/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
| | | vcodec | pix_fmt | | | |
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
| | | libx264 | | libx265 | | libsvtav1 |
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
| | | vcodec | pix_fmt | | | |
| ---------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
| | | libx264 | | libx265 | | libsvtav1 |
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
| aliberts/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
| aliberts/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
| aliberts/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
+2 -12
View File
@@ -7,6 +7,8 @@
- sections:
- local: il_robots
title: Imitation Learning for Robots
- local: cameras
title: Cameras
- local: bring_your_own_policies
title: Bring Your Own Policies
- local: integrate_hardware
@@ -27,8 +29,6 @@
title: Porting Large Datasets
- local: using_dataset_tools
title: Using the Dataset Tools
- local: dataset_subtask
title: Using Subtasks in the Dataset
title: "Datasets"
- sections:
- local: act
@@ -99,19 +99,11 @@
title: Unitree G1
- local: earthrover_mini_plus
title: Earth Rover Mini
- local: omx
title: OMX
- local: openarm
title: OpenArm
title: "Robots"
- sections:
- local: phone_teleop
title: Phone
title: "Teleoperators"
- sections:
- local: cameras
title: Cameras
title: "Sensors"
- sections:
- local: torch_accelerators
title: PyTorch accelerators
@@ -121,8 +113,6 @@
title: Notebooks
- local: feetech
title: Updating Feetech Firmware
- local: damiao
title: Damiao Motors and CAN Bus
title: "Resources"
- sections:
- local: contributing
+81 -95
View File
@@ -1,22 +1,12 @@
# Cameras
LeRobot offers multiple options for video capture:
LeRobot offers multiple options for video capture, including phone cameras, built-in laptop cameras, external webcams, and Intel RealSense cameras. To efficiently record frames from most cameras, you can use either the `OpenCVCamera` or `RealSenseCamera` class. For additional compatibility details on the `OpenCVCamera` class, refer to the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
| Class | Supported Cameras |
| ----------------- | ----------------------------------- |
| `OpenCVCamera` | Phone, built-in laptop, USB webcams |
| `ZMQCamera` | Network-connected cameras |
| `RealSenseCamera` | Intel RealSense (with depth) |
| `Reachy2Camera` | Reachy 2 robot cameras |
### Finding your camera
> [!TIP]
> For `OpenCVCamera` compatibility details, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
To instantiate a camera, you need a camera identifier. This identifier might change if you reboot your computer or re-plug your camera, a behavior mostly dependant on your operating system.
### Find your camera
Every camera requires a unique identifier to be instantiated, allowing you to distinguish between multiple connected devices.
`OpenCVCamera` and `RealSenseCamera` support auto-discovery. Run the command below to list available devices and their identifiers. Note that these identifiers may change after rebooting your computer or re-plugging the camera, depending on your operating system.
To find the camera indices of the cameras plugged into your system, run the following script:
```bash
lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
@@ -24,7 +14,7 @@ lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
The output will look something like this if you have two cameras connected:
```bash
```
--- Detected Cameras ---
Camera #0:
Name: OpenCV Camera @ 0
@@ -43,37 +33,13 @@ Camera #0:
> [!WARNING]
> When using Intel RealSense cameras in `macOS`, you could get this [error](https://github.com/IntelRealSense/librealsense/issues/12307): `Error finding RealSense cameras: failed to set power state`, this can be solved by running the same command with `sudo` permissions. Note that using RealSense cameras in `macOS` is unstable.
`ZMQCamera` and `Reachy2Camera` do not support auto-discovery. They must be configured manually by providing their network address and port or robot SDK settings.
## Use Cameras
## Use cameras
Below are two examples, demonstrating how to work with the API.
### Frame access modes
All camera classes implement three access modes for capturing frames:
| Method | Behavior | Blocks? | Best For |
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------- | ---------------------------------------- |
| `read()` | Waits for the camera hardware to return a frame. May block for a long time depending on the camera and SDK. | Yes | Simple scripts, sequential capture |
| `async_read(timeout_ms)` | Returns the latest unconsumed frame from background thread. Blocks only if buffer is empty, up to `timeout_ms`. Raises `TimeoutError` if no frame arrives. | With a timeout | Control loops synchronized to camera FPS |
| `read_latest(max_age_ms)` | Peeks at the most recent frame in buffer (may be stale). Raises `TimeoutError` if frame is older than `max_age_ms`. | No | UI visualization, logging, monitoring |
### Usage examples
The following examples show how to use the camera API to configure and capture frames from different camera types.
- **Blocking and non-blocking frame capture** using an OpenCV-based camera
- **Asynchronous frame capture** using an OpenCV-based camera
- **Color and depth capture** using an Intel RealSense camera
> [!WARNING]
> Failing to cleanly disconnect cameras can cause resource leaks. Use the context manager protocol to ensure automatic cleanup:
>
> ```python
> with OpenCVCamera(config) as camera:
> ...
> ```
>
> You can also call `connect()` and `disconnect()` manually, but always use a `finally` block for the latter.
<hfoptions id="shell_restart">
<hfoption id="Open CV Camera">
@@ -94,30 +60,16 @@ config = OpenCVCameraConfig(
)
# Instantiate and connect an `OpenCVCamera`, performing a warm-up read (default).
with OpenCVCamera(config) as camera:
# Read a frame synchronously — blocks until hardware delivers a new frame
frame = camera.read()
print(f"read() call returned frame with shape:", frame.shape)
# Read a frame asynchronously with a timeout — returns the latest unconsumed frame or waits up to timeout_ms for a new one
try:
for i in range(10):
frame = camera.async_read(timeout_ms=200)
print(f"async_read call returned frame {i} with shape:", frame.shape)
except TimeoutError as e:
print(f"No frame received within timeout: {e}")
# Instantly return a frame - returns the most recent frame captured by the camera
try:
initial_frame = camera.read_latest(max_age_ms=1000)
for i in range(10):
frame = camera.read_latest(max_age_ms=1000)
print(f"read_latest call returned frame {i} with shape:", frame.shape)
print(f"Was a new frame received by the camera? {not (initial_frame == frame).any()}")
except TimeoutError as e:
print(f"Frame too old: {e}")
camera = OpenCVCamera(config)
camera.connect()
# Read frames asynchronously in a loop via `async_read(timeout_ms)`
try:
for i in range(10):
frame = camera.async_read(timeout_ms=200)
print(f"Async frame {i} shape:", frame.shape)
finally:
camera.disconnect()
```
<!-- prettier-ignore-end -->
@@ -159,10 +111,10 @@ finally:
</hfoption>
</hfoptions>
## Use your phone's camera
## Use your phone
<hfoptions id="use phone">
<hfoption id="iPhone & macOS">
<hfoption id="Mac">
To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
@@ -172,49 +124,83 @@ To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
For more details, visit [Apple support](https://support.apple.com/en-gb/guide/mac-help/mchl77879b8a/mac).
Your iPhone should be detected automatically when running the camera setup script in the next section.
</hfoption>
<hfoption id="OBS virtual camera">
<hfoption id="Linux">
If you want to use your phone as a camera using OBS, follow these steps to set up a virtual camera.
If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera
1. _(Linux only) Install `v4l2loopback-dkms` and `v4l-utils`_. These packages create virtual camera devices and verify their settings. Install with:
1. _Install `v4l2loopback-dkms` and `v4l-utils`_. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using:
```bash
<!-- prettier-ignore-start -->
```python
sudo apt install v4l2loopback-dkms v4l-utils
```
<!-- prettier-ignore-end -->
2. _Install the [DroidCam app](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
3. _Download and install [OBS Studio](https://obsproject.com)_.
4. _Download and install the [DroidCam OBS plugin](https://droidcam.app/obs)_.
5. _Start OBS Studio_.
2. _Install [DroidCam](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
3. _Install [OBS Studio](https://obsproject.com)_. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org):
6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480` to avoid the watermarks.
7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video` or `OBS > Preferences... > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it.
<!-- prettier-ignore-start -->
```python
flatpak install flathub com.obsproject.Studio
```
<!-- prettier-ignore-end -->
4. _Install the DroidCam OBS plugin_. This plugin integrates DroidCam with OBS Studio. Install it with:
<!-- prettier-ignore-start -->
```python
flatpak install flathub com.obsproject.Studio.Plugin.DroidCam
```
<!-- prettier-ignore-end -->
5. _Start OBS Studio_. Launch with:
<!-- prettier-ignore-start -->
```python
flatpak run com.obsproject.Studio
```
<!-- prettier-ignore-end -->
6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`.
7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in.
8. _Start virtual camera_. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide).
9. _Verify the virtual camera setup and resolution_.
- **Linux**: Use `v4l2-ctl` to list devices and check resolution:
```bash
v4l2-ctl --list-devices # find VirtualCam and note its /dev/videoX path
v4l2-ctl -d /dev/videoX --get-fmt-video # replace with your VirtualCam path
```
You should see `VirtualCam` listed and resolution `640x480`.
- **macOS**: Open Photo Booth or FaceTime and select "OBS Virtual Camera" as the input.
- **Windows**: The native Camera app doesn't support virtual cameras. Use a video conferencing app (Zoom, Teams) or run `lerobot-find-cameras opencv` directly to verify.
9. _Verify the virtual camera setup_. Use `v4l2-ctl` to list the devices:
<details>
<summary><strong>Troubleshooting</strong></summary>
<!-- prettier-ignore-start -->
```python
v4l2-ctl --list-devices
```
<!-- prettier-ignore-end -->
> The virtual camera resolution is incorrect.
You should see an entry like:
Delete the virtual camera source and recreate it. The resolution cannot be changed after creation.
```
VirtualCam (platform:v4l2loopback-000):
/dev/video1
```
> Error reading frame in background thread for OpenCVCamera(X): OpenCVCamera(X) frame width=640 or height=480 do not match configured width=1920 or height=1080.
10. _Check the camera resolution_. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`.
This error is caused by OBS Virtual Camera advertising a `1920x1080` resolution despite rescaling. The only fix for now is to comment out the width and height check in `_postprocess_image()`.
<!-- prettier-ignore-start -->
```python
v4l2-ctl -d /dev/video1 --get-fmt-video
```
<!-- prettier-ignore-end -->
</details>
You should see an entry like:
```
>>> Format Video Capture:
>>> Width/Height : 640/480
>>> Pixel Format : 'YUYV' (YUYV 4:2:2)
```
Troubleshooting: If the resolution is not correct you will have to delete the Virtual Camera port and try again as it cannot be changed.
If everything is set up correctly, you can proceed with the rest of the tutorial.
</hfoption>
</hfoptions>
If everything is set up correctly, your phone will appear as a standard OpenCV camera and can be used with `OpenCVCamera`.
-165
View File
@@ -1,165 +0,0 @@
# Damiao Motors and CAN Bus
This guide covers setup and usage of Damiao motors with LeRobot via CAN bus communication.
Currently, only Linux is supported, as the OpenArms CAN adapter only has drivers for Linux.
## Linux CAN Setup
Before using Damiao motors, you need to set up the CAN interface on your Linux system.
### Install CAN Utilities
```bash
sudo apt-get install can-utils
```
### Configure CAN Interface (Manual)
For standard CAN FD (recommended for OpenArms):
```bash
sudo ip link set can0 down
sudo ip link set can0 type can bitrate 1000000 dbitrate 5000000 fd on
sudo ip link set can0 up
```
For standard CAN (without FD):
```bash
sudo ip link set can0 down
sudo ip link set can0 type can bitrate 1000000
sudo ip link set can0 up
```
### Configure CAN Interface (Using LeRobot)
LeRobot provides a utility script to setup and test CAN interfaces:
```bash
# Setup multiple interfaces (e.g., OpenArms Followers with 2 CAN buses)
lerobot-setup-can --mode=setup --interfaces=can0,can1
```
## Debugging CAN Communication
Use the built-in debug tools to test motor communication:
```bash
# Test motors on all interfaces
lerobot-setup-can --mode=test --interfaces=can0,can1
# Run speed/latency test
lerobot-setup-can --mode=speed --interfaces=can0
```
The test mode will scan for motors (IDs 0x01-0x08) and report which ones respond. Example output:
```
can0: UP (CAN FD)
Motor 0x01 (joint_1): ✓ FOUND
→ Response 0x11 [FD]: 00112233...
Motor 0x02 (joint_2): ✓ FOUND
Motor 0x03 (joint_3): ✗ No response
...
Summary: 2/8 motors found
```
## Usage
### Basic Setup
```python
from lerobot.motors import Motor
from lerobot.motors.damiao import DamiaoMotorsBus
# Define your motors with send/receive CAN IDs
motors = {
"joint_1": Motor(id=0x01, motor_type_str="dm8009", recv_id=0x11),
"joint_2": Motor(id=0x02, motor_type_str="dm4340", recv_id=0x12),
"joint_3": Motor(id=0x03, motor_type_str="dm4310", recv_id=0x13),
}
# Create the bus
bus = DamiaoMotorsBus(
port="can0", # Linux socketcan interface
motors=motors,
)
# Connect
bus.connect()
```
### Reading Motor States
```python
# Read single motor position (degrees)
position = bus.read("Present_Position", "joint_1")
# Read from multiple motors
positions = bus.sync_read("Present_Position") # All motors
positions = bus.sync_read("Present_Position", ["joint_1", "joint_2"])
# Read all states at once (position, velocity, torque)
states = bus.sync_read_all_states()
# Returns: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
```
### Writing Motor Commands
```python
# Enable torque
bus.enable_torque()
# Set goal position (degrees)
bus.write("Goal_Position", "joint_1", 45.0)
# Set positions for multiple motors
bus.sync_write("Goal_Position", {
"joint_1": 45.0,
"joint_2": -30.0,
"joint_3": 90.0,
})
# Disable torque
bus.disable_torque()
```
## Configuration Options
| Parameter | Default | Description |
| -------------- | --------- | ----------------------------------------------------------- |
| `port` | - | CAN interface (`can0`) or serial port (`/dev/cu.usbmodem*`) |
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
| `bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
| `data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
## Motor Configuration
Each motor requires:
- `id`: CAN ID for sending commands
- `motor_type`: One of the supported motor types (e.g., `"dm8009"`, `"dm4340"`)
- `recv_id`: CAN ID for receiving responses
OpenArms default IDs follow the pattern: send ID `0x0N`, receive ID `0x1N` where N is the joint number.
## Troubleshooting
### No Response from Motors
1. **Check power**
2. **Verify CAN wiring**: Check CAN-H, CAN-L, and GND connections
3. **Check motor IDs**: Use Damiao Debugging Tools to verify/configure IDs
4. **Test CAN interface**: Run `candump can0` to see if messages are being received
5. **Run diagnostics**: `lerobot-setup-can --mode=test --interfaces=can0`
### Motor Timeout Parameter
If motors were configured with timeout=0, they won't respond to commands. Use Damiao Debugging Tools to set a non-zero timeout value.
### Verify CAN FD Status
```bash
ip -d link show can0 | grep fd
```
-278
View File
@@ -1,278 +0,0 @@
# Using Subtasks in LeRobot Datasets
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
## What are Subtasks?
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
1. "Approach the apple"
2. "Grasp the apple"
3. "Lift the apple"
4. "Move to basket"
5. "Release the apple"
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
width="80%"
/>
<p>
<em>Figure: Overview of subtask annotation.</em>
</p>
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
## Dataset Structure
Subtask information is stored in the dataset metadata:
```
my-dataset/
├── data/
│ └── ...
├── meta/
│ ├── info.json
│ ├── stats.json
│ ├── tasks.parquet
│ ├── subtasks.parquet # Subtask index → subtask string mapping
│ └── episodes/
│ └── ...
└── videos/
└── ...
```
### Subtasks Parquet File
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
| subtask_index | subtask (index column) |
| ------------- | ---------------------- |
| 0 | "Approach the apple" |
| 1 | "Grasp the apple" |
| 2 | "Lift the apple" |
| ... | ... |
### Frame-Level Annotations
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
```python
# Example frame data in the parquet file
{
"index": 42,
"timestamp": 1.4,
"episode_index": 0,
"task_index": 0,
"subtask_index": 2, # References "Lift the apple"
"observation.state": [...],
"action": [...],
}
```
## Annotating Datasets with Subtasks
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
After completing your annotation:
1. Click "Push to Hub" to upload your annotated dataset
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
## Loading Datasets with Subtasks
When you load a dataset with subtask annotations, the subtask information is automatically available:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Load a dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Access a sample
sample = dataset[100]
# The sample includes both task and subtask information
print(sample["task"]) # "Collect the fruit"
print(sample["subtask"]) # "Grasp the apple"
print(sample["task_index"]) # tensor(0)
print(sample["subtask_index"]) # tensor(2)
```
### Checking for Subtask Support
You can check if a dataset has subtask annotations:
```python
# Check if subtasks are available
has_subtasks = (
"subtask_index" in dataset.features
and dataset.meta.subtasks is not None
)
if has_subtasks:
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
print("Subtasks:", list(dataset.meta.subtasks.index))
```
## Using Subtasks for Training
### With the Tokenizer Processor
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
```python
from lerobot.processor.tokenizer_processor import TokenizerProcessor
from lerobot.processor.pipeline import ProcessorPipeline
# Create a tokenizer processor
tokenizer_processor = TokenizerProcessor(
tokenizer_name_or_path="google/paligemma-3b-pt-224",
padding="max_length",
max_length=64,
)
# The processor will automatically tokenize subtasks if present in the batch
# and add them to the observation under:
# - "observation.subtask.tokens"
# - "observation.subtask.attention_mask"
```
When subtasks are available in the batch, the tokenizer processor adds:
- `observation.subtask.tokens`: Tokenized subtask text
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
### DataLoader with Subtasks
```python
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=16,
shuffle=True,
)
for batch in dataloader:
# Access subtask information in the batch
subtasks = batch["subtask"] # List of subtask strings
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
# Use for training hierarchical policies or reward models
print(f"Batch subtasks: {set(subtasks)}")
```
## Example Datasets with Subtask Annotations
Try loading a dataset with subtask annotations:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Example dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Explore the subtasks
print("Available subtasks:")
for subtask_name in dataset.meta.subtasks.index:
print(f" - {subtask_name}")
# Get subtask distribution
subtask_counts = {}
for i in range(len(dataset)):
sample = dataset[i]
subtask = sample["subtask"]
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
print("\nSubtask distribution:")
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
print(f" {subtask}: {count} frames")
```
## Use Cases
### 1. Hierarchical Policy Training
Train policies that predict both actions and current subtask:
```python
class HierarchicalPolicy(nn.Module):
def __init__(self, num_subtasks):
super().__init__()
self.action_head = nn.Linear(hidden_dim, action_dim)
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
def forward(self, observations):
features = self.encoder(observations)
actions = self.action_head(features)
subtask_logits = self.subtask_head(features)
return actions, subtask_logits
```
### 2. Stage-Aware Reward Modeling (SARM)
Build reward models that understand task progression:
```python
# SARM predicts:
# - Stage: Which subtask is being executed (discrete)
# - Progress: How far along the subtask (continuous 0-1)
class SARMRewardModel(nn.Module):
def forward(self, observations):
features = self.encoder(observations)
stage_logits = self.stage_classifier(features)
progress = self.progress_regressor(features)
return stage_logits, progress
```
### 3. Progress Visualization
Monitor robot execution by tracking subtask progression:
```python
def visualize_execution(model, observations):
for t, obs in enumerate(observations):
action, subtask_logits = model(obs)
predicted_subtask = subtask_names[subtask_logits.argmax()]
print(f"t={t}: Executing '{predicted_subtask}'")
```
## API Reference
### LeRobotDataset Properties
| Property | Type | Description |
| --------------------------- | ---------------------- | ------------------------------------------ |
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
### Sample Keys
When subtasks are available, each sample includes:
| Key | Type | Description |
| --------------- | -------------- | ------------------------------------ |
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
| `subtask` | `str` | Natural language subtask description |
## Related Resources
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
+1 -7
View File
@@ -1,11 +1,5 @@
# EarthRover Mini Plus
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Earth_Rover_Mini_5_240c9adc-4f9e-44b7-982f-5d1dc24af1d8.png.webp"
alt="EarthRover Mini Plus"
width="70%"
/>
The EarthRover Mini Plus is a fully open source mobile robot that connects through the cloud using the Frodobots SDK. This lets you control the robot and record datasets for training AI models.
## What You Need
@@ -185,7 +179,7 @@ echo $HF_USER
Use the standard recording command:
```bash
lerobot-record \
python src/lerobot/scripts/lerobot_record.py \
--robot.type=earthrover_mini_plus \
--teleop.type=keyboard_rover \
--dataset.repo_id=your_username/dataset_name \
+5 -5
View File
@@ -224,7 +224,7 @@ lerobot-record \
--teleop.port=/dev/tty.usbmodem1201 \
--teleop.id=right \
--teleop.side=right \
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
--dataset.single_task="Hand recording test with video data" \
--dataset.num_episodes=1 \
--dataset.episode_time_s=5 \
@@ -241,7 +241,7 @@ lerobot-replay \
--robot.port=/dev/tty.usbmodem58760432281 \
--robot.id=right \
--robot.side=right \
--dataset.repo_id=<USER>/hand_record_test_with_camera \
--dataset.repo_id=nepyope/hand_record_test_with_camera \
--dataset.episode=0
```
@@ -249,13 +249,13 @@ lerobot-replay \
```bash
lerobot-train \
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
--policy.type=act \
--output_dir=outputs/train/hopejr_hand \
--job_name=hopejr \
--policy.device=mps \
--wandb.enable=true \
--policy.repo_id=<USER>/hand_test_policy
--policy.repo_id=nepyope/hand_test_policy
```
### Evaluate
@@ -270,7 +270,7 @@ lerobot-record \
--robot.side=right \
--robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
--display_data=false \
--dataset.repo_id=<USER>/eval_hopejr \
--dataset.repo_id=nepyope/eval_hopejr \
--dataset.single_task="Evaluate hopejr hand policy" \
--dataset.num_episodes=10 \
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
+3 -5
View File
@@ -1,15 +1,13 @@
# Installation
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
## Install [`miniforge`](https://conda-forge.org/download/)
```bash
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh
```
## Step 2: Environment Setup
## Environment Setup
Create a virtual environment with Python 3.10, using conda:
@@ -40,7 +38,7 @@ conda install ffmpeg -c conda-forge
>
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
## Step 3: Install LeRobot 🤗
## Install LeRobot 🤗
### From Source
-6
View File
@@ -1,11 +1,5 @@
# LeKiwi
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/1740517739083.jpeg"
alt="LeKiwi"
width="70%"
/>
In the steps below, we explain how to assemble the LeKiwi mobile robot.
## Source the parts
-1
View File
@@ -42,7 +42,6 @@ lerobot-eval \
```
- `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.).
- `--env.task_ids` picks task ids to run (`[0]`, `[1,2,3]`, etc.). Omit this flag (or set it to `null`) to run all tasks in the suite.
- `--eval.batch_size` controls how many environments run in parallel.
- `--eval.n_episodes` sets how many episodes to run in total.
-197
View File
@@ -1,197 +0,0 @@
## Order and Assemble the parts
First, assemble the OMX hardware following the official assembly guide.
OMX Assembly Guide: https://ai.robotis.com/omx/assembly_guide_omx.html
OMX robots are shipped preconfigured from the factory. Motor IDs, communication parameters, and joint offsets are already set, so no additional motor setup or calibration is required before using LeRobot.
## Install LeRobot 🤗
To install LeRobot, follow our [Installation Guide](./installation)
In addition to these instructions, you need to install the Dynamixel SDK:
```bash
pip install -e ".[dynamixel]"
```
## Connect the robot
To find the port for each bus servo adapter, run this script:
```bash
lerobot-find-port
```
This command runs and when prompted, disconnect the USB cable from either the leader or follower arm and press Enter. The output will show 'The port of this MotorsBus is [port]'. This identifies the port for the disconnected arm. Repeat for the other arm to identify both ports.
<hfoptions id="find_port">
<hfoption id="Mac">
Example output on macOS:
```
Finding all available ports for the MotorBus.
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
Remove the USB cable from your MotorsBus and press Enter when done.
[...Disconnect corresponding leader or follower arm and press Enter...]
The port of this MotorsBus is /dev/tty.usbmodem575E0032081
Reconnect the USB cable.
```
Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm.
</hfoption>
<hfoption id="Linux">
On Linux, we strongly recommend using udev rules to assign persistent and human-readable device names to the OMX leader and follower arms. This avoids issues where device names such as ttyACM0 and ttyACM1 change when the robot is unplugged, replugged, or when the system is rebooted.
#### 1. Find your device serial numbers
You should have obtained the port numbers like ../../ttyACM? for the leader and follower using `lerobot-find-port`. You can match those results with the serial numbers using the `ls -l /dev/serial/by-id/` command.
To create udev rules, you need the unique serial number for each OMX device. The easiest way is to list devices under:
```bash
ls -l /dev/serial/by-id/
```
You will see output similar to:
```bash
usb-ROBOTIS_OpenRB-150_228BDD7B503059384C2E3120FF0A2B19-if00 -> ../../ttyACM0
usb-ROBOTIS_OpenRB-150_67E1ED68503059384C2E3120FF092234-if00 -> ../../ttyACM1
```
In each line, the serial number is the long string after `usb-ROBOTIS_OpenRB-150_` and before `-if00`.
Follower serial: `228BDD7B503059384C2E3120FF0A2B19`
Leader serial: `67E1ED68503059384C2E3120FF092234`
#### 2. Create the udev rule
Create a new udev rule file:
```bash
sudo nano /etc/udev/rules.d/99-omx.rules
```
Paste the following lines, replacing the serial numbers with the values you found above:
```bash
SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="228BDD7B503059384C2E3120FF0A2B19", SYMLINK+="omx_follower"
SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="67E1ED68503059384C2E3120FF092234", SYMLINK+="omx_leader"
```
Save the file and reload udev rules:
```bash
sudo udevadm control --reload-rules
sudo udevadm trigger
```
Now unplug and replug both devices once.
#### 3. Verify the symlinks
Check that the persistent device names exist:
```bash
ls -l /dev/omx_follower /dev/omx_leader
```
You should see them pointing to ttyACM\* devices:
```bash
/dev/omx_follower -> ttyACM*
/dev/omx_leader -> ttyACM*
```
These names remain stable across reboots and reconnections.
</hfoption>
</hfoptions>
## Teleoperate
After identifying the correct ports, you can directly teleoperate the follower arm using the leader arm.
<hfoptions id="teleoperate">
<hfoption id="Mac">
### Teleoperate without camera
```bash
lerobot-teleoperate \
--robot.type=omx_follower \
--robot.port=<your_follower_port> \
--robot.id=omx_follower_arm \
--teleop.type=omx_leader \
--teleop.port=<your_leader_port> \
--teleop.id=omx_leader_arm
```
During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps.
### Teleoperate with camera
You can also enable camera input during teleoperation by providing a camera configuration for the follower arm.
```bash
lerobot-teleoperate \
--robot.type=omx_follower \
--robot.port=<your_follower_port> \
--robot.id=omx_follower_arm \
--robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \
--teleop.type=omx_leader \
--teleop.port=<your_leader_port> \
--teleop.id=omx_leader_arm \
--display_data=true
```
When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning.
</hfoption>
<hfoption id="Linux">
### Teleoperate without camera
```bash
lerobot-teleoperate \
--robot.type=omx_follower \
--robot.port=/dev/omx_follower \
--robot.id=omx_follower_arm \
--teleop.type=omx_leader \
--teleop.port=/dev/omx_leader \
--teleop.id=omx_leader_arm
```
During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps.
### Teleoperate with camera
You can also enable camera input during teleoperation by providing a camera configuration for the follower arm.
```bash
lerobot-teleoperate \
--robot.type=omx_follower \
--robot.port=/dev/omx_follower \
--robot.id=omx_follower_arm \
--robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \
--teleop.type=omx_leader \
--teleop.port=/dev/omx_leader \
--teleop.id=omx_leader_arm \
--display_data=true
```
When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning.
</hfoption>
</hfoptions>
Congrats 🎉, your robot is all set to learn a task on its own.
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/robotis).
-276
View File
@@ -1,276 +0,0 @@
# OpenArm
[OpenArm](https://openarm.dev) is an open-source 7DOF humanoid arm designed for physical AI research and deployment.
To get your OpenArm, assembled or DIY, and join the global community, browse verified and certified manufacturers worldwide at [openarm.dev](https://openarm.dev).
## What's Unique?
- **Human-Scale Design**: OpenArm is designed with human-like proportions, scaled for a person around 160-165cm tall. This provides an optimal balance between practical reach and manageable inertia for safe, responsive operation.
- **Safety-First Architecture**: Built with QDD backdrivable motors and high compliance, OpenArm prioritizes safe human-robot interaction while maintaining practical payload capabilities (6.0kg peak / 4.1kg nominal) for real-world tasks.
- **Built for Durability**: Critical structural components use aluminum and stainless steel construction, ensuring robust performance for repetitive data collection and continuous research use.
- **Fully Accessible & Buildable**: Every component, from CNC parts and 3D-printed casings to electrical wiring is designed to be purchasable and buildable by individual researchers and labs, with complete fabrication data provided.
- **Practical & Affordable**: At $6,500 USD for a complete bimanual system, OpenArm delivers research-grade capabilities at a fraction of traditional humanoid robot costs.
## Platform Requirements
<Tip warning={true}>
**Linux Only**: OpenArm currently only works on Linux. The CAN bus USB adapter
does not have macOS drivers and has not been tested on Windows.
</Tip>
## Safety Guide
Before operating OpenArm, please read the [official safety guide](https://docs.openarm.dev/getting-started/safety-guide). Key points:
- **Secure installation**: Fasten the arm to a flat, stable surface with screws or clamps
- **Safe distance**: Keep body parts and objects outside the range of motion during operation
- **Protective equipment**: Always wear safety goggles; use additional PPE as needed
- **Payload limits**: Do not exceed specified payload limits (6.0kg peak / 4.1kg nominal per arm)
- **Emergency stop**: Know the location and operation of the emergency stop device
- **Regular inspection**: Check for loose screws, damaged mechanical limits, unusual noises, and wiring damage
## Hardware Setup
Follow the official [OpenArm hardware documentation](https://docs.openarm.dev) for:
- Bill of materials and sourcing
- 3D printing instructions
- Mechanical assembly
- Electrical wiring
The hardware repositories are available at [github.com/enactic/openarm](https://github.com/enactic/openarm).
## CAN Bus Setup
OpenArm uses CAN bus communication with Damiao motors. Once you have the CAN bus USB adapter plugged into your Linux PC, follow the [Damiao Motors and CAN Bus guide](./damiao) to configure the interface.
Quick setup:
```bash
# Setup CAN interfaces
lerobot-setup-can --mode=setup --interfaces=can0,can1
# Test motor communication
lerobot-setup-can --mode=test --interfaces=can0,can1
```
## Install LeRobot 🤗
Follow our [Installation Guide](./installation), then install the Damiao motor support:
```bash
pip install -e ".[damiao]"
```
## Usage
### Follower Arm (Robot)
<hfoptions id="follower">
<hfoption id="Command">
```bash
lerobot-calibrate \
--robot.type=openarm_follower \
--robot.port=can0 \
--robot.side=right \
--robot.id=my_openarm_follower
```
</hfoption>
<hfoption id="API example">
```python
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
config = OpenArmFollowerConfig(
port="can0",
side="right", # or "left" for left arm
id="my_openarm_follower",
)
follower = OpenArmFollower(config)
follower.connect()
# Read current state
obs = follower.get_observation()
print(obs)
# Send action (position in degrees)
action = {
"joint_1.pos": 0.0,
"joint_2.pos": 0.0,
"joint_3.pos": 0.0,
"joint_4.pos": 45.0,
"joint_5.pos": 0.0,
"joint_6.pos": 0.0,
"joint_7.pos": 0.0,
"gripper.pos": 0.0,
}
follower.send_action(action)
follower.disconnect()
```
</hfoption>
</hfoptions>
### Leader Arm (Teleoperator)
The leader arm is used for teleoperation - manually moving it to control the follower arm.
<hfoptions id="leader">
<hfoption id="Command">
```bash
lerobot-calibrate \
--teleop.type=openarm_leader \
--teleop.port=can1 \
--teleop.id=my_openarm_leader
```
</hfoption>
<hfoption id="API example">
```python
from lerobot.teleoperators.openarm_leader import OpenArmLeader, OpenArmLeaderConfig
config = OpenArmLeaderConfig(
port="can1",
id="my_openarm_leader",
manual_control=True, # Disable torque for manual movement
)
leader = OpenArmLeader(config)
leader.connect()
# Read current position (as action to send to follower)
action = leader.get_action()
print(action)
leader.disconnect()
```
</hfoption>
</hfoptions>
### Teleoperation
To teleoperate OpenArm with leader-follower control:
```bash
lerobot-teleoperate \
--robot.type=openarm_follower \
--robot.port=can0 \
--robot.side=right \
--robot.id=my_follower \
--teleop.type=openarm_leader \
--teleop.port=can1 \
--teleop.id=my_leader
```
### Bimanual Teleoperation
To teleoperate a bimanual OpenArm setup with two leader and two follower arms:
```bash
lerobot-teleoperate \
--robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can0 \
--robot.left_arm_config.side=left \
--robot.right_arm_config.port=can1 \
--robot.right_arm_config.side=right \
--robot.id=my_bimanual_follower \
--teleop.type=bi_openarm_leader \
--teleop.left_arm_config.port=can2 \
--teleop.right_arm_config.port=can3 \
--teleop.id=my_bimanual_leader
```
### Recording Data
To record a dataset during teleoperation:
```bash
lerobot-record \
--robot.type=openarm_follower \
--robot.port=can0 \
--robot.side=right \
--robot.id=my_follower \
--teleop.type=openarm_leader \
--teleop.port=can1 \
--teleop.id=my_leader \
--repo-id=my_hf_username/my_openarm_dataset \
--fps=30 \
--num-episodes=10
```
## Configuration Options
### Follower Configuration
| Parameter | Default | Description |
| --------------------- | --------- | ---------------------------------------------------------- |
| `port` | - | CAN interface (e.g., `can0`) |
| `side` | `None` | Arm side: `"left"`, `"right"`, or `None` for custom limits |
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
| `max_relative_target` | `None` | Safety limit for relative target positions |
| `position_kp` | Per-joint | Position control proportional gains |
| `position_kd` | Per-joint | Position control derivative gains |
### Leader Configuration
| Parameter | Default | Description |
| ------------------ | --------- | ----------------------------------- |
| `port` | - | CAN interface (e.g., `can1`) |
| `manual_control` | `True` | Disable torque for manual movement |
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
## Motor Configuration
OpenArm uses Damiao motors with the following default configuration:
| Joint | Motor Type | Send ID | Recv ID |
| --------------------------- | ---------- | ------- | ------- |
| joint_1 (Shoulder pan) | DM8009 | 0x01 | 0x11 |
| joint_2 (Shoulder lift) | DM8009 | 0x02 | 0x12 |
| joint_3 (Shoulder rotation) | DM4340 | 0x03 | 0x13 |
| joint_4 (Elbow flex) | DM4340 | 0x04 | 0x14 |
| joint_5 (Wrist roll) | DM4310 | 0x05 | 0x15 |
| joint_6 (Wrist pitch) | DM4310 | 0x06 | 0x16 |
| joint_7 (Wrist rotation) | DM4310 | 0x07 | 0x17 |
| gripper | DM4310 | 0x08 | 0x18 |
## Troubleshooting
### No Response from Motors
1. Check power supply connections
2. Verify CAN wiring (CAN-H, CAN-L, GND)
3. Run diagnostics: `lerobot-setup-can --mode=test --interfaces=can0`
4. See the [Damiao troubleshooting guide](./damiao#troubleshooting) for more details
### CAN Interface Not Found
Ensure the CAN interface is configured:
```bash
ip link show can0
```
## Resources
- [OpenArm Website](https://openarm.dev)
- [OpenArm Documentation](https://docs.openarm.dev)
- [OpenArm GitHub](https://github.com/enactic/openarm)
- [Safety Guide](https://docs.openarm.dev/getting-started/safety-guide)
- [Damiao Motors and CAN Bus](./damiao)
+1 -1
View File
@@ -60,7 +60,7 @@ policy.type=pi0
For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your_dataset \
--policy.type=pi0 \
--output_dir=./outputs/pi0_training \
+1 -1
View File
@@ -56,7 +56,7 @@ policy.type=pi05
Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py\
--dataset.repo_id=your_dataset \
--policy.type=pi05 \
--output_dir=./outputs/pi05_training \
+4 -4
View File
@@ -269,7 +269,7 @@ This generates visualizations showing video frames with subtask boundaries overl
Train with **no annotations** - uses linear progress from 0 to 1:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=sarm \
--policy.annotation_mode=single_stage \
@@ -288,7 +288,7 @@ lerobot-train \
Train with **dense annotations only** (sparse auto-generated):
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=sarm \
--policy.annotation_mode=dense_only \
@@ -307,7 +307,7 @@ lerobot-train \
Train with **both sparse and dense annotations**:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=sarm \
--policy.annotation_mode=dual \
@@ -468,7 +468,7 @@ This script:
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \
--use_rabc=true \
-13
View File
@@ -1,18 +1,5 @@
# SO-101
<div style="display: flex; align-items: center; gap: 10px;">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/SO101_Follower.webp"
alt="SO-101"
width="60%"
/>
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/SO101_Leader.webp"
alt="SO-101"
width="60%"
/>
</div>
In the steps below, we explain how to assemble our flagship robot, the SO-101.
## Source the parts
+1 -99
View File
@@ -188,105 +188,7 @@ Press `Ctrl+C` to stop the policy.
## Running in Simulation Mode (MuJoCo)
You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI.
### Calibrate Exoskeleton Teleoperator
```bash
lerobot-calibrate \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo
```
### Teleoperate in Simulation
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--fps=100
```
### Record Dataset in Simulation
```bash
lerobot-record \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--dataset.repo_id=your-username/dataset-name \
--dataset.single_task="Test" \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true
```
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
---
## Running on Real Robot
Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot.
### Start the Camera Server
On the robot, start the ZMQ image server:
```bash
python src/lerobot/cameras/zmq/image_server.py
```
Keep this running in a separate terminal for camera streaming during recording.
### Teleoperate Real Robot
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--fps=100
```
### Record Dataset on Real Robot
```bash
lerobot-record \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--dataset.repo_id=your-username/dataset-name \
--dataset.single_task="Test" \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true
```
**Note**: Update `server_address` to match your robot's camera server IP.
Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real)
---
You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config.
## Additional Resources
-25
View File
@@ -12,7 +12,6 @@ LeRobot provides several utilities for manipulating datasets:
4. **Add Features** - Add new features to a dataset
5. **Remove Features** - Remove features from a dataset
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
The core implementation is in `lerobot.datasets.dataset_tools`.
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
@@ -157,30 +156,6 @@ lerobot-edit-dataset \
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
### Show the information of datasets
Show the information of datasets such as number of episode, number of frame, File size and so on.
No change will be made to the dataset
```bash
# Show dataset information without feature details
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type info \
# Show dataset information with feature details
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type info \
--operation.show_features true
```
**Parameters:**
- `parameters`: The flag to control show or no show dataset information with feature details.(default=false)
### Push to Hub
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
+1 -1
View File
@@ -45,7 +45,7 @@ policy.type=wall_x
For training WallX, you can use the standard LeRobot training script with the appropriate configuration:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your_dataset \
--policy.type=wall_x \
--output_dir=./outputs/wallx_training \
+1 -1
View File
@@ -154,7 +154,7 @@ lerobot-train \
```bash
lerobot-train \
--dataset.repo_id=<USER>/bimanual-so100-handover-cube \
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
--output_dir=./outputs/xvla_bimanual \
--job_name=xvla_so101_training \
--policy.path="lerobot/xvla-base" \
+16 -17
View File
@@ -22,7 +22,7 @@ lerobot-replay \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=black \
--dataset.repo_id=<USER>/record-test \
--dataset.repo_id=aliberts/record-test \
--dataset.episode=2
```
"""
@@ -81,25 +81,24 @@ def replay(cfg: ReplayConfig):
actions = dataset.hf_dataset.select_columns(ACTION)
robot.connect()
try:
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action_array = actions[idx][ACTION]
action = {}
for i, name in enumerate(dataset.features[ACTION]["names"]):
key = f"{name.removeprefix('main_')}.pos"
action[key] = action_array[i].item()
action_array = actions[idx][ACTION]
action = {}
for i, name in enumerate(dataset.features[ACTION]["names"]):
key = f"{name.removeprefix('main_')}.pos"
action[key] = action_array[i].item()
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
action["elbow_flex.pos"] -= 90
robot.send_action(action)
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
action["elbow_flex.pos"] -= 90
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
finally:
robot.disconnect()
dt_s = time.perf_counter() - start_episode_t
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
robot.disconnect()
if __name__ == "__main__":
+43 -45
View File
@@ -78,24 +78,40 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="lekiwi_evaluate")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
print("Starting evaluate loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
# Main record loop
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
@@ -104,42 +120,24 @@ def main():
robot_observation_processor=robot_observation_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Save episode
dataset.save_episode()
recorded_episodes += 1
# Save episode
dataset.save_episode()
recorded_episodes += 1
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
+44 -45
View File
@@ -74,23 +74,40 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="lekiwi_record")
try:
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {recorded_episodes}")
print("Starting record loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {recorded_episodes}")
# Main record loop
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
dataset=dataset,
teleop=[leader_arm, keyboard],
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
dataset=dataset,
teleop=[leader_arm, keyboard],
control_time_s=EPISODE_TIME_SEC,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
@@ -98,44 +115,26 @@ def main():
robot_observation_processor=robot_observation_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=[leader_arm, keyboard],
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Save episode
dataset.save_episode()
recorded_episodes += 1
# Save episode
dataset.save_episode()
recorded_episodes += 1
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
leader_arm.disconnect()
keyboard.disconnect()
listener.stop()
# Clean up
log_say("Stop recording")
robot.disconnect()
leader_arm.disconnect()
keyboard.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
+15 -17
View File
@@ -42,27 +42,25 @@ def main():
# Connect to the robot
robot.connect()
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
# Get recorded action from dataset
action = {
name: float(actions[idx][ACTION][i])
for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get recorded action from dataset
action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Send action to robot
_ = robot.send_action(action)
# Send action to robot
_ = robot.send_action(action)
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
finally:
robot.disconnect()
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
robot.disconnect()
if __name__ == "__main__":
+41 -44
View File
@@ -142,24 +142,38 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="phone_so100_evaluate")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
@@ -168,41 +182,24 @@ def main():
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Save episode
dataset.save_episode()
episode_idx += 1
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
+41 -44
View File
@@ -149,23 +149,38 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="phone_so100_record")
try:
if not robot.is_connected or not phone.is_connected:
raise ValueError("Robot or teleop is not connected!")
if not robot.is_connected or not phone.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop. Move your phone to teleoperate the robot...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
print("Starting record loop. Move your phone to teleoperate the robot...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
@@ -173,43 +188,25 @@ def main():
robot_observation_processor=robot_joints_to_ee_pose,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Save episode
dataset.save_episode()
episode_idx += 1
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
phone.disconnect()
listener.stop()
# Clean up
log_say("Stop recording")
robot.disconnect()
phone.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
+20 -22
View File
@@ -73,34 +73,32 @@ def main():
# Connect to the robot
robot.connect()
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i])
for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get robot observation
robot_obs = robot.get_observation()
# Get robot observation
robot_obs = robot.get_observation()
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Send action to robot
_ = robot.send_action(joint_action)
# Send action to robot
_ = robot.send_action(joint_action)
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
finally:
# Clean up
robot.disconnect()
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
# Clean up
robot.disconnect()
if __name__ == "__main__":
+10 -10
View File
@@ -27,8 +27,8 @@ measuring consistency and ground truth alignment.
Usage:
# Basic usage with smolvla policy
uv run python examples/rtc/eval_dataset.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--dataset.repo_id=<USER>/check_rtc \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--rtc.max_guidance_weight=10.0 \
@@ -58,16 +58,16 @@ Usage:
--device=cuda
uv run python examples/rtc/eval_dataset.py \
--policy.path=<USER>/reuben_pi0 \
--dataset.repo_id=<USER>/so101_cube_in_cup \
--policy.path=lipsop/reuben_pi0 \
--dataset.repo_id=ReubenLim/so101_cube_in_cup \
--rtc.execution_horizon=8 \
--device=cuda
# With torch.compile for faster inference (PyTorch 2.0+)
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
uv run python examples/rtc/eval_dataset.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--dataset.repo_id=<USER>/check_rtc \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--use_torch_compile=true \
@@ -75,8 +75,8 @@ Usage:
# With torch.compile on CUDA (CUDA graphs disabled by default)
uv run python examples/rtc/eval_dataset.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--dataset.repo_id=<USER>/check_rtc \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=cuda \
--use_torch_compile=true \
@@ -84,8 +84,8 @@ Usage:
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
uv run python examples/rtc/eval_dataset.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--dataset.repo_id=<USER>/check_rtc \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--use_torch_compile=true \
--torch_compile_backend=inductor \
--torch_compile_mode=max-autotune \
+3 -3
View File
@@ -28,7 +28,7 @@ For simulation environments, see eval_with_simulation.py
Usage:
# Run RTC with Real robot with RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
@@ -41,7 +41,7 @@ Usage:
# Run RTC with Real robot without RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=false \
--robot.type=so100_follower \
@@ -53,7 +53,7 @@ Usage:
# Run RTC with Real robot with pi0.5 policy
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/pi05_check_rtc \
--policy.path=helper2424/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
+41 -44
View File
@@ -142,24 +142,38 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="so100_so100_evaluate")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
@@ -168,41 +182,24 @@ def main():
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Save episode
dataset.save_episode()
episode_idx += 1
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
+41 -45
View File
@@ -146,23 +146,38 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="recording_phone")
try:
if not leader.is_connected or not follower.is_connected:
raise ValueError("Robot or teleop is not connected!")
if not leader.is_connected or not follower.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
print("Starting record loop...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
# Main record loop
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop=leader,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop=leader,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
@@ -170,44 +185,25 @@ def main():
robot_observation_processor=follower_joints_to_ee,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop=leader,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Save episode
dataset.save_episode()
episode_idx += 1
# Save episode
dataset.save_episode()
episode_idx += 1
# Clean up
log_say("Stop recording")
leader.disconnect()
follower.disconnect()
listener.stop()
finally:
# Clean up
log_say("Stop recording")
leader.disconnect()
follower.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
+19 -22
View File
@@ -74,35 +74,32 @@ def main():
# Connect to the robot
robot.connect()
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i])
for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get robot observation
robot_obs = robot.get_observation()
# Get robot observation
robot_obs = robot.get_observation()
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Send action to robot
_ = robot.send_action(joint_action)
# Send action to robot
_ = robot.send_action(joint_action)
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
finally:
# Clean up
robot.disconnect()
# Clean up
robot.disconnect()
if __name__ == "__main__":
+14 -17
View File
@@ -4,6 +4,7 @@ from pathlib import Path
from queue import Empty, Full
import torch
import torch.optim as optim
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
@@ -11,7 +12,6 @@ from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.gym_manipulator import make_robot_env
from lerobot.robots.so_follower import SO100FollowerConfig
@@ -40,9 +40,8 @@ def run_learner(
policy_learner.train()
policy_learner.to(device)
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
algorithm.make_optimizers()
# Create Adam optimizer from scratch - simple and clean
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
@@ -84,26 +83,24 @@ def run_learner(
else:
batch[key] = online_batch[key]
def batch_iter(b=batch):
while True:
yield b
loss, _ = policy_learner.forward(batch)
stats = algorithm.update(batch_iter())
optimizer.zero_grad()
loss.backward()
optimizer.step()
training_step += 1
if training_step % LOG_EVERY == 0:
log_dict = stats.to_log_dict()
print(
f"[LEARNER] Training step {training_step}, "
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
)
# Send updated parameters to actor every 10 training steps
if training_step % SEND_EVERY == 0:
try:
weights = algorithm.get_weights()
parameters_queue.put_nowait(weights)
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
parameters_queue.put_nowait(state_dict)
print("[LEARNER] Sent updated parameters to actor")
except Full:
# Missing write due to queue not being consumed (should happen rarely)
@@ -147,15 +144,15 @@ def run_actor(
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
try:
new_weights = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_weights)
new_params = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_params)
print("[ACTOR] Updated policy parameters from learner")
except Empty: # No new updated parameters available from learner, waiting
pass
# Get action from policy (returns full action: continuous + discrete)
# Get action from policy
policy_obs = make_policy_obs(obs, device=device)
action_tensor = policy_actor.select_action(policy_obs)
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
action = action_tensor.squeeze(0).cpu().numpy()
# Step environment
+7 -15
View File
@@ -76,9 +76,9 @@ dependencies = [
"pyserial>=3.5,<4.0",
"wandb>=0.24.0,<0.25.0",
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
"draccus==0.10.0", # TODO: Remove ==
"gymnasium>=1.1.1,<2.0.0",
@@ -102,20 +102,14 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
damiao = ["python-can>=4.2.0,<5.0.0"]
# Robots
openarms = ["lerobot[damiao]"]
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
unitree_g1 = [
"pyzmq>=26.2.1,<28.0.0",
"onnxruntime>=1.16.0,<2.0.0",
"pin>=3.0.0,<4.0.0",
"meshcat>=0.3.0,<0.4.0",
"matplotlib>=3.9.0,<4.0.0",
"casadi>=3.6.0,<4.0.0",
"onnxruntime>=1.16.0,<2.0.0"
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
kinematics = ["lerobot[placo-dep]"]
@@ -209,7 +203,6 @@ lerobot-info="lerobot.scripts.lerobot_info:main"
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.packages.find]
@@ -285,7 +278,6 @@ default.extend-ignore-identifiers-re = [
"thw",
"inpt",
"ROBOTIS",
"OT_VALUE"
]
# TODO: Uncomment when ready to use
@@ -360,9 +352,9 @@ ignore_errors = false
module = "lerobot.cameras.*"
ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.motors.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.motors.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.robots.*"
+1 -1
View File
@@ -13,5 +13,5 @@
# limitations under the License.
from .camera import Camera
from .configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
from .configs import CameraConfig, ColorMode, Cv2Rotation
from .utils import make_cameras_from_configs
+18 -82
View File
@@ -15,12 +15,11 @@
# limitations under the License.
import abc
import warnings
from typing import Any
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
from .configs import CameraConfig
from .configs import CameraConfig, ColorMode
class Camera(abc.ABC):
@@ -31,12 +30,20 @@ class Camera(abc.ABC):
Manages basic camera properties (FPS, resolution) and core operations:
- Connection/disconnection
- Frame capture (sync/async/latest)
- Frame capture (sync/async)
Attributes:
fps (int | None): Configured frames per second
width (int | None): Frame width in pixels
height (int | None): Frame height in pixels
Example:
class MyCamera(Camera):
def __init__(self, config): ...
@property
def is_connected(self) -> bool: ...
def connect(self, warmup=True): ...
# Plus other required methods
"""
def __init__(self, config: CameraConfig):
@@ -49,32 +56,6 @@ class Camera(abc.ABC):
self.width: int | None = config.width
self.height: int | None = config.height
def __enter__(self):
"""
Context manager entry.
Automatically connects to the camera.
"""
self.connect()
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
Context manager exit.
Automatically disconnects, ensuring resources are released even on error.
"""
self.disconnect()
def __del__(self) -> None:
"""
Destructor safety net.
Attempts to disconnect if the object is garbage collected without cleanup.
"""
try:
if self.is_connected:
self.disconnect()
except Exception: # nosec B110
pass
@property
@abc.abstractmethod
def is_connected(self) -> bool:
@@ -108,10 +89,12 @@ class Camera(abc.ABC):
pass
@abc.abstractmethod
def read(self) -> NDArray[Any]:
"""Capture and return a single frame from the camera synchronously.
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""Capture and return a single frame from the camera.
This is a blocking call that will wait for the hardware and its SDK.
Args:
color_mode: Desired color mode for the output frame. If None,
uses the camera's default color mode.
Returns:
np.ndarray: Captured frame as a numpy array.
@@ -120,64 +103,17 @@ class Camera(abc.ABC):
@abc.abstractmethod
def async_read(self, timeout_ms: float = ...) -> NDArray[Any]:
"""Return the most recent new frame.
This method retrieves the latest frame captured by the background thread.
If a new frame is already available in the buffer (captured since the last call),
it returns it immediately.
It blocks up to `timeout_ms` only if the buffer is empty or if the latest frame
was already consumed by a previous `async_read` call.
Essentially, this method return the latest unconsumed frame, waiting if necessary
for a new one to arrive within the specified timeout.
Usage:
- Ideal for control loops where you want to ensure every processed frame
is fresh, effectively synchronizing your loop to the camera's FPS.
- Causes of a timeout usually include: very low camera FPS, heavy processing load,
or if the camera is disconnected.
"""Asynchronously capture and return a single frame from the camera.
Args:
timeout_ms: Maximum time to wait for a new frame in milliseconds.
Defaults to 200ms (0.2s).
timeout_ms: Maximum time to wait for a frame in milliseconds.
Defaults to implementation-specific timeout.
Returns:
np.ndarray: Captured frame as a numpy array.
Raises:
TimeoutError: If no new frame arrives within `timeout_ms`.
"""
pass
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Usage:
Ideal for scenarios requiring zero latency or decoupled frequencies & when
we want a guaranteed frame, such as UI visualization, logging, or
non-critical monitoring.
Returns:
NDArray[Any]: The frame image (numpy array).
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
NotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
warnings.warn(
f"{self.__class__.__name__}.read_latest() is not implemented. "
"Please override read_latest(); it will be required in future releases.",
FutureWarning,
stacklevel=2,
)
return self.async_read()
@abc.abstractmethod
def disconnect(self) -> None:
"""Disconnect from the camera and release resources."""
-23
View File
@@ -25,10 +25,6 @@ class ColorMode(str, Enum):
RGB = "rgb"
BGR = "bgr"
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`color_mode` is expected to be in {list(cls)}, but {value} is provided.")
class Cv2Rotation(int, Enum):
NO_ROTATION = 0
@@ -36,25 +32,6 @@ class Cv2Rotation(int, Enum):
ROTATE_180 = 180
ROTATE_270 = -90
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`rotation` is expected to be in {list(cls)}, but {value} is provided.")
# Subset from https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html
class Cv2Backends(int, Enum):
ANY = 0
V4L2 = 200
DSHOW = 700
PVAPI = 800
ANDROID = 1000
AVFOUNDATION = 1200
MSMF = 1400
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`backend` is expected to be in {list(cls)}, but {value} is provided.")
@dataclass(kw_only=True)
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus
+65 -116
View File
@@ -32,11 +32,10 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..utils import get_cv2_rotation
from ..utils import get_cv2_backend, get_cv2_rotation
from .configuration_opencv import ColorMode, OpenCVCameraConfig
# NOTE(Steven): The maximum opencv device index depends on your operating system. For instance,
@@ -71,24 +70,34 @@ class OpenCVCamera(Camera):
Example:
```python
from lerobot.cameras.opencv import OpenCVCamera
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig, ColorMode, Cv2Rotation
# Basic usage with camera index 0
config = OpenCVCameraConfig(index_or_path=0)
camera = OpenCVCamera(config)
camera.connect()
# Read 1 frame synchronously (blocking)
# Read 1 frame synchronously
color_image = camera.read()
print(color_image.shape)
# Read 1 frame asynchronously (waits for new frame with a timeout)
# Read 1 frame asynchronously
async_image = camera.async_read()
# Get the latest frame immediately (no wait, returns timestamp)
latest_image, timestamp = camera.read_latest()
# When done, properly disconnect the camera using
camera.disconnect()
# Example with custom settings
custom_config = OpenCVCameraConfig(
index_or_path='/dev/video0', # Or use an index
fps=30,
width=1280,
height=720,
color_mode=ColorMode.RGB,
rotation=Cv2Rotation.ROTATE_90
)
custom_camera = OpenCVCamera(custom_config)
# ... connect, read, disconnect ...
```
"""
@@ -114,11 +123,10 @@ class OpenCVCamera(Camera):
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_timestamp: float | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
self.backend: int = config.backend
self.backend: int = get_cv2_backend()
if self.height and self.width:
self.capture_width, self.capture_height = self.width, self.height
@@ -133,23 +141,20 @@ class OpenCVCamera(Camera):
"""Checks if the camera is currently connected and opened."""
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""
Connects to the OpenCV camera specified in the configuration.
Initializes the OpenCV VideoCapture object, sets desired camera properties
(FPS, width, height), starts the background reading thread and performs initial checks.
Args:
warmup (bool): If True, waits at connect() time until at least one valid frame
has been captured by the background thread. Defaults to True.
(FPS, width, height), and performs initial checks.
Raises:
DeviceAlreadyConnectedError: If the camera is already connected.
ConnectionError: If the specified camera index/path is not found or fails to open.
RuntimeError: If the camera opens but fails to apply requested settings.
ConnectionError: If the specified camera index/path is not found or the camera is found but fails to open.
RuntimeError: If the camera opens but fails to apply requested FPS/resolution settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
# Use 1 thread for OpenCV operations to avoid potential conflicts or
# blocking in multi-threaded applications, especially during data collection.
@@ -165,20 +170,15 @@ class OpenCVCamera(Camera):
)
self._configure_capture_settings()
self._start_read_thread()
if warmup and self.warmup_s > 0:
if warmup:
start_time = time.time()
while time.time() - start_time < self.warmup_s:
self.async_read(timeout_ms=self.warmup_s * 1000)
self.read()
time.sleep(0.1)
with self.frame_lock:
if self.latest_frame is None:
raise ConnectionError(f"{self} failed to capture frames during warmup.")
logger.info(f"{self} connected.")
@check_if_not_connected
def _configure_capture_settings(self) -> None:
"""
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
@@ -196,8 +196,11 @@ class OpenCVCamera(Camera):
Raises:
RuntimeError: If the camera fails to set any of the specified properties
to the requested value.
DeviceNotConnectedError: If the camera is not connected.
DeviceNotConnectedError: If the camera is not connected when attempting
to configure settings.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
if self.config.fourcc is not None:
@@ -336,18 +339,6 @@ class OpenCVCamera(Camera):
return found_cameras_info
def _read_from_hardware(self) -> NDArray[Any]:
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
ret, frame = self.videocapture.read()
if not ret:
raise RuntimeError(f"{self} read failed (status={ret}).")
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -355,6 +346,11 @@ class OpenCVCamera(Camera):
This is a blocking call. It waits for the next available frame from the
camera hardware via OpenCV.
Args:
color_mode (Optional[ColorMode]): If specified, overrides the default
color mode (`self.color_mode`) for this read operation (e.g.,
request RGB even if default is BGR).
Returns:
np.ndarray: The captured frame as a NumPy array in the format
(height, width, channels), using the specified or default
@@ -366,31 +362,34 @@ class OpenCVCamera(Camera):
received frame dimensions don't match expectations before rotation.
ValueError: If an invalid `color_mode` is requested.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start_time = time.perf_counter()
if color_mode is not None:
logger.warning(
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
ret, frame = self.videocapture.read()
self.new_frame_event.clear()
frame = self.async_read(timeout_ms=10000)
if not ret or frame is None:
raise RuntimeError(f"{self} read failed (status={ret}).")
processed_frame = self._postprocess_image(frame, color_mode)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return frame
return processed_frame
def _postprocess_image(self, image: NDArray[Any]) -> NDArray[Any]:
def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Applies color conversion, dimension validation, and rotation to a raw frame.
Args:
image (np.ndarray): The raw image frame (expected BGR format from OpenCV).
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
uses the instance's default `self.color_mode`.
Returns:
np.ndarray: The processed image frame.
@@ -400,10 +399,11 @@ class OpenCVCamera(Camera):
RuntimeError: If the raw frame dimensions do not match the configured
`width` and `height`.
"""
requested_color_mode = self.color_mode if color_mode is None else color_mode
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
if requested_color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
f"Invalid color mode '{requested_color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
)
h, w, c = image.shape
@@ -417,7 +417,7 @@ class OpenCVCamera(Camera):
raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).")
processed_image = image
if self.color_mode == ColorMode.RGB:
if requested_color_mode == ColorMode.RGB:
processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]:
@@ -431,7 +431,7 @@ class OpenCVCamera(Camera):
On each iteration:
1. Reads a color frame
2. Stores result in latest_frame and updates timestamp (thread-safe)
2. Stores result in latest_frame (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
@@ -439,37 +439,30 @@ class OpenCVCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
failure_count = 0
while not self.stop_event.is_set():
try:
raw_frame = self._read_from_hardware()
processed_frame = self._postprocess_image(raw_frame)
capture_time = time.perf_counter()
color_image = self.read()
with self.frame_lock:
self.latest_frame = processed_frame
self.latest_timestamp = capture_time
self.latest_frame = color_image
self.new_frame_event.set()
failure_count = 0
except DeviceNotConnectedError:
break
except Exception as e:
if failure_count <= 10:
failure_count += 1
logger.warning(f"Error reading frame in background thread for {self}: {e}")
else:
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
logger.warning(f"Error reading frame in background thread for {self}: {e}")
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""
self._stop_read_thread()
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=0.1)
if self.stop_event is not None:
self.stop_event.set()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
self.thread.daemon = True
self.thread.start()
time.sleep(0.1)
def _stop_read_thread(self) -> None:
"""Signals the background read thread to stop and waits for it to join."""
@@ -482,12 +475,6 @@ class OpenCVCamera(Camera):
self.thread = None
self.stop_event = None
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
@@ -495,7 +482,6 @@ class OpenCVCamera(Camera):
This method retrieves the most recent frame captured by the background
read thread. It does not block waiting for the camera hardware directly,
but may wait up to timeout_ms for the background thread to provide a frame.
It is “best effort” under high FPS.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
@@ -510,14 +496,17 @@ class OpenCVCamera(Camera):
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
self._start_read_thread()
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
thread_alive = self.thread is not None and self.thread.is_alive()
raise TimeoutError(
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
f"Read thread alive: {self.thread.is_alive()}."
f"Read thread alive: {thread_alive}."
)
with self.frame_lock:
@@ -529,41 +518,6 @@ class OpenCVCamera(Camera):
return frame
@check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Returns:
NDArray[Any]: The frame image (numpy array).
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
frame = self.latest_frame
timestamp = self.latest_timestamp
if frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return frame
def disconnect(self) -> None:
"""
Disconnects from the camera and cleans up resources.
@@ -584,9 +538,4 @@ class OpenCVCamera(Camera):
self.videocapture.release()
self.videocapture = None
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
logger.info(f"{self} disconnected.")
@@ -15,9 +15,9 @@
from dataclasses import dataclass
from pathlib import Path
from ..configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
from ..configs import CameraConfig, ColorMode, Cv2Rotation
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation", "Cv2Backends"]
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
@CameraConfig.register_subclass("opencv")
@@ -50,7 +50,6 @@ class OpenCVCameraConfig(CameraConfig):
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
warmup_s: Time reading frames before returning from connect (in seconds)
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
backend: OpenCV backend identifier (https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html). Defaults to ANY.
Note:
- Only 3-channel color output (RGB/BGR) is currently supported.
@@ -63,12 +62,22 @@ class OpenCVCameraConfig(CameraConfig):
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
fourcc: str | None = None
backend: Cv2Backends = Cv2Backends.ANY
def __post_init__(self) -> None:
self.color_mode = ColorMode(self.color_mode)
self.rotation = Cv2Rotation(self.rotation)
self.backend = Cv2Backends(self.backend)
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.rotation not in (
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
raise ValueError(
@@ -74,4 +74,7 @@ class Reachy2CameraConfig(CameraConfig):
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
)
self.color_mode = ColorMode(self.color_mode)
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
@@ -32,7 +32,6 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.import_utils import _reachy2_sdk_available
if TYPE_CHECKING or _reachy2_sdk_available:
@@ -81,8 +80,6 @@ class Reachy2Camera(Camera):
self.config = config
self.color_mode = config.color_mode
self.latest_frame: NDArray[Any] | None = None
self.latest_timestamp: float | None = None
self.cam_manager: CameraManager | None = None
@@ -124,12 +121,16 @@ class Reachy2Camera(Camera):
"""
raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.")
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
This method retrieves the most recent frame available in Reachy 2's low-level software.
This is a blocking call.
Args:
color_mode (Optional[ColorMode]): If specified, overrides the default
color mode (`self.color_mode`) for this read operation (e.g.,
request RGB even if default is BGR).
Returns:
np.ndarray: The captured frame as a NumPy array in the format
@@ -138,13 +139,11 @@ class Reachy2Camera(Camera):
"""
start_time = time.perf_counter()
if self.cam_manager is None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if color_mode is not None:
logger.warning(
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
@@ -166,27 +165,25 @@ class Reachy2Camera(Camera):
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
if frame is None:
raise RuntimeError(f"Internal error: No frame available for {self}.")
return np.empty((0, 0, 3), dtype=np.uint8)
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
)
if self.color_mode == ColorMode.RGB:
if self.config.color_mode == "rgb":
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
self.latest_frame = frame
self.latest_timestamp = time.perf_counter()
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return frame
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Same as read()
Reads the latest available frame.
This method retrieves the most recent frame available in Reachy 2's low-level software.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
to become available. Defaults to 200ms (0.2 seconds).
Returns:
np.ndarray: The latest captured frame as a NumPy array in the format
@@ -197,40 +194,16 @@ class Reachy2Camera(Camera):
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
return self.read()
frame = self.read()
@check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
if frame is None:
raise RuntimeError(f"Internal error: No frame available for {self}.")
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
return frame
Returns:
tuple[NDArray, float]:
- The frame image (numpy array).
- The timestamp (time.perf_counter) when this frame was captured.
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if self.latest_frame is None or self.latest_timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
age_ms = (time.perf_counter() - self.latest_timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return self.latest_frame
@check_if_not_connected
def disconnect(self) -> None:
"""
Stops the background read thread (if running).
@@ -238,6 +211,8 @@ class Reachy2Camera(Camera):
Raises:
DeviceNotConnectedError: If the camera is already disconnected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} not connected.")
if self.cam_manager is not None:
self.cam_manager.disconnect()
+73 -144
View File
@@ -30,8 +30,7 @@ try:
except Exception as e:
logging.info(f"Could not import realsense: {e}")
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
@@ -73,14 +72,15 @@ class RealSenseCamera(Camera):
camera = RealSenseCamera(config)
camera.connect()
# Read 1 frame synchronously (blocking)
# Read 1 frame synchronously
color_image = camera.read()
print(color_image.shape)
# Read 1 frame asynchronously (waits for new frame with a timeout)
# Read 1 frame asynchronously
async_image = camera.async_read()
# Get the latest frame immediately (no wait, returns timestamp)
latest_image, timestamp = camera.read_latest()
# When done, properly disconnect the camera using
camera.disconnect()
# Example with depth capture and custom settings
custom_config = RealSenseCameraConfig(
@@ -133,9 +133,7 @@ class RealSenseCamera(Camera):
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_color_frame: NDArray[Any] | None = None
self.latest_depth_frame: NDArray[Any] | None = None
self.latest_timestamp: float | None = None
self.latest_frame: NDArray[Any] | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
@@ -153,7 +151,6 @@ class RealSenseCamera(Camera):
"""Checks if the camera pipeline is started and streams are active."""
return self.rs_pipeline is not None and self.rs_profile is not None
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""
Connects to the RealSense camera specified in the configuration.
@@ -161,16 +158,14 @@ class RealSenseCamera(Camera):
Initializes the RealSense pipeline, configures the required streams (color
and optionally depth), starts the pipeline, and validates the actual stream settings.
Args:
warmup (bool): If True, waits at connect() time until at least one valid frame
has been captured by the background thread. Defaults to True.
Raises:
DeviceAlreadyConnectedError: If the camera is already connected.
ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique).
ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all.
RuntimeError: If the pipeline starts but fails to apply requested settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
self.rs_pipeline = rs.pipeline()
rs_config = rs.config()
@@ -186,18 +181,15 @@ class RealSenseCamera(Camera):
) from e
self._configure_capture_settings()
self._start_read_thread()
# NOTE(Steven/Caroline): Enforcing at least one second of warmup as RS cameras need a bit of time before the first read. If we don't wait, the first read from the warmup will raise.
self.warmup_s = max(self.warmup_s, 1)
start_time = time.time()
while time.time() - start_time < self.warmup_s:
self.async_read(timeout_ms=self.warmup_s * 1000)
time.sleep(0.1)
with self.frame_lock:
if self.latest_color_frame is None or self.use_depth and self.latest_depth_frame is None:
raise ConnectionError(f"{self} failed to capture frames during warmup.")
if warmup:
time.sleep(
1
) # NOTE(Steven): RS cameras need a bit of time to warm up before the first read. If we don't wait, the first read from the warmup will raise.
start_time = time.time()
while time.time() - start_time < self.warmup_s:
self.read()
time.sleep(0.1)
logger.info(f"{self} connected.")
@@ -290,7 +282,6 @@ class RealSenseCamera(Camera):
if self.use_depth:
rs_config.enable_stream(rs.stream.depth)
@check_if_not_connected
def _configure_capture_settings(self) -> None:
"""Sets fps, width, and height from device stream if not already configured.
@@ -300,6 +291,8 @@ class RealSenseCamera(Camera):
Raises:
DeviceNotConnectedError: If device is not connected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
if self.rs_profile is None:
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
@@ -319,7 +312,6 @@ class RealSenseCamera(Camera):
self.width, self.height = actual_width, actual_height
self.capture_width, self.capture_height = actual_width, actual_height
@check_if_not_connected
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
"""
Reads a single frame (depth) synchronously from the camera.
@@ -327,6 +319,9 @@ class RealSenseCamera(Camera):
This is a blocking call. It waits for a coherent set of frames (depth)
from the camera hardware via the RealSense pipeline.
Args:
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
Returns:
np.ndarray: The depth map as a NumPy array (height, width)
of type `np.uint16` (raw depth values in millimeters) and rotation.
@@ -335,50 +330,44 @@ class RealSenseCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If reading frames from the pipeline fails or frames are invalid.
"""
if timeout_ms:
logger.warning(
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if not self.use_depth:
raise RuntimeError(
f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}."
)
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
start_time = time.perf_counter()
self.new_frame_event.clear()
_ = self.async_read(timeout_ms=10000)
with self.frame_lock:
depth_map = self.latest_depth_frame
if depth_map is None:
raise RuntimeError("No depth frame available. Ensure camera is streaming.")
return depth_map
def _read_from_hardware(self):
if self.rs_pipeline is None:
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=10000)
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
if not ret or frame is None:
raise RuntimeError(f"{self} read failed (status={ret}).")
raise RuntimeError(f"{self} read_depth failed (status={ret}).")
return frame
depth_frame = frame.get_depth_frame()
depth_map = np.asanyarray(depth_frame.get_data())
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]:
depth_map_processed = self._postprocess_image(depth_map, depth_frame=True)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return depth_map_processed
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]:
"""
Reads a single frame (color) synchronously from the camera.
This is a blocking call. It waits for a coherent set of frames (color)
from the camera hardware via the RealSense pipeline.
Args:
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
Returns:
np.ndarray: The captured color frame as a NumPy array
(height, width, channels), processed according to `color_mode` and rotation.
@@ -389,36 +378,39 @@ class RealSenseCamera(Camera):
ValueError: If an invalid `color_mode` is requested.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start_time = time.perf_counter()
if color_mode is not None:
logger.warning(
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
if self.rs_pipeline is None:
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
if timeout_ms:
logger.warning(
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
)
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
if not ret or frame is None:
raise RuntimeError(f"{self} read failed (status={ret}).")
self.new_frame_event.clear()
color_frame = frame.get_color_frame()
color_image_raw = np.asanyarray(color_frame.get_data())
frame = self.async_read(timeout_ms=10000)
color_image_processed = self._postprocess_image(color_image_raw, color_mode)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return frame
return color_image_processed
def _postprocess_image(self, image: NDArray[Any], depth_frame: bool = False) -> NDArray[Any]:
def _postprocess_image(
self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False
) -> NDArray[Any]:
"""
Applies color conversion, dimension validation, and rotation to a raw color frame.
Args:
image (np.ndarray): The raw image frame (expected RGB format from RealSense).
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
uses the instance's default `self.color_mode`.
Returns:
np.ndarray: The processed image frame according to `self.color_mode` and `self.rotation`.
@@ -429,9 +421,9 @@ class RealSenseCamera(Camera):
`width` and `height`.
"""
if self.color_mode and self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"Invalid requested color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
)
if depth_frame:
@@ -462,7 +454,7 @@ class RealSenseCamera(Camera):
On each iteration:
1. Reads a color frame with 500ms timeout
2. Stores result in latest_frame and updates timestamp (thread-safe)
2. Stores result in latest_frame (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
@@ -470,41 +462,25 @@ class RealSenseCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
failure_count = 0
while not self.stop_event.is_set():
try:
frame = self._read_from_hardware()
color_frame_raw = frame.get_color_frame()
color_frame = np.asanyarray(color_frame_raw.get_data())
processed_color_frame = self._postprocess_image(color_frame)
if self.use_depth:
depth_frame_raw = frame.get_depth_frame()
depth_frame = np.asanyarray(depth_frame_raw.get_data())
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
capture_time = time.perf_counter()
color_image = self.read(timeout_ms=500)
with self.frame_lock:
self.latest_color_frame = processed_color_frame
if self.use_depth:
self.latest_depth_frame = processed_depth_frame
self.latest_timestamp = capture_time
self.latest_frame = color_image
self.new_frame_event.set()
failure_count = 0
except DeviceNotConnectedError:
break
except Exception as e:
if failure_count <= 10:
failure_count += 1
logger.warning(f"Error reading frame in background thread for {self}: {e}")
else:
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
logger.warning(f"Error reading frame in background thread for {self}: {e}")
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""
self._stop_read_thread()
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=0.1)
if self.stop_event is not None:
self.stop_event.set()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
@@ -522,14 +498,7 @@ class RealSenseCamera(Camera):
self.thread = None
self.stop_event = None
with self.frame_lock:
self.latest_color_frame = None
self.latest_depth_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame data (color) asynchronously.
@@ -537,7 +506,6 @@ class RealSenseCamera(Camera):
This method retrieves the most recent color frame captured by the background
read thread. It does not block waiting for the camera hardware directly,
but may wait up to timeout_ms for the background thread to provide a frame.
It is “best effort” under high FPS.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
@@ -552,18 +520,21 @@ class RealSenseCamera(Camera):
TimeoutError: If no frame data becomes available within the specified timeout.
RuntimeError: If the background thread died unexpectedly or another error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
self._start_read_thread()
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
thread_alive = self.thread is not None and self.thread.is_alive()
raise TimeoutError(
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
f"Read thread alive: {self.thread.is_alive()}."
f"Read thread alive: {thread_alive}."
)
with self.frame_lock:
frame = self.latest_color_frame
frame = self.latest_frame
self.new_frame_event.clear()
if frame is None:
@@ -571,42 +542,6 @@ class RealSenseCamera(Camera):
return frame
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent (color) frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Returns:
NDArray[Any]: The frame image (numpy array).
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
frame = self.latest_color_frame
timestamp = self.latest_timestamp
if frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return frame
def disconnect(self) -> None:
"""
Disconnects from the camera, stops the pipeline, and cleans up resources.
@@ -630,10 +565,4 @@ class RealSenseCamera(Camera):
self.rs_pipeline = None
self.rs_profile = None
with self.frame_lock:
self.latest_color_frame = None
self.latest_depth_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
logger.info(f"{self} disconnected.")
@@ -60,8 +60,20 @@ class RealSenseCameraConfig(CameraConfig):
warmup_s: int = 1
def __post_init__(self) -> None:
self.color_mode = ColorMode(self.color_mode)
self.rotation = Cv2Rotation(self.rotation)
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.rotation not in (
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
values = (self.fps, self.width, self.height)
if any(v is not None for v in values) and any(v is None for v in values):
+12
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
from typing import cast
from lerobot.utils.import_utils import make_device_from_device_class
@@ -67,3 +68,14 @@ def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
return None
def get_cv2_backend() -> int:
import cv2
if platform.system() == "Windows":
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
# elif platform.system() == "Darwin": # macOS
# return cv2.CAP_AVFOUNDATION
else: # Linux and others
return int(cv2.CAP_ANY)
+35 -185
View File
@@ -34,8 +34,7 @@ import cv2
import numpy as np
from numpy.typing import NDArray
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
@@ -46,12 +45,6 @@ logger = logging.getLogger(__name__)
class ZMQCamera(Camera):
"""
Manages camera interactions via ZeroMQ for receiving frames from a remote server.
This class connects to a ZMQ Publisher, subscribes to frame topics, and decodes
incoming JSON messages containing Base64 encoded images. It supports both
synchronous and asynchronous frame reading patterns.
Example usage:
```python
from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig
@@ -59,16 +52,7 @@ class ZMQCamera(Camera):
config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera")
camera = ZMQCamera(config)
camera.connect()
# Read 1 frame synchronously (blocking)
color_image = camera.read()
# Read 1 frame asynchronously (waits for new frame with a timeout)
async_image = camera.async_read()
# Get the latest frame immediately (no wait, returns timestamp)
latest_image, timestamp = camera.read_latest()
frame = camera.read()
camera.disconnect()
```
"""
@@ -84,17 +68,14 @@ class ZMQCamera(Camera):
self.color_mode = config.color_mode
self.timeout_ms = config.timeout_ms
# ZMQ Context and Socket
self.context: zmq.Context | None = None
self.socket: zmq.Socket | None = None
self._connected = False
# Threading resources
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_timestamp: float | None = None
self.new_frame_event: Event = Event()
def __str__(self) -> str:
@@ -102,17 +83,12 @@ class ZMQCamera(Camera):
@property
def is_connected(self) -> bool:
"""Checks if the ZMQ socket is initialized and connected."""
return self._connected and self.context is not None and self.socket is not None
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""Connect to ZMQ camera server.
Args:
warmup (bool): If True, waits for the camera to provide at least one
valid frame before returning. Defaults to True.
"""
"""Connect to ZMQ camera server."""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
logger.info(f"Connecting to {self}...")
@@ -127,28 +103,17 @@ class ZMQCamera(Camera):
self.socket.connect(f"tcp://{self.server_address}:{self.port}")
self._connected = True
# Auto-detect resolution if not provided
# Auto-detect resolution
if self.width is None or self.height is None:
# Read directly from hardware because the thread isn't running yet
temp_frame = self._read_from_hardware()
h, w = temp_frame.shape[:2]
h, w = self.read().shape[:2]
self.height = h
self.width = w
logger.info(f"{self} resolution detected: {w}x{h}")
logger.info(f"{self} resolution: {w}x{h}")
self._start_read_thread()
logger.info(f"{self} connected.")
if warmup:
# Ensure we have captured at least one frame via the thread
start_time = time.time()
while time.time() - start_time < (self.config.warmup_s): # Wait a bit more than timeout
self.async_read(timeout_ms=self.config.warmup_s * 1000)
time.sleep(0.1)
with self.frame_lock:
if self.latest_frame is None:
raise ConnectionError(f"{self} failed to capture frames during warmup.")
time.sleep(0.1)
except Exception as e:
self._cleanup()
@@ -166,14 +131,15 @@ class ZMQCamera(Camera):
@staticmethod
def find_cameras() -> list[dict[str, Any]]:
"""
Detection not implemented for ZMQ cameras. These cameras require manual configuration (server address/port).
"""
raise NotImplementedError("Camera detection is not implemented for ZMQ cameras.")
"""ZMQ cameras require manual configuration (server address/port)."""
return []
def _read_from_hardware(self) -> NDArray[Any]:
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame directly from the ZMQ socket.
Read a single frame from the ZMQ camera.
Returns:
np.ndarray: Decoded frame (height, width, 3)
"""
if not self.is_connected or self.socket is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -181,7 +147,6 @@ class ZMQCamera(Camera):
try:
message = self.socket.recv_string()
except Exception as e:
# Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import
if type(e).__name__ == "Again":
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
raise
@@ -211,114 +176,42 @@ class ZMQCamera(Camera):
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
This is a blocking call. It waits for the next available frame from the
camera background thread.
Returns:
np.ndarray: Decoded frame (height, width, 3)
"""
start_time = time.perf_counter()
if color_mode is not None:
logger.warning(
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
self.new_frame_event.clear()
frame = self.async_read(timeout_ms=10000)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return frame
def _read_loop(self) -> None:
"""
Internal loop run by the background thread for asynchronous reading.
"""
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized.")
failure_count = 0
while not self.stop_event.is_set():
while self.stop_event and not self.stop_event.is_set():
try:
frame = self._read_from_hardware()
capture_time = time.perf_counter()
frame = self.read()
with self.frame_lock:
self.latest_frame = frame
self.latest_timestamp = capture_time
self.new_frame_event.set()
failure_count = 0
except DeviceNotConnectedError:
break
except (TimeoutError, Exception) as e:
if failure_count <= 10:
failure_count += 1
logger.warning(f"Read error: {e}")
else:
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
except TimeoutError:
pass
except Exception as e:
logger.warning(f"Read error: {e}")
def _start_read_thread(self) -> None:
if self.stop_event is not None:
self.stop_event.set()
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
if self.thread and self.thread.is_alive():
return
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, daemon=True, name=f"{self}_read_loop")
self.thread = Thread(target=self._read_loop, daemon=True)
self.thread.start()
time.sleep(0.1)
def _stop_read_thread(self) -> None:
if self.stop_event is not None:
if self.stop_event:
self.stop_event.set()
if self.thread is not None and self.thread.is_alive():
if self.thread and self.thread.is_alive():
self.thread.join(timeout=2.0)
self.thread = None
self.stop_event = None
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]:
"""Read latest frame asynchronously (non-blocking)."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
to become available. Defaults to 200ms.
Returns:
np.ndarray: The latest captured frame.
Raises:
DeviceNotConnectedError: If the camera is not connected.
TimeoutError: If no frame data becomes available within the specified timeout.
RuntimeError: If the background thread is not running.
"""
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
if not self.thread or not self.thread.is_alive():
self._start_read_thread()
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms")
@@ -332,54 +225,11 @@ class ZMQCamera(Camera):
return frame
@check_if_not_connected
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Returns:
NDArray[Any]: The frame image (numpy array).
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
frame = self.latest_frame
timestamp = self.latest_timestamp
if frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return frame
def disconnect(self) -> None:
"""Disconnect from ZMQ camera."""
if not self.is_connected and self.thread is None:
if not self.is_connected and not self.thread:
raise DeviceNotConnectedError(f"{self} not connected.")
if self.thread is not None:
self._stop_read_thread()
self._stop_read_thread()
self._cleanup()
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
logger.info(f"{self} disconnected.")
+4 -2
View File
@@ -29,10 +29,12 @@ class ZMQCameraConfig(CameraConfig):
camera_name: str = "zmq_camera"
color_mode: ColorMode = ColorMode.RGB
timeout_ms: int = 5000
warmup_s: int = 1
def __post_init__(self) -> None:
self.color_mode = ColorMode(self.color_mode)
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.timeout_ms <= 0:
raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.")
+16 -6
View File
@@ -45,12 +45,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
input_shapes: A dictionary defining the shapes of the input data for the policy.
output_shapes: A dictionary defining the shapes of the output data for the policy.
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
normalization mode to apply.
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
the original scale.
"""
n_obs_steps: int = 1
@@ -105,6 +105,16 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
raise NotImplementedError
@property
def image_observation_delta_indices(self) -> list | None: # type: ignore[type-arg]
"""Return indices for delta image observations only.
Unlike observation_delta_indices which applies to ALL observations,
this only applies to image observations (keys starting with observation.images).
Default returns None. Override in subclass to enable.
"""
return None
@property
@abc.abstractmethod
def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
-12
View File
@@ -211,15 +211,3 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig):
# NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
# Algorithm name registered in RLAlgorithmConfig registry
algorithm: str = "sac"
# Data mixer strategy name. Currently supports "online_offline"
mixer: str = "online_offline"
# Fraction sampled from online replay when using OnlineOfflineMixer
online_ratio: float = 0.5
# RL trainer iterator
async_prefetch: bool = True
queue_size: int = 2
+18 -82
View File
@@ -116,9 +116,6 @@ def update_meta_data(
Adjusts all indices and timestamps to account for previously aggregated
data and videos in the destination dataset.
For data file indices, uses the 'src_to_dst' mapping from aggregate_data()
to correctly map source file indices to their destination locations.
Args:
df: DataFrame containing the metadata to be updated.
dst_meta: Destination dataset metadata.
@@ -132,50 +129,8 @@ def update_meta_data(
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
# Update data file indices using source-to-destination mapping
# This is critical for handling datasets that are already results of a merge
data_src_to_dst = data_idx.get("src_to_dst", {})
if data_src_to_dst:
# Store original indices for lookup
df["_orig_data_chunk"] = df["data/chunk_index"].copy()
df["_orig_data_file"] = df["data/file_index"].copy()
# Vectorized mapping from (src_chunk, src_file) to (dst_chunk, dst_file)
# This is much faster than per-row iteration for large metadata tables
mapping_index = pd.MultiIndex.from_tuples(
list(data_src_to_dst.keys()),
names=["chunk_index", "file_index"],
)
mapping_values = list(data_src_to_dst.values())
mapping_df = pd.DataFrame(
mapping_values,
index=mapping_index,
columns=["dst_chunk", "dst_file"],
)
# Construct a MultiIndex for each row based on original data indices
row_index = pd.MultiIndex.from_arrays(
[df["_orig_data_chunk"], df["_orig_data_file"]],
names=["chunk_index", "file_index"],
)
# Align mapping to rows; missing keys fall back to the default destination
reindexed = mapping_df.reindex(row_index)
reindexed[["dst_chunk", "dst_file"]] = reindexed[["dst_chunk", "dst_file"]].fillna(
{"dst_chunk": data_idx["chunk"], "dst_file": data_idx["file"]}
)
# Assign mapped destination indices back to the DataFrame
df["data/chunk_index"] = reindexed["dst_chunk"].to_numpy()
df["data/file_index"] = reindexed["dst_file"].to_numpy()
# Clean up temporary columns
df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"])
else:
# Fallback to simple offset (backward compatibility for single-file sources)
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
for key, video_idx in videos_idx.items():
# Store original video file indices before updating
orig_chunk_col = f"videos/{key}/chunk_index"
@@ -191,7 +146,8 @@ def update_meta_data(
if src_to_dst:
# Map each episode to its correct destination file and apply offset
for idx in df.index:
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
# Convert to Python int to avoid numpy type mismatch in dict lookup
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
# Get destination chunk/file for this source file
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
@@ -207,7 +163,8 @@ def update_meta_data(
df[orig_chunk_col] = video_idx["chunk"]
df[orig_file_col] = video_idx["file"]
for idx in df.index:
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
# Convert to Python int to avoid numpy type mismatch in dict lookup
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
offset = src_to_offset.get(src_key, 0)
df.at[idx, f"videos/{key}/from_timestamp"] += offset
df.at[idx, f"videos/{key}/to_timestamp"] += offset
@@ -305,10 +262,6 @@ def aggregate_datasets(
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
# Clear the src_to_dst mapping after processing each source dataset
# to avoid interference between different source datasets
data_idx.pop("src_to_dst", None)
dst_meta.info["total_episodes"] += src_meta.total_episodes
dst_meta.info["total_frames"] += src_meta.total_frames
@@ -359,6 +312,10 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
dst_file_durations = video_idx["dst_file_durations"]
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
# Convert to Python int to ensure consistent dict keys
src_chunk_idx = int(src_chunk_idx)
src_file_idx = int(src_file_idx)
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key,
chunk_index=src_chunk_idx,
@@ -431,16 +388,10 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
Reads source data files, updates indices to match the aggregated dataset,
and writes them to the destination with proper file rotation.
Tracks a `src_to_dst` mapping from source (chunk, file) to destination (chunk, file)
which is critical for correctly updating episode metadata when source datasets
have multiple data files (e.g., from a previous merge operation).
Args:
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
data_idx: Dictionary tracking data chunk and file indices.
data_files_size_in_mb: Maximum size for data files in MB.
chunk_size: Maximum number of files per chunk.
Returns:
dict: Updated data_idx with current chunk and file indices.
@@ -458,10 +409,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
# retrieve features schema for proper image typing in parquet
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
# Track source to destination file mapping for metadata update
# This is critical for handling datasets that are already results of a merge
src_to_dst: dict[tuple[int, int], tuple[int, int]] = {}
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
chunk_index=src_chunk_idx, file_index=src_file_idx
@@ -474,9 +421,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta)
# Write data and get the actual destination file it was written to
# This avoids duplicating the rotation logic here
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
data_idx = append_or_create_parquet_file(
df,
src_path,
data_idx,
@@ -488,12 +433,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
hf_features=hf_features,
)
# Record the mapping from source to actual destination
src_to_dst[(src_chunk_idx, src_file_idx)] = (dst_chunk, dst_file)
# Add the mapping to data_idx for use in metadata update
data_idx["src_to_dst"] = src_to_dst
return data_idx
@@ -534,7 +473,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
videos_idx,
)
meta_idx, _ = append_or_create_parquet_file(
meta_idx = append_or_create_parquet_file(
df,
src_path,
meta_idx,
@@ -562,7 +501,7 @@ def append_or_create_parquet_file(
contains_images: bool = False,
aggr_root: Path = None,
hf_features: datasets.Features | None = None,
) -> tuple[dict[str, int], tuple[int, int]]:
):
"""Appends data to an existing parquet file or creates a new one based on size constraints.
Manages file rotation when size limits are exceeded to prevent individual files
@@ -580,11 +519,9 @@ def append_or_create_parquet_file(
hf_features: Optional HuggingFace Features schema for proper image typing.
Returns:
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
and (dst_chunk, dst_file) is the actual destination file the data was written to.
dict: Updated index dictionary with current chunk and file indices.
"""
dst_chunk, dst_file = idx["chunk"], idx["file"]
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
if not dst_path.exists():
dst_path.parent.mkdir(parents=True, exist_ok=True)
@@ -592,15 +529,14 @@ def append_or_create_parquet_file(
to_parquet_with_hf_images(df, dst_path, features=hf_features)
else:
df.to_parquet(dst_path)
return idx, (dst_chunk, dst_file)
return idx
src_size = get_parquet_file_size_in_mb(src_path)
dst_size = get_parquet_file_size_in_mb(dst_path)
if dst_size + src_size >= max_mb:
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
dst_chunk, dst_file = idx["chunk"], idx["file"]
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
new_path.parent.mkdir(parents=True, exist_ok=True)
final_df = df
target_path = new_path
@@ -619,7 +555,7 @@ def append_or_create_parquet_file(
else:
final_df.to_parquet(target_path)
return idx, (dst_chunk, dst_file)
return idx
def finalize_aggregation(aggr_meta, all_metadata):
-126
View File
@@ -1396,132 +1396,6 @@ BYTES_PER_KIB = 1024
BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB
def modify_tasks(
dataset: LeRobotDataset,
new_task: str | None = None,
episode_tasks: dict[int, str] | None = None,
) -> LeRobotDataset:
"""Modify tasks in a LeRobotDataset.
This function allows you to either:
1. Set a single task for the entire dataset (using `new_task`)
2. Set specific tasks for specific episodes (using `episode_tasks`)
You can combine both: `new_task` sets the default, and `episode_tasks` overrides
specific episodes.
The dataset is modified in-place, updating only the task-related files:
- meta/tasks.parquet
- data/**/*.parquet (task_index column)
- meta/episodes/**/*.parquet (tasks column)
- meta/info.json (total_tasks)
Args:
dataset: The source LeRobotDataset to modify.
new_task: A single task string to apply to all episodes. If None and episode_tasks
is also None, raises an error.
episode_tasks: Optional dict mapping episode indices to their task strings.
Overrides `new_task` for specific episodes.
Examples:
Set a single task for all episodes:
dataset = modify_tasks(dataset, new_task="Pick up the cube")
Set different tasks for specific episodes:
dataset = modify_tasks(
dataset,
episode_tasks={0: "Task A", 1: "Task B", 2: "Task A"}
)
Set a default task with overrides:
dataset = modify_tasks(
dataset,
new_task="Default task",
episode_tasks={5: "Special task for episode 5"}
)
"""
if new_task is None and episode_tasks is None:
raise ValueError("Must specify at least one of new_task or episode_tasks")
if episode_tasks is not None:
valid_indices = set(range(dataset.meta.total_episodes))
invalid = set(episode_tasks.keys()) - valid_indices
if invalid:
raise ValueError(f"Invalid episode indices: {invalid}")
# Ensure episodes metadata is loaded
if dataset.meta.episodes is None:
dataset.meta.episodes = load_episodes(dataset.root)
# Build the mapping from episode index to task string
episode_to_task: dict[int, str] = {}
for ep_idx in range(dataset.meta.total_episodes):
if episode_tasks and ep_idx in episode_tasks:
episode_to_task[ep_idx] = episode_tasks[ep_idx]
elif new_task is not None:
episode_to_task[ep_idx] = new_task
else:
# Keep original task if not overridden and no default provided
original_tasks = dataset.meta.episodes[ep_idx]["tasks"]
if not original_tasks:
raise ValueError(f"Episode {ep_idx} has no tasks and no default task was provided")
episode_to_task[ep_idx] = original_tasks[0]
# Collect all unique tasks and create new task mapping
unique_tasks = sorted(set(episode_to_task.values()))
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
logging.info(f"Modifying tasks in {dataset.repo_id}")
logging.info(f"New tasks: {unique_tasks}")
root = dataset.root
# Update data files - modify task_index column
logging.info("Updating data files...")
data_dir = root / DATA_DIR
for parquet_path in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Updating data"):
df = pd.read_parquet(parquet_path)
# Build a mapping from episode_index to new task_index for rows in this file
episode_indices_in_file = df["episode_index"].unique()
ep_to_new_task_idx = {
ep_idx: task_to_index[episode_to_task[ep_idx]] for ep_idx in episode_indices_in_file
}
# Update task_index column
df["task_index"] = df["episode_index"].map(ep_to_new_task_idx)
df.to_parquet(parquet_path, index=False)
# Update episodes metadata - modify tasks column
logging.info("Updating episodes metadata...")
episodes_dir = root / "meta" / "episodes"
for parquet_path in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Updating episodes"):
df = pd.read_parquet(parquet_path)
# Update tasks column
df["tasks"] = df["episode_index"].apply(lambda ep_idx: [episode_to_task[ep_idx]])
df.to_parquet(parquet_path, index=False)
# Write new tasks.parquet
write_tasks(new_task_df, root)
# Update info.json
dataset.meta.info["total_tasks"] = len(unique_tasks)
write_info(dataset.meta.info, root)
# Reload metadata to reflect changes
dataset.meta.tasks = new_task_df
dataset.meta.episodes = load_episodes(root)
logging.info(f"Tasks: {unique_tasks}")
return dataset
def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path,
+7 -2
View File
@@ -27,7 +27,7 @@ from lerobot.datasets.lerobot_dataset import (
)
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_PREFIX, REWARD
IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
@@ -59,7 +59,12 @@ def resolve_delta_timestamps(
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
if key == ACTION and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
# Check for image-specific delta indices first (e.g., for video encoding)
if key.startswith(OBS_IMAGES) and cfg.image_observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.image_observation_delta_indices]
# Fall back to generic observation delta indices for all observations
elif key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
if len(delta_timestamps) == 0:
+1 -10
View File
@@ -57,7 +57,6 @@ from lerobot.datasets.utils import (
load_info,
load_nested_dataset,
load_stats,
load_subtasks,
load_tasks,
update_chunk_file_indices,
validate_episode_buffer,
@@ -163,7 +162,6 @@ class LeRobotDatasetMetadata:
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.subtasks = load_subtasks(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
@@ -520,7 +518,6 @@ class LeRobotDatasetMetadata:
_validate_feature_names(features)
obj.tasks = None
obj.subtasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(
@@ -656,7 +653,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
will be stored under root/repo_id.
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
set the LEROBOT_HOME environment variable to point to a different location. Defaults to
'~/.cache/huggingface/lerobot'.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
@@ -1078,12 +1075,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks.iloc[task_idx].name
# add subtask information if available
if "subtask_index" in self.features and self.meta.subtasks is not None:
subtask_idx = item["subtask_index"].item()
item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name
return item
def __repr__(self):
+9 -10
View File
@@ -216,17 +216,16 @@ class ImageTransformsConfig:
def make_transform_from_config(cfg: ImageTransformConfig):
if cfg.type == "SharpnessJitter":
if cfg.type == "Identity":
return v2.Identity(**cfg.kwargs)
elif cfg.type == "ColorJitter":
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
transform_cls = getattr(v2, cfg.type, None)
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
return transform_cls(**cfg.kwargs)
raise ValueError(
f"Transform '{cfg.type}' is not valid. It must be a class in "
f"torchvision.transforms.v2 or 'SharpnessJitter'."
)
elif cfg.type == "RandomAffine":
return v2.RandomAffine(**cfg.kwargs)
else:
raise ValueError(f"Transform '{cfg.type}' is not valid.")
class ImageTransforms(Transform):
+13 -12
View File
@@ -60,7 +60,6 @@ VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
@@ -122,9 +121,19 @@ def load_nested_dataset(
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
with SuppressProgressBars():
# We use .from_parquet() memory-mapped loading for efficiency
filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
# PyArrow loads the entire dataset into memory
if episodes is None:
return Dataset.from_parquet([str(path) for path in paths], features=features)
arrow_dataset = pa_ds.dataset(paths, format="parquet")
filter_expr = pa_ds.field("episode_index").isin(episodes)
table = arrow_dataset.to_table(filter=filter_expr)
if features is not None:
table = table.cast(features.arrow_schema)
return Dataset(table)
def get_parquet_num_frames(parquet_path: str | Path) -> int:
@@ -344,14 +353,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
return tasks
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
"""Load subtasks from subtasks.parquet if it exists."""
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
if subtasks_path.exists():
return pd.read_parquet(subtasks_path)
return None
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
This function writes episode-level metadata to a single parquet file.
@@ -529,7 +529,7 @@ if __name__ == "__main__":
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
"(e.g. `lerobot/pusht`, `<USER>/aloha_sim_insertion_human`).",
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--branch",
+4 -6
View File
@@ -205,7 +205,6 @@ class ObservationConfig:
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
display_cameras: bool = False
@@ -261,7 +260,6 @@ class HILSerlRobotEnvConfig(EnvConfig):
@dataclass
class LiberoEnv(EnvConfig):
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
task_ids: list[int] | None = None
fps: int = 30
episode_length: int | None = None
obs_type: str = "pixels_agent_pos"
@@ -340,10 +338,10 @@ class LiberoEnv(EnvConfig):
@property
def gym_kwargs(self) -> dict:
kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
if self.task_ids is not None:
kwargs["task_ids"] = self.task_ids
return kwargs
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
}
@EnvConfig.register_subclass("metaworld")
+2 -7
View File
@@ -112,7 +112,6 @@ class LiberoEnv(gym.Env):
visualization_height: int = 480,
init_states: bool = True,
episode_index: int = 0,
n_envs: int = 1,
camera_name_mapping: dict[str, str] | None = None,
num_steps_wait: int = 10,
control_mode: str = "relative",
@@ -146,9 +145,7 @@ class LiberoEnv(gym.Env):
self.episode_length = episode_length
# Load once and keep
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`.
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._env = self._make_envs_task(task_suite, self.task_id)
default_steps = 500
@@ -298,8 +295,7 @@ class LiberoEnv(gym.Env):
self._env.seed(seed)
raw_obs = self._env.reset()
if self.init_states and self._init_states is not None:
raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)])
self.init_state_id += self._reset_stride # Change init_state_id when reset
raw_obs = self._env.set_init_state(self._init_states[self._init_state_id])
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
# Step the simulator with a no-op action for a few frames so everything settles.
@@ -377,7 +373,6 @@ def _make_env_fns(
init_states=init_states,
episode_length=episode_length,
episode_index=episode_index,
n_envs=n_envs,
control_mode=control_mode,
**local_kwargs,
)
+1 -5
View File
@@ -14,8 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .motors_bus import (
Motor,
MotorCalibration,
MotorNormMode,
)
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus
+5 -7
View File
@@ -18,7 +18,7 @@ from dataclasses import dataclass
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"
from .motors_bus import MotorCalibration, MotorsBus
from lerobot.motors import MotorCalibration, MotorsBus
BAR_LEN, BAR_THICKNESS = 450, 8
HANDLE_R = 10
@@ -221,7 +221,7 @@ class RangeFinderGUI:
self.bus = bus
self.groups = groups if groups is not None else {"all": list(bus.motors)}
self.group_names = list(self.groups)
self.group_names = list(groups)
self.current_group = self.group_names[0]
if not bus.is_connected:
@@ -230,20 +230,18 @@ class RangeFinderGUI:
self.calibration = bus.read_calibration()
self.res_table = bus.model_resolution_table
self.present_cache = {
m: bus.read("Present_Position", m, normalize=False)
for motors in self.groups.values()
for m in motors
m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors
}
pygame.init()
self.font = pygame.font.Font(None, FONT_SIZE)
label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms)
label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms)
self.label_pad = label_pad
width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10
self.controls_bottom = 10 + SAVE_H
self.base_y = self.controls_bottom + TOP_GAP
height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40
height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40
self.screen = pygame.display.set_mode((width, height))
pygame.display.set_caption("Motors range finder")
-18
View File
@@ -1,18 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .damiao import DamiaoMotorsBus
from .tables import *
-859
View File
@@ -1,859 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Portions of this file are derived from DM_Control_Python by cmjang.
# Licensed under the MIT License; see `LICENSE` for the full text:
# https://github.com/cmjang/DM_Control_Python
import logging
import time
from contextlib import contextmanager
from copy import deepcopy
from functools import cached_property
from typing import TYPE_CHECKING, Any, TypedDict
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import _can_available
if TYPE_CHECKING or _can_available:
import can
else:
class can: # noqa: N801
Message = object
interface = None
import numpy as np
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import enter_pressed, move_cursor_up
from ..motors_bus import Motor, MotorCalibration, MotorsBusBase, NameOrID, Value
from .tables import (
AVAILABLE_BAUDRATES,
CAN_CMD_DISABLE,
CAN_CMD_ENABLE,
CAN_CMD_REFRESH,
CAN_CMD_SET_ZERO,
CAN_PARAM_ID,
DEFAULT_BAUDRATE,
DEFAULT_TIMEOUT_MS,
MIT_KD_RANGE,
MIT_KP_RANGE,
MOTOR_LIMIT_PARAMS,
MotorType,
)
logger = logging.getLogger(__name__)
LONG_TIMEOUT_SEC = 0.1
MEDIUM_TIMEOUT_SEC = 0.01
SHORT_TIMEOUT_SEC = 0.001
PRECISE_TIMEOUT_SEC = 0.0001
class MotorState(TypedDict):
position: float
velocity: float
torque: float
temp_mos: float
temp_rotor: float
class DamiaoMotorsBus(MotorsBusBase):
"""
The Damiao implementation for a MotorsBus using CAN bus communication.
This class uses python-can for CAN bus communication with Damiao motors.
For more info, see:
- python-can documentation: https://python-can.readthedocs.io/en/stable/
- Seedstudio documentation: https://wiki.seeedstudio.com/damiao_series/
- DM_Control_Python repo: https://github.com/cmjang/DM_Control_Python
"""
# CAN-specific settings
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
default_baudrate = DEFAULT_BAUDRATE
default_timeout = DEFAULT_TIMEOUT_MS
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
can_interface: str = "auto",
use_can_fd: bool = True,
bitrate: int = 1000000,
data_bitrate: int | None = 5000000,
):
"""
Initialize the Damiao motors bus.
Args:
port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS)
motors: Dictionary mapping motor names to Motor objects
calibration: Optional calibration data
can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial)
use_can_fd: Whether to use CAN FD mode (default: True for OpenArms)
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
"""
super().__init__(port, motors, calibration)
self.port = port
self.can_interface = can_interface
self.use_can_fd = use_can_fd
self.bitrate = bitrate
self.data_bitrate = data_bitrate
self.canbus: can.interface.Bus | None = None
self._is_connected = False
# Map motor names to CAN IDs
self._motor_can_ids: dict[str, int] = {}
self._recv_id_to_motor: dict[int, str] = {}
self._motor_types: dict[str, MotorType] = {}
for name, motor in self.motors.items():
if motor.motor_type_str is None:
raise ValueError(f"Motor '{name}' is missing required 'motor_type'")
self._motor_types[name] = getattr(MotorType, motor.motor_type_str.upper().replace("-", "_"))
# Map recv_id to motor name for filtering responses
if motor.recv_id is not None:
self._recv_id_to_motor[motor.recv_id] = name
# State cache for handling packet drops safely
self._last_known_states: dict[str, MotorState] = {
name: {
"position": 0.0,
"velocity": 0.0,
"torque": 0.0,
"temp_mos": 0.0,
"temp_rotor": 0.0,
}
for name in self.motors
}
# Dynamic gains storage
# Defaults: Kp=10.0 (Stiffness), Kd=0.5 (Damping)
self._gains: dict[str, dict[str, float]] = {name: {"kp": 10.0, "kd": 0.5} for name in self.motors}
@property
def is_connected(self) -> bool:
"""Check if the CAN bus is connected."""
return self._is_connected and self.canbus is not None
@check_if_already_connected
def connect(self, handshake: bool = True) -> None:
"""
Open the CAN bus and initialize communication.
Args:
handshake: If True, ping all motors to verify they're present
"""
try:
# Auto-detect interface type based on port name
if self.can_interface == "auto":
if self.port.startswith("/dev/"):
self.can_interface = "slcan"
logger.info(f"Auto-detected slcan interface for port {self.port}")
else:
self.can_interface = "socketcan"
logger.info(f"Auto-detected socketcan interface for port {self.port}")
# Connect to CAN bus
kwargs = {
"channel": self.port,
"bitrate": self.bitrate,
"interface": self.can_interface,
}
if self.can_interface == "socketcan" and self.use_can_fd and self.data_bitrate is not None:
kwargs.update({"data_bitrate": self.data_bitrate, "fd": True})
logger.info(
f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})"
)
else:
logger.info(f"Connected to {self.port} with {self.can_interface} (bitrate={self.bitrate})")
self.canbus = can.interface.Bus(**kwargs)
self._is_connected = True
if handshake:
self._handshake()
logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.")
except Exception as e:
self._is_connected = False
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
def _handshake(self) -> None:
"""
Verify all motors are present and populate initial state cache.
Raises ConnectionError if any motor fails to respond.
"""
logger.info("Starting handshake with motors...")
# Drain any pending messages
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
while self.canbus.recv(timeout=0.01):
pass
missing_motors = []
for motor_name in self.motors:
motor_id = self._get_motor_id(motor_name)
recv_id = self._get_motor_recv_id(motor_name)
# Send enable command
data = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, CAN_CMD_ENABLE]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
# Wait for response with longer timeout
response = None
start_time = time.time()
while time.time() - start_time < 0.1:
response = self.canbus.recv(timeout=0.1)
if response and response.arbitration_id == recv_id:
break
response = None
if response is None:
missing_motors.append(motor_name)
else:
self._process_response(motor_name, msg)
time.sleep(MEDIUM_TIMEOUT_SEC)
if missing_motors:
raise ConnectionError(
f"Handshake failed. The following motors did not respond: {missing_motors}. "
"Check power (24V) and CAN wiring."
)
logger.info("Handshake successful. All motors ready.")
@check_if_not_connected
def disconnect(self, disable_torque: bool = True) -> None:
"""
Close the CAN bus connection.
Args:
disable_torque: If True, disable torque on all motors before disconnecting
"""
if disable_torque:
try:
self.disable_torque()
except Exception as e:
logger.warning(f"Failed to disable torque during disconnect: {e}")
if self.canbus:
self.canbus.shutdown()
self.canbus = None
self._is_connected = False
logger.debug(f"{self.__class__.__name__} disconnected.")
def configure_motors(self) -> None:
"""Configure all motors with default settings."""
# Damiao motors don't require much configuration in MIT mode
# Just ensure they're enabled
for motor in self.motors:
self._send_simple_command(motor, CAN_CMD_ENABLE)
time.sleep(MEDIUM_TIMEOUT_SEC)
def _send_simple_command(self, motor: NameOrID, command_byte: int) -> None:
"""Helper to send simple 8-byte commands (Enable, Disable, Zero)."""
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [command_byte]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
self.canbus.send(msg)
if msg := self._recv_motor_response(expected_recv_id=recv_id):
self._process_response(motor_name, msg)
else:
logger.debug(f"No response from {motor_name} after command 0x{command_byte:02X}")
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors."""
target_motors = self._get_motors_list(motors)
for motor in target_motors:
for _ in range(num_retry + 1):
try:
self._send_simple_command(motor, CAN_CMD_ENABLE)
break
except Exception as e:
if _ == num_retry:
raise e
time.sleep(MEDIUM_TIMEOUT_SEC)
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Disable torque on selected motors."""
target_motors = self._get_motors_list(motors)
for motor in target_motors:
for _ in range(num_retry + 1):
try:
self._send_simple_command(motor, CAN_CMD_DISABLE)
break
except Exception as e:
if _ == num_retry:
raise e
time.sleep(MEDIUM_TIMEOUT_SEC)
@contextmanager
def torque_disabled(self, motors: str | list[str] | None = None):
"""
Context manager that guarantees torque is re-enabled.
This helper is useful to temporarily disable torque when configuring motors.
"""
self.disable_torque(motors)
try:
yield
finally:
self.enable_torque(motors)
def set_zero_position(self, motors: str | list[str] | None = None) -> None:
"""Set current position as zero for selected motors."""
target_motors = self._get_motors_list(motors)
for motor in target_motors:
self._send_simple_command(motor, CAN_CMD_SET_ZERO)
time.sleep(MEDIUM_TIMEOUT_SEC)
def _refresh_motor(self, motor: NameOrID) -> can.Message | None:
"""Refresh motor status and return the response."""
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd)
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
self.canbus.send(msg)
return self._recv_motor_response(expected_recv_id=recv_id)
def _recv_motor_response(
self, expected_recv_id: int | None = None, timeout: float = 0.001
) -> can.Message | None:
"""
Receive a response from a motor.
Args:
expected_recv_id: If provided, only return messages from this CAN ID
timeout: Timeout in seconds (default: 1ms for high-speed operation)
Returns:
CAN message if received, None otherwise
"""
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
try:
start_time = time.time()
messages_seen = []
while time.time() - start_time < timeout:
msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC)
if msg:
messages_seen.append(f"0x{msg.arbitration_id:02X}")
if expected_recv_id is None or msg.arbitration_id == expected_recv_id:
return msg
logger.debug(
f"Ignoring message from 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}"
)
if logger.isEnabledFor(logging.DEBUG):
if messages_seen:
logger.debug(
f"Received {len(messages_seen)} msgs from {set(messages_seen)}, expected 0x{expected_recv_id:02X}"
)
else:
logger.debug(f"No CAN messages received (expected 0x{expected_recv_id:02X})")
except Exception as e:
logger.debug(f"Failed to receive CAN message: {e}")
return None
def _recv_all_responses(
self, expected_recv_ids: list[int], timeout: float = 0.002
) -> dict[int, can.Message]:
"""
Efficiently receive responses from multiple motors at once.
Uses the OpenArms pattern: collect all available messages within timeout.
Args:
expected_recv_ids: List of CAN IDs we expect responses from
timeout: Total timeout in seconds (default: 2ms)
Returns:
Dictionary mapping recv_id to CAN message
"""
responses: dict[int, can.Message] = {}
expected_set = set(expected_recv_ids)
start_time = time.time()
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
try:
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
# 100us poll timeout
msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC)
if msg and msg.arbitration_id in expected_set:
responses[msg.arbitration_id] = msg
if len(responses) == len(expected_recv_ids):
break
except Exception as e:
logger.debug(f"Error receiving responses: {e}")
return responses
def _encode_mit_packet(
self,
motor_type: MotorType,
kp: float,
kd: float,
position_degrees: float,
velocity_deg_per_sec: float,
torque: float,
) -> list[int]:
"""Helper to encode control parameters into 8 bytes for MIT mode."""
# Convert degrees to radians
position_rad = np.radians(position_degrees)
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
# Get motor limits
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
# Encode parameters
kp_uint = self._float_to_uint(kp, *MIT_KP_RANGE, 12)
kd_uint = self._float_to_uint(kd, *MIT_KD_RANGE, 12)
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
# Pack data
data = [0] * 8
data[0] = (q_uint >> 8) & 0xFF
data[1] = q_uint & 0xFF
data[2] = dq_uint >> 4
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
data[4] = kp_uint & 0xFF
data[5] = kd_uint >> 4
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
data[7] = tau_uint & 0xFF
return data
def _mit_control(
self,
motor: NameOrID,
kp: float,
kd: float,
position_degrees: float,
velocity_deg_per_sec: float,
torque: float,
) -> None:
"""Send MIT control command to a motor."""
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
recv_id = self._get_motor_recv_id(motor)
if msg := self._recv_motor_response(expected_recv_id=recv_id):
self._process_response(motor_name, msg)
else:
logger.debug(f"No response from {motor_name} after MIT control command")
def _mit_control_batch(
self,
commands: dict[NameOrID, tuple[float, float, float, float, float]],
) -> None:
"""
Send MIT control commands to multiple motors in batch.
Sends all commands first, then collects responses.
Args:
commands: Dict mapping motor name/ID to (kp, kd, position_deg, velocity_deg/s, torque)
Example: {'joint_1': (10.0, 0.5, 45.0, 0.0, 0.0), ...}
"""
if not commands:
return
recv_id_to_motor: dict[int, str] = {}
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
# Step 1: Send all MIT control commands
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
# Step 2: Collect responses and update state cache
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=SHORT_TIMEOUT_SEC)
for recv_id, motor_name in recv_id_to_motor.items():
if msg := responses.get(recv_id):
self._process_response(motor_name, msg)
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
"""Convert float to unsigned integer for CAN transmission."""
x = max(x_min, min(x_max, x)) # Clamp to range
span = x_max - x_min
data_norm = (x - x_min) / span
return int(data_norm * ((1 << bits) - 1))
def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float:
"""Convert unsigned integer from CAN to float."""
span = x_max - x_min
data_norm = float(x) / ((1 << bits) - 1)
return data_norm * span + x_min
def _decode_motor_state(
self, data: bytearray | bytes, motor_type: MotorType
) -> tuple[float, float, float, int, int]:
"""
Decode motor state from CAN data.
Returns: (position_deg, velocity_deg_s, torque, temp_mos, temp_rotor)
"""
if len(data) < 8:
raise ValueError("Invalid motor state data")
# Extract encoded values
q_uint = (data[1] << 8) | data[2]
dq_uint = (data[3] << 4) | (data[4] >> 4)
tau_uint = ((data[4] & 0x0F) << 8) | data[5]
t_mos = data[6]
t_rotor = data[7]
# Get motor limits
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
# Decode to physical values
position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16)
velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12)
torque = self._uint_to_float(tau_uint, -tmax, tmax, 12)
return np.degrees(position_rad), np.degrees(velocity_rad_per_sec), torque, t_mos, t_rotor
def _process_response(self, motor: str, msg: can.Message) -> None:
"""Decode a message and update the motor state cache."""
try:
motor_type = self._motor_types[motor]
pos, vel, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
self._last_known_states[motor] = {
"position": pos,
"velocity": vel,
"torque": torque,
"temp_mos": float(t_mos),
"temp_rotor": float(t_rotor),
}
except Exception as e:
logger.warning(f"Failed to decode response from {motor}: {e}")
@check_if_not_connected
def read(self, data_name: str, motor: str) -> Value:
"""Read a value from a single motor. Positions are always in degrees."""
# Refresh motor to get latest state
msg = self._refresh_motor(motor)
if msg is None:
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
raise ConnectionError(
f"No response from motor '{motor}' (send ID: 0x{motor_id:02X}, recv ID: 0x{recv_id:02X}). "
f"Check that: 1) Motor is powered (24V), 2) CAN wiring is correct, "
f"3) Motor IDs are configured correctly using Damiao Debugging Tools"
)
self._process_response(motor, msg)
return self._get_cached_value(motor, data_name)
def _get_cached_value(self, motor: str, data_name: str) -> Value:
"""Retrieve a specific value from the cache."""
state = self._last_known_states[motor]
mapping: dict[str, Any] = {
"Present_Position": state["position"],
"Present_Velocity": state["velocity"],
"Present_Torque": state["torque"],
"Temperature_MOS": state["temp_mos"],
"Temperature_Rotor": state["temp_rotor"],
}
if data_name not in mapping:
raise ValueError(f"Unknown data_name: {data_name}")
return mapping[data_name]
@check_if_not_connected
def write(
self,
data_name: str,
motor: str,
value: Value,
) -> None:
"""
Write a value to a single motor. Positions are always in degrees.
Can write 'Goal_Position', 'Kp', or 'Kd'.
"""
if data_name in ("Kp", "Kd"):
self._gains[motor][data_name.lower()] = float(value)
elif data_name == "Goal_Position":
kp = self._gains[motor]["kp"]
kd = self._gains[motor]["kd"]
self._mit_control(motor, kp, kd, float(value), 0.0, 0.0)
else:
raise ValueError(f"Writing {data_name} not supported in MIT mode")
def sync_read(
self,
data_name: str,
motors: str | list[str] | None = None,
) -> dict[str, Value]:
"""
Read the same value from multiple motors simultaneously.
"""
target_motors = self._get_motors_list(motors)
self._batch_refresh(target_motors)
result = {}
for motor in target_motors:
result[motor] = self._get_cached_value(motor, data_name)
return result
def sync_read_all_states(
self,
motors: str | list[str] | None = None,
*,
num_retry: int = 0,
) -> dict[str, MotorState]:
"""
Read ALL motor states (position, velocity, torque) from multiple motors in ONE refresh cycle.
Returns:
Dictionary mapping motor names to state dicts with keys: 'position', 'velocity', 'torque'
Example: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
"""
target_motors = self._get_motors_list(motors)
self._batch_refresh(target_motors)
result = {}
for motor in target_motors:
result[motor] = self._last_known_states[motor].copy()
return result
def _batch_refresh(self, motors: list[str]) -> None:
"""Internal helper to refresh a list of motors and update cache."""
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
# Send refresh commands
for motor in motors:
motor_id = self._get_motor_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(
arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd
)
self.canbus.send(msg)
# Collect responses
expected_recv_ids = [self._get_motor_recv_id(m) for m in motors]
responses = self._recv_all_responses(expected_recv_ids, timeout=MEDIUM_TIMEOUT_SEC)
# Update cache
for motor in motors:
recv_id = self._get_motor_recv_id(motor)
msg = responses.get(recv_id)
if msg:
self._process_response(motor, msg)
else:
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
@check_if_not_connected
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
"""
Write values to multiple motors simultaneously. Positions are always in degrees.
"""
if data_name in ("Kp", "Kd"):
key = data_name.lower()
for motor, val in values.items():
self._gains[motor][key] = float(val)
elif data_name == "Goal_Position":
# Step 1: Send all MIT control commands
recv_id_to_motor: dict[int, str] = {}
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
for motor, value_degrees in values.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
kp = self._gains[motor]["kp"]
kd = self._gains[motor]["kd"]
data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0)
msg = can.Message(
arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd
)
self.canbus.send(msg)
precise_sleep(PRECISE_TIMEOUT_SEC)
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
# Step 2: Collect responses and update state cache
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=MEDIUM_TIMEOUT_SEC)
for recv_id, motor_name in recv_id_to_motor.items():
if msg := responses.get(recv_id):
self._process_response(motor_name, msg)
else:
# Fall back to individual writes
for motor, value in values.items():
self.write(data_name, motor, value)
def read_calibration(self) -> dict[str, MotorCalibration]:
"""Read calibration data from motors."""
# Damiao motors don't store calibration internally
# Return existing calibration or empty dict
return self.calibration if self.calibration else {}
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
"""Write calibration data to motors."""
# Damiao motors don't store calibration internally
# Just cache it in memory
if cache:
self.calibration = calibration_dict
def record_ranges_of_motion(
self,
motors: str | list[str] | None = None,
display_values: bool = True,
) -> tuple[dict[str, Value], dict[str, Value]]:
"""
Interactively record the min/max values of each motor in degrees.
Move the joints by hand (with torque disabled) while the method streams live positions.
Press Enter to finish.
"""
target_motors = self._get_motors_list(motors)
self.disable_torque(target_motors)
time.sleep(LONG_TIMEOUT_SEC)
start_positions = self.sync_read("Present_Position", target_motors)
mins = start_positions.copy()
maxes = start_positions.copy()
print("\nMove joints through their full range of motion. Press ENTER when done.")
user_pressed_enter = False
while not user_pressed_enter:
positions = self.sync_read("Present_Position", target_motors)
for motor in target_motors:
if motor in positions:
mins[motor] = min(positions[motor], mins.get(motor, positions[motor]))
maxes[motor] = max(positions[motor], maxes.get(motor, positions[motor]))
if display_values:
print("\n" + "=" * 50)
print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}")
print("-" * 50)
for motor in target_motors:
if motor in positions:
print(
f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}"
)
if enter_pressed():
user_pressed_enter = True
if display_values and not user_pressed_enter:
move_cursor_up(len(target_motors) + 4)
time.sleep(LONG_TIMEOUT_SEC)
self.enable_torque(target_motors)
for motor in target_motors:
if (motor in mins) and (motor in maxes) and (int(abs(maxes[motor] - mins[motor])) < 5):
raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)")
return mins, maxes
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
"""Convert motor specification to list of motor names."""
if motors is None:
return list(self.motors.keys())
elif isinstance(motors, str):
return [motors]
elif isinstance(motors, list):
return motors
else:
raise TypeError(f"Invalid motors type: {type(motors)}")
def _get_motor_id(self, motor: NameOrID) -> int:
"""Get CAN ID for a motor."""
if isinstance(motor, str):
if motor in self.motors:
return self.motors[motor].id
else:
raise ValueError(f"Unknown motor: {motor}")
else:
return motor
def _get_motor_name(self, motor: NameOrID) -> str:
"""Get motor name from name or ID."""
if isinstance(motor, str):
return motor
else:
for name, m in self.motors.items():
if m.id == motor:
return name
raise ValueError(f"Unknown motor ID: {motor}")
def _get_motor_recv_id(self, motor: NameOrID) -> int:
"""Get motor recv_id from name or ID."""
motor_name = self._get_motor_name(motor)
motor_obj = self.motors.get(motor_name)
if motor_obj and motor_obj.recv_id is not None:
return motor_obj.recv_id
else:
raise ValueError(f"Motor {motor_obj} doesn't have a valid recv_id (None).")
@cached_property
def is_calibrated(self) -> bool:
"""Check if motors are calibrated."""
return bool(self.calibration)
-209
View File
@@ -1,209 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration tables for Damiao motors."""
from enum import IntEnum
# Motor type definitions
class MotorType(IntEnum):
DM3507 = 0
DM4310 = 1
DM4310_48V = 2
DM4340 = 3
DM4340_48V = 4
DM6006 = 5
DM8006 = 6
DM8009 = 7
DM10010L = 8
DM10010 = 9
DMH3510 = 10
DMH6215 = 11
DMG6220 = 12
# Control modes
class ControlMode(IntEnum):
MIT = 1
POS_VEL = 2
VEL = 3
TORQUE_POS = 4
# Motor variable IDs (RID)
class MotorVariable(IntEnum):
UV_VALUE = 0
KT_VALUE = 1
OT_VALUE = 2
OC_VALUE = 3
ACC = 4
DEC = 5
MAX_SPD = 6
MST_ID = 7
ESC_ID = 8
TIMEOUT = 9
CTRL_MODE = 10
DAMP = 11
INERTIA = 12
HW_VER = 13
SW_VER = 14
SN = 15
NPP = 16
RS = 17
LS = 18
FLUX = 19
GR = 20
PMAX = 21
VMAX = 22
TMAX = 23
I_BW = 24
KP_ASR = 25
KI_ASR = 26
KP_APR = 27
KI_APR = 28
OV_VALUE = 29
GREF = 30
DETA = 31
V_BW = 32
IQ_C1 = 33
VL_C1 = 34
CAN_BR = 35
SUB_VER = 36
U_OFF = 50
V_OFF = 51
K1 = 52
K2 = 53
M_OFF = 54
DIR = 55
P_M = 80
XOUT = 81
# Motor limit parameters [PMAX, VMAX, TMAX]
# PMAX: Maximum position (rad)
# VMAX: Maximum velocity (rad/s)
# TMAX: Maximum torque (N·m)
MOTOR_LIMIT_PARAMS = {
MotorType.DM3507: (12.5, 30, 10),
MotorType.DM4310: (12.5, 30, 10),
MotorType.DM4310_48V: (12.5, 50, 10),
MotorType.DM4340: (12.5, 8, 28),
MotorType.DM4340_48V: (12.5, 10, 28),
MotorType.DM6006: (12.5, 45, 20),
MotorType.DM8006: (12.5, 45, 40),
MotorType.DM8009: (12.5, 45, 54),
MotorType.DM10010L: (12.5, 25, 200),
MotorType.DM10010: (12.5, 20, 200),
MotorType.DMH3510: (12.5, 280, 1),
MotorType.DMH6215: (12.5, 45, 10),
MotorType.DMG6220: (12.5, 45, 10),
}
# Motor model names
MODEL_NAMES = {
MotorType.DM3507: "dm3507",
MotorType.DM4310: "dm4310",
MotorType.DM4310_48V: "dm4310_48v",
MotorType.DM4340: "dm4340",
MotorType.DM4340_48V: "dm4340_48v",
MotorType.DM6006: "dm6006",
MotorType.DM8006: "dm8006",
MotorType.DM8009: "dm8009",
MotorType.DM10010L: "dm10010l",
MotorType.DM10010: "dm10010",
MotorType.DMH3510: "dmh3510",
MotorType.DMH6215: "dmh6215",
MotorType.DMG6220: "dmg6220",
}
# Motor resolution table (encoder counts per revolution)
MODEL_RESOLUTION = {
"dm3507": 65536,
"dm4310": 65536,
"dm4310_48v": 65536,
"dm4340": 65536,
"dm4340_48v": 65536,
"dm6006": 65536,
"dm8006": 65536,
"dm8009": 65536,
"dm10010l": 65536,
"dm10010": 65536,
"dmh3510": 65536,
"dmh6215": 65536,
"dmg6220": 65536,
}
# CAN baudrates supported by Damiao motors
AVAILABLE_BAUDRATES = [
125000, # 0: 125 kbps
200000, # 1: 200 kbps
250000, # 2: 250 kbps
500000, # 3: 500 kbps
1000000, # 4: 1 mbps (default for OpenArms)
2000000, # 5: 2 mbps
2500000, # 6: 2.5 mbps
3200000, # 7: 3.2 mbps
4000000, # 8: 4 mbps
5000000, # 9: 5 mbps
]
DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms
# Default timeout in milliseconds
DEFAULT_TIMEOUT_MS = 1000
# OpenArms specific configurations
# Based on: https://docs.openarm.dev/software/setup/configure-test
# OpenArms has 7 DOF per arm (14 total for dual arm)
OPENARMS_ARM_MOTOR_IDS = {
"joint_1": {"send": 0x01, "recv": 0x11}, # J1 - Shoulder pan
"joint_2": {"send": 0x02, "recv": 0x12}, # J2 - Shoulder lift
"joint_3": {"send": 0x03, "recv": 0x13}, # J3 - Elbow flex
"joint_4": {"send": 0x04, "recv": 0x14}, # J4 - Wrist flex
"joint_5": {"send": 0x05, "recv": 0x15}, # J5 - Wrist roll
"joint_6": {"send": 0x06, "recv": 0x16}, # J6 - Wrist pitch
"joint_7": {"send": 0x07, "recv": 0x17}, # J7 - Wrist rotation
}
OPENARMS_GRIPPER_MOTOR_IDS = {
"gripper": {"send": 0x08, "recv": 0x18}, # J8 - Gripper
}
# Default motor types for OpenArms
OPENARMS_DEFAULT_MOTOR_TYPES = {
"joint_1": MotorType.DM8009, # Shoulder pan - high torque
"joint_2": MotorType.DM8009, # Shoulder lift - high torque
"joint_3": MotorType.DM4340, # Shoulder rotation
"joint_4": MotorType.DM4340, # Elbow flex
"joint_5": MotorType.DM4310, # Wrist roll
"joint_6": MotorType.DM4310, # Wrist pitch
"joint_7": MotorType.DM4310, # Wrist rotation
"gripper": MotorType.DM4310, # Gripper
}
# MIT control parameter ranges
MIT_KP_RANGE = (0.0, 500.0)
MIT_KD_RANGE = (0.0, 5.0)
# CAN frame command IDs
CAN_CMD_ENABLE = 0xFC
CAN_CMD_DISABLE = 0xFD
CAN_CMD_SET_ZERO = 0xFE
CAN_CMD_REFRESH = 0xCC
CAN_CMD_QUERY_PARAM = 0x33
CAN_CMD_WRITE_PARAM = 0x55
CAN_CMD_SAVE_PARAM = 0xAA
# CAN ID for parameter operations
CAN_PARAM_ID = 0x7FF
+14 -13
View File
@@ -22,8 +22,9 @@ import logging
from copy import deepcopy
from enum import Enum
from ..encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from .tables import (
AVAILABLE_BAUDRATES,
MODEL_BAUDRATE_TABLE,
@@ -99,7 +100,7 @@ def _split_into_byte_chunks(value: int, length: int) -> list[int]:
return data
class DynamixelMotorsBus(SerialMotorsBus):
class DynamixelMotorsBus(MotorsBus):
"""
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
the motors. For more info, see the Dynamixel SDK Documentation:
@@ -181,10 +182,10 @@ class DynamixelMotorsBus(SerialMotorsBus):
for motor, m in self.motors.items():
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=int(drive_modes[motor]),
homing_offset=int(offsets[motor]),
range_min=int(mins[motor]),
range_max=int(maxes[motor]),
drive_mode=drive_modes[motor],
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
)
return calibration
@@ -198,15 +199,15 @@ class DynamixelMotorsBus(SerialMotorsBus):
if cache:
self.calibration = calibration_dict
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
@@ -235,7 +236,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
On Dynamixel Motors:
Present_Position = Actual_Position + Homing_Offset
"""
half_turn_homings: dict[NameOrID, Value] = {}
half_turn_homings = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
@@ -258,6 +259,6 @@ class DynamixelMotorsBus(SerialMotorsBus):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return None
return
return {id_: data[0] for id_, data in data_list.items()}
+16 -15
View File
@@ -17,8 +17,9 @@ from copy import deepcopy
from enum import Enum
from pprint import pformat
from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from .tables import (
FIRMWARE_MAJOR_VERSION,
FIRMWARE_MINOR_VERSION,
@@ -95,7 +96,7 @@ def patch_setPacketTimeout(self, packet_length): # noqa: N802
self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50
class FeetechMotorsBus(SerialMotorsBus):
class FeetechMotorsBus(MotorsBus):
"""
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
@@ -126,7 +127,7 @@ class FeetechMotorsBus(SerialMotorsBus):
self.port_handler = scs.PortHandler(self.port)
# HACK: monkeypatch
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign]
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
self.port_handler, scs.PortHandler
)
self.packet_handler = scs.PacketHandler(protocol_version)
@@ -262,9 +263,9 @@ class FeetechMotorsBus(SerialMotorsBus):
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=0,
homing_offset=int(offsets[motor]),
range_min=int(mins[motor]),
range_max=int(maxes[motor]),
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
)
return calibration
@@ -284,7 +285,7 @@ class FeetechMotorsBus(SerialMotorsBus):
On Feetech Motors:
Present_Position = Actual_Position - Homing_Offset
"""
half_turn_homings: dict[NameOrID, Value] = {}
half_turn_homings = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
@@ -292,18 +293,18 @@ class FeetechMotorsBus(SerialMotorsBus):
return half_turn_homings
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self.write("Lock", motor, 0, num_retry=num_retry)
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
addr, length = get_address(self.model_ctrl_table, model, "Lock")
self._write(addr, length, motor, 0, num_retry=num_retry)
self._write(addr, length, motor_id, 0, num_retry=num_retry)
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
self.write("Lock", motor, 1, num_retry=num_retry)
@@ -334,7 +335,7 @@ class FeetechMotorsBus(SerialMotorsBus):
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
import scservo_sdk as scs
data_list: dict[int, int] = {}
data_list = {}
status_length = 6
@@ -414,7 +415,7 @@ class FeetechMotorsBus(SerialMotorsBus):
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return None
return
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
if ids_errors:
+99 -183
View File
@@ -19,11 +19,8 @@
# TODO(aliberts): Add block noqa when feature below is available
# https://github.com/astral-sh/ruff/issues/3711
from __future__ import annotations
import abc
import logging
from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
@@ -44,81 +41,6 @@ Value: TypeAlias = int | float
logger = logging.getLogger(__name__)
class MotorsBusBase(abc.ABC):
"""
Base class for all motor bus implementations.
This is a minimal interface that all motor buses must implement, regardless of their
communication protocol (serial, CAN, etc.).
"""
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
self.port = port
self.motors = motors
self.calibration = calibration if calibration else {}
@abc.abstractmethod
def connect(self, handshake: bool = True) -> None:
"""Establish connection to the motors."""
pass
@abc.abstractmethod
def disconnect(self, disable_torque: bool = True) -> None:
"""Disconnect from the motors."""
pass
@property
@abc.abstractmethod
def is_connected(self) -> bool:
"""Check if connected to the motors."""
pass
@abc.abstractmethod
def read(self, data_name: str, motor: str) -> Value:
"""Read a value from a single motor."""
pass
@abc.abstractmethod
def write(self, data_name: str, motor: str, value: Value) -> None:
"""Write a value to a single motor."""
pass
@abc.abstractmethod
def sync_read(self, data_name: str, motors: str | list[str] | None = None) -> dict[str, Value]:
"""Read a value from multiple motors."""
pass
@abc.abstractmethod
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
"""Write values to multiple motors."""
pass
@abc.abstractmethod
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors."""
pass
@abc.abstractmethod
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Disable torque on selected motors."""
pass
@abc.abstractmethod
def read_calibration(self) -> dict[str, MotorCalibration]:
"""Read calibration parameters from the motors."""
pass
@abc.abstractmethod
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
"""Write calibration parameters to the motors."""
pass
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
ctrl_table = model_ctrl_table.get(model)
if ctrl_table is None:
@@ -175,21 +97,18 @@ class Motor:
id: int
model: str
norm_mode: MotorNormMode
motor_type_str: str | None = None
recv_id: int | None = None
class PortHandler(Protocol):
is_open: bool
baudrate: int
packet_start_time: float
packet_timeout: float
tx_time_per_byte: float
is_using: bool
port_name: str
ser: serial.Serial
def __init__(self, port_name: str) -> None: ...
def __init__(self, port_name):
self.is_open: bool
self.baudrate: int
self.packet_start_time: float
self.packet_timeout: float
self.tx_time_per_byte: float
self.is_using: bool
self.port_name: str
self.ser: serial.Serial
def openPort(self): ...
def closePort(self): ...
@@ -242,22 +161,19 @@ class PacketHandler(Protocol):
def regWriteTxRx(self, port, id, address, length, data): ...
def syncReadTx(self, port, start_address, data_length, param, param_length): ...
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
def broadcastPing(self, port): ...
class GroupSyncRead(Protocol):
port: str
ph: PortHandler
start_address: int
data_length: int
last_result: bool
is_param_changed: bool
param: list
data_dict: dict
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.last_result: bool
self.is_param_changed: bool
self.param: list
self.data_dict: dict
def __init__(
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
) -> None: ...
def makeParam(self): ...
def addParam(self, id): ...
def removeParam(self, id): ...
@@ -270,17 +186,15 @@ class GroupSyncRead(Protocol):
class GroupSyncWrite(Protocol):
port: str
ph: PortHandler
start_address: int
data_length: int
is_param_changed: bool
param: list
data_dict: dict
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.is_param_changed: bool
self.param: list
self.data_dict: dict
def __init__(
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
) -> None: ...
def makeParam(self): ...
def addParam(self, id, data): ...
def removeParam(self, id): ...
@@ -289,15 +203,15 @@ class GroupSyncWrite(Protocol):
def txPacket(self): ...
class SerialMotorsBus(MotorsBusBase):
class MotorsBus(abc.ABC):
"""
A SerialMotorsBus allows to efficiently read and write to motors connected via serial communication.
A MotorsBus allows to efficiently read and write to the attached motors.
It represents several motors daisy-chained together and connected through a serial port.
There are currently two implementations of this class:
There are currently two implementations of this abstract class:
- DynamixelMotorsBus
- FeetechMotorsBus
This class is specifically for serial-based motor protocols (Dynamixel, Feetech, etc.).
Note: This class may evolve in the future should we add support for other types of bus.
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
To find the port, you can run our utility script:
@@ -346,7 +260,9 @@ class SerialMotorsBus(MotorsBusBase):
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
super().__init__(port, motors, calibration)
self.port = port
self.motors = motors
self.calibration = calibration if calibration else {}
self.port_handler: PortHandler
self.packet_handler: PacketHandler
@@ -407,7 +323,7 @@ class SerialMotorsBus(MotorsBusBase):
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_motor_model(self, motor: NameOrID) -> str:
def _get_motor_model(self, motor: NameOrID) -> int:
if isinstance(motor, str):
return self.motors[motor].model
elif isinstance(motor, int):
@@ -415,19 +331,17 @@ class SerialMotorsBus(MotorsBusBase):
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]:
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
if motors is None:
return list(self.motors)
elif isinstance(motors, str):
return [motors]
elif isinstance(motors, int):
return [self._id_to_name(motors)]
elif isinstance(motors, Sequence):
return [m if isinstance(m, str) else self._id_to_name(m) for m in motors]
elif isinstance(motors, list):
return motors.copy()
else:
raise TypeError(motors)
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]:
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
if isinstance(values, (int | float)):
return dict.fromkeys(self.ids, values)
elif isinstance(values, dict):
@@ -618,7 +532,7 @@ class SerialMotorsBus(MotorsBusBase):
self.set_baudrate(self.default_baudrate)
@abc.abstractmethod
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]:
pass
@abc.abstractmethod
@@ -631,13 +545,13 @@ class SerialMotorsBus(MotorsBusBase):
pass
@abc.abstractmethod
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
"""Disable torque on selected motors.
Disabling Torque allows to write to the motors' permanent memory area (EPROM/EEPROM).
Args:
motors ( str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a
motors (int | str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a
list of names or `None` to affect every registered motor. Defaults to `None`.
num_retry (int, optional): Number of additional retry attempts on communication failure.
Defaults to 0.
@@ -649,19 +563,18 @@ class SerialMotorsBus(MotorsBusBase):
pass
@abc.abstractmethod
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors.
Args:
motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`.
Defaults to `None`.
motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`.
num_retry (int, optional): Number of additional retry attempts on communication failure.
Defaults to 0.
"""
pass
@contextmanager
def torque_disabled(self, motors: str | list[str] | None = None):
def torque_disabled(self, motors: int | str | list[str] | None = None):
"""Context-manager that guarantees torque is re-enabled.
This helper is useful to temporarily disable torque when configuring motors.
@@ -738,19 +651,24 @@ class SerialMotorsBus(MotorsBusBase):
"""
pass
def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None:
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None:
"""Restore factory calibration for the selected motors.
Homing offset is set to ``0`` and min/max position limits are set to the full usable range.
The in-memory :pyattr:`calibration` is cleared.
Args:
motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default)
motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default)
resets every motor.
"""
motor_names = self._get_motors_list(motors)
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
for motor in motor_names:
for motor in motors:
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
self.write("Homing_Offset", motor, 0, normalize=False)
@@ -759,9 +677,7 @@ class SerialMotorsBus(MotorsBusBase):
self.calibration = {}
def set_half_turn_homings(
self, motors: NameOrID | Sequence[NameOrID] | None = None
) -> dict[NameOrID, Value]:
def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]:
"""Centre each motor range around its current position.
The function computes and writes a homing offset such that the present position becomes exactly one
@@ -771,12 +687,17 @@ class SerialMotorsBus(MotorsBusBase):
motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`).
Returns:
dict[str, Value]: Mapping *motor name written homing offset*.
dict[NameOrID, Value]: Mapping *motor written homing offset*.
"""
motor_names = self._get_motors_list(motors)
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
self.reset_calibration(motor_names)
actual_positions = self.sync_read("Present_Position", motor_names, normalize=False)
self.reset_calibration(motors)
actual_positions = self.sync_read("Present_Position", motors, normalize=False)
homing_offsets = self._get_half_turn_homings(actual_positions)
for motor, offset in homing_offsets.items():
self.write("Homing_Offset", motor, offset)
@@ -788,8 +709,8 @@ class SerialMotorsBus(MotorsBusBase):
pass
def record_ranges_of_motion(
self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[str, Value], dict[str, Value]]:
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
"""Interactively record the min/max encoder values of each motor.
Move the joints by hand (with torque disabled) while the method streams live positions. Press
@@ -801,25 +722,30 @@ class SerialMotorsBus(MotorsBusBase):
display_values (bool, optional): When `True` (default) a live table is printed to the console.
Returns:
tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the
tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the
extreme values observed for each motor.
"""
motor_names = self._get_motors_list(motors)
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
start_positions = self.sync_read("Present_Position", motor_names, normalize=False)
start_positions = self.sync_read("Present_Position", motors, normalize=False)
mins = start_positions.copy()
maxes = start_positions.copy()
user_pressed_enter = False
while not user_pressed_enter:
positions = self.sync_read("Present_Position", motor_names, normalize=False)
positions = self.sync_read("Present_Position", motors, normalize=False)
mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()}
maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()}
if display_values:
print("\n-------------------------------------------")
print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
for motor in motor_names:
for motor in motors:
print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}")
if enter_pressed():
@@ -827,9 +753,9 @@ class SerialMotorsBus(MotorsBusBase):
if display_values and not user_pressed_enter:
# Move cursor up to overwrite the previous output
move_cursor_up(len(motor_names) + 3)
move_cursor_up(len(motors) + 3)
same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]]
same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]]
if same_min_max:
raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}")
@@ -952,12 +878,12 @@ class SerialMotorsBus(MotorsBusBase):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
else:
return None
return
if self._is_error(error):
if raise_on_error:
raise RuntimeError(self.packet_handler.getRxPacketError(error))
else:
return None
return
return model_number
@@ -1004,13 +930,12 @@ class SerialMotorsBus(MotorsBusBase):
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
decoded = self._decode_sign(data_name, {id_: value})
id_value = self._decode_sign(data_name, {id_: value})
if normalize and data_name in self.normalized_data:
normalized = self._normalize(decoded)
return normalized[id_]
id_value = self._normalize(id_value)
return decoded[id_]
return id_value[id_]
def _read(
self,
@@ -1021,7 +946,7 @@ class SerialMotorsBus(MotorsBusBase):
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> tuple[int, int, int]:
) -> tuple[int, int]:
if length == 1:
read_fn = self.packet_handler.read1ByteTxRx
elif length == 2:
@@ -1071,14 +996,13 @@ class SerialMotorsBus(MotorsBusBase):
model = self.motors[motor].model
addr, length = get_address(self.model_ctrl_table, model, data_name)
int_value = int(value)
if normalize and data_name in self.normalized_data:
int_value = self._unnormalize({id_: value})[id_]
value = self._unnormalize({id_: value})[id_]
int_value = self._encode_sign(data_name, {id_: int_value})[id_]
value = self._encode_sign(data_name, {id_: value})[id_]
err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries."
self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _write(
self,
@@ -1112,7 +1036,7 @@ class SerialMotorsBus(MotorsBusBase):
def sync_read(
self,
data_name: str,
motors: NameOrID | Sequence[NameOrID] | None = None,
motors: str | list[str] | None = None,
*,
normalize: bool = True,
num_retry: int = 0,
@@ -1121,7 +1045,7 @@ class SerialMotorsBus(MotorsBusBase):
Args:
data_name (str): Register name.
motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor.
motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor.
normalize (bool, optional): Normalisation flag. Defaults to `True`.
num_retry (int, optional): Retry attempts. Defaults to `0`.
@@ -1142,17 +1066,16 @@ class SerialMotorsBus(MotorsBusBase):
addr, length = get_address(self.model_ctrl_table, model, data_name)
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
raw_ids_values, _ = self._sync_read(
ids_values, _ = self._sync_read(
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
decoded = self._decode_sign(data_name, raw_ids_values)
ids_values = self._decode_sign(data_name, ids_values)
if normalize and data_name in self.normalized_data:
normalized = self._normalize(decoded)
return {self._id_to_name(id_): value for id_, value in normalized.items()}
ids_values = self._normalize(ids_values)
return {self._id_to_name(id_): value for id_, value in decoded.items()}
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
def _sync_read(
self,
@@ -1224,24 +1147,21 @@ class SerialMotorsBus(MotorsBusBase):
num_retry (int, optional): Retry attempts. Defaults to `0`.
"""
raw_ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in raw_ids_values]
ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in ids_values]
if self._has_different_ctrl_tables:
assert_same_address(self.model_ctrl_table, models, data_name)
model = next(iter(models))
addr, length = get_address(self.model_ctrl_table, model, data_name)
int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()}
if normalize and data_name in self.normalized_data:
int_ids_values = self._unnormalize(raw_ids_values)
ids_values = self._unnormalize(ids_values)
int_ids_values = self._encode_sign(data_name, int_ids_values)
ids_values = self._encode_sign(data_name, ids_values)
err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries."
self._sync_write(
addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _sync_write(
self,
@@ -1274,7 +1194,3 @@ class SerialMotorsBus(MotorsBusBase):
for id_, value in ids_values.items():
data = self._serialize_data(value, length)
self.sync_writer.addParam(id_, data)
# Backward compatibility alias
MotorsBus: TypeAlias = SerialMotorsBus
+16 -7
View File
@@ -28,7 +28,7 @@ class ACTConfig(PreTrainedConfig):
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_features` and `output_features`.
Those are: `input_shapes` and 'output_shapes`.
Notes on the inputs and outputs:
- Either:
@@ -48,12 +48,21 @@ class ACTConfig(PreTrainedConfig):
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
@@ -30,7 +30,7 @@ class DiffusionConfig(PreTrainedConfig):
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_features` and `output_features`.
Those are: `input_shapes` and `output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
@@ -48,12 +48,21 @@ class DiffusionConfig(PreTrainedConfig):
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
@@ -64,7 +73,7 @@ class DiffusionConfig(PreTrainedConfig):
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
use_separate_rgb_encoder_per_camera: Whether to use a separate RGB encoder for each camera view.
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
You may provide a variable number of dimensions, therefore also controlling the degree of
downsampling.
+17 -2
View File
@@ -35,6 +35,7 @@ from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sarm.configuration_sarm import SARMConfig
@@ -67,7 +68,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
"vqbet", "pi0", "pi05", "pi05_video", "sac", "reward_classifier", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -103,6 +104,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
return PI05Policy
elif name == "pi05_video":
from lerobot.policies.videovla.modeling_pi05 import PI05VideoPolicy
return PI05VideoPolicy
elif name == "sac":
from lerobot.policies.sac.modeling_sac import SACPolicy
@@ -147,7 +152,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
"diffusion", "act", "vqbet", "pi0", "pi05", "pi05_video", "sac", "smolvla",
"reward_classifier", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
@@ -169,6 +174,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi05":
return PI05Config(**kwargs)
elif policy_type == "pi05_video":
return PI05VideoConfig(**kwargs)
elif policy_type == "sac":
return SACConfig(**kwargs)
elif policy_type == "smolvla":
@@ -333,6 +340,14 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI05VideoConfig):
from lerobot.policies.videovla.processor_pi05 import make_pi05_video_pre_post_processors
processors = make_pi05_video_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SACConfig):
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
+8 -8
View File
@@ -460,8 +460,8 @@ class PaliGemmaWithExpertModel(
inputs_embeds=inputs_embeds[1],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
use_cache=False,
past_key_values=None, #jadechoghari
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
)
suffix_output = suffix_output.last_hidden_state
@@ -575,13 +575,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
try:
from transformers.models.siglip import check
# try:
# from transformers.models.siglip import check
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
# if not check.check_whether_transformers_replace_is_installed_correctly():
# raise ValueError(msg)
# except ImportError:
# raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
-18
View File
@@ -1,18 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.policies.rlt.configuration_rlt import RLTConfig
from lerobot.policies.rlt.modeling_rlt import RLTPolicy
__all__ = ["RLTConfig", "RLTPolicy"]
@@ -1,156 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RLT (RL Token) policy configuration.
Reference: "RL Token: Bootstrapping Online RL with Vision-Language-Action Models"
(Xu et al., Physical Intelligence, 2026)
"""
from __future__ import annotations
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.policies.sac.configuration_sac import ActorLearnerConfig, ConcurrencyConfig
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
@dataclass
class RLTokenConfig:
"""Configuration for the RL-token encoder/decoder transformer."""
input_dim: int = 2048
rl_token_dim: int = 2048
num_encoder_layers: int = 2
num_decoder_layers: int = 2
num_heads: int = 8
ff_dim: int = 2048
dropout: float = 0.0
@dataclass
class RLTActorConfig:
"""Configuration for the lightweight RL actor MLP."""
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
std: float = 0.1
@dataclass
class RLTCriticConfig:
"""Configuration for the RLT critic MLP."""
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
@PreTrainedConfig.register_subclass("rlt")
@dataclass
class RLTConfig(PreTrainedConfig):
"""Configuration for the RLT (RL Token) policy.
RLT adds an RL-token encoder/decoder to a frozen VLA backbone, then trains
a lightweight actor-critic head using the RL token as state representation.
The frozen VLA also provides reference action chunks that the actor refines.
"""
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
default_factory=lambda: {
OBS_IMAGE: {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
OBS_STATE: {"min": [0.0], "max": [1.0]},
ACTION: {"min": [0.0], "max": [1.0]},
}
)
# ── Device ──
device: str = "cuda"
storage_device: str = "cpu"
# ── VLA backbone ──
vla_checkpoint: str | None = None
# ── RL-token ──
rl_token: RLTokenConfig = field(default_factory=RLTokenConfig)
# ── Actor / Critic heads ──
actor: RLTActorConfig = field(default_factory=RLTActorConfig)
critic: RLTCriticConfig = field(default_factory=RLTCriticConfig)
# ── Action chunks ──
chunk_size: int = 10
vla_chunk_size: int = 50
# ── Training parameters ──
online_steps: int = 50000
offline_steps: int = 5000
online_buffer_capacity: int = 100000
offline_buffer_capacity: int = 100000
online_step_before_learning: int = 500
warmup_steps: int = 500
async_prefetch: bool = False
# ── Algorithm hyperparameters ──
utd_ratio: int = 5
policy_update_freq: int = 2
discount: float = 0.99
critic_lr: float = 3e-4
actor_lr: float = 3e-4
rl_token_lr: float = 1e-4
tau: float = 0.005
clip_grad_norm: float = 10.0
num_critics: int = 2
bc_reg_coeff: float = 0.1
ref_dropout: float = 0.5
chunk_stride: int = 2
vla_finetune_weight: float = 0.0
# ── Distributed ──
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
def __post_init__(self):
super().__post_init__()
def get_optimizer_preset(self):
return None
def get_scheduler_preset(self):
return None
def validate_features(self) -> None:
if ACTION not in self.output_features:
raise ValueError("You must provide 'action' in the output features")
@property
def observation_delta_indices(self) -> list | None:
return None
@property
def action_delta_indices(self) -> list | None:
return None
@property
def reward_delta_indices(self) -> None:
return None
-318
View File
@@ -1,318 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RLT (RL Token) policy networks.
Reference: "RL Token: Bootstrapping Online RL with Vision-Language-Action Models"
(Xu et al., Physical Intelligence, 2026)
Architecture:
- RLTokenEncoder: compresses VLA token embeddings into a single compact RL token
- RLTokenDecoder: reconstructs VLA embeddings from the RL token (Stage 1 training only)
- RLTActor: refines VLA reference action chunks conditioned on (z_rl, proprioception, ref_action)
- RLTCritic: Q(x, action_chunk) where x = (z_rl, proprioception)
- RLTPolicy: bundles RL-token modules + actor into a PreTrainedPolicy for inference
"""
from __future__ import annotations
import math
import torch
import torch.nn as nn
from torch import Tensor
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rlt.configuration_rlt import RLTConfig
# ── Building blocks ──────────────────────────────────────────────────
class MLP(nn.Module):
"""Simple feedforward network with ReLU activations."""
def __init__(self, input_dim: int, hidden_dims: list[int], output_dim: int):
super().__init__()
layers: list[nn.Module] = []
prev = input_dim
for h in hidden_dims:
layers.append(nn.Linear(prev, h))
layers.append(nn.ReLU())
prev = h
layers.append(nn.Linear(prev, output_dim))
self.net = nn.Sequential(*layers)
def forward(self, x: Tensor) -> Tensor:
return self.net(x)
# ── RL Token Encoder ─────────────────────────────────────────────────
class RLTokenEncoder(nn.Module):
"""Compress VLA token embeddings into a single RL token via a small transformer.
Appends a learnable ``e_rl`` embedding to the VLA token sequence, processes
through transformer encoder layers, and returns the output at the ``e_rl``
position as the RL token ``z_rl``.
Paper Eq. 1: z_rl = g_phi([z_{1:M}, e_rl])_{M+1}
"""
def __init__(
self,
input_dim: int,
rl_token_dim: int,
num_layers: int,
num_heads: int,
ff_dim: int,
dropout: float = 0.0,
):
super().__init__()
self.rl_token_dim = rl_token_dim
self.e_rl = nn.Parameter(torch.randn(1, 1, input_dim) * 0.02)
if input_dim != rl_token_dim:
self.input_proj = nn.Linear(input_dim, rl_token_dim)
else:
self.input_proj = nn.Identity()
encoder_layer = nn.TransformerEncoderLayer(
d_model=rl_token_dim,
nhead=num_heads,
dim_feedforward=ff_dim,
dropout=dropout,
batch_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
def forward(self, z_vla: Tensor) -> Tensor:
"""
Args:
z_vla: VLA token embeddings, shape ``(B, M, D)``.
Returns:
RL token ``z_rl``, shape ``(B, rl_token_dim)``.
"""
batch_size = z_vla.shape[0]
e_rl = self.e_rl.expand(batch_size, -1, -1)
seq = torch.cat([z_vla, e_rl], dim=1) # (B, M+1, D)
seq = self.input_proj(seq)
out = self.transformer(seq)
z_rl = out[:, -1, :] # output at e_rl position
return z_rl
# ── RL Token Decoder ─────────────────────────────────────────────────
class RLTokenDecoder(nn.Module):
"""Autoregressively reconstruct VLA embeddings from z_rl.
Used only during Stage 1 (offline RL-token training).
Paper Eq. 2: L_ro = E[sum_i || h(d([z_rl, z_bar_{1:i-1}]))_i - z_bar_i ||^2]
"""
def __init__(
self,
rl_token_dim: int,
output_dim: int,
num_layers: int,
num_heads: int,
ff_dim: int,
dropout: float = 0.0,
):
super().__init__()
self.output_dim = output_dim
if rl_token_dim != output_dim:
self.rl_proj = nn.Linear(rl_token_dim, output_dim)
else:
self.rl_proj = nn.Identity()
decoder_layer = nn.TransformerDecoderLayer(
d_model=output_dim,
nhead=num_heads,
dim_feedforward=ff_dim,
dropout=dropout,
batch_first=True,
)
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
self.output_head = nn.Linear(output_dim, output_dim)
def forward(self, z_rl: Tensor, z_vla_stopped: Tensor) -> Tensor:
"""
Args:
z_rl: RL token, shape ``(B, D_rl)``.
z_vla_stopped: Stop-gradient VLA embeddings, shape ``(B, M, D)``.
Returns:
Reconstructed embeddings, shape ``(B, M, D)``.
"""
seq_len = z_vla_stopped.shape[1]
z_rl_proj = self.rl_proj(z_rl).unsqueeze(1)
target = torch.cat([z_rl_proj, z_vla_stopped[:, :-1, :]], dim=1)
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=z_rl.device)
decoded = self.transformer(
tgt=target,
memory=z_rl_proj,
tgt_mask=causal_mask,
)
return self.output_head(decoded) # (B, M, D)
# ── Actor ────────────────────────────────────────────────────────────
class RLTActor(nn.Module):
"""Lightweight actor that refines VLA reference action chunks.
Paper Eq. 4: pi_theta(a_{1:C} | x, a_tilde_{1:C}) = N(mu_theta(x, a_tilde), sigma^2 I)
The actor is conditioned on both the RL state and the VLA's proposed action
chunk, acting as a "VLA-guided action editor".
"""
def __init__(self, state_dim: int, action_chunk_dim: int, hidden_dims: list[int], std: float = 0.1):
super().__init__()
input_dim = state_dim + action_chunk_dim
self.net = MLP(input_dim, hidden_dims, action_chunk_dim)
self.log_std = math.log(std)
def forward(self, state: Tensor, ref_action_chunk: Tensor) -> Tensor:
"""Return the mean action chunk.
Args:
state: RL state ``x = (z_rl, proprioception)``, shape ``(B, state_dim)``.
ref_action_chunk: Flattened VLA reference chunk, shape ``(B, C*d)``.
Returns:
Refined action chunk (mean), shape ``(B, C*d)``.
"""
x = torch.cat([state, ref_action_chunk], dim=-1)
return self.net(x)
def sample(self, state: Tensor, ref_action_chunk: Tensor) -> tuple[Tensor, Tensor]:
"""Sample an action and return (action, log_prob)."""
mean = self.forward(state, ref_action_chunk)
std = math.exp(self.log_std)
noise = torch.randn_like(mean) * std
action = mean + noise
log_prob = -0.5 * (noise / std).pow(2).sum(dim=-1) - mean.shape[-1] * math.log(
std * math.sqrt(2 * math.pi)
)
return action, log_prob
# ── Policy (inference bundle) ────────────────────────────────────────
class RLTPolicy(PreTrainedPolicy):
"""RLT policy — bundles the RL-token encoder and actor for inference.
The frozen VLA backbone is **not** part of this module; it is loaded
separately and its embeddings / reference actions are passed in via the
observation dict (populated by the actor process or a preprocessor).
During training, the :class:`RLTAlgorithm` holds the critic, target networks,
and optimizers. This class only contains what is needed for ``select_action``.
"""
name = "rlt"
config_class = RLTConfig
def __init__(self, config: RLTConfig, dataset_stats=None):
super().__init__(config, dataset_stats)
action_dim = config.output_features["action"].shape[0]
action_chunk_dim = config.chunk_size * action_dim
prop_feature = config.input_features.get("observation.state", None)
proprioception_dim = prop_feature.shape[0] if prop_feature is not None else 0
state_dim = config.rl_token.rl_token_dim + proprioception_dim
# RL-token encoder (frozen after Stage 1)
self.rl_token_encoder = RLTokenEncoder(
input_dim=config.rl_token.input_dim,
rl_token_dim=config.rl_token.rl_token_dim,
num_layers=config.rl_token.num_encoder_layers,
num_heads=config.rl_token.num_heads,
ff_dim=config.rl_token.ff_dim,
dropout=config.rl_token.dropout,
)
# RL-token decoder (used only during Stage 1 training)
self.rl_token_decoder = RLTokenDecoder(
rl_token_dim=config.rl_token.rl_token_dim,
output_dim=config.rl_token.input_dim,
num_layers=config.rl_token.num_decoder_layers,
num_heads=config.rl_token.num_heads,
ff_dim=config.rl_token.ff_dim,
dropout=config.rl_token.dropout,
)
# Actor MLP
self.actor = RLTActor(
state_dim=state_dim,
action_chunk_dim=action_chunk_dim,
hidden_dims=config.actor.hidden_dims,
std=config.actor.std,
)
self._action_dim = action_dim
self._action_chunk_dim = action_chunk_dim
self._state_dim = state_dim
self._proprioception_dim = proprioception_dim
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a refined action chunk given an observation.
Expects the observation dict to contain:
- ``"observation.vla_embeddings"``: VLA internal token embeddings ``(M, D)``
- ``"observation.reference_action"``: VLA reference chunk ``(C*d,)``
- ``"observation.state"`` (optional): proprioceptive state ``(P,)``
Returns:
Action chunk tensor of shape ``(C*d,)``.
"""
self.eval()
vla_emb = batch["observation.vla_embeddings"]
if vla_emb.dim() == 2:
vla_emb = vla_emb.unsqueeze(0)
z_rl = self.rl_token_encoder(vla_emb) # (1, D_rl)
parts = [z_rl]
if "observation.state" in batch and self._proprioception_dim > 0:
prop = batch["observation.state"]
if prop.dim() == 1:
prop = prop.unsqueeze(0)
parts.append(prop)
state = torch.cat(parts, dim=-1)
ref = batch["observation.reference_action"]
if ref.dim() == 1:
ref = ref.unsqueeze(0)
action = self.actor(state, ref)
return action.squeeze(0)
def reset(self):
pass
+372 -23
View File
@@ -15,11 +15,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from dataclasses import asdict
from typing import Literal
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
@@ -47,13 +52,20 @@ class SACPolicy(
# Determine action dimension and initialize all components
continuous_action_dim = config.output_features[ACTION].shape[0]
self.encoder = SACObservationEncoder(config)
self._init_encoders()
self._init_critics(continuous_action_dim)
self._init_actor(continuous_action_dim)
self._init_discrete_critic()
self._init_temperature()
def get_optim_params(self) -> dict:
optim_params = {
"actor": [self.actor.parameters()],
"actor": [
p
for n, p in self.actor.named_parameters()
if not n.startswith("encoder") or not self.shared_encoder
],
"critic": self.critic_ensemble.parameters(),
"temperature": self.log_alpha,
}
if self.config.num_discrete_actions is not None:
optim_params["discrete_critic"] = self.discrete_critic.parameters()
@@ -71,9 +83,10 @@ class SACPolicy(
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
observations_features = None
if self.encoder.has_images:
observations_features = self.encoder.get_cached_image_features(batch)
if self.shared_encoder and self.actor.encoder.has_images:
observations_features = self.actor.encoder.get_cached_image_features(batch)
actions, _, _ = self.actor(batch, observations_features)
@@ -84,35 +97,371 @@ class SACPolicy(
return actions
def critic_forward(
self,
observations: dict[str, Tensor],
actions: Tensor,
use_target: bool = False,
observation_features: Tensor | None = None,
) -> Tensor:
"""Forward pass through a critic network ensemble
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
Returns:
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, observation_features)
return q_values
def discrete_critic_forward(
self, observations, use_target=False, observation_features=None
) -> torch.Tensor:
"""Forward pass through a discrete critic network
Args:
observations: Dictionary of observations
use_target: If True, use target critics, otherwise use ensemble critics
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
Returns:
Tensor of Q-values from the discrete critic network
"""
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
q_values = discrete_critic(observations, observation_features)
return q_values
def forward(
self,
batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
) -> dict[str, Tensor]:
"""Actor forward pass."""
observations = batch.get("state", batch)
observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None
actions, log_probs, means = self.actor(observations, observation_features)
return {"action": actions, "log_prob": log_probs, "action_mean": means}
"""Compute the loss for the given model
def _init_actor(self, continuous_action_dim: int) -> None:
self.actor = Policy(
encoder=self.encoder,
network=MLP(input_dim=self.encoder.output_dim, **asdict(self.config.actor_network_kwargs)),
action_dim=continuous_action_dim,
encoder_is_shared=False,
**asdict(self.config.policy_kwargs),
Args:
batch: Dictionary containing:
- action: Action tensor
- reward: Reward tensor
- state: Observations tensor dict
- next_state: Next observations tensor dict
- done: Done mask tensor
- observation_feature: Optional pre-computed observation features
- next_observation_feature: Optional pre-computed next observation features
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
Returns:
The computed loss tensor
"""
# Extract common components from batch
actions: Tensor = batch[ACTION]
observations: dict[str, Tensor] = batch["state"]
observation_features: Tensor = batch.get("observation_feature")
if model == "critic":
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
loss_critic = self.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
return {"loss_critic": loss_critic}
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
loss_discrete_critic = self.compute_loss_discrete_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
complementary_info=complementary_info,
)
return {"loss_discrete_critic": loss_discrete_critic}
if model == "actor":
return {
"loss_actor": self.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
}
if model == "temperature":
return {
"loss_temperature": self.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
}
raise ValueError(f"Unknown model type: {model}")
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_param, param in zip(
self.critic_target.parameters(),
self.critic_ensemble.parameters(),
strict=True,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
if self.config.num_discrete_actions is not None:
for target_param, param in zip(
self.discrete_critic_target.parameters(),
self.discrete_critic.parameters(),
strict=True,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def update_temperature(self):
self.temperature = self.log_alpha.exp().item()
def compute_loss_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None,
) -> Tensor:
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations,
actions=next_action_preds,
use_target=True,
observation_features=next_observation_features,
)
# subsample critics to prevent overfitting if use high UTD (update to date)
# TODO: Get indices before forward pass to avoid unnecessary computation
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (self.temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
if self.config.num_discrete_actions is not None:
# NOTE: We only want to keep the continuous action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
q_preds = self.critic_forward(
observations=observations,
actions=actions,
use_target=False,
observation_features=observation_features,
)
def _init_discrete_critic(self) -> None:
if self.config.num_discrete_actions is None:
self.discrete_critic = None
return
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(dim=1)
).sum()
return critics_loss
def compute_loss_discrete_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features=None,
next_observation_features=None,
complementary_info=None,
):
# NOTE: We only want to keep the discrete action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long()
discrete_penalties: Tensor | None = None
if complementary_info is not None:
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_discrete_qs = self.discrete_critic_forward(
next_observations, use_target=False, observation_features=next_observation_features
)
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
# Get target Q-values from target network
target_next_discrete_qs = self.discrete_critic_forward(
observations=next_observations,
use_target=True,
observation_features=next_observation_features,
)
# Use gather to select Q-values for best actions
target_next_discrete_q = torch.gather(
target_next_discrete_qs, dim=1, index=best_next_discrete_action
).squeeze(-1)
# Compute target Q-value with Bellman equation
rewards_discrete = rewards
if discrete_penalties is not None:
rewards_discrete = rewards + discrete_penalties
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
# Get predicted Q-values for current observations
predicted_discrete_qs = self.discrete_critic_forward(
observations=observations, use_target=False, observation_features=observation_features
)
# Use gather to select Q-values for taken actions
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
# Compute MSE loss between predicted and target Q-values
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
return discrete_critic_loss
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(
self,
observations,
observation_features: Tensor | None = None,
) -> Tensor:
actions_pi, log_probs, _ = self.actor(observations, observation_features)
q_preds = self.critic_forward(
observations=observations,
actions=actions_pi,
use_target=False,
observation_features=observation_features,
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic."""
self.shared_encoder = self.config.shared_encoder
self.encoder_critic = SACObservationEncoder(self.config)
self.encoder_actor = (
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
)
def _init_critics(self, continuous_action_dim):
"""Build critic ensemble, targets, and optional discrete critic."""
heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
target_heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
if self.config.num_discrete_actions is not None:
self._init_discrete_critics()
def _init_discrete_critics(self):
"""Build discrete discrete critic ensemble and target networks."""
self.discrete_critic = DiscreteCritic(
encoder=self.encoder,
input_dim=self.encoder.output_dim,
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
self.discrete_critic_target = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
# TODO: (maractingi, azouitine) Compile the discrete critic
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
def _init_actor(self, continuous_action_dim):
"""Initialize policy actor network and default target entropy."""
# NOTE: The actor select only the continuous action part
self.actor = Policy(
encoder=self.encoder_actor,
network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)),
action_dim=continuous_action_dim,
encoder_is_shared=self.shared_encoder,
**asdict(self.config.policy_kwargs),
)
self.target_entropy = self.config.target_entropy
if self.target_entropy is None:
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.target_entropy = -np.prod(dim) / 2
def _init_temperature(self):
"""Set up temperature parameter and initial log_alpha."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
self.temperature = self.log_alpha.exp().item()
class SACObservationEncoder(nn.Module):
@@ -27,18 +27,18 @@ Usage:
# Full RA-BC computation with visualizations
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4
--reward-model-path pepijn223/sarm_single_uni4
# Faster computation with stride (compute every 5 frames, interpolate the rest)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 \\
--reward-model-path pepijn223/sarm_single_uni4 \\
--stride 5
# Visualize predictions only (no RA-BC computation)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 \\
--reward-model-path pepijn223/sarm_single_uni4 \\
--visualize-only \\
--num-visualizations 5
@@ -714,12 +714,12 @@ Examples:
# Full RA-BC computation with visualizations
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4
--reward-model-path pepijn223/sarm_single_uni4
# Visualize predictions only (no RA-BC computation)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 \\
--reward-model-path pepijn223/sarm_single_uni4 \\
--visualize-only \\
--num-visualizations 10
""",
@@ -30,7 +30,7 @@ Example of finetuning the smolvla pretrained model (`smolvla_base`):
```bash
lerobot-train \
--policy.path=lerobot/smolvla_base \
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
--batch_size=64 \
--steps=200000
```
@@ -40,7 +40,7 @@ and an action expert.
```bash
lerobot-train \
--policy.type=smolvla \
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
--batch_size=64 \
--steps=200000
```
@@ -378,16 +378,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone().mean().item()
loss_dict["losses_after_forward"] = losses.clone()
if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone().mean().item()
loss_dict["losses_after_in_ep_bound"] = losses.clone()
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone().mean().item()
loss_dict["losses_after_rm_padding"] = losses.clone()
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
@@ -30,7 +30,7 @@ class TDMPCConfig(PreTrainedConfig):
camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_features`, `output_features`, and perhaps `max_random_shift_ratio`.
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
Args:
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
@@ -40,12 +40,24 @@ class TDMPCConfig(PreTrainedConfig):
is an alternative to using action repeats. If this is set to more than 1, then we require
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
approach of using multiple steps from the plan is not in the original implementation.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
match the original implementation.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
normalization mode here.
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
latent_dim: Observation's latent embedding dimension.
+49
View File
@@ -0,0 +1,49 @@
# π₀.₅ (pi05)
This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
It is designed as a **Vision-Language-Action model with open-world generalization**.
---
## Model Overview
| Feature | π₀ | π₀.₅ |
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
| AdaRMS | Not used | Used in action expert |
| Tokenizer Length | 48 tokens | 200 tokens |
| Discrete State Input | False (Uses `state_proj` layer) | True |
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
---
## Citation
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
```bibtex
@misc{openpi2024,
author = {Physical Intelligence Lab},
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
year = {2024},
publisher = {GitHub},
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
license = {Apache-2.0}
}
@misc{intelligence2025pi05visionlanguageactionmodelopenworld,
title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization},
author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky},
year = {2025},
eprint = {2504.16054},
archivePrefix= {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2504.16054},
}
```
---
## License
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
+31
View File
@@ -0,0 +1,31 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lazy imports to avoid conflicts with lerobot.policies.pi05.PI05Config
# when only importing subpackages like videoprism
def __getattr__(name):
if name == "PI05VideoConfig":
from .configuration_pi05 import PI05VideoConfig
return PI05VideoConfig
elif name == "PI05VideoPolicy":
from .modeling_pi05 import PI05VideoPolicy
return PI05VideoPolicy
elif name == "make_pi05_video_pre_post_processors":
from .processor_pi05 import make_pi05_video_pre_post_processors
return make_pi05_video_pre_post_processors
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
__all__ = ["PI05VideoConfig", "PI05VideoPolicy", "make_pi05_video_pre_post_processors"]
@@ -0,0 +1,212 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224
@PreTrainedConfig.register_subclass("pi05_video")
@dataclass
class PI05VideoConfig(PreTrainedConfig):
paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m"
dtype: str = "float32" # Options: "bfloat16", "float32"
n_obs_steps: int = 1
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
n_action_steps: int = 50 # Number of action steps to execute
# Video encoder settings (VideoPrism)
use_video_encoder: bool = False # Enable video encoding with VideoPrism
video_num_frames: int = 16 # Number of frames for video encoding (VideoPrism default is 16)
videoprism_model_name: str = "MHRDYN7/videoprism-base-f16r288" # VideoPrism model to use
videoprism_image_size: int = 288 # VideoPrism expects 288x288 images
freeze_video_encoder: bool = True # Whether to freeze the video encoder weights
video_padding_mode: str = "repeat" # How to pad frames at episode start: "repeat" or "zero"
# Which camera to use for video encoding (None = first camera, or specify key like "observation.images.top")
video_encoder_camera_key: str | None = None
# Perceiver Resampler settings to reduce video tokens (4096 -> video_num_latents)
video_num_latents: int = 256 # Number of latent tokens for video resampler
video_resampler_num_heads: int = 8 # Number of attention heads in resampler
# Shorter state and action vectors will be padded to these dimensions
max_state_dim: int = 32
max_action_dim: int = 32
# Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10
time_sampling_beta_alpha: float = 1.5
time_sampling_beta_beta: float = 1.0
time_sampling_scale: float = 0.999
time_sampling_offset: float = 0.001
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE,
DEFAULT_IMAGE_SIZE,
) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
empty_cameras: int = 0
tokenizer_max_length: int = 200 # see openpi `__post_init__`
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state
"ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action
}
)
# Training settings
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
compile_model: bool = False # Whether to use torch.compile for model optimization
compile_mode: str = "max-autotune" # Torch compile mode
device: str | None = None # Device to use for the model (None = auto-detect)
# Finetuning settings
freeze_vision_encoder: bool = False # Freeze only the vision encoder
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
# Optimizer settings: see openpi `AdamW`
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.01
optimizer_grad_clip_norm: float = 1.0
# Scheduler settings: see openpi `CosineDecaySchedule`
# Note: These will auto-scale if --steps < scheduler_decay_steps
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
tokenizer_max_length: int = 200 # see openpi `__post_init__`
def __post_init__(self):
super().__post_init__()
# Validate configuration
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
)
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
if self.dtype not in ["bfloat16", "float32"]:
raise ValueError(f"Invalid dtype: {self.dtype}")
# Validate video encoder settings
if self.use_video_encoder:
if self.video_num_frames < 1:
raise ValueError(f"video_num_frames must be >= 1, got {self.video_num_frames}")
if self.videoprism_image_size < 1:
raise ValueError(f"videoprism_image_size must be >= 1, got {self.videoprism_image_size}")
if self.video_padding_mode not in ["repeat", "zero"]:
raise ValueError(
f"video_padding_mode must be 'repeat' or 'zero', got {self.video_padding_mode}"
)
def validate_features(self) -> None:
"""Validate and set up input/output features."""
for i in range(self.empty_cameras):
key = OBS_IMAGES + f".empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, *self.image_resolution), # Use configured image resolution
)
self.input_features[key] = empty_camera
if OBS_STATE not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim
)
self.input_features[OBS_STATE] = state_feature
if ACTION not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim
)
self.output_features[ACTION] = action_feature
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list[int] | None:
"""Return indices for delta observations.
For PI05, we don't use generic observation_delta_indices because it would
apply to both images AND state. Instead, we use image_observation_delta_indices
which only applies to image observations.
"""
return None
@property
def image_observation_delta_indices(self) -> list[int] | None:
"""Return indices for delta image observations only.
When video encoding is enabled, returns indices for the past frames
needed by VideoPrism (e.g., -15, -14, ..., -1, 0 for 16 frames).
This only applies to image observations, not state.
"""
if self.use_video_encoder:
# Return indices for past frames: [-15, -14, ..., -1, 0] for 16 frames
return list(range(-(self.video_num_frames - 1), 1))
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,171 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig
from lerobot.policies.pi05.modeling_pi05 import pad_vector
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
@dataclass
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
"""
Processor step to prepare the state and tokenize the language input.
"""
max_state_dim: int = 32
task_key: str = "task"
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
if state is None:
raise ValueError("State is required for PI05")
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
if tasks is None:
raise ValueError("No task found in complementary data")
# TODO: check if this necessary
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.max_state_dim)
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
return transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
This step does not alter the feature definitions.
"""
return features
def make_pi05_video_pre_post_processors(
config: PI05VideoConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for the PI05Video policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Appending a newline character to the task description for tokenizer compatibility.
5. Tokenizing the text prompt using the PaliGemma tokenizer.
6. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the PI0 policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DeviceProcessorStep(device=config.device),
]
output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@@ -0,0 +1,214 @@
#!/usr/bin/env python
"""
Test script for PI05 with video encoder (VideoPrism).
This script creates a dummy example to test the model with video encoding enabled.
"""
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig
from lerobot.policies.videovla.modeling_pi05 import PI05VideoPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
def create_dummy_batch(
batch_size: int = 2,
num_frames: int = 16,
image_size: int = 224,
num_cameras: int = 2,
state_dim: int = 14,
action_dim: int = 14,
chunk_size: int = 50,
seq_len: int = 10,
device: str = "cuda",
) -> dict[str, torch.Tensor]:
"""Create a dummy batch for testing."""
batch = {}
# Create image observations with temporal dimension [B, T, C, H, W]
for i in range(num_cameras):
key = f"{OBS_IMAGES}.camera_{i}"
# Images in [0, 1] range
batch[key] = torch.rand(batch_size, num_frames, 3, image_size, image_size, device=device)
# Create state observation [B, state_dim]
batch[OBS_STATE] = torch.rand(batch_size, state_dim, device=device)
# Create language tokens and attention mask [B, seq_len]
batch["observation.language.tokens"] = torch.randint(0, 1000, (batch_size, seq_len), device=device)
batch["observation.language.attention_mask"] = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
# Create action targets [B, chunk_size, action_dim]
batch[ACTION] = torch.rand(batch_size, chunk_size, action_dim, device=device)
return batch
def test_video_encoder():
"""Test the PI05 model with video encoding enabled."""
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Configuration
batch_size = 2
num_frames = 16
image_size = 224
num_cameras = 2
state_dim = 14
action_dim = 14
chunk_size = 50
# Create config with video encoder enabled
print("Creating PI05VideoConfig with video encoder...")
config = PI05VideoConfig(
use_video_encoder=True,
video_num_frames=num_frames,
videoprism_model_name="MHRDYN7/videoprism-base-f16r288",
videoprism_image_size=288,
freeze_video_encoder=True,
video_padding_mode="repeat",
video_encoder_camera_key=f"{OBS_IMAGES}.camera_0", # Use first camera for video
chunk_size=chunk_size,
max_action_dim=32,
max_state_dim=32,
dtype="float32", # Use float32 for testing
device=device,
)
# Set up input/output features
for i in range(num_cameras):
key = f"{OBS_IMAGES}.camera_{i}"
config.input_features[key] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, image_size, image_size),
)
config.input_features[OBS_STATE] = PolicyFeature(
type=FeatureType.STATE,
shape=(state_dim,),
)
config.output_features[ACTION] = PolicyFeature(
type=FeatureType.ACTION,
shape=(action_dim,),
)
print(f"use_video_encoder: {config.use_video_encoder}")
print(f"video_num_frames: {config.video_num_frames}")
print(f"video_padding_mode: {config.video_padding_mode}")
print(f"video_encoder_camera_key: {config.video_encoder_camera_key}")
print(f"image_observation_delta_indices: {config.image_observation_delta_indices}")
# Create model
model = PI05VideoPolicy(config)
model.to(device)
# Create dummy batch
batch = create_dummy_batch(
batch_size=batch_size,
num_frames=num_frames,
image_size=image_size,
num_cameras=num_cameras,
state_dim=state_dim,
action_dim=action_dim,
chunk_size=chunk_size,
device=device,
)
print(f"Batch keys: {list(batch.keys())}" )
for key, value in batch.items():
print(f"{key}: {value.shape}")
# Test forward pass
model.train()
try:
loss, loss_dict = model.forward(batch)
print(f"Forward pass successful!")
print(f"Loss: {loss.item():.4f}")
print(f"Loss dict: {loss_dict}")
except Exception as e:
print(f"Forward pass failed: {e}")
raise
# Test inference
model.eval()
with torch.no_grad():
try:
actions = model.predict_action_chunk(batch)
print(f"Test pass, inference pass!")
print(f"Predicted actions shape: {actions.shape}")
except Exception as e:
print(f"Inference failed: {e}")
raise
print("All tests passed!")
def test_frame_padding():
"""Test frame padding at episode start."""
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create config
config = PI05VideoConfig(
use_video_encoder=True,
video_num_frames=16,
videoprism_model_name="MHRDYN7/videoprism-base-f16r288",
freeze_video_encoder=True,
video_padding_mode="repeat",
chunk_size=50,
dtype="float32",
device=device,
)
# Set up minimal features
config.input_features[f"{OBS_IMAGES}.camera_0"] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224),
)
config.output_features[ACTION] = PolicyFeature(
type=FeatureType.ACTION,
shape=(14,),
)
# Create model
model = PI05VideoPolicy(config)
model.to(device)
# Test with fewer frames than expected (simulating episode start)
batch = {
f"{OBS_IMAGES}.camera_0": torch.rand(2, 5, 3, 224, 224, device=device),
"observation.language.tokens": torch.randint(0, 1000, (2, 10), device=device),
"observation.language.attention_mask": torch.ones(2, 10, dtype=torch.bool, device=device),
ACTION: torch.rand(2, 50, 14, device=device),
}
video_frames = model._preprocess_video(batch)
if video_frames is not None:
print(f"Input frames: 5")
print(f"Output video_frames shape: {video_frames.shape}")
print(f"Expected: [2, 16, 3, 224, 224]")
assert video_frames.shape == (2, 16, 3, 224, 224), f"Unexpected shape: {video_frames.shape}"
print("Frame padding test PASSED!")
else:
print("video_frames is None (unexpected)")
# Test with single frame
batch[f"{OBS_IMAGES}.camera_0"] = torch.rand(2, 3, 224, 224, device=device) # [B, C, H, W]
video_frames = model._preprocess_video(batch)
if video_frames is not None:
print(f"Input: single frame [B, C, H, W]")
print(f"Output video_frames shape: {video_frames.shape}")
print(f"Expected: [2, 16, 3, 224, 224]")
assert video_frames.shape == (2, 16, 3, 224, 224), f"Unexpected shape: {video_frames.shape}"
print("Single frame expansion test PASSED!")
else:
print("video_frames is None (unexpected)")
print("All tests passed!")
if __name__ == "__main__":
# Run tests
test_frame_padding()
test_video_encoder()
@@ -0,0 +1,37 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_videoprism import VideoPrismConfig, VideoPrismTextConfig, VideoPrismVisionConfig
from .modeling_videoprism import (
VideoPrismClipModel,
VideoPrismForVideoClassification,
VideoPrismPreTrainedModel,
VideoPrismTextModel,
VideoPrismVideoModel,
VideoPrismVisionModel,
)
from .video_processing_videoprism import VideoPrismVideoProcessor
__all__ = [
"VideoPrismConfig",
"VideoPrismTextConfig",
"VideoPrismVisionConfig",
"VideoPrismClipModel",
"VideoPrismForVideoClassification",
"VideoPrismPreTrainedModel",
"VideoPrismTextModel",
"VideoPrismVideoModel",
"VideoPrismVisionModel",
"VideoPrismVideoProcessor",
]
@@ -0,0 +1,269 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/videoprism/modular_videoprism.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_videoprism.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from transformers import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class VideoPrismVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`VideoPrismVisionModel`]. It is used to instantiate a
VideoPrism vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the VideoPrism
[google/videoprism](https://huggingface.co/google/videoprism) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
image_size (`int`, *optional*, defaults to 288):
The size of the input image.
num_frames (`int`, *optional*, defaults to 16):
The number of frames in the input video.
tubelet_size (`List[int]`, *optional*, defaults to `[1, 18, 18]`):
The size of the tubelet patch.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_spatial_layers (`int`, *optional*, defaults to 12):
Number of spatial transformer blocks.
num_temporal_layers (`int`, *optional*, defaults to 4):
Number of temporal transformer blocks.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_python"`):
The non-linear activation function (function or string).
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the qkv projections in attention layers.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
Softcapping constant for attention logits.
num_auxiliary_layers (`int`, *optional*, defaults to 2):
Number of auxiliary layers. This is used in the VideoPrismVideoModel that is a part of VideoPrismClipModel.
apply_l2_norm (`bool`, *optional*, defaults to `True`):
Whether to apply L2 normalization to the output. This is used in the VideoPrismVideoModel that is a part of VideoPrismClipModel.
Example:
```python
>>> from transformers import VideoPrismVisionConfig, VideoPrismVisionModel
>>> # Initializing a VideoPrismVisionConfig with default values
>>> configuration = VideoPrismVisionConfig()
>>> # Initializing a VideoPrismVisionModel with the configuration
>>> model = VideoPrismVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "videoprism_vision_model"
base_config_key = "vision_config"
def __init__(
self,
image_size=288,
num_frames=16,
tubelet_size=[1, 18, 18],
num_channels=3,
hidden_size=768,
num_spatial_layers=12,
num_temporal_layers=4,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu_python",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
layer_norm_eps=1e-06,
qkv_bias=True,
attn_logit_softcapping=50.0,
num_auxiliary_layers=2,
apply_l2_norm=True,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.image_size = image_size
self.num_frames = num_frames
self.tubelet_size = tubelet_size
self.num_channels = num_channels
self.qkv_bias = qkv_bias
self.num_spatial_layers = num_spatial_layers
self.num_temporal_layers = num_temporal_layers
self.attn_logit_softcapping = attn_logit_softcapping
self.num_auxiliary_layers = num_auxiliary_layers
self.apply_l2_norm = apply_l2_norm
class VideoPrismTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`VideoPrismTextModel`]. It is used to instantiate a
VideoPrism text encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the VideoPrism
[google/videoprism](https://huggingface.co/google/videoprism) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_text_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the text Transformer encoder.
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the text model. Defines the number of different tokens that can be represented by the
`input_ids` passed when calling [`VideoPrismTextModel`].
apply_l2_norm (`bool`, *optional*, defaults to `True`):
Whether to apply L2 normalization to the output text embeddings.
hidden_act (`str` or `function`, *optional*, defaults to `"relu"`):
The non-linear activation function (function or string) in the encoder and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the query, key, and value projections in the attention layers.
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
Softcapping constant for attention logits.
Example:
```python
>>> from transformers import VideoPrismTextConfig, VideoPrismTextModel
>>> # Initializing a VideoPrismTextConfig with default values
>>> configuration = VideoPrismTextConfig()
>>> # Initializing a VideoPrismTextModel (with random weights) from the configuration
>>> model = VideoPrismTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "videoprism_text_model"
base_config_key = "text_config"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_attention_heads=12,
num_text_layers=12,
vocab_size=32000,
apply_l2_norm=True,
hidden_act="relu",
attention_probs_dropout_prob=0.0,
qkv_bias=True,
hidden_dropout_prob=0.0,
layer_norm_eps=1e-06,
initializer_range=0.02,
attn_logit_softcapping=50.0,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_text_layers = num_text_layers
self.vocab_size = vocab_size
self.apply_l2_norm = apply_l2_norm
self.hidden_act = hidden_act
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.qkv_bias = qkv_bias
self.hidden_dropout_prob = hidden_dropout_prob
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.attn_logit_softcapping = attn_logit_softcapping
class VideoPrismConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`VideoPrismModel`]. It is used to instantiate a
VideoPrism model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the VideoPrism
[google/videoprism](https://huggingface.co/google/videoprism) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`VideoPrismTextConfig`, *optional*):
Configuration for the text model.
vision_config (`VideoPrismVisionConfig`, *optional*):
Configuration for the vision model.
kwargs (*optional*):
Dictionary of keyword arguments.
Example:
```python
>>> from transformers import VideoPrismConfig, VideoPrismModel
>>> # Initializing a VideoPrismConfig with default values
>>> configuration = VideoPrismConfig()
>>> # Initializing a VideoPrismClipModel with the configuration
>>> model = VideoPrismClipModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "videoprism"
sub_configs = {"text_config": VideoPrismTextConfig, "vision_config": VideoPrismVisionConfig}
def __init__(self, text_config=None, vision_config=None, **kwargs):
if text_config is None:
text_config = VideoPrismTextConfig()
logger.info("`text_config` is `None`. Initializing the `VideoPrismTextConfig` with default values.")
elif isinstance(text_config, dict):
text_config = VideoPrismTextConfig(**text_config)
if vision_config is None:
vision_config = VideoPrismVisionConfig()
logger.info("`vision_config` is `None`. initializing the `VideoPrismVisionConfig` with default values.")
elif isinstance(vision_config, dict):
vision_config = VideoPrismVisionConfig(**vision_config)
self.text_config = text_config
self.vision_config = vision_config
super().__init__(**kwargs)
__all__ = ["VideoPrismVisionConfig", "VideoPrismTextConfig", "VideoPrismConfig"]
@@ -0,0 +1,245 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from collections import defaultdict
from contextlib import contextmanager
import torch
# Record all the torch primitives in advance, so that we can use them without them being modified when we patch torch
# in context managers
TORCH_INIT_FUNCTIONS = {
"uniform_": torch.nn.init.uniform_,
"normal_": torch.nn.init.normal_,
"constant_": torch.nn.init.constant_,
"ones_": torch.nn.init.ones_,
"zeros_": torch.nn.init.zeros_,
"eye_": torch.nn.init.eye_,
"dirac_": torch.nn.init.dirac_,
"xavier_uniform_": torch.nn.init.xavier_uniform_,
"xavier_normal_": torch.nn.init.xavier_normal_,
"kaiming_uniform_": torch.nn.init.kaiming_uniform_,
"kaiming_normal_": torch.nn.init.kaiming_normal_,
"trunc_normal_": torch.nn.init.trunc_normal_,
"orthogonal_": torch.nn.init.orthogonal_,
"sparse_": torch.nn.init.sparse_,
}
def uniform_(
tensor: torch.Tensor, a: float = 0.0, b: float = 1.0, generator: torch.Generator | None = None
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["uniform_"](tensor, a=a, b=b, generator=generator)
return tensor
def normal_(
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, generator: torch.Generator | None = None
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["normal_"](tensor, mean=mean, std=std, generator=generator)
return tensor
def constant_(tensor: torch.Tensor, val: float) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["constant_"](tensor, val=val)
return tensor
def ones_(tensor: torch.Tensor) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["ones_"](tensor)
return tensor
def zeros_(tensor: torch.Tensor) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["zeros_"](tensor)
return tensor
def eye_(tensor: torch.Tensor) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["eye_"](tensor)
return tensor
def dirac_(tensor: torch.Tensor, groups: int = 1) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["dirac_"](tensor, groups=groups)
return tensor
def xavier_uniform_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["xavier_uniform_"](tensor, gain=gain, generator=generator)
return tensor
def xavier_normal_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["xavier_normal_"](tensor, gain=gain, generator=generator)
return tensor
def kaiming_uniform_(
tensor: torch.Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
generator: torch.Generator | None = None,
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["kaiming_uniform_"](
tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
)
return tensor
def kaiming_normal_(
tensor: torch.Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
generator: torch.Generator | None = None,
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["kaiming_normal_"](
tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
)
return tensor
def trunc_normal_(
tensor: torch.Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
generator: torch.Generator | None = None,
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["trunc_normal_"](tensor, mean=mean, std=std, a=a, b=b, generator=generator)
return tensor
def orthogonal_(
tensor: torch.Tensor,
gain: float = 1,
generator: torch.Generator | None = None,
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["orthogonal_"](tensor, gain=gain, generator=generator)
return tensor
def sparse_(
tensor: torch.Tensor, sparsity: float, std: float = 0.01, generator: torch.Generator | None = None
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["sparse_"](tensor, sparsity=sparsity, std=std, generator=generator)
return tensor
def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
with torch.no_grad():
return tensor.copy_(other)
return tensor
# Here, we need to check several modules imported, and hot patch all of them, as sometimes torch does
# something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules.activations,
# where MultiHeadAttention lives), so the function name is binded at import time and just doing
# `setattr(torch.nn.init, name, globals()[name])` is thus not enough
# The following list should be enough for all torch versions we work with
TORCH_MODULES_TO_PATCH = (
"torch.nn.init",
"torch.nn.modules.activation",
"torch.nn.modules.transformer",
"torch.nn.modules.linear",
"torch.nn.modules.loss",
"torch.nn.modules.batchnorm",
"torch.nn.modules.conv",
"torch.nn.modules.normalization",
"torch.nn.modules.rnn",
"torch.nn.modules.sparse",
)
@contextmanager
def guard_torch_init_functions():
"""
Guard the `torch.nn.init` primitive functions to behave exactly like the functions in this file, i.e. be
protected against the `_is_hf_initialized` flag to avoid re-init if the param was already loaded.
Usually, all models are using the init from `transformers` which are already guarded, but just to make extra sure
and for remote code, we also use this context manager.
"""
originals = defaultdict(dict)
try:
# Replace all torch funcs by the ones in this file
for module_name in TORCH_MODULES_TO_PATCH:
if module_name in sys.modules:
module = sys.modules[module_name]
for func_name in TORCH_INIT_FUNCTIONS.keys():
if hasattr(module, func_name):
originals[module][func_name] = getattr(module, func_name)
setattr(module, func_name, globals()[func_name])
yield
finally:
# Set back the original functions on all modules
for module, functions in originals.items():
for func_name, func in functions.items():
setattr(module, func_name, func)
@contextmanager
def no_init_weights():
"""
Disable weight initialization both at the torch-level, and at the transformers-level (`init_weights`).
This is used to speed-up initializing an empty model with deepspeed, as we do not initialize the model on meta device
with deepspeed, but we still don't need to run expensive weight initializations as we are loading params afterwards.
"""
from .modeling_utils import PreTrainedModel
def empty_func(*args, **kwargs):
pass
originals = defaultdict(dict)
try:
# Replace all torch funcs by empty ones
for module_name in TORCH_MODULES_TO_PATCH:
if module_name in sys.modules:
module = sys.modules[module_name]
for func_name in TORCH_INIT_FUNCTIONS.keys():
if hasattr(module, func_name):
originals[module][func_name] = getattr(module, func_name)
setattr(module, func_name, empty_func)
# Also patch our own `init_weights`
original_init_weights = PreTrainedModel.init_weights
PreTrainedModel.init_weights = empty_func
yield
finally:
# Set back the original torch functions on all modules
for module, functions in originals.items():
for func_name, func in functions.items():
setattr(module, func_name, func)
# Set back `init_weights`
PreTrainedModel.init_weights = original_init_weights
@@ -0,0 +1,994 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/videoprism/modular_videoprism.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_videoprism.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
from collections.abc import Callable
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import _calculate_fan_in_and_fan_out
from . import initialization as init
from transformers.activations import ACT2FN
from transformers.masking_utils import create_causal_mask
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutput, ImageClassifierOutput
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.file_utils import ModelOutput
from .configuration_videoprism import VideoPrismConfig, VideoPrismTextConfig, VideoPrismVisionConfig
def torch_int(x):
"""
Casts an input to a torch int64 tensor if we are in a tracing context, otherwise to a Python int.
"""
if not torch.is_available():
return int(x)
return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
@dataclass
class BaseModelOutputWithSpatialAndTemporalStates(ModelOutput):
"""
Base class for model outputs that include spatial and temporal states.
Args:
last_hidden_state (Optional[torch.FloatTensor]):
The last hidden state of the model, typically of shape
(batch_size, num_patches * num_frames, hidden_size).
temporal_hidden_state (Optional[torch.FloatTensor]):
The last hidden_state of the temporal encoder, typically of shape
(batch_size * num_patches, num_frames, hidden_size).
spatial_hidden_state (Optional[torch.FloatTensor]):
The last hidden_state of the spatial encoder, typically of shape
(batch_size * num_frames, num_patches, hidden_size).
"""
last_hidden_state: torch.FloatTensor | None = None
temporal_hidden_state: torch.FloatTensor | None = None
spatial_hidden_state: torch.FloatTensor | None = None
@dataclass
class VideoPrismClipOutput(ModelOutput):
"""
Base class for VideoPrismClip model outputs.
"""
logits_per_video: torch.FloatTensor | None = None
logits_per_text: torch.FloatTensor | None = None
video_embeds: torch.FloatTensor | None = None
text_embeds: torch.FloatTensor | None = None
@dataclass
class VideoPrismVideoOutput(ModelOutput):
"""
Base class for VideoPrismVideo model outputs.
"""
video_last_hidden_state: torch.FloatTensor | None = None
auxiliary_output: torch.FloatTensor | None = None
attention_pooling_output: torch.FloatTensor | None = None
class VideoPrismTubeletEmbeddings(nn.Module):
"""
Construct VideoPrism Tubelet embeddings.
This module turns a batch of videos of shape (batch_size, num_frames, num_channels, height, width) into a tensor of
shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.
The seq_len (the number of patches) equals (number of frames // tubelet_size[0]) * (height // tubelet_size[1]) *
(width // tubelet_size[2]).
"""
def __init__(self, config: VideoPrismVisionConfig):
super().__init__()
self.config = config
self.num_frames = config.num_frames
self.image_size = (
config.image_size
if isinstance(self.config.image_size, tuple)
else (self.config.image_size, self.config.image_size)
)
self.patch_size = config.tubelet_size
self.embed_dim = config.hidden_size
self.projection = nn.Conv3d(
config.num_channels, config.hidden_size, kernel_size=config.tubelet_size, stride=config.tubelet_size
)
self.pos_emb_shape = [self.image_size[0] // self.patch_size[1], self.image_size[1] // self.patch_size[2]]
self.num_patches = self.pos_emb_shape[0] * self.pos_emb_shape[1]
def forward(self, pixel_values_videos: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, num_frames, num_channels, height, width = pixel_values_videos.shape
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]}). Set interpolate_pos_encoding=True to automatically resize the model position embeddings."
)
# permute to (batch_size, num_channels, num_frames, height, width)
pixel_values_videos = pixel_values_videos.permute(0, 2, 1, 3, 4)
hidden_states = self.projection(pixel_values_videos)
# flatten the spatial part and permute to (B, T, num_patches, dim)
hidden_states = hidden_states.flatten(3).permute(0, 2, 3, 1)
# combine batch and time dimension
batch_size, num_frames, num_patches, hidden_size = hidden_states.shape
hidden_states = hidden_states.reshape(batch_size * num_frames, num_patches, hidden_size)
return hidden_states
class VideoPrismSpatialEmbeddings(nn.Module):
"""
VideoPrism Spatial Embeddings.
Creates embeddings from a video using VideoPrismSpatialTubeletEmbeddings and adds positional embeddings.
"""
def __init__(self, config: VideoPrismVisionConfig):
super().__init__()
self.config = config
self.patch_embeddings = VideoPrismTubeletEmbeddings(config)
self.position_embeddings = nn.Parameter(torch.zeros(1, self.patch_embeddings.num_patches, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.tubelet_size[1:]
self.tubelet_size = config.tubelet_size
# Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
num_patches = embeddings.shape[1]
num_positions = self.position_embeddings.shape[1]
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings
dim = embeddings.shape[-1]
num_row_patches = height // self.patch_size[0]
num_col_patches = width // self.patch_size[1]
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = self.position_embeddings.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(num_row_patches, num_col_patches),
mode="bilinear",
antialias=True,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
def forward(
self, pixel_values_videos: torch.Tensor, interpolate_pos_encoding: bool | None = False
) -> torch.Tensor:
b, t, c, h, w = pixel_values_videos.shape
assert h == w, "Input image height and width must be the same"
embeddings = self.patch_embeddings(pixel_values_videos, interpolate_pos_encoding)
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, h, w)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class VideoPrismTemporalEmbeddings(nn.Module):
"""
VideoPrism Temporal Embeddings.
Receives embeddings from spatial encoder, reshapes the hidden state to
(batch_size * num_patches, num_frames, hidden_size) and adds positional embeddings.
"""
def __init__(self, config: VideoPrismVisionConfig):
super().__init__()
self.config = config
self.position_embeddings = nn.Parameter(torch.zeros(1, self.config.num_frames, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
target_emb_length = embeddings.shape[1]
source_emb_length = self.position_embeddings.shape[1]
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and target_emb_length == source_emb_length:
return self.position_embeddings
source_emb = self.position_embeddings
dim = embeddings.shape[-1]
source_emb = source_emb.unsqueeze(1)
source_emb = nn.functional.interpolate(
source_emb,
size=(target_emb_length, dim),
mode="bilinear",
antialias=True,
)
return source_emb.squeeze(1)
def forward(
self,
pixel_values_videos: torch.Tensor,
input_shape: torch.Size,
interpolate_pos_encoding: bool | None = False,
) -> torch.Tensor:
if input_shape is not None:
b, t, c, h, w = input_shape
_, features, dim = pixel_values_videos.shape
hidden_states = pixel_values_videos.view(b, t, features, dim)
hidden_states = hidden_states.permute(0, 2, 1, 3)
embeddings = hidden_states.reshape(b * features, t, dim)
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
softcap: float | None = None,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if softcap is not None:
attn_weights = attn_weights / softcap
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * softcap
if attention_mask is not None:
attn_weights = attn_weights + attention_mask.expand(*attn_weights.shape)
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class VideoPrismSelfAttention(nn.Module):
def __init__(self, config: VideoPrismVisionConfig | VideoPrismTextConfig):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scale = self.attention_head_size**-0.5
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.shape[0]
new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
query = self.query(hidden_states).view(*new_shape).transpose(1, 2)
key = self.key(hidden_states).view(*new_shape).transpose(1, 2)
value = self.value(hidden_states).view(*new_shape).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
context_layer, attention_probs = attention_interface(
self,
query,
key,
value,
attention_mask,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout_prob,
softcap=self.config.attn_logit_softcapping,
**kwargs,
)
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
return (context_layer, attention_probs)
class VideoPrismSelfOutput(nn.Module):
"""
The residual connection is defined in VideoPrismLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: VideoPrismConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class VideoPrismAttention(nn.Module):
def __init__(self, config: VideoPrismConfig):
super().__init__()
self.attention = VideoPrismSelfAttention(config)
self.output = VideoPrismSelfOutput(config)
def forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, **kwargs
) -> torch.Tensor:
self_attn_output, _ = self.attention(hidden_states, attention_mask, **kwargs)
output = self.output(self_attn_output, hidden_states)
return output
class VideoPrismLayerNorm(nn.LayerNorm):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return F.layer_norm(hidden_states, self.normalized_shape, self.weight + 1, self.bias, self.eps)
class VideoPrismIntermediate(nn.Module):
def __init__(self, config: VideoPrismConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class VideoPrismOutput(nn.Module):
def __init__(self, config: VideoPrismConfig):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class VideoPrismLayer(GradientCheckpointingLayer):
"""This corresponds to the EncoderBlock class in the scenic/videoprism implementation."""
def __init__(self, config: VideoPrismVisionConfig | VideoPrismTextConfig):
super().__init__()
self.config = config
self.attention = VideoPrismAttention(config)
self.intermediate = VideoPrismIntermediate(config)
self.output = VideoPrismOutput(config)
self.layernorm_before = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
self.layernorm_after = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
hidden_states_norm = self.layernorm_before(hidden_states)
attention_output = self.attention(hidden_states_norm, attention_mask, **kwargs)
# first residual connection
hidden_states = attention_output + hidden_states
# in VideoPrism, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
# second residual connection is done here
layer_output = self.output(layer_output, hidden_states)
return layer_output
class VideoPrismSpatialEncoder(nn.Module):
def __init__(self, config: VideoPrismVisionConfig):
super().__init__()
self.config = config
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_spatial_layers)])
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> BaseModelOutput:
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states)
return BaseModelOutput(last_hidden_state=hidden_states)
class VideoPrismTemporalEncoder(nn.Module):
def __init__(self, config: VideoPrismVisionConfig):
super().__init__()
self.config = config
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_temporal_layers)])
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> BaseModelOutput:
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states)
return BaseModelOutput(last_hidden_state=hidden_states)
class VideoPrismAuxiliaryEncoder(nn.Module):
def __init__(self, config: VideoPrismVisionConfig):
super().__init__()
self.config = config
self.layer = nn.ModuleList([VideoPrismLayer(self.config) for _ in range(config.num_auxiliary_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**kwargs,
) -> BaseModelOutput:
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask, **kwargs)
return BaseModelOutput(last_hidden_state=hidden_states)
class VideoPrismTextEncoder(nn.Module):
def __init__(self, config: VideoPrismTextConfig):
super().__init__()
self.config = config
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_text_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**kwargs,
) -> BaseModelOutput:
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask, **kwargs)
return BaseModelOutput(last_hidden_state=hidden_states)
def variance_scaling_(tensor, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = 1.0 / denom
if distribution == "truncated_normal":
init.trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
init.normal_(tensor, std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
init.uniform_(tensor, -bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
class VideoPrismPreTrainedModel(PreTrainedModel):
config_class = VideoPrismConfig
config: VideoPrismConfig
base_model_prefix = "videoprism"
main_input_name = "pixel_values_videos"
input_modalities = ("video", "text")
supports_gradient_checkpointing = True
_no_split_modules = [
"VideoPrismSpatialEmbeddings",
"VideoPrismTemporalEmbeddings",
"VideoPrismSpatialEncoder",
"VideoPrismTemporalEncoder",
"VideoPrismAuxiliaryEncoder",
"VideoPrismTextEncoder",
"VideoPrismMultiheadAttentionPoolingHead",
]
_supports_sdpa = True
_supports_flash_attn = True
_supports_attention_backend = True
_supports_flex_attention = True
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Conv3d)):
lecun_normal_(module.weight)
init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
init.zeros_(module.bias)
init.ones_(module.weight)
class VideoPrismVisionModel(VideoPrismPreTrainedModel):
config_class = VideoPrismVisionConfig
config: VideoPrismVisionConfig
def __init__(self, config: VideoPrismVisionConfig):
super().__init__(config)
self.config = config
self.layernorm1 = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
self.layernorm2 = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
self.spatial_embeddings = VideoPrismSpatialEmbeddings(self.config)
self.temporal_embeddings = VideoPrismTemporalEmbeddings(self.config)
self.spatial_encoder = VideoPrismSpatialEncoder(self.config)
self.temporal_encoder = VideoPrismTemporalEncoder(self.config)
self.post_init()
def get_input_embeddings(self):
return self.spatial_embeddings.patch_embeddings
def forward(
self,
pixel_values_videos: torch.FloatTensor | None = None,
interpolate_pos_encoding: bool | None = False,
**kwargs,
) -> BaseModelOutputWithSpatialAndTemporalStates:
r"""
Args:
pixel_values_videos (`torch.FloatTensor`):
Pixel values of the video frames of shape (batch_size, num_frames, num_channels, height, width).
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate positional encodings to match input size.
Example:
```python
>>> from transformers import VideoPrismVideoProcessor, VideoPrismVisionModel
>>> import torch
>>> processor = VideoPrismVideoProcessor.from_pretrained("google/videoprism")
>>> model = VideoPrismVisionModel.from_pretrained("google/videoprism")
>>> video = "sample_video.mp4"
>>> inputs = processor(videos=video)
>>> with torch.no_grad():
... outputs = model(**inputs)
... features = outputs.last_hidden_state
```
"""
if pixel_values_videos is None:
raise ValueError("You have to specify pixel_values_videos")
input_shape = pixel_values_videos.shape
spatial_embeds = self.spatial_embeddings(pixel_values_videos, interpolate_pos_encoding)
spatial_encoder_outputs: BaseModelOutput = self.spatial_encoder(hidden_states=spatial_embeds, **kwargs)
# shape of spatial_sequence_output is (B * num_frames, num_patches, dim)
spatial_sequence_output = spatial_encoder_outputs.last_hidden_state
features = self.layernorm1(spatial_sequence_output)
temporal_embeds = self.temporal_embeddings(features, input_shape, interpolate_pos_encoding)
temporal_encoder_outputs: BaseModelOutput = self.temporal_encoder(hidden_states=temporal_embeds, **kwargs)
# shape of temporal_sequence_output is (B * num_patches, num_frames, dim)
temporal_sequence_output = temporal_encoder_outputs.last_hidden_state
features = self.layernorm2(temporal_sequence_output)
_, num_frames, dim = features.shape
features = features.view(input_shape[0], -1, num_frames, dim).permute(0, 2, 1, 3).contiguous()
_, num_frames, num_patches, dim = features.shape
features = features.view(input_shape[0], num_frames * num_patches, -1)
return BaseModelOutputWithSpatialAndTemporalStates(
last_hidden_state=features,
temporal_hidden_state=temporal_sequence_output,
spatial_hidden_state=spatial_sequence_output,
)
class VideoPrismMultiheadAttentionPoolingHead(nn.Module):
def __init__(self, config: VideoPrismVisionConfig):
super().__init__()
self.config = config
self.num_attention_heads = self.config.num_attention_heads
self.attention_head_size = int(self.config.intermediate_size / self.config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = self.config.attention_probs_dropout_prob
# PerDimScale
self.dim = int(self.config.intermediate_size / self.config.num_attention_heads)
self.per_dim_scale = nn.Parameter(torch.zeros(self.dim))
r_softplus_0 = 1.442695041
scale = torch.tensor(r_softplus_0 / (self.dim**0.5))
softplus = nn.functional.softplus(self.per_dim_scale)
scale = scale * softplus
self.register_buffer("scale", scale)
self.pooling_attention_query = nn.Parameter(torch.zeros(1, 1, self.config.hidden_size))
self.query = nn.Linear(self.config.hidden_size, self.config.intermediate_size, bias=self.config.qkv_bias)
self.key = nn.Linear(self.config.hidden_size, self.config.intermediate_size, bias=self.config.qkv_bias)
self.value = nn.Linear(self.config.hidden_size, self.config.intermediate_size, bias=self.config.qkv_bias)
self.projection = nn.Linear(self.config.intermediate_size, self.config.hidden_size, bias=self.config.qkv_bias)
self.layernorm = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
self.dim = int(self.config.intermediate_size / self.config.num_attention_heads)
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.LongTensor | None = None,
**kwargs,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
batch_size, seq_length, hidden_size = hidden_states.shape
query = self.pooling_attention_query.expand(batch_size, -1, -1)
query_layer = (
self.query(query).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
)
query_layer = query_layer * self.scale.expand(*query_layer.shape)
key_layer = (
self.key(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
value_layer = (
self.value(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
scaling=1.0,
dropout=0.0 if not self.training else self.dropout_prob,
softcap=None,
**kwargs,
)
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
outputs = self.projection(context_layer)
outputs = self.layernorm(outputs)
return (outputs, attention_probs)
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
"""This function is intended to align with the l2norm implementation in the FLA library."""
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return x * inv_norm
class VideoPrismTextModel(VideoPrismPreTrainedModel):
config_class = VideoPrismTextConfig
config: VideoPrismTextConfig
def __init__(self, config: VideoPrismTextConfig):
super().__init__(config)
self.config = config
self.text_encoder = VideoPrismTextEncoder(self.config)
self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.cls_emb = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.layernorm = VideoPrismLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.normalize = config.apply_l2_norm
self.post_init()
def create_sinusoidal_positions(self, num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / (dim - 2)))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**kwargs,
) -> BaseModelOutput:
r"""
Args:
input_ids (`torch.Tensor`):
Input token IDs.
attention_mask (`torch.Tensor`, *optional*):
Attention mask to avoid performing attention on padding token indices.
"""
batch_size, seq_length = input_ids.shape
hidden_states = self.token_embeddings(input_ids)
hidden_states = hidden_states * (self.config.hidden_size**0.5)
cls_padding = torch.ones(batch_size, 1)
input_ids = torch.cat((input_ids, cls_padding), dim=1)
attention_mask = torch.cat((attention_mask, cls_padding), dim=1) if attention_mask is not None else None
if attention_mask is not None:
attention_mask = create_causal_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=attention_mask,
cache_position=torch.arange(hidden_states.shape[1] + 1, device=hidden_states.device),
past_key_values=None,
)
features = hidden_states + self.create_sinusoidal_positions(seq_length, self.config.hidden_size)
cls_emb = self.cls_emb * (self.config.hidden_size**0.5)
cls_emb = cls_emb.expand(features.shape[0], -1, -1)
features = torch.cat((features, cls_emb), dim=1)
text_encoder_output = self.text_encoder(features, attention_mask)
features = text_encoder_output.last_hidden_state
features = self.layernorm(features)
text_embeddings = features[:, -1]
if self.normalize:
text_embeddings = l2norm(text_embeddings, dim=-1)
return BaseModelOutput(
last_hidden_state=text_embeddings,
)
class VideoPrismVideoModel(VideoPrismPreTrainedModel):
config_class = VideoPrismVisionConfig
config: VideoPrismVisionConfig
def __init__(self, config: VideoPrismVisionConfig):
super().__init__(config)
self.config = config
self.backbone = VideoPrismVisionModel(self.config)
self.auxiliary_encoder = VideoPrismAuxiliaryEncoder(self.config)
self.contrastive_vision_pooler = VideoPrismMultiheadAttentionPoolingHead(self.config)
self.normalize = self.config.apply_l2_norm
self.post_init()
def get_input_embeddings(self):
return self.backbone.spatial_embeddings.patch_embeddings
def forward(
self,
pixel_values_videos: torch.FloatTensor,
interpolate_pos_encoding: bool | None = False,
**kwargs,
) -> VideoPrismVideoOutput:
r"""
Args:
pixel_values_videos (`torch.FloatTensor`):
Pixel values of the video frames.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate positional encodings to match input size.
"""
backbone_outputs = self.backbone(
pixel_values_videos=pixel_values_videos, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs
)
video_features = backbone_outputs.last_hidden_state
auxiliary_output = self.auxiliary_encoder(video_features)
auxiliary_output_features = auxiliary_output.last_hidden_state
contrastive_vision_pooler_output = self.contrastive_vision_pooler(auxiliary_output_features, **kwargs)
video_embeddings = contrastive_vision_pooler_output[0]
if self.normalize:
video_embeddings = l2norm(video_embeddings, dim=-1)
return VideoPrismVideoOutput(
video_last_hidden_state=video_embeddings,
auxiliary_output=auxiliary_output,
attention_pooling_output=contrastive_vision_pooler_output,
)
class VideoPrismClipModel(VideoPrismPreTrainedModel):
config_class = VideoPrismConfig
def __init__(self, config: VideoPrismConfig):
super().__init__(config)
self.config = config
self.vision_config = config.vision_config
self.text_config = config.text_config
self.video_model = VideoPrismVideoModel(self.vision_config)
self.text_model = VideoPrismTextModel(self.text_config)
self.post_init()
def forward(
self,
pixel_values_videos: torch.FloatTensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
interpolate_pos_encoding: bool | None = False,
temperature: float | None = None,
**kwargs,
) -> VideoPrismClipOutput:
r"""
Args:
pixel_values_videos (`torch.FloatTensor`):
Pixel values of the video frames.
input_ids (`torch.Tensor`):
Input token IDs for text.
attention_mask (`torch.Tensor`, *optional*):
Attention mask for text inputs.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate positional encodings.
temperature (`float`, *optional*):
Temperature parameter for scaling similarity scores.
Example:
```python
>>> from transformers import VideoPrismProcessor, VideoPrismClipModel
>>> import torch
>>> processor = VideoPrismProcessor.from_pretrained("google/videoprism")
>>> model = VideoPrismClipModel.from_pretrained("google/videoprism")
>>> video = "sample_video.mp4"
>>> texts = ["a dog", "a cat"]
>>> inputs = processor(videos=video, texts=texts, return_tensors="pt", padding=True)
>>> with torch.no_grad():
... outputs = model(**inputs)
... logits_per_video = outputs.logits_per_video
```
"""
video_model_outputs = self.video_model(
pixel_values_videos=pixel_values_videos, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs
)
text_model_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
video_embeddings = video_model_outputs.video_last_hidden_state
text_embeddings = text_model_outputs.last_hidden_state
emb_dim = video_embeddings[0].shape[-1]
assert emb_dim == text_embeddings[0].shape[-1]
video_embeds = video_embeddings.reshape(-1, emb_dim)
text_embeds = text_embeddings.reshape(-1, emb_dim)
similarity_matrix = torch.matmul(video_embeds, text_embeds.T)
if temperature is not None:
similarity_matrix /= temperature
logits_per_video = torch.exp(similarity_matrix)
logits_per_text = logits_per_video.T
logits_per_video = logits_per_video / torch.sum(logits_per_video, dim=0, keepdims=True)
logits_per_text = logits_per_text / torch.sum(logits_per_text, dim=0, keepdims=True)
return VideoPrismClipOutput(
logits_per_video=logits_per_video,
logits_per_text=logits_per_text,
video_embeds=video_embeds,
text_embeds=text_embeds,
)
class VideoPrismForVideoClassification(VideoPrismPreTrainedModel):
config_class = VideoPrismVisionConfig
config: VideoPrismVisionConfig
def __init__(self, config: VideoPrismVisionConfig):
super().__init__(config)
self.config = config
self.encoder = VideoPrismVisionModel(self.config)
self.contrastive_vision_pooler = VideoPrismMultiheadAttentionPoolingHead(self.config)
self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels)
self.post_init()
def get_input_embeddings(self):
return self.encoder.spatial_embeddings.patch_embeddings
def forward(
self,
pixel_values_videos: torch.FloatTensor,
labels: torch.LongTensor | None = None,
interpolate_pos_encoding: bool | None = False,
**kwargs,
) -> ImageClassifierOutput:
r"""
Args:
pixel_values_videos (`torch.FloatTensor`):
Pixel values of the video frames.
labels (`torch.LongTensor`, *optional*):
Video classification labels.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate positional encodings.
Example:
```python
>>> from transformers import VideoPrismVideoProcessor, VideoPrismForVideoClassification
>>> import torch
>>> processor = VideoPrismVideoProcessor("google/videoprism")
>>> model = VideoPrismForVideoClassification.from_pretrained("google/videoprism", num_labels=1000)
>>> video = "sample_video.mp4"
>>> inputs = processor(videos=video, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
... logits = outputs.logits
```
"""
encoder_outputs = self.encoder(
pixel_values_videos=pixel_values_videos, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs
)
sequence_output = encoder_outputs.last_hidden_state
pooled_output = self.contrastive_vision_pooler(sequence_output, **kwargs).pooled_output
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss = self.loss_function(labels, logits, self.config, **kwargs)
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=encoder_outputs.last_hidden_state,
)
__all__ = [
"VideoPrismVisionModel",
"VideoPrismPreTrainedModel",
"VideoPrismVideoModel",
"VideoPrismTextModel",
"VideoPrismClipModel",
"VideoPrismForVideoClassification",
]
@@ -0,0 +1,50 @@
import torch
import numpy as np
from torchcodec.decoders import VideoDecoder
from lerobot.policies.videovla.videoprism import VideoPrismVideoProcessor
from lerobot.policies.videovla.videoprism import VideoPrismVisionModel
processor = VideoPrismVideoProcessor.from_pretrained(
"MHRDYN7/videoprism-base-f16r288"
)
model = VideoPrismVisionModel.from_pretrained(
"MHRDYN7/videoprism-base-f16r288",
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="sdpa",
)
video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/archery/-Qz25rXdMjE_000014_000024.mp4"
vr = VideoDecoder(video_url)
frame_idx = np.arange(0, 64)
video = vr.get_frames_at(indices=frame_idx).data # T x C x H x W
video = processor(video, return_tensors="pt")
video = {k: v.to(model.device, model.dtype) for k, v in video.items()}
outputs = model(**video)
encoder_outputs = outputs.last_hidden_state
print(encoder_outputs.shape) #
import time
import torch
# warmup
for _ in range(10):
_ = model(**video)
times = []
for _ in range(50):
torch.cuda.synchronize()
t0 = time.perf_counter()
_ = model(**video)
torch.cuda.synchronize()
t1 = time.perf_counter()
times.append(t1 - t0)
print(f"Mean: {1000*sum(times)/len(times):.2f} ms")
print(f"Min : {1000*min(times):.2f} ms")
print(f"Max : {1000*max(times):.2f} ms")
@@ -0,0 +1,44 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/videoprism/modular_videoprism.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_videoprism.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
from transformers.video_processing_utils import BaseVideoProcessor
class VideoPrismVideoProcessor(BaseVideoProcessor):
r"""
Constructs a VideoPrism video processor.
This processor inherits from [`LlavaOnevisionVideoProcessor`] and sets default parameters for VideoPrism models.
Video frames are resized to 288x288 using bicubic resampling without normalization.
Args:
size (`Dict[str, int]`, *optional*, defaults to `{"height": 288, "width": 288}`):
The size to resize the video frames to.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
The resampling filter to use when resizing images.
do_normalize (`bool`, *optional*, defaults to `False`):
Whether to normalize the video frames.
"""
resample = PILImageResampling.BICUBIC
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
size = {"height": 288, "width": 288}
rescale_factor = 1 / 255
default_to_square = False
crop_size = None
do_resize = True
do_center_crop = None
do_rescale = True
do_normalize = False
do_convert_rgb = True
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
__all__ = ["VideoPrismVideoProcessor"]
@@ -32,7 +32,7 @@ class VQBeTConfig(PreTrainedConfig):
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_features` and `output_features`.
Those are: `input_shapes` and `output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
@@ -46,12 +46,21 @@ class VQBeTConfig(PreTrainedConfig):
current step and additional steps going back).
n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts.
action_chunk_size: Action chunk size of each action prediction token.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.image" refers to an input from
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
-2
View File
@@ -44,7 +44,6 @@ from .hil_processor import (
AddTeleopActionAsComplimentaryDataStep,
AddTeleopEventsAsInfoStep,
GripperPenaltyProcessorStep,
GymHILAdapterProcessorStep,
ImageCropResizeProcessorStep,
InterventionActionProcessorStep,
RewardClassifierProcessorStep,
@@ -88,7 +87,6 @@ __all__ = [
"DoneProcessorStep",
"EnvAction",
"EnvTransition",
"GymHILAdapterProcessorStep",
"GripperPenaltyProcessorStep",
"hotswap_stats",
"IdentityProcessorStep",
+1 -2
View File
@@ -168,12 +168,11 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
def create_transition(
+6 -4
View File
@@ -17,7 +17,7 @@ from dataclasses import dataclass
import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@@ -92,7 +92,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
# copy over non-STATE features
for ft, feats in features.items():
if ft != FeatureType.STATE:
if ft != PipelineFeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
@@ -100,11 +100,13 @@ class LiberoProcessorStep(ObservationProcessorStep):
# add our new flattened state
state_feats[OBS_STATE] = PolicyFeature(
type=FeatureType.STATE,
key=OBS_STATE,
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
dtype="float32",
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
)
new_features[FeatureType.STATE] = state_feats
new_features[PipelineFeatureType.STATE] = state_feats
return new_features
@@ -20,7 +20,6 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from .converters import to_tensor
from .core import EnvAction, EnvTransition, PolicyAction
from .hil_processor import TELEOP_ACTION_KEY
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry
@@ -90,13 +89,6 @@ class Numpy2TorchActionProcessorStep(ProcessorStep):
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
new_transition[TransitionKey.ACTION] = torch_action
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if TELEOP_ACTION_KEY in complementary_data:
teleop_action = complementary_data[TELEOP_ACTION_KEY]
if isinstance(teleop_action, EnvAction):
complementary_data[TELEOP_ACTION_KEY] = to_tensor(teleop_action)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
def transform_features(
+15 -50
View File
@@ -18,18 +18,16 @@
import math
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable
from typing import Any, Protocol, TypeVar, runtime_checkable
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
if TYPE_CHECKING:
from lerobot.teleoperators.teleoperator import Teleoperator
from .core import EnvTransition, PolicyAction, TransitionKey
from .pipeline import (
ComplementaryDataProcessorStep,
@@ -71,10 +69,10 @@ class HasTeleopEvents(Protocol):
# Type variable constrained to Teleoperator subclasses that also implement events
TeleopWithEvents = TypeVar("TeleopWithEvents", bound="Teleoperator")
TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
def _check_teleop_with_events(teleop: "Teleoperator") -> None:
def _check_teleop_with_events(teleop: Teleoperator) -> None:
"""
Runtime check that a teleoperator implements the `HasTeleopEvents` protocol.
@@ -105,7 +103,7 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
teleop_device: The teleoperator instance to get the action from.
"""
teleop_device: "Teleoperator"
teleop_device: Teleoperator
def complementary_data(self, complementary_data: dict) -> dict:
"""
@@ -312,40 +310,9 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
return features
@ProcessorStepRegistry.register("gym_hil_adapter_processor")
class GymHILAdapterProcessorStep(ProcessorStep):
"""
Adapts the output of the `gym-hil` environment to the format expected by `lerobot` processors.
This step normalizes the `transition` object by:
1. Copying `teleop_action` from `info` to `complementary_data`.
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
info = transition.get(TransitionKey.INFO, {})
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if TELEOP_ACTION_KEY in info:
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
if "is_intervention" in info:
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
transition[TransitionKey.INFO] = info
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessorStep(ProcessorStep):
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
"""
Applies a penalty for inefficient gripper usage.
@@ -360,27 +327,26 @@ class GripperPenaltyProcessorStep(ProcessorStep):
penalty: float = -0.01
max_gripper_pos: float = 30.0
def __call__(self, transition: EnvTransition) -> EnvTransition:
def complementary_data(self, complementary_data: dict) -> dict:
"""
Calculates the gripper penalty and adds it to the complementary data.
Args:
transition: The incoming environment transition.
complementary_data: The incoming complementary data, which should contain
raw joint positions.
Returns:
The modified transition with the penalty added to complementary data.
A new complementary data dictionary with the `discrete_penalty` key added.
"""
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
action = self.transition.get(TransitionKey.ACTION)
raw_joint_positions = complementary_data.get("raw_joint_positions")
if raw_joint_positions is None:
return new_transition
return complementary_data
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
if current_gripper_pos is None:
return new_transition
return complementary_data
# Gripper action is a PolicyAction at this stage
gripper_action = action[-1].item()
@@ -396,12 +362,11 @@ class GripperPenaltyProcessorStep(ProcessorStep):
gripper_penalty = self.penalty * int(gripper_penalty_bool)
# Update complementary data with penalty info
# Create new complementary data with penalty info
new_complementary_data = dict(complementary_data)
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
return new_complementary_data
def get_config(self) -> dict[str, Any]:
"""
@@ -131,15 +131,6 @@ class _NormalizationMixin:
if self.dtype is None:
self.dtype = torch.float32
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
def _reshape_visual_stats(self) -> None:
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
for key, feature in self.features.items():
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
for stat_name, stat_tensor in self._tensor_stats[key].items():
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
@@ -158,7 +149,6 @@ class _NormalizationMixin:
if dtype is not None:
self.dtype = dtype
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
return self
def state_dict(self) -> dict[str, Tensor]:
@@ -208,7 +198,6 @@ class _NormalizationMixin:
# Don't load from state_dict, keep the explicitly provided stats
# But ensure _tensor_stats is properly initialized
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
self._reshape_visual_stats()
return
# Normal behavior: load stats from state_dict
@@ -219,7 +208,6 @@ class _NormalizationMixin:
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
dtype=torch.float32, device=self.device
)
self._reshape_visual_stats()
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
# and other functions that rely on self.stats
+1 -1
View File
@@ -413,7 +413,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
Args:
save_directory: The directory where the pipeline will be saved. If None, saves to
HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}.
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=true`.
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`.
push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it.
card_kwargs: Additional arguments passed to the card template to customize the card.
config_filename: The name of the JSON configuration file. If None, a name is
@@ -34,8 +34,6 @@ from lerobot.utils.constants import (
ACTION_TOKEN_MASK,
ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
OBS_LANGUAGE_SUBTASK_TOKENS,
OBS_LANGUAGE_TOKENS,
)
from lerobot.utils.import_utils import _transformers_available
@@ -141,32 +139,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
return None
def get_subtask(self, transition: EnvTransition) -> list[str] | None:
"""
Extracts the subtask from the transition's complementary data.
Args:
transition: The environment transition.
Returns:
A list of subtask strings, or None if the subtask key is not found or the value is None.
"""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return None
subtask = complementary_data.get("subtask")
if subtask is None:
return None
# Standardize to a list of strings for the tokenizer
if isinstance(subtask, str):
return [subtask]
elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask):
return subtask
return None
def observation(self, observation: RobotObservation) -> RobotObservation:
"""
Tokenizes the task description and adds it to the observation dictionary.
@@ -204,24 +176,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
# Tokenize subtask if available
subtask = self.get_subtask(self.transition)
if subtask is not None:
tokenized_subtask = self._tokenize_text(subtask)
# Move new tokenized tensors to the detected device
if target_device is not None:
tokenized_subtask = {
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
for k, v in tokenized_subtask.items()
}
# Add tokenized subtask to the observation
new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"]
new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to(
dtype=torch.bool
)
return new_observation
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
-13
View File
@@ -1,13 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Some files were not shown because too many files have changed in this diff Show More