mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-12 07:09:43 +00:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f147a4cd48 | |||
| c3fa269b21 | |||
| f6b1c39b78 | |||
| 0c0c171d35 | |||
| 9cfb5ce546 | |||
| 366bef915c | |||
| 9e10eb4a77 | |||
| 385ba8d1b7 | |||
| f4ccf911fa | |||
| 0cb8c92fe4 |
@@ -57,6 +57,8 @@
|
||||
title: Use Async Inference
|
||||
- local: rtc
|
||||
title: Real-Time Chunking (RTC)
|
||||
- local: training_time_rtc
|
||||
title: Training-Time RTC
|
||||
title: "Inference"
|
||||
- sections:
|
||||
- local: envhub
|
||||
@@ -99,6 +101,8 @@
|
||||
title: Unitree G1
|
||||
- local: earthrover_mini_plus
|
||||
title: Earth Rover Mini
|
||||
- local: omx
|
||||
title: OMX
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
@@ -113,6 +117,8 @@
|
||||
title: Notebooks
|
||||
- local: feetech
|
||||
title: Updating Feetech Firmware
|
||||
- local: damiao
|
||||
title: Damiao Motors and CAN Bus
|
||||
title: "Resources"
|
||||
- sections:
|
||||
- local: contributing
|
||||
|
||||
@@ -0,0 +1,165 @@
|
||||
# Damiao Motors and CAN Bus
|
||||
|
||||
This guide covers setup and usage of Damiao motors with LeRobot via CAN bus communication.
|
||||
|
||||
Currently, only Linux is supported, as the OpenArms CAN adapter only has drivers for Linux.
|
||||
|
||||
## Linux CAN Setup
|
||||
|
||||
Before using Damiao motors, you need to set up the CAN interface on your Linux system.
|
||||
|
||||
### Install CAN Utilities
|
||||
|
||||
```bash
|
||||
sudo apt-get install can-utils
|
||||
```
|
||||
|
||||
### Configure CAN Interface (Manual)
|
||||
|
||||
For standard CAN FD (recommended for OpenArms):
|
||||
|
||||
```bash
|
||||
sudo ip link set can0 down
|
||||
sudo ip link set can0 type can bitrate 1000000 dbitrate 5000000 fd on
|
||||
sudo ip link set can0 up
|
||||
```
|
||||
|
||||
For standard CAN (without FD):
|
||||
|
||||
```bash
|
||||
sudo ip link set can0 down
|
||||
sudo ip link set can0 type can bitrate 1000000
|
||||
sudo ip link set can0 up
|
||||
```
|
||||
|
||||
### Configure CAN Interface (Using LeRobot)
|
||||
|
||||
LeRobot provides a utility script to setup and test CAN interfaces:
|
||||
|
||||
```bash
|
||||
# Setup multiple interfaces (e.g., OpenArms Followers with 2 CAN buses)
|
||||
lerobot-setup-can --mode=setup --interfaces=can0,can1
|
||||
```
|
||||
|
||||
## Debugging CAN Communication
|
||||
|
||||
Use the built-in debug tools to test motor communication:
|
||||
|
||||
```bash
|
||||
# Test motors on all interfaces
|
||||
lerobot-setup-can --mode=test --interfaces=can0,can1
|
||||
|
||||
# Run speed/latency test
|
||||
lerobot-setup-can --mode=speed --interfaces=can0
|
||||
```
|
||||
|
||||
The test mode will scan for motors (IDs 0x01-0x08) and report which ones respond. Example output:
|
||||
|
||||
```
|
||||
can0: UP (CAN FD)
|
||||
Motor 0x01 (joint_1): ✓ FOUND
|
||||
→ Response 0x11 [FD]: 00112233...
|
||||
Motor 0x02 (joint_2): ✓ FOUND
|
||||
Motor 0x03 (joint_3): ✗ No response
|
||||
...
|
||||
Summary: 2/8 motors found
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
from lerobot.motors import Motor
|
||||
from lerobot.motors.damiao import DamiaoMotorsBus
|
||||
|
||||
# Define your motors with send/receive CAN IDs
|
||||
motors = {
|
||||
"joint_1": Motor(id=0x01, motor_type_str="dm8009", recv_id=0x11),
|
||||
"joint_2": Motor(id=0x02, motor_type_str="dm4340", recv_id=0x12),
|
||||
"joint_3": Motor(id=0x03, motor_type_str="dm4310", recv_id=0x13),
|
||||
}
|
||||
|
||||
# Create the bus
|
||||
bus = DamiaoMotorsBus(
|
||||
port="can0", # Linux socketcan interface
|
||||
motors=motors,
|
||||
)
|
||||
|
||||
# Connect
|
||||
bus.connect()
|
||||
```
|
||||
|
||||
### Reading Motor States
|
||||
|
||||
```python
|
||||
# Read single motor position (degrees)
|
||||
position = bus.read("Present_Position", "joint_1")
|
||||
|
||||
# Read from multiple motors
|
||||
positions = bus.sync_read("Present_Position") # All motors
|
||||
positions = bus.sync_read("Present_Position", ["joint_1", "joint_2"])
|
||||
|
||||
# Read all states at once (position, velocity, torque)
|
||||
states = bus.sync_read_all_states()
|
||||
# Returns: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
|
||||
```
|
||||
|
||||
### Writing Motor Commands
|
||||
|
||||
```python
|
||||
# Enable torque
|
||||
bus.enable_torque()
|
||||
|
||||
# Set goal position (degrees)
|
||||
bus.write("Goal_Position", "joint_1", 45.0)
|
||||
|
||||
# Set positions for multiple motors
|
||||
bus.sync_write("Goal_Position", {
|
||||
"joint_1": 45.0,
|
||||
"joint_2": -30.0,
|
||||
"joint_3": 90.0,
|
||||
})
|
||||
|
||||
# Disable torque
|
||||
bus.disable_torque()
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| -------------- | --------- | ----------------------------------------------------------- |
|
||||
| `port` | - | CAN interface (`can0`) or serial port (`/dev/cu.usbmodem*`) |
|
||||
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
|
||||
| `bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
|
||||
| `data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
|
||||
|
||||
## Motor Configuration
|
||||
|
||||
Each motor requires:
|
||||
|
||||
- `id`: CAN ID for sending commands
|
||||
- `motor_type`: One of the supported motor types (e.g., `"dm8009"`, `"dm4340"`)
|
||||
- `recv_id`: CAN ID for receiving responses
|
||||
|
||||
OpenArms default IDs follow the pattern: send ID `0x0N`, receive ID `0x1N` where N is the joint number.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### No Response from Motors
|
||||
|
||||
1. **Check power**
|
||||
2. **Verify CAN wiring**: Check CAN-H, CAN-L, and GND connections
|
||||
3. **Check motor IDs**: Use Damiao Debugging Tools to verify/configure IDs
|
||||
4. **Test CAN interface**: Run `candump can0` to see if messages are being received
|
||||
5. **Run diagnostics**: `lerobot-setup-can --mode=test --interfaces=can0`
|
||||
|
||||
### Motor Timeout Parameter
|
||||
|
||||
If motors were configured with timeout=0, they won't respond to commands. Use Damiao Debugging Tools to set a non-zero timeout value.
|
||||
|
||||
### Verify CAN FD Status
|
||||
|
||||
```bash
|
||||
ip -d link show can0 | grep fd
|
||||
```
|
||||
@@ -1,5 +1,11 @@
|
||||
# EarthRover Mini Plus
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Earth_Rover_Mini_5_240c9adc-4f9e-44b7-982f-5d1dc24af1d8.png.webp"
|
||||
alt="EarthRover Mini Plus"
|
||||
width="70%"
|
||||
/>
|
||||
|
||||
The EarthRover Mini Plus is a fully open source mobile robot that connects through the cloud using the Frodobots SDK. This lets you control the robot and record datasets for training AI models.
|
||||
|
||||
## What You Need
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
# LeKiwi
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/1740517739083.jpeg"
|
||||
alt="LeKiwi"
|
||||
width="70%"
|
||||
/>
|
||||
|
||||
In the steps below, we explain how to assemble the LeKiwi mobile robot.
|
||||
|
||||
## Source the parts
|
||||
|
||||
@@ -42,6 +42,7 @@ lerobot-eval \
|
||||
```
|
||||
|
||||
- `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.).
|
||||
- `--env.task_ids` picks task ids to run (`[0]`, `[1,2,3]`, etc.). Omit this flag (or set it to `null`) to run all tasks in the suite.
|
||||
- `--eval.batch_size` controls how many environments run in parallel.
|
||||
- `--eval.n_episodes` sets how many episodes to run in total.
|
||||
|
||||
|
||||
@@ -0,0 +1,197 @@
|
||||
## Order and Assemble the parts
|
||||
|
||||
First, assemble the OMX hardware following the official assembly guide.
|
||||
|
||||
OMX Assembly Guide: https://ai.robotis.com/omx/assembly_guide_omx.html
|
||||
|
||||
OMX robots are shipped preconfigured from the factory. Motor IDs, communication parameters, and joint offsets are already set, so no additional motor setup or calibration is required before using LeRobot.
|
||||
|
||||
## Install LeRobot 🤗
|
||||
|
||||
To install LeRobot, follow our [Installation Guide](./installation)
|
||||
|
||||
In addition to these instructions, you need to install the Dynamixel SDK:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dynamixel]"
|
||||
```
|
||||
|
||||
## Connect the robot
|
||||
|
||||
To find the port for each bus servo adapter, run this script:
|
||||
|
||||
```bash
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
This command runs and when prompted, disconnect the USB cable from either the leader or follower arm and press Enter. The output will show 'The port of this MotorsBus is [port]'. This identifies the port for the disconnected arm. Repeat for the other arm to identify both ports.
|
||||
|
||||
<hfoptions id="find_port">
|
||||
<hfoption id="Mac">
|
||||
|
||||
Example output on macOS:
|
||||
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the USB cable from your MotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect corresponding leader or follower arm and press Enter...]
|
||||
|
||||
The port of this MotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the USB cable.
|
||||
```
|
||||
|
||||
Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Linux">
|
||||
|
||||
On Linux, we strongly recommend using udev rules to assign persistent and human-readable device names to the OMX leader and follower arms. This avoids issues where device names such as ttyACM0 and ttyACM1 change when the robot is unplugged, replugged, or when the system is rebooted.
|
||||
|
||||
#### 1. Find your device serial numbers
|
||||
|
||||
You should have obtained the port numbers like ../../ttyACM? for the leader and follower using `lerobot-find-port`. You can match those results with the serial numbers using the `ls -l /dev/serial/by-id/` command.
|
||||
To create udev rules, you need the unique serial number for each OMX device. The easiest way is to list devices under:
|
||||
|
||||
```bash
|
||||
ls -l /dev/serial/by-id/
|
||||
```
|
||||
|
||||
You will see output similar to:
|
||||
|
||||
```bash
|
||||
usb-ROBOTIS_OpenRB-150_228BDD7B503059384C2E3120FF0A2B19-if00 -> ../../ttyACM0
|
||||
usb-ROBOTIS_OpenRB-150_67E1ED68503059384C2E3120FF092234-if00 -> ../../ttyACM1
|
||||
```
|
||||
|
||||
In each line, the serial number is the long string after `usb-ROBOTIS_OpenRB-150_` and before `-if00`.
|
||||
|
||||
Follower serial: `228BDD7B503059384C2E3120FF0A2B19`
|
||||
|
||||
Leader serial: `67E1ED68503059384C2E3120FF092234`
|
||||
|
||||
#### 2. Create the udev rule
|
||||
|
||||
Create a new udev rule file:
|
||||
|
||||
```bash
|
||||
sudo nano /etc/udev/rules.d/99-omx.rules
|
||||
```
|
||||
|
||||
Paste the following lines, replacing the serial numbers with the values you found above:
|
||||
|
||||
```bash
|
||||
SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="228BDD7B503059384C2E3120FF0A2B19", SYMLINK+="omx_follower"
|
||||
SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="67E1ED68503059384C2E3120FF092234", SYMLINK+="omx_leader"
|
||||
```
|
||||
|
||||
Save the file and reload udev rules:
|
||||
|
||||
```bash
|
||||
sudo udevadm control --reload-rules
|
||||
sudo udevadm trigger
|
||||
```
|
||||
|
||||
Now unplug and replug both devices once.
|
||||
|
||||
#### 3. Verify the symlinks
|
||||
|
||||
Check that the persistent device names exist:
|
||||
|
||||
```bash
|
||||
ls -l /dev/omx_follower /dev/omx_leader
|
||||
```
|
||||
|
||||
You should see them pointing to ttyACM\* devices:
|
||||
|
||||
```bash
|
||||
/dev/omx_follower -> ttyACM*
|
||||
/dev/omx_leader -> ttyACM*
|
||||
```
|
||||
|
||||
These names remain stable across reboots and reconnections.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Teleoperate
|
||||
|
||||
After identifying the correct ports, you can directly teleoperate the follower arm using the leader arm.
|
||||
|
||||
<hfoptions id="teleoperate">
|
||||
<hfoption id="Mac">
|
||||
|
||||
### Teleoperate without camera
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=<your_follower_port> \
|
||||
--robot.id=omx_follower_arm \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=<your_leader_port> \
|
||||
--teleop.id=omx_leader_arm
|
||||
```
|
||||
|
||||
During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps.
|
||||
|
||||
### Teleoperate with camera
|
||||
|
||||
You can also enable camera input during teleoperation by providing a camera configuration for the follower arm.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=<your_follower_port> \
|
||||
--robot.id=omx_follower_arm \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=<your_leader_port> \
|
||||
--teleop.id=omx_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Linux">
|
||||
|
||||
### Teleoperate without camera
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=/dev/omx_follower \
|
||||
--robot.id=omx_follower_arm \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=/dev/omx_leader \
|
||||
--teleop.id=omx_leader_arm
|
||||
```
|
||||
|
||||
During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps.
|
||||
|
||||
### Teleoperate with camera
|
||||
|
||||
You can also enable camera input during teleoperation by providing a camera configuration for the follower arm.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=/dev/omx_follower \
|
||||
--robot.id=omx_follower_arm \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=/dev/omx_leader \
|
||||
--teleop.id=omx_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own.
|
||||
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/robotis).
|
||||
@@ -1,5 +1,18 @@
|
||||
# SO-101
|
||||
|
||||
<div style="display: flex; align-items: center; gap: 10px;">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/SO101_Follower.webp"
|
||||
alt="SO-101"
|
||||
width="60%"
|
||||
/>
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/SO101_Leader.webp"
|
||||
alt="SO-101"
|
||||
width="60%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
In the steps below, we explain how to assemble our flagship robot, the SO-101.
|
||||
|
||||
## Source the parts
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
# Training-Time RTC
|
||||
|
||||
Training-Time RTC teaches the model to handle inference delay during training.
|
||||
It feeds the **ground-truth action prefix** to the model and trains only on the remaining postfix actions.
|
||||
This keeps chunk transitions smooth without doing any inference-time inpainting.
|
||||
|
||||
Based on: [Training-Time Action Conditioning for Efficient Real-Time Chunking](https://arxiv.org/abs/2512.05964).
|
||||
|
||||
LeRobot supports this for `pi0`, `pi05` and `smolvla` without changing model parameters.
|
||||
|
||||
---
|
||||
|
||||
## How It Works
|
||||
|
||||
### At Training Time
|
||||
|
||||
- Sample a delay `d` per batch element.
|
||||
- Keep the first `d` action steps as **ground truth** (no noise).
|
||||
- Add noise only to the postfix actions.
|
||||
- Set the flow-matching timestep to **1.0** for prefix tokens and normal timesteps for postfix tokens.
|
||||
- Mask the loss to only train on the postfix.
|
||||
|
||||
### At Inference Time
|
||||
|
||||
When `rtc_training_config.enabled=true`, the model uses training-time RTC inference:
|
||||
|
||||
- Replace prefix positions in `x_t` with previous chunk's leftover actions.
|
||||
- Set timestep to **1.0** for prefix positions.
|
||||
|
||||
---
|
||||
|
||||
## Quick Start (CLI)
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=pi0 \
|
||||
--dataset.repo_id=your/dataset \
|
||||
--policy.rtc_training_config.enabled=true \
|
||||
--policy.rtc_training_config.min_delay=0 \
|
||||
--policy.rtc_training_config.max_delay=6 \
|
||||
--policy.rtc_training_config.delay_distribution=UNIFORM
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Inference with Training-Time RTC
|
||||
|
||||
After training with `rtc_training_config`, use the same config at inference. The model will automatically use training-time RTC inference:
|
||||
|
||||
```python
|
||||
policy = PI0Policy.from_pretrained("path/to/trained/model")
|
||||
# rtc_training_config is loaded from the saved config
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
batch,
|
||||
inference_delay=5, # estimated delay in timesteps
|
||||
prev_chunk_left_over=previous_actions, # from previous chunk
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Parameters
|
||||
|
||||
`RTCTrainingConfig` is available on the policy config (`pi0`, `pi05`, `smolvla`, `xvla`):
|
||||
|
||||
- **`enabled`**: Toggle training-time RTC (both training and inference).
|
||||
- **`min_delay` / `max_delay`**: Delay range (inclusive).
|
||||
- **`delay_distribution`**:
|
||||
- `UNIFORM`: uniform in `[min_delay, max_delay]`
|
||||
- `EXP`: exponentially decayed distribution over delays
|
||||
- **`exp_decay`**: Exponential decay factor for `EXP` sampling.
|
||||
|
||||
---
|
||||
|
||||
## Notes and Recommendations
|
||||
|
||||
- Start with `min_delay=0` and `max_delay` around your expected worst-case inference delay.
|
||||
- Use `EXP` if you want more supervision on smaller delays.
|
||||
|
||||
---
|
||||
|
||||
## Related Docs
|
||||
|
||||
- [Real-Time Chunking (Inference-Time RTC)](./rtc)
|
||||
- [Pi0](./pi0), [Pi0.5](./pi05), [SmolVLA](./smolvla)
|
||||
@@ -102,6 +102,7 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
damiao = ["python-can>=4.2.0,<5.0.0"]
|
||||
|
||||
# Robots
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
@@ -203,6 +204,7 @@ lerobot-info="lerobot.scripts.lerobot_info:main"
|
||||
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.packages.find]
|
||||
@@ -278,6 +280,7 @@ default.extend-ignore-identifiers-re = [
|
||||
"thw",
|
||||
"inpt",
|
||||
"ROBOTIS",
|
||||
"OT_VALUE"
|
||||
]
|
||||
|
||||
# TODO: Uncomment when ready to use
|
||||
|
||||
@@ -50,3 +50,8 @@ class RTCAttentionSchedule(str, Enum):
|
||||
ONES = "ONES"
|
||||
LINEAR = "LINEAR"
|
||||
EXP = "EXP"
|
||||
|
||||
|
||||
class RTCTrainingDelayDistribution(str, Enum):
|
||||
UNIFORM = "UNIFORM"
|
||||
EXP = "EXP"
|
||||
|
||||
@@ -260,6 +260,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
@dataclass
|
||||
class LiberoEnv(EnvConfig):
|
||||
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
|
||||
task_ids: list[int] | None = None
|
||||
fps: int = 30
|
||||
episode_length: int | None = None
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
@@ -338,10 +339,10 @@ class LiberoEnv(EnvConfig):
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
}
|
||||
kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
||||
if self.task_ids is not None:
|
||||
kwargs["task_ids"] = self.task_ids
|
||||
return kwargs
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("metaworld")
|
||||
|
||||
@@ -14,4 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus
|
||||
from .motors_bus import (
|
||||
Motor,
|
||||
MotorCalibration,
|
||||
MotorNormMode,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ from dataclasses import dataclass
|
||||
|
||||
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"
|
||||
|
||||
from lerobot.motors import MotorCalibration, MotorsBus
|
||||
from .motors_bus import MotorCalibration, MotorsBus
|
||||
|
||||
BAR_LEN, BAR_THICKNESS = 450, 8
|
||||
HANDLE_R = 10
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .damiao import DamiaoMotorsBus
|
||||
from .tables import *
|
||||
@@ -0,0 +1,808 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Portions of this file are derived from DM_Control_Python by cmjang.
|
||||
# Licensed under the MIT License; see `LICENSE` for the full text:
|
||||
# https://github.com/cmjang/DM_Control_Python
|
||||
|
||||
import logging
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from lerobot.utils.import_utils import _can_available
|
||||
|
||||
if TYPE_CHECKING or _can_available:
|
||||
import can
|
||||
else:
|
||||
can.Message = object
|
||||
can.interface = None
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBusBase, NameOrID, Value
|
||||
from .tables import (
|
||||
AVAILABLE_BAUDRATES,
|
||||
CAN_CMD_DISABLE,
|
||||
CAN_CMD_ENABLE,
|
||||
CAN_CMD_REFRESH,
|
||||
CAN_CMD_SET_ZERO,
|
||||
CAN_PARAM_ID,
|
||||
DEFAULT_BAUDRATE,
|
||||
DEFAULT_TIMEOUT_MS,
|
||||
MIT_KD_RANGE,
|
||||
MIT_KP_RANGE,
|
||||
MOTOR_LIMIT_PARAMS,
|
||||
MotorType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
LONG_TIMEOUT_SEC = 0.1
|
||||
MEDIUM_TIMEOUT_SEC = 0.01
|
||||
SHORT_TIMEOUT_SEC = 0.001
|
||||
PRECISE_TIMEOUT_SEC = 0.0001
|
||||
|
||||
|
||||
class MotorState(TypedDict):
|
||||
position: float
|
||||
velocity: float
|
||||
torque: float
|
||||
temp_mos: float
|
||||
temp_rotor: float
|
||||
|
||||
|
||||
class DamiaoMotorsBus(MotorsBusBase):
|
||||
"""
|
||||
The Damiao implementation for a MotorsBus using CAN bus communication.
|
||||
|
||||
This class uses python-can for CAN bus communication with Damiao motors.
|
||||
For more info, see:
|
||||
- python-can documentation: https://python-can.readthedocs.io/en/stable/
|
||||
- Seedstudio documentation: https://wiki.seeedstudio.com/damiao_series/
|
||||
- DM_Control_Python repo: https://github.com/cmjang/DM_Control_Python
|
||||
"""
|
||||
|
||||
# CAN-specific settings
|
||||
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
|
||||
default_baudrate = DEFAULT_BAUDRATE
|
||||
default_timeout = DEFAULT_TIMEOUT_MS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
can_interface: str = "auto",
|
||||
use_can_fd: bool = True,
|
||||
bitrate: int = 1000000,
|
||||
data_bitrate: int | None = 5000000,
|
||||
):
|
||||
"""
|
||||
Initialize the Damiao motors bus.
|
||||
|
||||
Args:
|
||||
port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS)
|
||||
motors: Dictionary mapping motor names to Motor objects
|
||||
calibration: Optional calibration data
|
||||
can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial)
|
||||
use_can_fd: Whether to use CAN FD mode (default: True for OpenArms)
|
||||
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
|
||||
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
|
||||
"""
|
||||
super().__init__(port, motors, calibration)
|
||||
self.port = port
|
||||
self.can_interface = can_interface
|
||||
self.use_can_fd = use_can_fd
|
||||
self.bitrate = bitrate
|
||||
self.data_bitrate = data_bitrate
|
||||
self.canbus: can.interface.Bus | None = None
|
||||
self._is_connected = False
|
||||
|
||||
# Map motor names to CAN IDs
|
||||
self._motor_can_ids: dict[str, int] = {}
|
||||
self._recv_id_to_motor: dict[int, str] = {}
|
||||
self._motor_types: dict[str, MotorType] = {}
|
||||
|
||||
for name, motor in self.motors.items():
|
||||
if motor.motor_type_str is None:
|
||||
raise ValueError(f"Motor '{name}' is missing required 'motor_type'")
|
||||
self._motor_types[name] = getattr(MotorType, motor.motor_type_str.upper().replace("-", "_"))
|
||||
|
||||
# Map recv_id to motor name for filtering responses
|
||||
if motor.recv_id is not None:
|
||||
self._recv_id_to_motor[motor.recv_id] = name
|
||||
|
||||
# State cache for handling packet drops safely
|
||||
self._last_known_states: dict[str, MotorState] = {
|
||||
name: {
|
||||
"position": 0.0,
|
||||
"velocity": 0.0,
|
||||
"torque": 0.0,
|
||||
"temp_mos": 0.0,
|
||||
"temp_rotor": 0.0,
|
||||
}
|
||||
for name in self.motors
|
||||
}
|
||||
|
||||
# Dynamic gains storage
|
||||
# Defaults: Kp=10.0 (Stiffness), Kd=0.5 (Damping)
|
||||
self._gains: dict[str, dict[str, float]] = {name: {"kp": 10.0, "kd": 0.5} for name in self.motors}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the CAN bus is connected."""
|
||||
return self._is_connected and self.canbus is not None
|
||||
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""
|
||||
Open the CAN bus and initialize communication.
|
||||
|
||||
Args:
|
||||
handshake: If True, ping all motors to verify they're present
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is already connected."
|
||||
)
|
||||
|
||||
try:
|
||||
# Auto-detect interface type based on port name
|
||||
if self.can_interface == "auto":
|
||||
if self.port.startswith("/dev/"):
|
||||
self.can_interface = "slcan"
|
||||
logger.info(f"Auto-detected slcan interface for port {self.port}")
|
||||
else:
|
||||
self.can_interface = "socketcan"
|
||||
logger.info(f"Auto-detected socketcan interface for port {self.port}")
|
||||
|
||||
# Connect to CAN bus
|
||||
kwargs = {
|
||||
"channel": self.port,
|
||||
"bitrate": self.bitrate,
|
||||
"interface": self.can_interface,
|
||||
}
|
||||
|
||||
if self.can_interface == "socketcan" and self.use_can_fd and self.data_bitrate is not None:
|
||||
kwargs.update({"data_bitrate": self.data_bitrate, "fd": True})
|
||||
logger.info(
|
||||
f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Connected to {self.port} with {self.can_interface} (bitrate={self.bitrate})")
|
||||
|
||||
self.canbus = can.interface.Bus(**kwargs)
|
||||
self._is_connected = True
|
||||
|
||||
if handshake:
|
||||
self._handshake()
|
||||
|
||||
logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.")
|
||||
except Exception as e:
|
||||
self._is_connected = False
|
||||
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
|
||||
|
||||
def _handshake(self) -> None:
|
||||
"""
|
||||
Verify all motors are present and populate initial state cache.
|
||||
Raises ConnectionError if any motor fails to respond.
|
||||
"""
|
||||
logger.info("Starting handshake with motors...")
|
||||
missing_motors = []
|
||||
|
||||
for motor_name in self.motors:
|
||||
msg = self._refresh_motor(motor_name)
|
||||
if msg is None:
|
||||
missing_motors.append(motor_name)
|
||||
else:
|
||||
self._process_response(motor_name, msg)
|
||||
time.sleep(MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
if missing_motors:
|
||||
raise ConnectionError(
|
||||
f"Handshake failed. The following motors did not respond: {missing_motors}. "
|
||||
"Check power (24V) and CAN wiring."
|
||||
)
|
||||
logger.info("Handshake successful. All motors ready.")
|
||||
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""
|
||||
Close the CAN bus connection.
|
||||
|
||||
Args:
|
||||
disable_torque: If True, disable torque on all motors before disconnecting
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.")
|
||||
|
||||
if disable_torque:
|
||||
try:
|
||||
self.disable_torque()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to disable torque during disconnect: {e}")
|
||||
|
||||
if self.canbus:
|
||||
self.canbus.shutdown()
|
||||
self.canbus = None
|
||||
self._is_connected = False
|
||||
logger.debug(f"{self.__class__.__name__} disconnected.")
|
||||
|
||||
def configure_motors(self) -> None:
|
||||
"""Configure all motors with default settings."""
|
||||
# Damiao motors don't require much configuration in MIT mode
|
||||
# Just ensure they're enabled
|
||||
for motor in self.motors:
|
||||
self._send_simple_command(motor, CAN_CMD_ENABLE)
|
||||
time.sleep(MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
def _send_simple_command(self, motor: NameOrID, command_byte: int) -> None:
|
||||
"""Helper to send simple 8-byte commands (Enable, Disable, Zero)."""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [0xFF] * 7 + [command_byte]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
if msg := self._recv_motor_response(expected_recv_id=recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
else:
|
||||
logger.debug(f"No response from {motor_name} after command 0x{command_byte:02X}")
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Enable torque on selected motors."""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
for motor in target_motors:
|
||||
for _ in range(num_retry + 1):
|
||||
try:
|
||||
self._send_simple_command(motor, CAN_CMD_ENABLE)
|
||||
break
|
||||
except Exception as e:
|
||||
if _ == num_retry:
|
||||
raise e
|
||||
time.sleep(MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Disable torque on selected motors."""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
for motor in target_motors:
|
||||
for _ in range(num_retry + 1):
|
||||
try:
|
||||
self._send_simple_command(motor, CAN_CMD_DISABLE)
|
||||
break
|
||||
except Exception as e:
|
||||
if _ == num_retry:
|
||||
raise e
|
||||
time.sleep(MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
@contextmanager
|
||||
def torque_disabled(self, motors: str | list[str] | None = None):
|
||||
"""
|
||||
Context manager that guarantees torque is re-enabled.
|
||||
|
||||
This helper is useful to temporarily disable torque when configuring motors.
|
||||
"""
|
||||
self.disable_torque(motors)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.enable_torque(motors)
|
||||
|
||||
def set_zero_position(self, motors: str | list[str] | None = None) -> None:
|
||||
"""Set current position as zero for selected motors."""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
for motor in target_motors:
|
||||
self._send_simple_command(motor, CAN_CMD_SET_ZERO)
|
||||
time.sleep(MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
def _refresh_motor(self, motor: NameOrID) -> can.Message | None:
|
||||
"""Refresh motor status and return the response."""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
return self._recv_motor_response(expected_recv_id=recv_id)
|
||||
|
||||
def _recv_motor_response(
|
||||
self, expected_recv_id: int | None = None, timeout: float = 0.001
|
||||
) -> can.Message | None:
|
||||
"""
|
||||
Receive a response from a motor.
|
||||
|
||||
Args:
|
||||
expected_recv_id: If provided, only return messages from this CAN ID
|
||||
timeout: Timeout in seconds (default: 1ms for high-speed operation)
|
||||
Returns:
|
||||
CAN message if received, None otherwise
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
messages_seen = []
|
||||
while time.time() - start_time < timeout:
|
||||
msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC)
|
||||
if msg:
|
||||
messages_seen.append(f"0x{msg.arbitration_id:02X}")
|
||||
if expected_recv_id is None or msg.arbitration_id == expected_recv_id:
|
||||
return msg
|
||||
logger.debug(
|
||||
f"Ignoring message from 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}"
|
||||
)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
if messages_seen:
|
||||
logger.debug(
|
||||
f"Received {len(messages_seen)} msgs from {set(messages_seen)}, expected 0x{expected_recv_id:02X}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No CAN messages received (expected 0x{expected_recv_id:02X})")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to receive CAN message: {e}")
|
||||
return None
|
||||
|
||||
def _recv_all_responses(
|
||||
self, expected_recv_ids: list[int], timeout: float = 0.002
|
||||
) -> dict[int, can.Message]:
|
||||
"""
|
||||
Efficiently receive responses from multiple motors at once.
|
||||
Uses the OpenArms pattern: collect all available messages within timeout.
|
||||
|
||||
Args:
|
||||
expected_recv_ids: List of CAN IDs we expect responses from
|
||||
timeout: Total timeout in seconds (default: 2ms)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping recv_id to CAN message
|
||||
"""
|
||||
responses = {}
|
||||
expected_set = set(expected_recv_ids)
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
|
||||
# 100us poll timeout
|
||||
msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC)
|
||||
if msg and msg.arbitration_id in expected_set:
|
||||
responses[msg.arbitration_id] = msg
|
||||
if len(responses) == len(expected_recv_ids):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"Error receiving responses: {e}")
|
||||
|
||||
return responses
|
||||
|
||||
def _encode_mit_packet(
|
||||
self,
|
||||
motor_type: MotorType,
|
||||
kp: float,
|
||||
kd: float,
|
||||
position_degrees: float,
|
||||
velocity_deg_per_sec: float,
|
||||
torque: float,
|
||||
) -> list[int]:
|
||||
"""Helper to encode control parameters into 8 bytes for MIT mode."""
|
||||
# Convert degrees to radians
|
||||
position_rad = np.radians(position_degrees)
|
||||
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
|
||||
|
||||
# Get motor limits
|
||||
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||
|
||||
# Encode parameters
|
||||
kp_uint = self._float_to_uint(kp, *MIT_KP_RANGE, 12)
|
||||
kd_uint = self._float_to_uint(kd, *MIT_KD_RANGE, 12)
|
||||
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
|
||||
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
|
||||
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
|
||||
|
||||
# Pack data
|
||||
data = [0] * 8
|
||||
data[0] = (q_uint >> 8) & 0xFF
|
||||
data[1] = q_uint & 0xFF
|
||||
data[2] = dq_uint >> 4
|
||||
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
|
||||
data[4] = kp_uint & 0xFF
|
||||
data[5] = kd_uint >> 4
|
||||
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
|
||||
data[7] = tau_uint & 0xFF
|
||||
return data
|
||||
|
||||
def _mit_control(
|
||||
self,
|
||||
motor: NameOrID,
|
||||
kp: float,
|
||||
kd: float,
|
||||
position_degrees: float,
|
||||
velocity_deg_per_sec: float,
|
||||
torque: float,
|
||||
) -> None:
|
||||
"""Send MIT control command to a motor."""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types[motor_name]
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
if msg := self._recv_motor_response(expected_recv_id=recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
else:
|
||||
logger.debug(f"No response from {motor_name} after MIT control command")
|
||||
|
||||
def _mit_control_batch(
|
||||
self,
|
||||
commands: dict[NameOrID, tuple[float, float, float, float, float]],
|
||||
) -> None:
|
||||
"""
|
||||
Send MIT control commands to multiple motors in batch.
|
||||
Sends all commands first, then collects responses.
|
||||
|
||||
Args:
|
||||
commands: Dict mapping motor name/ID to (kp, kd, position_deg, velocity_deg/s, torque)
|
||||
Example: {'joint_1': (10.0, 0.5, 45.0, 0.0, 0.0), ...}
|
||||
"""
|
||||
if not commands:
|
||||
return
|
||||
|
||||
recv_id_to_motor: dict[int, str] = {}
|
||||
|
||||
# Step 1: Send all MIT control commands
|
||||
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types[motor_name]
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
|
||||
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
||||
|
||||
# Step 2: Collect responses and update state cache
|
||||
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=SHORT_TIMEOUT_SEC)
|
||||
for recv_id, motor_name in recv_id_to_motor.items():
|
||||
if msg := responses.get(recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
|
||||
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
|
||||
"""Convert float to unsigned integer for CAN transmission."""
|
||||
x = max(x_min, min(x_max, x)) # Clamp to range
|
||||
span = x_max - x_min
|
||||
data_norm = (x - x_min) / span
|
||||
return int(data_norm * ((1 << bits) - 1))
|
||||
|
||||
def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float:
|
||||
"""Convert unsigned integer from CAN to float."""
|
||||
span = x_max - x_min
|
||||
data_norm = float(x) / ((1 << bits) - 1)
|
||||
return data_norm * span + x_min
|
||||
|
||||
def _decode_motor_state(
|
||||
self, data: bytearray | bytes, motor_type: MotorType
|
||||
) -> tuple[float, float, float, int, int]:
|
||||
"""
|
||||
Decode motor state from CAN data.
|
||||
Returns: (position_deg, velocity_deg_s, torque, temp_mos, temp_rotor)
|
||||
"""
|
||||
if len(data) < 8:
|
||||
raise ValueError("Invalid motor state data")
|
||||
|
||||
# Extract encoded values
|
||||
q_uint = (data[1] << 8) | data[2]
|
||||
dq_uint = (data[3] << 4) | (data[4] >> 4)
|
||||
tau_uint = ((data[4] & 0x0F) << 8) | data[5]
|
||||
t_mos = data[6]
|
||||
t_rotor = data[7]
|
||||
|
||||
# Get motor limits
|
||||
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||
|
||||
# Decode to physical values
|
||||
position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16)
|
||||
velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12)
|
||||
torque = self._uint_to_float(tau_uint, -tmax, tmax, 12)
|
||||
|
||||
return np.degrees(position_rad), np.degrees(velocity_rad_per_sec), torque, t_mos, t_rotor
|
||||
|
||||
def _process_response(self, motor: str, msg: can.Message) -> None:
|
||||
"""Decode a message and update the motor state cache."""
|
||||
try:
|
||||
motor_type = self._motor_types[motor]
|
||||
pos, vel, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
|
||||
|
||||
self._last_known_states[motor] = {
|
||||
"position": pos,
|
||||
"velocity": vel,
|
||||
"torque": torque,
|
||||
"temp_mos": float(t_mos),
|
||||
"temp_rotor": float(t_rotor),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to decode response from {motor}: {e}")
|
||||
|
||||
def read(self, data_name: str, motor: str) -> Value:
|
||||
"""Read a value from a single motor. Positions are always in degrees."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Refresh motor to get latest state
|
||||
msg = self._refresh_motor(motor)
|
||||
if msg is None:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
raise ConnectionError(
|
||||
f"No response from motor '{motor}' (send ID: 0x{motor_id:02X}, recv ID: 0x{recv_id:02X}). "
|
||||
f"Check that: 1) Motor is powered (24V), 2) CAN wiring is correct, "
|
||||
f"3) Motor IDs are configured correctly using Damiao Debugging Tools"
|
||||
)
|
||||
|
||||
self._process_response(motor, msg)
|
||||
return self._get_cached_value(motor, data_name)
|
||||
|
||||
def _get_cached_value(self, motor: str, data_name: str) -> Value:
|
||||
"""Retrieve a specific value from the cache."""
|
||||
state = self._last_known_states[motor]
|
||||
mapping: dict[str, Any] = {
|
||||
"Present_Position": state["position"],
|
||||
"Present_Velocity": state["velocity"],
|
||||
"Present_Torque": state["torque"],
|
||||
"Temperature_MOS": state["temp_mos"],
|
||||
"Temperature_Rotor": state["temp_rotor"],
|
||||
}
|
||||
if data_name not in mapping:
|
||||
raise ValueError(f"Unknown data_name: {data_name}")
|
||||
return mapping[data_name]
|
||||
|
||||
def write(
|
||||
self,
|
||||
data_name: str,
|
||||
motor: str,
|
||||
value: Value,
|
||||
) -> None:
|
||||
"""
|
||||
Write a value to a single motor. Positions are always in degrees.
|
||||
Can write 'Goal_Position', 'Kp', or 'Kd'.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if data_name in ("Kp", "Kd"):
|
||||
self._gains[motor][data_name.lower()] = float(value)
|
||||
elif data_name == "Goal_Position":
|
||||
kp = self._gains[motor]["kp"]
|
||||
kd = self._gains[motor]["kd"]
|
||||
self._mit_control(motor, kp, kd, float(value), 0.0, 0.0)
|
||||
else:
|
||||
raise ValueError(f"Writing {data_name} not supported in MIT mode")
|
||||
|
||||
def sync_read(
|
||||
self,
|
||||
data_name: str,
|
||||
motors: str | list[str] | None = None,
|
||||
) -> dict[str, Value]:
|
||||
"""
|
||||
Read the same value from multiple motors simultaneously.
|
||||
"""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
self._batch_refresh(target_motors)
|
||||
|
||||
result = {}
|
||||
for motor in target_motors:
|
||||
result[motor] = self._get_cached_value(motor, data_name)
|
||||
return result
|
||||
|
||||
def sync_read_all_states(
|
||||
self,
|
||||
motors: str | list[str] | None = None,
|
||||
*,
|
||||
num_retry: int = 0,
|
||||
) -> dict[str, MotorState]:
|
||||
"""
|
||||
Read ALL motor states (position, velocity, torque) from multiple motors in ONE refresh cycle.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping motor names to state dicts with keys: 'position', 'velocity', 'torque'
|
||||
Example: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
|
||||
"""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
self._batch_refresh(target_motors)
|
||||
|
||||
result = {}
|
||||
for motor in target_motors:
|
||||
result[motor] = self._last_known_states[motor].copy()
|
||||
return result
|
||||
|
||||
def _batch_refresh(self, motors: list[str]) -> None:
|
||||
"""Internal helper to refresh a list of motors and update cache."""
|
||||
# Send refresh commands
|
||||
for motor in motors:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
# Small delay to reduce bus congestion if necessary, though removed in sync_read previously
|
||||
# precise_sleep(PRECISE_SLEEP_SEC)
|
||||
|
||||
# Collect responses
|
||||
expected_recv_ids = [self._get_motor_recv_id(m) for m in motors]
|
||||
responses = self._recv_all_responses(expected_recv_ids, timeout=MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
# Update cache
|
||||
for motor in motors:
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
msg = responses.get(recv_id)
|
||||
if msg:
|
||||
self._process_response(motor, msg)
|
||||
else:
|
||||
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
|
||||
"""
|
||||
Write values to multiple motors simultaneously. Positions are always in degrees.
|
||||
"""
|
||||
if data_name in ("Kp", "Kd"):
|
||||
key = data_name.lower()
|
||||
for motor, val in values.items():
|
||||
self._gains[motor][key] = float(val)
|
||||
|
||||
elif data_name == "Goal_Position":
|
||||
# Step 1: Send all MIT control commands
|
||||
recv_id_to_motor: dict[int, str] = {}
|
||||
for motor, value_degrees in values.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types[motor_name]
|
||||
|
||||
kp = self._gains[motor]["kp"]
|
||||
kd = self._gains[motor]["kd"]
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
precise_sleep(PRECISE_TIMEOUT_SEC)
|
||||
|
||||
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
||||
|
||||
# Step 2: Collect responses and update state cache
|
||||
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=MEDIUM_TIMEOUT_SEC)
|
||||
for recv_id, motor_name in recv_id_to_motor.items():
|
||||
if msg := responses.get(recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
else:
|
||||
# Fall back to individual writes
|
||||
for motor, value in values.items():
|
||||
self.write(data_name, motor, value)
|
||||
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
"""Read calibration data from motors."""
|
||||
# Damiao motors don't store calibration internally
|
||||
# Return existing calibration or empty dict
|
||||
return self.calibration if self.calibration else {}
|
||||
|
||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
|
||||
"""Write calibration data to motors."""
|
||||
# Damiao motors don't store calibration internally
|
||||
# Just cache it in memory
|
||||
if cache:
|
||||
self.calibration = calibration_dict
|
||||
|
||||
def record_ranges_of_motion(
|
||||
self,
|
||||
motors: NameOrID | list[NameOrID] | None = None,
|
||||
display_values: bool = True,
|
||||
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
|
||||
"""
|
||||
Interactively record the min/max values of each motor in degrees.
|
||||
|
||||
Move the joints by hand (with torque disabled) while the method streams live positions.
|
||||
Press Enter to finish.
|
||||
"""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
|
||||
self.disable_torque(target_motors)
|
||||
time.sleep(LONG_TIMEOUT_SEC)
|
||||
|
||||
start_positions = self.sync_read("Present_Position", target_motors)
|
||||
mins = start_positions.copy()
|
||||
maxes = start_positions.copy()
|
||||
|
||||
print("\nMove joints through their full range of motion. Press ENTER when done.")
|
||||
user_pressed_enter = False
|
||||
|
||||
while not user_pressed_enter:
|
||||
positions = self.sync_read("Present_Position", target_motors)
|
||||
|
||||
for motor in target_motors:
|
||||
if motor in positions:
|
||||
mins[motor] = min(positions[motor], mins.get(motor, positions[motor]))
|
||||
maxes[motor] = max(positions[motor], maxes.get(motor, positions[motor]))
|
||||
|
||||
if display_values:
|
||||
print("\n" + "=" * 50)
|
||||
print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}")
|
||||
print("-" * 50)
|
||||
for motor in target_motors:
|
||||
if motor in positions:
|
||||
print(
|
||||
f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}"
|
||||
)
|
||||
|
||||
if enter_pressed():
|
||||
user_pressed_enter = True
|
||||
|
||||
if display_values and not user_pressed_enter:
|
||||
move_cursor_up(len(target_motors) + 4)
|
||||
|
||||
time.sleep(LONG_TIMEOUT_SEC)
|
||||
|
||||
self.enable_torque(target_motors)
|
||||
|
||||
for motor in target_motors:
|
||||
if (motor in mins) and (motor in maxes) and (int(abs(maxes[motor] - mins[motor])) < 5):
|
||||
raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)")
|
||||
|
||||
return mins, maxes
|
||||
|
||||
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
|
||||
"""Convert motor specification to list of motor names."""
|
||||
if motors is None:
|
||||
return list(self.motors.keys())
|
||||
elif isinstance(motors, str):
|
||||
return [motors]
|
||||
elif isinstance(motors, list):
|
||||
return motors
|
||||
else:
|
||||
raise TypeError(f"Invalid motors type: {type(motors)}")
|
||||
|
||||
def _get_motor_id(self, motor: NameOrID) -> int:
|
||||
"""Get CAN ID for a motor."""
|
||||
if isinstance(motor, str):
|
||||
if motor in self.motors:
|
||||
return self.motors[motor].id
|
||||
else:
|
||||
raise ValueError(f"Unknown motor: {motor}")
|
||||
else:
|
||||
return motor
|
||||
|
||||
def _get_motor_name(self, motor: NameOrID) -> str:
|
||||
"""Get motor name from name or ID."""
|
||||
if isinstance(motor, str):
|
||||
return motor
|
||||
else:
|
||||
for name, m in self.motors.items():
|
||||
if m.id == motor:
|
||||
return name
|
||||
raise ValueError(f"Unknown motor ID: {motor}")
|
||||
|
||||
def _get_motor_recv_id(self, motor: NameOrID) -> int:
|
||||
"""Get motor recv_id from name or ID."""
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_obj = self.motors.get(motor_name)
|
||||
if motor_obj and motor_obj.recv_id is not None:
|
||||
return motor_obj.recv_id
|
||||
else:
|
||||
raise ValueError(f"Motor {motor_obj} doesn't have a valid recv_id (None).")
|
||||
|
||||
@cached_property
|
||||
def is_calibrated(self) -> bool:
|
||||
"""Check if motors are calibrated."""
|
||||
return bool(self.calibration)
|
||||
@@ -0,0 +1,209 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration tables for Damiao motors."""
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
# Motor type definitions
|
||||
class MotorType(IntEnum):
|
||||
DM3507 = 0
|
||||
DM4310 = 1
|
||||
DM4310_48V = 2
|
||||
DM4340 = 3
|
||||
DM4340_48V = 4
|
||||
DM6006 = 5
|
||||
DM8006 = 6
|
||||
DM8009 = 7
|
||||
DM10010L = 8
|
||||
DM10010 = 9
|
||||
DMH3510 = 10
|
||||
DMH6215 = 11
|
||||
DMG6220 = 12
|
||||
|
||||
|
||||
# Control modes
|
||||
class ControlMode(IntEnum):
|
||||
MIT = 1
|
||||
POS_VEL = 2
|
||||
VEL = 3
|
||||
TORQUE_POS = 4
|
||||
|
||||
|
||||
# Motor variable IDs (RID)
|
||||
class MotorVariable(IntEnum):
|
||||
UV_VALUE = 0
|
||||
KT_VALUE = 1
|
||||
OT_VALUE = 2
|
||||
OC_VALUE = 3
|
||||
ACC = 4
|
||||
DEC = 5
|
||||
MAX_SPD = 6
|
||||
MST_ID = 7
|
||||
ESC_ID = 8
|
||||
TIMEOUT = 9
|
||||
CTRL_MODE = 10
|
||||
DAMP = 11
|
||||
INERTIA = 12
|
||||
HW_VER = 13
|
||||
SW_VER = 14
|
||||
SN = 15
|
||||
NPP = 16
|
||||
RS = 17
|
||||
LS = 18
|
||||
FLUX = 19
|
||||
GR = 20
|
||||
PMAX = 21
|
||||
VMAX = 22
|
||||
TMAX = 23
|
||||
I_BW = 24
|
||||
KP_ASR = 25
|
||||
KI_ASR = 26
|
||||
KP_APR = 27
|
||||
KI_APR = 28
|
||||
OV_VALUE = 29
|
||||
GREF = 30
|
||||
DETA = 31
|
||||
V_BW = 32
|
||||
IQ_C1 = 33
|
||||
VL_C1 = 34
|
||||
CAN_BR = 35
|
||||
SUB_VER = 36
|
||||
U_OFF = 50
|
||||
V_OFF = 51
|
||||
K1 = 52
|
||||
K2 = 53
|
||||
M_OFF = 54
|
||||
DIR = 55
|
||||
P_M = 80
|
||||
XOUT = 81
|
||||
|
||||
|
||||
# Motor limit parameters [PMAX, VMAX, TMAX]
|
||||
# PMAX: Maximum position (rad)
|
||||
# VMAX: Maximum velocity (rad/s)
|
||||
# TMAX: Maximum torque (N·m)
|
||||
MOTOR_LIMIT_PARAMS = {
|
||||
MotorType.DM3507: (12.5, 30, 10),
|
||||
MotorType.DM4310: (12.5, 30, 10),
|
||||
MotorType.DM4310_48V: (12.5, 50, 10),
|
||||
MotorType.DM4340: (12.5, 8, 28),
|
||||
MotorType.DM4340_48V: (12.5, 10, 28),
|
||||
MotorType.DM6006: (12.5, 45, 20),
|
||||
MotorType.DM8006: (12.5, 45, 40),
|
||||
MotorType.DM8009: (12.5, 45, 54),
|
||||
MotorType.DM10010L: (12.5, 25, 200),
|
||||
MotorType.DM10010: (12.5, 20, 200),
|
||||
MotorType.DMH3510: (12.5, 280, 1),
|
||||
MotorType.DMH6215: (12.5, 45, 10),
|
||||
MotorType.DMG6220: (12.5, 45, 10),
|
||||
}
|
||||
|
||||
# Motor model names
|
||||
MODEL_NAMES = {
|
||||
MotorType.DM3507: "dm3507",
|
||||
MotorType.DM4310: "dm4310",
|
||||
MotorType.DM4310_48V: "dm4310_48v",
|
||||
MotorType.DM4340: "dm4340",
|
||||
MotorType.DM4340_48V: "dm4340_48v",
|
||||
MotorType.DM6006: "dm6006",
|
||||
MotorType.DM8006: "dm8006",
|
||||
MotorType.DM8009: "dm8009",
|
||||
MotorType.DM10010L: "dm10010l",
|
||||
MotorType.DM10010: "dm10010",
|
||||
MotorType.DMH3510: "dmh3510",
|
||||
MotorType.DMH6215: "dmh6215",
|
||||
MotorType.DMG6220: "dmg6220",
|
||||
}
|
||||
|
||||
# Motor resolution table (encoder counts per revolution)
|
||||
MODEL_RESOLUTION = {
|
||||
"dm3507": 65536,
|
||||
"dm4310": 65536,
|
||||
"dm4310_48v": 65536,
|
||||
"dm4340": 65536,
|
||||
"dm4340_48v": 65536,
|
||||
"dm6006": 65536,
|
||||
"dm8006": 65536,
|
||||
"dm8009": 65536,
|
||||
"dm10010l": 65536,
|
||||
"dm10010": 65536,
|
||||
"dmh3510": 65536,
|
||||
"dmh6215": 65536,
|
||||
"dmg6220": 65536,
|
||||
}
|
||||
|
||||
# CAN baudrates supported by Damiao motors
|
||||
AVAILABLE_BAUDRATES = [
|
||||
125000, # 0: 125 kbps
|
||||
200000, # 1: 200 kbps
|
||||
250000, # 2: 250 kbps
|
||||
500000, # 3: 500 kbps
|
||||
1000000, # 4: 1 mbps (default for OpenArms)
|
||||
2000000, # 5: 2 mbps
|
||||
2500000, # 6: 2.5 mbps
|
||||
3200000, # 7: 3.2 mbps
|
||||
4000000, # 8: 4 mbps
|
||||
5000000, # 9: 5 mbps
|
||||
]
|
||||
DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms
|
||||
|
||||
# Default timeout in milliseconds
|
||||
DEFAULT_TIMEOUT_MS = 1000
|
||||
|
||||
# OpenArms specific configurations
|
||||
# Based on: https://docs.openarm.dev/software/setup/configure-test
|
||||
# OpenArms has 7 DOF per arm (14 total for dual arm)
|
||||
OPENARMS_ARM_MOTOR_IDS = {
|
||||
"joint_1": {"send": 0x01, "recv": 0x11}, # J1 - Shoulder pan
|
||||
"joint_2": {"send": 0x02, "recv": 0x12}, # J2 - Shoulder lift
|
||||
"joint_3": {"send": 0x03, "recv": 0x13}, # J3 - Elbow flex
|
||||
"joint_4": {"send": 0x04, "recv": 0x14}, # J4 - Wrist flex
|
||||
"joint_5": {"send": 0x05, "recv": 0x15}, # J5 - Wrist roll
|
||||
"joint_6": {"send": 0x06, "recv": 0x16}, # J6 - Wrist pitch
|
||||
"joint_7": {"send": 0x07, "recv": 0x17}, # J7 - Wrist rotation
|
||||
}
|
||||
|
||||
OPENARMS_GRIPPER_MOTOR_IDS = {
|
||||
"gripper": {"send": 0x08, "recv": 0x18}, # J8 - Gripper
|
||||
}
|
||||
|
||||
# Default motor types for OpenArms
|
||||
OPENARMS_DEFAULT_MOTOR_TYPES = {
|
||||
"joint_1": MotorType.DM8009, # Shoulder pan - high torque
|
||||
"joint_2": MotorType.DM8009, # Shoulder lift - high torque
|
||||
"joint_3": MotorType.DM4340, # Shoulder rotation
|
||||
"joint_4": MotorType.DM4340, # Elbow flex
|
||||
"joint_5": MotorType.DM4310, # Wrist roll
|
||||
"joint_6": MotorType.DM4310, # Wrist pitch
|
||||
"joint_7": MotorType.DM4310, # Wrist rotation
|
||||
"gripper": MotorType.DM4310, # Gripper
|
||||
}
|
||||
|
||||
# MIT control parameter ranges
|
||||
MIT_KP_RANGE = (0.0, 500.0)
|
||||
MIT_KD_RANGE = (0.0, 5.0)
|
||||
|
||||
# CAN frame command IDs
|
||||
CAN_CMD_ENABLE = 0xFC
|
||||
CAN_CMD_DISABLE = 0xFD
|
||||
CAN_CMD_SET_ZERO = 0xFE
|
||||
CAN_CMD_REFRESH = 0xCC
|
||||
CAN_CMD_QUERY_PARAM = 0x33
|
||||
CAN_CMD_WRITE_PARAM = 0x55
|
||||
CAN_CMD_SAVE_PARAM = 0xAA
|
||||
|
||||
# CAN ID for parameter operations
|
||||
CAN_PARAM_ID = 0x7FF
|
||||
@@ -22,9 +22,8 @@ import logging
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
|
||||
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
|
||||
from ..encoding_utils import decode_twos_complement, encode_twos_complement
|
||||
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
||||
from .tables import (
|
||||
AVAILABLE_BAUDRATES,
|
||||
MODEL_BAUDRATE_TABLE,
|
||||
@@ -100,7 +99,7 @@ def _split_into_byte_chunks(value: int, length: int) -> list[int]:
|
||||
return data
|
||||
|
||||
|
||||
class DynamixelMotorsBus(MotorsBus):
|
||||
class DynamixelMotorsBus(SerialMotorsBus):
|
||||
"""
|
||||
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
|
||||
the motors. For more info, see the Dynamixel SDK Documentation:
|
||||
@@ -203,9 +202,9 @@ class DynamixelMotorsBus(MotorsBus):
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
|
||||
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
|
||||
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
|
||||
@@ -17,9 +17,8 @@ from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pprint import pformat
|
||||
|
||||
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
|
||||
from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
||||
from .tables import (
|
||||
FIRMWARE_MAJOR_VERSION,
|
||||
FIRMWARE_MINOR_VERSION,
|
||||
@@ -96,7 +95,7 @@ def patch_setPacketTimeout(self, packet_length): # noqa: N802
|
||||
self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50
|
||||
|
||||
|
||||
class FeetechMotorsBus(MotorsBus):
|
||||
class FeetechMotorsBus(SerialMotorsBus):
|
||||
"""
|
||||
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the
|
||||
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
|
||||
@@ -298,11 +297,11 @@ class FeetechMotorsBus(MotorsBus):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self.write("Lock", motor, 0, num_retry=num_retry)
|
||||
|
||||
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
|
||||
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
|
||||
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Lock")
|
||||
self._write(addr, length, motor_id, 0, num_retry=num_retry)
|
||||
self._write(addr, length, motor, 0, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
# TODO(aliberts): Add block noqa when feature below is available
|
||||
# https://github.com/astral-sh/ruff/issues/3711
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
@@ -41,6 +43,81 @@ Value: TypeAlias = int | float
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MotorsBusBase(abc.ABC):
|
||||
"""
|
||||
Base class for all motor bus implementations.
|
||||
|
||||
This is a minimal interface that all motor buses must implement, regardless of their
|
||||
communication protocol (serial, CAN, etc.).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
self.calibration = calibration if calibration else {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""Establish connection to the motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""Disconnect from the motors."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected to the motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, data_name: str, motor: str) -> Value:
|
||||
"""Read a value from a single motor."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def write(self, data_name: str, motor: str, value: Value) -> None:
|
||||
"""Write a value to a single motor."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_read(self, data_name: str, motors: str | list[str] | None = None) -> dict[str, Value]:
|
||||
"""Read a value from multiple motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
|
||||
"""Write values to multiple motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Enable torque on selected motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Disable torque on selected motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
"""Read calibration parameters from the motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
|
||||
"""Write calibration parameters to the motors."""
|
||||
pass
|
||||
|
||||
|
||||
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
|
||||
ctrl_table = model_ctrl_table.get(model)
|
||||
if ctrl_table is None:
|
||||
@@ -97,6 +174,8 @@ class Motor:
|
||||
id: int
|
||||
model: str
|
||||
norm_mode: MotorNormMode
|
||||
motor_type_str: str | None = None
|
||||
recv_id: int | None = None
|
||||
|
||||
|
||||
class PortHandler(Protocol):
|
||||
@@ -203,15 +282,15 @@ class GroupSyncWrite(Protocol):
|
||||
def txPacket(self): ...
|
||||
|
||||
|
||||
class MotorsBus(abc.ABC):
|
||||
class SerialMotorsBus(MotorsBusBase):
|
||||
"""
|
||||
A MotorsBus allows to efficiently read and write to the attached motors.
|
||||
A SerialMotorsBus allows to efficiently read and write to motors connected via serial communication.
|
||||
It represents several motors daisy-chained together and connected through a serial port.
|
||||
There are currently two implementations of this abstract class:
|
||||
There are currently two implementations of this class:
|
||||
- DynamixelMotorsBus
|
||||
- FeetechMotorsBus
|
||||
|
||||
Note: This class may evolve in the future should we add support for other types of bus.
|
||||
This class is specifically for serial-based motor protocols (Dynamixel, Feetech, etc.).
|
||||
|
||||
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
||||
To find the port, you can run our utility script:
|
||||
@@ -260,9 +339,7 @@ class MotorsBus(abc.ABC):
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
self.calibration = calibration if calibration else {}
|
||||
super().__init__(port, motors, calibration)
|
||||
|
||||
self.port_handler: PortHandler
|
||||
self.packet_handler: PacketHandler
|
||||
@@ -532,7 +609,7 @@ class MotorsBus(abc.ABC):
|
||||
self.set_baudrate(self.default_baudrate)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]:
|
||||
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -545,13 +622,13 @@ class MotorsBus(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Disable torque on selected motors.
|
||||
|
||||
Disabling Torque allows to write to the motors' permanent memory area (EPROM/EEPROM).
|
||||
|
||||
Args:
|
||||
motors (int | str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a
|
||||
motors ( str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a
|
||||
list of names or `None` to affect every registered motor. Defaults to `None`.
|
||||
num_retry (int, optional): Number of additional retry attempts on communication failure.
|
||||
Defaults to 0.
|
||||
@@ -1194,3 +1271,7 @@ class MotorsBus(abc.ABC):
|
||||
for id_, value in ids_values.items():
|
||||
data = self._serialize_data(value, length)
|
||||
self.sync_writer.addParam(id_, data)
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
MotorsBus: TypeAlias = SerialMotorsBus
|
||||
|
||||
@@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
@@ -50,8 +50,9 @@ class PI0Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
# Real-Time Chunking (RTC) configurations
|
||||
rtc_config: RTCConfig | None = None
|
||||
rtc_training_config: RTCTrainingConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
|
||||
@@ -44,6 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.rtc.training_time import (
|
||||
apply_rtc_training_time,
|
||||
apply_training_time_rtc_inference,
|
||||
masked_mean,
|
||||
sample_rtc_delay,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -79,8 +85,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
if time.ndim not in (1, 2):
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
@@ -88,8 +94,14 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
if time.ndim == 1:
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
time_flat = time.reshape(-1)
|
||||
sin_input = scaling_factor[None, :] * time_flat[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb.reshape(*time.shape, dimension)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
@@ -605,6 +617,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _training_time_rtc_inference_enabled(self):
|
||||
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -714,7 +729,10 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
if time_emb.dim() == 2:
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
elif time_emb.shape[:2] != action_emb.shape[:2]:
|
||||
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
def mlp_func(action_time_emb):
|
||||
@@ -750,7 +768,12 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
if time.ndim == 1:
|
||||
time_expanded = time[:, None, None]
|
||||
elif time.ndim == 2:
|
||||
time_expanded = time[:, :, None]
|
||||
else:
|
||||
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
@@ -846,24 +869,37 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
use_training_time_rtc = self._training_time_rtc_inference_enabled()
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
if use_training_time_rtc:
|
||||
x_t_cond, time_tensor = apply_training_time_rtc_inference(
|
||||
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
|
||||
)
|
||||
v_t = self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
x_t=x_t_cond,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
elif self._rtc_enabled():
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
@@ -874,7 +910,14 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=x_t,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
@@ -1277,7 +1320,19 @@ class PI0Policy(PreTrainedPolicy):
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Compute loss
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
postfix_mask = None
|
||||
rtc_cfg = self.config.rtc_training_config
|
||||
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
|
||||
batch_size = actions.shape[0]
|
||||
time = self.model.sample_time(batch_size, actions.device)
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
|
||||
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
|
||||
losses = self.model.forward(
|
||||
images, img_masks, lang_tokens, lang_masks, state, actions, noise=noise, time=time
|
||||
)
|
||||
else:
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
@@ -1289,12 +1344,12 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
@@ -52,6 +52,7 @@ class PI05Config(PreTrainedConfig):
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
rtc_training_config: RTCTrainingConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
|
||||
@@ -44,6 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.rtc.training_time import (
|
||||
apply_rtc_training_time,
|
||||
apply_training_time_rtc_inference,
|
||||
masked_mean,
|
||||
sample_rtc_delay,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -78,8 +84,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
if time.ndim not in (1, 2):
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
@@ -87,8 +93,14 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
if time.ndim == 1:
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
time_flat = time.reshape(-1)
|
||||
sin_input = scaling_factor[None, :] * time_flat[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb.reshape(*time.shape, dimension)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
@@ -602,6 +614,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _training_time_rtc_inference_enabled(self):
|
||||
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -729,7 +744,12 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
if time.ndim == 1:
|
||||
time_expanded = time[:, None, None]
|
||||
elif time.ndim == 2:
|
||||
time_expanded = time[:, :, None]
|
||||
else:
|
||||
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
@@ -820,23 +840,35 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
use_training_time_rtc = self._training_time_rtc_inference_enabled()
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
if use_training_time_rtc:
|
||||
x_t_cond, time_tensor = apply_training_time_rtc_inference(
|
||||
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
|
||||
)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
x_t=x_t_cond,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
elif self._rtc_enabled():
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
@@ -847,7 +879,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=x_t,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
@@ -1250,7 +1288,17 @@ class PI05Policy(PreTrainedPolicy):
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
postfix_mask = None
|
||||
rtc_cfg = self.config.rtc_training_config
|
||||
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
|
||||
batch_size = actions.shape[0]
|
||||
time = self.model.sample_time(batch_size, actions.device)
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
|
||||
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise=noise, time=time)
|
||||
else:
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
@@ -1262,12 +1310,12 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ Based on:
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.configs.types import RTCAttentionSchedule, RTCTrainingDelayDistribution
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -53,3 +53,22 @@ class RTCConfig:
|
||||
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
|
||||
if self.debug_maxlen <= 0:
|
||||
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCTrainingConfig:
|
||||
"""Configuration for training-time RTC action prefix conditioning."""
|
||||
|
||||
enabled: bool = False
|
||||
min_delay: int = 0
|
||||
max_delay: int = 0
|
||||
delay_distribution: RTCTrainingDelayDistribution = RTCTrainingDelayDistribution.UNIFORM
|
||||
exp_decay: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.min_delay < 0:
|
||||
raise ValueError(f"min_delay must be >= 0, got {self.min_delay}")
|
||||
if self.max_delay < self.min_delay:
|
||||
raise ValueError(f"max_delay ({self.max_delay}) must be >= min_delay ({self.min_delay})")
|
||||
if self.exp_decay <= 0:
|
||||
raise ValueError(f"exp_decay must be positive, got {self.exp_decay}")
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import RTCTrainingDelayDistribution
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig
|
||||
|
||||
|
||||
def sample_rtc_delay(cfg: RTCTrainingConfig, batch_size: int, device: torch.device) -> torch.Tensor:
|
||||
if cfg.max_delay == cfg.min_delay:
|
||||
return torch.full((batch_size,), cfg.min_delay, device=device, dtype=torch.long)
|
||||
|
||||
if cfg.delay_distribution == RTCTrainingDelayDistribution.UNIFORM:
|
||||
return torch.randint(cfg.min_delay, cfg.max_delay + 1, (batch_size,), device=device, dtype=torch.long)
|
||||
|
||||
delay_values = torch.arange(cfg.min_delay, cfg.max_delay + 1, device=device, dtype=torch.long)
|
||||
weights = torch.exp(-cfg.exp_decay * delay_values.to(dtype=torch.float32))
|
||||
probs = weights / weights.sum()
|
||||
samples = torch.multinomial(probs, batch_size, replacement=True)
|
||||
return delay_values[samples]
|
||||
|
||||
|
||||
def apply_rtc_training_time(
|
||||
time: torch.Tensor, delay: torch.Tensor, seq_len: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
device = time.device
|
||||
delay = torch.clamp(delay, max=seq_len)
|
||||
prefix_mask = torch.arange(seq_len, device=device)[None, :] < delay[:, None]
|
||||
time_tokens = time[:, None].expand(-1, seq_len)
|
||||
time_tokens = time_tokens.masked_fill(prefix_mask, 0.0)
|
||||
postfix_mask = ~prefix_mask
|
||||
return time_tokens, postfix_mask
|
||||
|
||||
|
||||
def masked_mean(
|
||||
losses: torch.Tensor, mask: torch.Tensor | None, reduce_dims: tuple[int, ...], eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
if mask is None:
|
||||
return losses.mean(dim=reduce_dims)
|
||||
|
||||
mask = mask.to(dtype=losses.dtype)
|
||||
while mask.dim() < losses.dim():
|
||||
mask = mask.unsqueeze(-1)
|
||||
masked = losses * mask
|
||||
denom = mask.sum(dim=reduce_dims).clamp_min(eps)
|
||||
return masked.sum(dim=reduce_dims) / denom
|
||||
|
||||
|
||||
def apply_training_time_rtc_inference(
|
||||
x_t: torch.Tensor,
|
||||
time: float,
|
||||
inference_delay: int | None,
|
||||
prev_chunk_left_over: torch.Tensor | None,
|
||||
chunk_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply training-time RTC conditioning during inference.
|
||||
|
||||
Based on Algorithm 1 from "Training-Time Action Conditioning for Efficient Real-Time Chunking".
|
||||
|
||||
At each denoising step:
|
||||
1. Replace prefix positions in x_t with ground truth from previous chunk
|
||||
2. Create per-token timesteps with 1.0 for prefix positions
|
||||
|
||||
Args:
|
||||
x_t: Current noisy actions (B, T, D)
|
||||
time: Current flow matching timestep (scalar)
|
||||
inference_delay: Number of prefix actions to condition on
|
||||
prev_chunk_left_over: Previous chunk's leftover actions (B, T, D)
|
||||
chunk_size: Total chunk size T
|
||||
|
||||
Returns:
|
||||
x_t_conditioned: x_t with prefix replaced by previous actions
|
||||
time_per_token: Per-token timesteps (B, T) with 1.0 for prefix
|
||||
"""
|
||||
batch_size = x_t.shape[0]
|
||||
device = x_t.device
|
||||
|
||||
if inference_delay is None or inference_delay <= 0 or prev_chunk_left_over is None:
|
||||
time_scalar = torch.full((batch_size,), time, device=device, dtype=torch.float32)
|
||||
return x_t, time_scalar
|
||||
|
||||
delay = min(inference_delay, chunk_size)
|
||||
prefix_mask = torch.arange(chunk_size, device=device)[None, :] < delay
|
||||
|
||||
x_t_conditioned = torch.where(
|
||||
prefix_mask[:, :, None].expand_as(x_t),
|
||||
prev_chunk_left_over[:, :chunk_size, :],
|
||||
x_t,
|
||||
)
|
||||
|
||||
time_per_token = torch.full((batch_size, chunk_size), time, device=device, dtype=torch.float32)
|
||||
time_per_token = time_per_token.masked_fill(prefix_mask, 1.0)
|
||||
|
||||
return x_t_conditioned, time_per_token
|
||||
@@ -20,7 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@@ -103,8 +103,9 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
# Real-Time Chunking (RTC) configurations
|
||||
rtc_config: RTCConfig | None = None
|
||||
rtc_training_config: RTCTrainingConfig | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
@@ -63,6 +63,12 @@ from typing_extensions import Unpack
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.rtc.training_time import (
|
||||
apply_rtc_training_time,
|
||||
apply_training_time_rtc_inference,
|
||||
masked_mean,
|
||||
sample_rtc_delay,
|
||||
)
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
from lerobot.policies.utils import (
|
||||
@@ -85,8 +91,8 @@ def create_sinusoidal_pos_embedding(
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
if time.ndim not in (1, 2):
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
@@ -94,9 +100,14 @@ def create_sinusoidal_pos_embedding(
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
if time.ndim == 1:
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
time_flat = time.reshape(-1)
|
||||
sin_input = scaling_factor[None, :] * time_flat[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb
|
||||
return pos_emb.reshape(*time.shape, dimension)
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
@@ -375,6 +386,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
actions = self.prepare_action(batch)
|
||||
postfix_mask = None
|
||||
rtc_cfg = self.config.rtc_training_config
|
||||
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
|
||||
batch_size = actions.shape[0]
|
||||
if time is None:
|
||||
time = self.model.sample_time(batch_size, actions.device)
|
||||
if noise is None:
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
|
||||
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
@@ -384,6 +405,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
in_episode_bound = ~actions_is_pad
|
||||
losses = losses * in_episode_bound.unsqueeze(-1)
|
||||
loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
||||
postfix_mask = in_episode_bound if postfix_mask is None else (postfix_mask & in_episode_bound)
|
||||
|
||||
# Remove padding
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
@@ -391,12 +413,12 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
@@ -596,6 +618,9 @@ class VLAFlowMatching(nn.Module):
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _training_time_rtc_inference_enabled(self):
|
||||
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
|
||||
|
||||
def set_requires_grad(self):
|
||||
for params in self.state_proj.parameters():
|
||||
params.requires_grad = self.config.train_state_proj
|
||||
@@ -731,7 +756,10 @@ class VLAFlowMatching(nn.Module):
|
||||
)
|
||||
time_emb = time_emb.type(dtype=dtype)
|
||||
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
if time_emb.dim() == 2:
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
elif time_emb.shape[:2] != action_emb.shape[:2]:
|
||||
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
||||
@@ -763,7 +791,12 @@ class VLAFlowMatching(nn.Module):
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
if time.ndim == 1:
|
||||
time_expanded = time[:, None, None]
|
||||
elif time.ndim == 2:
|
||||
time_expanded = time[:, :, None]
|
||||
else:
|
||||
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
@@ -826,23 +859,35 @@ class VLAFlowMatching(nn.Module):
|
||||
num_steps = self.config.num_steps
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
use_training_time_rtc = self._training_time_rtc_inference_enabled()
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
x_t=input_x_t,
|
||||
if use_training_time_rtc:
|
||||
x_t_cond, time_tensor = apply_training_time_rtc_inference(
|
||||
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
|
||||
)
|
||||
v_t = self.denoise_step(
|
||||
x_t=x_t_cond,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
timestep=current_timestep,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
elif self._rtc_enabled():
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
x_t=input_x_t,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
@@ -853,7 +898,13 @@ class VLAFlowMatching(nn.Module):
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
x_t=x_t,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
|
||||
@@ -0,0 +1,360 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Setup and debug CAN interfaces for Damiao motors (e.g., OpenArms).
|
||||
|
||||
Examples:
|
||||
|
||||
Setup CAN interfaces with CAN FD:
|
||||
```shell
|
||||
lerobot-setup-can --mode=setup --interfaces=can0,can1,can2,can3
|
||||
```
|
||||
|
||||
Test motors on a single interface:
|
||||
```shell
|
||||
lerobot-setup-can --mode=test --interfaces=can0
|
||||
```
|
||||
|
||||
Test motors on all interfaces:
|
||||
```shell
|
||||
lerobot-setup-can --mode=test --interfaces=can0,can1,can2,can3
|
||||
```
|
||||
|
||||
Speed test:
|
||||
```shell
|
||||
lerobot-setup-can --mode=speed --interfaces=can0
|
||||
```
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.utils.import_utils import is_package_available
|
||||
|
||||
MOTOR_NAMES = {
|
||||
0x01: "joint_1",
|
||||
0x02: "joint_2",
|
||||
0x03: "joint_3",
|
||||
0x04: "joint_4",
|
||||
0x05: "joint_5",
|
||||
0x06: "joint_6",
|
||||
0x07: "joint_7",
|
||||
0x08: "gripper",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CANSetupConfig:
|
||||
mode: str = "test"
|
||||
interfaces: str = "can0" # Comma-separated, e.g. "can0,can1,can2,can3"
|
||||
bitrate: int = 1000000
|
||||
data_bitrate: int = 5000000
|
||||
use_fd: bool = True
|
||||
motor_ids: list[int] = field(default_factory=lambda: list(range(0x01, 0x09)))
|
||||
timeout: float = 1.0
|
||||
speed_iterations: int = 100
|
||||
|
||||
def get_interfaces(self) -> list[str]:
|
||||
return [i.strip() for i in self.interfaces.split(",") if i.strip()]
|
||||
|
||||
|
||||
def check_interface_status(interface: str) -> tuple[bool, str, bool]:
|
||||
"""Check if CAN interface is UP and configured."""
|
||||
try:
|
||||
result = subprocess.run(["ip", "link", "show", interface], capture_output=True, text=True) # nosec B607
|
||||
if result.returncode != 0:
|
||||
return False, "Interface not found", False
|
||||
|
||||
output = result.stdout
|
||||
is_up = "UP" in output
|
||||
is_fd = "fd on" in output.lower() or "canfd" in output.lower()
|
||||
status = "UP" if is_up else "DOWN"
|
||||
if is_fd:
|
||||
status += " (CAN FD)"
|
||||
|
||||
return is_up, status, is_fd
|
||||
except FileNotFoundError:
|
||||
return False, "ip command not found", False
|
||||
|
||||
|
||||
def setup_interface(interface: str, bitrate: int, data_bitrate: int, use_fd: bool) -> bool:
|
||||
"""Configure a CAN interface."""
|
||||
try:
|
||||
subprocess.run(["sudo", "ip", "link", "set", interface, "down"], check=False, capture_output=True) # nosec B607
|
||||
|
||||
cmd = ["sudo", "ip", "link", "set", interface, "type", "can", "bitrate", str(bitrate)]
|
||||
if use_fd:
|
||||
cmd.extend(["dbitrate", str(data_bitrate), "fd", "on"])
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True) # nosec B607
|
||||
if result.returncode != 0:
|
||||
print(f" ✗ Failed to configure: {result.stderr}")
|
||||
return False
|
||||
|
||||
result = subprocess.run( # nosec B607
|
||||
["sudo", "ip", "link", "set", interface, "up"], capture_output=True, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print(f" ✗ Failed to bring up: {result.stderr}")
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f" ✗ Error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_motor(bus, motor_id: int, timeout: float, use_fd: bool):
|
||||
"""Test a single motor and return responses."""
|
||||
import can
|
||||
|
||||
enable_msg = can.Message(
|
||||
arbitration_id=motor_id,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
|
||||
is_extended_id=False,
|
||||
is_fd=use_fd,
|
||||
)
|
||||
|
||||
try:
|
||||
bus.send(enable_msg)
|
||||
except Exception as e:
|
||||
return None, f"Send error: {e}"
|
||||
|
||||
responses = []
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
msg = bus.recv(timeout=0.1)
|
||||
if msg:
|
||||
responses.append((msg.arbitration_id, msg.data.hex(), getattr(msg, "is_fd", False)))
|
||||
|
||||
disable_msg = can.Message(
|
||||
arbitration_id=motor_id,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD],
|
||||
is_extended_id=False,
|
||||
is_fd=use_fd,
|
||||
)
|
||||
try:
|
||||
bus.send(disable_msg)
|
||||
except Exception:
|
||||
print(f"Error sending message to motor 0x{motor_id:02X}")
|
||||
|
||||
return responses, None
|
||||
|
||||
|
||||
def test_interface(cfg: CANSetupConfig, interface: str):
|
||||
"""Test all motors on a CAN interface."""
|
||||
import can
|
||||
|
||||
is_up, status, _ = check_interface_status(interface)
|
||||
print(f"\n{interface}: {status}")
|
||||
|
||||
if not is_up:
|
||||
print(f" ⚠ Interface is not UP. Run: lerobot-setup-can --mode=setup --interfaces {interface}")
|
||||
return {}
|
||||
|
||||
try:
|
||||
kwargs = {"channel": interface, "interface": "socketcan", "bitrate": cfg.bitrate}
|
||||
if cfg.use_fd:
|
||||
kwargs.update({"data_bitrate": cfg.data_bitrate, "fd": True})
|
||||
bus = can.interface.Bus(**kwargs)
|
||||
except Exception as e:
|
||||
print(f" ✗ Connection failed: {e}")
|
||||
return {}
|
||||
|
||||
results = {}
|
||||
try:
|
||||
while bus.recv(timeout=0.01):
|
||||
pass
|
||||
|
||||
for motor_id in cfg.motor_ids:
|
||||
motor_name = MOTOR_NAMES.get(motor_id, f"motor_0x{motor_id:02X}")
|
||||
responses, error = test_motor(bus, motor_id, cfg.timeout, cfg.use_fd)
|
||||
|
||||
if error:
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ {error}")
|
||||
results[motor_id] = {"found": False, "error": error}
|
||||
elif responses:
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✓ FOUND")
|
||||
for resp_id, data, is_fd in responses:
|
||||
fd_flag = " [FD]" if is_fd else ""
|
||||
print(f" → Response 0x{resp_id:02X}{fd_flag}: {data}")
|
||||
results[motor_id] = {"found": True, "responses": responses}
|
||||
else:
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ No response")
|
||||
results[motor_id] = {"found": False}
|
||||
|
||||
time.sleep(0.05)
|
||||
finally:
|
||||
bus.shutdown()
|
||||
|
||||
found = sum(1 for r in results.values() if r.get("found"))
|
||||
print(f"\n Summary: {found}/{len(cfg.motor_ids)} motors found")
|
||||
return results
|
||||
|
||||
|
||||
def speed_test(cfg: CANSetupConfig, interface: str):
|
||||
"""Test communication speed with motors."""
|
||||
import can
|
||||
|
||||
is_up, status, _ = check_interface_status(interface)
|
||||
if not is_up:
|
||||
print(f"{interface}: {status} - skipping")
|
||||
return
|
||||
|
||||
print(f"\n{interface}: Running speed test ({cfg.speed_iterations} iterations)...")
|
||||
|
||||
try:
|
||||
kwargs = {"channel": interface, "interface": "socketcan", "bitrate": cfg.bitrate}
|
||||
if cfg.use_fd:
|
||||
kwargs.update({"data_bitrate": cfg.data_bitrate, "fd": True})
|
||||
bus = can.interface.Bus(**kwargs)
|
||||
except Exception as e:
|
||||
print(f" ✗ Connection failed: {e}")
|
||||
return
|
||||
|
||||
responding_motor = None
|
||||
for motor_id in cfg.motor_ids:
|
||||
responses, _ = test_motor(bus, motor_id, 0.5, cfg.use_fd)
|
||||
if responses:
|
||||
responding_motor = motor_id
|
||||
break
|
||||
|
||||
if not responding_motor:
|
||||
print(" ✗ No responding motors found")
|
||||
bus.shutdown()
|
||||
return
|
||||
|
||||
print(f" Testing with motor 0x{responding_motor:02X}...")
|
||||
latencies = []
|
||||
|
||||
for _ in range(cfg.speed_iterations):
|
||||
start = time.perf_counter()
|
||||
msg = can.Message(
|
||||
arbitration_id=responding_motor,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
|
||||
is_extended_id=False,
|
||||
is_fd=cfg.use_fd,
|
||||
)
|
||||
bus.send(msg)
|
||||
resp = bus.recv(timeout=0.1)
|
||||
if resp:
|
||||
latencies.append((time.perf_counter() - start) * 1000)
|
||||
|
||||
bus.shutdown()
|
||||
|
||||
if latencies:
|
||||
avg_latency = sum(latencies) / len(latencies)
|
||||
hz = 1000.0 / avg_latency if avg_latency > 0 else 0
|
||||
print(f" ✓ Success rate: {len(latencies)}/{cfg.speed_iterations}")
|
||||
print(f" ✓ Avg latency: {avg_latency:.2f} ms")
|
||||
print(f" ✓ Max frequency: {hz:.1f} Hz")
|
||||
else:
|
||||
print(" ✗ No successful responses")
|
||||
|
||||
|
||||
def run_setup(cfg: CANSetupConfig):
|
||||
"""Setup CAN interfaces."""
|
||||
print("=" * 50)
|
||||
print("CAN Interface Setup")
|
||||
print("=" * 50)
|
||||
print(f"Mode: {'CAN FD' if cfg.use_fd else 'CAN 2.0'}")
|
||||
print(f"Bitrate: {cfg.bitrate / 1_000_000:.1f} Mbps")
|
||||
if cfg.use_fd:
|
||||
print(f"Data bitrate: {cfg.data_bitrate / 1_000_000:.1f} Mbps")
|
||||
print()
|
||||
|
||||
interfaces = cfg.get_interfaces()
|
||||
for interface in interfaces:
|
||||
print(f"Configuring {interface}...")
|
||||
if setup_interface(interface, cfg.bitrate, cfg.data_bitrate, cfg.use_fd):
|
||||
is_up, status, _ = check_interface_status(interface)
|
||||
print(f" ✓ {interface}: {status}")
|
||||
else:
|
||||
print(f" ✗ {interface}: Failed")
|
||||
|
||||
print("\nSetup complete!")
|
||||
print("\nNext: Test motors with:")
|
||||
print(f" lerobot-setup-can --mode=test --interfaces {','.join(interfaces)}")
|
||||
|
||||
|
||||
def run_test(cfg: CANSetupConfig):
|
||||
"""Test motors on CAN interfaces."""
|
||||
print("=" * 50)
|
||||
print("CAN Motor Test")
|
||||
print("=" * 50)
|
||||
print(f"Testing motors 0x{min(cfg.motor_ids):02X}-0x{max(cfg.motor_ids):02X}")
|
||||
print(f"Mode: {'CAN FD' if cfg.use_fd else 'CAN 2.0'}")
|
||||
print()
|
||||
|
||||
interfaces = cfg.get_interfaces()
|
||||
all_results = {}
|
||||
for interface in interfaces:
|
||||
all_results[interface] = test_interface(cfg, interface)
|
||||
|
||||
total_found = sum(sum(1 for r in res.values() if r.get("found")) for res in all_results.values())
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Summary")
|
||||
print("=" * 50)
|
||||
print(f"Total motors found: {total_found}")
|
||||
|
||||
if total_found == 0:
|
||||
print("\n⚠ No motors found! Check:")
|
||||
print(" 1. Motors are powered (24V)")
|
||||
print(" 2. CAN wiring (CANH, CANL, GND)")
|
||||
print(" 3. Motor timeout parameter > 0 (use Damiao tools)")
|
||||
print(" 4. 120Ω termination at both cable ends")
|
||||
print(f" 5. Interface configured: lerobot-setup-can --mode=setup --interfaces {interfaces[0]}")
|
||||
|
||||
|
||||
def run_speed(cfg: CANSetupConfig):
|
||||
"""Run speed tests on CAN interfaces."""
|
||||
print("=" * 50)
|
||||
print("CAN Speed Test")
|
||||
print("=" * 50)
|
||||
|
||||
for interface in cfg.get_interfaces():
|
||||
speed_test(cfg, interface)
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def setup_can(cfg: CANSetupConfig):
|
||||
if not is_package_available("can"):
|
||||
print("Error: python-can not installed. Install with: pip install python-can")
|
||||
sys.exit(1)
|
||||
|
||||
if cfg.mode == "setup":
|
||||
run_setup(cfg)
|
||||
elif cfg.mode == "test":
|
||||
run_test(cfg)
|
||||
elif cfg.mode == "speed":
|
||||
run_speed(cfg)
|
||||
else:
|
||||
print(f"Unknown mode: {cfg.mode}")
|
||||
print("Available modes: setup, test, speed")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
setup_can()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -27,4 +27,4 @@ class OmxLeaderConfig(TeleoperatorConfig):
|
||||
|
||||
# Sets the arm in torque mode with the gripper motor set to this value. This makes it possible to squeeze
|
||||
# the gripper and have it spring back to an open position on its own.
|
||||
gripper_open_pos: float = 37.0
|
||||
gripper_open_pos: float = 60.0
|
||||
|
||||
@@ -103,7 +103,7 @@ class OmxLeader(Teleoperator):
|
||||
self.calibration[motor] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=drive_modes[motor],
|
||||
homing_offset=0,
|
||||
homing_offset=0 if motor != "gripper" else 100,
|
||||
range_min=0,
|
||||
range_max=4095,
|
||||
)
|
||||
@@ -123,12 +123,20 @@ class OmxLeader(Teleoperator):
|
||||
# point
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
if motor == "gripper":
|
||||
self.bus.write("Drive_Mode", motor, DriveMode.INVERTED.value)
|
||||
else:
|
||||
self.bus.write("Drive_Mode", motor, DriveMode.NON_INVERTED.value)
|
||||
|
||||
# Use 'position control current based' for gripper to be limited by the limit of the current.
|
||||
# For the follower gripper, it means it can grasp an object without forcing too much even tho,
|
||||
# its goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
|
||||
# For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger
|
||||
# to make it move, and it will move back to its original target position when we release the force.
|
||||
self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
|
||||
self.bus.write("Current_Limit", "gripper", 100)
|
||||
self.bus.write("Goal_Current", "gripper", 100)
|
||||
self.bus.write("Homing_Offset", "gripper", 100)
|
||||
# Set gripper's goal pos in current position mode so that we can use it as a trigger.
|
||||
self.bus.enable_torque("gripper")
|
||||
if self.is_calibrated:
|
||||
|
||||
@@ -73,6 +73,7 @@ _transformers_available = is_package_available("transformers")
|
||||
_peft_available = is_package_available("peft")
|
||||
_scipy_available = is_package_available("scipy")
|
||||
_reachy2_sdk_available = is_package_available("reachy2_sdk")
|
||||
_can_available = is_package_available("python-can", "can")
|
||||
|
||||
|
||||
def make_device_from_device_class(config: ChoiceRegistry) -> Any:
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Minimal test script for Damiao motor with ID 3."""
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.utils.import_utils import _can_available
|
||||
|
||||
if not _can_available:
|
||||
pytest.skip("python-can not available", allow_module_level=True)
|
||||
|
||||
from lerobot.motors import Motor
|
||||
from lerobot.motors.damiao import DamiaoMotorsBus
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires physical Damiao motor and CAN interface")
|
||||
def test_damiao_motor():
|
||||
motors = {
|
||||
"joint_3": Motor(
|
||||
id=0x03,
|
||||
model="damiao",
|
||||
norm_mode="degrees",
|
||||
motor_type_str="dm4310",
|
||||
recv_id=0x13,
|
||||
),
|
||||
}
|
||||
|
||||
bus = DamiaoMotorsBus(port="can0", motors=motors)
|
||||
|
||||
try:
|
||||
print("Connecting...")
|
||||
bus.connect()
|
||||
print("✓ Connected")
|
||||
|
||||
print("Enabling torque...")
|
||||
bus.enable_torque()
|
||||
print("✓ Torque enabled")
|
||||
|
||||
print("Reading all states...")
|
||||
states = bus.sync_read_all_states()
|
||||
print(f"✓ States: {states}")
|
||||
|
||||
print("Reading position...")
|
||||
positions = bus.sync_read("Present_Position")
|
||||
print(f"✓ Position: {positions}")
|
||||
|
||||
print("Testing MIT control batch...")
|
||||
current_pos = states["joint_3"]["position"]
|
||||
commands = {"joint_3": (10.0, 0.5, current_pos, 0.0, 0.0)}
|
||||
bus._mit_control_batch(commands)
|
||||
print("✓ MIT control batch sent")
|
||||
|
||||
print("Disabling torque...")
|
||||
bus.disable_torque()
|
||||
print("✓ Torque disabled")
|
||||
|
||||
print("Setting zero position...")
|
||||
bus.set_zero_position()
|
||||
print("✓ Zero position set")
|
||||
|
||||
finally:
|
||||
print("Disconnecting...")
|
||||
bus.disconnect(disable_torque=True)
|
||||
print("✓ Disconnected")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_damiao_motor()
|
||||
@@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for training-time RTC helpers."""
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import RTCTrainingDelayDistribution
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig
|
||||
from lerobot.policies.rtc.training_time import apply_rtc_training_time, sample_rtc_delay
|
||||
|
||||
|
||||
def test_rtc_training_config_defaults():
|
||||
config = RTCTrainingConfig()
|
||||
assert config.enabled is False
|
||||
assert config.min_delay == 0
|
||||
assert config.max_delay == 0
|
||||
assert config.delay_distribution == RTCTrainingDelayDistribution.UNIFORM
|
||||
assert config.exp_decay == 1.0
|
||||
|
||||
|
||||
def test_sample_rtc_delay_uniform_range():
|
||||
cfg = RTCTrainingConfig(enabled=True, min_delay=1, max_delay=4)
|
||||
delays = sample_rtc_delay(cfg, batch_size=100, device=torch.device("cpu"))
|
||||
assert delays.min().item() >= 1
|
||||
assert delays.max().item() <= 4
|
||||
|
||||
|
||||
def test_apply_rtc_training_time_prefix_mask():
|
||||
time = torch.tensor([0.5])
|
||||
delays = torch.tensor([2])
|
||||
time_tokens, postfix_mask = apply_rtc_training_time(time, delays, seq_len=4)
|
||||
assert time_tokens.shape == (1, 4)
|
||||
assert postfix_mask.shape == (1, 4)
|
||||
# Delay=2 means the first two steps are prefix (time forced to 0.0) and only the last two are postfix.
|
||||
assert torch.allclose(time_tokens[0], torch.tensor([0.0, 0.0, 0.5, 0.5]))
|
||||
assert torch.equal(postfix_mask[0], torch.tensor([False, False, True, True]))
|
||||
Reference in New Issue
Block a user