mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 22:59:50 +00:00
Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a07ff640bb | |||
| 0753415244 | |||
| a054663e38 | |||
| aceb651e40 | |||
| 76a4529d29 | |||
| 39e14c086c | |||
| 0af2029328 | |||
| c027b2971c | |||
| b18cef2e26 | |||
| 5c6182176f | |||
| 55c0471db9 | |||
| ec04b7ce3a | |||
| 04cbf669cf | |||
| 3409ef0dc2 | |||
| 4483184875 | |||
| 149628dfd5 | |||
| bf337e716d | |||
| 736b43f3cf | |||
| f6b1c39b78 | |||
| 0c0c171d35 | |||
| 9cfb5ce546 | |||
| 366bef915c | |||
| 9cc203034e | |||
| 9e10eb4a77 | |||
| 6d34a986de | |||
| 961277d86e | |||
| 0b067df57d | |||
| 9ca680dce2 | |||
| 9919b16b36 | |||
| d36dfcdf71 | |||
| 13bfee1aa4 | |||
| 79688a09f2 | |||
| b2ff219624 | |||
| 66929c5935 | |||
| 5286ef8439 | |||
| fe068df711 | |||
| da41646073 | |||
| 46e19ae579 | |||
| 77dc49b3a3 | |||
| 33910673ec | |||
| 19dce78457 | |||
| 112b2d173a | |||
| b825880c40 |
@@ -18,6 +18,11 @@ name: Documentation
|
||||
on:
|
||||
# Allows running this workflow manually from the Actions tab
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version tag (e.g. v0.1.2) - Leave empty for standard main build'
|
||||
required: false
|
||||
type: string
|
||||
|
||||
# Triggers the workflow on push events to main for the docs folder
|
||||
push:
|
||||
@@ -54,7 +59,13 @@ jobs:
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: lerobot
|
||||
additional_args: --not_python_module ${{ github.event_name == 'release' && format('--version {0}', github.event.release.tag_name) || '' }}
|
||||
additional_args: >-
|
||||
--not_python_module
|
||||
${{
|
||||
(github.event_name == 'release' && format('--version {0}', github.event.release.tag_name)) ||
|
||||
(inputs.version != '' && format('--version {0}', inputs.version)) ||
|
||||
''
|
||||
}}
|
||||
secrets:
|
||||
token: ${{ secrets.HUGGINGFACE_PUSH }}
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
|
||||
@@ -20,8 +20,8 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
# Run on the 1st and 15th of every month at 09:00 UTC
|
||||
schedule:
|
||||
- cron: '0 2 1,15 * *'
|
||||
# schedule:
|
||||
# - cron: '0 2 1,15 * *'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
+1
-1
@@ -14,7 +14,7 @@ You can contribute in many ways:
|
||||
- **Documentation:** Improve examples, guides, and docstrings.
|
||||
- **Feedback:** Submit tickets related to bugs or desired new features.
|
||||
|
||||
If you are unsure where to start, join our [Discord Channel](https://discord.gg/JkrYNdmw).
|
||||
If you are unsure where to start, join our [Discord Channel](https://discord.gg/q8Dzzpym3f).
|
||||
|
||||
## Development Setup
|
||||
|
||||
|
||||
@@ -128,6 +128,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
||||
## Resources
|
||||
|
||||
- **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API.
|
||||
- **[Chinese Tutorials: LeRobot+SO-ARM101中文教程-同济子豪兄](https://zihao-ai.feishu.cn/wiki/space/7589642043471924447)** Detailed doc for assembling, teleoperate, dataset, train, deploy. Verified by Seed Studio and 5 global hackathon players.
|
||||
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
|
||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||
|
||||
+48
@@ -0,0 +1,48 @@
|
||||
# Security Policy
|
||||
|
||||
## Project Status & Philosophy
|
||||
|
||||
`lerobot` has so far been primarily a research and prototyping tool, which is why deployment security hasn’t been a strong focus until now. As `lerobot` continues to be adopted and deployed in production, we are paying much closer attention to these kinds of issues.
|
||||
|
||||
Fortunately, being an open-source project, the community can also help by reporting and fixing vulnerabilities. We appreciate your efforts to responsibly disclose your findings and will make every effort to acknowledge your contributions.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/huggingface/lerobot/security/advisories/new) tab.
|
||||
|
||||
The `lerobot` team will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.
|
||||
|
||||
#### Hugging Face Security Team
|
||||
|
||||
Since this project is part of the Hugging Face ecosystem, feel free to submit vulnerability reports directly to: **[security@huggingface.co](mailto:security@huggingface.co)**. Someone from the HF security team will review the report and recommend next steps.
|
||||
|
||||
#### Open Source Disclosures
|
||||
|
||||
If reporting a vulnerability specific to the open-source codebase (and not the underlying Hub infrastructure), you may also use [Huntr](https://huntr.com), a vulnerability disclosure program for open source software.
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Currently, we treat `lerobot` as a rolling release. We prioritize security updates for the latest available version (`main` branch).
|
||||
|
||||
| Version | Supported |
|
||||
| -------- | --------- |
|
||||
| Latest | ✅ |
|
||||
| < Latest | ❌ |
|
||||
|
||||
## Secure Usage Guidelines
|
||||
|
||||
`lerobot` is tightly coupled to the Hugging Face Hub for sharing data and pretrained policies. When downloading artifacts uploaded by others, you expose yourself to risks. Please read below for recommendations to keep your runtime and robot environment safe.
|
||||
|
||||
### Remote Artefacts (Weights & Policies)
|
||||
|
||||
Models and policies uploaded to the Hugging Face Hub come in different formats. We heavily recommend uploading and downloading models in the [`safetensors`](https://github.com/huggingface/safetensors) format.
|
||||
|
||||
`safetensors` was developed specifically to prevent arbitrary code execution on your system, which is critical when running software on physical hardware/robots.
|
||||
|
||||
To avoid loading models from unsafe formats (e.g., `pickle`), you should ensure you are prioritizing `safetensors` files.
|
||||
|
||||
### Remote Code
|
||||
|
||||
Some models or environments on the Hub may require `trust_remote_code=True` to run custom architecture code.
|
||||
|
||||
Please **always** verify the content of the modeling files when using this argument. We recommend setting a specific `revision` (commit hash) when loading remote code to ensure you protect yourself from unverified updates to the repository.
|
||||
@@ -7,8 +7,6 @@
|
||||
- sections:
|
||||
- local: il_robots
|
||||
title: Imitation Learning for Robots
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
- local: bring_your_own_policies
|
||||
title: Bring Your Own Policies
|
||||
- local: integrate_hardware
|
||||
@@ -29,6 +27,8 @@
|
||||
title: Porting Large Datasets
|
||||
- local: using_dataset_tools
|
||||
title: Using the Dataset Tools
|
||||
- local: dataset_subtask
|
||||
title: Using Subtasks in the Dataset
|
||||
title: "Datasets"
|
||||
- sections:
|
||||
- local: act
|
||||
@@ -99,11 +99,19 @@
|
||||
title: Unitree G1
|
||||
- local: earthrover_mini_plus
|
||||
title: Earth Rover Mini
|
||||
- local: omx
|
||||
title: OMX
|
||||
- local: openarm
|
||||
title: OpenArm
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
title: Phone
|
||||
title: "Teleoperators"
|
||||
- sections:
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
title: "Sensors"
|
||||
- sections:
|
||||
- local: torch_accelerators
|
||||
title: PyTorch accelerators
|
||||
@@ -113,6 +121,8 @@
|
||||
title: Notebooks
|
||||
- local: feetech
|
||||
title: Updating Feetech Firmware
|
||||
- local: damiao
|
||||
title: Damiao Motors and CAN Bus
|
||||
title: "Resources"
|
||||
- sections:
|
||||
- local: contributing
|
||||
|
||||
@@ -195,6 +195,7 @@ client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address="localhost:8080",
|
||||
policy_device="mps",
|
||||
client_device="cpu",
|
||||
policy_type="smolvla",
|
||||
pretrained_name_or_path="<user>/smolvla_async",
|
||||
chunk_size_threshold=0.5,
|
||||
|
||||
+95
-81
@@ -1,12 +1,22 @@
|
||||
# Cameras
|
||||
|
||||
LeRobot offers multiple options for video capture, including phone cameras, built-in laptop cameras, external webcams, and Intel RealSense cameras. To efficiently record frames from most cameras, you can use either the `OpenCVCamera` or `RealSenseCamera` class. For additional compatibility details on the `OpenCVCamera` class, refer to the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
|
||||
LeRobot offers multiple options for video capture:
|
||||
|
||||
### Finding your camera
|
||||
| Class | Supported Cameras |
|
||||
| ----------------- | ----------------------------------- |
|
||||
| `OpenCVCamera` | Phone, built-in laptop, USB webcams |
|
||||
| `ZMQCamera` | Network-connected cameras |
|
||||
| `RealSenseCamera` | Intel RealSense (with depth) |
|
||||
| `Reachy2Camera` | Reachy 2 robot cameras |
|
||||
|
||||
To instantiate a camera, you need a camera identifier. This identifier might change if you reboot your computer or re-plug your camera, a behavior mostly dependant on your operating system.
|
||||
> [!TIP]
|
||||
> For `OpenCVCamera` compatibility details, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
|
||||
|
||||
To find the camera indices of the cameras plugged into your system, run the following script:
|
||||
### Find your camera
|
||||
|
||||
Every camera requires a unique identifier to be instantiated, allowing you to distinguish between multiple connected devices.
|
||||
|
||||
`OpenCVCamera` and `RealSenseCamera` support auto-discovery. Run the command below to list available devices and their identifiers. Note that these identifiers may change after rebooting your computer or re-plugging the camera, depending on your operating system.
|
||||
|
||||
```bash
|
||||
lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
|
||||
@@ -14,7 +24,7 @@ lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
|
||||
|
||||
The output will look something like this if you have two cameras connected:
|
||||
|
||||
```
|
||||
```bash
|
||||
--- Detected Cameras ---
|
||||
Camera #0:
|
||||
Name: OpenCV Camera @ 0
|
||||
@@ -33,13 +43,37 @@ Camera #0:
|
||||
> [!WARNING]
|
||||
> When using Intel RealSense cameras in `macOS`, you could get this [error](https://github.com/IntelRealSense/librealsense/issues/12307): `Error finding RealSense cameras: failed to set power state`, this can be solved by running the same command with `sudo` permissions. Note that using RealSense cameras in `macOS` is unstable.
|
||||
|
||||
## Use Cameras
|
||||
`ZMQCamera` and `Reachy2Camera` do not support auto-discovery. They must be configured manually by providing their network address and port or robot SDK settings.
|
||||
|
||||
Below are two examples, demonstrating how to work with the API.
|
||||
## Use cameras
|
||||
|
||||
- **Asynchronous frame capture** using an OpenCV-based camera
|
||||
### Frame access modes
|
||||
|
||||
All camera classes implement three access modes for capturing frames:
|
||||
|
||||
| Method | Behavior | Blocks? | Best For |
|
||||
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------- | ---------------------------------------- |
|
||||
| `read()` | Waits for the camera hardware to return a frame. May block for a long time depending on the camera and SDK. | Yes | Simple scripts, sequential capture |
|
||||
| `async_read(timeout_ms)` | Returns the latest unconsumed frame from background thread. Blocks only if buffer is empty, up to `timeout_ms`. Raises `TimeoutError` if no frame arrives. | With a timeout | Control loops synchronized to camera FPS |
|
||||
| `read_latest(max_age_ms)` | Peeks at the most recent frame in buffer (may be stale). Raises `TimeoutError` if frame is older than `max_age_ms`. | No | UI visualization, logging, monitoring |
|
||||
|
||||
### Usage examples
|
||||
|
||||
The following examples show how to use the camera API to configure and capture frames from different camera types.
|
||||
|
||||
- **Blocking and non-blocking frame capture** using an OpenCV-based camera
|
||||
- **Color and depth capture** using an Intel RealSense camera
|
||||
|
||||
> [!WARNING]
|
||||
> Failing to cleanly disconnect cameras can cause resource leaks. Use the context manager protocol to ensure automatic cleanup:
|
||||
>
|
||||
> ```python
|
||||
> with OpenCVCamera(config) as camera:
|
||||
> ...
|
||||
> ```
|
||||
>
|
||||
> You can also call `connect()` and `disconnect()` manually, but always use a `finally` block for the latter.
|
||||
|
||||
<hfoptions id="shell_restart">
|
||||
<hfoption id="Open CV Camera">
|
||||
|
||||
@@ -60,16 +94,30 @@ config = OpenCVCameraConfig(
|
||||
)
|
||||
|
||||
# Instantiate and connect an `OpenCVCamera`, performing a warm-up read (default).
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
with OpenCVCamera(config) as camera:
|
||||
|
||||
# Read a frame synchronously — blocks until hardware delivers a new frame
|
||||
frame = camera.read()
|
||||
print(f"read() call returned frame with shape:", frame.shape)
|
||||
|
||||
# Read a frame asynchronously with a timeout — returns the latest unconsumed frame or waits up to timeout_ms for a new one
|
||||
try:
|
||||
for i in range(10):
|
||||
frame = camera.async_read(timeout_ms=200)
|
||||
print(f"async_read call returned frame {i} with shape:", frame.shape)
|
||||
except TimeoutError as e:
|
||||
print(f"No frame received within timeout: {e}")
|
||||
|
||||
# Instantly return a frame - returns the most recent frame captured by the camera
|
||||
try:
|
||||
initial_frame = camera.read_latest(max_age_ms=1000)
|
||||
for i in range(10):
|
||||
frame = camera.read_latest(max_age_ms=1000)
|
||||
print(f"read_latest call returned frame {i} with shape:", frame.shape)
|
||||
print(f"Was a new frame received by the camera? {not (initial_frame == frame).any()}")
|
||||
except TimeoutError as e:
|
||||
print(f"Frame too old: {e}")
|
||||
|
||||
# Read frames asynchronously in a loop via `async_read(timeout_ms)`
|
||||
try:
|
||||
for i in range(10):
|
||||
frame = camera.async_read(timeout_ms=200)
|
||||
print(f"Async frame {i} shape:", frame.shape)
|
||||
finally:
|
||||
camera.disconnect()
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -111,10 +159,10 @@ finally:
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Use your phone
|
||||
## Use your phone's camera
|
||||
|
||||
<hfoptions id="use phone">
|
||||
<hfoption id="Mac">
|
||||
<hfoption id="iPhone & macOS">
|
||||
|
||||
To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
|
||||
|
||||
@@ -124,83 +172,49 @@ To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
|
||||
|
||||
For more details, visit [Apple support](https://support.apple.com/en-gb/guide/mac-help/mchl77879b8a/mac).
|
||||
|
||||
Your iPhone should be detected automatically when running the camera setup script in the next section.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Linux">
|
||||
<hfoption id="OBS virtual camera">
|
||||
|
||||
If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera
|
||||
If you want to use your phone as a camera using OBS, follow these steps to set up a virtual camera.
|
||||
|
||||
1. _Install `v4l2loopback-dkms` and `v4l-utils`_. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using:
|
||||
1. _(Linux only) Install `v4l2loopback-dkms` and `v4l-utils`_. These packages create virtual camera devices and verify their settings. Install with:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```bash
|
||||
sudo apt install v4l2loopback-dkms v4l-utils
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
2. _Install [DroidCam](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
|
||||
3. _Install [OBS Studio](https://obsproject.com)_. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org):
|
||||
2. _Install the [DroidCam app](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
|
||||
3. _Download and install [OBS Studio](https://obsproject.com)_.
|
||||
4. _Download and install the [DroidCam OBS plugin](https://droidcam.app/obs)_.
|
||||
5. _Start OBS Studio_.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
flatpak install flathub com.obsproject.Studio
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
4. _Install the DroidCam OBS plugin_. This plugin integrates DroidCam with OBS Studio. Install it with:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
flatpak install flathub com.obsproject.Studio.Plugin.DroidCam
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
5. _Start OBS Studio_. Launch with:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
flatpak run com.obsproject.Studio
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`.
|
||||
7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in.
|
||||
6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480` to avoid the watermarks.
|
||||
7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video` or `OBS > Preferences... > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it.
|
||||
8. _Start virtual camera_. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide).
|
||||
9. _Verify the virtual camera setup_. Use `v4l2-ctl` to list the devices:
|
||||
9. _Verify the virtual camera setup and resolution_.
|
||||
- **Linux**: Use `v4l2-ctl` to list devices and check resolution:
|
||||
```bash
|
||||
v4l2-ctl --list-devices # find VirtualCam and note its /dev/videoX path
|
||||
v4l2-ctl -d /dev/videoX --get-fmt-video # replace with your VirtualCam path
|
||||
```
|
||||
You should see `VirtualCam` listed and resolution `640x480`.
|
||||
- **macOS**: Open Photo Booth or FaceTime and select "OBS Virtual Camera" as the input.
|
||||
- **Windows**: The native Camera app doesn't support virtual cameras. Use a video conferencing app (Zoom, Teams) or run `lerobot-find-cameras opencv` directly to verify.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
v4l2-ctl --list-devices
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
<details>
|
||||
<summary><strong>Troubleshooting</strong></summary>
|
||||
|
||||
You should see an entry like:
|
||||
> The virtual camera resolution is incorrect.
|
||||
|
||||
```
|
||||
VirtualCam (platform:v4l2loopback-000):
|
||||
/dev/video1
|
||||
```
|
||||
Delete the virtual camera source and recreate it. The resolution cannot be changed after creation.
|
||||
|
||||
10. _Check the camera resolution_. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`.
|
||||
> Error reading frame in background thread for OpenCVCamera(X): OpenCVCamera(X) frame width=640 or height=480 do not match configured width=1920 or height=1080.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
v4l2-ctl -d /dev/video1 --get-fmt-video
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
This error is caused by OBS Virtual Camera advertising a `1920x1080` resolution despite rescaling. The only fix for now is to comment out the width and height check in `_postprocess_image()`.
|
||||
|
||||
You should see an entry like:
|
||||
|
||||
```
|
||||
>>> Format Video Capture:
|
||||
>>> Width/Height : 640/480
|
||||
>>> Pixel Format : 'YUYV' (YUYV 4:2:2)
|
||||
```
|
||||
|
||||
Troubleshooting: If the resolution is not correct you will have to delete the Virtual Camera port and try again as it cannot be changed.
|
||||
|
||||
If everything is set up correctly, you can proceed with the rest of the tutorial.
|
||||
</details>
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
If everything is set up correctly, your phone will appear as a standard OpenCV camera and can be used with `OpenCVCamera`.
|
||||
|
||||
@@ -0,0 +1,165 @@
|
||||
# Damiao Motors and CAN Bus
|
||||
|
||||
This guide covers setup and usage of Damiao motors with LeRobot via CAN bus communication.
|
||||
|
||||
Currently, only Linux is supported, as the OpenArms CAN adapter only has drivers for Linux.
|
||||
|
||||
## Linux CAN Setup
|
||||
|
||||
Before using Damiao motors, you need to set up the CAN interface on your Linux system.
|
||||
|
||||
### Install CAN Utilities
|
||||
|
||||
```bash
|
||||
sudo apt-get install can-utils
|
||||
```
|
||||
|
||||
### Configure CAN Interface (Manual)
|
||||
|
||||
For standard CAN FD (recommended for OpenArms):
|
||||
|
||||
```bash
|
||||
sudo ip link set can0 down
|
||||
sudo ip link set can0 type can bitrate 1000000 dbitrate 5000000 fd on
|
||||
sudo ip link set can0 up
|
||||
```
|
||||
|
||||
For standard CAN (without FD):
|
||||
|
||||
```bash
|
||||
sudo ip link set can0 down
|
||||
sudo ip link set can0 type can bitrate 1000000
|
||||
sudo ip link set can0 up
|
||||
```
|
||||
|
||||
### Configure CAN Interface (Using LeRobot)
|
||||
|
||||
LeRobot provides a utility script to setup and test CAN interfaces:
|
||||
|
||||
```bash
|
||||
# Setup multiple interfaces (e.g., OpenArms Followers with 2 CAN buses)
|
||||
lerobot-setup-can --mode=setup --interfaces=can0,can1
|
||||
```
|
||||
|
||||
## Debugging CAN Communication
|
||||
|
||||
Use the built-in debug tools to test motor communication:
|
||||
|
||||
```bash
|
||||
# Test motors on all interfaces
|
||||
lerobot-setup-can --mode=test --interfaces=can0,can1
|
||||
|
||||
# Run speed/latency test
|
||||
lerobot-setup-can --mode=speed --interfaces=can0
|
||||
```
|
||||
|
||||
The test mode will scan for motors (IDs 0x01-0x08) and report which ones respond. Example output:
|
||||
|
||||
```
|
||||
can0: UP (CAN FD)
|
||||
Motor 0x01 (joint_1): ✓ FOUND
|
||||
→ Response 0x11 [FD]: 00112233...
|
||||
Motor 0x02 (joint_2): ✓ FOUND
|
||||
Motor 0x03 (joint_3): ✗ No response
|
||||
...
|
||||
Summary: 2/8 motors found
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
from lerobot.motors import Motor
|
||||
from lerobot.motors.damiao import DamiaoMotorsBus
|
||||
|
||||
# Define your motors with send/receive CAN IDs
|
||||
motors = {
|
||||
"joint_1": Motor(id=0x01, motor_type_str="dm8009", recv_id=0x11),
|
||||
"joint_2": Motor(id=0x02, motor_type_str="dm4340", recv_id=0x12),
|
||||
"joint_3": Motor(id=0x03, motor_type_str="dm4310", recv_id=0x13),
|
||||
}
|
||||
|
||||
# Create the bus
|
||||
bus = DamiaoMotorsBus(
|
||||
port="can0", # Linux socketcan interface
|
||||
motors=motors,
|
||||
)
|
||||
|
||||
# Connect
|
||||
bus.connect()
|
||||
```
|
||||
|
||||
### Reading Motor States
|
||||
|
||||
```python
|
||||
# Read single motor position (degrees)
|
||||
position = bus.read("Present_Position", "joint_1")
|
||||
|
||||
# Read from multiple motors
|
||||
positions = bus.sync_read("Present_Position") # All motors
|
||||
positions = bus.sync_read("Present_Position", ["joint_1", "joint_2"])
|
||||
|
||||
# Read all states at once (position, velocity, torque)
|
||||
states = bus.sync_read_all_states()
|
||||
# Returns: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
|
||||
```
|
||||
|
||||
### Writing Motor Commands
|
||||
|
||||
```python
|
||||
# Enable torque
|
||||
bus.enable_torque()
|
||||
|
||||
# Set goal position (degrees)
|
||||
bus.write("Goal_Position", "joint_1", 45.0)
|
||||
|
||||
# Set positions for multiple motors
|
||||
bus.sync_write("Goal_Position", {
|
||||
"joint_1": 45.0,
|
||||
"joint_2": -30.0,
|
||||
"joint_3": 90.0,
|
||||
})
|
||||
|
||||
# Disable torque
|
||||
bus.disable_torque()
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| -------------- | --------- | ----------------------------------------------------------- |
|
||||
| `port` | - | CAN interface (`can0`) or serial port (`/dev/cu.usbmodem*`) |
|
||||
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
|
||||
| `bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
|
||||
| `data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
|
||||
|
||||
## Motor Configuration
|
||||
|
||||
Each motor requires:
|
||||
|
||||
- `id`: CAN ID for sending commands
|
||||
- `motor_type`: One of the supported motor types (e.g., `"dm8009"`, `"dm4340"`)
|
||||
- `recv_id`: CAN ID for receiving responses
|
||||
|
||||
OpenArms default IDs follow the pattern: send ID `0x0N`, receive ID `0x1N` where N is the joint number.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### No Response from Motors
|
||||
|
||||
1. **Check power**
|
||||
2. **Verify CAN wiring**: Check CAN-H, CAN-L, and GND connections
|
||||
3. **Check motor IDs**: Use Damiao Debugging Tools to verify/configure IDs
|
||||
4. **Test CAN interface**: Run `candump can0` to see if messages are being received
|
||||
5. **Run diagnostics**: `lerobot-setup-can --mode=test --interfaces=can0`
|
||||
|
||||
### Motor Timeout Parameter
|
||||
|
||||
If motors were configured with timeout=0, they won't respond to commands. Use Damiao Debugging Tools to set a non-zero timeout value.
|
||||
|
||||
### Verify CAN FD Status
|
||||
|
||||
```bash
|
||||
ip -d link show can0 | grep fd
|
||||
```
|
||||
@@ -0,0 +1,278 @@
|
||||
# Using Subtasks in LeRobot Datasets
|
||||
|
||||
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
|
||||
|
||||
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
|
||||
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
|
||||
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
|
||||
|
||||
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
|
||||
|
||||
## What are Subtasks?
|
||||
|
||||
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
|
||||
|
||||
1. "Approach the apple"
|
||||
2. "Grasp the apple"
|
||||
3. "Lift the apple"
|
||||
4. "Move to basket"
|
||||
5. "Release the apple"
|
||||
|
||||
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
|
||||
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
|
||||
width="80%"
|
||||
/>
|
||||
|
||||
<p>
|
||||
<em>Figure: Overview of subtask annotation.</em>
|
||||
</p>
|
||||
|
||||
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
Subtask information is stored in the dataset metadata:
|
||||
|
||||
```
|
||||
my-dataset/
|
||||
├── data/
|
||||
│ └── ...
|
||||
├── meta/
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ ├── tasks.parquet
|
||||
│ ├── subtasks.parquet # Subtask index → subtask string mapping
|
||||
│ └── episodes/
|
||||
│ └── ...
|
||||
└── videos/
|
||||
└── ...
|
||||
```
|
||||
|
||||
### Subtasks Parquet File
|
||||
|
||||
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
|
||||
|
||||
| subtask_index | subtask (index column) |
|
||||
| ------------- | ---------------------- |
|
||||
| 0 | "Approach the apple" |
|
||||
| 1 | "Grasp the apple" |
|
||||
| 2 | "Lift the apple" |
|
||||
| ... | ... |
|
||||
|
||||
### Frame-Level Annotations
|
||||
|
||||
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
|
||||
|
||||
```python
|
||||
# Example frame data in the parquet file
|
||||
{
|
||||
"index": 42,
|
||||
"timestamp": 1.4,
|
||||
"episode_index": 0,
|
||||
"task_index": 0,
|
||||
"subtask_index": 2, # References "Lift the apple"
|
||||
"observation.state": [...],
|
||||
"action": [...],
|
||||
}
|
||||
```
|
||||
|
||||
## Annotating Datasets with Subtasks
|
||||
|
||||
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
|
||||
|
||||
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
|
||||
|
||||
After completing your annotation:
|
||||
|
||||
1. Click "Push to Hub" to upload your annotated dataset
|
||||
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
|
||||
|
||||
## Loading Datasets with Subtasks
|
||||
|
||||
When you load a dataset with subtask annotations, the subtask information is automatically available:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Load a dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Access a sample
|
||||
sample = dataset[100]
|
||||
|
||||
# The sample includes both task and subtask information
|
||||
print(sample["task"]) # "Collect the fruit"
|
||||
print(sample["subtask"]) # "Grasp the apple"
|
||||
print(sample["task_index"]) # tensor(0)
|
||||
print(sample["subtask_index"]) # tensor(2)
|
||||
```
|
||||
|
||||
### Checking for Subtask Support
|
||||
|
||||
You can check if a dataset has subtask annotations:
|
||||
|
||||
```python
|
||||
# Check if subtasks are available
|
||||
has_subtasks = (
|
||||
"subtask_index" in dataset.features
|
||||
and dataset.meta.subtasks is not None
|
||||
)
|
||||
|
||||
if has_subtasks:
|
||||
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
|
||||
print("Subtasks:", list(dataset.meta.subtasks.index))
|
||||
```
|
||||
|
||||
## Using Subtasks for Training
|
||||
|
||||
### With the Tokenizer Processor
|
||||
|
||||
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
|
||||
|
||||
```python
|
||||
from lerobot.processor.tokenizer_processor import TokenizerProcessor
|
||||
from lerobot.processor.pipeline import ProcessorPipeline
|
||||
|
||||
# Create a tokenizer processor
|
||||
tokenizer_processor = TokenizerProcessor(
|
||||
tokenizer_name_or_path="google/paligemma-3b-pt-224",
|
||||
padding="max_length",
|
||||
max_length=64,
|
||||
)
|
||||
|
||||
# The processor will automatically tokenize subtasks if present in the batch
|
||||
# and add them to the observation under:
|
||||
# - "observation.subtask.tokens"
|
||||
# - "observation.subtask.attention_mask"
|
||||
```
|
||||
|
||||
When subtasks are available in the batch, the tokenizer processor adds:
|
||||
|
||||
- `observation.subtask.tokens`: Tokenized subtask text
|
||||
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
|
||||
|
||||
### DataLoader with Subtasks
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=16,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
# Access subtask information in the batch
|
||||
subtasks = batch["subtask"] # List of subtask strings
|
||||
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
|
||||
|
||||
# Use for training hierarchical policies or reward models
|
||||
print(f"Batch subtasks: {set(subtasks)}")
|
||||
```
|
||||
|
||||
## Example Datasets with Subtask Annotations
|
||||
|
||||
Try loading a dataset with subtask annotations:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Example dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Explore the subtasks
|
||||
print("Available subtasks:")
|
||||
for subtask_name in dataset.meta.subtasks.index:
|
||||
print(f" - {subtask_name}")
|
||||
|
||||
# Get subtask distribution
|
||||
subtask_counts = {}
|
||||
for i in range(len(dataset)):
|
||||
sample = dataset[i]
|
||||
subtask = sample["subtask"]
|
||||
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
|
||||
|
||||
print("\nSubtask distribution:")
|
||||
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {subtask}: {count} frames")
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Hierarchical Policy Training
|
||||
|
||||
Train policies that predict both actions and current subtask:
|
||||
|
||||
```python
|
||||
class HierarchicalPolicy(nn.Module):
|
||||
def __init__(self, num_subtasks):
|
||||
super().__init__()
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
|
||||
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
actions = self.action_head(features)
|
||||
subtask_logits = self.subtask_head(features)
|
||||
return actions, subtask_logits
|
||||
```
|
||||
|
||||
### 2. Stage-Aware Reward Modeling (SARM)
|
||||
|
||||
Build reward models that understand task progression:
|
||||
|
||||
```python
|
||||
# SARM predicts:
|
||||
# - Stage: Which subtask is being executed (discrete)
|
||||
# - Progress: How far along the subtask (continuous 0-1)
|
||||
|
||||
class SARMRewardModel(nn.Module):
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
stage_logits = self.stage_classifier(features)
|
||||
progress = self.progress_regressor(features)
|
||||
return stage_logits, progress
|
||||
```
|
||||
|
||||
### 3. Progress Visualization
|
||||
|
||||
Monitor robot execution by tracking subtask progression:
|
||||
|
||||
```python
|
||||
def visualize_execution(model, observations):
|
||||
for t, obs in enumerate(observations):
|
||||
action, subtask_logits = model(obs)
|
||||
predicted_subtask = subtask_names[subtask_logits.argmax()]
|
||||
print(f"t={t}: Executing '{predicted_subtask}'")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### LeRobotDataset Properties
|
||||
|
||||
| Property | Type | Description |
|
||||
| --------------------------- | ---------------------- | ------------------------------------------ |
|
||||
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
|
||||
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
|
||||
|
||||
### Sample Keys
|
||||
|
||||
When subtasks are available, each sample includes:
|
||||
|
||||
| Key | Type | Description |
|
||||
| --------------- | -------------- | ------------------------------------ |
|
||||
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
|
||||
| `subtask` | `str` | Natural language subtask description |
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
|
||||
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
|
||||
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
|
||||
@@ -1,5 +1,11 @@
|
||||
# EarthRover Mini Plus
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Earth_Rover_Mini_5_240c9adc-4f9e-44b7-982f-5d1dc24af1d8.png.webp"
|
||||
alt="EarthRover Mini Plus"
|
||||
width="70%"
|
||||
/>
|
||||
|
||||
The EarthRover Mini Plus is a fully open source mobile robot that connects through the cloud using the Frodobots SDK. This lets you control the robot and record datasets for training AI models.
|
||||
|
||||
## What You Need
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
# LeKiwi
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/1740517739083.jpeg"
|
||||
alt="LeKiwi"
|
||||
width="70%"
|
||||
/>
|
||||
|
||||
In the steps below, we explain how to assemble the LeKiwi mobile robot.
|
||||
|
||||
## Source the parts
|
||||
|
||||
@@ -42,6 +42,7 @@ lerobot-eval \
|
||||
```
|
||||
|
||||
- `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.).
|
||||
- `--env.task_ids` picks task ids to run (`[0]`, `[1,2,3]`, etc.). Omit this flag (or set it to `null`) to run all tasks in the suite.
|
||||
- `--eval.batch_size` controls how many environments run in parallel.
|
||||
- `--eval.n_episodes` sets how many episodes to run in total.
|
||||
|
||||
|
||||
@@ -0,0 +1,197 @@
|
||||
## Order and Assemble the parts
|
||||
|
||||
First, assemble the OMX hardware following the official assembly guide.
|
||||
|
||||
OMX Assembly Guide: https://ai.robotis.com/omx/assembly_guide_omx.html
|
||||
|
||||
OMX robots are shipped preconfigured from the factory. Motor IDs, communication parameters, and joint offsets are already set, so no additional motor setup or calibration is required before using LeRobot.
|
||||
|
||||
## Install LeRobot 🤗
|
||||
|
||||
To install LeRobot, follow our [Installation Guide](./installation)
|
||||
|
||||
In addition to these instructions, you need to install the Dynamixel SDK:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dynamixel]"
|
||||
```
|
||||
|
||||
## Connect the robot
|
||||
|
||||
To find the port for each bus servo adapter, run this script:
|
||||
|
||||
```bash
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
This command runs and when prompted, disconnect the USB cable from either the leader or follower arm and press Enter. The output will show 'The port of this MotorsBus is [port]'. This identifies the port for the disconnected arm. Repeat for the other arm to identify both ports.
|
||||
|
||||
<hfoptions id="find_port">
|
||||
<hfoption id="Mac">
|
||||
|
||||
Example output on macOS:
|
||||
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the USB cable from your MotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect corresponding leader or follower arm and press Enter...]
|
||||
|
||||
The port of this MotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the USB cable.
|
||||
```
|
||||
|
||||
Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Linux">
|
||||
|
||||
On Linux, we strongly recommend using udev rules to assign persistent and human-readable device names to the OMX leader and follower arms. This avoids issues where device names such as ttyACM0 and ttyACM1 change when the robot is unplugged, replugged, or when the system is rebooted.
|
||||
|
||||
#### 1. Find your device serial numbers
|
||||
|
||||
You should have obtained the port numbers like ../../ttyACM? for the leader and follower using `lerobot-find-port`. You can match those results with the serial numbers using the `ls -l /dev/serial/by-id/` command.
|
||||
To create udev rules, you need the unique serial number for each OMX device. The easiest way is to list devices under:
|
||||
|
||||
```bash
|
||||
ls -l /dev/serial/by-id/
|
||||
```
|
||||
|
||||
You will see output similar to:
|
||||
|
||||
```bash
|
||||
usb-ROBOTIS_OpenRB-150_228BDD7B503059384C2E3120FF0A2B19-if00 -> ../../ttyACM0
|
||||
usb-ROBOTIS_OpenRB-150_67E1ED68503059384C2E3120FF092234-if00 -> ../../ttyACM1
|
||||
```
|
||||
|
||||
In each line, the serial number is the long string after `usb-ROBOTIS_OpenRB-150_` and before `-if00`.
|
||||
|
||||
Follower serial: `228BDD7B503059384C2E3120FF0A2B19`
|
||||
|
||||
Leader serial: `67E1ED68503059384C2E3120FF092234`
|
||||
|
||||
#### 2. Create the udev rule
|
||||
|
||||
Create a new udev rule file:
|
||||
|
||||
```bash
|
||||
sudo nano /etc/udev/rules.d/99-omx.rules
|
||||
```
|
||||
|
||||
Paste the following lines, replacing the serial numbers with the values you found above:
|
||||
|
||||
```bash
|
||||
SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="228BDD7B503059384C2E3120FF0A2B19", SYMLINK+="omx_follower"
|
||||
SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="67E1ED68503059384C2E3120FF092234", SYMLINK+="omx_leader"
|
||||
```
|
||||
|
||||
Save the file and reload udev rules:
|
||||
|
||||
```bash
|
||||
sudo udevadm control --reload-rules
|
||||
sudo udevadm trigger
|
||||
```
|
||||
|
||||
Now unplug and replug both devices once.
|
||||
|
||||
#### 3. Verify the symlinks
|
||||
|
||||
Check that the persistent device names exist:
|
||||
|
||||
```bash
|
||||
ls -l /dev/omx_follower /dev/omx_leader
|
||||
```
|
||||
|
||||
You should see them pointing to ttyACM\* devices:
|
||||
|
||||
```bash
|
||||
/dev/omx_follower -> ttyACM*
|
||||
/dev/omx_leader -> ttyACM*
|
||||
```
|
||||
|
||||
These names remain stable across reboots and reconnections.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Teleoperate
|
||||
|
||||
After identifying the correct ports, you can directly teleoperate the follower arm using the leader arm.
|
||||
|
||||
<hfoptions id="teleoperate">
|
||||
<hfoption id="Mac">
|
||||
|
||||
### Teleoperate without camera
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=<your_follower_port> \
|
||||
--robot.id=omx_follower_arm \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=<your_leader_port> \
|
||||
--teleop.id=omx_leader_arm
|
||||
```
|
||||
|
||||
During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps.
|
||||
|
||||
### Teleoperate with camera
|
||||
|
||||
You can also enable camera input during teleoperation by providing a camera configuration for the follower arm.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=<your_follower_port> \
|
||||
--robot.id=omx_follower_arm \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=<your_leader_port> \
|
||||
--teleop.id=omx_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Linux">
|
||||
|
||||
### Teleoperate without camera
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=/dev/omx_follower \
|
||||
--robot.id=omx_follower_arm \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=/dev/omx_leader \
|
||||
--teleop.id=omx_leader_arm
|
||||
```
|
||||
|
||||
During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps.
|
||||
|
||||
### Teleoperate with camera
|
||||
|
||||
You can also enable camera input during teleoperation by providing a camera configuration for the follower arm.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=/dev/omx_follower \
|
||||
--robot.id=omx_follower_arm \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=/dev/omx_leader \
|
||||
--teleop.id=omx_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own.
|
||||
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/robotis).
|
||||
@@ -0,0 +1,276 @@
|
||||
# OpenArm
|
||||
|
||||
[OpenArm](https://openarm.dev) is an open-source 7DOF humanoid arm designed for physical AI research and deployment.
|
||||
|
||||
To get your OpenArm, assembled or DIY, and join the global community, browse verified and certified manufacturers worldwide at [openarm.dev](https://openarm.dev).
|
||||
|
||||
## What's Unique?
|
||||
|
||||
- **Human-Scale Design**: OpenArm is designed with human-like proportions, scaled for a person around 160-165cm tall. This provides an optimal balance between practical reach and manageable inertia for safe, responsive operation.
|
||||
|
||||
- **Safety-First Architecture**: Built with QDD backdrivable motors and high compliance, OpenArm prioritizes safe human-robot interaction while maintaining practical payload capabilities (6.0kg peak / 4.1kg nominal) for real-world tasks.
|
||||
|
||||
- **Built for Durability**: Critical structural components use aluminum and stainless steel construction, ensuring robust performance for repetitive data collection and continuous research use.
|
||||
|
||||
- **Fully Accessible & Buildable**: Every component, from CNC parts and 3D-printed casings to electrical wiring is designed to be purchasable and buildable by individual researchers and labs, with complete fabrication data provided.
|
||||
|
||||
- **Practical & Affordable**: At $6,500 USD for a complete bimanual system, OpenArm delivers research-grade capabilities at a fraction of traditional humanoid robot costs.
|
||||
|
||||
## Platform Requirements
|
||||
|
||||
<Tip warning={true}>
|
||||
**Linux Only**: OpenArm currently only works on Linux. The CAN bus USB adapter
|
||||
does not have macOS drivers and has not been tested on Windows.
|
||||
</Tip>
|
||||
|
||||
## Safety Guide
|
||||
|
||||
Before operating OpenArm, please read the [official safety guide](https://docs.openarm.dev/getting-started/safety-guide). Key points:
|
||||
|
||||
- **Secure installation**: Fasten the arm to a flat, stable surface with screws or clamps
|
||||
- **Safe distance**: Keep body parts and objects outside the range of motion during operation
|
||||
- **Protective equipment**: Always wear safety goggles; use additional PPE as needed
|
||||
- **Payload limits**: Do not exceed specified payload limits (6.0kg peak / 4.1kg nominal per arm)
|
||||
- **Emergency stop**: Know the location and operation of the emergency stop device
|
||||
- **Regular inspection**: Check for loose screws, damaged mechanical limits, unusual noises, and wiring damage
|
||||
|
||||
## Hardware Setup
|
||||
|
||||
Follow the official [OpenArm hardware documentation](https://docs.openarm.dev) for:
|
||||
|
||||
- Bill of materials and sourcing
|
||||
- 3D printing instructions
|
||||
- Mechanical assembly
|
||||
- Electrical wiring
|
||||
|
||||
The hardware repositories are available at [github.com/enactic/openarm](https://github.com/enactic/openarm).
|
||||
|
||||
## CAN Bus Setup
|
||||
|
||||
OpenArm uses CAN bus communication with Damiao motors. Once you have the CAN bus USB adapter plugged into your Linux PC, follow the [Damiao Motors and CAN Bus guide](./damiao) to configure the interface.
|
||||
|
||||
Quick setup:
|
||||
|
||||
```bash
|
||||
# Setup CAN interfaces
|
||||
lerobot-setup-can --mode=setup --interfaces=can0,can1
|
||||
|
||||
# Test motor communication
|
||||
lerobot-setup-can --mode=test --interfaces=can0,can1
|
||||
```
|
||||
|
||||
## Install LeRobot 🤗
|
||||
|
||||
Follow our [Installation Guide](./installation), then install the Damiao motor support:
|
||||
|
||||
```bash
|
||||
pip install -e ".[damiao]"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Follower Arm (Robot)
|
||||
|
||||
<hfoptions id="follower">
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=openarm_follower \
|
||||
--robot.port=can0 \
|
||||
--robot.side=right \
|
||||
--robot.id=my_openarm_follower
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
```python
|
||||
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||
|
||||
config = OpenArmFollowerConfig(
|
||||
port="can0",
|
||||
side="right", # or "left" for left arm
|
||||
id="my_openarm_follower",
|
||||
)
|
||||
|
||||
follower = OpenArmFollower(config)
|
||||
follower.connect()
|
||||
|
||||
# Read current state
|
||||
obs = follower.get_observation()
|
||||
print(obs)
|
||||
|
||||
# Send action (position in degrees)
|
||||
action = {
|
||||
"joint_1.pos": 0.0,
|
||||
"joint_2.pos": 0.0,
|
||||
"joint_3.pos": 0.0,
|
||||
"joint_4.pos": 45.0,
|
||||
"joint_5.pos": 0.0,
|
||||
"joint_6.pos": 0.0,
|
||||
"joint_7.pos": 0.0,
|
||||
"gripper.pos": 0.0,
|
||||
}
|
||||
follower.send_action(action)
|
||||
|
||||
follower.disconnect()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Leader Arm (Teleoperator)
|
||||
|
||||
The leader arm is used for teleoperation - manually moving it to control the follower arm.
|
||||
|
||||
<hfoptions id="leader">
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=openarm_leader \
|
||||
--teleop.port=can1 \
|
||||
--teleop.id=my_openarm_leader
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
```python
|
||||
from lerobot.teleoperators.openarm_leader import OpenArmLeader, OpenArmLeaderConfig
|
||||
|
||||
config = OpenArmLeaderConfig(
|
||||
port="can1",
|
||||
id="my_openarm_leader",
|
||||
manual_control=True, # Disable torque for manual movement
|
||||
)
|
||||
|
||||
leader = OpenArmLeader(config)
|
||||
leader.connect()
|
||||
|
||||
# Read current position (as action to send to follower)
|
||||
action = leader.get_action()
|
||||
print(action)
|
||||
|
||||
leader.disconnect()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Teleoperation
|
||||
|
||||
To teleoperate OpenArm with leader-follower control:
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=openarm_follower \
|
||||
--robot.port=can0 \
|
||||
--robot.side=right \
|
||||
--robot.id=my_follower \
|
||||
--teleop.type=openarm_leader \
|
||||
--teleop.port=can1 \
|
||||
--teleop.id=my_leader
|
||||
```
|
||||
|
||||
### Bimanual Teleoperation
|
||||
|
||||
To teleoperate a bimanual OpenArm setup with two leader and two follower arms:
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.left_arm_config.port=can0 \
|
||||
--robot.left_arm_config.side=left \
|
||||
--robot.right_arm_config.port=can1 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.id=my_bimanual_follower \
|
||||
--teleop.type=bi_openarm_leader \
|
||||
--teleop.left_arm_config.port=can2 \
|
||||
--teleop.right_arm_config.port=can3 \
|
||||
--teleop.id=my_bimanual_leader
|
||||
```
|
||||
|
||||
### Recording Data
|
||||
|
||||
To record a dataset during teleoperation:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=openarm_follower \
|
||||
--robot.port=can0 \
|
||||
--robot.side=right \
|
||||
--robot.id=my_follower \
|
||||
--teleop.type=openarm_leader \
|
||||
--teleop.port=can1 \
|
||||
--teleop.id=my_leader \
|
||||
--repo-id=my_hf_username/my_openarm_dataset \
|
||||
--fps=30 \
|
||||
--num-episodes=10
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Follower Configuration
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| --------------------- | --------- | ---------------------------------------------------------- |
|
||||
| `port` | - | CAN interface (e.g., `can0`) |
|
||||
| `side` | `None` | Arm side: `"left"`, `"right"`, or `None` for custom limits |
|
||||
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
|
||||
| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
|
||||
| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
|
||||
| `max_relative_target` | `None` | Safety limit for relative target positions |
|
||||
| `position_kp` | Per-joint | Position control proportional gains |
|
||||
| `position_kd` | Per-joint | Position control derivative gains |
|
||||
|
||||
### Leader Configuration
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| ------------------ | --------- | ----------------------------------- |
|
||||
| `port` | - | CAN interface (e.g., `can1`) |
|
||||
| `manual_control` | `True` | Disable torque for manual movement |
|
||||
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
|
||||
| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
|
||||
| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
|
||||
|
||||
## Motor Configuration
|
||||
|
||||
OpenArm uses Damiao motors with the following default configuration:
|
||||
|
||||
| Joint | Motor Type | Send ID | Recv ID |
|
||||
| --------------------------- | ---------- | ------- | ------- |
|
||||
| joint_1 (Shoulder pan) | DM8009 | 0x01 | 0x11 |
|
||||
| joint_2 (Shoulder lift) | DM8009 | 0x02 | 0x12 |
|
||||
| joint_3 (Shoulder rotation) | DM4340 | 0x03 | 0x13 |
|
||||
| joint_4 (Elbow flex) | DM4340 | 0x04 | 0x14 |
|
||||
| joint_5 (Wrist roll) | DM4310 | 0x05 | 0x15 |
|
||||
| joint_6 (Wrist pitch) | DM4310 | 0x06 | 0x16 |
|
||||
| joint_7 (Wrist rotation) | DM4310 | 0x07 | 0x17 |
|
||||
| gripper | DM4310 | 0x08 | 0x18 |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### No Response from Motors
|
||||
|
||||
1. Check power supply connections
|
||||
2. Verify CAN wiring (CAN-H, CAN-L, GND)
|
||||
3. Run diagnostics: `lerobot-setup-can --mode=test --interfaces=can0`
|
||||
4. See the [Damiao troubleshooting guide](./damiao#troubleshooting) for more details
|
||||
|
||||
### CAN Interface Not Found
|
||||
|
||||
Ensure the CAN interface is configured:
|
||||
|
||||
```bash
|
||||
ip link show can0
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- [OpenArm Website](https://openarm.dev)
|
||||
- [OpenArm Documentation](https://docs.openarm.dev)
|
||||
- [OpenArm GitHub](https://github.com/enactic/openarm)
|
||||
- [Safety Guide](https://docs.openarm.dev/getting-started/safety-guide)
|
||||
- [Damiao Motors and CAN Bus](./damiao)
|
||||
@@ -1,5 +1,18 @@
|
||||
# SO-101
|
||||
|
||||
<div style="display: flex; align-items: center; gap: 10px;">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/SO101_Follower.webp"
|
||||
alt="SO-101"
|
||||
width="60%"
|
||||
/>
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/SO101_Leader.webp"
|
||||
alt="SO-101"
|
||||
width="60%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
In the steps below, we explain how to assemble our flagship robot, the SO-101.
|
||||
|
||||
## Source the parts
|
||||
|
||||
@@ -188,7 +188,105 @@ Press `Ctrl+C` to stop the policy.
|
||||
|
||||
## Running in Simulation Mode (MuJoCo)
|
||||
|
||||
You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config.
|
||||
You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI.
|
||||
|
||||
### Calibrate Exoskeleton Teleoperator
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo
|
||||
```
|
||||
|
||||
### Teleoperate in Simulation
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--fps=100
|
||||
```
|
||||
|
||||
### Record Dataset in Simulation
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
|
||||
|
||||
---
|
||||
|
||||
## Running on Real Robot
|
||||
|
||||
Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot.
|
||||
|
||||
### Start the Camera Server
|
||||
|
||||
On the robot, start the ZMQ image server:
|
||||
|
||||
```bash
|
||||
python src/lerobot/cameras/zmq/image_server.py
|
||||
```
|
||||
|
||||
Keep this running in a separate terminal for camera streaming during recording.
|
||||
|
||||
### Teleoperate Real Robot
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--fps=100
|
||||
```
|
||||
|
||||
### Record Dataset on Real Robot
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
**Note**: Update `server_address` to match your robot's camera server IP.
|
||||
|
||||
Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real)
|
||||
|
||||
---
|
||||
|
||||
## Additional Resources
|
||||
|
||||
|
||||
@@ -95,26 +95,26 @@ Convert an image-based dataset to video format, creating a new LeRobotDataset wh
|
||||
# Local-only: Save to a custom output directory (no hub push)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
# Save with new repo_id (local storage)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video
|
||||
--operation.type convert_image_to_video
|
||||
|
||||
# Convert and push to Hugging Face Hub
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
# Convert with custom video codec and quality settings
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.vcodec libsvtav1 \
|
||||
--operation.pix_fmt yuv420p \
|
||||
@@ -124,16 +124,23 @@ lerobot-edit-dataset \
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.episode_indices "[0, 1, 2, 5, 10]"
|
||||
|
||||
# Convert with multiple workers for parallel processing
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.num_workers 8
|
||||
|
||||
# For memory-constrained systems, users can now specify limits:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.max_episodes_per_batch 50 \
|
||||
--operation.max_frames_per_batch 10000
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
@@ -81,24 +81,25 @@ def replay(cfg: ReplayConfig):
|
||||
actions = dataset.hf_dataset.select_columns(ACTION)
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
try:
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
key = f"{name.removeprefix('main_')}.pos"
|
||||
action[key] = action_array[i].item()
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
key = f"{name.removeprefix('main_')}.pos"
|
||||
action[key] = action_array[i].item()
|
||||
|
||||
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
|
||||
action["elbow_flex.pos"] -= 90
|
||||
robot.send_action(action)
|
||||
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
|
||||
action["elbow_flex.pos"] -= 90
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
|
||||
|
||||
robot.disconnect()
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
|
||||
finally:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+45
-43
@@ -78,40 +78,24 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
print("Starting evaluate loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -120,24 +104,42 @@ def main():
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+45
-44
@@ -74,40 +74,23 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_record")
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
try:
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {recorded_episodes}")
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {recorded_episodes}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
@@ -115,26 +98,44 @@ def main():
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+17
-15
@@ -42,25 +42,27 @@ def main():
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx][ACTION][i])
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
robot.disconnect()
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
finally:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -142,38 +142,24 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -182,24 +168,41 @@ def main():
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -149,38 +149,23 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_record")
|
||||
|
||||
if not robot.is_connected or not phone.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
try:
|
||||
if not robot.is_connected or not phone.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop. Move your phone to teleoperate the robot...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print("Starting record loop. Move your phone to teleoperate the robot...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
@@ -188,25 +173,43 @@ def main():
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -73,32 +73,34 @@ def main():
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i])
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
finally:
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -0,0 +1,480 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Mirror a bimanual robot dataset using SLURM for distributed video processing.
|
||||
|
||||
This script creates a mirrored version of a dataset where:
|
||||
1. Left and right arm observations/actions are swapped
|
||||
2. Joint values are inverted according to a mirroring mask
|
||||
3. Video frames are horizontally flipped (parallelized via SLURM)
|
||||
|
||||
Example usage:
|
||||
```shell
|
||||
# SLURM execution
|
||||
python examples/port_datasets/slurm_mirror_dataset.py \
|
||||
--repo-id pepijn/openarm_bimanual \
|
||||
--output-repo-id pepijn/openarm_bimanual_mirrored \
|
||||
--logs-dir /fsx/user/logs \
|
||||
--partition hopper-cpu
|
||||
|
||||
# Local execution (for debugging)
|
||||
python examples/port_datasets/slurm_mirror_dataset.py \
|
||||
--repo-id pepijn/openarm_bimanual \
|
||||
--output-repo-id pepijn/openarm_bimanual_mirrored \
|
||||
--slurm 0 \
|
||||
--push-to-hub
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENARM_MIRRORING_MASK = {
|
||||
"joint_1": -1,
|
||||
"joint_2": -1,
|
||||
"joint_3": -1,
|
||||
"joint_4": 1,
|
||||
"joint_5": -1,
|
||||
"joint_6": -1,
|
||||
"joint_7": -1,
|
||||
"gripper": 1,
|
||||
}
|
||||
|
||||
|
||||
class MirrorVideos(PipelineStep):
|
||||
"""Pipeline step that mirrors video files for assigned episodes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
output_repo_id: str,
|
||||
root: str | None = None,
|
||||
output_root: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.output_repo_id = output_repo_id
|
||||
self.root = root
|
||||
self.output_root = output_root
|
||||
self.vcodec = vcodec
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
import logging
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from datasets.utils.tqdm import disable_progress_bars
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
disable_progress_bars()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def swap_left_right_name(name: str) -> str:
|
||||
result = name.replace("left_", "LEFT_PLACEHOLDER_")
|
||||
result = result.replace("right_", "left_")
|
||||
result = result.replace("LEFT_PLACEHOLDER_", "right_")
|
||||
return result
|
||||
|
||||
def flip_video_frames(input_path: Path, output_path: Path, fps: float, vcodec: str):
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cmd = [
|
||||
"ffmpeg", "-y", "-i", str(input_path),
|
||||
"-vf", "hflip",
|
||||
"-c:v", vcodec,
|
||||
"-g", "2",
|
||||
"-crf", "30",
|
||||
"-r", str(int(fps)),
|
||||
"-pix_fmt", "yuv420p",
|
||||
"-loglevel", "error",
|
||||
]
|
||||
if vcodec == "libsvtav1":
|
||||
cmd.extend(["-preset", "12"])
|
||||
cmd.append(str(output_path))
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"FFmpeg failed: {result.stderr}")
|
||||
|
||||
def video_is_valid(path: Path) -> bool:
|
||||
if not path.exists():
|
||||
return False
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ffprobe", "-v", "error", "-select_streams", "v:0",
|
||||
"-show_entries", "stream=nb_frames", "-of", "csv=p=0", str(path)],
|
||||
capture_output=True, text=True, timeout=30
|
||||
)
|
||||
return result.returncode == 0 and result.stdout.strip().isdigit()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
root = Path(self.root) if self.root else None
|
||||
output_root = Path(self.output_root) if self.output_root else None
|
||||
|
||||
dataset = LeRobotDataset(self.repo_id, root=root)
|
||||
output_root = output_root or (HF_LEROBOT_HOME / self.output_repo_id)
|
||||
|
||||
if not dataset.meta.video_keys:
|
||||
logger.info(f"Rank {rank}: No videos to process")
|
||||
return
|
||||
|
||||
video_tasks = []
|
||||
for old_video_key in dataset.meta.video_keys:
|
||||
new_video_key = swap_left_right_name(old_video_key)
|
||||
for ep_idx in range(dataset.meta.total_episodes):
|
||||
try:
|
||||
src_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, old_video_key)
|
||||
dst_relative = dataset.meta.get_video_file_path(ep_idx, old_video_key)
|
||||
dst_relative_str = str(dst_relative).replace(old_video_key, new_video_key)
|
||||
dst_path = output_root / dst_relative_str
|
||||
if src_path.exists():
|
||||
video_tasks.append((src_path, dst_path, ep_idx, old_video_key))
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
my_tasks = [t for i, t in enumerate(video_tasks) if i % world_size == rank]
|
||||
logger.info(f"Rank {rank}/{world_size}: Processing {len(my_tasks)}/{len(video_tasks)} videos")
|
||||
|
||||
for src_path, dst_path, ep_idx, video_key in my_tasks:
|
||||
if video_is_valid(dst_path):
|
||||
logger.info(f"Rank {rank}: Skipping {dst_path.name} (already done)")
|
||||
continue
|
||||
logger.info(f"Rank {rank}: Processing {src_path.name} -> {dst_path.name}")
|
||||
flip_video_frames(src_path, dst_path, dataset.meta.fps, self.vcodec)
|
||||
|
||||
|
||||
class MirrorDataAndMetadata(PipelineStep):
|
||||
"""Pipeline step that mirrors parquet data and metadata (runs once on rank 0)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
output_repo_id: str,
|
||||
root: str | None = None,
|
||||
output_root: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.output_repo_id = output_repo_id
|
||||
self.root = root
|
||||
self.output_root = output_root
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
if rank != 0:
|
||||
return
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datasets.utils.tqdm import disable_progress_bars
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import DATA_DIR, DEFAULT_DATA_PATH, write_info, write_stats, write_tasks
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
disable_progress_bars()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MIRRORING_MASK = {
|
||||
"joint_1": -1, "joint_2": -1, "joint_3": -1, "joint_4": 1,
|
||||
"joint_5": -1, "joint_6": -1, "joint_7": -1, "gripper": 1,
|
||||
}
|
||||
|
||||
def get_mirroring_mask(robot_type: str) -> dict[str, int]:
|
||||
if robot_type in ["bi_openarm_follower", "openarm_follower", "bi_openarms_follower", "openarms_follower"]:
|
||||
return MIRRORING_MASK
|
||||
raise ValueError(f"Unknown robot type: {robot_type}. Add a mirroring mask for this robot.")
|
||||
|
||||
def swap_left_right_name(name: str) -> str:
|
||||
result = name.replace("left_", "LEFT_PLACEHOLDER_")
|
||||
result = result.replace("right_", "left_")
|
||||
result = result.replace("LEFT_PLACEHOLDER_", "right_")
|
||||
return result
|
||||
|
||||
def mirror_feature_names(names: list[str]) -> tuple[list[str], dict[int, int]]:
|
||||
mirrored_names = [swap_left_right_name(n) for n in names]
|
||||
old_to_new_idx = {}
|
||||
for old_idx, old_name in enumerate(names):
|
||||
new_name = swap_left_right_name(old_name)
|
||||
new_idx = mirrored_names.index(new_name)
|
||||
old_to_new_idx[old_idx] = new_idx
|
||||
return mirrored_names, old_to_new_idx
|
||||
|
||||
def apply_mirroring_mask(value: float, feature_name: str, mirroring_mask: dict[str, int]) -> float:
|
||||
name_without_prefix = feature_name.split("_", 1)[1] if "_" in feature_name else feature_name
|
||||
joint_name = name_without_prefix.split(".")[0]
|
||||
if joint_name in mirroring_mask:
|
||||
return value * mirroring_mask[joint_name]
|
||||
return value
|
||||
|
||||
def mirror_array(array: np.ndarray, names: list[str], mirroring_mask: dict[str, int]) -> np.ndarray:
|
||||
mirrored_names, idx_mapping = mirror_feature_names(names)
|
||||
result = np.zeros_like(array)
|
||||
for old_idx, new_idx in idx_mapping.items():
|
||||
new_name = mirrored_names[new_idx]
|
||||
value = array[old_idx]
|
||||
mirrored_value = apply_mirroring_mask(value, new_name, mirroring_mask)
|
||||
result[new_idx] = mirrored_value
|
||||
return result
|
||||
|
||||
def mirror_stats(stats: dict) -> dict:
|
||||
mirrored = {}
|
||||
for key, value in stats.items():
|
||||
new_key = swap_left_right_name(key)
|
||||
if isinstance(value, dict):
|
||||
mirrored[new_key] = mirror_stats(value)
|
||||
else:
|
||||
mirrored[new_key] = value
|
||||
return mirrored
|
||||
|
||||
import shutil
|
||||
|
||||
root = Path(self.root) if self.root else None
|
||||
output_root = Path(self.output_root) if self.output_root else None
|
||||
|
||||
dataset = LeRobotDataset(self.repo_id, root=root)
|
||||
output_root = output_root or (HF_LEROBOT_HOME / self.output_repo_id)
|
||||
|
||||
done_marker = output_root / ".data_mirrored"
|
||||
if done_marker.exists():
|
||||
logger.info("Data and metadata already mirrored, skipping")
|
||||
return
|
||||
|
||||
# Clean up partial output from previous failed runs
|
||||
if output_root.exists():
|
||||
logger.info(f"Removing existing partial output: {output_root}")
|
||||
shutil.rmtree(output_root)
|
||||
|
||||
robot_type = dataset.meta.robot_type or "bi_openarms_follower"
|
||||
mirroring_mask = get_mirroring_mask(robot_type)
|
||||
|
||||
mirrored_features = {}
|
||||
for key, feat in dataset.meta.features.items():
|
||||
new_key = swap_left_right_name(key)
|
||||
new_feat = feat.copy()
|
||||
if "names" in new_feat and new_feat["names"]:
|
||||
new_feat["names"] = [swap_left_right_name(n) for n in new_feat["names"]]
|
||||
mirrored_features[new_key] = new_feat
|
||||
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=self.output_repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=mirrored_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_root,
|
||||
use_videos=len(dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
|
||||
data_dir = dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
action_names = dataset.meta.features.get("action", {}).get("names", [])
|
||||
state_names = dataset.meta.features.get("observation.state", {}).get("names", [])
|
||||
|
||||
for src_path in parquet_files:
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
relative_path = src_path.relative_to(dataset.root)
|
||||
chunk_dir = relative_path.parts[1]
|
||||
file_name = relative_path.parts[2]
|
||||
chunk_idx = int(chunk_dir.split("-")[1])
|
||||
file_idx = int(file_name.split("-")[1].split(".")[0])
|
||||
|
||||
if "action" in df.columns and action_names:
|
||||
actions = np.stack(df["action"].values)
|
||||
mirrored_actions = np.array([mirror_array(row, action_names, mirroring_mask) for row in actions])
|
||||
df["action"] = list(mirrored_actions)
|
||||
|
||||
if "observation.state" in df.columns and state_names:
|
||||
states = np.stack(df["observation.state"].values)
|
||||
mirrored_states = np.array([mirror_array(row, state_names, mirroring_mask) for row in states])
|
||||
df["observation.state"] = list(mirrored_states)
|
||||
|
||||
dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_path, index=False)
|
||||
|
||||
episodes_dir = dataset.root / "meta/episodes"
|
||||
dst_episodes_dir = new_meta.root / "meta/episodes"
|
||||
if episodes_dir.exists():
|
||||
dst_episodes_dir.mkdir(parents=True, exist_ok=True)
|
||||
for src_parquet in episodes_dir.glob("*/*.parquet"):
|
||||
df = pd.read_parquet(src_parquet)
|
||||
columns_to_rename = {}
|
||||
for col in df.columns:
|
||||
if col.startswith("videos/"):
|
||||
parts = col.split("/")
|
||||
if len(parts) >= 2:
|
||||
video_key = parts[1]
|
||||
new_video_key = swap_left_right_name(video_key)
|
||||
new_col = col.replace(f"videos/{video_key}/", f"videos/{new_video_key}/")
|
||||
columns_to_rename[col] = new_col
|
||||
if columns_to_rename:
|
||||
df = df.rename(columns=columns_to_rename)
|
||||
dst_parquet = dst_episodes_dir / src_parquet.relative_to(episodes_dir)
|
||||
dst_parquet.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_parquet, index=False)
|
||||
|
||||
new_meta.info.update({
|
||||
"total_episodes": dataset.meta.info["total_episodes"],
|
||||
"total_frames": dataset.meta.info["total_frames"],
|
||||
"total_tasks": dataset.meta.info["total_tasks"],
|
||||
"splits": dataset.meta.info.get("splits", {}),
|
||||
})
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
if dataset.meta.stats is not None:
|
||||
mirrored_stats = mirror_stats(dataset.meta.stats)
|
||||
write_stats(mirrored_stats, new_meta.root)
|
||||
|
||||
done_marker.touch()
|
||||
logger.info(f"Data and metadata mirrored to {output_root}")
|
||||
|
||||
|
||||
def swap_left_right_name(name: str) -> str:
|
||||
result = name.replace("left_", "LEFT_PLACEHOLDER_")
|
||||
result = result.replace("right_", "left_")
|
||||
result = result.replace("LEFT_PLACEHOLDER_", "right_")
|
||||
return result
|
||||
|
||||
|
||||
def get_num_video_tasks(repo_id: str, root: str | None = None) -> int:
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
root_path = Path(root) if root else None
|
||||
dataset = LeRobotDataset(repo_id, root=root_path)
|
||||
count = 0
|
||||
for video_key in dataset.meta.video_keys:
|
||||
for ep_idx in range(dataset.meta.total_episodes):
|
||||
try:
|
||||
src_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
|
||||
if src_path.exists():
|
||||
count += 1
|
||||
except KeyError:
|
||||
continue
|
||||
return count
|
||||
|
||||
|
||||
def make_mirror_executor(
|
||||
repo_id: str,
|
||||
output_repo_id: str,
|
||||
root: str | None,
|
||||
output_root: str | None,
|
||||
vcodec: str,
|
||||
job_name: str,
|
||||
logs_dir: Path,
|
||||
workers: int,
|
||||
partition: str,
|
||||
cpus_per_task: int,
|
||||
mem_per_cpu: str,
|
||||
time_limit: str,
|
||||
slurm: bool = True,
|
||||
):
|
||||
num_tasks = get_num_video_tasks(repo_id, root) if slurm else 1
|
||||
num_tasks = max(1, num_tasks)
|
||||
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
MirrorDataAndMetadata(repo_id, output_repo_id, root, output_root),
|
||||
MirrorVideos(repo_id, output_repo_id, root, output_root, vcodec),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update({
|
||||
"job_name": job_name,
|
||||
"tasks": num_tasks,
|
||||
"workers": min(workers, num_tasks),
|
||||
"time": time_limit,
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {
|
||||
"mem-per-cpu": mem_per_cpu,
|
||||
"requeue": True,
|
||||
"signal": "USR1@30",
|
||||
},
|
||||
})
|
||||
return SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update({"tasks": 1, "workers": 1})
|
||||
return LocalPipelineExecutor(**kwargs)
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
parser = argparse.ArgumentParser(description="Mirror a bimanual robot dataset using SLURM")
|
||||
parser.add_argument("--repo-id", type=str, required=True, help="Source dataset repo_id")
|
||||
parser.add_argument("--output-repo-id", type=str, required=True, help="Output dataset repo_id")
|
||||
parser.add_argument("--root", type=str, default=None, help="Source dataset root directory")
|
||||
parser.add_argument("--output-root", type=str, default=None, help="Output dataset root directory")
|
||||
parser.add_argument("--vcodec", type=str, default="libsvtav1", help="Video codec")
|
||||
parser.add_argument("--logs-dir", type=Path, default=Path("logs"), help="Directory for datatrove logs")
|
||||
parser.add_argument("--job-name", type=str, default="mirror_dataset", help="SLURM job name")
|
||||
parser.add_argument("--slurm", type=int, default=1, help="Use SLURM (1) or local (0)")
|
||||
parser.add_argument("--workers", type=int, default=64, help="Number of SLURM workers")
|
||||
parser.add_argument("--partition", type=str, default="hopper-cpu", help="SLURM partition")
|
||||
parser.add_argument("--cpus-per-task", type=int, default=4, help="CPUs per task")
|
||||
parser.add_argument("--mem-per-cpu", type=str, default="2G", help="Memory per CPU")
|
||||
parser.add_argument("--time-limit", type=str, default="04:00:00", help="SLURM time limit")
|
||||
parser.add_argument("--push-to-hub", action="store_true", help="Push mirrored dataset to HuggingFace Hub")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
executor = make_mirror_executor(
|
||||
repo_id=args.repo_id,
|
||||
output_repo_id=args.output_repo_id,
|
||||
root=args.root,
|
||||
output_root=args.output_root,
|
||||
vcodec=args.vcodec,
|
||||
job_name=args.job_name,
|
||||
logs_dir=args.logs_dir,
|
||||
workers=args.workers,
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time_limit,
|
||||
slurm=args.slurm == 1,
|
||||
)
|
||||
executor.run()
|
||||
|
||||
if args.push_to_hub:
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
output_root = Path(args.output_root) if args.output_root else HF_LEROBOT_HOME / args.output_repo_id
|
||||
logger.info(f"Pushing dataset to HuggingFace Hub: {args.output_repo_id}")
|
||||
dataset = LeRobotDataset(args.output_repo_id, root=output_root)
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -142,38 +142,24 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="so100_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -182,24 +168,41 @@ def main():
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -146,38 +146,23 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording_phone")
|
||||
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
try:
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print("Starting record loop...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
@@ -185,25 +170,44 @@ def main():
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -74,32 +74,35 @@ def main():
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i])
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
finally:
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -30,6 +30,7 @@ def main():
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
client_device="cpu",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="<user>/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
|
||||
+10
-2
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.4.3"
|
||||
version = "0.4.4"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
dynamic = ["readme"]
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -102,14 +102,20 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
damiao = ["python-can>=4.2.0,<5.0.0"]
|
||||
|
||||
# Robots
|
||||
openarms = ["lerobot[damiao]"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
unitree_g1 = [
|
||||
"pyzmq>=26.2.1,<28.0.0",
|
||||
"onnxruntime>=1.16.0,<2.0.0"
|
||||
"onnxruntime>=1.16.0,<2.0.0",
|
||||
"pin>=3.0.0,<4.0.0",
|
||||
"meshcat>=0.3.0,<0.4.0",
|
||||
"matplotlib>=3.9.0,<4.0.0",
|
||||
"casadi>=3.6.0,<4.0.0",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
@@ -203,6 +209,7 @@ lerobot-info="lerobot.scripts.lerobot_info:main"
|
||||
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.packages.find]
|
||||
@@ -278,6 +285,7 @@ default.extend-ignore-identifiers-re = [
|
||||
"thw",
|
||||
"inpt",
|
||||
"ROBOTIS",
|
||||
"OT_VALUE"
|
||||
]
|
||||
|
||||
# TODO: Uncomment when ready to use
|
||||
|
||||
@@ -126,6 +126,12 @@ class RobotClientConfig:
|
||||
|
||||
# Device configuration
|
||||
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
|
||||
client_device: str = field(
|
||||
default="cpu",
|
||||
metadata={
|
||||
"help": "Device to move actions to after receiving from server (e.g., for downstream planners)"
|
||||
},
|
||||
)
|
||||
|
||||
# Control behavior configuration
|
||||
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
|
||||
@@ -161,6 +167,9 @@ class RobotClientConfig:
|
||||
if not self.policy_device:
|
||||
raise ValueError("policy_device cannot be empty")
|
||||
|
||||
if not self.client_device:
|
||||
raise ValueError("client_device cannot be empty")
|
||||
|
||||
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
|
||||
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
|
||||
|
||||
@@ -184,6 +193,7 @@ class RobotClientConfig:
|
||||
"policy_type": self.policy_type,
|
||||
"pretrained_name_or_path": self.pretrained_name_or_path,
|
||||
"policy_device": self.policy_device,
|
||||
"client_device": self.client_device,
|
||||
"chunk_size_threshold": self.chunk_size_threshold,
|
||||
"fps": self.fps,
|
||||
"actions_per_chunk": self.actions_per_chunk,
|
||||
|
||||
@@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
||||
|
||||
# All action chunking policies
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05", "groot"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"]
|
||||
|
||||
@@ -18,6 +18,7 @@ import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -39,8 +40,8 @@ from lerobot.utils.utils import init_logging
|
||||
|
||||
Action = torch.Tensor
|
||||
|
||||
# observation as received from the robot
|
||||
RawObservation = dict[str, torch.Tensor]
|
||||
# observation as received from the robot (can be numpy arrays, floats, etc.)
|
||||
RawObservation = dict[str, Any]
|
||||
|
||||
# observation as those recorded in LeRobot dataset (keys are different)
|
||||
LeRobotObservation = dict[str, torch.Tensor]
|
||||
|
||||
@@ -381,6 +381,8 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
|
||||
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
|
||||
|
||||
action_tensor = action_tensor.detach().cpu()
|
||||
|
||||
"""5. Convert to TimedAction list"""
|
||||
action_chunk = self._time_action_chunk(
|
||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
||||
|
||||
@@ -25,6 +25,7 @@ python src/lerobot/async_inference/robot_client.py \
|
||||
--policy_type=act \
|
||||
--pretrained_name_or_path=user/model \
|
||||
--policy_device=mps \
|
||||
--client_device=cpu \
|
||||
--actions_per_chunk=50 \
|
||||
--chunk_size_threshold=0.5 \
|
||||
--aggregate_fn_name=weighted_average \
|
||||
@@ -40,6 +41,7 @@ from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
@@ -47,7 +49,6 @@ import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -285,6 +286,21 @@ class RobotClient:
|
||||
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
||||
deserialize_time = time.perf_counter() - deserialize_start
|
||||
|
||||
# Log device type of received actions
|
||||
if len(timed_actions) > 0:
|
||||
received_device = timed_actions[0].get_action().device.type
|
||||
self.logger.debug(f"Received actions on device: {received_device}")
|
||||
|
||||
# Move actions to client_device (e.g., for downstream planners that need GPU)
|
||||
client_device = self.config.client_device
|
||||
if client_device != "cpu":
|
||||
for timed_action in timed_actions:
|
||||
if timed_action.get_action().device.type != client_device:
|
||||
timed_action.action = timed_action.get_action().to(client_device)
|
||||
self.logger.debug(f"Converted actions to device: {client_device}")
|
||||
else:
|
||||
self.logger.debug(f"Actions kept on device: {client_device}")
|
||||
|
||||
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
|
||||
|
||||
# Calculate network latency if we have matching observations
|
||||
@@ -351,7 +367,7 @@ class RobotClient:
|
||||
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
|
||||
return action
|
||||
|
||||
def control_loop_action(self, verbose: bool = False) -> RobotAction:
|
||||
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
|
||||
"""Reading and performing actions in local queue"""
|
||||
|
||||
# Lock only for queue operations
|
||||
|
||||
@@ -15,11 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
from .configs import CameraConfig, ColorMode
|
||||
from .configs import CameraConfig
|
||||
|
||||
|
||||
class Camera(abc.ABC):
|
||||
@@ -30,20 +31,12 @@ class Camera(abc.ABC):
|
||||
|
||||
Manages basic camera properties (FPS, resolution) and core operations:
|
||||
- Connection/disconnection
|
||||
- Frame capture (sync/async)
|
||||
- Frame capture (sync/async/latest)
|
||||
|
||||
Attributes:
|
||||
fps (int | None): Configured frames per second
|
||||
width (int | None): Frame width in pixels
|
||||
height (int | None): Frame height in pixels
|
||||
|
||||
Example:
|
||||
class MyCamera(Camera):
|
||||
def __init__(self, config): ...
|
||||
@property
|
||||
def is_connected(self) -> bool: ...
|
||||
def connect(self, warmup=True): ...
|
||||
# Plus other required methods
|
||||
"""
|
||||
|
||||
def __init__(self, config: CameraConfig):
|
||||
@@ -56,6 +49,32 @@ class Camera(abc.ABC):
|
||||
self.width: int | None = config.width
|
||||
self.height: int | None = config.height
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Context manager entry.
|
||||
Automatically connects to the camera.
|
||||
"""
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
"""
|
||||
Context manager exit.
|
||||
Automatically disconnects, ensuring resources are released even on error.
|
||||
"""
|
||||
self.disconnect()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
Destructor safety net.
|
||||
Attempts to disconnect if the object is garbage collected without cleanup.
|
||||
"""
|
||||
try:
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
except Exception: # nosec B110
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
@@ -89,12 +108,10 @@ class Camera(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""Capture and return a single frame from the camera.
|
||||
def read(self) -> NDArray[Any]:
|
||||
"""Capture and return a single frame from the camera synchronously.
|
||||
|
||||
Args:
|
||||
color_mode: Desired color mode for the output frame. If None,
|
||||
uses the camera's default color mode.
|
||||
This is a blocking call that will wait for the hardware and its SDK.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured frame as a numpy array.
|
||||
@@ -103,17 +120,64 @@ class Camera(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def async_read(self, timeout_ms: float = ...) -> NDArray[Any]:
|
||||
"""Asynchronously capture and return a single frame from the camera.
|
||||
"""Return the most recent new frame.
|
||||
|
||||
This method retrieves the latest frame captured by the background thread.
|
||||
If a new frame is already available in the buffer (captured since the last call),
|
||||
it returns it immediately.
|
||||
|
||||
It blocks up to `timeout_ms` only if the buffer is empty or if the latest frame
|
||||
was already consumed by a previous `async_read` call.
|
||||
|
||||
Essentially, this method return the latest unconsumed frame, waiting if necessary
|
||||
for a new one to arrive within the specified timeout.
|
||||
|
||||
Usage:
|
||||
- Ideal for control loops where you want to ensure every processed frame
|
||||
is fresh, effectively synchronizing your loop to the camera's FPS.
|
||||
- Causes of a timeout usually include: very low camera FPS, heavy processing load,
|
||||
or if the camera is disconnected.
|
||||
|
||||
Args:
|
||||
timeout_ms: Maximum time to wait for a frame in milliseconds.
|
||||
Defaults to implementation-specific timeout.
|
||||
timeout_ms: Maximum time to wait for a new frame in milliseconds.
|
||||
Defaults to 200ms (0.2s).
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured frame as a numpy array.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If no new frame arrives within `timeout_ms`.
|
||||
"""
|
||||
pass
|
||||
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Usage:
|
||||
Ideal for scenarios requiring zero latency or decoupled frequencies & when
|
||||
we want a guaranteed frame, such as UI visualization, logging, or
|
||||
non-critical monitoring.
|
||||
|
||||
Returns:
|
||||
NDArray[Any]: The frame image (numpy array).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
NotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
warnings.warn(
|
||||
f"{self.__class__.__name__}.read_latest() is not implemented. "
|
||||
"Please override read_latest(); it will be required in future releases.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.async_read()
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from the camera and release resources."""
|
||||
|
||||
@@ -70,34 +70,24 @@ class OpenCVCamera(Camera):
|
||||
Example:
|
||||
```python
|
||||
from lerobot.cameras.opencv import OpenCVCamera
|
||||
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig, ColorMode, Cv2Rotation
|
||||
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig
|
||||
|
||||
# Basic usage with camera index 0
|
||||
config = OpenCVCameraConfig(index_or_path=0)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
|
||||
# Read 1 frame synchronously
|
||||
# Read 1 frame synchronously (blocking)
|
||||
color_image = camera.read()
|
||||
print(color_image.shape)
|
||||
|
||||
# Read 1 frame asynchronously
|
||||
# Read 1 frame asynchronously (waits for new frame with a timeout)
|
||||
async_image = camera.async_read()
|
||||
|
||||
# Get the latest frame immediately (no wait, returns timestamp)
|
||||
latest_image, timestamp = camera.read_latest()
|
||||
|
||||
# When done, properly disconnect the camera using
|
||||
camera.disconnect()
|
||||
|
||||
# Example with custom settings
|
||||
custom_config = OpenCVCameraConfig(
|
||||
index_or_path='/dev/video0', # Or use an index
|
||||
fps=30,
|
||||
width=1280,
|
||||
height=720,
|
||||
color_mode=ColorMode.RGB,
|
||||
rotation=Cv2Rotation.ROTATE_90
|
||||
)
|
||||
custom_camera = OpenCVCamera(custom_config)
|
||||
# ... connect, read, disconnect ...
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -123,6 +113,7 @@ class OpenCVCamera(Camera):
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_timestamp: float | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
@@ -146,12 +137,16 @@ class OpenCVCamera(Camera):
|
||||
Connects to the OpenCV camera specified in the configuration.
|
||||
|
||||
Initializes the OpenCV VideoCapture object, sets desired camera properties
|
||||
(FPS, width, height), and performs initial checks.
|
||||
(FPS, width, height), starts the background reading thread and performs initial checks.
|
||||
|
||||
Args:
|
||||
warmup (bool): If True, waits at connect() time until at least one valid frame
|
||||
has been captured by the background thread. Defaults to True.
|
||||
|
||||
Raises:
|
||||
DeviceAlreadyConnectedError: If the camera is already connected.
|
||||
ConnectionError: If the specified camera index/path is not found or the camera is found but fails to open.
|
||||
RuntimeError: If the camera opens but fails to apply requested FPS/resolution settings.
|
||||
ConnectionError: If the specified camera index/path is not found or fails to open.
|
||||
RuntimeError: If the camera opens but fails to apply requested settings.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
@@ -170,12 +165,16 @@ class OpenCVCamera(Camera):
|
||||
)
|
||||
|
||||
self._configure_capture_settings()
|
||||
self._start_read_thread()
|
||||
|
||||
if warmup:
|
||||
if warmup and self.warmup_s > 0:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.warmup_s:
|
||||
self.read()
|
||||
self.async_read(timeout_ms=self.warmup_s * 1000)
|
||||
time.sleep(0.1)
|
||||
with self.frame_lock:
|
||||
if self.latest_frame is None:
|
||||
raise ConnectionError(f"{self} failed to capture frames during warmup.")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@@ -196,8 +195,7 @@ class OpenCVCamera(Camera):
|
||||
Raises:
|
||||
RuntimeError: If the camera fails to set any of the specified properties
|
||||
to the requested value.
|
||||
DeviceNotConnectedError: If the camera is not connected when attempting
|
||||
to configure settings.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
|
||||
@@ -339,6 +337,17 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return found_cameras_info
|
||||
|
||||
def _read_from_hardware(self) -> NDArray[Any]:
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
ret, frame = self.videocapture.read()
|
||||
|
||||
if not ret:
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
|
||||
return frame
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
@@ -346,11 +355,6 @@ class OpenCVCamera(Camera):
|
||||
This is a blocking call. It waits for the next available frame from the
|
||||
camera hardware via OpenCV.
|
||||
|
||||
Args:
|
||||
color_mode (Optional[ColorMode]): If specified, overrides the default
|
||||
color mode (`self.color_mode`) for this read operation (e.g.,
|
||||
request RGB even if default is BGR).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured frame as a NumPy array in the format
|
||||
(height, width, channels), using the specified or default
|
||||
@@ -362,34 +366,34 @@ class OpenCVCamera(Camera):
|
||||
received frame dimensions don't match expectations before rotation.
|
||||
ValueError: If an invalid `color_mode` is requested.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
if color_mode is not None:
|
||||
logger.warning(
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
ret, frame = self.videocapture.read()
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if not ret or frame is None:
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
processed_frame = self._postprocess_image(frame, color_mode)
|
||||
self.new_frame_event.clear()
|
||||
frame = self.async_read(timeout_ms=10000)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return processed_frame
|
||||
return frame
|
||||
|
||||
def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
def _postprocess_image(self, image: NDArray[Any]) -> NDArray[Any]:
|
||||
"""
|
||||
Applies color conversion, dimension validation, and rotation to a raw frame.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The raw image frame (expected BGR format from OpenCV).
|
||||
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
|
||||
uses the instance's default `self.color_mode`.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The processed image frame.
|
||||
@@ -399,11 +403,10 @@ class OpenCVCamera(Camera):
|
||||
RuntimeError: If the raw frame dimensions do not match the configured
|
||||
`width` and `height`.
|
||||
"""
|
||||
requested_color_mode = self.color_mode if color_mode is None else color_mode
|
||||
|
||||
if requested_color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"Invalid color mode '{requested_color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
)
|
||||
|
||||
h, w, c = image.shape
|
||||
@@ -417,7 +420,7 @@ class OpenCVCamera(Camera):
|
||||
raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).")
|
||||
|
||||
processed_image = image
|
||||
if requested_color_mode == ColorMode.RGB:
|
||||
if self.color_mode == ColorMode.RGB:
|
||||
processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]:
|
||||
@@ -431,7 +434,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame
|
||||
2. Stores result in latest_frame (thread-safe)
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
@@ -439,30 +442,37 @@ class OpenCVCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
raw_frame = self._read_from_hardware()
|
||||
processed_frame = self._postprocess_image(raw_frame)
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = color_image
|
||||
self.latest_frame = processed_frame
|
||||
self.latest_timestamp = capture_time
|
||||
self.new_frame_event.set()
|
||||
failure_count = 0
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
if failure_count <= 10:
|
||||
failure_count += 1
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
else:
|
||||
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
"""Starts or restarts the background read thread if it's not running."""
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=0.1)
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
self._stop_read_thread()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
time.sleep(0.1)
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
"""Signals the background read thread to stop and waits for it to join."""
|
||||
@@ -475,6 +485,11 @@ class OpenCVCamera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
@@ -482,6 +497,7 @@ class OpenCVCamera(Camera):
|
||||
This method retrieves the most recent frame captured by the background
|
||||
read thread. It does not block waiting for the camera hardware directly,
|
||||
but may wait up to timeout_ms for the background thread to provide a frame.
|
||||
It is “best effort” under high FPS.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
@@ -500,13 +516,12 @@ class OpenCVCamera(Camera):
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
thread_alive = self.thread is not None and self.thread.is_alive()
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
|
||||
f"Read thread alive: {thread_alive}."
|
||||
f"Read thread alive: {self.thread.is_alive()}."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
@@ -518,6 +533,42 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Returns:
|
||||
NDArray[Any]: The frame image (numpy array).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera and cleans up resources.
|
||||
@@ -538,4 +589,9 @@ class OpenCVCamera(Camera):
|
||||
self.videocapture.release()
|
||||
self.videocapture = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -80,6 +80,8 @@ class Reachy2Camera(Camera):
|
||||
self.config = config
|
||||
|
||||
self.color_mode = config.color_mode
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_timestamp: float | None = None
|
||||
|
||||
self.cam_manager: CameraManager | None = None
|
||||
|
||||
@@ -125,12 +127,7 @@ class Reachy2Camera(Camera):
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
This is a blocking call.
|
||||
|
||||
Args:
|
||||
color_mode (Optional[ColorMode]): If specified, overrides the default
|
||||
color mode (`self.color_mode`) for this read operation (e.g.,
|
||||
request RGB even if default is BGR).
|
||||
This method retrieves the most recent frame available in Reachy 2's low-level software.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured frame as a NumPy array in the format
|
||||
@@ -145,6 +142,11 @@ class Reachy2Camera(Camera):
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if color_mode is not None:
|
||||
logger.warning(
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
|
||||
|
||||
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
|
||||
@@ -165,11 +167,18 @@ class Reachy2Camera(Camera):
|
||||
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
|
||||
|
||||
if frame is None:
|
||||
return np.empty((0, 0, 3), dtype=np.uint8)
|
||||
raise RuntimeError(f"Internal error: No frame available for {self}.")
|
||||
|
||||
if self.config.color_mode == "rgb":
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
)
|
||||
if self.color_mode == ColorMode.RGB:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
self.latest_frame = frame
|
||||
self.latest_timestamp = time.perf_counter()
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
@@ -177,13 +186,7 @@ class Reachy2Camera(Camera):
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame.
|
||||
|
||||
This method retrieves the most recent frame available in Reachy 2's low-level software.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
to become available. Defaults to 200ms (0.2 seconds).
|
||||
Same as read()
|
||||
|
||||
Returns:
|
||||
np.ndarray: The latest captured frame as a NumPy array in the format
|
||||
@@ -197,12 +200,38 @@ class Reachy2Camera(Camera):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
frame = self.read()
|
||||
return self.read()
|
||||
|
||||
if frame is None:
|
||||
raise RuntimeError(f"Internal error: No frame available for {self}.")
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
return frame
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Returns:
|
||||
tuple[NDArray, float]:
|
||||
- The frame image (numpy array).
|
||||
- The timestamp (time.perf_counter) when this frame was captured.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.latest_frame is None or self.latest_timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - self.latest_timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return self.latest_frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -72,15 +72,14 @@ class RealSenseCamera(Camera):
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect()
|
||||
|
||||
# Read 1 frame synchronously
|
||||
# Read 1 frame synchronously (blocking)
|
||||
color_image = camera.read()
|
||||
print(color_image.shape)
|
||||
|
||||
# Read 1 frame asynchronously
|
||||
# Read 1 frame asynchronously (waits for new frame with a timeout)
|
||||
async_image = camera.async_read()
|
||||
|
||||
# When done, properly disconnect the camera using
|
||||
camera.disconnect()
|
||||
# Get the latest frame immediately (no wait, returns timestamp)
|
||||
latest_image, timestamp = camera.read_latest()
|
||||
|
||||
# Example with depth capture and custom settings
|
||||
custom_config = RealSenseCameraConfig(
|
||||
@@ -133,7 +132,9 @@ class RealSenseCamera(Camera):
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_color_frame: NDArray[Any] | None = None
|
||||
self.latest_depth_frame: NDArray[Any] | None = None
|
||||
self.latest_timestamp: float | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
@@ -158,6 +159,10 @@ class RealSenseCamera(Camera):
|
||||
Initializes the RealSense pipeline, configures the required streams (color
|
||||
and optionally depth), starts the pipeline, and validates the actual stream settings.
|
||||
|
||||
Args:
|
||||
warmup (bool): If True, waits at connect() time until at least one valid frame
|
||||
has been captured by the background thread. Defaults to True.
|
||||
|
||||
Raises:
|
||||
DeviceAlreadyConnectedError: If the camera is already connected.
|
||||
ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique).
|
||||
@@ -181,15 +186,18 @@ class RealSenseCamera(Camera):
|
||||
) from e
|
||||
|
||||
self._configure_capture_settings()
|
||||
self._start_read_thread()
|
||||
|
||||
if warmup:
|
||||
time.sleep(
|
||||
1
|
||||
) # NOTE(Steven): RS cameras need a bit of time to warm up before the first read. If we don't wait, the first read from the warmup will raise.
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.warmup_s:
|
||||
self.read()
|
||||
time.sleep(0.1)
|
||||
# NOTE(Steven/Caroline): Enforcing at least one second of warmup as RS cameras need a bit of time before the first read. If we don't wait, the first read from the warmup will raise.
|
||||
self.warmup_s = max(self.warmup_s, 1)
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.warmup_s:
|
||||
self.async_read(timeout_ms=self.warmup_s * 1000)
|
||||
time.sleep(0.1)
|
||||
with self.frame_lock:
|
||||
if self.latest_color_frame is None or self.use_depth and self.latest_depth_frame is None:
|
||||
raise ConnectionError(f"{self} failed to capture frames during warmup.")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@@ -319,9 +327,6 @@ class RealSenseCamera(Camera):
|
||||
This is a blocking call. It waits for a coherent set of frames (depth)
|
||||
from the camera hardware via the RealSense pipeline.
|
||||
|
||||
Args:
|
||||
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The depth map as a NumPy array (height, width)
|
||||
of type `np.uint16` (raw depth values in millimeters) and rotation.
|
||||
@@ -330,44 +335,52 @@ class RealSenseCamera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If reading frames from the pipeline fails or frames are invalid.
|
||||
"""
|
||||
if timeout_ms:
|
||||
logger.warning(
|
||||
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(
|
||||
f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
self.new_frame_event.clear()
|
||||
|
||||
_ = self.async_read(timeout_ms=10000)
|
||||
|
||||
with self.frame_lock:
|
||||
depth_map = self.latest_depth_frame
|
||||
|
||||
if depth_map is None:
|
||||
raise RuntimeError("No depth frame available. Ensure camera is streaming.")
|
||||
|
||||
return depth_map
|
||||
|
||||
def _read_from_hardware(self):
|
||||
if self.rs_pipeline is None:
|
||||
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
|
||||
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=10000)
|
||||
|
||||
if not ret or frame is None:
|
||||
raise RuntimeError(f"{self} read_depth failed (status={ret}).")
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
|
||||
depth_frame = frame.get_depth_frame()
|
||||
depth_map = np.asanyarray(depth_frame.get_data())
|
||||
return frame
|
||||
|
||||
depth_map_processed = self._postprocess_image(depth_map, depth_frame=True)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return depth_map_processed
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]:
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame (color) synchronously from the camera.
|
||||
|
||||
This is a blocking call. It waits for a coherent set of frames (color)
|
||||
from the camera hardware via the RealSense pipeline.
|
||||
|
||||
Args:
|
||||
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured color frame as a NumPy array
|
||||
(height, width, channels), processed according to `color_mode` and rotation.
|
||||
@@ -378,39 +391,39 @@ class RealSenseCamera(Camera):
|
||||
ValueError: If an invalid `color_mode` is requested.
|
||||
"""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if color_mode is not None:
|
||||
logger.warning(
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if timeout_ms:
|
||||
logger.warning(
|
||||
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if self.rs_pipeline is None:
|
||||
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
|
||||
self.new_frame_event.clear()
|
||||
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
|
||||
|
||||
if not ret or frame is None:
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
|
||||
color_frame = frame.get_color_frame()
|
||||
color_image_raw = np.asanyarray(color_frame.get_data())
|
||||
|
||||
color_image_processed = self._postprocess_image(color_image_raw, color_mode)
|
||||
frame = self.async_read(timeout_ms=10000)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return color_image_processed
|
||||
return frame
|
||||
|
||||
def _postprocess_image(
|
||||
self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False
|
||||
) -> NDArray[Any]:
|
||||
def _postprocess_image(self, image: NDArray[Any], depth_frame: bool = False) -> NDArray[Any]:
|
||||
"""
|
||||
Applies color conversion, dimension validation, and rotation to a raw color frame.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The raw image frame (expected RGB format from RealSense).
|
||||
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
|
||||
uses the instance's default `self.color_mode`.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The processed image frame according to `self.color_mode` and `self.rotation`.
|
||||
@@ -421,9 +434,9 @@ class RealSenseCamera(Camera):
|
||||
`width` and `height`.
|
||||
"""
|
||||
|
||||
if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
if self.color_mode and self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
f"Invalid requested color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
)
|
||||
|
||||
if depth_frame:
|
||||
@@ -454,7 +467,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame with 500ms timeout
|
||||
2. Stores result in latest_frame (thread-safe)
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
@@ -462,25 +475,41 @@ class RealSenseCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read(timeout_ms=500)
|
||||
frame = self._read_from_hardware()
|
||||
color_frame_raw = frame.get_color_frame()
|
||||
color_frame = np.asanyarray(color_frame_raw.get_data())
|
||||
processed_color_frame = self._postprocess_image(color_frame)
|
||||
|
||||
if self.use_depth:
|
||||
depth_frame_raw = frame.get_depth_frame()
|
||||
depth_frame = np.asanyarray(depth_frame_raw.get_data())
|
||||
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
|
||||
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = color_image
|
||||
self.latest_color_frame = processed_color_frame
|
||||
if self.use_depth:
|
||||
self.latest_depth_frame = processed_depth_frame
|
||||
self.latest_timestamp = capture_time
|
||||
self.new_frame_event.set()
|
||||
failure_count = 0
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
if failure_count <= 10:
|
||||
failure_count += 1
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
else:
|
||||
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
"""Starts or restarts the background read thread if it's not running."""
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=0.1)
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
self._stop_read_thread()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
|
||||
@@ -498,6 +527,12 @@ class RealSenseCamera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_color_frame = None
|
||||
self.latest_depth_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
@@ -506,6 +541,7 @@ class RealSenseCamera(Camera):
|
||||
This method retrieves the most recent color frame captured by the background
|
||||
read thread. It does not block waiting for the camera hardware directly,
|
||||
but may wait up to timeout_ms for the background thread to provide a frame.
|
||||
It is “best effort” under high FPS.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
@@ -524,17 +560,16 @@ class RealSenseCamera(Camera):
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
thread_alive = self.thread is not None and self.thread.is_alive()
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
|
||||
f"Read thread alive: {thread_alive}."
|
||||
f"Read thread alive: {self.thread.is_alive()}."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
frame = self.latest_color_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if frame is None:
|
||||
@@ -542,6 +577,43 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Returns:
|
||||
NDArray[Any]: The frame image (numpy array).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_color_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera, stops the pipeline, and cleans up resources.
|
||||
@@ -565,4 +637,10 @@ class RealSenseCamera(Camera):
|
||||
self.rs_pipeline = None
|
||||
self.rs_profile = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_color_frame = None
|
||||
self.latest_depth_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -45,6 +45,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ZMQCamera(Camera):
|
||||
"""
|
||||
Manages camera interactions via ZeroMQ for receiving frames from a remote server.
|
||||
|
||||
This class connects to a ZMQ Publisher, subscribes to frame topics, and decodes
|
||||
incoming JSON messages containing Base64 encoded images. It supports both
|
||||
synchronous and asynchronous frame reading patterns.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig
|
||||
@@ -52,7 +58,16 @@ class ZMQCamera(Camera):
|
||||
config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera")
|
||||
camera = ZMQCamera(config)
|
||||
camera.connect()
|
||||
frame = camera.read()
|
||||
|
||||
# Read 1 frame synchronously (blocking)
|
||||
color_image = camera.read()
|
||||
|
||||
# Read 1 frame asynchronously (waits for new frame with a timeout)
|
||||
async_image = camera.async_read()
|
||||
|
||||
# Get the latest frame immediately (no wait, returns timestamp)
|
||||
latest_image, timestamp = camera.read_latest()
|
||||
|
||||
camera.disconnect()
|
||||
```
|
||||
"""
|
||||
@@ -68,14 +83,17 @@ class ZMQCamera(Camera):
|
||||
self.color_mode = config.color_mode
|
||||
self.timeout_ms = config.timeout_ms
|
||||
|
||||
# ZMQ Context and Socket
|
||||
self.context: zmq.Context | None = None
|
||||
self.socket: zmq.Socket | None = None
|
||||
self._connected = False
|
||||
|
||||
# Threading resources
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_timestamp: float | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -83,10 +101,16 @@ class ZMQCamera(Camera):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Checks if the ZMQ socket is initialized and connected."""
|
||||
return self._connected and self.context is not None and self.socket is not None
|
||||
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""Connect to ZMQ camera server."""
|
||||
"""Connect to ZMQ camera server.
|
||||
|
||||
Args:
|
||||
warmup (bool): If True, waits for the camera to provide at least one
|
||||
valid frame before returning. Defaults to True.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
|
||||
@@ -103,17 +127,28 @@ class ZMQCamera(Camera):
|
||||
self.socket.connect(f"tcp://{self.server_address}:{self.port}")
|
||||
self._connected = True
|
||||
|
||||
# Auto-detect resolution
|
||||
# Auto-detect resolution if not provided
|
||||
if self.width is None or self.height is None:
|
||||
h, w = self.read().shape[:2]
|
||||
# Read directly from hardware because the thread isn't running yet
|
||||
temp_frame = self._read_from_hardware()
|
||||
h, w = temp_frame.shape[:2]
|
||||
self.height = h
|
||||
self.width = w
|
||||
logger.info(f"{self} resolution: {w}x{h}")
|
||||
logger.info(f"{self} resolution detected: {w}x{h}")
|
||||
|
||||
self._start_read_thread()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
if warmup:
|
||||
time.sleep(0.1)
|
||||
# Ensure we have captured at least one frame via the thread
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < (self.config.warmup_s): # Wait a bit more than timeout
|
||||
self.async_read(timeout_ms=self.config.warmup_s * 1000)
|
||||
time.sleep(0.1)
|
||||
|
||||
with self.frame_lock:
|
||||
if self.latest_frame is None:
|
||||
raise ConnectionError(f"{self} failed to capture frames during warmup.")
|
||||
|
||||
except Exception as e:
|
||||
self._cleanup()
|
||||
@@ -131,15 +166,14 @@ class ZMQCamera(Camera):
|
||||
|
||||
@staticmethod
|
||||
def find_cameras() -> list[dict[str, Any]]:
|
||||
"""ZMQ cameras require manual configuration (server address/port)."""
|
||||
return []
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Read a single frame from the ZMQ camera.
|
||||
Detection not implemented for ZMQ cameras. These cameras require manual configuration (server address/port).
|
||||
"""
|
||||
raise NotImplementedError("Camera detection is not implemented for ZMQ cameras.")
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decoded frame (height, width, 3)
|
||||
def _read_from_hardware(self) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame directly from the ZMQ socket.
|
||||
"""
|
||||
if not self.is_connected or self.socket is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
@@ -147,6 +181,7 @@ class ZMQCamera(Camera):
|
||||
try:
|
||||
message = self.socket.recv_string()
|
||||
except Exception as e:
|
||||
# Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import
|
||||
if type(e).__name__ == "Again":
|
||||
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
|
||||
raise
|
||||
@@ -176,42 +211,117 @@ class ZMQCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
while self.stop_event and not self.stop_event.is_set():
|
||||
try:
|
||||
frame = self.read()
|
||||
with self.frame_lock:
|
||||
self.latest_frame = frame
|
||||
self.new_frame_event.set()
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Read error: {e}")
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
if self.thread and self.thread.is_alive():
|
||||
return
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, daemon=True)
|
||||
self.thread.start()
|
||||
This is a blocking call. It waits for the next available frame from the
|
||||
camera background thread.
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
if self.stop_event:
|
||||
self.stop_event.set()
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
Returns:
|
||||
np.ndarray: Decoded frame (height, width, 3)
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if color_mode is not None:
|
||||
logger.warning(
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]:
|
||||
"""Read latest frame asynchronously (non-blocking)."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if not self.thread or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
self.new_frame_event.clear()
|
||||
frame = self.async_read(timeout_ms=10000)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = frame
|
||||
self.latest_timestamp = capture_time
|
||||
self.new_frame_event.set()
|
||||
failure_count = 0
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except (TimeoutError, Exception) as e:
|
||||
if failure_count <= 10:
|
||||
failure_count += 1
|
||||
logger.warning(f"Read error: {e}")
|
||||
else:
|
||||
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, daemon=True, name=f"{self}_read_loop")
|
||||
self.thread.start()
|
||||
time.sleep(0.1)
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
to become available. Defaults to 200ms.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The latest captured frame.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
TimeoutError: If no frame data becomes available within the specified timeout.
|
||||
RuntimeError: If the background thread is not running.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms")
|
||||
@@ -225,11 +335,55 @@ class ZMQCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Returns:
|
||||
NDArray[Any]: The frame image (numpy array).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from ZMQ camera."""
|
||||
if not self.is_connected and not self.thread:
|
||||
if not self.is_connected and self.thread is None:
|
||||
raise DeviceNotConnectedError(f"{self} not connected.")
|
||||
|
||||
self._stop_read_thread()
|
||||
if self.thread is not None:
|
||||
self._stop_read_thread()
|
||||
|
||||
self._cleanup()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -29,6 +29,7 @@ class ZMQCameraConfig(CameraConfig):
|
||||
camera_name: str = "zmq_camera"
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
timeout_ms: int = 5000
|
||||
warmup_s: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
|
||||
@@ -19,6 +19,7 @@ import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
|
||||
@@ -32,6 +33,7 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
get_file_size_in_mb,
|
||||
get_hf_features_from_features,
|
||||
get_parquet_file_size_in_mb,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
@@ -70,10 +72,11 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
raise ValueError(
|
||||
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
|
||||
)
|
||||
if features != meta.features:
|
||||
raise ValueError(
|
||||
f"Same features is expected, but got features={meta.features} instead of {features}."
|
||||
)
|
||||
# TODO: Temporarily disabled for merging datasets with different features (e.g. shirt_id)
|
||||
# if features != meta.features:
|
||||
# raise ValueError(
|
||||
# f"Same features is expected, but got features={meta.features} instead of {features}."
|
||||
# )
|
||||
|
||||
return fps, robot_type, features
|
||||
|
||||
@@ -114,6 +117,9 @@ def update_meta_data(
|
||||
Adjusts all indices and timestamps to account for previously aggregated
|
||||
data and videos in the destination dataset.
|
||||
|
||||
For data file indices, uses the 'src_to_dst' mapping from aggregate_data()
|
||||
to correctly map source file indices to their destination locations.
|
||||
|
||||
Args:
|
||||
df: DataFrame containing the metadata to be updated.
|
||||
dst_meta: Destination dataset metadata.
|
||||
@@ -127,8 +133,50 @@ def update_meta_data(
|
||||
|
||||
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
||||
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
|
||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||
|
||||
# Update data file indices using source-to-destination mapping
|
||||
# This is critical for handling datasets that are already results of a merge
|
||||
data_src_to_dst = data_idx.get("src_to_dst", {})
|
||||
if data_src_to_dst:
|
||||
# Store original indices for lookup
|
||||
df["_orig_data_chunk"] = df["data/chunk_index"].copy()
|
||||
df["_orig_data_file"] = df["data/file_index"].copy()
|
||||
|
||||
# Vectorized mapping from (src_chunk, src_file) to (dst_chunk, dst_file)
|
||||
# This is much faster than per-row iteration for large metadata tables
|
||||
mapping_index = pd.MultiIndex.from_tuples(
|
||||
list(data_src_to_dst.keys()),
|
||||
names=["chunk_index", "file_index"],
|
||||
)
|
||||
mapping_values = list(data_src_to_dst.values())
|
||||
mapping_df = pd.DataFrame(
|
||||
mapping_values,
|
||||
index=mapping_index,
|
||||
columns=["dst_chunk", "dst_file"],
|
||||
)
|
||||
|
||||
# Construct a MultiIndex for each row based on original data indices
|
||||
row_index = pd.MultiIndex.from_arrays(
|
||||
[df["_orig_data_chunk"], df["_orig_data_file"]],
|
||||
names=["chunk_index", "file_index"],
|
||||
)
|
||||
|
||||
# Align mapping to rows; missing keys fall back to the default destination
|
||||
reindexed = mapping_df.reindex(row_index)
|
||||
reindexed[["dst_chunk", "dst_file"]] = reindexed[["dst_chunk", "dst_file"]].fillna(
|
||||
{"dst_chunk": data_idx["chunk"], "dst_file": data_idx["file"]}
|
||||
)
|
||||
|
||||
# Assign mapped destination indices back to the DataFrame
|
||||
df["data/chunk_index"] = reindexed["dst_chunk"].to_numpy()
|
||||
df["data/file_index"] = reindexed["dst_file"].to_numpy()
|
||||
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"])
|
||||
else:
|
||||
# Fallback to simple offset (backward compatibility for single-file sources)
|
||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||
for key, video_idx in videos_idx.items():
|
||||
# Store original video file indices before updating
|
||||
orig_chunk_col = f"videos/{key}/chunk_index"
|
||||
@@ -144,8 +192,7 @@ def update_meta_data(
|
||||
if src_to_dst:
|
||||
# Map each episode to its correct destination file and apply offset
|
||||
for idx in df.index:
|
||||
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
||||
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
|
||||
# Get destination chunk/file for this source file
|
||||
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
|
||||
@@ -161,8 +208,7 @@ def update_meta_data(
|
||||
df[orig_chunk_col] = video_idx["chunk"]
|
||||
df[orig_file_col] = video_idx["file"]
|
||||
for idx in df.index:
|
||||
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
||||
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||
@@ -260,6 +306,10 @@ def aggregate_datasets(
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||
|
||||
# Clear the src_to_dst mapping after processing each source dataset
|
||||
# to avoid interference between different source datasets
|
||||
data_idx.pop("src_to_dst", None)
|
||||
|
||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||
|
||||
@@ -310,10 +360,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
dst_file_durations = video_idx["dst_file_durations"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
# Convert to Python int to ensure consistent dict keys
|
||||
src_chunk_idx = int(src_chunk_idx)
|
||||
src_file_idx = int(src_file_idx)
|
||||
|
||||
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=src_chunk_idx,
|
||||
@@ -386,10 +432,16 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
Reads source data files, updates indices to match the aggregated dataset,
|
||||
and writes them to the destination with proper file rotation.
|
||||
|
||||
Tracks a `src_to_dst` mapping from source (chunk, file) to destination (chunk, file)
|
||||
which is critical for correctly updating episode metadata when source datasets
|
||||
have multiple data files (e.g., from a previous merge operation).
|
||||
|
||||
Args:
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
data_idx: Dictionary tracking data chunk and file indices.
|
||||
data_files_size_in_mb: Maximum size for data files in MB.
|
||||
chunk_size: Maximum number of files per chunk.
|
||||
|
||||
Returns:
|
||||
dict: Updated data_idx with current chunk and file indices.
|
||||
@@ -402,25 +454,47 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
}
|
||||
|
||||
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
|
||||
contains_images = len(dst_meta.image_keys) > 0
|
||||
|
||||
# retrieve features schema for proper image typing in parquet
|
||||
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
|
||||
|
||||
# Track source to destination file mapping for metadata update
|
||||
# This is critical for handling datasets that are already results of a merge
|
||||
src_to_dst: dict[tuple[int, int], tuple[int, int]] = {}
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
||||
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
||||
chunk_index=src_chunk_idx, file_index=src_file_idx
|
||||
)
|
||||
df = pd.read_parquet(src_path)
|
||||
if contains_images:
|
||||
# Use HuggingFace datasets to read source data to preserve image format
|
||||
src_ds = datasets.Dataset.from_parquet(str(src_path))
|
||||
df = src_ds.to_pandas()
|
||||
else:
|
||||
df = pd.read_parquet(src_path)
|
||||
df = update_data_df(df, src_meta, dst_meta)
|
||||
|
||||
data_idx = append_or_create_parquet_file(
|
||||
# Write data and get the actual destination file it was written to
|
||||
# This avoids duplicating the rotation logic here
|
||||
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
data_idx,
|
||||
data_files_size_in_mb,
|
||||
chunk_size,
|
||||
DEFAULT_DATA_PATH,
|
||||
contains_images=len(dst_meta.image_keys) > 0,
|
||||
contains_images=contains_images,
|
||||
aggr_root=dst_meta.root,
|
||||
hf_features=hf_features,
|
||||
)
|
||||
|
||||
# Record the mapping from source to actual destination
|
||||
src_to_dst[(src_chunk_idx, src_file_idx)] = (dst_chunk, dst_file)
|
||||
|
||||
# Add the mapping to data_idx for use in metadata update
|
||||
data_idx["src_to_dst"] = src_to_dst
|
||||
|
||||
return data_idx
|
||||
|
||||
|
||||
@@ -461,7 +535,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
videos_idx,
|
||||
)
|
||||
|
||||
meta_idx = append_or_create_parquet_file(
|
||||
meta_idx, _ = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
meta_idx,
|
||||
@@ -488,7 +562,8 @@ def append_or_create_parquet_file(
|
||||
default_path: str,
|
||||
contains_images: bool = False,
|
||||
aggr_root: Path = None,
|
||||
):
|
||||
hf_features: datasets.Features | None = None,
|
||||
) -> tuple[dict[str, int], tuple[int, int]]:
|
||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||
|
||||
Manages file rotation when size limits are exceeded to prevent individual files
|
||||
@@ -503,40 +578,49 @@ def append_or_create_parquet_file(
|
||||
default_path: Format string for generating file paths.
|
||||
contains_images: Whether the data contains images requiring special handling.
|
||||
aggr_root: Root path for the aggregated dataset.
|
||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||
|
||||
Returns:
|
||||
dict: Updated index dictionary with current chunk and file indices.
|
||||
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
||||
and (dst_chunk, dst_file) is the actual destination file the data was written to.
|
||||
"""
|
||||
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||
|
||||
if not dst_path.exists():
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(df, dst_path)
|
||||
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
||||
else:
|
||||
df.to_parquet(dst_path)
|
||||
return idx
|
||||
return idx, (dst_chunk, dst_file)
|
||||
|
||||
src_size = get_parquet_file_size_in_mb(src_path)
|
||||
dst_size = get_parquet_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= max_mb:
|
||||
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
||||
new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
final_df = df
|
||||
target_path = new_path
|
||||
else:
|
||||
existing_df = pd.read_parquet(dst_path)
|
||||
if contains_images:
|
||||
# Use HuggingFace datasets to read existing data to preserve image format
|
||||
existing_ds = datasets.Dataset.from_parquet(str(dst_path))
|
||||
existing_df = existing_ds.to_pandas()
|
||||
else:
|
||||
existing_df = pd.read_parquet(dst_path)
|
||||
final_df = pd.concat([existing_df, df], ignore_index=True)
|
||||
target_path = dst_path
|
||||
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(final_df, target_path)
|
||||
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
||||
else:
|
||||
final_df.to_parquet(target_path)
|
||||
|
||||
return idx
|
||||
return idx, (dst_chunk, dst_file)
|
||||
|
||||
|
||||
def finalize_aggregation(aggr_meta, all_metadata):
|
||||
|
||||
@@ -26,6 +26,7 @@ This module provides utilities for:
|
||||
import logging
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -51,7 +52,8 @@ from lerobot.datasets.utils import (
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
|
||||
|
||||
|
||||
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||
@@ -1083,3 +1085,687 @@ def _copy_episodes_metadata_and_stats(
|
||||
else:
|
||||
if src_dataset.meta.stats:
|
||||
write_stats(src_dataset.meta.stats, dst_meta.root)
|
||||
|
||||
|
||||
def _save_episode_images_for_video(
|
||||
dataset: LeRobotDataset,
|
||||
imgs_dir: Path,
|
||||
img_key: str,
|
||||
episode_index: int,
|
||||
num_workers: int = 4,
|
||||
) -> None:
|
||||
"""Save images from a specific episode and camera to disk for video encoding.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset to extract images from
|
||||
imgs_dir: Directory to save images to
|
||||
img_key: The image key (camera) to extract
|
||||
episode_index: Index of the episode to save
|
||||
num_workers: Number of threads for parallel image saving
|
||||
"""
|
||||
# Create directory
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get dataset without torch format for PIL image access
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# Select only this camera's images
|
||||
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||
|
||||
# Get episode start and end indices
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
|
||||
# Get all items for this episode
|
||||
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||
|
||||
# Define function to save a single image
|
||||
def save_single_image(i_item_tuple):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key]
|
||||
# Use frame-XXXXXX.png format to match encode_video_frames expectations
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
return i
|
||||
|
||||
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
|
||||
items = list(enumerate(episode_dataset))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(save_single_image, item) for item in items]
|
||||
for future in as_completed(futures):
|
||||
future.result() # This will raise any exceptions that occurred
|
||||
|
||||
|
||||
def _save_batch_episodes_images(
|
||||
dataset: LeRobotDataset,
|
||||
imgs_dir: Path,
|
||||
img_key: str,
|
||||
episode_indices: list[int],
|
||||
num_workers: int = 4,
|
||||
) -> list[float]:
|
||||
"""Save images from multiple episodes to disk for batch video encoding.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset to extract images from
|
||||
imgs_dir: Directory to save images to
|
||||
img_key: The image key (camera) to extract
|
||||
episode_indices: List of episode indices to save
|
||||
num_workers: Number of threads for parallel image saving
|
||||
|
||||
Returns:
|
||||
List of episode durations in seconds
|
||||
"""
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||
|
||||
# Define function to save a single image with global frame index
|
||||
# Defined once outside the loop to avoid repeated closure creation
|
||||
def save_single_image(i_item_tuple, base_frame_idx, img_key_param):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key_param]
|
||||
# Use global frame index for naming
|
||||
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100)
|
||||
return i
|
||||
|
||||
episode_durations = []
|
||||
frame_idx = 0
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
# Get episode range
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][ep_idx]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][ep_idx]
|
||||
episode_length = to_idx - from_idx
|
||||
episode_durations.append(episode_length / dataset.fps)
|
||||
|
||||
# Get episode images
|
||||
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||
|
||||
# Save images
|
||||
items = list(enumerate(episode_dataset))
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(save_single_image, item, frame_idx, img_key) for item in items]
|
||||
for future in as_completed(futures):
|
||||
future.result()
|
||||
|
||||
frame_idx += episode_length
|
||||
|
||||
return episode_durations
|
||||
|
||||
|
||||
def _iter_episode_batches(
|
||||
episode_indices: list[int],
|
||||
episode_lengths: dict[int, int],
|
||||
size_per_frame_mb: float,
|
||||
video_file_size_limit: float,
|
||||
max_episodes: int | None,
|
||||
max_frames: int | None,
|
||||
):
|
||||
"""Generator that yields batches of episode indices for video encoding.
|
||||
|
||||
Groups episodes into batches that respect size and memory constraints:
|
||||
- Stays under video file size limit
|
||||
- Respects maximum episodes per batch (if specified)
|
||||
- Respects maximum frames per batch (if specified)
|
||||
|
||||
Args:
|
||||
episode_indices: List of episode indices to batch
|
||||
episode_lengths: Dictionary mapping episode index to episode length
|
||||
size_per_frame_mb: Estimated size per frame in MB
|
||||
video_file_size_limit: Maximum video file size in MB
|
||||
max_episodes: Maximum number of episodes per batch (None = no limit)
|
||||
max_frames: Maximum number of frames per batch (None = no limit)
|
||||
|
||||
Yields:
|
||||
List of episode indices for each batch
|
||||
"""
|
||||
batch_episodes = []
|
||||
estimated_size = 0.0
|
||||
total_frames = 0
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
ep_length = episode_lengths[ep_idx]
|
||||
ep_estimated_size = ep_length * size_per_frame_mb
|
||||
|
||||
# we check if adding this episode would exceed any constraint
|
||||
would_exceed_size = estimated_size > 0 and estimated_size + ep_estimated_size >= video_file_size_limit
|
||||
would_exceed_episodes = max_episodes is not None and len(batch_episodes) >= max_episodes
|
||||
would_exceed_frames = max_frames is not None and total_frames + ep_length > max_frames
|
||||
|
||||
if batch_episodes and (would_exceed_size or would_exceed_episodes or would_exceed_frames):
|
||||
# yield current batch before adding this episode
|
||||
yield batch_episodes
|
||||
# start a new batch with current episode
|
||||
batch_episodes = [ep_idx]
|
||||
estimated_size = ep_estimated_size
|
||||
total_frames = ep_length
|
||||
else:
|
||||
# add to current batch
|
||||
batch_episodes.append(ep_idx)
|
||||
estimated_size += ep_estimated_size
|
||||
total_frames += ep_length
|
||||
|
||||
# yield final batch if not empty
|
||||
if batch_episodes:
|
||||
yield batch_episodes
|
||||
|
||||
|
||||
def _estimate_frame_size_via_calibration(
|
||||
dataset: LeRobotDataset,
|
||||
img_key: str,
|
||||
episode_indices: list[int],
|
||||
temp_dir: Path,
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
fast_decode: int,
|
||||
num_calibration_frames: int = 30,
|
||||
) -> float:
|
||||
"""Estimate MB per frame by encoding a small calibration sample.
|
||||
|
||||
Encodes a representative sample of frames using the exact codec parameters
|
||||
to measure actual compression ratio, which is more accurate than heuristics.
|
||||
|
||||
Args:
|
||||
dataset: Source dataset with images.
|
||||
img_key: Image key to calibrate (e.g., "observation.images.top").
|
||||
episode_indices: List of episode indices being processed.
|
||||
temp_dir: Temporary directory for calibration files.
|
||||
fps: Frames per second for video encoding.
|
||||
vcodec: Video codec (libsvtav1, h264, hevc).
|
||||
pix_fmt: Pixel format (yuv420p, etc.).
|
||||
g: GOP size (group of pictures).
|
||||
crf: Constant Rate Factor (quality).
|
||||
fast_decode: Fast decode tuning parameter.
|
||||
num_calibration_frames: Number of frames to use for calibration (default: 30).
|
||||
|
||||
Returns:
|
||||
Estimated size in MB per frame based on actual encoding.
|
||||
"""
|
||||
calibration_dir = temp_dir / "calibration" / img_key
|
||||
calibration_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
# Select a representative episode (prefer middle episode if available)
|
||||
calibration_ep_idx = episode_indices[len(episode_indices) // 2]
|
||||
|
||||
# Get episode range
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][calibration_ep_idx]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][calibration_ep_idx]
|
||||
episode_length = to_idx - from_idx
|
||||
|
||||
# Use up to num_calibration_frames from this episode
|
||||
num_frames = min(num_calibration_frames, episode_length)
|
||||
|
||||
# Get frames from dataset
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
sample_indices = range(from_idx, from_idx + num_frames)
|
||||
|
||||
# Save calibration frames
|
||||
for i, idx in enumerate(sample_indices):
|
||||
img = hf_dataset[idx][img_key]
|
||||
img.save(str(calibration_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
|
||||
# Encode calibration video
|
||||
calibration_video_path = calibration_dir / "calibration.mp4"
|
||||
encode_video_frames(
|
||||
imgs_dir=calibration_dir,
|
||||
video_path=calibration_video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Measure actual compressed size
|
||||
video_size_bytes = calibration_video_path.stat().st_size
|
||||
video_size_mb = video_size_bytes / BYTES_PER_MIB
|
||||
size_per_frame_mb = video_size_mb / num_frames
|
||||
|
||||
logging.info(
|
||||
f" Calibration: {num_frames} frames -> {video_size_mb:.2f} MB "
|
||||
f"= {size_per_frame_mb:.4f} MB/frame for {img_key}"
|
||||
)
|
||||
|
||||
return size_per_frame_mb
|
||||
|
||||
finally:
|
||||
# Clean up calibration files
|
||||
if calibration_dir.exists():
|
||||
shutil.rmtree(calibration_dir)
|
||||
|
||||
|
||||
def _copy_data_without_images(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_indices: list[int],
|
||||
img_keys: list[str],
|
||||
) -> None:
|
||||
"""Copy data files without image columns.
|
||||
|
||||
Args:
|
||||
src_dataset: Source dataset
|
||||
dst_meta: Destination metadata
|
||||
episode_indices: Episodes to include
|
||||
img_keys: Image keys to remove
|
||||
"""
|
||||
from lerobot.datasets.utils import DATA_DIR
|
||||
|
||||
data_dir = src_dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
episode_set = set(episode_indices)
|
||||
|
||||
for src_path in tqdm(parquet_files, desc="Processing data files"):
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
|
||||
# Filter to only include selected episodes
|
||||
df = df[df["episode_index"].isin(episode_set)].copy()
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
# Remove image columns
|
||||
columns_to_drop = [col for col in img_keys if col in df.columns]
|
||||
if columns_to_drop:
|
||||
df = df.drop(columns=columns_to_drop)
|
||||
|
||||
# Get chunk and file indices from path
|
||||
relative_path = src_path.relative_to(src_dataset.root)
|
||||
chunk_dir = relative_path.parts[1]
|
||||
file_name = relative_path.parts[2]
|
||||
chunk_idx = int(chunk_dir.split("-")[1])
|
||||
file_idx = int(file_name.split("-")[1].split(".")[0])
|
||||
|
||||
# Write to destination without pandas index
|
||||
dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet"
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_path, index=False)
|
||||
|
||||
|
||||
# Video conversion constants
|
||||
BYTES_PER_KIB = 1024
|
||||
BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB
|
||||
|
||||
|
||||
def modify_tasks(
|
||||
dataset: LeRobotDataset,
|
||||
new_task: str | None = None,
|
||||
episode_tasks: dict[int, str] | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Modify tasks in a LeRobotDataset.
|
||||
|
||||
This function allows you to either:
|
||||
1. Set a single task for the entire dataset (using `new_task`)
|
||||
2. Set specific tasks for specific episodes (using `episode_tasks`)
|
||||
|
||||
You can combine both: `new_task` sets the default, and `episode_tasks` overrides
|
||||
specific episodes.
|
||||
|
||||
The dataset is modified in-place, updating only the task-related files:
|
||||
- meta/tasks.parquet
|
||||
- data/**/*.parquet (task_index column)
|
||||
- meta/episodes/**/*.parquet (tasks column)
|
||||
- meta/info.json (total_tasks)
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset to modify.
|
||||
new_task: A single task string to apply to all episodes. If None and episode_tasks
|
||||
is also None, raises an error.
|
||||
episode_tasks: Optional dict mapping episode indices to their task strings.
|
||||
Overrides `new_task` for specific episodes.
|
||||
|
||||
|
||||
Examples:
|
||||
Set a single task for all episodes:
|
||||
dataset = modify_tasks(dataset, new_task="Pick up the cube")
|
||||
|
||||
Set different tasks for specific episodes:
|
||||
dataset = modify_tasks(
|
||||
dataset,
|
||||
episode_tasks={0: "Task A", 1: "Task B", 2: "Task A"}
|
||||
)
|
||||
|
||||
Set a default task with overrides:
|
||||
dataset = modify_tasks(
|
||||
dataset,
|
||||
new_task="Default task",
|
||||
episode_tasks={5: "Special task for episode 5"}
|
||||
)
|
||||
"""
|
||||
if new_task is None and episode_tasks is None:
|
||||
raise ValueError("Must specify at least one of new_task or episode_tasks")
|
||||
|
||||
if episode_tasks is not None:
|
||||
valid_indices = set(range(dataset.meta.total_episodes))
|
||||
invalid = set(episode_tasks.keys()) - valid_indices
|
||||
if invalid:
|
||||
raise ValueError(f"Invalid episode indices: {invalid}")
|
||||
|
||||
# Ensure episodes metadata is loaded
|
||||
if dataset.meta.episodes is None:
|
||||
dataset.meta.episodes = load_episodes(dataset.root)
|
||||
|
||||
# Build the mapping from episode index to task string
|
||||
episode_to_task: dict[int, str] = {}
|
||||
for ep_idx in range(dataset.meta.total_episodes):
|
||||
if episode_tasks and ep_idx in episode_tasks:
|
||||
episode_to_task[ep_idx] = episode_tasks[ep_idx]
|
||||
elif new_task is not None:
|
||||
episode_to_task[ep_idx] = new_task
|
||||
else:
|
||||
# Keep original task if not overridden and no default provided
|
||||
original_tasks = dataset.meta.episodes[ep_idx]["tasks"]
|
||||
if not original_tasks:
|
||||
raise ValueError(f"Episode {ep_idx} has no tasks and no default task was provided")
|
||||
episode_to_task[ep_idx] = original_tasks[0]
|
||||
|
||||
# Collect all unique tasks and create new task mapping
|
||||
unique_tasks = sorted(set(episode_to_task.values()))
|
||||
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
|
||||
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
|
||||
|
||||
logging.info(f"Modifying tasks in {dataset.repo_id}")
|
||||
logging.info(f"New tasks: {unique_tasks}")
|
||||
|
||||
root = dataset.root
|
||||
|
||||
# Update data files - modify task_index column
|
||||
logging.info("Updating data files...")
|
||||
data_dir = root / DATA_DIR
|
||||
|
||||
for parquet_path in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Updating data"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
# Build a mapping from episode_index to new task_index for rows in this file
|
||||
episode_indices_in_file = df["episode_index"].unique()
|
||||
ep_to_new_task_idx = {
|
||||
ep_idx: task_to_index[episode_to_task[ep_idx]] for ep_idx in episode_indices_in_file
|
||||
}
|
||||
|
||||
# Update task_index column
|
||||
df["task_index"] = df["episode_index"].map(ep_to_new_task_idx)
|
||||
df.to_parquet(parquet_path, index=False)
|
||||
|
||||
# Update episodes metadata - modify tasks column
|
||||
logging.info("Updating episodes metadata...")
|
||||
episodes_dir = root / "meta" / "episodes"
|
||||
|
||||
for parquet_path in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Updating episodes"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
# Update tasks column
|
||||
df["tasks"] = df["episode_index"].apply(lambda ep_idx: [episode_to_task[ep_idx]])
|
||||
df.to_parquet(parquet_path, index=False)
|
||||
|
||||
# Write new tasks.parquet
|
||||
write_tasks(new_task_df, root)
|
||||
|
||||
# Update info.json
|
||||
dataset.meta.info["total_tasks"] = len(unique_tasks)
|
||||
write_info(dataset.meta.info, root)
|
||||
|
||||
# Reload metadata to reflect changes
|
||||
dataset.meta.tasks = new_task_df
|
||||
dataset.meta.episodes = load_episodes(root)
|
||||
|
||||
logging.info(f"Tasks: {unique_tasks}")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int = 2,
|
||||
crf: int = 30,
|
||||
fast_decode: int = 0,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
max_episodes_per_batch: int | None = None,
|
||||
max_frames_per_batch: int | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Convert image-to-video dataset.
|
||||
|
||||
Creates a new LeRobotDataset with images encoded as videos, following the proper
|
||||
LeRobot dataset structure with videos stored in chunked MP4 files.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Directory to save the new video dataset
|
||||
repo_id: Repository ID for the new dataset (default: original_id + "_video")
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
crf: Constant rate factor (default: 30)
|
||||
fast_decode: Fast decode tuning (default: 0)
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
|
||||
max_frames_per_batch: Maximum frames per video batch to avoid memory issues (None = no limit)
|
||||
|
||||
Returns:
|
||||
New LeRobotDataset with images encoded as videos
|
||||
"""
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
|
||||
)
|
||||
|
||||
# Get all image keys
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
if len(img_keys) == 0:
|
||||
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
|
||||
|
||||
# Determine which episodes to process
|
||||
if episode_indices is None:
|
||||
episode_indices = list(range(dataset.meta.total_episodes))
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_video"
|
||||
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
for key, value in dataset.meta.features.items():
|
||||
if key not in img_keys:
|
||||
new_features[key] = value
|
||||
else:
|
||||
# Convert image key to video format
|
||||
new_features[key] = value.copy()
|
||||
new_features[key]["dtype"] = "video" # Change dtype from "image" to "video"
|
||||
# Video info will be updated after episodes are encoded
|
||||
|
||||
# Create new metadata for video dataset
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=new_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=True,
|
||||
chunks_size=dataset.meta.chunks_size,
|
||||
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
|
||||
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
|
||||
)
|
||||
|
||||
# Create temporary directory for image extraction
|
||||
temp_dir = output_dir / "temp_images"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Process all episodes and batch encode videos
|
||||
# Use dictionary for O(1) episode metadata lookups instead of O(n) linear search
|
||||
all_episode_metadata = {}
|
||||
fps = int(dataset.fps)
|
||||
|
||||
try:
|
||||
# Build episode metadata entries first
|
||||
logging.info("Building episode metadata...")
|
||||
cumulative_frame_idx = 0
|
||||
for ep_idx in episode_indices:
|
||||
src_episode = dataset.meta.episodes[ep_idx]
|
||||
ep_length = src_episode["length"]
|
||||
ep_meta = {
|
||||
"episode_index": ep_idx,
|
||||
"length": ep_length,
|
||||
"dataset_from_index": cumulative_frame_idx,
|
||||
"dataset_to_index": cumulative_frame_idx + ep_length,
|
||||
}
|
||||
if "data/chunk_index" in src_episode:
|
||||
ep_meta["data/chunk_index"] = src_episode["data/chunk_index"]
|
||||
ep_meta["data/file_index"] = src_episode["data/file_index"]
|
||||
all_episode_metadata[ep_idx] = ep_meta
|
||||
cumulative_frame_idx += ep_length
|
||||
|
||||
# Process each camera and batch encode multiple episodes together
|
||||
video_file_size_limit = new_meta.video_files_size_in_mb
|
||||
|
||||
# Pre-compute episode lengths for batching
|
||||
episode_lengths = {ep_idx: dataset.meta.episodes["length"][ep_idx] for ep_idx in episode_indices}
|
||||
|
||||
for img_key in tqdm(img_keys, desc="Processing cameras"):
|
||||
# Estimate size per frame by encoding a small calibration sample
|
||||
# This provides accurate compression ratio for the specific codec parameters
|
||||
size_per_frame_mb = _estimate_frame_size_via_calibration(
|
||||
dataset=dataset,
|
||||
img_key=img_key,
|
||||
episode_indices=episode_indices,
|
||||
temp_dir=temp_dir,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
)
|
||||
|
||||
logging.info(f"Processing camera: {img_key}")
|
||||
chunk_idx, file_idx = 0, 0
|
||||
cumulative_timestamp = 0.0
|
||||
|
||||
# Process episodes in batches to stay under size limit
|
||||
for batch_episodes in _iter_episode_batches(
|
||||
episode_indices=episode_indices,
|
||||
episode_lengths=episode_lengths,
|
||||
size_per_frame_mb=size_per_frame_mb,
|
||||
video_file_size_limit=video_file_size_limit,
|
||||
max_episodes=max_episodes_per_batch,
|
||||
max_frames=max_frames_per_batch,
|
||||
):
|
||||
total_frames_in_batch = sum(episode_lengths[idx] for idx in batch_episodes)
|
||||
logging.info(
|
||||
f" Encoding batch of {len(batch_episodes)} episodes "
|
||||
f"({batch_episodes[0]}-{batch_episodes[-1]}) = {total_frames_in_batch} frames"
|
||||
)
|
||||
|
||||
# Save images for all episodes in this batch
|
||||
imgs_dir = temp_dir / f"batch_{chunk_idx}_{file_idx}" / img_key
|
||||
episode_durations = _save_batch_episodes_images(
|
||||
dataset=dataset,
|
||||
imgs_dir=imgs_dir,
|
||||
img_key=img_key,
|
||||
episode_indices=batch_episodes,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
# Encode all batched episodes into single video
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
encode_video_frames(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Clean up temporary images
|
||||
shutil.rmtree(imgs_dir)
|
||||
|
||||
# Update metadata for each episode in the batch
|
||||
for ep_idx, duration in zip(batch_episodes, episode_durations, strict=True):
|
||||
from_timestamp = cumulative_timestamp
|
||||
to_timestamp = cumulative_timestamp + duration
|
||||
cumulative_timestamp = to_timestamp
|
||||
|
||||
# Find episode metadata entry and add video metadata (O(1) dictionary lookup)
|
||||
ep_meta = all_episode_metadata[ep_idx]
|
||||
ep_meta[f"videos/{img_key}/chunk_index"] = chunk_idx
|
||||
ep_meta[f"videos/{img_key}/file_index"] = file_idx
|
||||
ep_meta[f"videos/{img_key}/from_timestamp"] = from_timestamp
|
||||
ep_meta[f"videos/{img_key}/to_timestamp"] = to_timestamp
|
||||
|
||||
# Move to next video file for next batch
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, new_meta.chunks_size)
|
||||
cumulative_timestamp = 0.0
|
||||
|
||||
# Copy and transform data files (removing image columns)
|
||||
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
|
||||
|
||||
# Save episode metadata
|
||||
episodes_df = pd.DataFrame(list(all_episode_metadata.values()))
|
||||
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
|
||||
episodes_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes_df.to_parquet(episodes_path, index=False)
|
||||
|
||||
# Update metadata info
|
||||
new_meta.info["total_episodes"] = len(episode_indices)
|
||||
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values())
|
||||
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
||||
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys (now videos)
|
||||
# We need to manually set video info since update_video_info() checks video_keys first
|
||||
for img_key in img_keys:
|
||||
if not new_meta.features[img_key].get("info", None):
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
# Copy stats and tasks
|
||||
if dataset.meta.stats is not None:
|
||||
# Remove image stats
|
||||
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
|
||||
write_stats(new_stats, new_meta.root)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
|
||||
finally:
|
||||
# Clean up temporary directory
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
logging.info(f"Completed converting {dataset.repo_id} to video format")
|
||||
logging.info(f"New dataset saved to: {output_dir}")
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
|
||||
@@ -57,6 +57,7 @@ from lerobot.datasets.utils import (
|
||||
load_info,
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
update_chunk_file_indices,
|
||||
validate_episode_buffer,
|
||||
@@ -162,6 +163,7 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
@@ -518,6 +520,7 @@ class LeRobotDatasetMetadata:
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
@@ -560,7 +563,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episodes: list[int] | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
tolerance_s: float = 1e-2,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
@@ -935,17 +938,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
def _get_query_indices(
|
||||
self, abs_idx: int, ep_idx: int
|
||||
) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]:
|
||||
"""Compute query indices for delta timestamps.
|
||||
|
||||
Args:
|
||||
abs_idx: The absolute index in the full dataset (not the relative index in filtered episodes).
|
||||
ep_idx: The episode index.
|
||||
|
||||
Returns:
|
||||
A tuple of (query_indices, padding) where:
|
||||
- query_indices: Dict mapping keys to lists of absolute indices to query
|
||||
- padding: Dict mapping "{key}_is_pad" to boolean tensors indicating padded positions
|
||||
"""
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
ep_start = ep["dataset_from_index"]
|
||||
ep_end = ep["dataset_to_index"]
|
||||
query_indices = {
|
||||
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
|
||||
key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx]
|
||||
[(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx]
|
||||
)
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
@@ -1037,10 +1053,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self._ensure_hf_dataset_loaded()
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
# Use the absolute index from the dataset for delta timestamp calculations
|
||||
abs_idx = item["index"].item()
|
||||
|
||||
query_indices = None
|
||||
if self.delta_indices is not None:
|
||||
query_indices, padding = self._get_query_indices(idx, ep_idx)
|
||||
query_indices, padding = self._get_query_indices(abs_idx, ep_idx)
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
@@ -1060,6 +1078,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks.iloc[task_idx].name
|
||||
|
||||
# add subtask information if available
|
||||
if "subtask_index" in self.features and self.meta.subtasks is not None:
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
@@ -1498,7 +1522,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if isinstance(episode_index, np.ndarray):
|
||||
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
|
||||
for cam_key in self.meta.camera_keys:
|
||||
for cam_key in self.meta.image_keys:
|
||||
img_dir = self._get_image_file_dir(episode_index, cam_key)
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
@@ -1548,7 +1572,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
root: str | Path | None = None,
|
||||
robot_type: str | None = None,
|
||||
use_videos: bool = True,
|
||||
tolerance_s: float = 1e-4,
|
||||
tolerance_s: float = 1e-2,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
|
||||
@@ -60,6 +60,7 @@ VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
@@ -353,6 +354,14 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
return tasks
|
||||
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
This function writes episode-level metadata to a single parquet file.
|
||||
@@ -1172,12 +1181,21 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
||||
)
|
||||
|
||||
|
||||
def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None:
|
||||
def to_parquet_with_hf_images(
|
||||
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
|
||||
) -> None:
|
||||
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
||||
|
||||
Args:
|
||||
df: DataFrame to write to parquet.
|
||||
path: Path to write the parquet file.
|
||||
features: Optional HuggingFace Features schema. If provided, ensures image columns
|
||||
are properly typed as Image() in the parquet schema.
|
||||
"""
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
|
||||
ds.to_parquet(path)
|
||||
|
||||
|
||||
def item_to_torch(item: dict) -> dict:
|
||||
|
||||
@@ -260,6 +260,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
@dataclass
|
||||
class LiberoEnv(EnvConfig):
|
||||
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
|
||||
task_ids: list[int] | None = None
|
||||
fps: int = 30
|
||||
episode_length: int | None = None
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
@@ -338,10 +339,10 @@ class LiberoEnv(EnvConfig):
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
}
|
||||
kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
||||
if self.task_ids is not None:
|
||||
kwargs["task_ids"] = self.task_ids
|
||||
return kwargs
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("metaworld")
|
||||
|
||||
@@ -293,9 +293,9 @@ class LiberoEnv(gym.Env):
|
||||
def reset(self, seed=None, **kwargs):
|
||||
super().reset(seed=seed)
|
||||
self._env.seed(seed)
|
||||
if self.init_states and self._init_states is not None:
|
||||
self._env.set_init_state(self._init_states[self._init_state_id])
|
||||
raw_obs = self._env.reset()
|
||||
if self.init_states and self._init_states is not None:
|
||||
raw_obs = self._env.set_init_state(self._init_states[self._init_state_id])
|
||||
|
||||
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
|
||||
# Step the simulator with a no-op action for a few frames so everything settles.
|
||||
|
||||
@@ -14,4 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus
|
||||
from .motors_bus import (
|
||||
Motor,
|
||||
MotorCalibration,
|
||||
MotorNormMode,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ from dataclasses import dataclass
|
||||
|
||||
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"
|
||||
|
||||
from lerobot.motors import MotorCalibration, MotorsBus
|
||||
from .motors_bus import MotorCalibration, MotorsBus
|
||||
|
||||
BAR_LEN, BAR_THICKNESS = 450, 8
|
||||
HANDLE_R = 10
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .damiao import DamiaoMotorsBus
|
||||
from .tables import *
|
||||
@@ -0,0 +1,833 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Portions of this file are derived from DM_Control_Python by cmjang.
|
||||
# Licensed under the MIT License; see `LICENSE` for the full text:
|
||||
# https://github.com/cmjang/DM_Control_Python
|
||||
|
||||
import logging
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from lerobot.utils.import_utils import _can_available
|
||||
|
||||
if TYPE_CHECKING or _can_available:
|
||||
import can
|
||||
else:
|
||||
|
||||
class can: # noqa: N801
|
||||
Message = object
|
||||
interface = None
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBusBase, NameOrID, Value
|
||||
from .tables import (
|
||||
AVAILABLE_BAUDRATES,
|
||||
CAN_CMD_DISABLE,
|
||||
CAN_CMD_ENABLE,
|
||||
CAN_CMD_REFRESH,
|
||||
CAN_CMD_SET_ZERO,
|
||||
CAN_PARAM_ID,
|
||||
DEFAULT_BAUDRATE,
|
||||
DEFAULT_TIMEOUT_MS,
|
||||
MIT_KD_RANGE,
|
||||
MIT_KP_RANGE,
|
||||
MOTOR_LIMIT_PARAMS,
|
||||
MotorType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
LONG_TIMEOUT_SEC = 0.1
|
||||
MEDIUM_TIMEOUT_SEC = 0.01
|
||||
SHORT_TIMEOUT_SEC = 0.001
|
||||
PRECISE_TIMEOUT_SEC = 0.0001
|
||||
|
||||
|
||||
class MotorState(TypedDict):
|
||||
position: float
|
||||
velocity: float
|
||||
torque: float
|
||||
temp_mos: float
|
||||
temp_rotor: float
|
||||
|
||||
|
||||
class DamiaoMotorsBus(MotorsBusBase):
|
||||
"""
|
||||
The Damiao implementation for a MotorsBus using CAN bus communication.
|
||||
|
||||
This class uses python-can for CAN bus communication with Damiao motors.
|
||||
For more info, see:
|
||||
- python-can documentation: https://python-can.readthedocs.io/en/stable/
|
||||
- Seedstudio documentation: https://wiki.seeedstudio.com/damiao_series/
|
||||
- DM_Control_Python repo: https://github.com/cmjang/DM_Control_Python
|
||||
"""
|
||||
|
||||
# CAN-specific settings
|
||||
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
|
||||
default_baudrate = DEFAULT_BAUDRATE
|
||||
default_timeout = DEFAULT_TIMEOUT_MS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
can_interface: str = "auto",
|
||||
use_can_fd: bool = True,
|
||||
bitrate: int = 1000000,
|
||||
data_bitrate: int | None = 5000000,
|
||||
):
|
||||
"""
|
||||
Initialize the Damiao motors bus.
|
||||
|
||||
Args:
|
||||
port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS)
|
||||
motors: Dictionary mapping motor names to Motor objects
|
||||
calibration: Optional calibration data
|
||||
can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial)
|
||||
use_can_fd: Whether to use CAN FD mode (default: True for OpenArms)
|
||||
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
|
||||
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
|
||||
"""
|
||||
super().__init__(port, motors, calibration)
|
||||
self.port = port
|
||||
self.can_interface = can_interface
|
||||
self.use_can_fd = use_can_fd
|
||||
self.bitrate = bitrate
|
||||
self.data_bitrate = data_bitrate
|
||||
self.canbus: can.interface.Bus | None = None
|
||||
self._is_connected = False
|
||||
|
||||
# Map motor names to CAN IDs
|
||||
self._motor_can_ids: dict[str, int] = {}
|
||||
self._recv_id_to_motor: dict[int, str] = {}
|
||||
self._motor_types: dict[str, MotorType] = {}
|
||||
|
||||
for name, motor in self.motors.items():
|
||||
if motor.motor_type_str is None:
|
||||
raise ValueError(f"Motor '{name}' is missing required 'motor_type'")
|
||||
self._motor_types[name] = getattr(MotorType, motor.motor_type_str.upper().replace("-", "_"))
|
||||
|
||||
# Map recv_id to motor name for filtering responses
|
||||
if motor.recv_id is not None:
|
||||
self._recv_id_to_motor[motor.recv_id] = name
|
||||
|
||||
# State cache for handling packet drops safely
|
||||
self._last_known_states: dict[str, MotorState] = {
|
||||
name: {
|
||||
"position": 0.0,
|
||||
"velocity": 0.0,
|
||||
"torque": 0.0,
|
||||
"temp_mos": 0.0,
|
||||
"temp_rotor": 0.0,
|
||||
}
|
||||
for name in self.motors
|
||||
}
|
||||
|
||||
# Dynamic gains storage
|
||||
# Defaults: Kp=10.0 (Stiffness), Kd=0.5 (Damping)
|
||||
self._gains: dict[str, dict[str, float]] = {name: {"kp": 10.0, "kd": 0.5} for name in self.motors}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the CAN bus is connected."""
|
||||
return self._is_connected and self.canbus is not None
|
||||
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""
|
||||
Open the CAN bus and initialize communication.
|
||||
|
||||
Args:
|
||||
handshake: If True, ping all motors to verify they're present
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is already connected."
|
||||
)
|
||||
|
||||
try:
|
||||
# Auto-detect interface type based on port name
|
||||
if self.can_interface == "auto":
|
||||
if self.port.startswith("/dev/"):
|
||||
self.can_interface = "slcan"
|
||||
logger.info(f"Auto-detected slcan interface for port {self.port}")
|
||||
else:
|
||||
self.can_interface = "socketcan"
|
||||
logger.info(f"Auto-detected socketcan interface for port {self.port}")
|
||||
|
||||
# Connect to CAN bus
|
||||
kwargs = {
|
||||
"channel": self.port,
|
||||
"bitrate": self.bitrate,
|
||||
"interface": self.can_interface,
|
||||
}
|
||||
|
||||
if self.can_interface == "socketcan" and self.use_can_fd and self.data_bitrate is not None:
|
||||
kwargs.update({"data_bitrate": self.data_bitrate, "fd": True})
|
||||
logger.info(
|
||||
f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Connected to {self.port} with {self.can_interface} (bitrate={self.bitrate})")
|
||||
|
||||
self.canbus = can.interface.Bus(**kwargs)
|
||||
self._is_connected = True
|
||||
|
||||
if handshake:
|
||||
self._handshake()
|
||||
|
||||
logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.")
|
||||
except Exception as e:
|
||||
self._is_connected = False
|
||||
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
|
||||
|
||||
def _handshake(self) -> None:
|
||||
"""
|
||||
Verify all motors are present and populate initial state cache.
|
||||
Raises ConnectionError if any motor fails to respond.
|
||||
"""
|
||||
logger.info("Starting handshake with motors...")
|
||||
|
||||
# Drain any pending messages
|
||||
while self.canbus.recv(timeout=0.01):
|
||||
pass
|
||||
|
||||
missing_motors = []
|
||||
for motor_name in self.motors:
|
||||
motor_id = self._get_motor_id(motor_name)
|
||||
recv_id = self._get_motor_recv_id(motor_name)
|
||||
|
||||
# Send enable command
|
||||
data = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, CAN_CMD_ENABLE]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
|
||||
# Wait for response with longer timeout
|
||||
response = None
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < 0.1:
|
||||
response = self.canbus.recv(timeout=0.1)
|
||||
if response and response.arbitration_id == recv_id:
|
||||
break
|
||||
response = None
|
||||
|
||||
if response is None:
|
||||
missing_motors.append(motor_name)
|
||||
else:
|
||||
self._process_response(motor_name, msg)
|
||||
time.sleep(MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
if missing_motors:
|
||||
raise ConnectionError(
|
||||
f"Handshake failed. The following motors did not respond: {missing_motors}. "
|
||||
"Check power (24V) and CAN wiring."
|
||||
)
|
||||
logger.info("Handshake successful. All motors ready.")
|
||||
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""
|
||||
Close the CAN bus connection.
|
||||
|
||||
Args:
|
||||
disable_torque: If True, disable torque on all motors before disconnecting
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.")
|
||||
|
||||
if disable_torque:
|
||||
try:
|
||||
self.disable_torque()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to disable torque during disconnect: {e}")
|
||||
|
||||
if self.canbus:
|
||||
self.canbus.shutdown()
|
||||
self.canbus = None
|
||||
self._is_connected = False
|
||||
logger.debug(f"{self.__class__.__name__} disconnected.")
|
||||
|
||||
def configure_motors(self) -> None:
|
||||
"""Configure all motors with default settings."""
|
||||
# Damiao motors don't require much configuration in MIT mode
|
||||
# Just ensure they're enabled
|
||||
for motor in self.motors:
|
||||
self._send_simple_command(motor, CAN_CMD_ENABLE)
|
||||
time.sleep(MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
def _send_simple_command(self, motor: NameOrID, command_byte: int) -> None:
|
||||
"""Helper to send simple 8-byte commands (Enable, Disable, Zero)."""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [0xFF] * 7 + [command_byte]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
if msg := self._recv_motor_response(expected_recv_id=recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
else:
|
||||
logger.debug(f"No response from {motor_name} after command 0x{command_byte:02X}")
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Enable torque on selected motors."""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
for motor in target_motors:
|
||||
for _ in range(num_retry + 1):
|
||||
try:
|
||||
self._send_simple_command(motor, CAN_CMD_ENABLE)
|
||||
break
|
||||
except Exception as e:
|
||||
if _ == num_retry:
|
||||
raise e
|
||||
time.sleep(MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Disable torque on selected motors."""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
for motor in target_motors:
|
||||
for _ in range(num_retry + 1):
|
||||
try:
|
||||
self._send_simple_command(motor, CAN_CMD_DISABLE)
|
||||
break
|
||||
except Exception as e:
|
||||
if _ == num_retry:
|
||||
raise e
|
||||
time.sleep(MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
@contextmanager
|
||||
def torque_disabled(self, motors: str | list[str] | None = None):
|
||||
"""
|
||||
Context manager that guarantees torque is re-enabled.
|
||||
|
||||
This helper is useful to temporarily disable torque when configuring motors.
|
||||
"""
|
||||
self.disable_torque(motors)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.enable_torque(motors)
|
||||
|
||||
def set_zero_position(self, motors: str | list[str] | None = None) -> None:
|
||||
"""Set current position as zero for selected motors."""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
for motor in target_motors:
|
||||
self._send_simple_command(motor, CAN_CMD_SET_ZERO)
|
||||
time.sleep(MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
def _refresh_motor(self, motor: NameOrID) -> can.Message | None:
|
||||
"""Refresh motor status and return the response."""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
return self._recv_motor_response(expected_recv_id=recv_id)
|
||||
|
||||
def _recv_motor_response(
|
||||
self, expected_recv_id: int | None = None, timeout: float = 0.001
|
||||
) -> can.Message | None:
|
||||
"""
|
||||
Receive a response from a motor.
|
||||
|
||||
Args:
|
||||
expected_recv_id: If provided, only return messages from this CAN ID
|
||||
timeout: Timeout in seconds (default: 1ms for high-speed operation)
|
||||
Returns:
|
||||
CAN message if received, None otherwise
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
messages_seen = []
|
||||
while time.time() - start_time < timeout:
|
||||
msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC)
|
||||
if msg:
|
||||
messages_seen.append(f"0x{msg.arbitration_id:02X}")
|
||||
if expected_recv_id is None or msg.arbitration_id == expected_recv_id:
|
||||
return msg
|
||||
logger.debug(
|
||||
f"Ignoring message from 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}"
|
||||
)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
if messages_seen:
|
||||
logger.debug(
|
||||
f"Received {len(messages_seen)} msgs from {set(messages_seen)}, expected 0x{expected_recv_id:02X}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No CAN messages received (expected 0x{expected_recv_id:02X})")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to receive CAN message: {e}")
|
||||
return None
|
||||
|
||||
def _recv_all_responses(
|
||||
self, expected_recv_ids: list[int], timeout: float = 0.002
|
||||
) -> dict[int, can.Message]:
|
||||
"""
|
||||
Efficiently receive responses from multiple motors at once.
|
||||
Uses the OpenArms pattern: collect all available messages within timeout.
|
||||
|
||||
Args:
|
||||
expected_recv_ids: List of CAN IDs we expect responses from
|
||||
timeout: Total timeout in seconds (default: 2ms)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping recv_id to CAN message
|
||||
"""
|
||||
responses = {}
|
||||
expected_set = set(expected_recv_ids)
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
|
||||
# 100us poll timeout
|
||||
msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC)
|
||||
if msg and msg.arbitration_id in expected_set:
|
||||
responses[msg.arbitration_id] = msg
|
||||
if len(responses) == len(expected_recv_ids):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"Error receiving responses: {e}")
|
||||
|
||||
return responses
|
||||
|
||||
def _encode_mit_packet(
|
||||
self,
|
||||
motor_type: MotorType,
|
||||
kp: float,
|
||||
kd: float,
|
||||
position_degrees: float,
|
||||
velocity_deg_per_sec: float,
|
||||
torque: float,
|
||||
) -> list[int]:
|
||||
"""Helper to encode control parameters into 8 bytes for MIT mode."""
|
||||
# Convert degrees to radians
|
||||
position_rad = np.radians(position_degrees)
|
||||
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
|
||||
|
||||
# Get motor limits
|
||||
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||
|
||||
# Encode parameters
|
||||
kp_uint = self._float_to_uint(kp, *MIT_KP_RANGE, 12)
|
||||
kd_uint = self._float_to_uint(kd, *MIT_KD_RANGE, 12)
|
||||
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
|
||||
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
|
||||
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
|
||||
|
||||
# Pack data
|
||||
data = [0] * 8
|
||||
data[0] = (q_uint >> 8) & 0xFF
|
||||
data[1] = q_uint & 0xFF
|
||||
data[2] = dq_uint >> 4
|
||||
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
|
||||
data[4] = kp_uint & 0xFF
|
||||
data[5] = kd_uint >> 4
|
||||
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
|
||||
data[7] = tau_uint & 0xFF
|
||||
return data
|
||||
|
||||
def _mit_control(
|
||||
self,
|
||||
motor: NameOrID,
|
||||
kp: float,
|
||||
kd: float,
|
||||
position_degrees: float,
|
||||
velocity_deg_per_sec: float,
|
||||
torque: float,
|
||||
) -> None:
|
||||
"""Send MIT control command to a motor."""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types[motor_name]
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
if msg := self._recv_motor_response(expected_recv_id=recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
else:
|
||||
logger.debug(f"No response from {motor_name} after MIT control command")
|
||||
|
||||
def _mit_control_batch(
|
||||
self,
|
||||
commands: dict[NameOrID, tuple[float, float, float, float, float]],
|
||||
) -> None:
|
||||
"""
|
||||
Send MIT control commands to multiple motors in batch.
|
||||
Sends all commands first, then collects responses.
|
||||
|
||||
Args:
|
||||
commands: Dict mapping motor name/ID to (kp, kd, position_deg, velocity_deg/s, torque)
|
||||
Example: {'joint_1': (10.0, 0.5, 45.0, 0.0, 0.0), ...}
|
||||
"""
|
||||
if not commands:
|
||||
return
|
||||
|
||||
recv_id_to_motor: dict[int, str] = {}
|
||||
|
||||
# Step 1: Send all MIT control commands
|
||||
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types[motor_name]
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
|
||||
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
||||
|
||||
# Step 2: Collect responses and update state cache
|
||||
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=SHORT_TIMEOUT_SEC)
|
||||
for recv_id, motor_name in recv_id_to_motor.items():
|
||||
if msg := responses.get(recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
|
||||
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
|
||||
"""Convert float to unsigned integer for CAN transmission."""
|
||||
x = max(x_min, min(x_max, x)) # Clamp to range
|
||||
span = x_max - x_min
|
||||
data_norm = (x - x_min) / span
|
||||
return int(data_norm * ((1 << bits) - 1))
|
||||
|
||||
def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float:
|
||||
"""Convert unsigned integer from CAN to float."""
|
||||
span = x_max - x_min
|
||||
data_norm = float(x) / ((1 << bits) - 1)
|
||||
return data_norm * span + x_min
|
||||
|
||||
def _decode_motor_state(
|
||||
self, data: bytearray | bytes, motor_type: MotorType
|
||||
) -> tuple[float, float, float, int, int]:
|
||||
"""
|
||||
Decode motor state from CAN data.
|
||||
Returns: (position_deg, velocity_deg_s, torque, temp_mos, temp_rotor)
|
||||
"""
|
||||
if len(data) < 8:
|
||||
raise ValueError("Invalid motor state data")
|
||||
|
||||
# Extract encoded values
|
||||
q_uint = (data[1] << 8) | data[2]
|
||||
dq_uint = (data[3] << 4) | (data[4] >> 4)
|
||||
tau_uint = ((data[4] & 0x0F) << 8) | data[5]
|
||||
t_mos = data[6]
|
||||
t_rotor = data[7]
|
||||
|
||||
# Get motor limits
|
||||
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||
|
||||
# Decode to physical values
|
||||
position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16)
|
||||
velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12)
|
||||
torque = self._uint_to_float(tau_uint, -tmax, tmax, 12)
|
||||
|
||||
return np.degrees(position_rad), np.degrees(velocity_rad_per_sec), torque, t_mos, t_rotor
|
||||
|
||||
def _process_response(self, motor: str, msg: can.Message) -> None:
|
||||
"""Decode a message and update the motor state cache."""
|
||||
try:
|
||||
motor_type = self._motor_types[motor]
|
||||
pos, vel, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
|
||||
|
||||
self._last_known_states[motor] = {
|
||||
"position": pos,
|
||||
"velocity": vel,
|
||||
"torque": torque,
|
||||
"temp_mos": float(t_mos),
|
||||
"temp_rotor": float(t_rotor),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to decode response from {motor}: {e}")
|
||||
|
||||
def read(self, data_name: str, motor: str) -> Value:
|
||||
"""Read a value from a single motor. Positions are always in degrees."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Refresh motor to get latest state
|
||||
msg = self._refresh_motor(motor)
|
||||
if msg is None:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
raise ConnectionError(
|
||||
f"No response from motor '{motor}' (send ID: 0x{motor_id:02X}, recv ID: 0x{recv_id:02X}). "
|
||||
f"Check that: 1) Motor is powered (24V), 2) CAN wiring is correct, "
|
||||
f"3) Motor IDs are configured correctly using Damiao Debugging Tools"
|
||||
)
|
||||
|
||||
self._process_response(motor, msg)
|
||||
return self._get_cached_value(motor, data_name)
|
||||
|
||||
def _get_cached_value(self, motor: str, data_name: str) -> Value:
|
||||
"""Retrieve a specific value from the cache."""
|
||||
state = self._last_known_states[motor]
|
||||
mapping: dict[str, Any] = {
|
||||
"Present_Position": state["position"],
|
||||
"Present_Velocity": state["velocity"],
|
||||
"Present_Torque": state["torque"],
|
||||
"Temperature_MOS": state["temp_mos"],
|
||||
"Temperature_Rotor": state["temp_rotor"],
|
||||
}
|
||||
if data_name not in mapping:
|
||||
raise ValueError(f"Unknown data_name: {data_name}")
|
||||
return mapping[data_name]
|
||||
|
||||
def write(
|
||||
self,
|
||||
data_name: str,
|
||||
motor: str,
|
||||
value: Value,
|
||||
) -> None:
|
||||
"""
|
||||
Write a value to a single motor. Positions are always in degrees.
|
||||
Can write 'Goal_Position', 'Kp', or 'Kd'.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if data_name in ("Kp", "Kd"):
|
||||
self._gains[motor][data_name.lower()] = float(value)
|
||||
elif data_name == "Goal_Position":
|
||||
kp = self._gains[motor]["kp"]
|
||||
kd = self._gains[motor]["kd"]
|
||||
self._mit_control(motor, kp, kd, float(value), 0.0, 0.0)
|
||||
else:
|
||||
raise ValueError(f"Writing {data_name} not supported in MIT mode")
|
||||
|
||||
def sync_read(
|
||||
self,
|
||||
data_name: str,
|
||||
motors: str | list[str] | None = None,
|
||||
) -> dict[str, Value]:
|
||||
"""
|
||||
Read the same value from multiple motors simultaneously.
|
||||
"""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
self._batch_refresh(target_motors)
|
||||
|
||||
result = {}
|
||||
for motor in target_motors:
|
||||
result[motor] = self._get_cached_value(motor, data_name)
|
||||
return result
|
||||
|
||||
def sync_read_all_states(
|
||||
self,
|
||||
motors: str | list[str] | None = None,
|
||||
*,
|
||||
num_retry: int = 0,
|
||||
) -> dict[str, MotorState]:
|
||||
"""
|
||||
Read ALL motor states (position, velocity, torque) from multiple motors in ONE refresh cycle.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping motor names to state dicts with keys: 'position', 'velocity', 'torque'
|
||||
Example: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
|
||||
"""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
self._batch_refresh(target_motors)
|
||||
|
||||
result = {}
|
||||
for motor in target_motors:
|
||||
result[motor] = self._last_known_states[motor].copy()
|
||||
return result
|
||||
|
||||
def _batch_refresh(self, motors: list[str]) -> None:
|
||||
"""Internal helper to refresh a list of motors and update cache."""
|
||||
# Send refresh commands
|
||||
for motor in motors:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(
|
||||
arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd
|
||||
)
|
||||
self.canbus.send(msg)
|
||||
|
||||
# Collect responses
|
||||
expected_recv_ids = [self._get_motor_recv_id(m) for m in motors]
|
||||
responses = self._recv_all_responses(expected_recv_ids, timeout=MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
# Update cache
|
||||
for motor in motors:
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
msg = responses.get(recv_id)
|
||||
if msg:
|
||||
self._process_response(motor, msg)
|
||||
else:
|
||||
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
|
||||
"""
|
||||
Write values to multiple motors simultaneously. Positions are always in degrees.
|
||||
"""
|
||||
if data_name in ("Kp", "Kd"):
|
||||
key = data_name.lower()
|
||||
for motor, val in values.items():
|
||||
self._gains[motor][key] = float(val)
|
||||
|
||||
elif data_name == "Goal_Position":
|
||||
# Step 1: Send all MIT control commands
|
||||
recv_id_to_motor: dict[int, str] = {}
|
||||
for motor, value_degrees in values.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types[motor_name]
|
||||
|
||||
kp = self._gains[motor]["kp"]
|
||||
kd = self._gains[motor]["kd"]
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0)
|
||||
msg = can.Message(
|
||||
arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd
|
||||
)
|
||||
self.canbus.send(msg)
|
||||
precise_sleep(PRECISE_TIMEOUT_SEC)
|
||||
|
||||
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
||||
|
||||
# Step 2: Collect responses and update state cache
|
||||
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=MEDIUM_TIMEOUT_SEC)
|
||||
for recv_id, motor_name in recv_id_to_motor.items():
|
||||
if msg := responses.get(recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
else:
|
||||
# Fall back to individual writes
|
||||
for motor, value in values.items():
|
||||
self.write(data_name, motor, value)
|
||||
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
"""Read calibration data from motors."""
|
||||
# Damiao motors don't store calibration internally
|
||||
# Return existing calibration or empty dict
|
||||
return self.calibration if self.calibration else {}
|
||||
|
||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
|
||||
"""Write calibration data to motors."""
|
||||
# Damiao motors don't store calibration internally
|
||||
# Just cache it in memory
|
||||
if cache:
|
||||
self.calibration = calibration_dict
|
||||
|
||||
def record_ranges_of_motion(
|
||||
self,
|
||||
motors: NameOrID | list[NameOrID] | None = None,
|
||||
display_values: bool = True,
|
||||
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
|
||||
"""
|
||||
Interactively record the min/max values of each motor in degrees.
|
||||
|
||||
Move the joints by hand (with torque disabled) while the method streams live positions.
|
||||
Press Enter to finish.
|
||||
"""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
|
||||
self.disable_torque(target_motors)
|
||||
time.sleep(LONG_TIMEOUT_SEC)
|
||||
|
||||
start_positions = self.sync_read("Present_Position", target_motors)
|
||||
mins = start_positions.copy()
|
||||
maxes = start_positions.copy()
|
||||
|
||||
print("\nMove joints through their full range of motion. Press ENTER when done.")
|
||||
user_pressed_enter = False
|
||||
|
||||
while not user_pressed_enter:
|
||||
positions = self.sync_read("Present_Position", target_motors)
|
||||
|
||||
for motor in target_motors:
|
||||
if motor in positions:
|
||||
mins[motor] = min(positions[motor], mins.get(motor, positions[motor]))
|
||||
maxes[motor] = max(positions[motor], maxes.get(motor, positions[motor]))
|
||||
|
||||
if display_values:
|
||||
print("\n" + "=" * 50)
|
||||
print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}")
|
||||
print("-" * 50)
|
||||
for motor in target_motors:
|
||||
if motor in positions:
|
||||
print(
|
||||
f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}"
|
||||
)
|
||||
|
||||
if enter_pressed():
|
||||
user_pressed_enter = True
|
||||
|
||||
if display_values and not user_pressed_enter:
|
||||
move_cursor_up(len(target_motors) + 4)
|
||||
|
||||
time.sleep(LONG_TIMEOUT_SEC)
|
||||
|
||||
self.enable_torque(target_motors)
|
||||
|
||||
for motor in target_motors:
|
||||
if (motor in mins) and (motor in maxes) and (int(abs(maxes[motor] - mins[motor])) < 5):
|
||||
raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)")
|
||||
|
||||
return mins, maxes
|
||||
|
||||
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
|
||||
"""Convert motor specification to list of motor names."""
|
||||
if motors is None:
|
||||
return list(self.motors.keys())
|
||||
elif isinstance(motors, str):
|
||||
return [motors]
|
||||
elif isinstance(motors, list):
|
||||
return motors
|
||||
else:
|
||||
raise TypeError(f"Invalid motors type: {type(motors)}")
|
||||
|
||||
def _get_motor_id(self, motor: NameOrID) -> int:
|
||||
"""Get CAN ID for a motor."""
|
||||
if isinstance(motor, str):
|
||||
if motor in self.motors:
|
||||
return self.motors[motor].id
|
||||
else:
|
||||
raise ValueError(f"Unknown motor: {motor}")
|
||||
else:
|
||||
return motor
|
||||
|
||||
def _get_motor_name(self, motor: NameOrID) -> str:
|
||||
"""Get motor name from name or ID."""
|
||||
if isinstance(motor, str):
|
||||
return motor
|
||||
else:
|
||||
for name, m in self.motors.items():
|
||||
if m.id == motor:
|
||||
return name
|
||||
raise ValueError(f"Unknown motor ID: {motor}")
|
||||
|
||||
def _get_motor_recv_id(self, motor: NameOrID) -> int:
|
||||
"""Get motor recv_id from name or ID."""
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_obj = self.motors.get(motor_name)
|
||||
if motor_obj and motor_obj.recv_id is not None:
|
||||
return motor_obj.recv_id
|
||||
else:
|
||||
raise ValueError(f"Motor {motor_obj} doesn't have a valid recv_id (None).")
|
||||
|
||||
@cached_property
|
||||
def is_calibrated(self) -> bool:
|
||||
"""Check if motors are calibrated."""
|
||||
return bool(self.calibration)
|
||||
@@ -0,0 +1,209 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration tables for Damiao motors."""
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
# Motor type definitions
|
||||
class MotorType(IntEnum):
|
||||
DM3507 = 0
|
||||
DM4310 = 1
|
||||
DM4310_48V = 2
|
||||
DM4340 = 3
|
||||
DM4340_48V = 4
|
||||
DM6006 = 5
|
||||
DM8006 = 6
|
||||
DM8009 = 7
|
||||
DM10010L = 8
|
||||
DM10010 = 9
|
||||
DMH3510 = 10
|
||||
DMH6215 = 11
|
||||
DMG6220 = 12
|
||||
|
||||
|
||||
# Control modes
|
||||
class ControlMode(IntEnum):
|
||||
MIT = 1
|
||||
POS_VEL = 2
|
||||
VEL = 3
|
||||
TORQUE_POS = 4
|
||||
|
||||
|
||||
# Motor variable IDs (RID)
|
||||
class MotorVariable(IntEnum):
|
||||
UV_VALUE = 0
|
||||
KT_VALUE = 1
|
||||
OT_VALUE = 2
|
||||
OC_VALUE = 3
|
||||
ACC = 4
|
||||
DEC = 5
|
||||
MAX_SPD = 6
|
||||
MST_ID = 7
|
||||
ESC_ID = 8
|
||||
TIMEOUT = 9
|
||||
CTRL_MODE = 10
|
||||
DAMP = 11
|
||||
INERTIA = 12
|
||||
HW_VER = 13
|
||||
SW_VER = 14
|
||||
SN = 15
|
||||
NPP = 16
|
||||
RS = 17
|
||||
LS = 18
|
||||
FLUX = 19
|
||||
GR = 20
|
||||
PMAX = 21
|
||||
VMAX = 22
|
||||
TMAX = 23
|
||||
I_BW = 24
|
||||
KP_ASR = 25
|
||||
KI_ASR = 26
|
||||
KP_APR = 27
|
||||
KI_APR = 28
|
||||
OV_VALUE = 29
|
||||
GREF = 30
|
||||
DETA = 31
|
||||
V_BW = 32
|
||||
IQ_C1 = 33
|
||||
VL_C1 = 34
|
||||
CAN_BR = 35
|
||||
SUB_VER = 36
|
||||
U_OFF = 50
|
||||
V_OFF = 51
|
||||
K1 = 52
|
||||
K2 = 53
|
||||
M_OFF = 54
|
||||
DIR = 55
|
||||
P_M = 80
|
||||
XOUT = 81
|
||||
|
||||
|
||||
# Motor limit parameters [PMAX, VMAX, TMAX]
|
||||
# PMAX: Maximum position (rad)
|
||||
# VMAX: Maximum velocity (rad/s)
|
||||
# TMAX: Maximum torque (N·m)
|
||||
MOTOR_LIMIT_PARAMS = {
|
||||
MotorType.DM3507: (12.5, 30, 10),
|
||||
MotorType.DM4310: (12.5, 30, 10),
|
||||
MotorType.DM4310_48V: (12.5, 50, 10),
|
||||
MotorType.DM4340: (12.5, 8, 28),
|
||||
MotorType.DM4340_48V: (12.5, 10, 28),
|
||||
MotorType.DM6006: (12.5, 45, 20),
|
||||
MotorType.DM8006: (12.5, 45, 40),
|
||||
MotorType.DM8009: (12.5, 45, 54),
|
||||
MotorType.DM10010L: (12.5, 25, 200),
|
||||
MotorType.DM10010: (12.5, 20, 200),
|
||||
MotorType.DMH3510: (12.5, 280, 1),
|
||||
MotorType.DMH6215: (12.5, 45, 10),
|
||||
MotorType.DMG6220: (12.5, 45, 10),
|
||||
}
|
||||
|
||||
# Motor model names
|
||||
MODEL_NAMES = {
|
||||
MotorType.DM3507: "dm3507",
|
||||
MotorType.DM4310: "dm4310",
|
||||
MotorType.DM4310_48V: "dm4310_48v",
|
||||
MotorType.DM4340: "dm4340",
|
||||
MotorType.DM4340_48V: "dm4340_48v",
|
||||
MotorType.DM6006: "dm6006",
|
||||
MotorType.DM8006: "dm8006",
|
||||
MotorType.DM8009: "dm8009",
|
||||
MotorType.DM10010L: "dm10010l",
|
||||
MotorType.DM10010: "dm10010",
|
||||
MotorType.DMH3510: "dmh3510",
|
||||
MotorType.DMH6215: "dmh6215",
|
||||
MotorType.DMG6220: "dmg6220",
|
||||
}
|
||||
|
||||
# Motor resolution table (encoder counts per revolution)
|
||||
MODEL_RESOLUTION = {
|
||||
"dm3507": 65536,
|
||||
"dm4310": 65536,
|
||||
"dm4310_48v": 65536,
|
||||
"dm4340": 65536,
|
||||
"dm4340_48v": 65536,
|
||||
"dm6006": 65536,
|
||||
"dm8006": 65536,
|
||||
"dm8009": 65536,
|
||||
"dm10010l": 65536,
|
||||
"dm10010": 65536,
|
||||
"dmh3510": 65536,
|
||||
"dmh6215": 65536,
|
||||
"dmg6220": 65536,
|
||||
}
|
||||
|
||||
# CAN baudrates supported by Damiao motors
|
||||
AVAILABLE_BAUDRATES = [
|
||||
125000, # 0: 125 kbps
|
||||
200000, # 1: 200 kbps
|
||||
250000, # 2: 250 kbps
|
||||
500000, # 3: 500 kbps
|
||||
1000000, # 4: 1 mbps (default for OpenArms)
|
||||
2000000, # 5: 2 mbps
|
||||
2500000, # 6: 2.5 mbps
|
||||
3200000, # 7: 3.2 mbps
|
||||
4000000, # 8: 4 mbps
|
||||
5000000, # 9: 5 mbps
|
||||
]
|
||||
DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms
|
||||
|
||||
# Default timeout in milliseconds
|
||||
DEFAULT_TIMEOUT_MS = 1000
|
||||
|
||||
# OpenArms specific configurations
|
||||
# Based on: https://docs.openarm.dev/software/setup/configure-test
|
||||
# OpenArms has 7 DOF per arm (14 total for dual arm)
|
||||
OPENARMS_ARM_MOTOR_IDS = {
|
||||
"joint_1": {"send": 0x01, "recv": 0x11}, # J1 - Shoulder pan
|
||||
"joint_2": {"send": 0x02, "recv": 0x12}, # J2 - Shoulder lift
|
||||
"joint_3": {"send": 0x03, "recv": 0x13}, # J3 - Elbow flex
|
||||
"joint_4": {"send": 0x04, "recv": 0x14}, # J4 - Wrist flex
|
||||
"joint_5": {"send": 0x05, "recv": 0x15}, # J5 - Wrist roll
|
||||
"joint_6": {"send": 0x06, "recv": 0x16}, # J6 - Wrist pitch
|
||||
"joint_7": {"send": 0x07, "recv": 0x17}, # J7 - Wrist rotation
|
||||
}
|
||||
|
||||
OPENARMS_GRIPPER_MOTOR_IDS = {
|
||||
"gripper": {"send": 0x08, "recv": 0x18}, # J8 - Gripper
|
||||
}
|
||||
|
||||
# Default motor types for OpenArms
|
||||
OPENARMS_DEFAULT_MOTOR_TYPES = {
|
||||
"joint_1": MotorType.DM8009, # Shoulder pan - high torque
|
||||
"joint_2": MotorType.DM8009, # Shoulder lift - high torque
|
||||
"joint_3": MotorType.DM4340, # Shoulder rotation
|
||||
"joint_4": MotorType.DM4340, # Elbow flex
|
||||
"joint_5": MotorType.DM4310, # Wrist roll
|
||||
"joint_6": MotorType.DM4310, # Wrist pitch
|
||||
"joint_7": MotorType.DM4310, # Wrist rotation
|
||||
"gripper": MotorType.DM4310, # Gripper
|
||||
}
|
||||
|
||||
# MIT control parameter ranges
|
||||
MIT_KP_RANGE = (0.0, 500.0)
|
||||
MIT_KD_RANGE = (0.0, 5.0)
|
||||
|
||||
# CAN frame command IDs
|
||||
CAN_CMD_ENABLE = 0xFC
|
||||
CAN_CMD_DISABLE = 0xFD
|
||||
CAN_CMD_SET_ZERO = 0xFE
|
||||
CAN_CMD_REFRESH = 0xCC
|
||||
CAN_CMD_QUERY_PARAM = 0x33
|
||||
CAN_CMD_WRITE_PARAM = 0x55
|
||||
CAN_CMD_SAVE_PARAM = 0xAA
|
||||
|
||||
# CAN ID for parameter operations
|
||||
CAN_PARAM_ID = 0x7FF
|
||||
@@ -22,9 +22,8 @@ import logging
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
|
||||
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
|
||||
from ..encoding_utils import decode_twos_complement, encode_twos_complement
|
||||
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
||||
from .tables import (
|
||||
AVAILABLE_BAUDRATES,
|
||||
MODEL_BAUDRATE_TABLE,
|
||||
@@ -100,7 +99,7 @@ def _split_into_byte_chunks(value: int, length: int) -> list[int]:
|
||||
return data
|
||||
|
||||
|
||||
class DynamixelMotorsBus(MotorsBus):
|
||||
class DynamixelMotorsBus(SerialMotorsBus):
|
||||
"""
|
||||
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
|
||||
the motors. For more info, see the Dynamixel SDK Documentation:
|
||||
@@ -203,9 +202,9 @@ class DynamixelMotorsBus(MotorsBus):
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
|
||||
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
|
||||
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
|
||||
@@ -17,9 +17,8 @@ from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pprint import pformat
|
||||
|
||||
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
|
||||
from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
||||
from .tables import (
|
||||
FIRMWARE_MAJOR_VERSION,
|
||||
FIRMWARE_MINOR_VERSION,
|
||||
@@ -96,7 +95,7 @@ def patch_setPacketTimeout(self, packet_length): # noqa: N802
|
||||
self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50
|
||||
|
||||
|
||||
class FeetechMotorsBus(MotorsBus):
|
||||
class FeetechMotorsBus(SerialMotorsBus):
|
||||
"""
|
||||
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the
|
||||
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
|
||||
@@ -298,11 +297,11 @@ class FeetechMotorsBus(MotorsBus):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self.write("Lock", motor, 0, num_retry=num_retry)
|
||||
|
||||
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
|
||||
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
|
||||
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Lock")
|
||||
self._write(addr, length, motor_id, 0, num_retry=num_retry)
|
||||
self._write(addr, length, motor, 0, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
|
||||
@@ -205,6 +205,7 @@ MODEL_BAUDRATE_TABLE = {
|
||||
|
||||
# Sign-Magnitude encoding bits
|
||||
STS_SMS_SERIES_ENCODINGS_TABLE = {
|
||||
"Present_Load": 10,
|
||||
"Homing_Offset": 11,
|
||||
"Goal_Position": 15,
|
||||
"Goal_Velocity": 15,
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
# TODO(aliberts): Add block noqa when feature below is available
|
||||
# https://github.com/astral-sh/ruff/issues/3711
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
@@ -32,7 +34,7 @@ import serial
|
||||
from deepdiff import DeepDiff
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
NameOrID: TypeAlias = str | int
|
||||
@@ -41,6 +43,81 @@ Value: TypeAlias = int | float
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MotorsBusBase(abc.ABC):
|
||||
"""
|
||||
Base class for all motor bus implementations.
|
||||
|
||||
This is a minimal interface that all motor buses must implement, regardless of their
|
||||
communication protocol (serial, CAN, etc.).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
self.calibration = calibration if calibration else {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""Establish connection to the motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""Disconnect from the motors."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected to the motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, data_name: str, motor: str) -> Value:
|
||||
"""Read a value from a single motor."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def write(self, data_name: str, motor: str, value: Value) -> None:
|
||||
"""Write a value to a single motor."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_read(self, data_name: str, motors: str | list[str] | None = None) -> dict[str, Value]:
|
||||
"""Read a value from multiple motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
|
||||
"""Write values to multiple motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Enable torque on selected motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Disable torque on selected motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
"""Read calibration parameters from the motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
|
||||
"""Write calibration parameters to the motors."""
|
||||
pass
|
||||
|
||||
|
||||
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
|
||||
ctrl_table = model_ctrl_table.get(model)
|
||||
if ctrl_table is None:
|
||||
@@ -97,6 +174,8 @@ class Motor:
|
||||
id: int
|
||||
model: str
|
||||
norm_mode: MotorNormMode
|
||||
motor_type_str: str | None = None
|
||||
recv_id: int | None = None
|
||||
|
||||
|
||||
class PortHandler(Protocol):
|
||||
@@ -203,15 +282,15 @@ class GroupSyncWrite(Protocol):
|
||||
def txPacket(self): ...
|
||||
|
||||
|
||||
class MotorsBus(abc.ABC):
|
||||
class SerialMotorsBus(MotorsBusBase):
|
||||
"""
|
||||
A MotorsBus allows to efficiently read and write to the attached motors.
|
||||
A SerialMotorsBus allows to efficiently read and write to motors connected via serial communication.
|
||||
It represents several motors daisy-chained together and connected through a serial port.
|
||||
There are currently two implementations of this abstract class:
|
||||
There are currently two implementations of this class:
|
||||
- DynamixelMotorsBus
|
||||
- FeetechMotorsBus
|
||||
|
||||
Note: This class may evolve in the future should we add support for other types of bus.
|
||||
This class is specifically for serial-based motor protocols (Dynamixel, Feetech, etc.).
|
||||
|
||||
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
||||
To find the port, you can run our utility script:
|
||||
@@ -260,9 +339,7 @@ class MotorsBus(abc.ABC):
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
self.calibration = calibration if calibration else {}
|
||||
super().__init__(port, motors, calibration)
|
||||
|
||||
self.port_handler: PortHandler
|
||||
self.packet_handler: PacketHandler
|
||||
@@ -411,6 +488,7 @@ class MotorsBus(abc.ABC):
|
||||
"""bool: `True` if the underlying serial port is open."""
|
||||
return self.port_handler.is_open
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""Open the serial port and initialise communication.
|
||||
|
||||
@@ -422,10 +500,6 @@ class MotorsBus(abc.ABC):
|
||||
DeviceAlreadyConnectedError: The port is already open.
|
||||
ConnectionError: The underlying SDK failed to open the port or the handshake did not succeed.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice."
|
||||
)
|
||||
|
||||
self._connect(handshake)
|
||||
self.set_timeout()
|
||||
@@ -447,6 +521,7 @@ class MotorsBus(abc.ABC):
|
||||
def _handshake(self) -> None:
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""Close the serial port (optionally disabling torque first).
|
||||
|
||||
@@ -455,10 +530,6 @@ class MotorsBus(abc.ABC):
|
||||
closing the port. This can prevent damaging motors if they are left applying resisting torque
|
||||
after disconnect.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. Try running `{self.__class__.__name__}.connect()` first."
|
||||
)
|
||||
|
||||
if disable_torque:
|
||||
self.port_handler.clearPort()
|
||||
@@ -538,7 +609,7 @@ class MotorsBus(abc.ABC):
|
||||
self.set_baudrate(self.default_baudrate)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]:
|
||||
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -551,13 +622,13 @@ class MotorsBus(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Disable torque on selected motors.
|
||||
|
||||
Disabling Torque allows to write to the motors' permanent memory area (EPROM/EEPROM).
|
||||
|
||||
Args:
|
||||
motors (int | str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a
|
||||
motors ( str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a
|
||||
list of names or `None` to affect every registered motor. Defaults to `None`.
|
||||
num_retry (int, optional): Number of additional retry attempts on communication failure.
|
||||
Defaults to 0.
|
||||
@@ -907,6 +978,7 @@ class MotorsBus(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def read(
|
||||
self,
|
||||
data_name: str,
|
||||
@@ -927,10 +999,6 @@ class MotorsBus(abc.ABC):
|
||||
Returns:
|
||||
Value: Raw or normalised value depending on *normalize*.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
id_ = self.motors[motor].id
|
||||
model = self.motors[motor].model
|
||||
@@ -981,6 +1049,7 @@ class MotorsBus(abc.ABC):
|
||||
|
||||
return value, comm, error
|
||||
|
||||
@check_if_not_connected
|
||||
def write(
|
||||
self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0
|
||||
) -> None:
|
||||
@@ -999,10 +1068,6 @@ class MotorsBus(abc.ABC):
|
||||
normalize (bool, optional): Enable or disable normalisation. Defaults to `True`.
|
||||
num_retry (int, optional): Retry attempts. Defaults to `0`.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
id_ = self.motors[motor].id
|
||||
model = self.motors[motor].model
|
||||
@@ -1044,6 +1109,7 @@ class MotorsBus(abc.ABC):
|
||||
|
||||
return comm, error
|
||||
|
||||
@check_if_not_connected
|
||||
def sync_read(
|
||||
self,
|
||||
data_name: str,
|
||||
@@ -1063,10 +1129,6 @@ class MotorsBus(abc.ABC):
|
||||
Returns:
|
||||
dict[str, Value]: Mapping *motor name → value*.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
self._assert_protocol_is_compatible("sync_read")
|
||||
|
||||
@@ -1139,6 +1201,7 @@ class MotorsBus(abc.ABC):
|
||||
# for id_ in motor_ids:
|
||||
# value = self.sync_reader.getData(id_, address, length)
|
||||
|
||||
@check_if_not_connected
|
||||
def sync_write(
|
||||
self,
|
||||
data_name: str,
|
||||
@@ -1160,10 +1223,6 @@ class MotorsBus(abc.ABC):
|
||||
normalize (bool, optional): If `True` (default) convert values from the user range to raw units.
|
||||
num_retry (int, optional): Retry attempts. Defaults to `0`.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
ids_values = self._get_ids_values_dict(values)
|
||||
models = [self._id_to_model(id_) for id_ in ids_values]
|
||||
@@ -1212,3 +1271,7 @@ class MotorsBus(abc.ABC):
|
||||
for id_, value in ids_values.items():
|
||||
data = self._serialize_data(value, length)
|
||||
self.sync_writer.addParam(id_, data)
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
MotorsBus: TypeAlias = SerialMotorsBus
|
||||
|
||||
@@ -32,16 +32,22 @@ Notes:
|
||||
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.groot.groot_n1 import GR00TN15
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
T = TypeVar("T", bound="GrootPolicy")
|
||||
|
||||
|
||||
class GrootPolicy(PreTrainedPolicy):
|
||||
@@ -90,6 +96,129 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
"""Reset policy state when environment resets."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: GrootConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = True,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Load Groot policy from pretrained model.
|
||||
|
||||
Handles two cases:
|
||||
1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model
|
||||
2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors
|
||||
|
||||
Args:
|
||||
pretrained_name_or_path: Path to the GR00T model or fine-tuned checkpoint
|
||||
config: Optional GrootConfig. If None, loads from checkpoint or creates default
|
||||
force_download: Force download even if cached
|
||||
resume_download: Resume interrupted download
|
||||
proxies: Proxy settings
|
||||
token: HuggingFace authentication token
|
||||
cache_dir: Cache directory path
|
||||
local_files_only: Only use local files
|
||||
revision: Specific model revision
|
||||
strict: Strict state dict loading
|
||||
**kwargs: Additional arguments (passed to config)
|
||||
|
||||
Returns:
|
||||
Initialized GrootPolicy instance with loaded model
|
||||
"""
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
print(
|
||||
"The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n"
|
||||
f"Loading pretrained model from: {pretrained_name_or_path}"
|
||||
)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
is_finetuned_checkpoint = False
|
||||
|
||||
# Check if this is a fine-tuned LeRobot checkpoint (has model.safetensors)
|
||||
try:
|
||||
if os.path.isdir(model_id):
|
||||
is_finetuned_checkpoint = os.path.exists(os.path.join(model_id, SAFETENSORS_SINGLE_FILE))
|
||||
else:
|
||||
# Try to download the safetensors file to check if it exists
|
||||
try:
|
||||
hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_SINGLE_FILE,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=False, # Just check, don't force download
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
is_finetuned_checkpoint = True
|
||||
except HfHubHTTPError:
|
||||
is_finetuned_checkpoint = False
|
||||
except Exception:
|
||||
is_finetuned_checkpoint = False
|
||||
|
||||
if is_finetuned_checkpoint:
|
||||
# This is a fine-tuned LeRobot checkpoint - use parent class loading
|
||||
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||
return super().from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
config=config,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
strict=strict,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# This is a base GR00T model - load it fresh
|
||||
print("Detected base GR00T model, loading from HuggingFace...")
|
||||
|
||||
if config is None:
|
||||
# Create default config with the pretrained path
|
||||
config = GrootConfig(base_model_path=str(pretrained_name_or_path))
|
||||
|
||||
# Add minimal visual feature required for validation
|
||||
# validate_features() will automatically add state and action features
|
||||
# These are placeholders - actual robot features come from the preprocessor
|
||||
if not config.input_features:
|
||||
config.input_features = {
|
||||
f"{OBS_IMAGES}.camera": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224), # Default image size from config
|
||||
),
|
||||
}
|
||||
else:
|
||||
# Override the base_model_path with the provided path
|
||||
config.base_model_path = str(pretrained_name_or_path)
|
||||
|
||||
# Pass through any additional config overrides from kwargs
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
# Create a fresh policy instance - this will automatically load the GR00T model
|
||||
# in __init__ via _create_groot_model()
|
||||
policy = cls(config)
|
||||
|
||||
policy.eval()
|
||||
return policy
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
|
||||
@@ -1297,3 +1297,14 @@ class PI0Policy(PreTrainedPolicy):
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any]:
|
||||
"""Return default PEFT target modules for PI0 fine-tuning."""
|
||||
common_projections = (
|
||||
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
||||
)
|
||||
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
|
||||
return {
|
||||
"target_modules": target_modules,
|
||||
"modules_to_save": [],
|
||||
}
|
||||
|
||||
@@ -61,8 +61,6 @@ class PI05Config(PreTrainedConfig):
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
|
||||
@@ -1270,3 +1270,14 @@ class PI05Policy(PreTrainedPolicy):
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any]:
|
||||
"""Return default PEFT target modules for PI0.5 fine-tuning."""
|
||||
common_projections = (
|
||||
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
||||
)
|
||||
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
|
||||
return {
|
||||
"target_modules": target_modules,
|
||||
"modules_to_save": [],
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import builtins
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
from importlib.resources import files
|
||||
@@ -265,3 +266,166 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
card = ModelCard.from_template(card_data, template_str=template_card)
|
||||
card.validate()
|
||||
return card
|
||||
|
||||
def wrap_with_peft(
|
||||
self,
|
||||
peft_config=None,
|
||||
peft_cli_overrides: dict | None = None,
|
||||
) -> "PreTrainedPolicy":
|
||||
"""
|
||||
Wrap this policy with PEFT adapters for parameter-efficient fine-tuning.
|
||||
|
||||
This method is the single entry point for PEFT integration. Subclasses should
|
||||
override `_get_default_peft_targets()` to provide default target modules, and
|
||||
`_validate_peft_config()` for policy-specific validation.
|
||||
|
||||
Args:
|
||||
peft_config: Optional PEFT adapter configuration (e.g., LoraConfig).
|
||||
If provided, used directly (with CLI overrides applied).
|
||||
peft_cli_overrides: Optional dict of CLI overrides (method_type, target_modules, r, etc.)
|
||||
These are merged with policy defaults to build the final config.
|
||||
"""
|
||||
from peft import get_peft_model
|
||||
|
||||
# If user provided a complete config, use it directly (with overrides)
|
||||
if peft_config is not None:
|
||||
final_config = peft_config
|
||||
if peft_cli_overrides:
|
||||
final_config = self._apply_peft_cli_overrides(final_config, peft_cli_overrides)
|
||||
else:
|
||||
# Build config from defaults + CLI overrides
|
||||
final_config = self._build_peft_config(peft_cli_overrides or {})
|
||||
|
||||
# Validate the configuration
|
||||
self._validate_peft_config(final_config)
|
||||
|
||||
# Freeze base parameters, only adapter params will be trained
|
||||
for p in self.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
# Store pretrained path for PEFT's base_model_name_or_path
|
||||
if self.config.pretrained_path:
|
||||
self.name_or_path = str(self.config.pretrained_path)
|
||||
|
||||
# Wrap with PEFT
|
||||
peft_model = get_peft_model(self, final_config)
|
||||
|
||||
# Mark config as using PEFT for proper loading later
|
||||
peft_model.config.use_peft = True
|
||||
|
||||
logging.info(f"Wrapped {self.name} with PEFT ({type(final_config).__name__})")
|
||||
return peft_model
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any] | None:
|
||||
"""
|
||||
Return default PEFT target modules for this policy.
|
||||
|
||||
Override this in subclasses to provide policy-specific defaults. These defaults
|
||||
are PEFT-method agnostic - they only specify which modules to target.
|
||||
|
||||
"""
|
||||
return None
|
||||
|
||||
def _validate_peft_config(self, peft_config) -> None:
|
||||
"""
|
||||
Validate the PEFT configuration for this policy.
|
||||
|
||||
Override this in subclasses to add policy-specific validation or warnings.
|
||||
The default implementation checks that a pretrained_path exists.
|
||||
|
||||
Args:
|
||||
peft_config: The PEFT configuration to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configuration is invalid.
|
||||
"""
|
||||
if not self.config.pretrained_path:
|
||||
raise ValueError(
|
||||
"Training from scratch using PEFT is unlikely to yield good results. "
|
||||
"Supply a `policy.pretrained_path` to fine-tune an existing model."
|
||||
)
|
||||
|
||||
def _preprocess_peft_cli_overrides(self, cli_overrides: dict, peft_method_type) -> dict:
|
||||
"""
|
||||
Preprocess CLI overrides: rename keys and handle method-specific init_type.
|
||||
|
||||
Args:
|
||||
cli_overrides: Dict of CLI options (will be copied, not mutated).
|
||||
peft_method_type: The PeftType enum value for the PEFT method.
|
||||
|
||||
Returns:
|
||||
Preprocessed dict with renamed keys and init_type mapped to method-specific key.
|
||||
"""
|
||||
from peft import PeftType
|
||||
|
||||
cli_overrides = cli_overrides.copy()
|
||||
|
||||
# Handle the full_training_modules -> modules_to_save rename
|
||||
if "full_training_modules" in cli_overrides:
|
||||
cli_overrides["modules_to_save"] = cli_overrides.pop("full_training_modules")
|
||||
|
||||
# Remove method_type as it's handled separately
|
||||
cli_overrides.pop("method_type", None)
|
||||
|
||||
# Handle init_type specially based on PEFT method
|
||||
init_type = cli_overrides.pop("init_type", None)
|
||||
if init_type is not None:
|
||||
if peft_method_type == PeftType.LORA:
|
||||
cli_overrides["init_lora_weights"] = init_type
|
||||
elif peft_method_type == PeftType.MISS:
|
||||
cli_overrides["init_weights"] = init_type
|
||||
else:
|
||||
raise ValueError(f"Init type '{init_type}' unknown for PEFT method {peft_method_type}.")
|
||||
|
||||
return cli_overrides
|
||||
|
||||
def _build_peft_config(self, cli_overrides: dict):
|
||||
"""Build a PEFT config from policy defaults and CLI overrides."""
|
||||
from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType
|
||||
|
||||
# Determine PEFT method type (default to LORA)
|
||||
method_type_str = cli_overrides.get("method_type") or "lora"
|
||||
peft_method_type = PeftType[method_type_str.upper()]
|
||||
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]
|
||||
|
||||
# Preprocess CLI overrides
|
||||
cli_overrides = self._preprocess_peft_cli_overrides(cli_overrides, peft_method_type)
|
||||
|
||||
# Start with policy defaults, apply CLI overrides
|
||||
config_dict = dict(self._get_default_peft_targets() or {})
|
||||
for key, value in cli_overrides.items():
|
||||
if value is not None:
|
||||
config_dict[key] = value
|
||||
|
||||
# Ensure we have target_modules
|
||||
if not config_dict.get("target_modules"):
|
||||
raise ValueError(
|
||||
f"Policy '{self.name}' does not define default target_modules. "
|
||||
"Please pass --peft.target_modules explicitly."
|
||||
)
|
||||
|
||||
return peft_config_cls(**config_dict)
|
||||
|
||||
def _apply_peft_cli_overrides(self, peft_config, cli_overrides: dict):
|
||||
"""Apply CLI overrides to an existing PEFT config."""
|
||||
from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType
|
||||
|
||||
# Get method type from existing config or CLI override
|
||||
method_type_str = cli_overrides.get("method_type")
|
||||
if method_type_str:
|
||||
peft_method_type = PeftType[method_type_str.upper()]
|
||||
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]
|
||||
else:
|
||||
peft_method_type = PeftType(peft_config.peft_type)
|
||||
peft_config_cls = type(peft_config)
|
||||
|
||||
# Preprocess CLI overrides
|
||||
cli_overrides = self._preprocess_peft_cli_overrides(cli_overrides, peft_method_type)
|
||||
|
||||
# Start with existing config, apply CLI overrides
|
||||
config_dict = {k: v for k, v in dataclasses.asdict(peft_config).items() if not k.startswith("_")}
|
||||
for key, value in cli_overrides.items():
|
||||
if value is not None:
|
||||
config_dict[key] = value
|
||||
|
||||
return peft_config_cls(**config_dict)
|
||||
|
||||
@@ -239,8 +239,10 @@ class SACPolicy(
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def update_temperature(self):
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
"""Return the current temperature value, always in sync with log_alpha."""
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def compute_loss_critic(
|
||||
self,
|
||||
@@ -457,11 +459,10 @@ class SACPolicy(
|
||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.target_entropy = -np.prod(dim) / 2
|
||||
|
||||
def _init_temperature(self):
|
||||
"""Set up temperature parameter and initial log_alpha."""
|
||||
def _init_temperature(self) -> None:
|
||||
"""Set up temperature parameter (log_alpha)."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
|
||||
@@ -480,6 +480,28 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
return actions
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any]:
|
||||
"""Return default PEFT target modules for SmolVLA fine-tuning."""
|
||||
common_projections = (
|
||||
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
||||
)
|
||||
target_modules = rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))"
|
||||
return {
|
||||
"target_modules": target_modules,
|
||||
"modules_to_save": [],
|
||||
}
|
||||
|
||||
def _validate_peft_config(self, peft_config) -> None:
|
||||
"""Validate PEFT configuration for SmolVLA."""
|
||||
super()._validate_peft_config(peft_config)
|
||||
if not self.config.load_vlm_weights:
|
||||
import logging
|
||||
|
||||
logging.warning(
|
||||
"Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. "
|
||||
"Set `load_vlm_weights=True` to fine-tune the existing policy."
|
||||
)
|
||||
|
||||
|
||||
def pad_tensor(tensor, max_len, pad_value=0):
|
||||
"""
|
||||
|
||||
@@ -168,11 +168,12 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
|
||||
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -18,16 +18,18 @@
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol, TypeVar, runtime_checkable
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
from .pipeline import (
|
||||
ComplementaryDataProcessorStep,
|
||||
@@ -69,10 +71,10 @@ class HasTeleopEvents(Protocol):
|
||||
|
||||
|
||||
# Type variable constrained to Teleoperator subclasses that also implement events
|
||||
TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
|
||||
TeleopWithEvents = TypeVar("TeleopWithEvents", bound="Teleoperator")
|
||||
|
||||
|
||||
def _check_teleop_with_events(teleop: Teleoperator) -> None:
|
||||
def _check_teleop_with_events(teleop: "Teleoperator") -> None:
|
||||
"""
|
||||
Runtime check that a teleoperator implements the `HasTeleopEvents` protocol.
|
||||
|
||||
@@ -103,7 +105,7 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
|
||||
teleop_device: The teleoperator instance to get the action from.
|
||||
"""
|
||||
|
||||
teleop_device: Teleoperator
|
||||
teleop_device: "Teleoperator"
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
"""
|
||||
|
||||
@@ -34,6 +34,8 @@ from lerobot.utils.constants import (
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_TOKENS,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
@@ -139,6 +141,32 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
|
||||
return None
|
||||
|
||||
def get_subtask(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""
|
||||
Extracts the subtask from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
A list of subtask strings, or None if the subtask key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
subtask = complementary_data.get("subtask")
|
||||
if subtask is None:
|
||||
return None
|
||||
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(subtask, str):
|
||||
return [subtask]
|
||||
elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask):
|
||||
return subtask
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation: RobotObservation) -> RobotObservation:
|
||||
"""
|
||||
Tokenizes the task description and adds it to the observation dictionary.
|
||||
@@ -176,6 +204,24 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Tokenize subtask if available
|
||||
subtask = self.get_subtask(self.transition)
|
||||
if subtask is not None:
|
||||
tokenized_subtask = self._tokenize_text(subtask)
|
||||
|
||||
# Move new tokenized tensors to the detected device
|
||||
if target_device is not None:
|
||||
tokenized_subtask = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_subtask.items()
|
||||
}
|
||||
|
||||
# Add tokenized subtask to the observation
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
|
||||
return new_observation
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
|
||||
@@ -545,9 +545,6 @@ def add_actor_information_and_train(
|
||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||
training_infos["temperature"] = policy.temperature
|
||||
|
||||
# Update temperature
|
||||
policy.update_temperature()
|
||||
|
||||
# Push policy to actors if needed
|
||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .bi_openarm_follower import BiOpenArmFollower
|
||||
from .config_bi_openarm_follower import BiOpenArmFollowerConfig
|
||||
|
||||
__all__ = ["BiOpenArmFollower", "BiOpenArmFollowerConfig"]
|
||||
@@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_bi_openarm_follower import BiOpenArmFollowerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiOpenArmFollower(Robot):
|
||||
"""
|
||||
Bimanual OpenArm Follower Arms
|
||||
"""
|
||||
|
||||
config_class = BiOpenArmFollowerConfig
|
||||
name = "bi_openarm_follower"
|
||||
|
||||
def __init__(self, config: BiOpenArmFollowerConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
left_arm_config = OpenArmFollowerConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.left_arm_config.port,
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
cameras=config.left_arm_config.cameras,
|
||||
side=config.left_arm_config.side,
|
||||
can_interface=config.left_arm_config.can_interface,
|
||||
use_can_fd=config.left_arm_config.use_can_fd,
|
||||
can_bitrate=config.left_arm_config.can_bitrate,
|
||||
can_data_bitrate=config.left_arm_config.can_data_bitrate,
|
||||
motor_config=config.left_arm_config.motor_config,
|
||||
position_kd=config.left_arm_config.position_kd,
|
||||
position_kp=config.left_arm_config.position_kp,
|
||||
joint_limits=config.left_arm_config.joint_limits,
|
||||
)
|
||||
|
||||
right_arm_config = OpenArmFollowerConfig(
|
||||
id=f"{config.id}_right" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.right_arm_config.port,
|
||||
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
|
||||
max_relative_target=config.right_arm_config.max_relative_target,
|
||||
cameras=config.right_arm_config.cameras,
|
||||
side=config.right_arm_config.side,
|
||||
can_interface=config.right_arm_config.can_interface,
|
||||
use_can_fd=config.right_arm_config.use_can_fd,
|
||||
can_bitrate=config.right_arm_config.can_bitrate,
|
||||
can_data_bitrate=config.right_arm_config.can_data_bitrate,
|
||||
motor_config=config.right_arm_config.motor_config,
|
||||
position_kd=config.right_arm_config.position_kd,
|
||||
position_kp=config.right_arm_config.position_kp,
|
||||
joint_limits=config.right_arm_config.joint_limits,
|
||||
)
|
||||
|
||||
self.left_arm = OpenArmFollower(left_arm_config)
|
||||
self.right_arm = OpenArmFollower(right_arm_config)
|
||||
|
||||
# Only for compatibility with other parts of the codebase that expect a `robot.cameras` attribute
|
||||
self.cameras = {**self.left_arm.cameras, **self.right_arm.cameras}
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
left_arm_motors_ft = self.left_arm._motors_ft
|
||||
right_arm_motors_ft = self.right_arm._motors_ft
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
|
||||
}
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
left_arm_cameras_ft = self.left_arm._cameras_ft
|
||||
right_arm_cameras_ft = self.right_arm._cameras_ft
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
left_obs = self.left_arm.get_observation()
|
||||
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
|
||||
|
||||
# Add "right_" prefix
|
||||
right_obs = self.right_arm.get_observation()
|
||||
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(
|
||||
self,
|
||||
action: RobotAction,
|
||||
custom_kp: dict[str, float] | None = None,
|
||||
custom_kd: dict[str, float] | None = None,
|
||||
) -> RobotAction:
|
||||
# Remove "left_" prefix
|
||||
left_action = {
|
||||
key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_")
|
||||
}
|
||||
# Remove "right_" prefix
|
||||
right_action = {
|
||||
key.removeprefix("right_"): value for key, value in action.items() if key.startswith("right_")
|
||||
}
|
||||
|
||||
sent_action_left = self.left_arm.send_action(left_action, custom_kp, custom_kd)
|
||||
sent_action_right = self.right_arm.send_action(right_action, custom_kp, custom_kd)
|
||||
|
||||
# Add prefixes back
|
||||
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
|
||||
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
||||
|
||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||
|
||||
def disconnect(self):
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.robots.openarm_follower import OpenArmFollowerConfigBase
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("bi_openarm_follower")
|
||||
@dataclass
|
||||
class BiOpenArmFollowerConfig(RobotConfig):
|
||||
"""Configuration class for Bi OpenArm Follower robots."""
|
||||
|
||||
left_arm_config: OpenArmFollowerConfigBase
|
||||
right_arm_config: OpenArmFollowerConfigBase
|
||||
@@ -24,7 +24,8 @@ import numpy as np
|
||||
import requests
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
|
||||
@@ -99,6 +100,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
"""Check if robot is connected to SDK."""
|
||||
return self._is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""Connect to robot via Frodobots SDK.
|
||||
|
||||
@@ -109,8 +111,6 @@ class EarthRoverMiniPlus(Robot):
|
||||
DeviceAlreadyConnectedError: If robot is already connected
|
||||
DeviceNotConnectedError: If cannot connect to SDK server
|
||||
"""
|
||||
if self._is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self.name} is already connected")
|
||||
|
||||
# Verify SDK is running and accessible
|
||||
try:
|
||||
@@ -197,6 +197,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
ACTION_ANGULAR_VEL: float,
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""Get current robot observation from SDK.
|
||||
|
||||
@@ -223,8 +224,6 @@ class EarthRoverMiniPlus(Robot):
|
||||
Robot telemetry is retrieved from /data endpoint.
|
||||
All SDK values are normalized to appropriate ranges for dataset recording.
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||
|
||||
observation = {}
|
||||
|
||||
@@ -255,6 +254,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
|
||||
return observation
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Send action to robot via SDK.
|
||||
|
||||
@@ -272,8 +272,6 @@ class EarthRoverMiniPlus(Robot):
|
||||
Actions are sent to SDK via POST /control endpoint.
|
||||
SDK expects commands in range [-1, 1].
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||
|
||||
# Extract action values and convert to float
|
||||
linear = float(action.get(ACTION_LINEAR_VEL, 0.0))
|
||||
@@ -291,6 +289,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
ACTION_ANGULAR_VEL: angular,
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from robot.
|
||||
|
||||
@@ -299,8 +298,6 @@ class EarthRoverMiniPlus(Robot):
|
||||
Raises:
|
||||
DeviceNotConnectedError: If robot is not connected
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||
|
||||
# Stop the robot before disconnecting
|
||||
try:
|
||||
|
||||
@@ -25,7 +25,7 @@ from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -82,13 +82,12 @@ class HopeJrArm(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
We assume that at connection time, arm is in a rest position,
|
||||
and torque can be safely disabled to run calibration.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect(handshake=False)
|
||||
if not self.is_calibrated and calibrate:
|
||||
@@ -128,10 +127,8 @@ class HopeJrArm(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position", self.other_motors)
|
||||
@@ -149,10 +146,8 @@ class HopeJrArm(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
@@ -165,10 +160,8 @@ class HopeJrArm(Robot):
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
@@ -25,7 +25,7 @@ from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_hope_jr import HopeJrHandConfig
|
||||
@@ -118,10 +118,8 @@ class HopeJrHand(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
self.calibrate()
|
||||
@@ -159,10 +157,8 @@ class HopeJrHand(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
obs_dict = {}
|
||||
|
||||
# Read hand position
|
||||
@@ -181,18 +177,14 @@ class HopeJrHand(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return action
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
@@ -25,7 +25,7 @@ from lerobot.motors.dynamixel import (
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -84,13 +84,12 @@ class KochFollower(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
We assume that at connection time, arm is in a rest position,
|
||||
and torque can be safely disabled to run calibration.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
@@ -182,10 +181,8 @@ class KochFollower(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
@@ -202,6 +199,7 @@ class KochFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
@@ -215,8 +213,6 @@ class KochFollower(Robot):
|
||||
Returns:
|
||||
RobotAction: The action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
@@ -231,10 +227,8 @@ class KochFollower(Robot):
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
@@ -29,7 +29,7 @@ from lerobot.motors.feetech import (
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -109,10 +109,8 @@ class LeKiwi(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
@@ -339,10 +337,8 @@ class LeKiwi(Robot):
|
||||
"theta.vel": theta,
|
||||
} # m/s and deg/s
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read actuators position for arm and vel for base
|
||||
start = time.perf_counter()
|
||||
arm_pos = self.bus.sync_read("Present_Position", self.arm_motors)
|
||||
@@ -370,6 +366,7 @@ class LeKiwi(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command lekiwi to move to a target joint configuration.
|
||||
|
||||
@@ -383,8 +380,6 @@ class LeKiwi(Robot):
|
||||
Returns:
|
||||
RobotAction: the action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
arm_goal_pos = {k: v for k, v in action.items() if k.endswith(".pos")}
|
||||
base_goal_vel = {k: v for k, v in action.items() if k.endswith(".vel")}
|
||||
@@ -412,10 +407,8 @@ class LeKiwi(Robot):
|
||||
self.bus.sync_write("Goal_Velocity", dict.fromkeys(self.base_motors, 0), num_retry=5)
|
||||
logger.info("Base motors stopped")
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.stop_base()
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
|
||||
@@ -24,7 +24,8 @@ import numpy as np
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_lekiwi import LeKiwiClientConfig
|
||||
@@ -112,14 +113,10 @@ class LeKiwiClient(Robot):
|
||||
def is_calibrated(self) -> bool:
|
||||
pass
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self) -> None:
|
||||
"""Establishes ZMQ sockets with the remote mobile robot"""
|
||||
|
||||
if self._is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
"LeKiwi Daemon is already connected. Do not run `robot.connect()` twice."
|
||||
)
|
||||
|
||||
zmq = self._zmq
|
||||
self.zmq_context = zmq.Context()
|
||||
self.zmq_cmd_socket = self.zmq_context.socket(zmq.PUSH)
|
||||
@@ -252,14 +249,13 @@ class LeKiwiClient(Robot):
|
||||
|
||||
return new_frames, new_state
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Capture observations from the remote robot: current follower arm positions,
|
||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||
and a camera frame. Receives over ZMQ, translate to body-frame vel
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.")
|
||||
|
||||
frames, obs_dict = self._get_data()
|
||||
|
||||
@@ -307,6 +303,7 @@ class LeKiwiClient(Robot):
|
||||
def configure(self):
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ
|
||||
|
||||
@@ -318,10 +315,6 @@ class LeKiwiClient(Robot):
|
||||
Returns:
|
||||
np.ndarray: the action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
self.zmq_cmd_socket.send_string(json.dumps(action)) # action is in motor space
|
||||
|
||||
@@ -332,13 +325,10 @@ class LeKiwiClient(Robot):
|
||||
action_sent[ACTION] = actions
|
||||
return action_sent
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
"""Cleans ZMQ comms"""
|
||||
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"LeKiwi is not connected. You need to run `robot.connect()` before disconnecting."
|
||||
)
|
||||
self.zmq_observation_socket.close()
|
||||
self.zmq_cmd_socket.close()
|
||||
self.zmq_context.term()
|
||||
|
||||
@@ -26,7 +26,7 @@ from lerobot.motors.dynamixel import (
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -84,6 +84,7 @@ class OmxFollower(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
For OMX robots that come pre-calibrated:
|
||||
@@ -91,8 +92,6 @@ class OmxFollower(Robot):
|
||||
- This allows using pre-calibrated robots without manual calibration
|
||||
- If no calibration file exists, use factory default values (homing_offset=0, range_min=0, range_max=4095)
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
@@ -165,10 +164,8 @@ class OmxFollower(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
@@ -185,6 +182,7 @@ class OmxFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
@@ -198,8 +196,6 @@ class OmxFollower(Robot):
|
||||
Returns:
|
||||
RobotAction: The action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
@@ -214,10 +210,8 @@ class OmxFollower(Robot):
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_openarm_follower import OpenArmFollowerConfig, OpenArmFollowerConfigBase
|
||||
from .openarm_follower import OpenArmFollower
|
||||
|
||||
__all__ = ["OpenArmFollower", "OpenArmFollowerConfig", "OpenArmFollowerConfigBase"]
|
||||
@@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
LEFT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = {
|
||||
"joint_1": (-75.0, 75.0),
|
||||
"joint_2": (-90.0, 9.0),
|
||||
"joint_3": (-85.0, 85.0),
|
||||
"joint_4": (0.0, 135.0),
|
||||
"joint_5": (-85.0, 85.0),
|
||||
"joint_6": (-40.0, 40.0),
|
||||
"joint_7": (-80.0, 80.0),
|
||||
"gripper": (-65.0, 0.0),
|
||||
}
|
||||
|
||||
RIGHT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = {
|
||||
"joint_1": (-75.0, 75.0),
|
||||
"joint_2": (-9.0, 90.0),
|
||||
"joint_3": (-85.0, 85.0),
|
||||
"joint_4": (0.0, 135.0),
|
||||
"joint_5": (-85.0, 85.0),
|
||||
"joint_6": (-40.0, 40.0),
|
||||
"joint_7": (-80.0, 80.0),
|
||||
"gripper": (-65.0, 0.0),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenArmFollowerConfigBase:
|
||||
"""Base configuration for the OpenArms follower robot with Damiao motors."""
|
||||
|
||||
# CAN interfaces - one per arm
|
||||
# arm CAN interface (e.g., "can1")
|
||||
# Linux: "can0", "can1", etc.
|
||||
port: str
|
||||
|
||||
# side of the arm: "left" or "right". If "None" default values will be used
|
||||
side: str | None = None
|
||||
|
||||
# CAN interface type: "socketcan" (Linux), "slcan" (serial), or "auto" (auto-detect)
|
||||
can_interface: str = "socketcan"
|
||||
|
||||
# CAN FD settings (OpenArms uses CAN FD by default)
|
||||
use_can_fd: bool = True
|
||||
can_bitrate: int = 1000000 # Nominal bitrate (1 Mbps)
|
||||
can_data_bitrate: int = 5000000 # Data bitrate for CAN FD (5 Mbps)
|
||||
|
||||
# Whether to disable torque when disconnecting
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# Safety limit for relative target positions
|
||||
# Set to a positive scalar for all motors, or a dict mapping motor names to limits
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
# Camera configurations
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# Motor configuration for OpenArms (7 DOF per arm)
|
||||
# Maps motor names to (send_can_id, recv_can_id, motor_type)
|
||||
# Based on: https://docs.openarm.dev/software/setup/configure-test
|
||||
# OpenArms uses 4 types of motors:
|
||||
# - DM8009 (DM-J8009P-2EC) for shoulders (high torque)
|
||||
# - DM4340P and DM4340 for shoulder rotation and elbow
|
||||
# - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper
|
||||
motor_config: dict[str, tuple[int, int, str]] = field(
|
||||
default_factory=lambda: {
|
||||
"joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009)
|
||||
"joint_2": (0x02, 0x12, "dm8009"), # J2 - Shoulder lift (DM8009)
|
||||
"joint_3": (0x03, 0x13, "dm4340"), # J3 - Shoulder rotation (DM4340)
|
||||
"joint_4": (0x04, 0x14, "dm4340"), # J4 - Elbow flex (DM4340)
|
||||
"joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310)
|
||||
"joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310)
|
||||
"joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310)
|
||||
"gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310)
|
||||
}
|
||||
)
|
||||
|
||||
# MIT control parameters for position control (used in send_action)
|
||||
# List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
|
||||
position_kp: list[float] = field(
|
||||
default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 25.0]
|
||||
)
|
||||
position_kd: list[float] = field(default_factory=lambda: [5.0, 5.0, 3.0, 5.0, 0.3, 0.3, 0.3, 0.3])
|
||||
|
||||
# Values for joint limits. Can be overridden via CLI (for custom values) or by setting config.side to either 'left' or 'right'.
|
||||
# If config.side is left set to None and no CLI values are passed, the default joint limit values are small for safety.
|
||||
joint_limits: dict[str, tuple[float, float]] = field(
|
||||
default_factory=lambda: {
|
||||
"joint_1": (-5.0, 5.0),
|
||||
"joint_2": (-5.0, 5.0),
|
||||
"joint_3": (-5.0, 5.0),
|
||||
"joint_4": (0.0, 5.0),
|
||||
"joint_5": (-5.0, 5.0),
|
||||
"joint_6": (-5.0, 5.0),
|
||||
"joint_7": (-5.0, 5.0),
|
||||
"gripper": (-5.0, 0.0),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("openarm_follower")
|
||||
@dataclass
|
||||
class OpenArmFollowerConfig(RobotConfig, OpenArmFollowerConfigBase):
|
||||
pass
|
||||
@@ -0,0 +1,348 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.damiao import DamiaoMotorsBus
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
from .config_openarm_follower import (
|
||||
LEFT_DEFAULT_JOINTS_LIMITS,
|
||||
RIGHT_DEFAULT_JOINTS_LIMITS,
|
||||
OpenArmFollowerConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenArmFollower(Robot):
|
||||
"""
|
||||
OpenArms Follower Robot which uses CAN bus communication to control 7 DOF arm with a gripper.
|
||||
The arm uses Damiao motors in MIT control mode.
|
||||
"""
|
||||
|
||||
config_class = OpenArmFollowerConfig
|
||||
name = "openarm_follower"
|
||||
|
||||
def __init__(self, config: OpenArmFollowerConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Arm motors
|
||||
motors: dict[str, Motor] = {}
|
||||
for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items():
|
||||
motor = Motor(
|
||||
send_id, motor_type_str, MotorNormMode.DEGREES
|
||||
) # Always use degrees for Damiao motors
|
||||
motor.recv_id = recv_id
|
||||
motor.motor_type_str = motor_type_str
|
||||
motors[motor_name] = motor
|
||||
|
||||
self.bus = DamiaoMotorsBus(
|
||||
port=self.config.port,
|
||||
motors=motors,
|
||||
calibration=self.calibration,
|
||||
can_interface=self.config.can_interface,
|
||||
use_can_fd=self.config.use_can_fd,
|
||||
bitrate=self.config.can_bitrate,
|
||||
data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None,
|
||||
)
|
||||
|
||||
if config.side is not None:
|
||||
if config.side == "left":
|
||||
config.joint_limits = LEFT_DEFAULT_JOINTS_LIMITS
|
||||
elif config.side == "right":
|
||||
config.joint_limits = RIGHT_DEFAULT_JOINTS_LIMITS
|
||||
else:
|
||||
raise ValueError(
|
||||
"config.side must be either 'left', 'right' (for default values) or 'None' (for CLI values)"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Set config.side to either 'left' or 'right' to use pre-configured values for joint limits."
|
||||
)
|
||||
logger.info(f"Values used for joint limits: {config.joint_limits}.")
|
||||
|
||||
# Initialize cameras
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
"""Motor features for observation and action spaces."""
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus.motors:
|
||||
features[f"{motor}.pos"] = float
|
||||
features[f"{motor}.vel"] = float # Add this
|
||||
features[f"{motor}.torque"] = float # Add this
|
||||
return features
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
"""Camera features for observation space."""
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
"""Combined observation features from motors and cameras."""
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
"""Action features."""
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if robot is connected."""
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
Connect to the robot and optionally calibrate.
|
||||
|
||||
We assume that at connection time, the arms are in a safe rest position,
|
||||
and torque can be safely disabled to run calibration if needed.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
# Connect to CAN bus
|
||||
logger.info(f"Connecting arm on {self.config.port}...")
|
||||
self.bus.connect()
|
||||
|
||||
# Run calibration if needed
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
|
||||
)
|
||||
self.calibrate()
|
||||
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
self.configure()
|
||||
|
||||
if self.is_calibrated:
|
||||
self.bus.set_zero_position()
|
||||
|
||||
self.bus.enable_torque()
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
"""Check if robot is calibrated."""
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""
|
||||
Run calibration procedure for OpenArms robot.
|
||||
|
||||
The calibration procedure:
|
||||
1. Disable torque
|
||||
2. Ask user to position arms in hanging position with grippers closed
|
||||
3. Set this as zero position
|
||||
4. Record range of motion for each joint
|
||||
5. Save calibration
|
||||
"""
|
||||
if self.calibration:
|
||||
# Calibration file exists, ask user whether to use it or run new calibration
|
||||
user_input = input(
|
||||
f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
|
||||
)
|
||||
if user_input.strip().lower() != "c":
|
||||
logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
|
||||
self.bus.write_calibration(self.calibration)
|
||||
return
|
||||
|
||||
logger.info(f"\nRunning calibration for {self}")
|
||||
self.bus.disable_torque()
|
||||
|
||||
# Step 1: Set zero position
|
||||
input(
|
||||
"\nCalibration: Set Zero Position)\n"
|
||||
"Position the arm in the following configuration:\n"
|
||||
" - Arm hanging straight down\n"
|
||||
" - Gripper closed\n"
|
||||
"Press ENTER when ready..."
|
||||
)
|
||||
|
||||
# Set current position as zero for all motors
|
||||
self.bus.set_zero_position()
|
||||
logger.info("Arm zero position set.")
|
||||
|
||||
logger.info("Setting range: -90° to +90° for safety by default for all joints")
|
||||
for motor_name, motor in self.bus.motors.items():
|
||||
self.calibration[motor_name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=0,
|
||||
homing_offset=0,
|
||||
range_min=-90,
|
||||
range_max=90,
|
||||
)
|
||||
|
||||
self.bus.write_calibration(self.calibration)
|
||||
self._save_calibration()
|
||||
print(f"Calibration saved to {self.calibration_fpath}")
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Configure motors with appropriate settings."""
|
||||
# TODO(Steven, Pepijn): Slightly different from what it is happening in the leader
|
||||
with self.bus.torque_disabled():
|
||||
self.bus.configure_motors()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Get current observation from robot including position, velocity, and torque.
|
||||
|
||||
Reads all motor states (pos/vel/torque) in one CAN refresh cycle
|
||||
instead of 3 separate reads.
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
obs_dict: dict[str, Any] = {}
|
||||
|
||||
states = self.bus.sync_read_all_states()
|
||||
|
||||
for motor in self.bus.motors:
|
||||
state = states.get(motor, {})
|
||||
obs_dict[f"{motor}.pos"] = state.get("position", 0.0)
|
||||
obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0)
|
||||
obs_dict[f"{motor}.torque"] = state.get("torque", 0.0)
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} get_observation took: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(
|
||||
self,
|
||||
action: RobotAction,
|
||||
custom_kp: dict[str, float] | None = None,
|
||||
custom_kd: dict[str, float] | None = None,
|
||||
) -> RobotAction:
|
||||
"""
|
||||
Send action command to robot.
|
||||
|
||||
The action magnitude may be clipped based on safety limits.
|
||||
|
||||
Args:
|
||||
action: Dictionary with motor positions (e.g., "joint_1.pos", "joint_2.pos")
|
||||
custom_kp: Optional custom kp gains per motor (e.g., {"joint_1": 120.0, "joint_2": 150.0})
|
||||
custom_kd: Optional custom kd gains per motor (e.g., {"joint_1": 1.5, "joint_2": 2.0})
|
||||
|
||||
Returns:
|
||||
The action actually sent (potentially clipped)
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
# Apply joint limit clipping to arm
|
||||
for motor_name, position in goal_pos.items():
|
||||
if motor_name in self.config.joint_limits:
|
||||
min_limit, max_limit = self.config.joint_limits[motor_name]
|
||||
clipped_position = max(min_limit, min(max_limit, position))
|
||||
if clipped_position != position:
|
||||
logger.debug(f"Clipped {motor_name} from {position:.2f}° to {clipped_position:.2f}°")
|
||||
goal_pos[motor_name] = clipped_position
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# /!\ Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.bus.sync_read("Present_Position")
|
||||
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
|
||||
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
||||
|
||||
# TODO(Steven, Pepijn): Refactor writing
|
||||
# Motor name to index mapping for gains
|
||||
motor_index = {
|
||||
"joint_1": 0,
|
||||
"joint_2": 1,
|
||||
"joint_3": 2,
|
||||
"joint_4": 3,
|
||||
"joint_5": 4,
|
||||
"joint_6": 5,
|
||||
"joint_7": 6,
|
||||
"gripper": 7,
|
||||
}
|
||||
|
||||
# Use batch MIT control for arm (sends all commands, then collects responses)
|
||||
commands = {}
|
||||
for motor_name, position_degrees in goal_pos.items():
|
||||
idx = motor_index.get(motor_name, 0)
|
||||
# Use custom gains if provided, otherwise use config defaults
|
||||
if custom_kp is not None and motor_name in custom_kp:
|
||||
kp = custom_kp[motor_name]
|
||||
else:
|
||||
kp = (
|
||||
self.config.position_kp[idx]
|
||||
if isinstance(self.config.position_kp, list)
|
||||
else self.config.position_kp
|
||||
)
|
||||
if custom_kd is not None and motor_name in custom_kd:
|
||||
kd = custom_kd[motor_name]
|
||||
else:
|
||||
kd = (
|
||||
self.config.position_kd[idx]
|
||||
if isinstance(self.config.position_kd, list)
|
||||
else self.config.position_kd
|
||||
)
|
||||
commands[motor_name] = (kp, kd, position_degrees, 0.0, 0.0)
|
||||
|
||||
self.bus._mit_control_batch(commands)
|
||||
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from robot."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Disconnect CAN bus
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
|
||||
# Disconnect cameras
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -58,6 +58,32 @@ class Robot(abc.ABC):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.id} {self.__class__.__name__}"
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Context manager entry.
|
||||
Automatically connects to the camera.
|
||||
"""
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
"""
|
||||
Context manager exit.
|
||||
Automatically disconnects, ensuring resources are released even on error.
|
||||
"""
|
||||
self.disconnect()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
Destructor safety net.
|
||||
Attempts to disconnect if the object is garbage collected without cleanup.
|
||||
"""
|
||||
try:
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
except Exception: # nosec B110
|
||||
pass
|
||||
|
||||
# TODO(aliberts): create a proper Feature class for this that links with datasets
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -26,7 +26,7 @@ from lerobot.motors.feetech import (
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -85,13 +85,12 @@ class SOFollower(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
We assume that at connection time, arm is in a rest position,
|
||||
and torque can be safely disabled to run calibration.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
@@ -176,10 +175,8 @@ class SOFollower(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
@@ -196,6 +193,7 @@ class SOFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
@@ -209,8 +207,6 @@ class SOFollower(Robot):
|
||||
Returns:
|
||||
RobotAction: the action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
@@ -225,10 +221,8 @@ class SOFollower(Robot):
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
@@ -65,3 +65,6 @@ class UnitreeG1Config(RobotConfig):
|
||||
|
||||
# Cameras (ZMQ-based remote cameras)
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# Compensates for gravity on the unitree's arms using the arm ik solver
|
||||
gravity_compensation: bool = False
|
||||
|
||||
@@ -18,7 +18,7 @@ from enum import IntEnum
|
||||
|
||||
# ruff: noqa: N801, N815
|
||||
|
||||
NUM_MOTORS = 35
|
||||
NUM_MOTORS = 29
|
||||
|
||||
|
||||
class G1_29_JointArmIndex(IntEnum):
|
||||
|
||||
@@ -0,0 +1,313 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
parent2_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(parent2_dir)
|
||||
|
||||
|
||||
class WeightedMovingFilter:
|
||||
def __init__(self, weights, data_size=14):
|
||||
self._window_size = len(weights)
|
||||
self._weights = np.array(weights)
|
||||
self._data_size = data_size
|
||||
self._filtered_data = np.zeros(self._data_size)
|
||||
self._data_queue = []
|
||||
|
||||
def _apply_filter(self):
|
||||
if len(self._data_queue) < self._window_size:
|
||||
return self._data_queue[-1]
|
||||
|
||||
data_array = np.array(self._data_queue)
|
||||
temp_filtered_data = np.zeros(self._data_size)
|
||||
for i in range(self._data_size):
|
||||
temp_filtered_data[i] = np.convolve(data_array[:, i], self._weights, mode="valid")[-1]
|
||||
|
||||
return temp_filtered_data
|
||||
|
||||
def add_data(self, new_data):
|
||||
assert len(new_data) == self._data_size
|
||||
|
||||
if len(self._data_queue) > 0 and np.array_equal(
|
||||
new_data, self._data_queue[-1]
|
||||
): # skip duplicate data
|
||||
return
|
||||
|
||||
if len(self._data_queue) >= self._window_size:
|
||||
self._data_queue.pop(0)
|
||||
|
||||
self._data_queue.append(new_data)
|
||||
self._filtered_data = self._apply_filter()
|
||||
|
||||
@property
|
||||
def filtered_data(self):
|
||||
return self._filtered_data
|
||||
|
||||
|
||||
class G1_29_ArmIK: # noqa: N801
|
||||
def __init__(self, unit_test=False):
|
||||
import casadi
|
||||
import pinocchio as pin
|
||||
from huggingface_hub import snapshot_download
|
||||
from pinocchio import casadi as cpin
|
||||
|
||||
self._pin = pin
|
||||
np.set_printoptions(precision=5, suppress=True, linewidth=200)
|
||||
|
||||
self.unit_test = unit_test
|
||||
|
||||
self.repo_path = snapshot_download("lerobot/unitree-g1-mujoco")
|
||||
urdf_path = os.path.join(self.repo_path, "assets", "g1_body29_hand14.urdf")
|
||||
mesh_dir = os.path.join(self.repo_path, "assets")
|
||||
|
||||
self.robot = self._pin.RobotWrapper.BuildFromURDF(urdf_path, mesh_dir)
|
||||
|
||||
self.mixed_jointsToLockIDs = [
|
||||
"left_hip_pitch_joint",
|
||||
"left_hip_roll_joint",
|
||||
"left_hip_yaw_joint",
|
||||
"left_knee_joint",
|
||||
"left_ankle_pitch_joint",
|
||||
"left_ankle_roll_joint",
|
||||
"right_hip_pitch_joint",
|
||||
"right_hip_roll_joint",
|
||||
"right_hip_yaw_joint",
|
||||
"right_knee_joint",
|
||||
"right_ankle_pitch_joint",
|
||||
"right_ankle_roll_joint",
|
||||
"waist_yaw_joint",
|
||||
"waist_roll_joint",
|
||||
"waist_pitch_joint",
|
||||
"left_hand_thumb_0_joint",
|
||||
"left_hand_thumb_1_joint",
|
||||
"left_hand_thumb_2_joint",
|
||||
"left_hand_middle_0_joint",
|
||||
"left_hand_middle_1_joint",
|
||||
"left_hand_index_0_joint",
|
||||
"left_hand_index_1_joint",
|
||||
"right_hand_thumb_0_joint",
|
||||
"right_hand_thumb_1_joint",
|
||||
"right_hand_thumb_2_joint",
|
||||
"right_hand_index_0_joint",
|
||||
"right_hand_index_1_joint",
|
||||
"right_hand_middle_0_joint",
|
||||
"right_hand_middle_1_joint",
|
||||
]
|
||||
|
||||
self.reduced_robot = self.robot.buildReducedRobot(
|
||||
list_of_joints_to_lock=self.mixed_jointsToLockIDs,
|
||||
reference_configuration=np.array([0.0] * self.robot.model.nq),
|
||||
)
|
||||
|
||||
# Arm joint names in G1 motor order (G1_29_JointArmIndex)
|
||||
self._arm_joint_names_g1 = [
|
||||
"left_shoulder_pitch_joint",
|
||||
"left_shoulder_roll_joint",
|
||||
"left_shoulder_yaw_joint",
|
||||
"left_elbow_joint",
|
||||
"left_wrist_roll_joint",
|
||||
"left_wrist_pitch_joint",
|
||||
"left_wrist_yaw_joint",
|
||||
"right_shoulder_pitch_joint",
|
||||
"right_shoulder_roll_joint",
|
||||
"right_shoulder_yaw_joint",
|
||||
"right_elbow_joint",
|
||||
"right_wrist_roll_joint",
|
||||
"right_wrist_pitch_joint",
|
||||
"right_wrist_yaw_joint",
|
||||
]
|
||||
# Pinocchio uses its own joint order in q; build index mapping.
|
||||
self._arm_joint_names_pin = sorted(
|
||||
self._arm_joint_names_g1,
|
||||
key=lambda name: self.reduced_robot.model.idx_qs[self.reduced_robot.model.getJointId(name)],
|
||||
)
|
||||
logger.info(f"Pinocchio arm joint order: {self._arm_joint_names_pin}")
|
||||
self._arm_reorder_g1_to_pin = [
|
||||
self._arm_joint_names_g1.index(name) for name in self._arm_joint_names_pin
|
||||
]
|
||||
# Inverse mapping to return tau in G1 motor order.
|
||||
self._arm_reorder_pin_to_g1 = np.argsort(self._arm_reorder_g1_to_pin)
|
||||
|
||||
self.reduced_robot.model.addFrame(
|
||||
self._pin.Frame(
|
||||
"L_ee",
|
||||
self.reduced_robot.model.getJointId("left_wrist_yaw_joint"),
|
||||
self._pin.SE3(np.eye(3), np.array([0.05, 0, 0]).T),
|
||||
self._pin.FrameType.OP_FRAME,
|
||||
)
|
||||
)
|
||||
|
||||
self.reduced_robot.model.addFrame(
|
||||
self._pin.Frame(
|
||||
"R_ee",
|
||||
self.reduced_robot.model.getJointId("right_wrist_yaw_joint"),
|
||||
self._pin.SE3(np.eye(3), np.array([0.05, 0, 0]).T),
|
||||
self._pin.FrameType.OP_FRAME,
|
||||
)
|
||||
)
|
||||
|
||||
# Creating Casadi models and data for symbolic computing
|
||||
self.cmodel = cpin.Model(self.reduced_robot.model)
|
||||
self.cdata = self.cmodel.createData()
|
||||
|
||||
# Creating symbolic variables
|
||||
self.cq = casadi.SX.sym("q", self.reduced_robot.model.nq, 1)
|
||||
self.cTf_l = casadi.SX.sym("tf_l", 4, 4)
|
||||
self.cTf_r = casadi.SX.sym("tf_r", 4, 4)
|
||||
cpin.framesForwardKinematics(self.cmodel, self.cdata, self.cq)
|
||||
|
||||
# Get the hand joint ID and define the error function
|
||||
self.L_hand_id = self.reduced_robot.model.getFrameId("L_ee")
|
||||
self.R_hand_id = self.reduced_robot.model.getFrameId("R_ee")
|
||||
|
||||
self.translational_error = casadi.Function(
|
||||
"translational_error",
|
||||
[self.cq, self.cTf_l, self.cTf_r],
|
||||
[
|
||||
casadi.vertcat(
|
||||
self.cdata.oMf[self.L_hand_id].translation - self.cTf_l[:3, 3],
|
||||
self.cdata.oMf[self.R_hand_id].translation - self.cTf_r[:3, 3],
|
||||
)
|
||||
],
|
||||
)
|
||||
self.rotational_error = casadi.Function(
|
||||
"rotational_error",
|
||||
[self.cq, self.cTf_l, self.cTf_r],
|
||||
[
|
||||
casadi.vertcat(
|
||||
cpin.log3(self.cdata.oMf[self.L_hand_id].rotation @ self.cTf_l[:3, :3].T),
|
||||
cpin.log3(self.cdata.oMf[self.R_hand_id].rotation @ self.cTf_r[:3, :3].T),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Defining the optimization problem
|
||||
self.opti = casadi.Opti()
|
||||
self.var_q = self.opti.variable(self.reduced_robot.model.nq)
|
||||
self.var_q_last = self.opti.parameter(self.reduced_robot.model.nq) # for smooth
|
||||
self.param_tf_l = self.opti.parameter(4, 4)
|
||||
self.param_tf_r = self.opti.parameter(4, 4)
|
||||
self.translational_cost = casadi.sumsqr(
|
||||
self.translational_error(self.var_q, self.param_tf_l, self.param_tf_r)
|
||||
)
|
||||
self.rotation_cost = casadi.sumsqr(
|
||||
self.rotational_error(self.var_q, self.param_tf_l, self.param_tf_r)
|
||||
)
|
||||
self.regularization_cost = casadi.sumsqr(self.var_q)
|
||||
self.smooth_cost = casadi.sumsqr(self.var_q - self.var_q_last)
|
||||
|
||||
# Setting optimization constraints and goals
|
||||
self.opti.subject_to(
|
||||
self.opti.bounded(
|
||||
self.reduced_robot.model.lowerPositionLimit,
|
||||
self.var_q,
|
||||
self.reduced_robot.model.upperPositionLimit,
|
||||
)
|
||||
)
|
||||
self.opti.minimize(
|
||||
50 * self.translational_cost
|
||||
+ self.rotation_cost
|
||||
+ 0.02 * self.regularization_cost
|
||||
+ 0.1 * self.smooth_cost
|
||||
)
|
||||
|
||||
opts = {
|
||||
"ipopt": {"print_level": 0, "max_iter": 50, "tol": 1e-6},
|
||||
"print_time": False, # print or not
|
||||
"calc_lam_p": False, # https://github.com/casadi/casadi/wiki/FAQ:-Why-am-I-getting-%22NaN-detected%22in-my-optimization%3F
|
||||
}
|
||||
self.opti.solver("ipopt", opts)
|
||||
|
||||
self.init_data = np.zeros(self.reduced_robot.model.nq)
|
||||
self.smooth_filter = WeightedMovingFilter(np.array([0.4, 0.3, 0.2, 0.1]), 14)
|
||||
|
||||
def solve_ik(self, left_wrist, right_wrist, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None):
|
||||
if current_lr_arm_motor_q is not None:
|
||||
self.init_data = current_lr_arm_motor_q
|
||||
self.opti.set_initial(self.var_q, self.init_data)
|
||||
|
||||
self.opti.set_value(self.param_tf_l, left_wrist)
|
||||
self.opti.set_value(self.param_tf_r, right_wrist)
|
||||
self.opti.set_value(self.var_q_last, self.init_data) # for smooth
|
||||
|
||||
try:
|
||||
self.opti.solve()
|
||||
|
||||
sol_q = self.opti.value(self.var_q)
|
||||
self.smooth_filter.add_data(sol_q)
|
||||
sol_q = self.smooth_filter.filtered_data
|
||||
|
||||
if current_lr_arm_motor_dq is not None:
|
||||
v = current_lr_arm_motor_dq * 0.0
|
||||
else:
|
||||
v = (sol_q - self.init_data) * 0.0
|
||||
|
||||
self.init_data = sol_q
|
||||
|
||||
sol_tauff = self._pin.rnea(
|
||||
self.reduced_robot.model,
|
||||
self.reduced_robot.data,
|
||||
sol_q,
|
||||
v,
|
||||
np.zeros(self.reduced_robot.model.nv),
|
||||
)
|
||||
|
||||
return sol_q, sol_tauff
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR in convergence, plotting debug info.{e}")
|
||||
|
||||
sol_q = self.opti.debug.value(self.var_q)
|
||||
self.smooth_filter.add_data(sol_q)
|
||||
sol_q = self.smooth_filter.filtered_data
|
||||
|
||||
if current_lr_arm_motor_dq is not None:
|
||||
v = current_lr_arm_motor_dq * 0.0
|
||||
else:
|
||||
v = (sol_q - self.init_data) * 0.0
|
||||
|
||||
self.init_data = sol_q
|
||||
|
||||
logger.error(
|
||||
f"sol_q:{sol_q} \nmotorstate: \n{current_lr_arm_motor_q} \nleft_pose: \n{left_wrist} \nright_pose: \n{right_wrist}"
|
||||
)
|
||||
|
||||
return current_lr_arm_motor_q, np.zeros(self.reduced_robot.model.nv)
|
||||
|
||||
def solve_tau(self, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None):
|
||||
try:
|
||||
q_g1 = np.array(current_lr_arm_motor_q, dtype=float)
|
||||
if q_g1.shape[0] != len(self._arm_joint_names_g1):
|
||||
raise ValueError(f"Expected {len(self._arm_joint_names_g1)} arm joints, got {q_g1.shape[0]}")
|
||||
q_pin = q_g1[self._arm_reorder_g1_to_pin]
|
||||
sol_tauff = self._pin.rnea(
|
||||
self.reduced_robot.model,
|
||||
self.reduced_robot.data,
|
||||
q_pin,
|
||||
np.zeros(self.reduced_robot.model.nv),
|
||||
np.zeros(self.reduced_robot.model.nv),
|
||||
)
|
||||
return sol_tauff[self._arm_reorder_pin_to_g1]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR in convergence, plotting debug info.{e}")
|
||||
return np.zeros(self.reduced_robot.model.nv)
|
||||
@@ -27,7 +27,8 @@ import numpy as np
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex, G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
@@ -127,6 +128,8 @@ class UnitreeG1(Robot):
|
||||
self.subscribe_thread = None
|
||||
self.remote_controller = self.RemoteController()
|
||||
|
||||
self.arm_ik = G1_29_ArmIK()
|
||||
|
||||
def _subscribe_motor_state(self): # polls robot state @ 250Hz
|
||||
while not self._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
@@ -361,6 +364,20 @@ class UnitreeG1(Robot):
|
||||
self.msg.motor_cmd[motor.value].kd = self.kd[motor.value]
|
||||
self.msg.motor_cmd[motor.value].tau = 0
|
||||
|
||||
if self.config.gravity_compensation:
|
||||
# Build action_np from motor commands (arm joints are indices 15-28, local indices 0-13)
|
||||
action_np = np.zeros(14)
|
||||
arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value # 15
|
||||
for joint in G1_29_JointArmIndex:
|
||||
local_idx = joint.value - arm_start_idx
|
||||
action_np[local_idx] = self.msg.motor_cmd[joint.value].q
|
||||
tau = self.arm_ik.solve_tau(action_np)
|
||||
|
||||
# Apply tau back to motor commands
|
||||
for joint in G1_29_JointArmIndex:
|
||||
local_idx = joint.value - arm_start_idx
|
||||
self.msg.motor_cmd[joint.value].tau = tau[local_idx]
|
||||
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
return action
|
||||
|
||||
@@ -60,6 +60,14 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
from .reachy2 import Reachy2Robot
|
||||
|
||||
return Reachy2Robot(config)
|
||||
elif config.type == "openarm_follower":
|
||||
from .openarm_follower import OpenArmFollower
|
||||
|
||||
return OpenArmFollower(config)
|
||||
elif config.type == "bi_openarm_follower":
|
||||
from .bi_openarm_follower import BiOpenArmFollower
|
||||
|
||||
return BiOpenArmFollower(config)
|
||||
elif config.type == "mock_robot":
|
||||
from tests.mocks.mock_robot import MockRobot
|
||||
|
||||
|
||||
@@ -36,23 +36,28 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
lekiwi,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_so_leader,
|
||||
homunculus,
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.utils import init_logging
|
||||
@@ -81,8 +86,11 @@ def calibrate(cfg: CalibrateConfig):
|
||||
device = make_teleoperator_from_config(cfg.device)
|
||||
|
||||
device.connect(calibrate=False)
|
||||
device.calibrate()
|
||||
device.disconnect()
|
||||
|
||||
try:
|
||||
device.calibrate()
|
||||
finally:
|
||||
device.disconnect()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
Edit LeRobot datasets using various transformation tools.
|
||||
|
||||
This script allows you to delete episodes, split datasets, merge datasets,
|
||||
remove features, and convert image datasets to video format.
|
||||
remove features, modify tasks, and convert image datasets to video format.
|
||||
When new_repo_id is specified, creates a new dataset.
|
||||
|
||||
Usage Examples:
|
||||
@@ -66,23 +66,42 @@ Remove camera feature:
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
|
||||
Convert image dataset to video format (saves locally):
|
||||
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Pick up the cube and place it"
|
||||
|
||||
Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}'
|
||||
|
||||
Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Default task" \
|
||||
--operation.episode_tasks '{"5": "Special task for episode 5"}'
|
||||
|
||||
Convert image dataset to video format and save locally:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
Convert image dataset and save with new repo_id:
|
||||
Convert image dataset to video format and save with new repo_id:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video
|
||||
--operation.type convert_image_to_video
|
||||
|
||||
Convert and push to hub:
|
||||
Convert image dataset to video format and push to hub:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
Using JSON config file:
|
||||
@@ -92,24 +111,20 @@ Using JSON config file:
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
convert_image_to_video_dataset,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
modify_tasks,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import write_stats, write_tasks
|
||||
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
@@ -138,8 +153,15 @@ class RemoveFeatureConfig:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvertToVideoConfig:
|
||||
type: str = "convert_to_video"
|
||||
class ModifyTasksConfig:
|
||||
type: str = "modify_tasks"
|
||||
new_task: str | None = None
|
||||
episode_tasks: dict[str, str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvertImageToVideoConfig:
|
||||
type: str = "convert_image_to_video"
|
||||
output_dir: str | None = None
|
||||
vcodec: str = "libsvtav1"
|
||||
pix_fmt: str = "yuv420p"
|
||||
@@ -148,12 +170,21 @@ class ConvertToVideoConfig:
|
||||
fast_decode: int = 0
|
||||
episode_indices: list[int] | None = None
|
||||
num_workers: int = 4
|
||||
max_episodes_per_batch: int | None = None
|
||||
max_frames_per_batch: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertToVideoConfig
|
||||
operation: (
|
||||
DeleteEpisodesConfig
|
||||
| SplitConfig
|
||||
| MergeConfig
|
||||
| RemoveFeatureConfig
|
||||
| ModifyTasksConfig
|
||||
| ConvertImageToVideoConfig
|
||||
)
|
||||
root: str | None = None
|
||||
new_repo_id: str | None = None
|
||||
push_to_hub: bool = False
|
||||
@@ -297,362 +328,49 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
|
||||
|
||||
def save_episode_images_for_video(
|
||||
dataset: LeRobotDataset,
|
||||
imgs_dir: Path,
|
||||
img_key: str,
|
||||
episode_index: int,
|
||||
num_workers: int = 4,
|
||||
) -> None:
|
||||
"""Save images from a specific episode and camera to disk for video encoding.
|
||||
def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
||||
if not isinstance(cfg.operation, ModifyTasksConfig):
|
||||
raise ValueError("Operation config must be ModifyTasksConfig")
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset to extract images from
|
||||
imgs_dir: Directory to save images to
|
||||
img_key: The image key (camera) to extract
|
||||
episode_index: Index of the episode to save
|
||||
num_workers: Number of threads for parallel image saving
|
||||
"""
|
||||
# Create directory
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
new_task = cfg.operation.new_task
|
||||
episode_tasks_raw = cfg.operation.episode_tasks
|
||||
|
||||
# Get dataset without torch format for PIL image access
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
if new_task is None and episode_tasks_raw is None:
|
||||
raise ValueError("Must specify at least one of new_task or episode_tasks for modify_tasks operation")
|
||||
|
||||
# Select only this camera's images
|
||||
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||
# Warn about in-place modification behavior
|
||||
if cfg.new_repo_id is not None:
|
||||
logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.")
|
||||
|
||||
# Get episode start and end indices
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.")
|
||||
|
||||
# Get all items for this episode
|
||||
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||
# Convert episode_tasks keys from string to int if needed (CLI passes strings)
|
||||
episode_tasks: dict[int, str] | None = None
|
||||
if episode_tasks_raw is not None:
|
||||
episode_tasks = {int(k): v for k, v in episode_tasks_raw.items()}
|
||||
|
||||
# Define function to save a single image
|
||||
def save_single_image(i_item_tuple):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key]
|
||||
# Use frame-XXXXXX.png format to match encode_video_frames expectations
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
return i
|
||||
logging.info(f"Modifying tasks in {cfg.repo_id}")
|
||||
if new_task:
|
||||
logging.info(f" Default task: '{new_task}'")
|
||||
if episode_tasks:
|
||||
logging.info(f" Episode-specific tasks: {episode_tasks}")
|
||||
|
||||
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
|
||||
items = list(enumerate(episode_dataset))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(save_single_image, item) for item in items]
|
||||
for future in as_completed(futures):
|
||||
future.result() # This will raise any exceptions that occurred
|
||||
|
||||
|
||||
def encode_episode_videos(
|
||||
dataset: LeRobotDataset,
|
||||
new_meta: LeRobotDatasetMetadata,
|
||||
episode_index: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
fast_decode: int,
|
||||
temp_dir: Path,
|
||||
num_image_workers: int = 4,
|
||||
) -> dict[str, dict]:
|
||||
"""Encode videos for a single episode and return video metadata.
|
||||
|
||||
Args:
|
||||
dataset: Source dataset with images
|
||||
new_meta: Metadata object for the new video dataset
|
||||
episode_index: Episode index to process
|
||||
vcodec: Video codec
|
||||
pix_fmt: Pixel format
|
||||
g: Group of pictures size
|
||||
crf: Constant rate factor
|
||||
fast_decode: Fast decode tuning
|
||||
temp_dir: Temporary directory for images
|
||||
num_image_workers: Number of workers for saving images
|
||||
|
||||
Returns:
|
||||
Dictionary mapping video keys to their metadata (chunk_index, file_index, timestamps)
|
||||
"""
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
video_metadata = {}
|
||||
fps = int(dataset.fps) # Convert to int for PyAV compatibility
|
||||
episode_length = dataset.meta.episodes["length"][episode_index]
|
||||
episode_duration = episode_length / dataset.fps # Use original fps for duration calculation
|
||||
|
||||
for img_key in img_keys:
|
||||
# Save images temporarily
|
||||
imgs_dir = temp_dir / f"episode_{episode_index:06d}" / img_key
|
||||
save_episode_images_for_video(dataset, imgs_dir, img_key, episode_index, num_image_workers)
|
||||
|
||||
# Determine chunk and file indices
|
||||
# For simplicity, we'll put each episode in its own file
|
||||
chunk_idx = episode_index // new_meta.chunks_size
|
||||
file_idx = episode_index % new_meta.chunks_size
|
||||
|
||||
# Create video path in the new dataset structure
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encode video
|
||||
encode_video_frames(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Clean up temporary images
|
||||
shutil.rmtree(imgs_dir)
|
||||
|
||||
# Store video metadata
|
||||
video_metadata[img_key] = {
|
||||
f"videos/{img_key}/chunk_index": chunk_idx,
|
||||
f"videos/{img_key}/file_index": file_idx,
|
||||
f"videos/{img_key}/from_timestamp": 0.0,
|
||||
f"videos/{img_key}/to_timestamp": episode_duration,
|
||||
}
|
||||
|
||||
return video_metadata
|
||||
|
||||
|
||||
def convert_dataset_to_videos(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int = 2,
|
||||
crf: int = 30,
|
||||
fast_decode: int = 0,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
) -> LeRobotDataset:
|
||||
"""Convert image-based dataset to video-based dataset.
|
||||
|
||||
Creates a new LeRobotDataset with videos instead of images, following the proper
|
||||
LeRobot dataset structure with videos stored in chunked MP4 files.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Directory to save the new video dataset
|
||||
repo_id: Repository ID for the new dataset (default: original_id + "_video")
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
crf: Constant rate factor (default: 30)
|
||||
fast_decode: Fast decode tuning (default: 0)
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
|
||||
Returns:
|
||||
New LeRobotDataset with videos
|
||||
"""
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
|
||||
)
|
||||
|
||||
# Get all image keys
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
if len(img_keys) == 0:
|
||||
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
|
||||
|
||||
# Determine which episodes to process
|
||||
if episode_indices is None:
|
||||
episode_indices = list(range(dataset.meta.total_episodes))
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_video"
|
||||
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
for key, value in dataset.meta.features.items():
|
||||
if key not in img_keys:
|
||||
new_features[key] = value
|
||||
else:
|
||||
# Convert image key to video format
|
||||
new_features[key] = value.copy()
|
||||
new_features[key]["dtype"] = "video" # Change dtype from "image" to "video"
|
||||
# Video info will be updated after episodes are encoded
|
||||
|
||||
# Create new metadata for video dataset
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=new_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=True,
|
||||
chunks_size=dataset.meta.chunks_size,
|
||||
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
|
||||
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
|
||||
modified_dataset = modify_tasks(
|
||||
dataset,
|
||||
new_task=new_task,
|
||||
episode_tasks=episode_tasks,
|
||||
)
|
||||
|
||||
# Create temporary directory for image extraction
|
||||
temp_dir = output_dir / "temp_images"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
logging.info(f"Dataset modified at {dataset.root}")
|
||||
logging.info(f"Tasks: {list(modified_dataset.meta.tasks.index)}")
|
||||
|
||||
# Process each episode
|
||||
all_episode_metadata = []
|
||||
|
||||
try:
|
||||
for ep_idx in tqdm(episode_indices, desc="Converting episodes to videos"):
|
||||
# Get episode metadata from source
|
||||
src_episode = dataset.meta.episodes[ep_idx]
|
||||
|
||||
# Encode videos for this episode
|
||||
video_metadata = encode_episode_videos(
|
||||
dataset=dataset,
|
||||
new_meta=new_meta,
|
||||
episode_index=ep_idx,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
temp_dir=temp_dir,
|
||||
num_image_workers=num_workers,
|
||||
)
|
||||
|
||||
# Build episode metadata
|
||||
episode_meta = {
|
||||
"episode_index": ep_idx,
|
||||
"length": src_episode["length"],
|
||||
"dataset_from_index": ep_idx * src_episode["length"],
|
||||
"dataset_to_index": (ep_idx + 1) * src_episode["length"],
|
||||
}
|
||||
|
||||
# Add video metadata
|
||||
for img_key in img_keys:
|
||||
episode_meta.update(video_metadata[img_key])
|
||||
|
||||
# Add data chunk/file info (using same structure as source)
|
||||
if "data/chunk_index" in src_episode:
|
||||
episode_meta["data/chunk_index"] = src_episode["data/chunk_index"]
|
||||
episode_meta["data/file_index"] = src_episode["data/file_index"]
|
||||
|
||||
all_episode_metadata.append(episode_meta)
|
||||
|
||||
# Copy and transform data files (removing image columns)
|
||||
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
|
||||
|
||||
# Save episode metadata
|
||||
episodes_df = pd.DataFrame(all_episode_metadata)
|
||||
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
|
||||
episodes_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes_df.to_parquet(episodes_path, index=False)
|
||||
|
||||
# Update metadata info
|
||||
new_meta.info["total_episodes"] = len(episode_indices)
|
||||
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata)
|
||||
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
||||
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys (now videos)
|
||||
# We need to manually set video info since update_video_info() checks video_keys first
|
||||
for img_key in img_keys:
|
||||
if not new_meta.features[img_key].get("info", None):
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
||||
|
||||
from lerobot.datasets.utils import write_info
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
# Copy stats and tasks
|
||||
if dataset.meta.stats is not None:
|
||||
# Remove image stats
|
||||
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
|
||||
write_stats(new_stats, new_meta.root)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
|
||||
finally:
|
||||
# Clean up temporary directory
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
logging.info(f"✓ Completed converting {dataset.repo_id} to video format")
|
||||
logging.info(f"New dataset saved to: {output_dir}")
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {cfg.repo_id}")
|
||||
modified_dataset.push_to_hub()
|
||||
|
||||
|
||||
def _copy_data_without_images(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_indices: list[int],
|
||||
img_keys: list[str],
|
||||
) -> None:
|
||||
"""Copy data files without image columns.
|
||||
|
||||
Args:
|
||||
src_dataset: Source dataset
|
||||
dst_meta: Destination metadata
|
||||
episode_indices: Episodes to include
|
||||
img_keys: Image keys to remove
|
||||
"""
|
||||
from lerobot.datasets.utils import DATA_DIR
|
||||
|
||||
data_dir = src_dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
episode_set = set(episode_indices)
|
||||
|
||||
for src_path in tqdm(parquet_files, desc="Processing data files"):
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
|
||||
# Filter to only include selected episodes
|
||||
df = df[df["episode_index"].isin(episode_set)].copy()
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
# Remove image columns
|
||||
columns_to_drop = [col for col in img_keys if col in df.columns]
|
||||
if columns_to_drop:
|
||||
df = df.drop(columns=columns_to_drop)
|
||||
|
||||
# Get chunk and file indices from path
|
||||
relative_path = src_path.relative_to(src_dataset.root)
|
||||
chunk_dir = relative_path.parts[1]
|
||||
file_name = relative_path.parts[2]
|
||||
chunk_idx = int(chunk_dir.split("-")[1])
|
||||
file_idx = int(file_name.split("-")[1].split(".")[0])
|
||||
|
||||
# Write to destination without pandas index
|
||||
dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet"
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_path, index=False)
|
||||
|
||||
|
||||
def handle_convert_to_video(cfg: EditDatasetConfig) -> None:
|
||||
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
# Note: Parser may create any config type with the right fields, so we access fields directly
|
||||
# instead of checking isinstance()
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
@@ -664,8 +382,12 @@ def handle_convert_to_video(cfg: EditDatasetConfig) -> None:
|
||||
if cfg.new_repo_id:
|
||||
# Use new_repo_id for both local storage and hub push
|
||||
output_repo_id = cfg.new_repo_id
|
||||
output_dir = Path(cfg.root) / cfg.new_repo_id if cfg.root else HF_LEROBOT_HOME / cfg.new_repo_id
|
||||
logging.info(f"Saving to new dataset: {cfg.new_repo_id}")
|
||||
# Place new dataset as a sibling to the original dataset
|
||||
# Get the parent of the actual dataset root (not cfg.root which might be the lerobot cache dir)
|
||||
# Extract just the dataset name (after last slash) for the local directory
|
||||
local_dir_name = cfg.new_repo_id.split("/")[-1]
|
||||
output_dir = dataset.root.parent / local_dir_name
|
||||
logging.info(f"Saving to new dataset: {cfg.new_repo_id} at {output_dir}")
|
||||
elif output_dir_config:
|
||||
# Use custom output directory for local-only storage
|
||||
output_dir = Path(output_dir_config)
|
||||
@@ -675,12 +397,15 @@ def handle_convert_to_video(cfg: EditDatasetConfig) -> None:
|
||||
else:
|
||||
# Auto-generate name: append "_video" to original repo_id
|
||||
output_repo_id = f"{cfg.repo_id}_video"
|
||||
output_dir = Path(cfg.root) / output_repo_id if cfg.root else HF_LEROBOT_HOME / output_repo_id
|
||||
# Place new dataset as a sibling to the original dataset
|
||||
# Extract just the dataset name (after last slash) for the local directory
|
||||
local_dir_name = output_repo_id.split("/")[-1]
|
||||
output_dir = dataset.root.parent / local_dir_name
|
||||
logging.info(f"Saving to auto-generated location: {output_dir}")
|
||||
|
||||
logging.info(f"Converting dataset {cfg.repo_id} to video format")
|
||||
|
||||
new_dataset = convert_dataset_to_videos(
|
||||
new_dataset = convert_image_to_video_dataset(
|
||||
dataset=dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
@@ -691,6 +416,8 @@ def handle_convert_to_video(cfg: EditDatasetConfig) -> None:
|
||||
fast_decode=getattr(cfg.operation, "fast_decode", 0),
|
||||
episode_indices=getattr(cfg.operation, "episode_indices", None),
|
||||
num_workers=getattr(cfg.operation, "num_workers", 4),
|
||||
max_episodes_per_batch=getattr(cfg.operation, "max_episodes_per_batch", None),
|
||||
max_frames_per_batch=getattr(cfg.operation, "max_frames_per_batch", None),
|
||||
)
|
||||
|
||||
logging.info("Video dataset created successfully!")
|
||||
@@ -718,12 +445,14 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_merge(cfg)
|
||||
elif operation_type == "remove_feature":
|
||||
handle_remove_feature(cfg)
|
||||
elif operation_type == "convert_to_video":
|
||||
handle_convert_to_video(cfg)
|
||||
elif operation_type == "modify_tasks":
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown operation type: {operation_type}\n"
|
||||
f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video"
|
||||
f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -44,19 +44,23 @@ import numpy as np
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_so_leader,
|
||||
gamepad,
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
so_leader,
|
||||
)
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
@@ -98,26 +98,31 @@ from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
earthrover_mini_plus,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
reachy2,
|
||||
so_follower,
|
||||
unitree_g1,
|
||||
unitree_g1 as unitree_g1_robot,
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_so_leader,
|
||||
homunculus,
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
|
||||
@@ -53,12 +53,14 @@ from lerobot.processor import (
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
earthrover_mini_plus,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
reachy2,
|
||||
so_follower,
|
||||
unitree_g1,
|
||||
@@ -108,25 +110,26 @@ def replay(cfg: ReplayConfig):
|
||||
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(len(episode_frames)):
|
||||
start_episode_t = time.perf_counter()
|
||||
try:
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(len(episode_frames)):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
action[name] = action_array[i]
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
action[name] = action_array[i]
|
||||
|
||||
robot_obs = robot.get_observation()
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
processed_action = robot_action_processor((action, robot_obs))
|
||||
processed_action = robot_action_processor((action, robot_obs))
|
||||
|
||||
_ = robot.send_action(processed_action)
|
||||
_ = robot.send_action(processed_action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
|
||||
|
||||
robot.disconnect()
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
|
||||
finally:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -0,0 +1,360 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Setup and debug CAN interfaces for Damiao motors (e.g., OpenArms).
|
||||
|
||||
Examples:
|
||||
|
||||
Setup CAN interfaces with CAN FD:
|
||||
```shell
|
||||
lerobot-setup-can --mode=setup --interfaces=can0,can1,can2,can3
|
||||
```
|
||||
|
||||
Test motors on a single interface:
|
||||
```shell
|
||||
lerobot-setup-can --mode=test --interfaces=can0
|
||||
```
|
||||
|
||||
Test motors on all interfaces:
|
||||
```shell
|
||||
lerobot-setup-can --mode=test --interfaces=can0,can1,can2,can3
|
||||
```
|
||||
|
||||
Speed test:
|
||||
```shell
|
||||
lerobot-setup-can --mode=speed --interfaces=can0
|
||||
```
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.utils.import_utils import is_package_available
|
||||
|
||||
MOTOR_NAMES = {
|
||||
0x01: "joint_1",
|
||||
0x02: "joint_2",
|
||||
0x03: "joint_3",
|
||||
0x04: "joint_4",
|
||||
0x05: "joint_5",
|
||||
0x06: "joint_6",
|
||||
0x07: "joint_7",
|
||||
0x08: "gripper",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CANSetupConfig:
|
||||
mode: str = "test"
|
||||
interfaces: str = "can0" # Comma-separated, e.g. "can0,can1,can2,can3"
|
||||
bitrate: int = 1000000
|
||||
data_bitrate: int = 5000000
|
||||
use_fd: bool = True
|
||||
motor_ids: list[int] = field(default_factory=lambda: list(range(0x01, 0x09)))
|
||||
timeout: float = 1.0
|
||||
speed_iterations: int = 100
|
||||
|
||||
def get_interfaces(self) -> list[str]:
|
||||
return [i.strip() for i in self.interfaces.split(",") if i.strip()]
|
||||
|
||||
|
||||
def check_interface_status(interface: str) -> tuple[bool, str, bool]:
|
||||
"""Check if CAN interface is UP and configured."""
|
||||
try:
|
||||
result = subprocess.run(["ip", "link", "show", interface], capture_output=True, text=True) # nosec B607
|
||||
if result.returncode != 0:
|
||||
return False, "Interface not found", False
|
||||
|
||||
output = result.stdout
|
||||
is_up = "UP" in output
|
||||
is_fd = "fd on" in output.lower() or "canfd" in output.lower()
|
||||
status = "UP" if is_up else "DOWN"
|
||||
if is_fd:
|
||||
status += " (CAN FD)"
|
||||
|
||||
return is_up, status, is_fd
|
||||
except FileNotFoundError:
|
||||
return False, "ip command not found", False
|
||||
|
||||
|
||||
def setup_interface(interface: str, bitrate: int, data_bitrate: int, use_fd: bool) -> bool:
|
||||
"""Configure a CAN interface."""
|
||||
try:
|
||||
subprocess.run(["sudo", "ip", "link", "set", interface, "down"], check=False, capture_output=True) # nosec B607
|
||||
|
||||
cmd = ["sudo", "ip", "link", "set", interface, "type", "can", "bitrate", str(bitrate)]
|
||||
if use_fd:
|
||||
cmd.extend(["dbitrate", str(data_bitrate), "fd", "on"])
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True) # nosec B607
|
||||
if result.returncode != 0:
|
||||
print(f" ✗ Failed to configure: {result.stderr}")
|
||||
return False
|
||||
|
||||
result = subprocess.run( # nosec B607
|
||||
["sudo", "ip", "link", "set", interface, "up"], capture_output=True, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print(f" ✗ Failed to bring up: {result.stderr}")
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f" ✗ Error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_motor(bus, motor_id: int, timeout: float, use_fd: bool):
|
||||
"""Test a single motor and return responses."""
|
||||
import can
|
||||
|
||||
enable_msg = can.Message(
|
||||
arbitration_id=motor_id,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
|
||||
is_extended_id=False,
|
||||
is_fd=use_fd,
|
||||
)
|
||||
|
||||
try:
|
||||
bus.send(enable_msg)
|
||||
except Exception as e:
|
||||
return None, f"Send error: {e}"
|
||||
|
||||
responses = []
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
msg = bus.recv(timeout=0.1)
|
||||
if msg:
|
||||
responses.append((msg.arbitration_id, msg.data.hex(), getattr(msg, "is_fd", False)))
|
||||
|
||||
disable_msg = can.Message(
|
||||
arbitration_id=motor_id,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD],
|
||||
is_extended_id=False,
|
||||
is_fd=use_fd,
|
||||
)
|
||||
try:
|
||||
bus.send(disable_msg)
|
||||
except Exception:
|
||||
print(f"Error sending message to motor 0x{motor_id:02X}")
|
||||
|
||||
return responses, None
|
||||
|
||||
|
||||
def test_interface(cfg: CANSetupConfig, interface: str):
|
||||
"""Test all motors on a CAN interface."""
|
||||
import can
|
||||
|
||||
is_up, status, _ = check_interface_status(interface)
|
||||
print(f"\n{interface}: {status}")
|
||||
|
||||
if not is_up:
|
||||
print(f" ⚠ Interface is not UP. Run: lerobot-setup-can --mode=setup --interfaces {interface}")
|
||||
return {}
|
||||
|
||||
try:
|
||||
kwargs = {"channel": interface, "interface": "socketcan", "bitrate": cfg.bitrate}
|
||||
if cfg.use_fd:
|
||||
kwargs.update({"data_bitrate": cfg.data_bitrate, "fd": True})
|
||||
bus = can.interface.Bus(**kwargs)
|
||||
except Exception as e:
|
||||
print(f" ✗ Connection failed: {e}")
|
||||
return {}
|
||||
|
||||
results = {}
|
||||
try:
|
||||
while bus.recv(timeout=0.01):
|
||||
pass
|
||||
|
||||
for motor_id in cfg.motor_ids:
|
||||
motor_name = MOTOR_NAMES.get(motor_id, f"motor_0x{motor_id:02X}")
|
||||
responses, error = test_motor(bus, motor_id, cfg.timeout, cfg.use_fd)
|
||||
|
||||
if error:
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ {error}")
|
||||
results[motor_id] = {"found": False, "error": error}
|
||||
elif responses:
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✓ FOUND")
|
||||
for resp_id, data, is_fd in responses:
|
||||
fd_flag = " [FD]" if is_fd else ""
|
||||
print(f" → Response 0x{resp_id:02X}{fd_flag}: {data}")
|
||||
results[motor_id] = {"found": True, "responses": responses}
|
||||
else:
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ No response")
|
||||
results[motor_id] = {"found": False}
|
||||
|
||||
time.sleep(0.05)
|
||||
finally:
|
||||
bus.shutdown()
|
||||
|
||||
found = sum(1 for r in results.values() if r.get("found"))
|
||||
print(f"\n Summary: {found}/{len(cfg.motor_ids)} motors found")
|
||||
return results
|
||||
|
||||
|
||||
def speed_test(cfg: CANSetupConfig, interface: str):
|
||||
"""Test communication speed with motors."""
|
||||
import can
|
||||
|
||||
is_up, status, _ = check_interface_status(interface)
|
||||
if not is_up:
|
||||
print(f"{interface}: {status} - skipping")
|
||||
return
|
||||
|
||||
print(f"\n{interface}: Running speed test ({cfg.speed_iterations} iterations)...")
|
||||
|
||||
try:
|
||||
kwargs = {"channel": interface, "interface": "socketcan", "bitrate": cfg.bitrate}
|
||||
if cfg.use_fd:
|
||||
kwargs.update({"data_bitrate": cfg.data_bitrate, "fd": True})
|
||||
bus = can.interface.Bus(**kwargs)
|
||||
except Exception as e:
|
||||
print(f" ✗ Connection failed: {e}")
|
||||
return
|
||||
|
||||
responding_motor = None
|
||||
for motor_id in cfg.motor_ids:
|
||||
responses, _ = test_motor(bus, motor_id, 0.5, cfg.use_fd)
|
||||
if responses:
|
||||
responding_motor = motor_id
|
||||
break
|
||||
|
||||
if not responding_motor:
|
||||
print(" ✗ No responding motors found")
|
||||
bus.shutdown()
|
||||
return
|
||||
|
||||
print(f" Testing with motor 0x{responding_motor:02X}...")
|
||||
latencies = []
|
||||
|
||||
for _ in range(cfg.speed_iterations):
|
||||
start = time.perf_counter()
|
||||
msg = can.Message(
|
||||
arbitration_id=responding_motor,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
|
||||
is_extended_id=False,
|
||||
is_fd=cfg.use_fd,
|
||||
)
|
||||
bus.send(msg)
|
||||
resp = bus.recv(timeout=0.1)
|
||||
if resp:
|
||||
latencies.append((time.perf_counter() - start) * 1000)
|
||||
|
||||
bus.shutdown()
|
||||
|
||||
if latencies:
|
||||
avg_latency = sum(latencies) / len(latencies)
|
||||
hz = 1000.0 / avg_latency if avg_latency > 0 else 0
|
||||
print(f" ✓ Success rate: {len(latencies)}/{cfg.speed_iterations}")
|
||||
print(f" ✓ Avg latency: {avg_latency:.2f} ms")
|
||||
print(f" ✓ Max frequency: {hz:.1f} Hz")
|
||||
else:
|
||||
print(" ✗ No successful responses")
|
||||
|
||||
|
||||
def run_setup(cfg: CANSetupConfig):
|
||||
"""Setup CAN interfaces."""
|
||||
print("=" * 50)
|
||||
print("CAN Interface Setup")
|
||||
print("=" * 50)
|
||||
print(f"Mode: {'CAN FD' if cfg.use_fd else 'CAN 2.0'}")
|
||||
print(f"Bitrate: {cfg.bitrate / 1_000_000:.1f} Mbps")
|
||||
if cfg.use_fd:
|
||||
print(f"Data bitrate: {cfg.data_bitrate / 1_000_000:.1f} Mbps")
|
||||
print()
|
||||
|
||||
interfaces = cfg.get_interfaces()
|
||||
for interface in interfaces:
|
||||
print(f"Configuring {interface}...")
|
||||
if setup_interface(interface, cfg.bitrate, cfg.data_bitrate, cfg.use_fd):
|
||||
is_up, status, _ = check_interface_status(interface)
|
||||
print(f" ✓ {interface}: {status}")
|
||||
else:
|
||||
print(f" ✗ {interface}: Failed")
|
||||
|
||||
print("\nSetup complete!")
|
||||
print("\nNext: Test motors with:")
|
||||
print(f" lerobot-setup-can --mode=test --interfaces {','.join(interfaces)}")
|
||||
|
||||
|
||||
def run_test(cfg: CANSetupConfig):
|
||||
"""Test motors on CAN interfaces."""
|
||||
print("=" * 50)
|
||||
print("CAN Motor Test")
|
||||
print("=" * 50)
|
||||
print(f"Testing motors 0x{min(cfg.motor_ids):02X}-0x{max(cfg.motor_ids):02X}")
|
||||
print(f"Mode: {'CAN FD' if cfg.use_fd else 'CAN 2.0'}")
|
||||
print()
|
||||
|
||||
interfaces = cfg.get_interfaces()
|
||||
all_results = {}
|
||||
for interface in interfaces:
|
||||
all_results[interface] = test_interface(cfg, interface)
|
||||
|
||||
total_found = sum(sum(1 for r in res.values() if r.get("found")) for res in all_results.values())
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Summary")
|
||||
print("=" * 50)
|
||||
print(f"Total motors found: {total_found}")
|
||||
|
||||
if total_found == 0:
|
||||
print("\n⚠ No motors found! Check:")
|
||||
print(" 1. Motors are powered (24V)")
|
||||
print(" 2. CAN wiring (CANH, CANL, GND)")
|
||||
print(" 3. Motor timeout parameter > 0 (use Damiao tools)")
|
||||
print(" 4. 120Ω termination at both cable ends")
|
||||
print(f" 5. Interface configured: lerobot-setup-can --mode=setup --interfaces {interfaces[0]}")
|
||||
|
||||
|
||||
def run_speed(cfg: CANSetupConfig):
|
||||
"""Run speed tests on CAN interfaces."""
|
||||
print("=" * 50)
|
||||
print("CAN Speed Test")
|
||||
print("=" * 50)
|
||||
|
||||
for interface in cfg.get_interfaces():
|
||||
speed_test(cfg, interface)
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def setup_can(cfg: CANSetupConfig):
|
||||
if not is_package_available("can"):
|
||||
print("Error: python-can not installed. Install with: pip install python-can")
|
||||
sys.exit(1)
|
||||
|
||||
if cfg.mode == "setup":
|
||||
run_setup(cfg)
|
||||
elif cfg.mode == "test":
|
||||
run_test(cfg)
|
||||
elif cfg.mode == "speed":
|
||||
run_speed(cfg)
|
||||
else:
|
||||
print(f"Unknown mode: {cfg.mode}")
|
||||
print("Available modes: setup, test, speed")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
setup_can()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -70,18 +70,22 @@ from lerobot.processor import (
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
earthrover_mini_plus,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
reachy2,
|
||||
so_follower,
|
||||
unitree_g1 as unitree_g1_robot,
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_so_leader,
|
||||
gamepad,
|
||||
homunculus,
|
||||
@@ -89,8 +93,10 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
@@ -148,92 +148,6 @@ def update_policy(
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
def get_default_peft_configuration(policy_type):
|
||||
"""Build a basic PEFT configuration for the given policy type assuming that we train a policy from a checkpoint."""
|
||||
|
||||
common_projections = "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
||||
|
||||
if policy_type == "smolvla":
|
||||
return {
|
||||
"target_modules": rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))",
|
||||
"modules_to_save": [],
|
||||
}
|
||||
elif policy_type in ("pi0", "pi05"):
|
||||
return {
|
||||
"target_modules": rf"(.*\.gemma_expert\..*\.self_attn.(q|v)_proj|model\.({common_projections}))",
|
||||
"modules_to_save": [],
|
||||
}
|
||||
|
||||
return {"modules_to_save": None}
|
||||
|
||||
|
||||
def wrap_policy_in_peft_model(cfg, policy):
|
||||
from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType, get_peft_model
|
||||
|
||||
# Disable all gradients because we'll only train the parameters selected by the PEFT method.
|
||||
# Layers that should receive gradients anyway need to be listed in `modules_to_save`.
|
||||
for p in policy.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
if not cfg.policy.pretrained_path:
|
||||
raise ValueError(
|
||||
"Training from scratch using PEFT. This is unlikely to yield good results. "
|
||||
"Supply a `policy.path` to fine-tune an existing model."
|
||||
)
|
||||
|
||||
if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights:
|
||||
logging.warning(
|
||||
"Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. Set "
|
||||
"`load_vlm_weights=True` to fine-tune the existing policy."
|
||||
)
|
||||
|
||||
peft_config_policy = get_default_peft_configuration(cfg.policy.type)
|
||||
peft_config_cli = dataclasses.asdict(cfg.peft) if cfg.peft else {}
|
||||
peft_config_cli["modules_to_save"] = peft_config_cli["full_training_modules"] # compatibility with PEFT
|
||||
peft_method_type = PeftType[peft_config_cli["method_type"].upper()]
|
||||
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]
|
||||
|
||||
# Handle specific CLI overrides
|
||||
for key in ["target_modules", "modules_to_save", "r"]:
|
||||
if peft_config_cli[key] is not None:
|
||||
peft_config_policy[key] = peft_config_cli[key]
|
||||
|
||||
if "target_modules" not in peft_config_policy:
|
||||
raise ValueError(
|
||||
f"There is no default `target_modules` value for policy {cfg.policy.type}. Please pass it manually."
|
||||
)
|
||||
|
||||
# Init method depends on the used PEFT method, your specific PEFT method
|
||||
# might not be considered here, in that case an error is raised.
|
||||
if peft_config_cli["init_type"] is not None:
|
||||
if peft_method_type == "LORA":
|
||||
peft_config_policy["init_lora_weights"] = peft_config_cli["init_type"]
|
||||
elif peft_method_type == "MISS":
|
||||
peft_config_policy["init_weights"] = peft_config_cli["init_type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Init type {peft_config_cli['init_type']} unknown for PEFT method {peft_method_type}."
|
||||
)
|
||||
|
||||
# PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the
|
||||
# correct base model in `make_policy` since in a PEFT loading setting we only get the path to the
|
||||
# adapter, not the base model.
|
||||
if policy.config.pretrained_path:
|
||||
policy.name_or_path = str(policy.config.pretrained_path)
|
||||
|
||||
# Finally wrap the policy in a PEFT model
|
||||
policy = get_peft_model(
|
||||
policy,
|
||||
peft_config_cls(**peft_config_policy),
|
||||
)
|
||||
|
||||
# Make sure that the config is tagged as using PEFT so that the loading code can take the
|
||||
# appropriate steps to use the adapter weights and the PEFT config instead of the full model weights.
|
||||
policy.config.use_peft = True
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
"""
|
||||
@@ -263,8 +177,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||
# Force the device to be CPU when policy.device is set to CPU.
|
||||
# Note (maractin): cfg.policy may be None before validate() fully loads from pretrained_path
|
||||
force_cpu = cfg.policy is not None and cfg.policy.device == "cpu"
|
||||
force_cpu = cfg.policy.device == "cpu"
|
||||
accelerator = Accelerator(
|
||||
step_scheduler_with_optimizer=False,
|
||||
kwargs_handlers=[ddp_kwargs],
|
||||
@@ -312,9 +225,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
eval_env = None
|
||||
if cfg.eval_freq > 0 and cfg.env is not None:
|
||||
if is_main_process:
|
||||
logging.info("Creating env")
|
||||
if cfg.eval_freq > 0 and cfg.env is not None and is_main_process:
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
if is_main_process:
|
||||
@@ -327,7 +239,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
if cfg.peft is not None:
|
||||
logging.info("Using PEFT! Wrapping model.")
|
||||
policy = wrap_policy_in_peft_model(cfg, policy)
|
||||
# Convert CLI peft config to dict for overrides
|
||||
peft_cli_overrides = dataclasses.asdict(cfg.peft)
|
||||
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
|
||||
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -423,13 +337,28 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
# Filter out episodes - hardcoded list of bad episodes to discard
|
||||
episodes_to_discard = {
|
||||
133, 134, 502, 565, 568, 657, 910, 944, 1039, 1209, 1346, 1360, 1379,
|
||||
1605, 1690, 1790, 2105, 2106, 2122, 2118, 2156, 2575, 2764, 2876, 2925,
|
||||
3100, 3381, 3405, 3406, 68, 1214, 1456,
|
||||
}
|
||||
all_episodes = set(range(dataset.meta.total_episodes))
|
||||
episodes_to_use = dataset.episodes # May be None (all episodes) or a subset
|
||||
# If dataset.episodes is already filtered, start from that subset
|
||||
if episodes_to_use is not None:
|
||||
episodes_to_use = [ep for ep in episodes_to_use if ep not in episodes_to_discard]
|
||||
else:
|
||||
episodes_to_use = sorted(all_episodes - episodes_to_discard)
|
||||
|
||||
if hasattr(cfg.policy, "drop_n_last_frames") or episodes_to_use is not None:
|
||||
shuffle = False
|
||||
drop_n_last = getattr(cfg.policy, "drop_n_last_frames", 0)
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
episode_indices_to_use=episodes_to_use,
|
||||
drop_n_last_frames=drop_n_last,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .bi_openarm_leader import BiOpenArmLeader
|
||||
from .config_bi_openarm_leader import BiOpenArmLeaderConfig
|
||||
|
||||
__all__ = ["BiOpenArmLeader", "BiOpenArmLeaderConfig"]
|
||||
@@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig
|
||||
|
||||
from ..openarm_leader import OpenArmLeader
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_bi_openarm_leader import BiOpenArmLeaderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiOpenArmLeader(Teleoperator):
|
||||
"""
|
||||
Bimanual OpenArm Leader Arms
|
||||
"""
|
||||
|
||||
config_class = BiOpenArmLeaderConfig
|
||||
name = "bi_openarm_leader"
|
||||
|
||||
def __init__(self, config: BiOpenArmLeaderConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
left_arm_config = OpenArmLeaderConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.left_arm_config.port,
|
||||
can_interface=config.left_arm_config.can_interface,
|
||||
use_can_fd=config.left_arm_config.use_can_fd,
|
||||
can_bitrate=config.left_arm_config.can_bitrate,
|
||||
can_data_bitrate=config.left_arm_config.can_data_bitrate,
|
||||
motor_config=config.left_arm_config.motor_config,
|
||||
manual_control=config.left_arm_config.manual_control,
|
||||
position_kd=config.left_arm_config.position_kd,
|
||||
position_kp=config.left_arm_config.position_kp,
|
||||
)
|
||||
|
||||
right_arm_config = OpenArmLeaderConfig(
|
||||
id=f"{config.id}_right" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.right_arm_config.port,
|
||||
can_interface=config.right_arm_config.can_interface,
|
||||
use_can_fd=config.right_arm_config.use_can_fd,
|
||||
can_bitrate=config.right_arm_config.can_bitrate,
|
||||
can_data_bitrate=config.right_arm_config.can_data_bitrate,
|
||||
motor_config=config.right_arm_config.motor_config,
|
||||
manual_control=config.right_arm_config.manual_control,
|
||||
position_kd=config.right_arm_config.position_kd,
|
||||
position_kp=config.right_arm_config.position_kp,
|
||||
)
|
||||
|
||||
self.left_arm = OpenArmLeader(left_arm_config)
|
||||
self.right_arm = OpenArmLeader(right_arm_config)
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
left_arm_features = self.left_arm.action_features
|
||||
right_arm_features = self.right_arm.action_features
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_features.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_features.items()},
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
def get_action(self) -> RobotAction:
|
||||
action_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
left_action = self.left_arm.get_action()
|
||||
action_dict.update({f"left_{key}": value for key, value in left_action.items()})
|
||||
|
||||
# Add "right_" prefix
|
||||
right_action = self.right_arm.get_action()
|
||||
action_dict.update({f"right_{key}": value for key, value in right_action.items()})
|
||||
|
||||
return action_dict
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfigBase
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("bi_openarm_leader")
|
||||
@dataclass
|
||||
class BiOpenArmLeaderConfig(TeleoperatorConfig):
|
||||
"""Configuration class for Bi OpenArm Follower robots."""
|
||||
|
||||
left_arm_config: OpenArmLeaderConfigBase
|
||||
right_arm_config: OpenArmLeaderConfigBase
|
||||
@@ -18,7 +18,7 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..so_leader import SOLeader
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -92,10 +92,8 @@ class BiSOLeader(Teleoperator):
|
||||
self.left_arm.setup_motors()
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
action_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user