Compare commits

..

10 Commits

Author SHA1 Message Date
Pepijn f147a4cd48 Add inference for training time rtc 2026-01-29 11:05:42 +01:00
Pepijn c3fa269b21 Merge branch 'main' into feat/training_time_rtc 2026-01-27 17:34:56 +01:00
Reece O'Mahoney f6b1c39b78 docs: update libero (#2857)
* update libero docs

* Update docs/source/libero.mdx

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Jade Choghari <chogharijade@gmail.com>

---------

Signed-off-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-01-27 15:31:53 +01:00
Pepijn 0c0c171d35 Add robot images to docs (#2862)
* Add robot images to docs

* increase img size

* remove img so100
2026-01-27 13:33:45 +01:00
Steven Palma 9cfb5ce546 feat(motors): add damiao motors & can bus (#2788)
* fix(motors): cleanup imports + fix signatures

* feat(motors): add damiao canbus + multiple fixes

* fix(motors): address comments -> last_state + different gains + sleep

* refactor(motors): reduce duplicated code + adressed some comments in the PR

* chore(motors): better timeouts

* tests(motors): damiao test and imports

* chore(deps): fix space

* Apply suggestions from code review

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* chore(motors): remove normalization tables damiao

* fix(motors): imports and signatures

* feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id

* chore(motors): remove normalize from base motor class and damaio

* tests(motors): remove bad tests (to be replaced)

* chore(motors): updated import check

* use constant for kp and kd range and check responses in mit_control_batch()

* Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command

* precommit format

* supress bandit as these are intentional cli commands

* fix setup-can

* add test

* skip test in ci

* nit precommit

* update doc example

* dont import can for tests

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Pepijn <pepijn@huggingface.co>
2026-01-26 17:53:25 +01:00
Reece O'Mahoney 366bef915c add task ids to libero env cfg (#2842) 2026-01-26 17:26:49 +01:00
Woojin Wie 9e10eb4a77 fix(robots): update gripper configuration and calibration settings for OMX (#2815) 2026-01-25 22:29:37 +01:00
Pepijn 385ba8d1b7 remove wall-oss from doc links 2026-01-20 20:11:56 +01:00
Pepijn f4ccf911fa format 2026-01-20 20:08:28 +01:00
Pepijn 0cb8c92fe4 Implement training time rtc for pi0, pi0.5 and smolvla 2026-01-20 20:02:10 +01:00
48 changed files with 2482 additions and 4086 deletions
+6
View File
@@ -57,6 +57,8 @@
title: Use Async Inference
- local: rtc
title: Real-Time Chunking (RTC)
- local: training_time_rtc
title: Training-Time RTC
title: "Inference"
- sections:
- local: envhub
@@ -99,6 +101,8 @@
title: Unitree G1
- local: earthrover_mini_plus
title: Earth Rover Mini
- local: omx
title: OMX
title: "Robots"
- sections:
- local: phone_teleop
@@ -113,6 +117,8 @@
title: Notebooks
- local: feetech
title: Updating Feetech Firmware
- local: damiao
title: Damiao Motors and CAN Bus
title: "Resources"
- sections:
- local: contributing
+165
View File
@@ -0,0 +1,165 @@
# Damiao Motors and CAN Bus
This guide covers setup and usage of Damiao motors with LeRobot via CAN bus communication.
Currently, only Linux is supported, as the OpenArms CAN adapter only has drivers for Linux.
## Linux CAN Setup
Before using Damiao motors, you need to set up the CAN interface on your Linux system.
### Install CAN Utilities
```bash
sudo apt-get install can-utils
```
### Configure CAN Interface (Manual)
For standard CAN FD (recommended for OpenArms):
```bash
sudo ip link set can0 down
sudo ip link set can0 type can bitrate 1000000 dbitrate 5000000 fd on
sudo ip link set can0 up
```
For standard CAN (without FD):
```bash
sudo ip link set can0 down
sudo ip link set can0 type can bitrate 1000000
sudo ip link set can0 up
```
### Configure CAN Interface (Using LeRobot)
LeRobot provides a utility script to setup and test CAN interfaces:
```bash
# Setup multiple interfaces (e.g., OpenArms Followers with 2 CAN buses)
lerobot-setup-can --mode=setup --interfaces=can0,can1
```
## Debugging CAN Communication
Use the built-in debug tools to test motor communication:
```bash
# Test motors on all interfaces
lerobot-setup-can --mode=test --interfaces=can0,can1
# Run speed/latency test
lerobot-setup-can --mode=speed --interfaces=can0
```
The test mode will scan for motors (IDs 0x01-0x08) and report which ones respond. Example output:
```
can0: UP (CAN FD)
Motor 0x01 (joint_1): ✓ FOUND
→ Response 0x11 [FD]: 00112233...
Motor 0x02 (joint_2): ✓ FOUND
Motor 0x03 (joint_3): ✗ No response
...
Summary: 2/8 motors found
```
## Usage
### Basic Setup
```python
from lerobot.motors import Motor
from lerobot.motors.damiao import DamiaoMotorsBus
# Define your motors with send/receive CAN IDs
motors = {
"joint_1": Motor(id=0x01, motor_type_str="dm8009", recv_id=0x11),
"joint_2": Motor(id=0x02, motor_type_str="dm4340", recv_id=0x12),
"joint_3": Motor(id=0x03, motor_type_str="dm4310", recv_id=0x13),
}
# Create the bus
bus = DamiaoMotorsBus(
port="can0", # Linux socketcan interface
motors=motors,
)
# Connect
bus.connect()
```
### Reading Motor States
```python
# Read single motor position (degrees)
position = bus.read("Present_Position", "joint_1")
# Read from multiple motors
positions = bus.sync_read("Present_Position") # All motors
positions = bus.sync_read("Present_Position", ["joint_1", "joint_2"])
# Read all states at once (position, velocity, torque)
states = bus.sync_read_all_states()
# Returns: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
```
### Writing Motor Commands
```python
# Enable torque
bus.enable_torque()
# Set goal position (degrees)
bus.write("Goal_Position", "joint_1", 45.0)
# Set positions for multiple motors
bus.sync_write("Goal_Position", {
"joint_1": 45.0,
"joint_2": -30.0,
"joint_3": 90.0,
})
# Disable torque
bus.disable_torque()
```
## Configuration Options
| Parameter | Default | Description |
| -------------- | --------- | ----------------------------------------------------------- |
| `port` | - | CAN interface (`can0`) or serial port (`/dev/cu.usbmodem*`) |
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
| `bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
| `data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
## Motor Configuration
Each motor requires:
- `id`: CAN ID for sending commands
- `motor_type`: One of the supported motor types (e.g., `"dm8009"`, `"dm4340"`)
- `recv_id`: CAN ID for receiving responses
OpenArms default IDs follow the pattern: send ID `0x0N`, receive ID `0x1N` where N is the joint number.
## Troubleshooting
### No Response from Motors
1. **Check power**
2. **Verify CAN wiring**: Check CAN-H, CAN-L, and GND connections
3. **Check motor IDs**: Use Damiao Debugging Tools to verify/configure IDs
4. **Test CAN interface**: Run `candump can0` to see if messages are being received
5. **Run diagnostics**: `lerobot-setup-can --mode=test --interfaces=can0`
### Motor Timeout Parameter
If motors were configured with timeout=0, they won't respond to commands. Use Damiao Debugging Tools to set a non-zero timeout value.
### Verify CAN FD Status
```bash
ip -d link show can0 | grep fd
```
+6
View File
@@ -1,5 +1,11 @@
# EarthRover Mini Plus
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Earth_Rover_Mini_5_240c9adc-4f9e-44b7-982f-5d1dc24af1d8.png.webp"
alt="EarthRover Mini Plus"
width="70%"
/>
The EarthRover Mini Plus is a fully open source mobile robot that connects through the cloud using the Frodobots SDK. This lets you control the robot and record datasets for training AI models.
## What You Need
+6
View File
@@ -1,5 +1,11 @@
# LeKiwi
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/1740517739083.jpeg"
alt="LeKiwi"
width="70%"
/>
In the steps below, we explain how to assemble the LeKiwi mobile robot.
## Source the parts
+1
View File
@@ -42,6 +42,7 @@ lerobot-eval \
```
- `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.).
- `--env.task_ids` picks task ids to run (`[0]`, `[1,2,3]`, etc.). Omit this flag (or set it to `null`) to run all tasks in the suite.
- `--eval.batch_size` controls how many environments run in parallel.
- `--eval.n_episodes` sets how many episodes to run in total.
+197
View File
@@ -0,0 +1,197 @@
## Order and Assemble the parts
First, assemble the OMX hardware following the official assembly guide.
OMX Assembly Guide: https://ai.robotis.com/omx/assembly_guide_omx.html
OMX robots are shipped preconfigured from the factory. Motor IDs, communication parameters, and joint offsets are already set, so no additional motor setup or calibration is required before using LeRobot.
## Install LeRobot 🤗
To install LeRobot, follow our [Installation Guide](./installation)
In addition to these instructions, you need to install the Dynamixel SDK:
```bash
pip install -e ".[dynamixel]"
```
## Connect the robot
To find the port for each bus servo adapter, run this script:
```bash
lerobot-find-port
```
This command runs and when prompted, disconnect the USB cable from either the leader or follower arm and press Enter. The output will show 'The port of this MotorsBus is [port]'. This identifies the port for the disconnected arm. Repeat for the other arm to identify both ports.
<hfoptions id="find_port">
<hfoption id="Mac">
Example output on macOS:
```
Finding all available ports for the MotorBus.
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
Remove the USB cable from your MotorsBus and press Enter when done.
[...Disconnect corresponding leader or follower arm and press Enter...]
The port of this MotorsBus is /dev/tty.usbmodem575E0032081
Reconnect the USB cable.
```
Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm.
</hfoption>
<hfoption id="Linux">
On Linux, we strongly recommend using udev rules to assign persistent and human-readable device names to the OMX leader and follower arms. This avoids issues where device names such as ttyACM0 and ttyACM1 change when the robot is unplugged, replugged, or when the system is rebooted.
#### 1. Find your device serial numbers
You should have obtained the port numbers like ../../ttyACM? for the leader and follower using `lerobot-find-port`. You can match those results with the serial numbers using the `ls -l /dev/serial/by-id/` command.
To create udev rules, you need the unique serial number for each OMX device. The easiest way is to list devices under:
```bash
ls -l /dev/serial/by-id/
```
You will see output similar to:
```bash
usb-ROBOTIS_OpenRB-150_228BDD7B503059384C2E3120FF0A2B19-if00 -> ../../ttyACM0
usb-ROBOTIS_OpenRB-150_67E1ED68503059384C2E3120FF092234-if00 -> ../../ttyACM1
```
In each line, the serial number is the long string after `usb-ROBOTIS_OpenRB-150_` and before `-if00`.
Follower serial: `228BDD7B503059384C2E3120FF0A2B19`
Leader serial: `67E1ED68503059384C2E3120FF092234`
#### 2. Create the udev rule
Create a new udev rule file:
```bash
sudo nano /etc/udev/rules.d/99-omx.rules
```
Paste the following lines, replacing the serial numbers with the values you found above:
```bash
SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="228BDD7B503059384C2E3120FF0A2B19", SYMLINK+="omx_follower"
SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="67E1ED68503059384C2E3120FF092234", SYMLINK+="omx_leader"
```
Save the file and reload udev rules:
```bash
sudo udevadm control --reload-rules
sudo udevadm trigger
```
Now unplug and replug both devices once.
#### 3. Verify the symlinks
Check that the persistent device names exist:
```bash
ls -l /dev/omx_follower /dev/omx_leader
```
You should see them pointing to ttyACM\* devices:
```bash
/dev/omx_follower -> ttyACM*
/dev/omx_leader -> ttyACM*
```
These names remain stable across reboots and reconnections.
</hfoption>
</hfoptions>
## Teleoperate
After identifying the correct ports, you can directly teleoperate the follower arm using the leader arm.
<hfoptions id="teleoperate">
<hfoption id="Mac">
### Teleoperate without camera
```bash
lerobot-teleoperate \
--robot.type=omx_follower \
--robot.port=<your_follower_port> \
--robot.id=omx_follower_arm \
--teleop.type=omx_leader \
--teleop.port=<your_leader_port> \
--teleop.id=omx_leader_arm
```
During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps.
### Teleoperate with camera
You can also enable camera input during teleoperation by providing a camera configuration for the follower arm.
```bash
lerobot-teleoperate \
--robot.type=omx_follower \
--robot.port=<your_follower_port> \
--robot.id=omx_follower_arm \
--robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \
--teleop.type=omx_leader \
--teleop.port=<your_leader_port> \
--teleop.id=omx_leader_arm \
--display_data=true
```
When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning.
</hfoption>
<hfoption id="Linux">
### Teleoperate without camera
```bash
lerobot-teleoperate \
--robot.type=omx_follower \
--robot.port=/dev/omx_follower \
--robot.id=omx_follower_arm \
--teleop.type=omx_leader \
--teleop.port=/dev/omx_leader \
--teleop.id=omx_leader_arm
```
During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps.
### Teleoperate with camera
You can also enable camera input during teleoperation by providing a camera configuration for the follower arm.
```bash
lerobot-teleoperate \
--robot.type=omx_follower \
--robot.port=/dev/omx_follower \
--robot.id=omx_follower_arm \
--robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \
--teleop.type=omx_leader \
--teleop.port=/dev/omx_leader \
--teleop.id=omx_leader_arm \
--display_data=true
```
When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning.
</hfoption>
</hfoptions>
Congrats 🎉, your robot is all set to learn a task on its own.
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/robotis).
+13
View File
@@ -1,5 +1,18 @@
# SO-101
<div style="display: flex; align-items: center; gap: 10px;">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/SO101_Follower.webp"
alt="SO-101"
width="60%"
/>
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/SO101_Leader.webp"
alt="SO-101"
width="60%"
/>
</div>
In the steps below, we explain how to assemble our flagship robot, the SO-101.
## Source the parts
+86
View File
@@ -0,0 +1,86 @@
# Training-Time RTC
Training-Time RTC teaches the model to handle inference delay during training.
It feeds the **ground-truth action prefix** to the model and trains only on the remaining postfix actions.
This keeps chunk transitions smooth without doing any inference-time inpainting.
Based on: [Training-Time Action Conditioning for Efficient Real-Time Chunking](https://arxiv.org/abs/2512.05964).
LeRobot supports this for `pi0`, `pi05` and `smolvla` without changing model parameters.
---
## How It Works
### At Training Time
- Sample a delay `d` per batch element.
- Keep the first `d` action steps as **ground truth** (no noise).
- Add noise only to the postfix actions.
- Set the flow-matching timestep to **1.0** for prefix tokens and normal timesteps for postfix tokens.
- Mask the loss to only train on the postfix.
### At Inference Time
When `rtc_training_config.enabled=true`, the model uses training-time RTC inference:
- Replace prefix positions in `x_t` with previous chunk's leftover actions.
- Set timestep to **1.0** for prefix positions.
---
## Quick Start (CLI)
```bash
lerobot-train \
--policy.type=pi0 \
--dataset.repo_id=your/dataset \
--policy.rtc_training_config.enabled=true \
--policy.rtc_training_config.min_delay=0 \
--policy.rtc_training_config.max_delay=6 \
--policy.rtc_training_config.delay_distribution=UNIFORM
```
---
## Inference with Training-Time RTC
After training with `rtc_training_config`, use the same config at inference. The model will automatically use training-time RTC inference:
```python
policy = PI0Policy.from_pretrained("path/to/trained/model")
# rtc_training_config is loaded from the saved config
actions = policy.predict_action_chunk(
batch,
inference_delay=5, # estimated delay in timesteps
prev_chunk_left_over=previous_actions, # from previous chunk
)
```
---
## Key Parameters
`RTCTrainingConfig` is available on the policy config (`pi0`, `pi05`, `smolvla`, `xvla`):
- **`enabled`**: Toggle training-time RTC (both training and inference).
- **`min_delay` / `max_delay`**: Delay range (inclusive).
- **`delay_distribution`**:
- `UNIFORM`: uniform in `[min_delay, max_delay]`
- `EXP`: exponentially decayed distribution over delays
- **`exp_decay`**: Exponential decay factor for `EXP` sampling.
---
## Notes and Recommendations
- Start with `min_delay=0` and `max_delay` around your expected worst-case inference delay.
- Use `EXP` if you want more supervision on smaller delays.
---
## Related Docs
- [Real-Time Chunking (Inference-Time RTC)](./rtc)
- [Pi0](./pi0), [Pi0.5](./pi05), [SmolVLA](./smolvla)
+3
View File
@@ -102,6 +102,7 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
damiao = ["python-can>=4.2.0,<5.0.0"]
# Robots
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
@@ -203,6 +204,7 @@ lerobot-info="lerobot.scripts.lerobot_info:main"
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.packages.find]
@@ -278,6 +280,7 @@ default.extend-ignore-identifiers-re = [
"thw",
"inpt",
"ROBOTIS",
"OT_VALUE"
]
# TODO: Uncomment when ready to use
-10
View File
@@ -105,16 +105,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
raise NotImplementedError
@property
def image_observation_delta_indices(self) -> list | None: # type: ignore[type-arg]
"""Return indices for delta image observations only.
Unlike observation_delta_indices which applies to ALL observations,
this only applies to image observations (keys starting with observation.images).
Default returns None. Override in subclass to enable.
"""
return None
@property
@abc.abstractmethod
def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
+5
View File
@@ -50,3 +50,8 @@ class RTCAttentionSchedule(str, Enum):
ONES = "ONES"
LINEAR = "LINEAR"
EXP = "EXP"
class RTCTrainingDelayDistribution(str, Enum):
UNIFORM = "UNIFORM"
EXP = "EXP"
+2 -7
View File
@@ -27,7 +27,7 @@ from lerobot.datasets.lerobot_dataset import (
)
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_PREFIX, REWARD
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
@@ -59,12 +59,7 @@ def resolve_delta_timestamps(
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
if key == ACTION and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
# Check for image-specific delta indices first (e.g., for video encoding)
if key.startswith(OBS_IMAGES) and cfg.image_observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.image_observation_delta_indices]
# Fall back to generic observation delta indices for all observations
elif key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
if len(delta_timestamps) == 0:
+5 -4
View File
@@ -260,6 +260,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
@dataclass
class LiberoEnv(EnvConfig):
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
task_ids: list[int] | None = None
fps: int = 30
episode_length: int | None = None
obs_type: str = "pixels_agent_pos"
@@ -338,10 +339,10 @@ class LiberoEnv(EnvConfig):
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
}
kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
if self.task_ids is not None:
kwargs["task_ids"] = self.task_ids
return kwargs
@EnvConfig.register_subclass("metaworld")
+5 -1
View File
@@ -14,4 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus
from .motors_bus import (
Motor,
MotorCalibration,
MotorNormMode,
)
+1 -1
View File
@@ -18,7 +18,7 @@ from dataclasses import dataclass
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"
from lerobot.motors import MotorCalibration, MotorsBus
from .motors_bus import MotorCalibration, MotorsBus
BAR_LEN, BAR_THICKNESS = 450, 8
HANDLE_R = 10
+18
View File
@@ -0,0 +1,18 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .damiao import DamiaoMotorsBus
from .tables import *
+808
View File
@@ -0,0 +1,808 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Portions of this file are derived from DM_Control_Python by cmjang.
# Licensed under the MIT License; see `LICENSE` for the full text:
# https://github.com/cmjang/DM_Control_Python
import logging
import time
from contextlib import contextmanager
from copy import deepcopy
from functools import cached_property
from typing import TYPE_CHECKING, Any, TypedDict
from lerobot.utils.import_utils import _can_available
if TYPE_CHECKING or _can_available:
import can
else:
can.Message = object
can.interface = None
import numpy as np
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import enter_pressed, move_cursor_up
from ..motors_bus import Motor, MotorCalibration, MotorsBusBase, NameOrID, Value
from .tables import (
AVAILABLE_BAUDRATES,
CAN_CMD_DISABLE,
CAN_CMD_ENABLE,
CAN_CMD_REFRESH,
CAN_CMD_SET_ZERO,
CAN_PARAM_ID,
DEFAULT_BAUDRATE,
DEFAULT_TIMEOUT_MS,
MIT_KD_RANGE,
MIT_KP_RANGE,
MOTOR_LIMIT_PARAMS,
MotorType,
)
logger = logging.getLogger(__name__)
LONG_TIMEOUT_SEC = 0.1
MEDIUM_TIMEOUT_SEC = 0.01
SHORT_TIMEOUT_SEC = 0.001
PRECISE_TIMEOUT_SEC = 0.0001
class MotorState(TypedDict):
position: float
velocity: float
torque: float
temp_mos: float
temp_rotor: float
class DamiaoMotorsBus(MotorsBusBase):
"""
The Damiao implementation for a MotorsBus using CAN bus communication.
This class uses python-can for CAN bus communication with Damiao motors.
For more info, see:
- python-can documentation: https://python-can.readthedocs.io/en/stable/
- Seedstudio documentation: https://wiki.seeedstudio.com/damiao_series/
- DM_Control_Python repo: https://github.com/cmjang/DM_Control_Python
"""
# CAN-specific settings
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
default_baudrate = DEFAULT_BAUDRATE
default_timeout = DEFAULT_TIMEOUT_MS
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
can_interface: str = "auto",
use_can_fd: bool = True,
bitrate: int = 1000000,
data_bitrate: int | None = 5000000,
):
"""
Initialize the Damiao motors bus.
Args:
port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS)
motors: Dictionary mapping motor names to Motor objects
calibration: Optional calibration data
can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial)
use_can_fd: Whether to use CAN FD mode (default: True for OpenArms)
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
"""
super().__init__(port, motors, calibration)
self.port = port
self.can_interface = can_interface
self.use_can_fd = use_can_fd
self.bitrate = bitrate
self.data_bitrate = data_bitrate
self.canbus: can.interface.Bus | None = None
self._is_connected = False
# Map motor names to CAN IDs
self._motor_can_ids: dict[str, int] = {}
self._recv_id_to_motor: dict[int, str] = {}
self._motor_types: dict[str, MotorType] = {}
for name, motor in self.motors.items():
if motor.motor_type_str is None:
raise ValueError(f"Motor '{name}' is missing required 'motor_type'")
self._motor_types[name] = getattr(MotorType, motor.motor_type_str.upper().replace("-", "_"))
# Map recv_id to motor name for filtering responses
if motor.recv_id is not None:
self._recv_id_to_motor[motor.recv_id] = name
# State cache for handling packet drops safely
self._last_known_states: dict[str, MotorState] = {
name: {
"position": 0.0,
"velocity": 0.0,
"torque": 0.0,
"temp_mos": 0.0,
"temp_rotor": 0.0,
}
for name in self.motors
}
# Dynamic gains storage
# Defaults: Kp=10.0 (Stiffness), Kd=0.5 (Damping)
self._gains: dict[str, dict[str, float]] = {name: {"kp": 10.0, "kd": 0.5} for name in self.motors}
@property
def is_connected(self) -> bool:
"""Check if the CAN bus is connected."""
return self._is_connected and self.canbus is not None
def connect(self, handshake: bool = True) -> None:
"""
Open the CAN bus and initialize communication.
Args:
handshake: If True, ping all motors to verify they're present
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(
f"{self.__class__.__name__}('{self.port}') is already connected."
)
try:
# Auto-detect interface type based on port name
if self.can_interface == "auto":
if self.port.startswith("/dev/"):
self.can_interface = "slcan"
logger.info(f"Auto-detected slcan interface for port {self.port}")
else:
self.can_interface = "socketcan"
logger.info(f"Auto-detected socketcan interface for port {self.port}")
# Connect to CAN bus
kwargs = {
"channel": self.port,
"bitrate": self.bitrate,
"interface": self.can_interface,
}
if self.can_interface == "socketcan" and self.use_can_fd and self.data_bitrate is not None:
kwargs.update({"data_bitrate": self.data_bitrate, "fd": True})
logger.info(
f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})"
)
else:
logger.info(f"Connected to {self.port} with {self.can_interface} (bitrate={self.bitrate})")
self.canbus = can.interface.Bus(**kwargs)
self._is_connected = True
if handshake:
self._handshake()
logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.")
except Exception as e:
self._is_connected = False
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
def _handshake(self) -> None:
"""
Verify all motors are present and populate initial state cache.
Raises ConnectionError if any motor fails to respond.
"""
logger.info("Starting handshake with motors...")
missing_motors = []
for motor_name in self.motors:
msg = self._refresh_motor(motor_name)
if msg is None:
missing_motors.append(motor_name)
else:
self._process_response(motor_name, msg)
time.sleep(MEDIUM_TIMEOUT_SEC)
if missing_motors:
raise ConnectionError(
f"Handshake failed. The following motors did not respond: {missing_motors}. "
"Check power (24V) and CAN wiring."
)
logger.info("Handshake successful. All motors ready.")
def disconnect(self, disable_torque: bool = True) -> None:
"""
Close the CAN bus connection.
Args:
disable_torque: If True, disable torque on all motors before disconnecting
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.")
if disable_torque:
try:
self.disable_torque()
except Exception as e:
logger.warning(f"Failed to disable torque during disconnect: {e}")
if self.canbus:
self.canbus.shutdown()
self.canbus = None
self._is_connected = False
logger.debug(f"{self.__class__.__name__} disconnected.")
def configure_motors(self) -> None:
"""Configure all motors with default settings."""
# Damiao motors don't require much configuration in MIT mode
# Just ensure they're enabled
for motor in self.motors:
self._send_simple_command(motor, CAN_CMD_ENABLE)
time.sleep(MEDIUM_TIMEOUT_SEC)
def _send_simple_command(self, motor: NameOrID, command_byte: int) -> None:
"""Helper to send simple 8-byte commands (Enable, Disable, Zero)."""
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [command_byte]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
if msg := self._recv_motor_response(expected_recv_id=recv_id):
self._process_response(motor_name, msg)
else:
logger.debug(f"No response from {motor_name} after command 0x{command_byte:02X}")
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors."""
target_motors = self._get_motors_list(motors)
for motor in target_motors:
for _ in range(num_retry + 1):
try:
self._send_simple_command(motor, CAN_CMD_ENABLE)
break
except Exception as e:
if _ == num_retry:
raise e
time.sleep(MEDIUM_TIMEOUT_SEC)
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Disable torque on selected motors."""
target_motors = self._get_motors_list(motors)
for motor in target_motors:
for _ in range(num_retry + 1):
try:
self._send_simple_command(motor, CAN_CMD_DISABLE)
break
except Exception as e:
if _ == num_retry:
raise e
time.sleep(MEDIUM_TIMEOUT_SEC)
@contextmanager
def torque_disabled(self, motors: str | list[str] | None = None):
"""
Context manager that guarantees torque is re-enabled.
This helper is useful to temporarily disable torque when configuring motors.
"""
self.disable_torque(motors)
try:
yield
finally:
self.enable_torque(motors)
def set_zero_position(self, motors: str | list[str] | None = None) -> None:
"""Set current position as zero for selected motors."""
target_motors = self._get_motors_list(motors)
for motor in target_motors:
self._send_simple_command(motor, CAN_CMD_SET_ZERO)
time.sleep(MEDIUM_TIMEOUT_SEC)
def _refresh_motor(self, motor: NameOrID) -> can.Message | None:
"""Refresh motor status and return the response."""
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
self.canbus.send(msg)
return self._recv_motor_response(expected_recv_id=recv_id)
def _recv_motor_response(
self, expected_recv_id: int | None = None, timeout: float = 0.001
) -> can.Message | None:
"""
Receive a response from a motor.
Args:
expected_recv_id: If provided, only return messages from this CAN ID
timeout: Timeout in seconds (default: 1ms for high-speed operation)
Returns:
CAN message if received, None otherwise
"""
try:
start_time = time.time()
messages_seen = []
while time.time() - start_time < timeout:
msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC)
if msg:
messages_seen.append(f"0x{msg.arbitration_id:02X}")
if expected_recv_id is None or msg.arbitration_id == expected_recv_id:
return msg
logger.debug(
f"Ignoring message from 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}"
)
if logger.isEnabledFor(logging.DEBUG):
if messages_seen:
logger.debug(
f"Received {len(messages_seen)} msgs from {set(messages_seen)}, expected 0x{expected_recv_id:02X}"
)
else:
logger.debug(f"No CAN messages received (expected 0x{expected_recv_id:02X})")
except Exception as e:
logger.debug(f"Failed to receive CAN message: {e}")
return None
def _recv_all_responses(
self, expected_recv_ids: list[int], timeout: float = 0.002
) -> dict[int, can.Message]:
"""
Efficiently receive responses from multiple motors at once.
Uses the OpenArms pattern: collect all available messages within timeout.
Args:
expected_recv_ids: List of CAN IDs we expect responses from
timeout: Total timeout in seconds (default: 2ms)
Returns:
Dictionary mapping recv_id to CAN message
"""
responses = {}
expected_set = set(expected_recv_ids)
start_time = time.time()
try:
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
# 100us poll timeout
msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC)
if msg and msg.arbitration_id in expected_set:
responses[msg.arbitration_id] = msg
if len(responses) == len(expected_recv_ids):
break
except Exception as e:
logger.debug(f"Error receiving responses: {e}")
return responses
def _encode_mit_packet(
self,
motor_type: MotorType,
kp: float,
kd: float,
position_degrees: float,
velocity_deg_per_sec: float,
torque: float,
) -> list[int]:
"""Helper to encode control parameters into 8 bytes for MIT mode."""
# Convert degrees to radians
position_rad = np.radians(position_degrees)
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
# Get motor limits
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
# Encode parameters
kp_uint = self._float_to_uint(kp, *MIT_KP_RANGE, 12)
kd_uint = self._float_to_uint(kd, *MIT_KD_RANGE, 12)
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
# Pack data
data = [0] * 8
data[0] = (q_uint >> 8) & 0xFF
data[1] = q_uint & 0xFF
data[2] = dq_uint >> 4
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
data[4] = kp_uint & 0xFF
data[5] = kd_uint >> 4
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
data[7] = tau_uint & 0xFF
return data
def _mit_control(
self,
motor: NameOrID,
kp: float,
kd: float,
position_degrees: float,
velocity_deg_per_sec: float,
torque: float,
) -> None:
"""Send MIT control command to a motor."""
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
recv_id = self._get_motor_recv_id(motor)
if msg := self._recv_motor_response(expected_recv_id=recv_id):
self._process_response(motor_name, msg)
else:
logger.debug(f"No response from {motor_name} after MIT control command")
def _mit_control_batch(
self,
commands: dict[NameOrID, tuple[float, float, float, float, float]],
) -> None:
"""
Send MIT control commands to multiple motors in batch.
Sends all commands first, then collects responses.
Args:
commands: Dict mapping motor name/ID to (kp, kd, position_deg, velocity_deg/s, torque)
Example: {'joint_1': (10.0, 0.5, 45.0, 0.0, 0.0), ...}
"""
if not commands:
return
recv_id_to_motor: dict[int, str] = {}
# Step 1: Send all MIT control commands
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
# Step 2: Collect responses and update state cache
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=SHORT_TIMEOUT_SEC)
for recv_id, motor_name in recv_id_to_motor.items():
if msg := responses.get(recv_id):
self._process_response(motor_name, msg)
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
"""Convert float to unsigned integer for CAN transmission."""
x = max(x_min, min(x_max, x)) # Clamp to range
span = x_max - x_min
data_norm = (x - x_min) / span
return int(data_norm * ((1 << bits) - 1))
def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float:
"""Convert unsigned integer from CAN to float."""
span = x_max - x_min
data_norm = float(x) / ((1 << bits) - 1)
return data_norm * span + x_min
def _decode_motor_state(
self, data: bytearray | bytes, motor_type: MotorType
) -> tuple[float, float, float, int, int]:
"""
Decode motor state from CAN data.
Returns: (position_deg, velocity_deg_s, torque, temp_mos, temp_rotor)
"""
if len(data) < 8:
raise ValueError("Invalid motor state data")
# Extract encoded values
q_uint = (data[1] << 8) | data[2]
dq_uint = (data[3] << 4) | (data[4] >> 4)
tau_uint = ((data[4] & 0x0F) << 8) | data[5]
t_mos = data[6]
t_rotor = data[7]
# Get motor limits
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
# Decode to physical values
position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16)
velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12)
torque = self._uint_to_float(tau_uint, -tmax, tmax, 12)
return np.degrees(position_rad), np.degrees(velocity_rad_per_sec), torque, t_mos, t_rotor
def _process_response(self, motor: str, msg: can.Message) -> None:
"""Decode a message and update the motor state cache."""
try:
motor_type = self._motor_types[motor]
pos, vel, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
self._last_known_states[motor] = {
"position": pos,
"velocity": vel,
"torque": torque,
"temp_mos": float(t_mos),
"temp_rotor": float(t_rotor),
}
except Exception as e:
logger.warning(f"Failed to decode response from {motor}: {e}")
def read(self, data_name: str, motor: str) -> Value:
"""Read a value from a single motor. Positions are always in degrees."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Refresh motor to get latest state
msg = self._refresh_motor(motor)
if msg is None:
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
raise ConnectionError(
f"No response from motor '{motor}' (send ID: 0x{motor_id:02X}, recv ID: 0x{recv_id:02X}). "
f"Check that: 1) Motor is powered (24V), 2) CAN wiring is correct, "
f"3) Motor IDs are configured correctly using Damiao Debugging Tools"
)
self._process_response(motor, msg)
return self._get_cached_value(motor, data_name)
def _get_cached_value(self, motor: str, data_name: str) -> Value:
"""Retrieve a specific value from the cache."""
state = self._last_known_states[motor]
mapping: dict[str, Any] = {
"Present_Position": state["position"],
"Present_Velocity": state["velocity"],
"Present_Torque": state["torque"],
"Temperature_MOS": state["temp_mos"],
"Temperature_Rotor": state["temp_rotor"],
}
if data_name not in mapping:
raise ValueError(f"Unknown data_name: {data_name}")
return mapping[data_name]
def write(
self,
data_name: str,
motor: str,
value: Value,
) -> None:
"""
Write a value to a single motor. Positions are always in degrees.
Can write 'Goal_Position', 'Kp', or 'Kd'.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if data_name in ("Kp", "Kd"):
self._gains[motor][data_name.lower()] = float(value)
elif data_name == "Goal_Position":
kp = self._gains[motor]["kp"]
kd = self._gains[motor]["kd"]
self._mit_control(motor, kp, kd, float(value), 0.0, 0.0)
else:
raise ValueError(f"Writing {data_name} not supported in MIT mode")
def sync_read(
self,
data_name: str,
motors: str | list[str] | None = None,
) -> dict[str, Value]:
"""
Read the same value from multiple motors simultaneously.
"""
target_motors = self._get_motors_list(motors)
self._batch_refresh(target_motors)
result = {}
for motor in target_motors:
result[motor] = self._get_cached_value(motor, data_name)
return result
def sync_read_all_states(
self,
motors: str | list[str] | None = None,
*,
num_retry: int = 0,
) -> dict[str, MotorState]:
"""
Read ALL motor states (position, velocity, torque) from multiple motors in ONE refresh cycle.
Returns:
Dictionary mapping motor names to state dicts with keys: 'position', 'velocity', 'torque'
Example: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
"""
target_motors = self._get_motors_list(motors)
self._batch_refresh(target_motors)
result = {}
for motor in target_motors:
result[motor] = self._last_known_states[motor].copy()
return result
def _batch_refresh(self, motors: list[str]) -> None:
"""Internal helper to refresh a list of motors and update cache."""
# Send refresh commands
for motor in motors:
motor_id = self._get_motor_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
self.canbus.send(msg)
# Small delay to reduce bus congestion if necessary, though removed in sync_read previously
# precise_sleep(PRECISE_SLEEP_SEC)
# Collect responses
expected_recv_ids = [self._get_motor_recv_id(m) for m in motors]
responses = self._recv_all_responses(expected_recv_ids, timeout=MEDIUM_TIMEOUT_SEC)
# Update cache
for motor in motors:
recv_id = self._get_motor_recv_id(motor)
msg = responses.get(recv_id)
if msg:
self._process_response(motor, msg)
else:
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
"""
Write values to multiple motors simultaneously. Positions are always in degrees.
"""
if data_name in ("Kp", "Kd"):
key = data_name.lower()
for motor, val in values.items():
self._gains[motor][key] = float(val)
elif data_name == "Goal_Position":
# Step 1: Send all MIT control commands
recv_id_to_motor: dict[int, str] = {}
for motor, value_degrees in values.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
kp = self._gains[motor]["kp"]
kd = self._gains[motor]["kd"]
data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
precise_sleep(PRECISE_TIMEOUT_SEC)
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
# Step 2: Collect responses and update state cache
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=MEDIUM_TIMEOUT_SEC)
for recv_id, motor_name in recv_id_to_motor.items():
if msg := responses.get(recv_id):
self._process_response(motor_name, msg)
else:
# Fall back to individual writes
for motor, value in values.items():
self.write(data_name, motor, value)
def read_calibration(self) -> dict[str, MotorCalibration]:
"""Read calibration data from motors."""
# Damiao motors don't store calibration internally
# Return existing calibration or empty dict
return self.calibration if self.calibration else {}
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
"""Write calibration data to motors."""
# Damiao motors don't store calibration internally
# Just cache it in memory
if cache:
self.calibration = calibration_dict
def record_ranges_of_motion(
self,
motors: NameOrID | list[NameOrID] | None = None,
display_values: bool = True,
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
"""
Interactively record the min/max values of each motor in degrees.
Move the joints by hand (with torque disabled) while the method streams live positions.
Press Enter to finish.
"""
target_motors = self._get_motors_list(motors)
self.disable_torque(target_motors)
time.sleep(LONG_TIMEOUT_SEC)
start_positions = self.sync_read("Present_Position", target_motors)
mins = start_positions.copy()
maxes = start_positions.copy()
print("\nMove joints through their full range of motion. Press ENTER when done.")
user_pressed_enter = False
while not user_pressed_enter:
positions = self.sync_read("Present_Position", target_motors)
for motor in target_motors:
if motor in positions:
mins[motor] = min(positions[motor], mins.get(motor, positions[motor]))
maxes[motor] = max(positions[motor], maxes.get(motor, positions[motor]))
if display_values:
print("\n" + "=" * 50)
print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}")
print("-" * 50)
for motor in target_motors:
if motor in positions:
print(
f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}"
)
if enter_pressed():
user_pressed_enter = True
if display_values and not user_pressed_enter:
move_cursor_up(len(target_motors) + 4)
time.sleep(LONG_TIMEOUT_SEC)
self.enable_torque(target_motors)
for motor in target_motors:
if (motor in mins) and (motor in maxes) and (int(abs(maxes[motor] - mins[motor])) < 5):
raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)")
return mins, maxes
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
"""Convert motor specification to list of motor names."""
if motors is None:
return list(self.motors.keys())
elif isinstance(motors, str):
return [motors]
elif isinstance(motors, list):
return motors
else:
raise TypeError(f"Invalid motors type: {type(motors)}")
def _get_motor_id(self, motor: NameOrID) -> int:
"""Get CAN ID for a motor."""
if isinstance(motor, str):
if motor in self.motors:
return self.motors[motor].id
else:
raise ValueError(f"Unknown motor: {motor}")
else:
return motor
def _get_motor_name(self, motor: NameOrID) -> str:
"""Get motor name from name or ID."""
if isinstance(motor, str):
return motor
else:
for name, m in self.motors.items():
if m.id == motor:
return name
raise ValueError(f"Unknown motor ID: {motor}")
def _get_motor_recv_id(self, motor: NameOrID) -> int:
"""Get motor recv_id from name or ID."""
motor_name = self._get_motor_name(motor)
motor_obj = self.motors.get(motor_name)
if motor_obj and motor_obj.recv_id is not None:
return motor_obj.recv_id
else:
raise ValueError(f"Motor {motor_obj} doesn't have a valid recv_id (None).")
@cached_property
def is_calibrated(self) -> bool:
"""Check if motors are calibrated."""
return bool(self.calibration)
+209
View File
@@ -0,0 +1,209 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration tables for Damiao motors."""
from enum import IntEnum
# Motor type definitions
class MotorType(IntEnum):
DM3507 = 0
DM4310 = 1
DM4310_48V = 2
DM4340 = 3
DM4340_48V = 4
DM6006 = 5
DM8006 = 6
DM8009 = 7
DM10010L = 8
DM10010 = 9
DMH3510 = 10
DMH6215 = 11
DMG6220 = 12
# Control modes
class ControlMode(IntEnum):
MIT = 1
POS_VEL = 2
VEL = 3
TORQUE_POS = 4
# Motor variable IDs (RID)
class MotorVariable(IntEnum):
UV_VALUE = 0
KT_VALUE = 1
OT_VALUE = 2
OC_VALUE = 3
ACC = 4
DEC = 5
MAX_SPD = 6
MST_ID = 7
ESC_ID = 8
TIMEOUT = 9
CTRL_MODE = 10
DAMP = 11
INERTIA = 12
HW_VER = 13
SW_VER = 14
SN = 15
NPP = 16
RS = 17
LS = 18
FLUX = 19
GR = 20
PMAX = 21
VMAX = 22
TMAX = 23
I_BW = 24
KP_ASR = 25
KI_ASR = 26
KP_APR = 27
KI_APR = 28
OV_VALUE = 29
GREF = 30
DETA = 31
V_BW = 32
IQ_C1 = 33
VL_C1 = 34
CAN_BR = 35
SUB_VER = 36
U_OFF = 50
V_OFF = 51
K1 = 52
K2 = 53
M_OFF = 54
DIR = 55
P_M = 80
XOUT = 81
# Motor limit parameters [PMAX, VMAX, TMAX]
# PMAX: Maximum position (rad)
# VMAX: Maximum velocity (rad/s)
# TMAX: Maximum torque (N·m)
MOTOR_LIMIT_PARAMS = {
MotorType.DM3507: (12.5, 30, 10),
MotorType.DM4310: (12.5, 30, 10),
MotorType.DM4310_48V: (12.5, 50, 10),
MotorType.DM4340: (12.5, 8, 28),
MotorType.DM4340_48V: (12.5, 10, 28),
MotorType.DM6006: (12.5, 45, 20),
MotorType.DM8006: (12.5, 45, 40),
MotorType.DM8009: (12.5, 45, 54),
MotorType.DM10010L: (12.5, 25, 200),
MotorType.DM10010: (12.5, 20, 200),
MotorType.DMH3510: (12.5, 280, 1),
MotorType.DMH6215: (12.5, 45, 10),
MotorType.DMG6220: (12.5, 45, 10),
}
# Motor model names
MODEL_NAMES = {
MotorType.DM3507: "dm3507",
MotorType.DM4310: "dm4310",
MotorType.DM4310_48V: "dm4310_48v",
MotorType.DM4340: "dm4340",
MotorType.DM4340_48V: "dm4340_48v",
MotorType.DM6006: "dm6006",
MotorType.DM8006: "dm8006",
MotorType.DM8009: "dm8009",
MotorType.DM10010L: "dm10010l",
MotorType.DM10010: "dm10010",
MotorType.DMH3510: "dmh3510",
MotorType.DMH6215: "dmh6215",
MotorType.DMG6220: "dmg6220",
}
# Motor resolution table (encoder counts per revolution)
MODEL_RESOLUTION = {
"dm3507": 65536,
"dm4310": 65536,
"dm4310_48v": 65536,
"dm4340": 65536,
"dm4340_48v": 65536,
"dm6006": 65536,
"dm8006": 65536,
"dm8009": 65536,
"dm10010l": 65536,
"dm10010": 65536,
"dmh3510": 65536,
"dmh6215": 65536,
"dmg6220": 65536,
}
# CAN baudrates supported by Damiao motors
AVAILABLE_BAUDRATES = [
125000, # 0: 125 kbps
200000, # 1: 200 kbps
250000, # 2: 250 kbps
500000, # 3: 500 kbps
1000000, # 4: 1 mbps (default for OpenArms)
2000000, # 5: 2 mbps
2500000, # 6: 2.5 mbps
3200000, # 7: 3.2 mbps
4000000, # 8: 4 mbps
5000000, # 9: 5 mbps
]
DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms
# Default timeout in milliseconds
DEFAULT_TIMEOUT_MS = 1000
# OpenArms specific configurations
# Based on: https://docs.openarm.dev/software/setup/configure-test
# OpenArms has 7 DOF per arm (14 total for dual arm)
OPENARMS_ARM_MOTOR_IDS = {
"joint_1": {"send": 0x01, "recv": 0x11}, # J1 - Shoulder pan
"joint_2": {"send": 0x02, "recv": 0x12}, # J2 - Shoulder lift
"joint_3": {"send": 0x03, "recv": 0x13}, # J3 - Elbow flex
"joint_4": {"send": 0x04, "recv": 0x14}, # J4 - Wrist flex
"joint_5": {"send": 0x05, "recv": 0x15}, # J5 - Wrist roll
"joint_6": {"send": 0x06, "recv": 0x16}, # J6 - Wrist pitch
"joint_7": {"send": 0x07, "recv": 0x17}, # J7 - Wrist rotation
}
OPENARMS_GRIPPER_MOTOR_IDS = {
"gripper": {"send": 0x08, "recv": 0x18}, # J8 - Gripper
}
# Default motor types for OpenArms
OPENARMS_DEFAULT_MOTOR_TYPES = {
"joint_1": MotorType.DM8009, # Shoulder pan - high torque
"joint_2": MotorType.DM8009, # Shoulder lift - high torque
"joint_3": MotorType.DM4340, # Shoulder rotation
"joint_4": MotorType.DM4340, # Elbow flex
"joint_5": MotorType.DM4310, # Wrist roll
"joint_6": MotorType.DM4310, # Wrist pitch
"joint_7": MotorType.DM4310, # Wrist rotation
"gripper": MotorType.DM4310, # Gripper
}
# MIT control parameter ranges
MIT_KP_RANGE = (0.0, 500.0)
MIT_KD_RANGE = (0.0, 5.0)
# CAN frame command IDs
CAN_CMD_ENABLE = 0xFC
CAN_CMD_DISABLE = 0xFD
CAN_CMD_SET_ZERO = 0xFE
CAN_CMD_REFRESH = 0xCC
CAN_CMD_QUERY_PARAM = 0x33
CAN_CMD_WRITE_PARAM = 0x55
CAN_CMD_SAVE_PARAM = 0xAA
# CAN ID for parameter operations
CAN_PARAM_ID = 0x7FF
+5 -6
View File
@@ -22,9 +22,8 @@ import logging
from copy import deepcopy
from enum import Enum
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from ..encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
from .tables import (
AVAILABLE_BAUDRATES,
MODEL_BAUDRATE_TABLE,
@@ -100,7 +99,7 @@ def _split_into_byte_chunks(value: int, length: int) -> list[int]:
return data
class DynamixelMotorsBus(MotorsBus):
class DynamixelMotorsBus(SerialMotorsBus):
"""
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
the motors. For more info, see the Dynamixel SDK Documentation:
@@ -203,9 +202,9 @@ class DynamixelMotorsBus(MotorsBus):
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
+6 -7
View File
@@ -17,9 +17,8 @@ from copy import deepcopy
from enum import Enum
from pprint import pformat
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
from .tables import (
FIRMWARE_MAJOR_VERSION,
FIRMWARE_MINOR_VERSION,
@@ -96,7 +95,7 @@ def patch_setPacketTimeout(self, packet_length): # noqa: N802
self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50
class FeetechMotorsBus(MotorsBus):
class FeetechMotorsBus(SerialMotorsBus):
"""
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
@@ -298,11 +297,11 @@ class FeetechMotorsBus(MotorsBus):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self.write("Lock", motor, 0, num_retry=num_retry)
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
addr, length = get_address(self.model_ctrl_table, model, "Lock")
self._write(addr, length, motor_id, 0, num_retry=num_retry)
self._write(addr, length, motor, 0, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
+91 -10
View File
@@ -19,6 +19,8 @@
# TODO(aliberts): Add block noqa when feature below is available
# https://github.com/astral-sh/ruff/issues/3711
from __future__ import annotations
import abc
import logging
from contextlib import contextmanager
@@ -41,6 +43,81 @@ Value: TypeAlias = int | float
logger = logging.getLogger(__name__)
class MotorsBusBase(abc.ABC):
"""
Base class for all motor bus implementations.
This is a minimal interface that all motor buses must implement, regardless of their
communication protocol (serial, CAN, etc.).
"""
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
self.port = port
self.motors = motors
self.calibration = calibration if calibration else {}
@abc.abstractmethod
def connect(self, handshake: bool = True) -> None:
"""Establish connection to the motors."""
pass
@abc.abstractmethod
def disconnect(self, disable_torque: bool = True) -> None:
"""Disconnect from the motors."""
pass
@property
@abc.abstractmethod
def is_connected(self) -> bool:
"""Check if connected to the motors."""
pass
@abc.abstractmethod
def read(self, data_name: str, motor: str) -> Value:
"""Read a value from a single motor."""
pass
@abc.abstractmethod
def write(self, data_name: str, motor: str, value: Value) -> None:
"""Write a value to a single motor."""
pass
@abc.abstractmethod
def sync_read(self, data_name: str, motors: str | list[str] | None = None) -> dict[str, Value]:
"""Read a value from multiple motors."""
pass
@abc.abstractmethod
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
"""Write values to multiple motors."""
pass
@abc.abstractmethod
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors."""
pass
@abc.abstractmethod
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Disable torque on selected motors."""
pass
@abc.abstractmethod
def read_calibration(self) -> dict[str, MotorCalibration]:
"""Read calibration parameters from the motors."""
pass
@abc.abstractmethod
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
"""Write calibration parameters to the motors."""
pass
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
ctrl_table = model_ctrl_table.get(model)
if ctrl_table is None:
@@ -97,6 +174,8 @@ class Motor:
id: int
model: str
norm_mode: MotorNormMode
motor_type_str: str | None = None
recv_id: int | None = None
class PortHandler(Protocol):
@@ -203,15 +282,15 @@ class GroupSyncWrite(Protocol):
def txPacket(self): ...
class MotorsBus(abc.ABC):
class SerialMotorsBus(MotorsBusBase):
"""
A MotorsBus allows to efficiently read and write to the attached motors.
A SerialMotorsBus allows to efficiently read and write to motors connected via serial communication.
It represents several motors daisy-chained together and connected through a serial port.
There are currently two implementations of this abstract class:
There are currently two implementations of this class:
- DynamixelMotorsBus
- FeetechMotorsBus
Note: This class may evolve in the future should we add support for other types of bus.
This class is specifically for serial-based motor protocols (Dynamixel, Feetech, etc.).
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
To find the port, you can run our utility script:
@@ -260,9 +339,7 @@ class MotorsBus(abc.ABC):
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
self.port = port
self.motors = motors
self.calibration = calibration if calibration else {}
super().__init__(port, motors, calibration)
self.port_handler: PortHandler
self.packet_handler: PacketHandler
@@ -532,7 +609,7 @@ class MotorsBus(abc.ABC):
self.set_baudrate(self.default_baudrate)
@abc.abstractmethod
def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]:
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
pass
@abc.abstractmethod
@@ -545,13 +622,13 @@ class MotorsBus(abc.ABC):
pass
@abc.abstractmethod
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Disable torque on selected motors.
Disabling Torque allows to write to the motors' permanent memory area (EPROM/EEPROM).
Args:
motors (int | str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a
motors ( str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a
list of names or `None` to affect every registered motor. Defaults to `None`.
num_retry (int, optional): Number of additional retry attempts on communication failure.
Defaults to 0.
@@ -1194,3 +1271,7 @@ class MotorsBus(abc.ABC):
for id_, value in ids_values.items():
data = self._serialize_data(value, length)
self.sync_writer.addParam(id_, data)
# Backward compatibility alias
MotorsBus: TypeAlias = SerialMotorsBus
+2 -17
View File
@@ -35,7 +35,6 @@ from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sarm.configuration_sarm import SARMConfig
@@ -68,7 +67,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi05", "pi05_video", "sac", "reward_classifier", "smolvla", "wall_x".
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -104,10 +103,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
return PI05Policy
elif name == "pi05_video":
from lerobot.policies.videovla.modeling_pi05 import PI05VideoPolicy
return PI05VideoPolicy
elif name == "sac":
from lerobot.policies.sac.modeling_sac import SACPolicy
@@ -152,7 +147,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi05", "pi05_video", "sac", "smolvla",
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
"reward_classifier", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
@@ -174,8 +169,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi05":
return PI05Config(**kwargs)
elif policy_type == "pi05_video":
return PI05VideoConfig(**kwargs)
elif policy_type == "sac":
return SACConfig(**kwargs)
elif policy_type == "smolvla":
@@ -340,14 +333,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI05VideoConfig):
from lerobot.policies.videovla.processor_pi05 import make_pi05_video_pre_post_processors
processors = make_pi05_video_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SACConfig):
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
@@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224
@@ -50,8 +50,9 @@ class PI0Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
# Real-Time Chunking (RTC) configurations
rtc_config: RTCConfig | None = None
rtc_training_config: RTCTrainingConfig | None = None
image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE,
+74 -19
View File
@@ -44,6 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.rtc.training_time import (
apply_rtc_training_time,
apply_training_time_rtc_inference,
masked_mean,
sample_rtc_delay,
)
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -79,8 +85,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
if time.ndim not in (1, 2):
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
@@ -88,8 +94,14 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
if time.ndim == 1:
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
time_flat = time.reshape(-1)
sin_input = scaling_factor[None, :] * time_flat[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb.reshape(*time.shape, dimension)
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
@@ -605,6 +617,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _training_time_rtc_inference_enabled(self):
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -714,7 +729,10 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
time_emb = time_emb[:, None, :].expand_as(action_emb)
if time_emb.dim() == 2:
time_emb = time_emb[:, None, :].expand_as(action_emb)
elif time_emb.shape[:2] != action_emb.shape[:2]:
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
def mlp_func(action_time_emb):
@@ -750,7 +768,12 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
if time.ndim == 1:
time_expanded = time[:, None, None]
elif time.ndim == 2:
time_expanded = time[:, :, None]
else:
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
@@ -846,24 +869,37 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
dt = -1.0 / num_steps
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
use_training_time_rtc = self._training_time_rtc_inference_enabled()
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
if use_training_time_rtc:
x_t_cond, time_tensor = apply_training_time_rtc_inference(
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
)
v_t = self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
x_t=x_t_cond,
timestep=time_tensor,
)
elif self._rtc_enabled():
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
@@ -874,7 +910,14 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
v_t = self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=x_t,
timestep=time_tensor,
)
x_t = x_t + dt * v_t
@@ -1277,7 +1320,19 @@ class PI0Policy(PreTrainedPolicy):
actions = self.prepare_action(batch)
# Compute loss
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
postfix_mask = None
rtc_cfg = self.config.rtc_training_config
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
batch_size = actions.shape[0]
time = self.model.sample_time(batch_size, actions.device)
noise = self.model.sample_noise(actions.shape, actions.device)
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
losses = self.model.forward(
images, img_masks, lang_tokens, lang_masks, state, actions, noise=noise, time=time
)
else:
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
# Truncate losses to actual action dimensions
original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -1289,12 +1344,12 @@ class PI0Policy(PreTrainedPolicy):
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = losses.mean(dim=(1, 2))
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict
else:
# Default: return scalar mean loss
loss = losses.mean()
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
loss_dict["loss"] = loss.item()
return loss, loss_dict
@@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224
@@ -52,6 +52,7 @@ class PI05Config(PreTrainedConfig):
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
rtc_training_config: RTCTrainingConfig | None = None
image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE,
+74 -26
View File
@@ -44,6 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.rtc.training_time import (
apply_rtc_training_time,
apply_training_time_rtc_inference,
masked_mean,
sample_rtc_delay,
)
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -78,8 +84,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
if time.ndim not in (1, 2):
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
@@ -87,8 +93,14 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
if time.ndim == 1:
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
time_flat = time.reshape(-1)
sin_input = scaling_factor[None, :] * time_flat[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb.reshape(*time.shape, dimension)
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
@@ -460,8 +472,8 @@ class PaliGemmaWithExpertModel(
inputs_embeds=inputs_embeds[1],
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False,
past_key_values=None, #jadechoghari
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
)
suffix_output = suffix_output.last_hidden_state
@@ -575,13 +587,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
# try:
# from transformers.models.siglip import check
try:
from transformers.models.siglip import check
# if not check.check_whether_transformers_replace_is_installed_correctly():
# raise ValueError(msg)
# except ImportError:
# raise ValueError(msg) from None
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
@@ -602,6 +614,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _training_time_rtc_inference_enabled(self):
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -729,7 +744,12 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
if time.ndim == 1:
time_expanded = time[:, None, None]
elif time.ndim == 2:
time_expanded = time[:, :, None]
else:
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
@@ -820,23 +840,35 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
dt = -1.0 / num_steps
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
use_training_time_rtc = self._training_time_rtc_inference_enabled()
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
if use_training_time_rtc:
x_t_cond, time_tensor = apply_training_time_rtc_inference(
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
)
v_t = self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
x_t=x_t_cond,
timestep=time_tensor,
)
elif self._rtc_enabled():
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
@@ -847,7 +879,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=x_t,
timestep=time_tensor,
)
x_t = x_t + dt * v_t
@@ -1250,7 +1288,17 @@ class PI05Policy(PreTrainedPolicy):
actions = self.prepare_action(batch)
# Compute loss (no separate state needed for PI05)
losses = self.model.forward(images, img_masks, tokens, masks, actions)
postfix_mask = None
rtc_cfg = self.config.rtc_training_config
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
batch_size = actions.shape[0]
time = self.model.sample_time(batch_size, actions.device)
noise = self.model.sample_noise(actions.shape, actions.device)
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise=noise, time=time)
else:
losses = self.model.forward(images, img_masks, tokens, masks, actions)
# Truncate losses to actual action dimensions
original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -1262,12 +1310,12 @@ class PI05Policy(PreTrainedPolicy):
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = losses.mean(dim=(1, 2))
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict
else:
# Default: return scalar mean loss
loss = losses.mean()
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
loss_dict["loss"] = loss.item()
return loss, loss_dict
+20 -1
View File
@@ -23,7 +23,7 @@ Based on:
from dataclasses import dataclass
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.configs.types import RTCAttentionSchedule, RTCTrainingDelayDistribution
@dataclass
@@ -53,3 +53,22 @@ class RTCConfig:
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
if self.debug_maxlen <= 0:
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
@dataclass
class RTCTrainingConfig:
"""Configuration for training-time RTC action prefix conditioning."""
enabled: bool = False
min_delay: int = 0
max_delay: int = 0
delay_distribution: RTCTrainingDelayDistribution = RTCTrainingDelayDistribution.UNIFORM
exp_decay: float = 1.0
def __post_init__(self):
if self.min_delay < 0:
raise ValueError(f"min_delay must be >= 0, got {self.min_delay}")
if self.max_delay < self.min_delay:
raise ValueError(f"max_delay ({self.max_delay}) must be >= min_delay ({self.min_delay})")
if self.exp_decay <= 0:
raise ValueError(f"exp_decay must be positive, got {self.exp_decay}")
+110
View File
@@ -0,0 +1,110 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from lerobot.configs.types import RTCTrainingDelayDistribution
from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig
def sample_rtc_delay(cfg: RTCTrainingConfig, batch_size: int, device: torch.device) -> torch.Tensor:
if cfg.max_delay == cfg.min_delay:
return torch.full((batch_size,), cfg.min_delay, device=device, dtype=torch.long)
if cfg.delay_distribution == RTCTrainingDelayDistribution.UNIFORM:
return torch.randint(cfg.min_delay, cfg.max_delay + 1, (batch_size,), device=device, dtype=torch.long)
delay_values = torch.arange(cfg.min_delay, cfg.max_delay + 1, device=device, dtype=torch.long)
weights = torch.exp(-cfg.exp_decay * delay_values.to(dtype=torch.float32))
probs = weights / weights.sum()
samples = torch.multinomial(probs, batch_size, replacement=True)
return delay_values[samples]
def apply_rtc_training_time(
time: torch.Tensor, delay: torch.Tensor, seq_len: int
) -> tuple[torch.Tensor, torch.Tensor]:
device = time.device
delay = torch.clamp(delay, max=seq_len)
prefix_mask = torch.arange(seq_len, device=device)[None, :] < delay[:, None]
time_tokens = time[:, None].expand(-1, seq_len)
time_tokens = time_tokens.masked_fill(prefix_mask, 0.0)
postfix_mask = ~prefix_mask
return time_tokens, postfix_mask
def masked_mean(
losses: torch.Tensor, mask: torch.Tensor | None, reduce_dims: tuple[int, ...], eps: float = 1e-8
) -> torch.Tensor:
if mask is None:
return losses.mean(dim=reduce_dims)
mask = mask.to(dtype=losses.dtype)
while mask.dim() < losses.dim():
mask = mask.unsqueeze(-1)
masked = losses * mask
denom = mask.sum(dim=reduce_dims).clamp_min(eps)
return masked.sum(dim=reduce_dims) / denom
def apply_training_time_rtc_inference(
x_t: torch.Tensor,
time: float,
inference_delay: int | None,
prev_chunk_left_over: torch.Tensor | None,
chunk_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply training-time RTC conditioning during inference.
Based on Algorithm 1 from "Training-Time Action Conditioning for Efficient Real-Time Chunking".
At each denoising step:
1. Replace prefix positions in x_t with ground truth from previous chunk
2. Create per-token timesteps with 1.0 for prefix positions
Args:
x_t: Current noisy actions (B, T, D)
time: Current flow matching timestep (scalar)
inference_delay: Number of prefix actions to condition on
prev_chunk_left_over: Previous chunk's leftover actions (B, T, D)
chunk_size: Total chunk size T
Returns:
x_t_conditioned: x_t with prefix replaced by previous actions
time_per_token: Per-token timesteps (B, T) with 1.0 for prefix
"""
batch_size = x_t.shape[0]
device = x_t.device
if inference_delay is None or inference_delay <= 0 or prev_chunk_left_over is None:
time_scalar = torch.full((batch_size,), time, device=device, dtype=torch.float32)
return x_t, time_scalar
delay = min(inference_delay, chunk_size)
prefix_mask = torch.arange(chunk_size, device=device)[None, :] < delay
x_t_conditioned = torch.where(
prefix_mask[:, :, None].expand_as(x_t),
prev_chunk_left_over[:, :chunk_size, :],
x_t,
)
time_per_token = torch.full((batch_size, chunk_size), time, device=device, dtype=torch.float32)
time_per_token = time_per_token.masked_fill(prefix_mask, 1.0)
return x_t_conditioned, time_per_token
@@ -20,7 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
from lerobot.utils.constants import OBS_IMAGES
@@ -103,8 +103,9 @@ class SmolVLAConfig(PreTrainedConfig):
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
# Real-Time Chunking (RTC) configurations
rtc_config: RTCConfig | None = None
rtc_training_config: RTCTrainingConfig | None = None
def __post_init__(self):
super().__post_init__()
@@ -63,6 +63,12 @@ from typing_extensions import Unpack
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.rtc.training_time import (
apply_rtc_training_time,
apply_training_time_rtc_inference,
masked_mean,
sample_rtc_delay,
)
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.policies.utils import (
@@ -85,8 +91,8 @@ def create_sinusoidal_pos_embedding(
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
if time.ndim not in (1, 2):
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
@@ -94,9 +100,14 @@ def create_sinusoidal_pos_embedding(
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
if time.ndim == 1:
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
time_flat = time.reshape(-1)
sin_input = scaling_factor[None, :] * time_flat[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb
return pos_emb.reshape(*time.shape, dimension)
def make_att_2d_masks(pad_masks, att_masks):
@@ -375,6 +386,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
actions = self.prepare_action(batch)
postfix_mask = None
rtc_cfg = self.config.rtc_training_config
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
batch_size = actions.shape[0]
if time is None:
time = self.model.sample_time(batch_size, actions.device)
if noise is None:
noise = self.model.sample_noise(actions.shape, actions.device)
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
@@ -384,6 +405,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone()
postfix_mask = in_episode_bound if postfix_mask is None else (postfix_mask & in_episode_bound)
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
@@ -391,12 +413,12 @@ class SmolVLAPolicy(PreTrainedPolicy):
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = losses.mean(dim=(1, 2))
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict
else:
# Default: return scalar mean loss
loss = losses.mean()
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
loss_dict["loss"] = loss.item()
return loss, loss_dict
@@ -596,6 +618,9 @@ class VLAFlowMatching(nn.Module):
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _training_time_rtc_inference_enabled(self):
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
def set_requires_grad(self):
for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj
@@ -731,7 +756,10 @@ class VLAFlowMatching(nn.Module):
)
time_emb = time_emb.type(dtype=dtype)
time_emb = time_emb[:, None, :].expand_as(action_emb)
if time_emb.dim() == 2:
time_emb = time_emb[:, None, :].expand_as(action_emb)
elif time_emb.shape[:2] != action_emb.shape[:2]:
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
action_time_emb = self.action_time_mlp_in(action_time_emb)
@@ -763,7 +791,12 @@ class VLAFlowMatching(nn.Module):
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
if time.ndim == 1:
time_expanded = time[:, None, None]
elif time.ndim == 2:
time_expanded = time[:, :, None]
else:
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
@@ -826,23 +859,35 @@ class VLAFlowMatching(nn.Module):
num_steps = self.config.num_steps
dt = -1.0 / num_steps
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
use_training_time_rtc = self._training_time_rtc_inference_enabled()
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
x_t=input_x_t,
if use_training_time_rtc:
x_t_cond, time_tensor = apply_training_time_rtc_inference(
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
)
v_t = self.denoise_step(
x_t=x_t_cond,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=current_timestep,
timestep=time_tensor,
)
elif self._rtc_enabled():
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
x_t=input_x_t,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=current_timestep,
)
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
@@ -853,7 +898,13 @@ class VLAFlowMatching(nn.Module):
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
v_t = self.denoise_step(
x_t=x_t,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=time_tensor,
)
x_t = x_t + dt * v_t
-49
View File
@@ -1,49 +0,0 @@
# π₀.₅ (pi05)
This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
It is designed as a **Vision-Language-Action model with open-world generalization**.
---
## Model Overview
| Feature | π₀ | π₀.₅ |
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
| AdaRMS | Not used | Used in action expert |
| Tokenizer Length | 48 tokens | 200 tokens |
| Discrete State Input | False (Uses `state_proj` layer) | True |
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
---
## Citation
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
```bibtex
@misc{openpi2024,
author = {Physical Intelligence Lab},
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
year = {2024},
publisher = {GitHub},
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
license = {Apache-2.0}
}
@misc{intelligence2025pi05visionlanguageactionmodelopenworld,
title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization},
author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky},
year = {2025},
eprint = {2504.16054},
archivePrefix= {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2504.16054},
}
```
---
## License
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
-31
View File
@@ -1,31 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lazy imports to avoid conflicts with lerobot.policies.pi05.PI05Config
# when only importing subpackages like videoprism
def __getattr__(name):
if name == "PI05VideoConfig":
from .configuration_pi05 import PI05VideoConfig
return PI05VideoConfig
elif name == "PI05VideoPolicy":
from .modeling_pi05 import PI05VideoPolicy
return PI05VideoPolicy
elif name == "make_pi05_video_pre_post_processors":
from .processor_pi05 import make_pi05_video_pre_post_processors
return make_pi05_video_pre_post_processors
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
__all__ = ["PI05VideoConfig", "PI05VideoPolicy", "make_pi05_video_pre_post_processors"]
@@ -1,212 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224
@PreTrainedConfig.register_subclass("pi05_video")
@dataclass
class PI05VideoConfig(PreTrainedConfig):
paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m"
dtype: str = "float32" # Options: "bfloat16", "float32"
n_obs_steps: int = 1
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
n_action_steps: int = 50 # Number of action steps to execute
# Video encoder settings (VideoPrism)
use_video_encoder: bool = False # Enable video encoding with VideoPrism
video_num_frames: int = 16 # Number of frames for video encoding (VideoPrism default is 16)
videoprism_model_name: str = "MHRDYN7/videoprism-base-f16r288" # VideoPrism model to use
videoprism_image_size: int = 288 # VideoPrism expects 288x288 images
freeze_video_encoder: bool = True # Whether to freeze the video encoder weights
video_padding_mode: str = "repeat" # How to pad frames at episode start: "repeat" or "zero"
# Which camera to use for video encoding (None = first camera, or specify key like "observation.images.top")
video_encoder_camera_key: str | None = None
# Perceiver Resampler settings to reduce video tokens (4096 -> video_num_latents)
video_num_latents: int = 256 # Number of latent tokens for video resampler
video_resampler_num_heads: int = 8 # Number of attention heads in resampler
# Shorter state and action vectors will be padded to these dimensions
max_state_dim: int = 32
max_action_dim: int = 32
# Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10
time_sampling_beta_alpha: float = 1.5
time_sampling_beta_beta: float = 1.0
time_sampling_scale: float = 0.999
time_sampling_offset: float = 0.001
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE,
DEFAULT_IMAGE_SIZE,
) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
empty_cameras: int = 0
tokenizer_max_length: int = 200 # see openpi `__post_init__`
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state
"ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action
}
)
# Training settings
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
compile_model: bool = False # Whether to use torch.compile for model optimization
compile_mode: str = "max-autotune" # Torch compile mode
device: str | None = None # Device to use for the model (None = auto-detect)
# Finetuning settings
freeze_vision_encoder: bool = False # Freeze only the vision encoder
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
# Optimizer settings: see openpi `AdamW`
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.01
optimizer_grad_clip_norm: float = 1.0
# Scheduler settings: see openpi `CosineDecaySchedule`
# Note: These will auto-scale if --steps < scheduler_decay_steps
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
tokenizer_max_length: int = 200 # see openpi `__post_init__`
def __post_init__(self):
super().__post_init__()
# Validate configuration
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
)
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
if self.dtype not in ["bfloat16", "float32"]:
raise ValueError(f"Invalid dtype: {self.dtype}")
# Validate video encoder settings
if self.use_video_encoder:
if self.video_num_frames < 1:
raise ValueError(f"video_num_frames must be >= 1, got {self.video_num_frames}")
if self.videoprism_image_size < 1:
raise ValueError(f"videoprism_image_size must be >= 1, got {self.videoprism_image_size}")
if self.video_padding_mode not in ["repeat", "zero"]:
raise ValueError(
f"video_padding_mode must be 'repeat' or 'zero', got {self.video_padding_mode}"
)
def validate_features(self) -> None:
"""Validate and set up input/output features."""
for i in range(self.empty_cameras):
key = OBS_IMAGES + f".empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, *self.image_resolution), # Use configured image resolution
)
self.input_features[key] = empty_camera
if OBS_STATE not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim
)
self.input_features[OBS_STATE] = state_feature
if ACTION not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim
)
self.output_features[ACTION] = action_feature
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list[int] | None:
"""Return indices for delta observations.
For PI05, we don't use generic observation_delta_indices because it would
apply to both images AND state. Instead, we use image_observation_delta_indices
which only applies to image observations.
"""
return None
@property
def image_observation_delta_indices(self) -> list[int] | None:
"""Return indices for delta image observations only.
When video encoding is enabled, returns indices for the past frames
needed by VideoPrism (e.g., -15, -14, ..., -1, 0 for 16 frames).
This only applies to image observations, not state.
"""
if self.use_video_encoder:
# Return indices for past frames: [-15, -14, ..., -1, 0] for 16 frames
return list(range(-(self.video_num_frames - 1), 1))
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
File diff suppressed because it is too large Load Diff
@@ -1,171 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig
from lerobot.policies.pi05.modeling_pi05 import pad_vector
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
@dataclass
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
"""
Processor step to prepare the state and tokenize the language input.
"""
max_state_dim: int = 32
task_key: str = "task"
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
if state is None:
raise ValueError("State is required for PI05")
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
if tasks is None:
raise ValueError("No task found in complementary data")
# TODO: check if this necessary
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.max_state_dim)
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
return transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
This step does not alter the feature definitions.
"""
return features
def make_pi05_video_pre_post_processors(
config: PI05VideoConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for the PI05Video policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Appending a newline character to the task description for tokenizer compatibility.
5. Tokenizing the text prompt using the PaliGemma tokenizer.
6. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the PI0 policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DeviceProcessorStep(device=config.device),
]
output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@@ -1,214 +0,0 @@
#!/usr/bin/env python
"""
Test script for PI05 with video encoder (VideoPrism).
This script creates a dummy example to test the model with video encoding enabled.
"""
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig
from lerobot.policies.videovla.modeling_pi05 import PI05VideoPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
def create_dummy_batch(
batch_size: int = 2,
num_frames: int = 16,
image_size: int = 224,
num_cameras: int = 2,
state_dim: int = 14,
action_dim: int = 14,
chunk_size: int = 50,
seq_len: int = 10,
device: str = "cuda",
) -> dict[str, torch.Tensor]:
"""Create a dummy batch for testing."""
batch = {}
# Create image observations with temporal dimension [B, T, C, H, W]
for i in range(num_cameras):
key = f"{OBS_IMAGES}.camera_{i}"
# Images in [0, 1] range
batch[key] = torch.rand(batch_size, num_frames, 3, image_size, image_size, device=device)
# Create state observation [B, state_dim]
batch[OBS_STATE] = torch.rand(batch_size, state_dim, device=device)
# Create language tokens and attention mask [B, seq_len]
batch["observation.language.tokens"] = torch.randint(0, 1000, (batch_size, seq_len), device=device)
batch["observation.language.attention_mask"] = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
# Create action targets [B, chunk_size, action_dim]
batch[ACTION] = torch.rand(batch_size, chunk_size, action_dim, device=device)
return batch
def test_video_encoder():
"""Test the PI05 model with video encoding enabled."""
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Configuration
batch_size = 2
num_frames = 16
image_size = 224
num_cameras = 2
state_dim = 14
action_dim = 14
chunk_size = 50
# Create config with video encoder enabled
print("Creating PI05VideoConfig with video encoder...")
config = PI05VideoConfig(
use_video_encoder=True,
video_num_frames=num_frames,
videoprism_model_name="MHRDYN7/videoprism-base-f16r288",
videoprism_image_size=288,
freeze_video_encoder=True,
video_padding_mode="repeat",
video_encoder_camera_key=f"{OBS_IMAGES}.camera_0", # Use first camera for video
chunk_size=chunk_size,
max_action_dim=32,
max_state_dim=32,
dtype="float32", # Use float32 for testing
device=device,
)
# Set up input/output features
for i in range(num_cameras):
key = f"{OBS_IMAGES}.camera_{i}"
config.input_features[key] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, image_size, image_size),
)
config.input_features[OBS_STATE] = PolicyFeature(
type=FeatureType.STATE,
shape=(state_dim,),
)
config.output_features[ACTION] = PolicyFeature(
type=FeatureType.ACTION,
shape=(action_dim,),
)
print(f"use_video_encoder: {config.use_video_encoder}")
print(f"video_num_frames: {config.video_num_frames}")
print(f"video_padding_mode: {config.video_padding_mode}")
print(f"video_encoder_camera_key: {config.video_encoder_camera_key}")
print(f"image_observation_delta_indices: {config.image_observation_delta_indices}")
# Create model
model = PI05VideoPolicy(config)
model.to(device)
# Create dummy batch
batch = create_dummy_batch(
batch_size=batch_size,
num_frames=num_frames,
image_size=image_size,
num_cameras=num_cameras,
state_dim=state_dim,
action_dim=action_dim,
chunk_size=chunk_size,
device=device,
)
print(f"Batch keys: {list(batch.keys())}" )
for key, value in batch.items():
print(f"{key}: {value.shape}")
# Test forward pass
model.train()
try:
loss, loss_dict = model.forward(batch)
print(f"Forward pass successful!")
print(f"Loss: {loss.item():.4f}")
print(f"Loss dict: {loss_dict}")
except Exception as e:
print(f"Forward pass failed: {e}")
raise
# Test inference
model.eval()
with torch.no_grad():
try:
actions = model.predict_action_chunk(batch)
print(f"Test pass, inference pass!")
print(f"Predicted actions shape: {actions.shape}")
except Exception as e:
print(f"Inference failed: {e}")
raise
print("All tests passed!")
def test_frame_padding():
"""Test frame padding at episode start."""
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create config
config = PI05VideoConfig(
use_video_encoder=True,
video_num_frames=16,
videoprism_model_name="MHRDYN7/videoprism-base-f16r288",
freeze_video_encoder=True,
video_padding_mode="repeat",
chunk_size=50,
dtype="float32",
device=device,
)
# Set up minimal features
config.input_features[f"{OBS_IMAGES}.camera_0"] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224),
)
config.output_features[ACTION] = PolicyFeature(
type=FeatureType.ACTION,
shape=(14,),
)
# Create model
model = PI05VideoPolicy(config)
model.to(device)
# Test with fewer frames than expected (simulating episode start)
batch = {
f"{OBS_IMAGES}.camera_0": torch.rand(2, 5, 3, 224, 224, device=device),
"observation.language.tokens": torch.randint(0, 1000, (2, 10), device=device),
"observation.language.attention_mask": torch.ones(2, 10, dtype=torch.bool, device=device),
ACTION: torch.rand(2, 50, 14, device=device),
}
video_frames = model._preprocess_video(batch)
if video_frames is not None:
print(f"Input frames: 5")
print(f"Output video_frames shape: {video_frames.shape}")
print(f"Expected: [2, 16, 3, 224, 224]")
assert video_frames.shape == (2, 16, 3, 224, 224), f"Unexpected shape: {video_frames.shape}"
print("Frame padding test PASSED!")
else:
print("video_frames is None (unexpected)")
# Test with single frame
batch[f"{OBS_IMAGES}.camera_0"] = torch.rand(2, 3, 224, 224, device=device) # [B, C, H, W]
video_frames = model._preprocess_video(batch)
if video_frames is not None:
print(f"Input: single frame [B, C, H, W]")
print(f"Output video_frames shape: {video_frames.shape}")
print(f"Expected: [2, 16, 3, 224, 224]")
assert video_frames.shape == (2, 16, 3, 224, 224), f"Unexpected shape: {video_frames.shape}"
print("Single frame expansion test PASSED!")
else:
print("video_frames is None (unexpected)")
print("All tests passed!")
if __name__ == "__main__":
# Run tests
test_frame_padding()
test_video_encoder()
@@ -1,37 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_videoprism import VideoPrismConfig, VideoPrismTextConfig, VideoPrismVisionConfig
from .modeling_videoprism import (
VideoPrismClipModel,
VideoPrismForVideoClassification,
VideoPrismPreTrainedModel,
VideoPrismTextModel,
VideoPrismVideoModel,
VideoPrismVisionModel,
)
from .video_processing_videoprism import VideoPrismVideoProcessor
__all__ = [
"VideoPrismConfig",
"VideoPrismTextConfig",
"VideoPrismVisionConfig",
"VideoPrismClipModel",
"VideoPrismForVideoClassification",
"VideoPrismPreTrainedModel",
"VideoPrismTextModel",
"VideoPrismVideoModel",
"VideoPrismVisionModel",
"VideoPrismVideoProcessor",
]
@@ -1,269 +0,0 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/videoprism/modular_videoprism.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_videoprism.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from transformers import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class VideoPrismVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`VideoPrismVisionModel`]. It is used to instantiate a
VideoPrism vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the VideoPrism
[google/videoprism](https://huggingface.co/google/videoprism) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
image_size (`int`, *optional*, defaults to 288):
The size of the input image.
num_frames (`int`, *optional*, defaults to 16):
The number of frames in the input video.
tubelet_size (`List[int]`, *optional*, defaults to `[1, 18, 18]`):
The size of the tubelet patch.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_spatial_layers (`int`, *optional*, defaults to 12):
Number of spatial transformer blocks.
num_temporal_layers (`int`, *optional*, defaults to 4):
Number of temporal transformer blocks.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_python"`):
The non-linear activation function (function or string).
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the qkv projections in attention layers.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
Softcapping constant for attention logits.
num_auxiliary_layers (`int`, *optional*, defaults to 2):
Number of auxiliary layers. This is used in the VideoPrismVideoModel that is a part of VideoPrismClipModel.
apply_l2_norm (`bool`, *optional*, defaults to `True`):
Whether to apply L2 normalization to the output. This is used in the VideoPrismVideoModel that is a part of VideoPrismClipModel.
Example:
```python
>>> from transformers import VideoPrismVisionConfig, VideoPrismVisionModel
>>> # Initializing a VideoPrismVisionConfig with default values
>>> configuration = VideoPrismVisionConfig()
>>> # Initializing a VideoPrismVisionModel with the configuration
>>> model = VideoPrismVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "videoprism_vision_model"
base_config_key = "vision_config"
def __init__(
self,
image_size=288,
num_frames=16,
tubelet_size=[1, 18, 18],
num_channels=3,
hidden_size=768,
num_spatial_layers=12,
num_temporal_layers=4,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu_python",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
layer_norm_eps=1e-06,
qkv_bias=True,
attn_logit_softcapping=50.0,
num_auxiliary_layers=2,
apply_l2_norm=True,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.image_size = image_size
self.num_frames = num_frames
self.tubelet_size = tubelet_size
self.num_channels = num_channels
self.qkv_bias = qkv_bias
self.num_spatial_layers = num_spatial_layers
self.num_temporal_layers = num_temporal_layers
self.attn_logit_softcapping = attn_logit_softcapping
self.num_auxiliary_layers = num_auxiliary_layers
self.apply_l2_norm = apply_l2_norm
class VideoPrismTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`VideoPrismTextModel`]. It is used to instantiate a
VideoPrism text encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the VideoPrism
[google/videoprism](https://huggingface.co/google/videoprism) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_text_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the text Transformer encoder.
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the text model. Defines the number of different tokens that can be represented by the
`input_ids` passed when calling [`VideoPrismTextModel`].
apply_l2_norm (`bool`, *optional*, defaults to `True`):
Whether to apply L2 normalization to the output text embeddings.
hidden_act (`str` or `function`, *optional*, defaults to `"relu"`):
The non-linear activation function (function or string) in the encoder and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the query, key, and value projections in the attention layers.
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
Softcapping constant for attention logits.
Example:
```python
>>> from transformers import VideoPrismTextConfig, VideoPrismTextModel
>>> # Initializing a VideoPrismTextConfig with default values
>>> configuration = VideoPrismTextConfig()
>>> # Initializing a VideoPrismTextModel (with random weights) from the configuration
>>> model = VideoPrismTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "videoprism_text_model"
base_config_key = "text_config"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_attention_heads=12,
num_text_layers=12,
vocab_size=32000,
apply_l2_norm=True,
hidden_act="relu",
attention_probs_dropout_prob=0.0,
qkv_bias=True,
hidden_dropout_prob=0.0,
layer_norm_eps=1e-06,
initializer_range=0.02,
attn_logit_softcapping=50.0,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_text_layers = num_text_layers
self.vocab_size = vocab_size
self.apply_l2_norm = apply_l2_norm
self.hidden_act = hidden_act
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.qkv_bias = qkv_bias
self.hidden_dropout_prob = hidden_dropout_prob
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.attn_logit_softcapping = attn_logit_softcapping
class VideoPrismConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`VideoPrismModel`]. It is used to instantiate a
VideoPrism model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the VideoPrism
[google/videoprism](https://huggingface.co/google/videoprism) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`VideoPrismTextConfig`, *optional*):
Configuration for the text model.
vision_config (`VideoPrismVisionConfig`, *optional*):
Configuration for the vision model.
kwargs (*optional*):
Dictionary of keyword arguments.
Example:
```python
>>> from transformers import VideoPrismConfig, VideoPrismModel
>>> # Initializing a VideoPrismConfig with default values
>>> configuration = VideoPrismConfig()
>>> # Initializing a VideoPrismClipModel with the configuration
>>> model = VideoPrismClipModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "videoprism"
sub_configs = {"text_config": VideoPrismTextConfig, "vision_config": VideoPrismVisionConfig}
def __init__(self, text_config=None, vision_config=None, **kwargs):
if text_config is None:
text_config = VideoPrismTextConfig()
logger.info("`text_config` is `None`. Initializing the `VideoPrismTextConfig` with default values.")
elif isinstance(text_config, dict):
text_config = VideoPrismTextConfig(**text_config)
if vision_config is None:
vision_config = VideoPrismVisionConfig()
logger.info("`vision_config` is `None`. initializing the `VideoPrismVisionConfig` with default values.")
elif isinstance(vision_config, dict):
vision_config = VideoPrismVisionConfig(**vision_config)
self.text_config = text_config
self.vision_config = vision_config
super().__init__(**kwargs)
__all__ = ["VideoPrismVisionConfig", "VideoPrismTextConfig", "VideoPrismConfig"]
@@ -1,245 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from collections import defaultdict
from contextlib import contextmanager
import torch
# Record all the torch primitives in advance, so that we can use them without them being modified when we patch torch
# in context managers
TORCH_INIT_FUNCTIONS = {
"uniform_": torch.nn.init.uniform_,
"normal_": torch.nn.init.normal_,
"constant_": torch.nn.init.constant_,
"ones_": torch.nn.init.ones_,
"zeros_": torch.nn.init.zeros_,
"eye_": torch.nn.init.eye_,
"dirac_": torch.nn.init.dirac_,
"xavier_uniform_": torch.nn.init.xavier_uniform_,
"xavier_normal_": torch.nn.init.xavier_normal_,
"kaiming_uniform_": torch.nn.init.kaiming_uniform_,
"kaiming_normal_": torch.nn.init.kaiming_normal_,
"trunc_normal_": torch.nn.init.trunc_normal_,
"orthogonal_": torch.nn.init.orthogonal_,
"sparse_": torch.nn.init.sparse_,
}
def uniform_(
tensor: torch.Tensor, a: float = 0.0, b: float = 1.0, generator: torch.Generator | None = None
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["uniform_"](tensor, a=a, b=b, generator=generator)
return tensor
def normal_(
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, generator: torch.Generator | None = None
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["normal_"](tensor, mean=mean, std=std, generator=generator)
return tensor
def constant_(tensor: torch.Tensor, val: float) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["constant_"](tensor, val=val)
return tensor
def ones_(tensor: torch.Tensor) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["ones_"](tensor)
return tensor
def zeros_(tensor: torch.Tensor) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["zeros_"](tensor)
return tensor
def eye_(tensor: torch.Tensor) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["eye_"](tensor)
return tensor
def dirac_(tensor: torch.Tensor, groups: int = 1) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["dirac_"](tensor, groups=groups)
return tensor
def xavier_uniform_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["xavier_uniform_"](tensor, gain=gain, generator=generator)
return tensor
def xavier_normal_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["xavier_normal_"](tensor, gain=gain, generator=generator)
return tensor
def kaiming_uniform_(
tensor: torch.Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
generator: torch.Generator | None = None,
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["kaiming_uniform_"](
tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
)
return tensor
def kaiming_normal_(
tensor: torch.Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
generator: torch.Generator | None = None,
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["kaiming_normal_"](
tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
)
return tensor
def trunc_normal_(
tensor: torch.Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
generator: torch.Generator | None = None,
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["trunc_normal_"](tensor, mean=mean, std=std, a=a, b=b, generator=generator)
return tensor
def orthogonal_(
tensor: torch.Tensor,
gain: float = 1,
generator: torch.Generator | None = None,
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["orthogonal_"](tensor, gain=gain, generator=generator)
return tensor
def sparse_(
tensor: torch.Tensor, sparsity: float, std: float = 0.01, generator: torch.Generator | None = None
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["sparse_"](tensor, sparsity=sparsity, std=std, generator=generator)
return tensor
def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
with torch.no_grad():
return tensor.copy_(other)
return tensor
# Here, we need to check several modules imported, and hot patch all of them, as sometimes torch does
# something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules.activations,
# where MultiHeadAttention lives), so the function name is binded at import time and just doing
# `setattr(torch.nn.init, name, globals()[name])` is thus not enough
# The following list should be enough for all torch versions we work with
TORCH_MODULES_TO_PATCH = (
"torch.nn.init",
"torch.nn.modules.activation",
"torch.nn.modules.transformer",
"torch.nn.modules.linear",
"torch.nn.modules.loss",
"torch.nn.modules.batchnorm",
"torch.nn.modules.conv",
"torch.nn.modules.normalization",
"torch.nn.modules.rnn",
"torch.nn.modules.sparse",
)
@contextmanager
def guard_torch_init_functions():
"""
Guard the `torch.nn.init` primitive functions to behave exactly like the functions in this file, i.e. be
protected against the `_is_hf_initialized` flag to avoid re-init if the param was already loaded.
Usually, all models are using the init from `transformers` which are already guarded, but just to make extra sure
and for remote code, we also use this context manager.
"""
originals = defaultdict(dict)
try:
# Replace all torch funcs by the ones in this file
for module_name in TORCH_MODULES_TO_PATCH:
if module_name in sys.modules:
module = sys.modules[module_name]
for func_name in TORCH_INIT_FUNCTIONS.keys():
if hasattr(module, func_name):
originals[module][func_name] = getattr(module, func_name)
setattr(module, func_name, globals()[func_name])
yield
finally:
# Set back the original functions on all modules
for module, functions in originals.items():
for func_name, func in functions.items():
setattr(module, func_name, func)
@contextmanager
def no_init_weights():
"""
Disable weight initialization both at the torch-level, and at the transformers-level (`init_weights`).
This is used to speed-up initializing an empty model with deepspeed, as we do not initialize the model on meta device
with deepspeed, but we still don't need to run expensive weight initializations as we are loading params afterwards.
"""
from .modeling_utils import PreTrainedModel
def empty_func(*args, **kwargs):
pass
originals = defaultdict(dict)
try:
# Replace all torch funcs by empty ones
for module_name in TORCH_MODULES_TO_PATCH:
if module_name in sys.modules:
module = sys.modules[module_name]
for func_name in TORCH_INIT_FUNCTIONS.keys():
if hasattr(module, func_name):
originals[module][func_name] = getattr(module, func_name)
setattr(module, func_name, empty_func)
# Also patch our own `init_weights`
original_init_weights = PreTrainedModel.init_weights
PreTrainedModel.init_weights = empty_func
yield
finally:
# Set back the original torch functions on all modules
for module, functions in originals.items():
for func_name, func in functions.items():
setattr(module, func_name, func)
# Set back `init_weights`
PreTrainedModel.init_weights = original_init_weights
@@ -1,994 +0,0 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/videoprism/modular_videoprism.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_videoprism.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
from collections.abc import Callable
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import _calculate_fan_in_and_fan_out
from . import initialization as init
from transformers.activations import ACT2FN
from transformers.masking_utils import create_causal_mask
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutput, ImageClassifierOutput
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.file_utils import ModelOutput
from .configuration_videoprism import VideoPrismConfig, VideoPrismTextConfig, VideoPrismVisionConfig
def torch_int(x):
"""
Casts an input to a torch int64 tensor if we are in a tracing context, otherwise to a Python int.
"""
if not torch.is_available():
return int(x)
return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
@dataclass
class BaseModelOutputWithSpatialAndTemporalStates(ModelOutput):
"""
Base class for model outputs that include spatial and temporal states.
Args:
last_hidden_state (Optional[torch.FloatTensor]):
The last hidden state of the model, typically of shape
(batch_size, num_patches * num_frames, hidden_size).
temporal_hidden_state (Optional[torch.FloatTensor]):
The last hidden_state of the temporal encoder, typically of shape
(batch_size * num_patches, num_frames, hidden_size).
spatial_hidden_state (Optional[torch.FloatTensor]):
The last hidden_state of the spatial encoder, typically of shape
(batch_size * num_frames, num_patches, hidden_size).
"""
last_hidden_state: torch.FloatTensor | None = None
temporal_hidden_state: torch.FloatTensor | None = None
spatial_hidden_state: torch.FloatTensor | None = None
@dataclass
class VideoPrismClipOutput(ModelOutput):
"""
Base class for VideoPrismClip model outputs.
"""
logits_per_video: torch.FloatTensor | None = None
logits_per_text: torch.FloatTensor | None = None
video_embeds: torch.FloatTensor | None = None
text_embeds: torch.FloatTensor | None = None
@dataclass
class VideoPrismVideoOutput(ModelOutput):
"""
Base class for VideoPrismVideo model outputs.
"""
video_last_hidden_state: torch.FloatTensor | None = None
auxiliary_output: torch.FloatTensor | None = None
attention_pooling_output: torch.FloatTensor | None = None
class VideoPrismTubeletEmbeddings(nn.Module):
"""
Construct VideoPrism Tubelet embeddings.
This module turns a batch of videos of shape (batch_size, num_frames, num_channels, height, width) into a tensor of
shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.
The seq_len (the number of patches) equals (number of frames // tubelet_size[0]) * (height // tubelet_size[1]) *
(width // tubelet_size[2]).
"""
def __init__(self, config: VideoPrismVisionConfig):
super().__init__()
self.config = config
self.num_frames = config.num_frames
self.image_size = (
config.image_size
if isinstance(self.config.image_size, tuple)
else (self.config.image_size, self.config.image_size)
)
self.patch_size = config.tubelet_size
self.embed_dim = config.hidden_size
self.projection = nn.Conv3d(
config.num_channels, config.hidden_size, kernel_size=config.tubelet_size, stride=config.tubelet_size
)
self.pos_emb_shape = [self.image_size[0] // self.patch_size[1], self.image_size[1] // self.patch_size[2]]
self.num_patches = self.pos_emb_shape[0] * self.pos_emb_shape[1]
def forward(self, pixel_values_videos: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, num_frames, num_channels, height, width = pixel_values_videos.shape
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]}). Set interpolate_pos_encoding=True to automatically resize the model position embeddings."
)
# permute to (batch_size, num_channels, num_frames, height, width)
pixel_values_videos = pixel_values_videos.permute(0, 2, 1, 3, 4)
hidden_states = self.projection(pixel_values_videos)
# flatten the spatial part and permute to (B, T, num_patches, dim)
hidden_states = hidden_states.flatten(3).permute(0, 2, 3, 1)
# combine batch and time dimension
batch_size, num_frames, num_patches, hidden_size = hidden_states.shape
hidden_states = hidden_states.reshape(batch_size * num_frames, num_patches, hidden_size)
return hidden_states
class VideoPrismSpatialEmbeddings(nn.Module):
"""
VideoPrism Spatial Embeddings.
Creates embeddings from a video using VideoPrismSpatialTubeletEmbeddings and adds positional embeddings.
"""
def __init__(self, config: VideoPrismVisionConfig):
super().__init__()
self.config = config
self.patch_embeddings = VideoPrismTubeletEmbeddings(config)
self.position_embeddings = nn.Parameter(torch.zeros(1, self.patch_embeddings.num_patches, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.tubelet_size[1:]
self.tubelet_size = config.tubelet_size
# Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
num_patches = embeddings.shape[1]
num_positions = self.position_embeddings.shape[1]
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings
dim = embeddings.shape[-1]
num_row_patches = height // self.patch_size[0]
num_col_patches = width // self.patch_size[1]
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = self.position_embeddings.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(num_row_patches, num_col_patches),
mode="bilinear",
antialias=True,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
def forward(
self, pixel_values_videos: torch.Tensor, interpolate_pos_encoding: bool | None = False
) -> torch.Tensor:
b, t, c, h, w = pixel_values_videos.shape
assert h == w, "Input image height and width must be the same"
embeddings = self.patch_embeddings(pixel_values_videos, interpolate_pos_encoding)
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, h, w)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class VideoPrismTemporalEmbeddings(nn.Module):
"""
VideoPrism Temporal Embeddings.
Receives embeddings from spatial encoder, reshapes the hidden state to
(batch_size * num_patches, num_frames, hidden_size) and adds positional embeddings.
"""
def __init__(self, config: VideoPrismVisionConfig):
super().__init__()
self.config = config
self.position_embeddings = nn.Parameter(torch.zeros(1, self.config.num_frames, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
target_emb_length = embeddings.shape[1]
source_emb_length = self.position_embeddings.shape[1]
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and target_emb_length == source_emb_length:
return self.position_embeddings
source_emb = self.position_embeddings
dim = embeddings.shape[-1]
source_emb = source_emb.unsqueeze(1)
source_emb = nn.functional.interpolate(
source_emb,
size=(target_emb_length, dim),
mode="bilinear",
antialias=True,
)
return source_emb.squeeze(1)
def forward(
self,
pixel_values_videos: torch.Tensor,
input_shape: torch.Size,
interpolate_pos_encoding: bool | None = False,
) -> torch.Tensor:
if input_shape is not None:
b, t, c, h, w = input_shape
_, features, dim = pixel_values_videos.shape
hidden_states = pixel_values_videos.view(b, t, features, dim)
hidden_states = hidden_states.permute(0, 2, 1, 3)
embeddings = hidden_states.reshape(b * features, t, dim)
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
softcap: float | None = None,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if softcap is not None:
attn_weights = attn_weights / softcap
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * softcap
if attention_mask is not None:
attn_weights = attn_weights + attention_mask.expand(*attn_weights.shape)
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class VideoPrismSelfAttention(nn.Module):
def __init__(self, config: VideoPrismVisionConfig | VideoPrismTextConfig):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scale = self.attention_head_size**-0.5
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.shape[0]
new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
query = self.query(hidden_states).view(*new_shape).transpose(1, 2)
key = self.key(hidden_states).view(*new_shape).transpose(1, 2)
value = self.value(hidden_states).view(*new_shape).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
context_layer, attention_probs = attention_interface(
self,
query,
key,
value,
attention_mask,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout_prob,
softcap=self.config.attn_logit_softcapping,
**kwargs,
)
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
return (context_layer, attention_probs)
class VideoPrismSelfOutput(nn.Module):
"""
The residual connection is defined in VideoPrismLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: VideoPrismConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class VideoPrismAttention(nn.Module):
def __init__(self, config: VideoPrismConfig):
super().__init__()
self.attention = VideoPrismSelfAttention(config)
self.output = VideoPrismSelfOutput(config)
def forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, **kwargs
) -> torch.Tensor:
self_attn_output, _ = self.attention(hidden_states, attention_mask, **kwargs)
output = self.output(self_attn_output, hidden_states)
return output
class VideoPrismLayerNorm(nn.LayerNorm):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return F.layer_norm(hidden_states, self.normalized_shape, self.weight + 1, self.bias, self.eps)
class VideoPrismIntermediate(nn.Module):
def __init__(self, config: VideoPrismConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class VideoPrismOutput(nn.Module):
def __init__(self, config: VideoPrismConfig):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class VideoPrismLayer(GradientCheckpointingLayer):
"""This corresponds to the EncoderBlock class in the scenic/videoprism implementation."""
def __init__(self, config: VideoPrismVisionConfig | VideoPrismTextConfig):
super().__init__()
self.config = config
self.attention = VideoPrismAttention(config)
self.intermediate = VideoPrismIntermediate(config)
self.output = VideoPrismOutput(config)
self.layernorm_before = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
self.layernorm_after = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
hidden_states_norm = self.layernorm_before(hidden_states)
attention_output = self.attention(hidden_states_norm, attention_mask, **kwargs)
# first residual connection
hidden_states = attention_output + hidden_states
# in VideoPrism, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
# second residual connection is done here
layer_output = self.output(layer_output, hidden_states)
return layer_output
class VideoPrismSpatialEncoder(nn.Module):
def __init__(self, config: VideoPrismVisionConfig):
super().__init__()
self.config = config
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_spatial_layers)])
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> BaseModelOutput:
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states)
return BaseModelOutput(last_hidden_state=hidden_states)
class VideoPrismTemporalEncoder(nn.Module):
def __init__(self, config: VideoPrismVisionConfig):
super().__init__()
self.config = config
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_temporal_layers)])
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> BaseModelOutput:
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states)
return BaseModelOutput(last_hidden_state=hidden_states)
class VideoPrismAuxiliaryEncoder(nn.Module):
def __init__(self, config: VideoPrismVisionConfig):
super().__init__()
self.config = config
self.layer = nn.ModuleList([VideoPrismLayer(self.config) for _ in range(config.num_auxiliary_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**kwargs,
) -> BaseModelOutput:
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask, **kwargs)
return BaseModelOutput(last_hidden_state=hidden_states)
class VideoPrismTextEncoder(nn.Module):
def __init__(self, config: VideoPrismTextConfig):
super().__init__()
self.config = config
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_text_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**kwargs,
) -> BaseModelOutput:
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask, **kwargs)
return BaseModelOutput(last_hidden_state=hidden_states)
def variance_scaling_(tensor, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = 1.0 / denom
if distribution == "truncated_normal":
init.trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
init.normal_(tensor, std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
init.uniform_(tensor, -bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
class VideoPrismPreTrainedModel(PreTrainedModel):
config_class = VideoPrismConfig
config: VideoPrismConfig
base_model_prefix = "videoprism"
main_input_name = "pixel_values_videos"
input_modalities = ("video", "text")
supports_gradient_checkpointing = True
_no_split_modules = [
"VideoPrismSpatialEmbeddings",
"VideoPrismTemporalEmbeddings",
"VideoPrismSpatialEncoder",
"VideoPrismTemporalEncoder",
"VideoPrismAuxiliaryEncoder",
"VideoPrismTextEncoder",
"VideoPrismMultiheadAttentionPoolingHead",
]
_supports_sdpa = True
_supports_flash_attn = True
_supports_attention_backend = True
_supports_flex_attention = True
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Conv3d)):
lecun_normal_(module.weight)
init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
init.zeros_(module.bias)
init.ones_(module.weight)
class VideoPrismVisionModel(VideoPrismPreTrainedModel):
config_class = VideoPrismVisionConfig
config: VideoPrismVisionConfig
def __init__(self, config: VideoPrismVisionConfig):
super().__init__(config)
self.config = config
self.layernorm1 = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
self.layernorm2 = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
self.spatial_embeddings = VideoPrismSpatialEmbeddings(self.config)
self.temporal_embeddings = VideoPrismTemporalEmbeddings(self.config)
self.spatial_encoder = VideoPrismSpatialEncoder(self.config)
self.temporal_encoder = VideoPrismTemporalEncoder(self.config)
self.post_init()
def get_input_embeddings(self):
return self.spatial_embeddings.patch_embeddings
def forward(
self,
pixel_values_videos: torch.FloatTensor | None = None,
interpolate_pos_encoding: bool | None = False,
**kwargs,
) -> BaseModelOutputWithSpatialAndTemporalStates:
r"""
Args:
pixel_values_videos (`torch.FloatTensor`):
Pixel values of the video frames of shape (batch_size, num_frames, num_channels, height, width).
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate positional encodings to match input size.
Example:
```python
>>> from transformers import VideoPrismVideoProcessor, VideoPrismVisionModel
>>> import torch
>>> processor = VideoPrismVideoProcessor.from_pretrained("google/videoprism")
>>> model = VideoPrismVisionModel.from_pretrained("google/videoprism")
>>> video = "sample_video.mp4"
>>> inputs = processor(videos=video)
>>> with torch.no_grad():
... outputs = model(**inputs)
... features = outputs.last_hidden_state
```
"""
if pixel_values_videos is None:
raise ValueError("You have to specify pixel_values_videos")
input_shape = pixel_values_videos.shape
spatial_embeds = self.spatial_embeddings(pixel_values_videos, interpolate_pos_encoding)
spatial_encoder_outputs: BaseModelOutput = self.spatial_encoder(hidden_states=spatial_embeds, **kwargs)
# shape of spatial_sequence_output is (B * num_frames, num_patches, dim)
spatial_sequence_output = spatial_encoder_outputs.last_hidden_state
features = self.layernorm1(spatial_sequence_output)
temporal_embeds = self.temporal_embeddings(features, input_shape, interpolate_pos_encoding)
temporal_encoder_outputs: BaseModelOutput = self.temporal_encoder(hidden_states=temporal_embeds, **kwargs)
# shape of temporal_sequence_output is (B * num_patches, num_frames, dim)
temporal_sequence_output = temporal_encoder_outputs.last_hidden_state
features = self.layernorm2(temporal_sequence_output)
_, num_frames, dim = features.shape
features = features.view(input_shape[0], -1, num_frames, dim).permute(0, 2, 1, 3).contiguous()
_, num_frames, num_patches, dim = features.shape
features = features.view(input_shape[0], num_frames * num_patches, -1)
return BaseModelOutputWithSpatialAndTemporalStates(
last_hidden_state=features,
temporal_hidden_state=temporal_sequence_output,
spatial_hidden_state=spatial_sequence_output,
)
class VideoPrismMultiheadAttentionPoolingHead(nn.Module):
def __init__(self, config: VideoPrismVisionConfig):
super().__init__()
self.config = config
self.num_attention_heads = self.config.num_attention_heads
self.attention_head_size = int(self.config.intermediate_size / self.config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = self.config.attention_probs_dropout_prob
# PerDimScale
self.dim = int(self.config.intermediate_size / self.config.num_attention_heads)
self.per_dim_scale = nn.Parameter(torch.zeros(self.dim))
r_softplus_0 = 1.442695041
scale = torch.tensor(r_softplus_0 / (self.dim**0.5))
softplus = nn.functional.softplus(self.per_dim_scale)
scale = scale * softplus
self.register_buffer("scale", scale)
self.pooling_attention_query = nn.Parameter(torch.zeros(1, 1, self.config.hidden_size))
self.query = nn.Linear(self.config.hidden_size, self.config.intermediate_size, bias=self.config.qkv_bias)
self.key = nn.Linear(self.config.hidden_size, self.config.intermediate_size, bias=self.config.qkv_bias)
self.value = nn.Linear(self.config.hidden_size, self.config.intermediate_size, bias=self.config.qkv_bias)
self.projection = nn.Linear(self.config.intermediate_size, self.config.hidden_size, bias=self.config.qkv_bias)
self.layernorm = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
self.dim = int(self.config.intermediate_size / self.config.num_attention_heads)
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.LongTensor | None = None,
**kwargs,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
batch_size, seq_length, hidden_size = hidden_states.shape
query = self.pooling_attention_query.expand(batch_size, -1, -1)
query_layer = (
self.query(query).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
)
query_layer = query_layer * self.scale.expand(*query_layer.shape)
key_layer = (
self.key(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
value_layer = (
self.value(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
scaling=1.0,
dropout=0.0 if not self.training else self.dropout_prob,
softcap=None,
**kwargs,
)
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
outputs = self.projection(context_layer)
outputs = self.layernorm(outputs)
return (outputs, attention_probs)
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
"""This function is intended to align with the l2norm implementation in the FLA library."""
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return x * inv_norm
class VideoPrismTextModel(VideoPrismPreTrainedModel):
config_class = VideoPrismTextConfig
config: VideoPrismTextConfig
def __init__(self, config: VideoPrismTextConfig):
super().__init__(config)
self.config = config
self.text_encoder = VideoPrismTextEncoder(self.config)
self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.cls_emb = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.layernorm = VideoPrismLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.normalize = config.apply_l2_norm
self.post_init()
def create_sinusoidal_positions(self, num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / (dim - 2)))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**kwargs,
) -> BaseModelOutput:
r"""
Args:
input_ids (`torch.Tensor`):
Input token IDs.
attention_mask (`torch.Tensor`, *optional*):
Attention mask to avoid performing attention on padding token indices.
"""
batch_size, seq_length = input_ids.shape
hidden_states = self.token_embeddings(input_ids)
hidden_states = hidden_states * (self.config.hidden_size**0.5)
cls_padding = torch.ones(batch_size, 1)
input_ids = torch.cat((input_ids, cls_padding), dim=1)
attention_mask = torch.cat((attention_mask, cls_padding), dim=1) if attention_mask is not None else None
if attention_mask is not None:
attention_mask = create_causal_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=attention_mask,
cache_position=torch.arange(hidden_states.shape[1] + 1, device=hidden_states.device),
past_key_values=None,
)
features = hidden_states + self.create_sinusoidal_positions(seq_length, self.config.hidden_size)
cls_emb = self.cls_emb * (self.config.hidden_size**0.5)
cls_emb = cls_emb.expand(features.shape[0], -1, -1)
features = torch.cat((features, cls_emb), dim=1)
text_encoder_output = self.text_encoder(features, attention_mask)
features = text_encoder_output.last_hidden_state
features = self.layernorm(features)
text_embeddings = features[:, -1]
if self.normalize:
text_embeddings = l2norm(text_embeddings, dim=-1)
return BaseModelOutput(
last_hidden_state=text_embeddings,
)
class VideoPrismVideoModel(VideoPrismPreTrainedModel):
config_class = VideoPrismVisionConfig
config: VideoPrismVisionConfig
def __init__(self, config: VideoPrismVisionConfig):
super().__init__(config)
self.config = config
self.backbone = VideoPrismVisionModel(self.config)
self.auxiliary_encoder = VideoPrismAuxiliaryEncoder(self.config)
self.contrastive_vision_pooler = VideoPrismMultiheadAttentionPoolingHead(self.config)
self.normalize = self.config.apply_l2_norm
self.post_init()
def get_input_embeddings(self):
return self.backbone.spatial_embeddings.patch_embeddings
def forward(
self,
pixel_values_videos: torch.FloatTensor,
interpolate_pos_encoding: bool | None = False,
**kwargs,
) -> VideoPrismVideoOutput:
r"""
Args:
pixel_values_videos (`torch.FloatTensor`):
Pixel values of the video frames.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate positional encodings to match input size.
"""
backbone_outputs = self.backbone(
pixel_values_videos=pixel_values_videos, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs
)
video_features = backbone_outputs.last_hidden_state
auxiliary_output = self.auxiliary_encoder(video_features)
auxiliary_output_features = auxiliary_output.last_hidden_state
contrastive_vision_pooler_output = self.contrastive_vision_pooler(auxiliary_output_features, **kwargs)
video_embeddings = contrastive_vision_pooler_output[0]
if self.normalize:
video_embeddings = l2norm(video_embeddings, dim=-1)
return VideoPrismVideoOutput(
video_last_hidden_state=video_embeddings,
auxiliary_output=auxiliary_output,
attention_pooling_output=contrastive_vision_pooler_output,
)
class VideoPrismClipModel(VideoPrismPreTrainedModel):
config_class = VideoPrismConfig
def __init__(self, config: VideoPrismConfig):
super().__init__(config)
self.config = config
self.vision_config = config.vision_config
self.text_config = config.text_config
self.video_model = VideoPrismVideoModel(self.vision_config)
self.text_model = VideoPrismTextModel(self.text_config)
self.post_init()
def forward(
self,
pixel_values_videos: torch.FloatTensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
interpolate_pos_encoding: bool | None = False,
temperature: float | None = None,
**kwargs,
) -> VideoPrismClipOutput:
r"""
Args:
pixel_values_videos (`torch.FloatTensor`):
Pixel values of the video frames.
input_ids (`torch.Tensor`):
Input token IDs for text.
attention_mask (`torch.Tensor`, *optional*):
Attention mask for text inputs.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate positional encodings.
temperature (`float`, *optional*):
Temperature parameter for scaling similarity scores.
Example:
```python
>>> from transformers import VideoPrismProcessor, VideoPrismClipModel
>>> import torch
>>> processor = VideoPrismProcessor.from_pretrained("google/videoprism")
>>> model = VideoPrismClipModel.from_pretrained("google/videoprism")
>>> video = "sample_video.mp4"
>>> texts = ["a dog", "a cat"]
>>> inputs = processor(videos=video, texts=texts, return_tensors="pt", padding=True)
>>> with torch.no_grad():
... outputs = model(**inputs)
... logits_per_video = outputs.logits_per_video
```
"""
video_model_outputs = self.video_model(
pixel_values_videos=pixel_values_videos, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs
)
text_model_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
video_embeddings = video_model_outputs.video_last_hidden_state
text_embeddings = text_model_outputs.last_hidden_state
emb_dim = video_embeddings[0].shape[-1]
assert emb_dim == text_embeddings[0].shape[-1]
video_embeds = video_embeddings.reshape(-1, emb_dim)
text_embeds = text_embeddings.reshape(-1, emb_dim)
similarity_matrix = torch.matmul(video_embeds, text_embeds.T)
if temperature is not None:
similarity_matrix /= temperature
logits_per_video = torch.exp(similarity_matrix)
logits_per_text = logits_per_video.T
logits_per_video = logits_per_video / torch.sum(logits_per_video, dim=0, keepdims=True)
logits_per_text = logits_per_text / torch.sum(logits_per_text, dim=0, keepdims=True)
return VideoPrismClipOutput(
logits_per_video=logits_per_video,
logits_per_text=logits_per_text,
video_embeds=video_embeds,
text_embeds=text_embeds,
)
class VideoPrismForVideoClassification(VideoPrismPreTrainedModel):
config_class = VideoPrismVisionConfig
config: VideoPrismVisionConfig
def __init__(self, config: VideoPrismVisionConfig):
super().__init__(config)
self.config = config
self.encoder = VideoPrismVisionModel(self.config)
self.contrastive_vision_pooler = VideoPrismMultiheadAttentionPoolingHead(self.config)
self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels)
self.post_init()
def get_input_embeddings(self):
return self.encoder.spatial_embeddings.patch_embeddings
def forward(
self,
pixel_values_videos: torch.FloatTensor,
labels: torch.LongTensor | None = None,
interpolate_pos_encoding: bool | None = False,
**kwargs,
) -> ImageClassifierOutput:
r"""
Args:
pixel_values_videos (`torch.FloatTensor`):
Pixel values of the video frames.
labels (`torch.LongTensor`, *optional*):
Video classification labels.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate positional encodings.
Example:
```python
>>> from transformers import VideoPrismVideoProcessor, VideoPrismForVideoClassification
>>> import torch
>>> processor = VideoPrismVideoProcessor("google/videoprism")
>>> model = VideoPrismForVideoClassification.from_pretrained("google/videoprism", num_labels=1000)
>>> video = "sample_video.mp4"
>>> inputs = processor(videos=video, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
... logits = outputs.logits
```
"""
encoder_outputs = self.encoder(
pixel_values_videos=pixel_values_videos, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs
)
sequence_output = encoder_outputs.last_hidden_state
pooled_output = self.contrastive_vision_pooler(sequence_output, **kwargs).pooled_output
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss = self.loss_function(labels, logits, self.config, **kwargs)
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=encoder_outputs.last_hidden_state,
)
__all__ = [
"VideoPrismVisionModel",
"VideoPrismPreTrainedModel",
"VideoPrismVideoModel",
"VideoPrismTextModel",
"VideoPrismClipModel",
"VideoPrismForVideoClassification",
]
@@ -1,50 +0,0 @@
import torch
import numpy as np
from torchcodec.decoders import VideoDecoder
from lerobot.policies.videovla.videoprism import VideoPrismVideoProcessor
from lerobot.policies.videovla.videoprism import VideoPrismVisionModel
processor = VideoPrismVideoProcessor.from_pretrained(
"MHRDYN7/videoprism-base-f16r288"
)
model = VideoPrismVisionModel.from_pretrained(
"MHRDYN7/videoprism-base-f16r288",
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="sdpa",
)
video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/archery/-Qz25rXdMjE_000014_000024.mp4"
vr = VideoDecoder(video_url)
frame_idx = np.arange(0, 64)
video = vr.get_frames_at(indices=frame_idx).data # T x C x H x W
video = processor(video, return_tensors="pt")
video = {k: v.to(model.device, model.dtype) for k, v in video.items()}
outputs = model(**video)
encoder_outputs = outputs.last_hidden_state
print(encoder_outputs.shape) #
import time
import torch
# warmup
for _ in range(10):
_ = model(**video)
times = []
for _ in range(50):
torch.cuda.synchronize()
t0 = time.perf_counter()
_ = model(**video)
torch.cuda.synchronize()
t1 = time.perf_counter()
times.append(t1 - t0)
print(f"Mean: {1000*sum(times)/len(times):.2f} ms")
print(f"Min : {1000*min(times):.2f} ms")
print(f"Max : {1000*max(times):.2f} ms")
@@ -1,44 +0,0 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/videoprism/modular_videoprism.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_videoprism.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
from transformers.video_processing_utils import BaseVideoProcessor
class VideoPrismVideoProcessor(BaseVideoProcessor):
r"""
Constructs a VideoPrism video processor.
This processor inherits from [`LlavaOnevisionVideoProcessor`] and sets default parameters for VideoPrism models.
Video frames are resized to 288x288 using bicubic resampling without normalization.
Args:
size (`Dict[str, int]`, *optional*, defaults to `{"height": 288, "width": 288}`):
The size to resize the video frames to.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
The resampling filter to use when resizing images.
do_normalize (`bool`, *optional*, defaults to `False`):
Whether to normalize the video frames.
"""
resample = PILImageResampling.BICUBIC
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
size = {"height": 288, "width": 288}
rescale_factor = 1 / 255
default_to_square = False
crop_size = None
do_resize = True
do_center_crop = None
do_rescale = True
do_normalize = False
do_convert_rgb = True
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
__all__ = ["VideoPrismVideoProcessor"]
+360
View File
@@ -0,0 +1,360 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Setup and debug CAN interfaces for Damiao motors (e.g., OpenArms).
Examples:
Setup CAN interfaces with CAN FD:
```shell
lerobot-setup-can --mode=setup --interfaces=can0,can1,can2,can3
```
Test motors on a single interface:
```shell
lerobot-setup-can --mode=test --interfaces=can0
```
Test motors on all interfaces:
```shell
lerobot-setup-can --mode=test --interfaces=can0,can1,can2,can3
```
Speed test:
```shell
lerobot-setup-can --mode=speed --interfaces=can0
```
"""
import subprocess
import sys
import time
from dataclasses import dataclass, field
import draccus
from lerobot.utils.import_utils import is_package_available
MOTOR_NAMES = {
0x01: "joint_1",
0x02: "joint_2",
0x03: "joint_3",
0x04: "joint_4",
0x05: "joint_5",
0x06: "joint_6",
0x07: "joint_7",
0x08: "gripper",
}
@dataclass
class CANSetupConfig:
mode: str = "test"
interfaces: str = "can0" # Comma-separated, e.g. "can0,can1,can2,can3"
bitrate: int = 1000000
data_bitrate: int = 5000000
use_fd: bool = True
motor_ids: list[int] = field(default_factory=lambda: list(range(0x01, 0x09)))
timeout: float = 1.0
speed_iterations: int = 100
def get_interfaces(self) -> list[str]:
return [i.strip() for i in self.interfaces.split(",") if i.strip()]
def check_interface_status(interface: str) -> tuple[bool, str, bool]:
"""Check if CAN interface is UP and configured."""
try:
result = subprocess.run(["ip", "link", "show", interface], capture_output=True, text=True) # nosec B607
if result.returncode != 0:
return False, "Interface not found", False
output = result.stdout
is_up = "UP" in output
is_fd = "fd on" in output.lower() or "canfd" in output.lower()
status = "UP" if is_up else "DOWN"
if is_fd:
status += " (CAN FD)"
return is_up, status, is_fd
except FileNotFoundError:
return False, "ip command not found", False
def setup_interface(interface: str, bitrate: int, data_bitrate: int, use_fd: bool) -> bool:
"""Configure a CAN interface."""
try:
subprocess.run(["sudo", "ip", "link", "set", interface, "down"], check=False, capture_output=True) # nosec B607
cmd = ["sudo", "ip", "link", "set", interface, "type", "can", "bitrate", str(bitrate)]
if use_fd:
cmd.extend(["dbitrate", str(data_bitrate), "fd", "on"])
result = subprocess.run(cmd, capture_output=True, text=True) # nosec B607
if result.returncode != 0:
print(f" ✗ Failed to configure: {result.stderr}")
return False
result = subprocess.run( # nosec B607
["sudo", "ip", "link", "set", interface, "up"], capture_output=True, text=True
)
if result.returncode != 0:
print(f" ✗ Failed to bring up: {result.stderr}")
return False
return True
except Exception as e:
print(f" ✗ Error: {e}")
return False
def test_motor(bus, motor_id: int, timeout: float, use_fd: bool):
"""Test a single motor and return responses."""
import can
enable_msg = can.Message(
arbitration_id=motor_id,
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
is_extended_id=False,
is_fd=use_fd,
)
try:
bus.send(enable_msg)
except Exception as e:
return None, f"Send error: {e}"
responses = []
start_time = time.time()
while time.time() - start_time < timeout:
msg = bus.recv(timeout=0.1)
if msg:
responses.append((msg.arbitration_id, msg.data.hex(), getattr(msg, "is_fd", False)))
disable_msg = can.Message(
arbitration_id=motor_id,
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD],
is_extended_id=False,
is_fd=use_fd,
)
try:
bus.send(disable_msg)
except Exception:
print(f"Error sending message to motor 0x{motor_id:02X}")
return responses, None
def test_interface(cfg: CANSetupConfig, interface: str):
"""Test all motors on a CAN interface."""
import can
is_up, status, _ = check_interface_status(interface)
print(f"\n{interface}: {status}")
if not is_up:
print(f" ⚠ Interface is not UP. Run: lerobot-setup-can --mode=setup --interfaces {interface}")
return {}
try:
kwargs = {"channel": interface, "interface": "socketcan", "bitrate": cfg.bitrate}
if cfg.use_fd:
kwargs.update({"data_bitrate": cfg.data_bitrate, "fd": True})
bus = can.interface.Bus(**kwargs)
except Exception as e:
print(f" ✗ Connection failed: {e}")
return {}
results = {}
try:
while bus.recv(timeout=0.01):
pass
for motor_id in cfg.motor_ids:
motor_name = MOTOR_NAMES.get(motor_id, f"motor_0x{motor_id:02X}")
responses, error = test_motor(bus, motor_id, cfg.timeout, cfg.use_fd)
if error:
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ {error}")
results[motor_id] = {"found": False, "error": error}
elif responses:
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✓ FOUND")
for resp_id, data, is_fd in responses:
fd_flag = " [FD]" if is_fd else ""
print(f" → Response 0x{resp_id:02X}{fd_flag}: {data}")
results[motor_id] = {"found": True, "responses": responses}
else:
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ No response")
results[motor_id] = {"found": False}
time.sleep(0.05)
finally:
bus.shutdown()
found = sum(1 for r in results.values() if r.get("found"))
print(f"\n Summary: {found}/{len(cfg.motor_ids)} motors found")
return results
def speed_test(cfg: CANSetupConfig, interface: str):
"""Test communication speed with motors."""
import can
is_up, status, _ = check_interface_status(interface)
if not is_up:
print(f"{interface}: {status} - skipping")
return
print(f"\n{interface}: Running speed test ({cfg.speed_iterations} iterations)...")
try:
kwargs = {"channel": interface, "interface": "socketcan", "bitrate": cfg.bitrate}
if cfg.use_fd:
kwargs.update({"data_bitrate": cfg.data_bitrate, "fd": True})
bus = can.interface.Bus(**kwargs)
except Exception as e:
print(f" ✗ Connection failed: {e}")
return
responding_motor = None
for motor_id in cfg.motor_ids:
responses, _ = test_motor(bus, motor_id, 0.5, cfg.use_fd)
if responses:
responding_motor = motor_id
break
if not responding_motor:
print(" ✗ No responding motors found")
bus.shutdown()
return
print(f" Testing with motor 0x{responding_motor:02X}...")
latencies = []
for _ in range(cfg.speed_iterations):
start = time.perf_counter()
msg = can.Message(
arbitration_id=responding_motor,
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
is_extended_id=False,
is_fd=cfg.use_fd,
)
bus.send(msg)
resp = bus.recv(timeout=0.1)
if resp:
latencies.append((time.perf_counter() - start) * 1000)
bus.shutdown()
if latencies:
avg_latency = sum(latencies) / len(latencies)
hz = 1000.0 / avg_latency if avg_latency > 0 else 0
print(f" ✓ Success rate: {len(latencies)}/{cfg.speed_iterations}")
print(f" ✓ Avg latency: {avg_latency:.2f} ms")
print(f" ✓ Max frequency: {hz:.1f} Hz")
else:
print(" ✗ No successful responses")
def run_setup(cfg: CANSetupConfig):
"""Setup CAN interfaces."""
print("=" * 50)
print("CAN Interface Setup")
print("=" * 50)
print(f"Mode: {'CAN FD' if cfg.use_fd else 'CAN 2.0'}")
print(f"Bitrate: {cfg.bitrate / 1_000_000:.1f} Mbps")
if cfg.use_fd:
print(f"Data bitrate: {cfg.data_bitrate / 1_000_000:.1f} Mbps")
print()
interfaces = cfg.get_interfaces()
for interface in interfaces:
print(f"Configuring {interface}...")
if setup_interface(interface, cfg.bitrate, cfg.data_bitrate, cfg.use_fd):
is_up, status, _ = check_interface_status(interface)
print(f"{interface}: {status}")
else:
print(f"{interface}: Failed")
print("\nSetup complete!")
print("\nNext: Test motors with:")
print(f" lerobot-setup-can --mode=test --interfaces {','.join(interfaces)}")
def run_test(cfg: CANSetupConfig):
"""Test motors on CAN interfaces."""
print("=" * 50)
print("CAN Motor Test")
print("=" * 50)
print(f"Testing motors 0x{min(cfg.motor_ids):02X}-0x{max(cfg.motor_ids):02X}")
print(f"Mode: {'CAN FD' if cfg.use_fd else 'CAN 2.0'}")
print()
interfaces = cfg.get_interfaces()
all_results = {}
for interface in interfaces:
all_results[interface] = test_interface(cfg, interface)
total_found = sum(sum(1 for r in res.values() if r.get("found")) for res in all_results.values())
print("\n" + "=" * 50)
print("Summary")
print("=" * 50)
print(f"Total motors found: {total_found}")
if total_found == 0:
print("\n⚠ No motors found! Check:")
print(" 1. Motors are powered (24V)")
print(" 2. CAN wiring (CANH, CANL, GND)")
print(" 3. Motor timeout parameter > 0 (use Damiao tools)")
print(" 4. 120Ω termination at both cable ends")
print(f" 5. Interface configured: lerobot-setup-can --mode=setup --interfaces {interfaces[0]}")
def run_speed(cfg: CANSetupConfig):
"""Run speed tests on CAN interfaces."""
print("=" * 50)
print("CAN Speed Test")
print("=" * 50)
for interface in cfg.get_interfaces():
speed_test(cfg, interface)
@draccus.wrap()
def setup_can(cfg: CANSetupConfig):
if not is_package_available("can"):
print("Error: python-can not installed. Install with: pip install python-can")
sys.exit(1)
if cfg.mode == "setup":
run_setup(cfg)
elif cfg.mode == "test":
run_test(cfg)
elif cfg.mode == "speed":
run_speed(cfg)
else:
print(f"Unknown mode: {cfg.mode}")
print("Available modes: setup, test, speed")
sys.exit(1)
def main():
setup_can()
if __name__ == "__main__":
main()
@@ -27,4 +27,4 @@ class OmxLeaderConfig(TeleoperatorConfig):
# Sets the arm in torque mode with the gripper motor set to this value. This makes it possible to squeeze
# the gripper and have it spring back to an open position on its own.
gripper_open_pos: float = 37.0
gripper_open_pos: float = 60.0
@@ -103,7 +103,7 @@ class OmxLeader(Teleoperator):
self.calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=drive_modes[motor],
homing_offset=0,
homing_offset=0 if motor != "gripper" else 100,
range_min=0,
range_max=4095,
)
@@ -123,12 +123,20 @@ class OmxLeader(Teleoperator):
# point
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
if motor == "gripper":
self.bus.write("Drive_Mode", motor, DriveMode.INVERTED.value)
else:
self.bus.write("Drive_Mode", motor, DriveMode.NON_INVERTED.value)
# Use 'position control current based' for gripper to be limited by the limit of the current.
# For the follower gripper, it means it can grasp an object without forcing too much even tho,
# its goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
# For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger
# to make it move, and it will move back to its original target position when we release the force.
self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
self.bus.write("Current_Limit", "gripper", 100)
self.bus.write("Goal_Current", "gripper", 100)
self.bus.write("Homing_Offset", "gripper", 100)
# Set gripper's goal pos in current position mode so that we can use it as a trigger.
self.bus.enable_torque("gripper")
if self.is_calibrated:
+1
View File
@@ -73,6 +73,7 @@ _transformers_available = is_package_available("transformers")
_peft_available = is_package_available("peft")
_scipy_available = is_package_available("scipy")
_reachy2_sdk_available = is_package_available("reachy2_sdk")
_can_available = is_package_available("python-can", "can")
def make_device_from_device_class(config: ChoiceRegistry) -> Any:
+66
View File
@@ -0,0 +1,66 @@
"""Minimal test script for Damiao motor with ID 3."""
import pytest
from lerobot.utils.import_utils import _can_available
if not _can_available:
pytest.skip("python-can not available", allow_module_level=True)
from lerobot.motors import Motor
from lerobot.motors.damiao import DamiaoMotorsBus
@pytest.mark.skip(reason="Requires physical Damiao motor and CAN interface")
def test_damiao_motor():
motors = {
"joint_3": Motor(
id=0x03,
model="damiao",
norm_mode="degrees",
motor_type_str="dm4310",
recv_id=0x13,
),
}
bus = DamiaoMotorsBus(port="can0", motors=motors)
try:
print("Connecting...")
bus.connect()
print("✓ Connected")
print("Enabling torque...")
bus.enable_torque()
print("✓ Torque enabled")
print("Reading all states...")
states = bus.sync_read_all_states()
print(f"✓ States: {states}")
print("Reading position...")
positions = bus.sync_read("Present_Position")
print(f"✓ Position: {positions}")
print("Testing MIT control batch...")
current_pos = states["joint_3"]["position"]
commands = {"joint_3": (10.0, 0.5, current_pos, 0.0, 0.0)}
bus._mit_control_batch(commands)
print("✓ MIT control batch sent")
print("Disabling torque...")
bus.disable_torque()
print("✓ Torque disabled")
print("Setting zero position...")
bus.set_zero_position()
print("✓ Zero position set")
finally:
print("Disconnecting...")
bus.disconnect(disable_torque=True)
print("✓ Disconnected")
if __name__ == "__main__":
test_damiao_motor()
@@ -0,0 +1,50 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for training-time RTC helpers."""
import torch
from lerobot.configs.types import RTCTrainingDelayDistribution
from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig
from lerobot.policies.rtc.training_time import apply_rtc_training_time, sample_rtc_delay
def test_rtc_training_config_defaults():
config = RTCTrainingConfig()
assert config.enabled is False
assert config.min_delay == 0
assert config.max_delay == 0
assert config.delay_distribution == RTCTrainingDelayDistribution.UNIFORM
assert config.exp_decay == 1.0
def test_sample_rtc_delay_uniform_range():
cfg = RTCTrainingConfig(enabled=True, min_delay=1, max_delay=4)
delays = sample_rtc_delay(cfg, batch_size=100, device=torch.device("cpu"))
assert delays.min().item() >= 1
assert delays.max().item() <= 4
def test_apply_rtc_training_time_prefix_mask():
time = torch.tensor([0.5])
delays = torch.tensor([2])
time_tokens, postfix_mask = apply_rtc_training_time(time, delays, seq_len=4)
assert time_tokens.shape == (1, 4)
assert postfix_mask.shape == (1, 4)
# Delay=2 means the first two steps are prefix (time forced to 0.0) and only the last two are postfix.
assert torch.allclose(time_tokens[0], torch.tensor([0.0, 0.0, 0.5, 0.5]))
assert torch.equal(postfix_mask[0], torch.tensor([False, False, True, True]))