Compare commits

...

12 Commits

Author SHA1 Message Date
Jade Choghari 435f12f6e4 pre-commit 2026-02-10 13:07:40 +01:00
Stepan Feduniak a1193df2d7 fix the types 2026-01-31 09:56:15 +00:00
Jade Choghari b18cef2e26 feat(dataset): add subtask support (#2860)
* add subtask

* remove folder

* add docs

* update doc

* add testing

* update test

* update constant naming + doc

* more docs
2026-01-30 19:29:37 +01:00
Caroline Pascal 5c6182176f fix(find zmq): adding a clearer not implemented warning for the ZMQ find_cameras method (#2879)
Co-authored-by: Martino Russi <77496684+nepyope@users.noreply.github.com>
2026-01-30 16:58:13 +01:00
Caroline Pascal 55c0471db9 docs(cameras): revising and improving docs on cameras (#2878)
* docs(cameras): revising and improving docs on cameras

* resolving copilot comments
2026-01-30 16:57:56 +01:00
Michel Aractingi ec04b7ce3a Feat(dataset_tools.py) Add modify tasks tool (#2875)
* feat(datasets): add modify_tasks function for in-place task editing

Add a new utility function to modify tasks in LeRobotDataset in-place.
This allows users to:
- Set a single task for all episodes
- Set specific tasks for individual episodes
- Combine a default task with per-episode overrides

* feat(edit-dataset): add CLI support for modify_tasks operation

Integrate the modify_tasks function into lerobot_edit_dataset CLI.
Users can now modify dataset tasks via command line:
Supports setting a default task, per-episode tasks, or both combined.

* test(datasets): add tests for modify_tasks function

Add comprehensive test coverage for the modify_tasks utility:
- Single task for all episodes
- Episode-specific task assignment
- Default task with per-episode overrides
- Error handling for missing/invalid arguments
- Verification of task_index correctness
- In-place modification behavior
- Metadata preservation

* respond to copilot review
2026-01-30 13:19:42 +01:00
Michel Aractingi 04cbf669cf fix(sac): make temperature a property to fix checkpoint resume bug (#2877)
* fix(sac): make temperature a property to fix checkpoint resume bug

Temperature was stored as a plain float and not restored after loading
a checkpoint, causing incorrect loss computations until update_temperature()
was called. Changed to a property that always computes from log_alpha,
ensuring correct behavior after checkpoint loading.

* simplify docstrings
2026-01-30 12:23:22 +01:00
Steven Palma 3409ef0dc2 refactor(cameras): cameras API extension (#2808)
* feat(cameras): add new read_latest() method

* fix(cameras): fix threading bug + clear state

* refactor(cameras): multiple improvements

* feat(camera): add context manager to camera base class

* chore(camera): slight modifications to opencv

* test(cameras): update opencv tests according to the changes

* refactor(cameras): reflect desing changes to realsense + deal with depth

* test(cameras): fix realsense tests accordingly to new changes

* refactor(cameras): update reachymini and zmq accordingly

* chore: wrap resource sensitive examples into a try/finally

* test(cameras): add test for new read_latest

* test(cameras): fix problem with image artifact in opencv tests

* test(cameras): fix test_read_latest_high_frequency expectations

* Apply suggestions from code review 1

Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* chore(cameras): address feedback

* feat(cameras): add max_age_ms check in read_latest

* test(cameras): fix read_latest tests

* chore(redundancies): removing redundancies in Reachy 2 camera class

* fix(warmup): replacing the arbitrary time.sleep in by an actual warmup in the RealSense camera class

* chore(format): formatting latest changes

* chore(warning): adding a "to be implemented" warning for read_latest() in Camera base class

* chore(warning): making read_latest() warning message shorter and clearer

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-01-29 11:07:47 +01:00
Steven Palma 4483184875 feat(robots): add bi manual openarm follower and leader (#2835)
* fix(motors): cleanup imports + fix signatures

* feat(motors): add damiao canbus + multiple fixes

* fix(motors): address comments -> last_state + different gains + sleep

* refactor(motors): reduce duplicated code + adressed some comments in the PR

* chore(motors): better timeouts

* tests(motors): damiao test and imports

* chore(deps): fix space

* feat(robot): add openarm leader

Co-authored-by: Pepijn <pepijn@huggingface.co>

* feat(robot): add openarm follower

Co-authored-by: Pepijn <pepijn@huggingface.co>

* refactor(robot): remove mechanical compensations and double arm assumption + rename

* chore(robots): remove left arm references

* refactor(teleop): multiple improvements to leader

* refactor(teleop): multiple improvements to leader

* feat(robots): add open arm to util CLI

* chore(robot): add alias openarm

* Apply suggestions from code review

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* chore(motors): remove normalization tables damiao

* fix(motors): imports and signatures

* feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id

* chore(motors): remove normalize from base motor class and damaio

* tests(motors): remove bad tests (to be replaced)

* chore(motors): updated import check

* fix(robots): open arm mirrored config for joint limits

* chore(motors): update position_kd gain values

* chore(robots): set to 0 if openarm is calibrated at connect time

* chore(robots): remove macos in open arm as can doesn't support it

* chore(robots): update for motor_type_str in Motor class

* chore(robots): no default value for can port in open arms

* feat(robots): add bi manual openarm follower and leader

* use constant for kp and kd range and check responses in mit_control_batch()

* Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command

* precommit format

* supress bandit as these are intentional cli commands

* fix setup-can

* add test

* skip test in ci

* nit precommit

* update doc example

* dont import can for tests

* remove comment

* Add openarms docs

* format

* update purchase link

* can to none if nit availabl;e

* add canfd option in bus

* make handshake logic similar to lerobot-can

* type hint

* type check

* add temp teleop test

* remove script

* mock class

* mock class

* ignore linter

* pre-commit

* Add command for bimanual openarm

* fix import

* fix import leader

* fix import draccus

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Pepijn <pepijn@huggingface.co>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-01-28 17:25:57 +01:00
Martino Russi 149628dfd5 add g1 teleoperation (#2791)
* add gravity compensation

* add g1 teleoperation

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2026-01-28 15:17:38 +01:00
Steven Palma bf337e716d feat(robots): add OpenArm robot & teleoperator (#2795)
* fix(motors): cleanup imports + fix signatures

* feat(motors): add damiao canbus + multiple fixes

* fix(motors): address comments -> last_state + different gains + sleep

* refactor(motors): reduce duplicated code + adressed some comments in the PR

* chore(motors): better timeouts

* tests(motors): damiao test and imports

* chore(deps): fix space

* feat(robot): add openarm leader

Co-authored-by: Pepijn <pepijn@huggingface.co>

* feat(robot): add openarm follower

Co-authored-by: Pepijn <pepijn@huggingface.co>

* refactor(robot): remove mechanical compensations and double arm assumption + rename

* chore(robots): remove left arm references

* refactor(teleop): multiple improvements to leader

* refactor(teleop): multiple improvements to leader

* feat(robots): add open arm to util CLI

* chore(robot): add alias openarm

* Apply suggestions from code review

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* chore(motors): remove normalization tables damiao

* fix(motors): imports and signatures

* feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id

* chore(motors): remove normalize from base motor class and damaio

* tests(motors): remove bad tests (to be replaced)

* chore(motors): updated import check

* fix(robots): open arm mirrored config for joint limits

* chore(motors): update position_kd gain values

* chore(robots): set to 0 if openarm is calibrated at connect time

* chore(robots): remove macos in open arm as can doesn't support it

* chore(robots): update for motor_type_str in Motor class

* chore(robots): no default value for can port in open arms

* use constant for kp and kd range and check responses in mit_control_batch()

* Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command

* precommit format

* supress bandit as these are intentional cli commands

* fix setup-can

* add test

* skip test in ci

* nit precommit

* update doc example

* dont import can for tests

* remove comment

* Add openarms docs

* format

* update purchase link

* can to none if nit availabl;e

* add canfd option in bus

* make handshake logic similar to lerobot-can

* type hint

* type check

* add temp teleop test

* remove script

* mock class

* ignore linter

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Pepijn <pepijn@huggingface.co>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-01-28 14:28:51 +01:00
Michel Aractingi 736b43f3cf Fix(aggregate.py) Aggregation of datasets when sub-datasets are already a result of a previous merge (#2861)
* Fix aggeregation of datasets when subdatasets are already a result of a previous merge

* docstring

* respond to copilot review + add regression test

* Remove unnecessary int conversion for indicies
2026-01-28 13:31:27 +01:00
72 changed files with 6003 additions and 791 deletions
+8 -2
View File
@@ -7,8 +7,6 @@
- sections:
- local: il_robots
title: Imitation Learning for Robots
- local: cameras
title: Cameras
- local: bring_your_own_policies
title: Bring Your Own Policies
- local: integrate_hardware
@@ -29,6 +27,8 @@
title: Porting Large Datasets
- local: using_dataset_tools
title: Using the Dataset Tools
- local: dataset_subtask
title: Using Subtasks in the Dataset
title: "Datasets"
- sections:
- local: act
@@ -101,11 +101,17 @@
title: Earth Rover Mini
- local: omx
title: OMX
- local: openarm
title: OpenArm
title: "Robots"
- sections:
- local: phone_teleop
title: Phone
title: "Teleoperators"
- sections:
- local: cameras
title: Cameras
title: "Sensors"
- sections:
- local: torch_accelerators
title: PyTorch accelerators
+95 -81
View File
@@ -1,12 +1,22 @@
# Cameras
LeRobot offers multiple options for video capture, including phone cameras, built-in laptop cameras, external webcams, and Intel RealSense cameras. To efficiently record frames from most cameras, you can use either the `OpenCVCamera` or `RealSenseCamera` class. For additional compatibility details on the `OpenCVCamera` class, refer to the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
LeRobot offers multiple options for video capture:
### Finding your camera
| Class | Supported Cameras |
| ----------------- | ----------------------------------- |
| `OpenCVCamera` | Phone, built-in laptop, USB webcams |
| `ZMQCamera` | Network-connected cameras |
| `RealSenseCamera` | Intel RealSense (with depth) |
| `Reachy2Camera` | Reachy 2 robot cameras |
To instantiate a camera, you need a camera identifier. This identifier might change if you reboot your computer or re-plug your camera, a behavior mostly dependant on your operating system.
> [!TIP]
> For `OpenCVCamera` compatibility details, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
To find the camera indices of the cameras plugged into your system, run the following script:
### Find your camera
Every camera requires a unique identifier to be instantiated, allowing you to distinguish between multiple connected devices.
`OpenCVCamera` and `RealSenseCamera` support auto-discovery. Run the command below to list available devices and their identifiers. Note that these identifiers may change after rebooting your computer or re-plugging the camera, depending on your operating system.
```bash
lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
@@ -14,7 +24,7 @@ lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
The output will look something like this if you have two cameras connected:
```
```bash
--- Detected Cameras ---
Camera #0:
Name: OpenCV Camera @ 0
@@ -33,13 +43,37 @@ Camera #0:
> [!WARNING]
> When using Intel RealSense cameras in `macOS`, you could get this [error](https://github.com/IntelRealSense/librealsense/issues/12307): `Error finding RealSense cameras: failed to set power state`, this can be solved by running the same command with `sudo` permissions. Note that using RealSense cameras in `macOS` is unstable.
## Use Cameras
`ZMQCamera` and `Reachy2Camera` do not support auto-discovery. They must be configured manually by providing their network address and port or robot SDK settings.
Below are two examples, demonstrating how to work with the API.
## Use cameras
- **Asynchronous frame capture** using an OpenCV-based camera
### Frame access modes
All camera classes implement three access modes for capturing frames:
| Method | Behavior | Blocks? | Best For |
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------- | ---------------------------------------- |
| `read()` | Waits for the camera hardware to return a frame. May block for a long time depending on the camera and SDK. | Yes | Simple scripts, sequential capture |
| `async_read(timeout_ms)` | Returns the latest unconsumed frame from background thread. Blocks only if buffer is empty, up to `timeout_ms`. Raises `TimeoutError` if no frame arrives. | With a timeout | Control loops synchronized to camera FPS |
| `read_latest(max_age_ms)` | Peeks at the most recent frame in buffer (may be stale). Raises `TimeoutError` if frame is older than `max_age_ms`. | No | UI visualization, logging, monitoring |
### Usage examples
The following examples show how to use the camera API to configure and capture frames from different camera types.
- **Blocking and non-blocking frame capture** using an OpenCV-based camera
- **Color and depth capture** using an Intel RealSense camera
> [!WARNING]
> Failing to cleanly disconnect cameras can cause resource leaks. Use the context manager protocol to ensure automatic cleanup:
>
> ```python
> with OpenCVCamera(config) as camera:
> ...
> ```
>
> You can also call `connect()` and `disconnect()` manually, but always use a `finally` block for the latter.
<hfoptions id="shell_restart">
<hfoption id="Open CV Camera">
@@ -60,16 +94,30 @@ config = OpenCVCameraConfig(
)
# Instantiate and connect an `OpenCVCamera`, performing a warm-up read (default).
camera = OpenCVCamera(config)
camera.connect()
with OpenCVCamera(config) as camera:
# Read a frame synchronously — blocks until hardware delivers a new frame
frame = camera.read()
print(f"read() call returned frame with shape:", frame.shape)
# Read a frame asynchronously with a timeout — returns the latest unconsumed frame or waits up to timeout_ms for a new one
try:
for i in range(10):
frame = camera.async_read(timeout_ms=200)
print(f"async_read call returned frame {i} with shape:", frame.shape)
except TimeoutError as e:
print(f"No frame received within timeout: {e}")
# Instantly return a frame - returns the most recent frame captured by the camera
try:
initial_frame = camera.read_latest(max_age_ms=1000)
for i in range(10):
frame = camera.read_latest(max_age_ms=1000)
print(f"read_latest call returned frame {i} with shape:", frame.shape)
print(f"Was a new frame received by the camera? {not (initial_frame == frame).any()}")
except TimeoutError as e:
print(f"Frame too old: {e}")
# Read frames asynchronously in a loop via `async_read(timeout_ms)`
try:
for i in range(10):
frame = camera.async_read(timeout_ms=200)
print(f"Async frame {i} shape:", frame.shape)
finally:
camera.disconnect()
```
<!-- prettier-ignore-end -->
@@ -111,10 +159,10 @@ finally:
</hfoption>
</hfoptions>
## Use your phone
## Use your phone's camera
<hfoptions id="use phone">
<hfoption id="Mac">
<hfoption id="iPhone & macOS">
To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
@@ -124,83 +172,49 @@ To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
For more details, visit [Apple support](https://support.apple.com/en-gb/guide/mac-help/mchl77879b8a/mac).
Your iPhone should be detected automatically when running the camera setup script in the next section.
</hfoption>
<hfoption id="Linux">
<hfoption id="OBS virtual camera">
If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera
If you want to use your phone as a camera using OBS, follow these steps to set up a virtual camera.
1. _Install `v4l2loopback-dkms` and `v4l-utils`_. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using:
1. _(Linux only) Install `v4l2loopback-dkms` and `v4l-utils`_. These packages create virtual camera devices and verify their settings. Install with:
<!-- prettier-ignore-start -->
```python
```bash
sudo apt install v4l2loopback-dkms v4l-utils
```
<!-- prettier-ignore-end -->
2. _Install [DroidCam](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
3. _Install [OBS Studio](https://obsproject.com)_. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org):
2. _Install the [DroidCam app](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
3. _Download and install [OBS Studio](https://obsproject.com)_.
4. _Download and install the [DroidCam OBS plugin](https://droidcam.app/obs)_.
5. _Start OBS Studio_.
<!-- prettier-ignore-start -->
```python
flatpak install flathub com.obsproject.Studio
```
<!-- prettier-ignore-end -->
4. _Install the DroidCam OBS plugin_. This plugin integrates DroidCam with OBS Studio. Install it with:
<!-- prettier-ignore-start -->
```python
flatpak install flathub com.obsproject.Studio.Plugin.DroidCam
```
<!-- prettier-ignore-end -->
5. _Start OBS Studio_. Launch with:
<!-- prettier-ignore-start -->
```python
flatpak run com.obsproject.Studio
```
<!-- prettier-ignore-end -->
6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`.
7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in.
6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480` to avoid the watermarks.
7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video` or `OBS > Preferences... > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it.
8. _Start virtual camera_. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide).
9. _Verify the virtual camera setup_. Use `v4l2-ctl` to list the devices:
9. _Verify the virtual camera setup and resolution_.
- **Linux**: Use `v4l2-ctl` to list devices and check resolution:
```bash
v4l2-ctl --list-devices # find VirtualCam and note its /dev/videoX path
v4l2-ctl -d /dev/videoX --get-fmt-video # replace with your VirtualCam path
```
You should see `VirtualCam` listed and resolution `640x480`.
- **macOS**: Open Photo Booth or FaceTime and select "OBS Virtual Camera" as the input.
- **Windows**: The native Camera app doesn't support virtual cameras. Use a video conferencing app (Zoom, Teams) or run `lerobot-find-cameras opencv` directly to verify.
<!-- prettier-ignore-start -->
```python
v4l2-ctl --list-devices
```
<!-- prettier-ignore-end -->
<details>
<summary><strong>Troubleshooting</strong></summary>
You should see an entry like:
> The virtual camera resolution is incorrect.
```
VirtualCam (platform:v4l2loopback-000):
/dev/video1
```
Delete the virtual camera source and recreate it. The resolution cannot be changed after creation.
10. _Check the camera resolution_. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`.
> Error reading frame in background thread for OpenCVCamera(X): OpenCVCamera(X) frame width=640 or height=480 do not match configured width=1920 or height=1080.
<!-- prettier-ignore-start -->
```python
v4l2-ctl -d /dev/video1 --get-fmt-video
```
<!-- prettier-ignore-end -->
This error is caused by OBS Virtual Camera advertising a `1920x1080` resolution despite rescaling. The only fix for now is to comment out the width and height check in `_postprocess_image()`.
You should see an entry like:
```
>>> Format Video Capture:
>>> Width/Height : 640/480
>>> Pixel Format : 'YUYV' (YUYV 4:2:2)
```
Troubleshooting: If the resolution is not correct you will have to delete the Virtual Camera port and try again as it cannot be changed.
If everything is set up correctly, you can proceed with the rest of the tutorial.
</details>
</hfoption>
</hfoptions>
If everything is set up correctly, your phone will appear as a standard OpenCV camera and can be used with `OpenCVCamera`.
+278
View File
@@ -0,0 +1,278 @@
# Using Subtasks in LeRobot Datasets
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
## What are Subtasks?
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
1. "Approach the apple"
2. "Grasp the apple"
3. "Lift the apple"
4. "Move to basket"
5. "Release the apple"
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
width="80%"
/>
<p>
<em>Figure: Overview of subtask annotation.</em>
</p>
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
## Dataset Structure
Subtask information is stored in the dataset metadata:
```
my-dataset/
├── data/
│ └── ...
├── meta/
│ ├── info.json
│ ├── stats.json
│ ├── tasks.parquet
│ ├── subtasks.parquet # Subtask index → subtask string mapping
│ └── episodes/
│ └── ...
└── videos/
└── ...
```
### Subtasks Parquet File
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
| subtask_index | subtask (index column) |
| ------------- | ---------------------- |
| 0 | "Approach the apple" |
| 1 | "Grasp the apple" |
| 2 | "Lift the apple" |
| ... | ... |
### Frame-Level Annotations
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
```python
# Example frame data in the parquet file
{
"index": 42,
"timestamp": 1.4,
"episode_index": 0,
"task_index": 0,
"subtask_index": 2, # References "Lift the apple"
"observation.state": [...],
"action": [...],
}
```
## Annotating Datasets with Subtasks
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
After completing your annotation:
1. Click "Push to Hub" to upload your annotated dataset
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
## Loading Datasets with Subtasks
When you load a dataset with subtask annotations, the subtask information is automatically available:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Load a dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Access a sample
sample = dataset[100]
# The sample includes both task and subtask information
print(sample["task"]) # "Collect the fruit"
print(sample["subtask"]) # "Grasp the apple"
print(sample["task_index"]) # tensor(0)
print(sample["subtask_index"]) # tensor(2)
```
### Checking for Subtask Support
You can check if a dataset has subtask annotations:
```python
# Check if subtasks are available
has_subtasks = (
"subtask_index" in dataset.features
and dataset.meta.subtasks is not None
)
if has_subtasks:
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
print("Subtasks:", list(dataset.meta.subtasks.index))
```
## Using Subtasks for Training
### With the Tokenizer Processor
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
```python
from lerobot.processor.tokenizer_processor import TokenizerProcessor
from lerobot.processor.pipeline import ProcessorPipeline
# Create a tokenizer processor
tokenizer_processor = TokenizerProcessor(
tokenizer_name_or_path="google/paligemma-3b-pt-224",
padding="max_length",
max_length=64,
)
# The processor will automatically tokenize subtasks if present in the batch
# and add them to the observation under:
# - "observation.subtask.tokens"
# - "observation.subtask.attention_mask"
```
When subtasks are available in the batch, the tokenizer processor adds:
- `observation.subtask.tokens`: Tokenized subtask text
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
### DataLoader with Subtasks
```python
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=16,
shuffle=True,
)
for batch in dataloader:
# Access subtask information in the batch
subtasks = batch["subtask"] # List of subtask strings
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
# Use for training hierarchical policies or reward models
print(f"Batch subtasks: {set(subtasks)}")
```
## Example Datasets with Subtask Annotations
Try loading a dataset with subtask annotations:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Example dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Explore the subtasks
print("Available subtasks:")
for subtask_name in dataset.meta.subtasks.index:
print(f" - {subtask_name}")
# Get subtask distribution
subtask_counts = {}
for i in range(len(dataset)):
sample = dataset[i]
subtask = sample["subtask"]
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
print("\nSubtask distribution:")
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
print(f" {subtask}: {count} frames")
```
## Use Cases
### 1. Hierarchical Policy Training
Train policies that predict both actions and current subtask:
```python
class HierarchicalPolicy(nn.Module):
def __init__(self, num_subtasks):
super().__init__()
self.action_head = nn.Linear(hidden_dim, action_dim)
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
def forward(self, observations):
features = self.encoder(observations)
actions = self.action_head(features)
subtask_logits = self.subtask_head(features)
return actions, subtask_logits
```
### 2. Stage-Aware Reward Modeling (SARM)
Build reward models that understand task progression:
```python
# SARM predicts:
# - Stage: Which subtask is being executed (discrete)
# - Progress: How far along the subtask (continuous 0-1)
class SARMRewardModel(nn.Module):
def forward(self, observations):
features = self.encoder(observations)
stage_logits = self.stage_classifier(features)
progress = self.progress_regressor(features)
return stage_logits, progress
```
### 3. Progress Visualization
Monitor robot execution by tracking subtask progression:
```python
def visualize_execution(model, observations):
for t, obs in enumerate(observations):
action, subtask_logits = model(obs)
predicted_subtask = subtask_names[subtask_logits.argmax()]
print(f"t={t}: Executing '{predicted_subtask}'")
```
## API Reference
### LeRobotDataset Properties
| Property | Type | Description |
| --------------------------- | ---------------------- | ------------------------------------------ |
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
### Sample Keys
When subtasks are available, each sample includes:
| Key | Type | Description |
| --------------- | -------------- | ------------------------------------ |
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
| `subtask` | `str` | Natural language subtask description |
## Related Resources
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
+276
View File
@@ -0,0 +1,276 @@
# OpenArm
[OpenArm](https://openarm.dev) is an open-source 7DOF humanoid arm designed for physical AI research and deployment.
To get your OpenArm, assembled or DIY, and join the global community, browse verified and certified manufacturers worldwide at [openarm.dev](https://openarm.dev).
## What's Unique?
- **Human-Scale Design**: OpenArm is designed with human-like proportions, scaled for a person around 160-165cm tall. This provides an optimal balance between practical reach and manageable inertia for safe, responsive operation.
- **Safety-First Architecture**: Built with QDD backdrivable motors and high compliance, OpenArm prioritizes safe human-robot interaction while maintaining practical payload capabilities (6.0kg peak / 4.1kg nominal) for real-world tasks.
- **Built for Durability**: Critical structural components use aluminum and stainless steel construction, ensuring robust performance for repetitive data collection and continuous research use.
- **Fully Accessible & Buildable**: Every component, from CNC parts and 3D-printed casings to electrical wiring is designed to be purchasable and buildable by individual researchers and labs, with complete fabrication data provided.
- **Practical & Affordable**: At $6,500 USD for a complete bimanual system, OpenArm delivers research-grade capabilities at a fraction of traditional humanoid robot costs.
## Platform Requirements
<Tip warning={true}>
**Linux Only**: OpenArm currently only works on Linux. The CAN bus USB adapter
does not have macOS drivers and has not been tested on Windows.
</Tip>
## Safety Guide
Before operating OpenArm, please read the [official safety guide](https://docs.openarm.dev/getting-started/safety-guide). Key points:
- **Secure installation**: Fasten the arm to a flat, stable surface with screws or clamps
- **Safe distance**: Keep body parts and objects outside the range of motion during operation
- **Protective equipment**: Always wear safety goggles; use additional PPE as needed
- **Payload limits**: Do not exceed specified payload limits (6.0kg peak / 4.1kg nominal per arm)
- **Emergency stop**: Know the location and operation of the emergency stop device
- **Regular inspection**: Check for loose screws, damaged mechanical limits, unusual noises, and wiring damage
## Hardware Setup
Follow the official [OpenArm hardware documentation](https://docs.openarm.dev) for:
- Bill of materials and sourcing
- 3D printing instructions
- Mechanical assembly
- Electrical wiring
The hardware repositories are available at [github.com/enactic/openarm](https://github.com/enactic/openarm).
## CAN Bus Setup
OpenArm uses CAN bus communication with Damiao motors. Once you have the CAN bus USB adapter plugged into your Linux PC, follow the [Damiao Motors and CAN Bus guide](./damiao) to configure the interface.
Quick setup:
```bash
# Setup CAN interfaces
lerobot-setup-can --mode=setup --interfaces=can0,can1
# Test motor communication
lerobot-setup-can --mode=test --interfaces=can0,can1
```
## Install LeRobot 🤗
Follow our [Installation Guide](./installation), then install the Damiao motor support:
```bash
pip install -e ".[damiao]"
```
## Usage
### Follower Arm (Robot)
<hfoptions id="follower">
<hfoption id="Command">
```bash
lerobot-calibrate \
--robot.type=openarm_follower \
--robot.port=can0 \
--robot.side=right \
--robot.id=my_openarm_follower
```
</hfoption>
<hfoption id="API example">
```python
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
config = OpenArmFollowerConfig(
port="can0",
side="right", # or "left" for left arm
id="my_openarm_follower",
)
follower = OpenArmFollower(config)
follower.connect()
# Read current state
obs = follower.get_observation()
print(obs)
# Send action (position in degrees)
action = {
"joint_1.pos": 0.0,
"joint_2.pos": 0.0,
"joint_3.pos": 0.0,
"joint_4.pos": 45.0,
"joint_5.pos": 0.0,
"joint_6.pos": 0.0,
"joint_7.pos": 0.0,
"gripper.pos": 0.0,
}
follower.send_action(action)
follower.disconnect()
```
</hfoption>
</hfoptions>
### Leader Arm (Teleoperator)
The leader arm is used for teleoperation - manually moving it to control the follower arm.
<hfoptions id="leader">
<hfoption id="Command">
```bash
lerobot-calibrate \
--teleop.type=openarm_leader \
--teleop.port=can1 \
--teleop.id=my_openarm_leader
```
</hfoption>
<hfoption id="API example">
```python
from lerobot.teleoperators.openarm_leader import OpenArmLeader, OpenArmLeaderConfig
config = OpenArmLeaderConfig(
port="can1",
id="my_openarm_leader",
manual_control=True, # Disable torque for manual movement
)
leader = OpenArmLeader(config)
leader.connect()
# Read current position (as action to send to follower)
action = leader.get_action()
print(action)
leader.disconnect()
```
</hfoption>
</hfoptions>
### Teleoperation
To teleoperate OpenArm with leader-follower control:
```bash
lerobot-teleoperate \
--robot.type=openarm_follower \
--robot.port=can0 \
--robot.side=right \
--robot.id=my_follower \
--teleop.type=openarm_leader \
--teleop.port=can1 \
--teleop.id=my_leader
```
### Bimanual Teleoperation
To teleoperate a bimanual OpenArm setup with two leader and two follower arms:
```bash
lerobot-teleoperate \
--robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can0 \
--robot.left_arm_config.side=left \
--robot.right_arm_config.port=can1 \
--robot.right_arm_config.side=right \
--robot.id=my_bimanual_follower \
--teleop.type=bi_openarm_leader \
--teleop.left_arm_config.port=can2 \
--teleop.right_arm_config.port=can3 \
--teleop.id=my_bimanual_leader
```
### Recording Data
To record a dataset during teleoperation:
```bash
lerobot-record \
--robot.type=openarm_follower \
--robot.port=can0 \
--robot.side=right \
--robot.id=my_follower \
--teleop.type=openarm_leader \
--teleop.port=can1 \
--teleop.id=my_leader \
--repo-id=my_hf_username/my_openarm_dataset \
--fps=30 \
--num-episodes=10
```
## Configuration Options
### Follower Configuration
| Parameter | Default | Description |
| --------------------- | --------- | ---------------------------------------------------------- |
| `port` | - | CAN interface (e.g., `can0`) |
| `side` | `None` | Arm side: `"left"`, `"right"`, or `None` for custom limits |
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
| `max_relative_target` | `None` | Safety limit for relative target positions |
| `position_kp` | Per-joint | Position control proportional gains |
| `position_kd` | Per-joint | Position control derivative gains |
### Leader Configuration
| Parameter | Default | Description |
| ------------------ | --------- | ----------------------------------- |
| `port` | - | CAN interface (e.g., `can1`) |
| `manual_control` | `True` | Disable torque for manual movement |
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
## Motor Configuration
OpenArm uses Damiao motors with the following default configuration:
| Joint | Motor Type | Send ID | Recv ID |
| --------------------------- | ---------- | ------- | ------- |
| joint_1 (Shoulder pan) | DM8009 | 0x01 | 0x11 |
| joint_2 (Shoulder lift) | DM8009 | 0x02 | 0x12 |
| joint_3 (Shoulder rotation) | DM4340 | 0x03 | 0x13 |
| joint_4 (Elbow flex) | DM4340 | 0x04 | 0x14 |
| joint_5 (Wrist roll) | DM4310 | 0x05 | 0x15 |
| joint_6 (Wrist pitch) | DM4310 | 0x06 | 0x16 |
| joint_7 (Wrist rotation) | DM4310 | 0x07 | 0x17 |
| gripper | DM4310 | 0x08 | 0x18 |
## Troubleshooting
### No Response from Motors
1. Check power supply connections
2. Verify CAN wiring (CAN-H, CAN-L, GND)
3. Run diagnostics: `lerobot-setup-can --mode=test --interfaces=can0`
4. See the [Damiao troubleshooting guide](./damiao#troubleshooting) for more details
### CAN Interface Not Found
Ensure the CAN interface is configured:
```bash
ip link show can0
```
## Resources
- [OpenArm Website](https://openarm.dev)
- [OpenArm Documentation](https://docs.openarm.dev)
- [OpenArm GitHub](https://github.com/enactic/openarm)
- [Safety Guide](https://docs.openarm.dev/getting-started/safety-guide)
- [Damiao Motors and CAN Bus](./damiao)
+99 -1
View File
@@ -188,7 +188,105 @@ Press `Ctrl+C` to stop the policy.
## Running in Simulation Mode (MuJoCo)
You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config.
You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI.
### Calibrate Exoskeleton Teleoperator
```bash
lerobot-calibrate \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo
```
### Teleoperate in Simulation
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--fps=100
```
### Record Dataset in Simulation
```bash
python -m lerobot.scripts.lerobot_record \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--dataset.repo_id=your-username/dataset-name \
--dataset.single_task="Test" \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true
```
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
---
## Running on Real Robot
Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot.
### Start the Camera Server
On the robot, start the ZMQ image server:
```bash
python src/lerobot/cameras/zmq/image_server.py
```
Keep this running in a separate terminal for camera streaming during recording.
### Teleoperate Real Robot
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--fps=100
```
### Record Dataset on Real Robot
```bash
python -m lerobot.scripts.lerobot_record \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--dataset.repo_id=your-username/dataset-name \
--dataset.single_task="Test" \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true
```
**Note**: Update `server_address` to match your robot's camera server IP.
Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real)
---
## Additional Resources
+16 -15
View File
@@ -81,24 +81,25 @@ def replay(cfg: ReplayConfig):
actions = dataset.hf_dataset.select_columns(ACTION)
robot.connect()
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
try:
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action_array = actions[idx][ACTION]
action = {}
for i, name in enumerate(dataset.features[ACTION]["names"]):
key = f"{name.removeprefix('main_')}.pos"
action[key] = action_array[i].item()
action_array = actions[idx][ACTION]
action = {}
for i, name in enumerate(dataset.features[ACTION]["names"]):
key = f"{name.removeprefix('main_')}.pos"
action[key] = action_array[i].item()
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
action["elbow_flex.pos"] -= 90
robot.send_action(action)
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
action["elbow_flex.pos"] -= 90
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
robot.disconnect()
dt_s = time.perf_counter() - start_episode_t
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
finally:
robot.disconnect()
if __name__ == "__main__":
+45 -43
View File
@@ -78,40 +78,24 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="lekiwi_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
print("Starting evaluate loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
@@ -120,24 +104,42 @@ def main():
robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Save episode
dataset.save_episode()
recorded_episodes += 1
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
recorded_episodes += 1
dataset.finalize()
dataset.push_to_hub()
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
+45 -44
View File
@@ -74,40 +74,23 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="lekiwi_record")
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
try:
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {recorded_episodes}")
print("Starting record loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {recorded_episodes}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
dataset=dataset,
teleop=[leader_arm, keyboard],
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
dataset=dataset,
teleop=[leader_arm, keyboard],
control_time_s=RESET_TIME_SEC,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
@@ -115,26 +98,44 @@ def main():
robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=[leader_arm, keyboard],
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Save episode
dataset.save_episode()
recorded_episodes += 1
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
leader_arm.disconnect()
keyboard.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
recorded_episodes += 1
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
leader_arm.disconnect()
keyboard.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
+17 -15
View File
@@ -42,25 +42,27 @@ def main():
# Connect to the robot
robot.connect()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
# Get recorded action from dataset
action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get recorded action from dataset
action = {
name: float(actions[idx][ACTION][i])
for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Send action to robot
_ = robot.send_action(action)
# Send action to robot
_ = robot.send_action(action)
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
robot.disconnect()
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
finally:
robot.disconnect()
if __name__ == "__main__":
+44 -41
View File
@@ -142,38 +142,24 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="phone_so100_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
@@ -182,24 +168,41 @@ def main():
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Save episode
dataset.save_episode()
episode_idx += 1
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
+44 -41
View File
@@ -149,38 +149,23 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="phone_so100_record")
if not robot.is_connected or not phone.is_connected:
raise ValueError("Robot or teleop is not connected!")
try:
if not robot.is_connected or not phone.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop. Move your phone to teleoperate the robot...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
print("Starting record loop. Move your phone to teleoperate the robot...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
control_time_s=RESET_TIME_SEC,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
@@ -188,25 +173,43 @@ def main():
robot_observation_processor=robot_joints_to_ee_pose,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
# Save episode
dataset.save_episode()
episode_idx += 1
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
phone.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
phone.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
+22 -20
View File
@@ -73,32 +73,34 @@ def main():
# Connect to the robot
robot.connect()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i])
for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get robot observation
robot_obs = robot.get_observation()
# Get robot observation
robot_obs = robot.get_observation()
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Send action to robot
_ = robot.send_action(joint_action)
# Send action to robot
_ = robot.send_action(joint_action)
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
# Clean up
robot.disconnect()
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
finally:
# Clean up
robot.disconnect()
if __name__ == "__main__":
+44 -41
View File
@@ -142,38 +142,24 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="so100_so100_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
@@ -182,24 +168,41 @@ def main():
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Save episode
dataset.save_episode()
episode_idx += 1
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
+45 -41
View File
@@ -146,38 +146,23 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="recording_phone")
if not leader.is_connected or not follower.is_connected:
raise ValueError("Robot or teleop is not connected!")
try:
if not leader.is_connected or not follower.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
print("Starting record loop...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop=leader,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop=leader,
control_time_s=RESET_TIME_SEC,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
@@ -185,25 +170,44 @@ def main():
robot_observation_processor=follower_joints_to_ee,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop=leader,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
)
# Save episode
dataset.save_episode()
episode_idx += 1
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
leader.disconnect()
follower.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
episode_idx += 1
dataset.finalize()
dataset.push_to_hub()
finally:
# Clean up
log_say("Stop recording")
leader.disconnect()
follower.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
+22 -19
View File
@@ -74,32 +74,35 @@ def main():
# Connect to the robot
robot.connect()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i])
for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get robot observation
robot_obs = robot.get_observation()
# Get robot observation
robot_obs = robot.get_observation()
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Send action to robot
_ = robot.send_action(joint_action)
# Send action to robot
_ = robot.send_action(joint_action)
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
# Clean up
robot.disconnect()
finally:
# Clean up
robot.disconnect()
if __name__ == "__main__":
+6 -1
View File
@@ -105,12 +105,17 @@ dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
damiao = ["python-can>=4.2.0,<5.0.0"]
# Robots
openarms = ["lerobot[damiao]"]
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
unitree_g1 = [
"pyzmq>=26.2.1,<28.0.0",
"onnxruntime>=1.16.0,<2.0.0"
"onnxruntime>=1.16.0,<2.0.0",
"pin>=3.0.0,<4.0.0",
"meshcat>=0.3.0,<0.4.0",
"matplotlib>=3.9.0,<4.0.0",
"casadi>=3.6.0,<4.0.0",
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
kinematics = ["lerobot[placo-dep]"]
+82 -18
View File
@@ -15,11 +15,12 @@
# limitations under the License.
import abc
import warnings
from typing import Any
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
from .configs import CameraConfig, ColorMode
from .configs import CameraConfig
class Camera(abc.ABC):
@@ -30,20 +31,12 @@ class Camera(abc.ABC):
Manages basic camera properties (FPS, resolution) and core operations:
- Connection/disconnection
- Frame capture (sync/async)
- Frame capture (sync/async/latest)
Attributes:
fps (int | None): Configured frames per second
width (int | None): Frame width in pixels
height (int | None): Frame height in pixels
Example:
class MyCamera(Camera):
def __init__(self, config): ...
@property
def is_connected(self) -> bool: ...
def connect(self, warmup=True): ...
# Plus other required methods
"""
def __init__(self, config: CameraConfig):
@@ -56,6 +49,32 @@ class Camera(abc.ABC):
self.width: int | None = config.width
self.height: int | None = config.height
def __enter__(self):
"""
Context manager entry.
Automatically connects to the camera.
"""
self.connect()
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
Context manager exit.
Automatically disconnects, ensuring resources are released even on error.
"""
self.disconnect()
def __del__(self) -> None:
"""
Destructor safety net.
Attempts to disconnect if the object is garbage collected without cleanup.
"""
try:
if self.is_connected:
self.disconnect()
except Exception: # nosec B110
pass
@property
@abc.abstractmethod
def is_connected(self) -> bool:
@@ -89,12 +108,10 @@ class Camera(abc.ABC):
pass
@abc.abstractmethod
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""Capture and return a single frame from the camera.
def read(self) -> NDArray[Any]:
"""Capture and return a single frame from the camera synchronously.
Args:
color_mode: Desired color mode for the output frame. If None,
uses the camera's default color mode.
This is a blocking call that will wait for the hardware and its SDK.
Returns:
np.ndarray: Captured frame as a numpy array.
@@ -103,17 +120,64 @@ class Camera(abc.ABC):
@abc.abstractmethod
def async_read(self, timeout_ms: float = ...) -> NDArray[Any]:
"""Asynchronously capture and return a single frame from the camera.
"""Return the most recent new frame.
This method retrieves the latest frame captured by the background thread.
If a new frame is already available in the buffer (captured since the last call),
it returns it immediately.
It blocks up to `timeout_ms` only if the buffer is empty or if the latest frame
was already consumed by a previous `async_read` call.
Essentially, this method return the latest unconsumed frame, waiting if necessary
for a new one to arrive within the specified timeout.
Usage:
- Ideal for control loops where you want to ensure every processed frame
is fresh, effectively synchronizing your loop to the camera's FPS.
- Causes of a timeout usually include: very low camera FPS, heavy processing load,
or if the camera is disconnected.
Args:
timeout_ms: Maximum time to wait for a frame in milliseconds.
Defaults to implementation-specific timeout.
timeout_ms: Maximum time to wait for a new frame in milliseconds.
Defaults to 200ms (0.2s).
Returns:
np.ndarray: Captured frame as a numpy array.
Raises:
TimeoutError: If no new frame arrives within `timeout_ms`.
"""
pass
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Usage:
Ideal for scenarios requiring zero latency or decoupled frequencies & when
we want a guaranteed frame, such as UI visualization, logging, or
non-critical monitoring.
Returns:
NDArray[Any]: The frame image (numpy array).
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
NotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
warnings.warn(
f"{self.__class__.__name__}.read_latest() is not implemented. "
"Please override read_latest(); it will be required in future releases.",
FutureWarning,
stacklevel=2,
)
return self.async_read()
@abc.abstractmethod
def disconnect(self) -> None:
"""Disconnect from the camera and release resources."""
+111 -55
View File
@@ -70,34 +70,24 @@ class OpenCVCamera(Camera):
Example:
```python
from lerobot.cameras.opencv import OpenCVCamera
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig, ColorMode, Cv2Rotation
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig
# Basic usage with camera index 0
config = OpenCVCameraConfig(index_or_path=0)
camera = OpenCVCamera(config)
camera.connect()
# Read 1 frame synchronously
# Read 1 frame synchronously (blocking)
color_image = camera.read()
print(color_image.shape)
# Read 1 frame asynchronously
# Read 1 frame asynchronously (waits for new frame with a timeout)
async_image = camera.async_read()
# Get the latest frame immediately (no wait, returns timestamp)
latest_image, timestamp = camera.read_latest()
# When done, properly disconnect the camera using
camera.disconnect()
# Example with custom settings
custom_config = OpenCVCameraConfig(
index_or_path='/dev/video0', # Or use an index
fps=30,
width=1280,
height=720,
color_mode=ColorMode.RGB,
rotation=Cv2Rotation.ROTATE_90
)
custom_camera = OpenCVCamera(custom_config)
# ... connect, read, disconnect ...
```
"""
@@ -123,6 +113,7 @@ class OpenCVCamera(Camera):
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_timestamp: float | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
@@ -146,12 +137,16 @@ class OpenCVCamera(Camera):
Connects to the OpenCV camera specified in the configuration.
Initializes the OpenCV VideoCapture object, sets desired camera properties
(FPS, width, height), and performs initial checks.
(FPS, width, height), starts the background reading thread and performs initial checks.
Args:
warmup (bool): If True, waits at connect() time until at least one valid frame
has been captured by the background thread. Defaults to True.
Raises:
DeviceAlreadyConnectedError: If the camera is already connected.
ConnectionError: If the specified camera index/path is not found or the camera is found but fails to open.
RuntimeError: If the camera opens but fails to apply requested FPS/resolution settings.
ConnectionError: If the specified camera index/path is not found or fails to open.
RuntimeError: If the camera opens but fails to apply requested settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
@@ -170,12 +165,16 @@ class OpenCVCamera(Camera):
)
self._configure_capture_settings()
self._start_read_thread()
if warmup:
if warmup and self.warmup_s > 0:
start_time = time.time()
while time.time() - start_time < self.warmup_s:
self.read()
self.async_read(timeout_ms=self.warmup_s * 1000)
time.sleep(0.1)
with self.frame_lock:
if self.latest_frame is None:
raise ConnectionError(f"{self} failed to capture frames during warmup.")
logger.info(f"{self} connected.")
@@ -196,8 +195,7 @@ class OpenCVCamera(Camera):
Raises:
RuntimeError: If the camera fails to set any of the specified properties
to the requested value.
DeviceNotConnectedError: If the camera is not connected when attempting
to configure settings.
DeviceNotConnectedError: If the camera is not connected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
@@ -339,6 +337,17 @@ class OpenCVCamera(Camera):
return found_cameras_info
def _read_from_hardware(self) -> NDArray[Any]:
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
ret, frame = self.videocapture.read()
if not ret:
raise RuntimeError(f"{self} read failed (status={ret}).")
return frame
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -346,11 +355,6 @@ class OpenCVCamera(Camera):
This is a blocking call. It waits for the next available frame from the
camera hardware via OpenCV.
Args:
color_mode (Optional[ColorMode]): If specified, overrides the default
color mode (`self.color_mode`) for this read operation (e.g.,
request RGB even if default is BGR).
Returns:
np.ndarray: The captured frame as a NumPy array in the format
(height, width, channels), using the specified or default
@@ -362,34 +366,34 @@ class OpenCVCamera(Camera):
received frame dimensions don't match expectations before rotation.
ValueError: If an invalid `color_mode` is requested.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start_time = time.perf_counter()
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
if color_mode is not None:
logger.warning(
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
ret, frame = self.videocapture.read()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if not ret or frame is None:
raise RuntimeError(f"{self} read failed (status={ret}).")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
processed_frame = self._postprocess_image(frame, color_mode)
self.new_frame_event.clear()
frame = self.async_read(timeout_ms=10000)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return processed_frame
return frame
def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
def _postprocess_image(self, image: NDArray[Any]) -> NDArray[Any]:
"""
Applies color conversion, dimension validation, and rotation to a raw frame.
Args:
image (np.ndarray): The raw image frame (expected BGR format from OpenCV).
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
uses the instance's default `self.color_mode`.
Returns:
np.ndarray: The processed image frame.
@@ -399,11 +403,10 @@ class OpenCVCamera(Camera):
RuntimeError: If the raw frame dimensions do not match the configured
`width` and `height`.
"""
requested_color_mode = self.color_mode if color_mode is None else color_mode
if requested_color_mode not in (ColorMode.RGB, ColorMode.BGR):
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"Invalid color mode '{requested_color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
)
h, w, c = image.shape
@@ -417,7 +420,7 @@ class OpenCVCamera(Camera):
raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).")
processed_image = image
if requested_color_mode == ColorMode.RGB:
if self.color_mode == ColorMode.RGB:
processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]:
@@ -431,7 +434,7 @@ class OpenCVCamera(Camera):
On each iteration:
1. Reads a color frame
2. Stores result in latest_frame (thread-safe)
2. Stores result in latest_frame and updates timestamp (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
@@ -439,30 +442,37 @@ class OpenCVCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
failure_count = 0
while not self.stop_event.is_set():
try:
color_image = self.read()
raw_frame = self._read_from_hardware()
processed_frame = self._postprocess_image(raw_frame)
capture_time = time.perf_counter()
with self.frame_lock:
self.latest_frame = color_image
self.latest_frame = processed_frame
self.latest_timestamp = capture_time
self.new_frame_event.set()
failure_count = 0
except DeviceNotConnectedError:
break
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
if failure_count <= 10:
failure_count += 1
logger.warning(f"Error reading frame in background thread for {self}: {e}")
else:
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=0.1)
if self.stop_event is not None:
self.stop_event.set()
self._stop_read_thread()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
self.thread.daemon = True
self.thread.start()
time.sleep(0.1)
def _stop_read_thread(self) -> None:
"""Signals the background read thread to stop and waits for it to join."""
@@ -475,6 +485,11 @@ class OpenCVCamera(Camera):
self.thread = None
self.stop_event = None
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
@@ -482,6 +497,7 @@ class OpenCVCamera(Camera):
This method retrieves the most recent frame captured by the background
read thread. It does not block waiting for the camera hardware directly,
but may wait up to timeout_ms for the background thread to provide a frame.
It is best effort under high FPS.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
@@ -500,13 +516,12 @@ class OpenCVCamera(Camera):
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
self._start_read_thread()
raise RuntimeError(f"{self} read thread is not running.")
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
thread_alive = self.thread is not None and self.thread.is_alive()
raise TimeoutError(
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
f"Read thread alive: {thread_alive}."
f"Read thread alive: {self.thread.is_alive()}."
)
with self.frame_lock:
@@ -518,6 +533,42 @@ class OpenCVCamera(Camera):
return frame
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Returns:
NDArray[Any]: The frame image (numpy array).
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
frame = self.latest_frame
timestamp = self.latest_timestamp
if frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return frame
def disconnect(self) -> None:
"""
Disconnects from the camera and cleans up resources.
@@ -538,4 +589,9 @@ class OpenCVCamera(Camera):
self.videocapture.release()
self.videocapture = None
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
logger.info(f"{self} disconnected.")
@@ -80,6 +80,8 @@ class Reachy2Camera(Camera):
self.config = config
self.color_mode = config.color_mode
self.latest_frame: NDArray[Any] | None = None
self.latest_timestamp: float | None = None
self.cam_manager: CameraManager | None = None
@@ -125,12 +127,7 @@ class Reachy2Camera(Camera):
"""
Reads a single frame synchronously from the camera.
This is a blocking call.
Args:
color_mode (Optional[ColorMode]): If specified, overrides the default
color mode (`self.color_mode`) for this read operation (e.g.,
request RGB even if default is BGR).
This method retrieves the most recent frame available in Reachy 2's low-level software.
Returns:
np.ndarray: The captured frame as a NumPy array in the format
@@ -145,6 +142,11 @@ class Reachy2Camera(Camera):
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
if color_mode is not None:
logger.warning(
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
@@ -165,11 +167,18 @@ class Reachy2Camera(Camera):
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
if frame is None:
return np.empty((0, 0, 3), dtype=np.uint8)
raise RuntimeError(f"Internal error: No frame available for {self}.")
if self.config.color_mode == "rgb":
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
)
if self.color_mode == ColorMode.RGB:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
self.latest_frame = frame
self.latest_timestamp = time.perf_counter()
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
@@ -177,13 +186,7 @@ class Reachy2Camera(Camera):
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame.
This method retrieves the most recent frame available in Reachy 2's low-level software.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
to become available. Defaults to 200ms (0.2 seconds).
Same as read()
Returns:
np.ndarray: The latest captured frame as a NumPy array in the format
@@ -197,12 +200,38 @@ class Reachy2Camera(Camera):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
frame = self.read()
return self.read()
if frame is None:
raise RuntimeError(f"Internal error: No frame available for {self}.")
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
return frame
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Returns:
tuple[NDArray, float]:
- The frame image (numpy array).
- The timestamp (time.perf_counter) when this frame was captured.
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.latest_frame is None or self.latest_timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
age_ms = (time.perf_counter() - self.latest_timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return self.latest_frame
def disconnect(self) -> None:
"""
+145 -67
View File
@@ -72,15 +72,14 @@ class RealSenseCamera(Camera):
camera = RealSenseCamera(config)
camera.connect()
# Read 1 frame synchronously
# Read 1 frame synchronously (blocking)
color_image = camera.read()
print(color_image.shape)
# Read 1 frame asynchronously
# Read 1 frame asynchronously (waits for new frame with a timeout)
async_image = camera.async_read()
# When done, properly disconnect the camera using
camera.disconnect()
# Get the latest frame immediately (no wait, returns timestamp)
latest_image, timestamp = camera.read_latest()
# Example with depth capture and custom settings
custom_config = RealSenseCameraConfig(
@@ -133,7 +132,9 @@ class RealSenseCamera(Camera):
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_color_frame: NDArray[Any] | None = None
self.latest_depth_frame: NDArray[Any] | None = None
self.latest_timestamp: float | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
@@ -158,6 +159,10 @@ class RealSenseCamera(Camera):
Initializes the RealSense pipeline, configures the required streams (color
and optionally depth), starts the pipeline, and validates the actual stream settings.
Args:
warmup (bool): If True, waits at connect() time until at least one valid frame
has been captured by the background thread. Defaults to True.
Raises:
DeviceAlreadyConnectedError: If the camera is already connected.
ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique).
@@ -181,15 +186,18 @@ class RealSenseCamera(Camera):
) from e
self._configure_capture_settings()
self._start_read_thread()
if warmup:
time.sleep(
1
) # NOTE(Steven): RS cameras need a bit of time to warm up before the first read. If we don't wait, the first read from the warmup will raise.
start_time = time.time()
while time.time() - start_time < self.warmup_s:
self.read()
time.sleep(0.1)
# NOTE(Steven/Caroline): Enforcing at least one second of warmup as RS cameras need a bit of time before the first read. If we don't wait, the first read from the warmup will raise.
self.warmup_s = max(self.warmup_s, 1)
start_time = time.time()
while time.time() - start_time < self.warmup_s:
self.async_read(timeout_ms=self.warmup_s * 1000)
time.sleep(0.1)
with self.frame_lock:
if self.latest_color_frame is None or self.use_depth and self.latest_depth_frame is None:
raise ConnectionError(f"{self} failed to capture frames during warmup.")
logger.info(f"{self} connected.")
@@ -319,9 +327,6 @@ class RealSenseCamera(Camera):
This is a blocking call. It waits for a coherent set of frames (depth)
from the camera hardware via the RealSense pipeline.
Args:
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
Returns:
np.ndarray: The depth map as a NumPy array (height, width)
of type `np.uint16` (raw depth values in millimeters) and rotation.
@@ -330,44 +335,52 @@ class RealSenseCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If reading frames from the pipeline fails or frames are invalid.
"""
if timeout_ms:
logger.warning(
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if not self.use_depth:
raise RuntimeError(
f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}."
)
start_time = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
self.new_frame_event.clear()
_ = self.async_read(timeout_ms=10000)
with self.frame_lock:
depth_map = self.latest_depth_frame
if depth_map is None:
raise RuntimeError("No depth frame available. Ensure camera is streaming.")
return depth_map
def _read_from_hardware(self):
if self.rs_pipeline is None:
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=10000)
if not ret or frame is None:
raise RuntimeError(f"{self} read_depth failed (status={ret}).")
raise RuntimeError(f"{self} read failed (status={ret}).")
depth_frame = frame.get_depth_frame()
depth_map = np.asanyarray(depth_frame.get_data())
return frame
depth_map_processed = self._postprocess_image(depth_map, depth_frame=True)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return depth_map_processed
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]:
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]:
"""
Reads a single frame (color) synchronously from the camera.
This is a blocking call. It waits for a coherent set of frames (color)
from the camera hardware via the RealSense pipeline.
Args:
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
Returns:
np.ndarray: The captured color frame as a NumPy array
(height, width, channels), processed according to `color_mode` and rotation.
@@ -378,39 +391,39 @@ class RealSenseCamera(Camera):
ValueError: If an invalid `color_mode` is requested.
"""
start_time = time.perf_counter()
if color_mode is not None:
logger.warning(
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
if timeout_ms:
logger.warning(
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start_time = time.perf_counter()
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
if self.rs_pipeline is None:
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
self.new_frame_event.clear()
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
if not ret or frame is None:
raise RuntimeError(f"{self} read failed (status={ret}).")
color_frame = frame.get_color_frame()
color_image_raw = np.asanyarray(color_frame.get_data())
color_image_processed = self._postprocess_image(color_image_raw, color_mode)
frame = self.async_read(timeout_ms=10000)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return color_image_processed
return frame
def _postprocess_image(
self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False
) -> NDArray[Any]:
def _postprocess_image(self, image: NDArray[Any], depth_frame: bool = False) -> NDArray[Any]:
"""
Applies color conversion, dimension validation, and rotation to a raw color frame.
Args:
image (np.ndarray): The raw image frame (expected RGB format from RealSense).
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
uses the instance's default `self.color_mode`.
Returns:
np.ndarray: The processed image frame according to `self.color_mode` and `self.rotation`.
@@ -421,9 +434,9 @@ class RealSenseCamera(Camera):
`width` and `height`.
"""
if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR):
if self.color_mode and self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
f"Invalid requested color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
)
if depth_frame:
@@ -454,7 +467,7 @@ class RealSenseCamera(Camera):
On each iteration:
1. Reads a color frame with 500ms timeout
2. Stores result in latest_frame (thread-safe)
2. Stores result in latest_frame and updates timestamp (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
@@ -462,25 +475,41 @@ class RealSenseCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
failure_count = 0
while not self.stop_event.is_set():
try:
color_image = self.read(timeout_ms=500)
frame = self._read_from_hardware()
color_frame_raw = frame.get_color_frame()
color_frame = np.asanyarray(color_frame_raw.get_data())
processed_color_frame = self._postprocess_image(color_frame)
if self.use_depth:
depth_frame_raw = frame.get_depth_frame()
depth_frame = np.asanyarray(depth_frame_raw.get_data())
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
capture_time = time.perf_counter()
with self.frame_lock:
self.latest_frame = color_image
self.latest_color_frame = processed_color_frame
if self.use_depth:
self.latest_depth_frame = processed_depth_frame
self.latest_timestamp = capture_time
self.new_frame_event.set()
failure_count = 0
except DeviceNotConnectedError:
break
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
if failure_count <= 10:
failure_count += 1
logger.warning(f"Error reading frame in background thread for {self}: {e}")
else:
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=0.1)
if self.stop_event is not None:
self.stop_event.set()
self._stop_read_thread()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
@@ -498,6 +527,12 @@ class RealSenseCamera(Camera):
self.thread = None
self.stop_event = None
with self.frame_lock:
self.latest_color_frame = None
self.latest_depth_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
# NOTE(Steven): Missing implementation for depth for now
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
@@ -506,6 +541,7 @@ class RealSenseCamera(Camera):
This method retrieves the most recent color frame captured by the background
read thread. It does not block waiting for the camera hardware directly,
but may wait up to timeout_ms for the background thread to provide a frame.
It is best effort under high FPS.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
@@ -524,17 +560,16 @@ class RealSenseCamera(Camera):
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
self._start_read_thread()
raise RuntimeError(f"{self} read thread is not running.")
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
thread_alive = self.thread is not None and self.thread.is_alive()
raise TimeoutError(
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
f"Read thread alive: {thread_alive}."
f"Read thread alive: {self.thread.is_alive()}."
)
with self.frame_lock:
frame = self.latest_frame
frame = self.latest_color_frame
self.new_frame_event.clear()
if frame is None:
@@ -542,6 +577,43 @@ class RealSenseCamera(Camera):
return frame
# NOTE(Steven): Missing implementation for depth for now
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent (color) frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Returns:
NDArray[Any]: The frame image (numpy array).
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
frame = self.latest_color_frame
timestamp = self.latest_timestamp
if frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return frame
def disconnect(self) -> None:
"""
Disconnects from the camera, stops the pipeline, and cleans up resources.
@@ -565,4 +637,10 @@ class RealSenseCamera(Camera):
self.rs_pipeline = None
self.rs_profile = None
with self.frame_lock:
self.latest_color_frame = None
self.latest_depth_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
logger.info(f"{self} disconnected.")
+199 -45
View File
@@ -45,6 +45,12 @@ logger = logging.getLogger(__name__)
class ZMQCamera(Camera):
"""
Manages camera interactions via ZeroMQ for receiving frames from a remote server.
This class connects to a ZMQ Publisher, subscribes to frame topics, and decodes
incoming JSON messages containing Base64 encoded images. It supports both
synchronous and asynchronous frame reading patterns.
Example usage:
```python
from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig
@@ -52,7 +58,16 @@ class ZMQCamera(Camera):
config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera")
camera = ZMQCamera(config)
camera.connect()
frame = camera.read()
# Read 1 frame synchronously (blocking)
color_image = camera.read()
# Read 1 frame asynchronously (waits for new frame with a timeout)
async_image = camera.async_read()
# Get the latest frame immediately (no wait, returns timestamp)
latest_image, timestamp = camera.read_latest()
camera.disconnect()
```
"""
@@ -68,14 +83,17 @@ class ZMQCamera(Camera):
self.color_mode = config.color_mode
self.timeout_ms = config.timeout_ms
# ZMQ Context and Socket
self.context: zmq.Context | None = None
self.socket: zmq.Socket | None = None
self._connected = False
# Threading resources
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_timestamp: float | None = None
self.new_frame_event: Event = Event()
def __str__(self) -> str:
@@ -83,10 +101,16 @@ class ZMQCamera(Camera):
@property
def is_connected(self) -> bool:
"""Checks if the ZMQ socket is initialized and connected."""
return self._connected and self.context is not None and self.socket is not None
def connect(self, warmup: bool = True) -> None:
"""Connect to ZMQ camera server."""
"""Connect to ZMQ camera server.
Args:
warmup (bool): If True, waits for the camera to provide at least one
valid frame before returning. Defaults to True.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
@@ -103,17 +127,28 @@ class ZMQCamera(Camera):
self.socket.connect(f"tcp://{self.server_address}:{self.port}")
self._connected = True
# Auto-detect resolution
# Auto-detect resolution if not provided
if self.width is None or self.height is None:
h, w = self.read().shape[:2]
# Read directly from hardware because the thread isn't running yet
temp_frame = self._read_from_hardware()
h, w = temp_frame.shape[:2]
self.height = h
self.width = w
logger.info(f"{self} resolution: {w}x{h}")
logger.info(f"{self} resolution detected: {w}x{h}")
self._start_read_thread()
logger.info(f"{self} connected.")
if warmup:
time.sleep(0.1)
# Ensure we have captured at least one frame via the thread
start_time = time.time()
while time.time() - start_time < (self.config.warmup_s): # Wait a bit more than timeout
self.async_read(timeout_ms=self.config.warmup_s * 1000)
time.sleep(0.1)
with self.frame_lock:
if self.latest_frame is None:
raise ConnectionError(f"{self} failed to capture frames during warmup.")
except Exception as e:
self._cleanup()
@@ -131,15 +166,14 @@ class ZMQCamera(Camera):
@staticmethod
def find_cameras() -> list[dict[str, Any]]:
"""ZMQ cameras require manual configuration (server address/port)."""
return []
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Read a single frame from the ZMQ camera.
Detection not implemented for ZMQ cameras. These cameras require manual configuration (server address/port).
"""
raise NotImplementedError("Camera detection is not implemented for ZMQ cameras.")
Returns:
np.ndarray: Decoded frame (height, width, 3)
def _read_from_hardware(self) -> NDArray[Any]:
"""
Reads a single frame directly from the ZMQ socket.
"""
if not self.is_connected or self.socket is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -147,6 +181,7 @@ class ZMQCamera(Camera):
try:
message = self.socket.recv_string()
except Exception as e:
# Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import
if type(e).__name__ == "Again":
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
raise
@@ -176,42 +211,117 @@ class ZMQCamera(Camera):
return frame
def _read_loop(self) -> None:
while self.stop_event and not self.stop_event.is_set():
try:
frame = self.read()
with self.frame_lock:
self.latest_frame = frame
self.new_frame_event.set()
except DeviceNotConnectedError:
break
except TimeoutError:
pass
except Exception as e:
logger.warning(f"Read error: {e}")
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
def _start_read_thread(self) -> None:
if self.thread and self.thread.is_alive():
return
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, daemon=True)
self.thread.start()
This is a blocking call. It waits for the next available frame from the
camera background thread.
def _stop_read_thread(self) -> None:
if self.stop_event:
self.stop_event.set()
if self.thread and self.thread.is_alive():
self.thread.join(timeout=2.0)
self.thread = None
self.stop_event = None
Returns:
np.ndarray: Decoded frame (height, width, 3)
"""
start_time = time.perf_counter()
if color_mode is not None:
logger.warning(
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]:
"""Read latest frame asynchronously (non-blocking)."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if not self.thread or not self.thread.is_alive():
self._start_read_thread()
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
self.new_frame_event.clear()
frame = self.async_read(timeout_ms=10000)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return frame
def _read_loop(self) -> None:
"""
Internal loop run by the background thread for asynchronous reading.
"""
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized.")
failure_count = 0
while not self.stop_event.is_set():
try:
frame = self._read_from_hardware()
capture_time = time.perf_counter()
with self.frame_lock:
self.latest_frame = frame
self.latest_timestamp = capture_time
self.new_frame_event.set()
failure_count = 0
except DeviceNotConnectedError:
break
except (TimeoutError, Exception) as e:
if failure_count <= 10:
failure_count += 1
logger.warning(f"Read error: {e}")
else:
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
def _start_read_thread(self) -> None:
if self.stop_event is not None:
self.stop_event.set()
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, daemon=True, name=f"{self}_read_loop")
self.thread.start()
time.sleep(0.1)
def _stop_read_thread(self) -> None:
if self.stop_event is not None:
self.stop_event.set()
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
self.thread = None
self.stop_event = None
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
to become available. Defaults to 200ms.
Returns:
np.ndarray: The latest captured frame.
Raises:
DeviceNotConnectedError: If the camera is not connected.
TimeoutError: If no frame data becomes available within the specified timeout.
RuntimeError: If the background thread is not running.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms")
@@ -225,11 +335,55 @@ class ZMQCamera(Camera):
return frame
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Returns:
NDArray[Any]: The frame image (numpy array).
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
frame = self.latest_frame
timestamp = self.latest_timestamp
if frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return frame
def disconnect(self) -> None:
"""Disconnect from ZMQ camera."""
if not self.is_connected and not self.thread:
if not self.is_connected and self.thread is None:
raise DeviceNotConnectedError(f"{self} not connected.")
self._stop_read_thread()
if self.thread is not None:
self._stop_read_thread()
self._cleanup()
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
logger.info(f"{self} disconnected.")
@@ -29,6 +29,7 @@ class ZMQCameraConfig(CameraConfig):
camera_name: str = "zmq_camera"
color_mode: ColorMode = ColorMode.RGB
timeout_ms: int = 5000
warmup_s: int = 1
def __post_init__(self) -> None:
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
+82 -18
View File
@@ -116,6 +116,9 @@ def update_meta_data(
Adjusts all indices and timestamps to account for previously aggregated
data and videos in the destination dataset.
For data file indices, uses the 'src_to_dst' mapping from aggregate_data()
to correctly map source file indices to their destination locations.
Args:
df: DataFrame containing the metadata to be updated.
dst_meta: Destination dataset metadata.
@@ -129,8 +132,50 @@ def update_meta_data(
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
# Update data file indices using source-to-destination mapping
# This is critical for handling datasets that are already results of a merge
data_src_to_dst = data_idx.get("src_to_dst", {})
if data_src_to_dst:
# Store original indices for lookup
df["_orig_data_chunk"] = df["data/chunk_index"].copy()
df["_orig_data_file"] = df["data/file_index"].copy()
# Vectorized mapping from (src_chunk, src_file) to (dst_chunk, dst_file)
# This is much faster than per-row iteration for large metadata tables
mapping_index = pd.MultiIndex.from_tuples(
list(data_src_to_dst.keys()),
names=["chunk_index", "file_index"],
)
mapping_values = list(data_src_to_dst.values())
mapping_df = pd.DataFrame(
mapping_values,
index=mapping_index,
columns=["dst_chunk", "dst_file"],
)
# Construct a MultiIndex for each row based on original data indices
row_index = pd.MultiIndex.from_arrays(
[df["_orig_data_chunk"], df["_orig_data_file"]],
names=["chunk_index", "file_index"],
)
# Align mapping to rows; missing keys fall back to the default destination
reindexed = mapping_df.reindex(row_index)
reindexed[["dst_chunk", "dst_file"]] = reindexed[["dst_chunk", "dst_file"]].fillna(
{"dst_chunk": data_idx["chunk"], "dst_file": data_idx["file"]}
)
# Assign mapped destination indices back to the DataFrame
df["data/chunk_index"] = reindexed["dst_chunk"].to_numpy()
df["data/file_index"] = reindexed["dst_file"].to_numpy()
# Clean up temporary columns
df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"])
else:
# Fallback to simple offset (backward compatibility for single-file sources)
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
for key, video_idx in videos_idx.items():
# Store original video file indices before updating
orig_chunk_col = f"videos/{key}/chunk_index"
@@ -146,8 +191,7 @@ def update_meta_data(
if src_to_dst:
# Map each episode to its correct destination file and apply offset
for idx in df.index:
# Convert to Python int to avoid numpy type mismatch in dict lookup
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
# Get destination chunk/file for this source file
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
@@ -163,8 +207,7 @@ def update_meta_data(
df[orig_chunk_col] = video_idx["chunk"]
df[orig_file_col] = video_idx["file"]
for idx in df.index:
# Convert to Python int to avoid numpy type mismatch in dict lookup
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
offset = src_to_offset.get(src_key, 0)
df.at[idx, f"videos/{key}/from_timestamp"] += offset
df.at[idx, f"videos/{key}/to_timestamp"] += offset
@@ -262,6 +305,10 @@ def aggregate_datasets(
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
# Clear the src_to_dst mapping after processing each source dataset
# to avoid interference between different source datasets
data_idx.pop("src_to_dst", None)
dst_meta.info["total_episodes"] += src_meta.total_episodes
dst_meta.info["total_frames"] += src_meta.total_frames
@@ -312,10 +359,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
dst_file_durations = video_idx["dst_file_durations"]
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
# Convert to Python int to ensure consistent dict keys
src_chunk_idx = int(src_chunk_idx)
src_file_idx = int(src_file_idx)
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key,
chunk_index=src_chunk_idx,
@@ -388,10 +431,16 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
Reads source data files, updates indices to match the aggregated dataset,
and writes them to the destination with proper file rotation.
Tracks a `src_to_dst` mapping from source (chunk, file) to destination (chunk, file)
which is critical for correctly updating episode metadata when source datasets
have multiple data files (e.g., from a previous merge operation).
Args:
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
data_idx: Dictionary tracking data chunk and file indices.
data_files_size_in_mb: Maximum size for data files in MB.
chunk_size: Maximum number of files per chunk.
Returns:
dict: Updated data_idx with current chunk and file indices.
@@ -409,6 +458,10 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
# retrieve features schema for proper image typing in parquet
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
# Track source to destination file mapping for metadata update
# This is critical for handling datasets that are already results of a merge
src_to_dst: dict[tuple[int, int], tuple[int, int]] = {}
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
chunk_index=src_chunk_idx, file_index=src_file_idx
@@ -421,7 +474,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta)
data_idx = append_or_create_parquet_file(
# Write data and get the actual destination file it was written to
# This avoids duplicating the rotation logic here
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
df,
src_path,
data_idx,
@@ -433,6 +488,12 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
hf_features=hf_features,
)
# Record the mapping from source to actual destination
src_to_dst[(src_chunk_idx, src_file_idx)] = (dst_chunk, dst_file)
# Add the mapping to data_idx for use in metadata update
data_idx["src_to_dst"] = src_to_dst
return data_idx
@@ -473,7 +534,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
videos_idx,
)
meta_idx = append_or_create_parquet_file(
meta_idx, _ = append_or_create_parquet_file(
df,
src_path,
meta_idx,
@@ -501,7 +562,7 @@ def append_or_create_parquet_file(
contains_images: bool = False,
aggr_root: Path = None,
hf_features: datasets.Features | None = None,
):
) -> tuple[dict[str, int], tuple[int, int]]:
"""Appends data to an existing parquet file or creates a new one based on size constraints.
Manages file rotation when size limits are exceeded to prevent individual files
@@ -519,9 +580,11 @@ def append_or_create_parquet_file(
hf_features: Optional HuggingFace Features schema for proper image typing.
Returns:
dict: Updated index dictionary with current chunk and file indices.
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
and (dst_chunk, dst_file) is the actual destination file the data was written to.
"""
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
dst_chunk, dst_file = idx["chunk"], idx["file"]
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
if not dst_path.exists():
dst_path.parent.mkdir(parents=True, exist_ok=True)
@@ -529,14 +592,15 @@ def append_or_create_parquet_file(
to_parquet_with_hf_images(df, dst_path, features=hf_features)
else:
df.to_parquet(dst_path)
return idx
return idx, (dst_chunk, dst_file)
src_size = get_parquet_file_size_in_mb(src_path)
dst_size = get_parquet_file_size_in_mb(dst_path)
if dst_size + src_size >= max_mb:
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
dst_chunk, dst_file = idx["chunk"], idx["file"]
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
new_path.parent.mkdir(parents=True, exist_ok=True)
final_df = df
target_path = new_path
@@ -555,7 +619,7 @@ def append_or_create_parquet_file(
else:
final_df.to_parquet(target_path)
return idx
return idx, (dst_chunk, dst_file)
def finalize_aggregation(aggr_meta, all_metadata):
+126
View File
@@ -1396,6 +1396,132 @@ BYTES_PER_KIB = 1024
BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB
def modify_tasks(
dataset: LeRobotDataset,
new_task: str | None = None,
episode_tasks: dict[int, str] | None = None,
) -> LeRobotDataset:
"""Modify tasks in a LeRobotDataset.
This function allows you to either:
1. Set a single task for the entire dataset (using `new_task`)
2. Set specific tasks for specific episodes (using `episode_tasks`)
You can combine both: `new_task` sets the default, and `episode_tasks` overrides
specific episodes.
The dataset is modified in-place, updating only the task-related files:
- meta/tasks.parquet
- data/**/*.parquet (task_index column)
- meta/episodes/**/*.parquet (tasks column)
- meta/info.json (total_tasks)
Args:
dataset: The source LeRobotDataset to modify.
new_task: A single task string to apply to all episodes. If None and episode_tasks
is also None, raises an error.
episode_tasks: Optional dict mapping episode indices to their task strings.
Overrides `new_task` for specific episodes.
Examples:
Set a single task for all episodes:
dataset = modify_tasks(dataset, new_task="Pick up the cube")
Set different tasks for specific episodes:
dataset = modify_tasks(
dataset,
episode_tasks={0: "Task A", 1: "Task B", 2: "Task A"}
)
Set a default task with overrides:
dataset = modify_tasks(
dataset,
new_task="Default task",
episode_tasks={5: "Special task for episode 5"}
)
"""
if new_task is None and episode_tasks is None:
raise ValueError("Must specify at least one of new_task or episode_tasks")
if episode_tasks is not None:
valid_indices = set(range(dataset.meta.total_episodes))
invalid = set(episode_tasks.keys()) - valid_indices
if invalid:
raise ValueError(f"Invalid episode indices: {invalid}")
# Ensure episodes metadata is loaded
if dataset.meta.episodes is None:
dataset.meta.episodes = load_episodes(dataset.root)
# Build the mapping from episode index to task string
episode_to_task: dict[int, str] = {}
for ep_idx in range(dataset.meta.total_episodes):
if episode_tasks and ep_idx in episode_tasks:
episode_to_task[ep_idx] = episode_tasks[ep_idx]
elif new_task is not None:
episode_to_task[ep_idx] = new_task
else:
# Keep original task if not overridden and no default provided
original_tasks = dataset.meta.episodes[ep_idx]["tasks"]
if not original_tasks:
raise ValueError(f"Episode {ep_idx} has no tasks and no default task was provided")
episode_to_task[ep_idx] = original_tasks[0]
# Collect all unique tasks and create new task mapping
unique_tasks = sorted(set(episode_to_task.values()))
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
logging.info(f"Modifying tasks in {dataset.repo_id}")
logging.info(f"New tasks: {unique_tasks}")
root = dataset.root
# Update data files - modify task_index column
logging.info("Updating data files...")
data_dir = root / DATA_DIR
for parquet_path in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Updating data"):
df = pd.read_parquet(parquet_path)
# Build a mapping from episode_index to new task_index for rows in this file
episode_indices_in_file = df["episode_index"].unique()
ep_to_new_task_idx = {
ep_idx: task_to_index[episode_to_task[ep_idx]] for ep_idx in episode_indices_in_file
}
# Update task_index column
df["task_index"] = df["episode_index"].map(ep_to_new_task_idx)
df.to_parquet(parquet_path, index=False)
# Update episodes metadata - modify tasks column
logging.info("Updating episodes metadata...")
episodes_dir = root / "meta" / "episodes"
for parquet_path in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Updating episodes"):
df = pd.read_parquet(parquet_path)
# Update tasks column
df["tasks"] = df["episode_index"].apply(lambda ep_idx: [episode_to_task[ep_idx]])
df.to_parquet(parquet_path, index=False)
# Write new tasks.parquet
write_tasks(new_task_df, root)
# Update info.json
dataset.meta.info["total_tasks"] = len(unique_tasks)
write_info(dataset.meta.info, root)
# Reload metadata to reflect changes
dataset.meta.tasks = new_task_df
dataset.meta.episodes = load_episodes(root)
logging.info(f"Tasks: {unique_tasks}")
return dataset
def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path,
+9
View File
@@ -57,6 +57,7 @@ from lerobot.datasets.utils import (
load_info,
load_nested_dataset,
load_stats,
load_subtasks,
load_tasks,
update_chunk_file_indices,
validate_episode_buffer,
@@ -162,6 +163,7 @@ class LeRobotDatasetMetadata:
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.subtasks = load_subtasks(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
@@ -518,6 +520,7 @@ class LeRobotDatasetMetadata:
_validate_feature_names(features)
obj.tasks = None
obj.subtasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(
@@ -1075,6 +1078,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks.iloc[task_idx].name
# add subtask information if available
if "subtask_index" in self.features and self.meta.subtasks is not None:
subtask_idx = item["subtask_index"].item()
item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name
return item
def __repr__(self):
+9
View File
@@ -60,6 +60,7 @@ VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
@@ -353,6 +354,14 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
return tasks
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
"""Load subtasks from subtasks.parquet if it exists."""
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
if subtasks_path.exists():
return pd.read_parquet(subtasks_path)
return None
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
This function writes episode-level metadata to a single parquet file.
+38 -13
View File
@@ -28,8 +28,11 @@ from lerobot.utils.import_utils import _can_available
if TYPE_CHECKING or _can_available:
import can
else:
can.Message = object
can.interface = None
class can: # noqa: N801
Message = object
interface = None
import numpy as np
@@ -206,11 +209,31 @@ class DamiaoMotorsBus(MotorsBusBase):
Raises ConnectionError if any motor fails to respond.
"""
logger.info("Starting handshake with motors...")
missing_motors = []
# Drain any pending messages
while self.canbus.recv(timeout=0.01):
pass
missing_motors = []
for motor_name in self.motors:
msg = self._refresh_motor(motor_name)
if msg is None:
motor_id = self._get_motor_id(motor_name)
recv_id = self._get_motor_recv_id(motor_name)
# Send enable command
data = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, CAN_CMD_ENABLE]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
# Wait for response with longer timeout
response = None
start_time = time.time()
while time.time() - start_time < 0.1:
response = self.canbus.recv(timeout=0.1)
if response and response.arbitration_id == recv_id:
break
response = None
if response is None:
missing_motors.append(motor_name)
else:
self._process_response(motor_name, msg)
@@ -259,7 +282,7 @@ class DamiaoMotorsBus(MotorsBusBase):
motor_name = self._get_motor_name(motor)
recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [command_byte]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
if msg := self._recv_motor_response(expected_recv_id=recv_id):
self._process_response(motor_name, msg)
@@ -317,7 +340,7 @@ class DamiaoMotorsBus(MotorsBusBase):
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
return self._recv_motor_response(expected_recv_id=recv_id)
@@ -439,7 +462,7 @@ class DamiaoMotorsBus(MotorsBusBase):
motor_type = self._motor_types[motor_name]
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
recv_id = self._get_motor_recv_id(motor)
@@ -472,7 +495,7 @@ class DamiaoMotorsBus(MotorsBusBase):
motor_type = self._motor_types[motor_name]
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
@@ -637,10 +660,10 @@ class DamiaoMotorsBus(MotorsBusBase):
for motor in motors:
motor_id = self._get_motor_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
msg = can.Message(
arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd
)
self.canbus.send(msg)
# Small delay to reduce bus congestion if necessary, though removed in sync_read previously
# precise_sleep(PRECISE_SLEEP_SEC)
# Collect responses
expected_recv_ids = [self._get_motor_recv_id(m) for m in motors]
@@ -676,7 +699,9 @@ class DamiaoMotorsBus(MotorsBusBase):
kd = self._gains[motor]["kd"]
data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
msg = can.Message(
arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd
)
self.canbus.send(msg)
precise_sleep(PRECISE_TIMEOUT_SEC)
+6 -5
View File
@@ -239,8 +239,10 @@ class SACPolicy(
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def update_temperature(self):
self.temperature = self.log_alpha.exp().item()
@property
def temperature(self) -> float:
"""Return the current temperature value, always in sync with log_alpha."""
return self.log_alpha.exp().item()
def compute_loss_critic(
self,
@@ -457,11 +459,10 @@ class SACPolicy(
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.target_entropy = -np.prod(dim) / 2
def _init_temperature(self):
"""Set up temperature parameter and initial log_alpha."""
def _init_temperature(self) -> None:
"""Set up temperature parameter (log_alpha)."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
self.temperature = self.log_alpha.exp().item()
class SACObservationEncoder(nn.Module):
+2 -1
View File
@@ -168,11 +168,12 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
def create_transition(
+4 -6
View File
@@ -17,7 +17,7 @@ from dataclasses import dataclass
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@@ -92,7 +92,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
# copy over non-STATE features
for ft, feats in features.items():
if ft != PipelineFeatureType.STATE:
if ft != FeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
@@ -100,13 +100,11 @@ class LiberoProcessorStep(ObservationProcessorStep):
# add our new flattened state
state_feats[OBS_STATE] = PolicyFeature(
key=OBS_STATE,
type=FeatureType.STATE,
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
dtype="float32",
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
)
new_features[PipelineFeatureType.STATE] = state_feats
new_features[FeatureType.STATE] = state_feats
return new_features
+7 -5
View File
@@ -18,16 +18,18 @@
import math
import time
from dataclasses import dataclass
from typing import Any, Protocol, TypeVar, runtime_checkable
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
if TYPE_CHECKING:
from lerobot.teleoperators.teleoperator import Teleoperator
from .core import EnvTransition, PolicyAction, TransitionKey
from .pipeline import (
ComplementaryDataProcessorStep,
@@ -69,10 +71,10 @@ class HasTeleopEvents(Protocol):
# Type variable constrained to Teleoperator subclasses that also implement events
TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
TeleopWithEvents = TypeVar("TeleopWithEvents", bound="Teleoperator")
def _check_teleop_with_events(teleop: Teleoperator) -> None:
def _check_teleop_with_events(teleop: "Teleoperator") -> None:
"""
Runtime check that a teleoperator implements the `HasTeleopEvents` protocol.
@@ -103,7 +105,7 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
teleop_device: The teleoperator instance to get the action from.
"""
teleop_device: Teleoperator
teleop_device: "Teleoperator"
def complementary_data(self, complementary_data: dict) -> dict:
"""
@@ -34,6 +34,8 @@ from lerobot.utils.constants import (
ACTION_TOKEN_MASK,
ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
OBS_LANGUAGE_SUBTASK_TOKENS,
OBS_LANGUAGE_TOKENS,
)
from lerobot.utils.import_utils import _transformers_available
@@ -139,6 +141,32 @@ class TokenizerProcessorStep(ObservationProcessorStep):
return None
def get_subtask(self, transition: EnvTransition) -> list[str] | None:
"""
Extracts the subtask from the transition's complementary data.
Args:
transition: The environment transition.
Returns:
A list of subtask strings, or None if the subtask key is not found or the value is None.
"""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return None
subtask = complementary_data.get("subtask")
if subtask is None:
return None
# Standardize to a list of strings for the tokenizer
if isinstance(subtask, str):
return [subtask]
elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask):
return subtask
return None
def observation(self, observation: RobotObservation) -> RobotObservation:
"""
Tokenizes the task description and adds it to the observation dictionary.
@@ -176,6 +204,24 @@ class TokenizerProcessorStep(ObservationProcessorStep):
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
# Tokenize subtask if available
subtask = self.get_subtask(self.transition)
if subtask is not None:
tokenized_subtask = self._tokenize_text(subtask)
# Move new tokenized tensors to the detected device
if target_device is not None:
tokenized_subtask = {
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
for k, v in tokenized_subtask.items()
}
# Add tokenized subtask to the observation
new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"]
new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to(
dtype=torch.bool
)
return new_observation
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
-3
View File
@@ -545,9 +545,6 @@ def add_actor_information_and_train(
training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature
# Update temperature
policy.update_temperature()
# Push policy to actors if needed
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
@@ -0,0 +1,20 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bi_openarm_follower import BiOpenArmFollower
from .config_bi_openarm_follower import BiOpenArmFollowerConfig
__all__ = ["BiOpenArmFollower", "BiOpenArmFollowerConfig"]
@@ -0,0 +1,175 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from functools import cached_property
from lerobot.processor import RobotAction, RobotObservation
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from ..robot import Robot
from .config_bi_openarm_follower import BiOpenArmFollowerConfig
logger = logging.getLogger(__name__)
class BiOpenArmFollower(Robot):
"""
Bimanual OpenArm Follower Arms
"""
config_class = BiOpenArmFollowerConfig
name = "bi_openarm_follower"
def __init__(self, config: BiOpenArmFollowerConfig):
super().__init__(config)
self.config = config
left_arm_config = OpenArmFollowerConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.left_arm_config.port,
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target,
cameras=config.left_arm_config.cameras,
side=config.left_arm_config.side,
can_interface=config.left_arm_config.can_interface,
use_can_fd=config.left_arm_config.use_can_fd,
can_bitrate=config.left_arm_config.can_bitrate,
can_data_bitrate=config.left_arm_config.can_data_bitrate,
motor_config=config.left_arm_config.motor_config,
position_kd=config.left_arm_config.position_kd,
position_kp=config.left_arm_config.position_kp,
joint_limits=config.left_arm_config.joint_limits,
)
right_arm_config = OpenArmFollowerConfig(
id=f"{config.id}_right" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.right_arm_config.port,
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
max_relative_target=config.right_arm_config.max_relative_target,
cameras=config.right_arm_config.cameras,
side=config.right_arm_config.side,
can_interface=config.right_arm_config.can_interface,
use_can_fd=config.right_arm_config.use_can_fd,
can_bitrate=config.right_arm_config.can_bitrate,
can_data_bitrate=config.right_arm_config.can_data_bitrate,
motor_config=config.right_arm_config.motor_config,
position_kd=config.right_arm_config.position_kd,
position_kp=config.right_arm_config.position_kp,
joint_limits=config.right_arm_config.joint_limits,
)
self.left_arm = OpenArmFollower(left_arm_config)
self.right_arm = OpenArmFollower(right_arm_config)
# Only for compatibility with other parts of the codebase that expect a `robot.cameras` attribute
self.cameras = {**self.left_arm.cameras, **self.right_arm.cameras}
@property
def _motors_ft(self) -> dict[str, type]:
left_arm_motors_ft = self.left_arm._motors_ft
right_arm_motors_ft = self.right_arm._motors_ft
return {
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
}
@property
def _cameras_ft(self) -> dict[str, tuple]:
left_arm_cameras_ft = self.left_arm._cameras_ft
right_arm_cameras_ft = self.right_arm._cameras_ft
return {
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
)
def get_observation(self) -> RobotObservation:
obs_dict = {}
# Add "left_" prefix
left_obs = self.left_arm.get_observation()
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
# Add "right_" prefix
right_obs = self.right_arm.get_observation()
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
return obs_dict
def send_action(
self,
action: RobotAction,
custom_kp: dict[str, float] | None = None,
custom_kd: dict[str, float] | None = None,
) -> RobotAction:
# Remove "left_" prefix
left_action = {
key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_")
}
# Remove "right_" prefix
right_action = {
key.removeprefix("right_"): value for key, value in action.items() if key.startswith("right_")
}
sent_action_left = self.left_arm.send_action(left_action, custom_kp, custom_kd)
sent_action_right = self.right_arm.send_action(right_action, custom_kp, custom_kd)
# Add prefixes back
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -0,0 +1,30 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from lerobot.robots.openarm_follower import OpenArmFollowerConfigBase
from ..config import RobotConfig
@RobotConfig.register_subclass("bi_openarm_follower")
@dataclass
class BiOpenArmFollowerConfig(RobotConfig):
"""Configuration class for Bi OpenArm Follower robots."""
left_arm_config: OpenArmFollowerConfigBase
right_arm_config: OpenArmFollowerConfigBase
@@ -0,0 +1,20 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_openarm_follower import OpenArmFollowerConfig, OpenArmFollowerConfigBase
from .openarm_follower import OpenArmFollower
__all__ = ["OpenArmFollower", "OpenArmFollowerConfig", "OpenArmFollowerConfigBase"]
@@ -0,0 +1,122 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig
LEFT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = {
"joint_1": (-75.0, 75.0),
"joint_2": (-90.0, 9.0),
"joint_3": (-85.0, 85.0),
"joint_4": (0.0, 135.0),
"joint_5": (-85.0, 85.0),
"joint_6": (-40.0, 40.0),
"joint_7": (-80.0, 80.0),
"gripper": (-65.0, 0.0),
}
RIGHT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = {
"joint_1": (-75.0, 75.0),
"joint_2": (-9.0, 90.0),
"joint_3": (-85.0, 85.0),
"joint_4": (0.0, 135.0),
"joint_5": (-85.0, 85.0),
"joint_6": (-40.0, 40.0),
"joint_7": (-80.0, 80.0),
"gripper": (-65.0, 0.0),
}
@dataclass
class OpenArmFollowerConfigBase:
"""Base configuration for the OpenArms follower robot with Damiao motors."""
# CAN interfaces - one per arm
# arm CAN interface (e.g., "can1")
# Linux: "can0", "can1", etc.
port: str
# side of the arm: "left" or "right". If "None" default values will be used
side: str | None = None
# CAN interface type: "socketcan" (Linux), "slcan" (serial), or "auto" (auto-detect)
can_interface: str = "socketcan"
# CAN FD settings (OpenArms uses CAN FD by default)
use_can_fd: bool = True
can_bitrate: int = 1000000 # Nominal bitrate (1 Mbps)
can_data_bitrate: int = 5000000 # Data bitrate for CAN FD (5 Mbps)
# Whether to disable torque when disconnecting
disable_torque_on_disconnect: bool = True
# Safety limit for relative target positions
# Set to a positive scalar for all motors, or a dict mapping motor names to limits
max_relative_target: float | dict[str, float] | None = None
# Camera configurations
cameras: dict[str, CameraConfig] = field(default_factory=dict)
# Motor configuration for OpenArms (7 DOF per arm)
# Maps motor names to (send_can_id, recv_can_id, motor_type)
# Based on: https://docs.openarm.dev/software/setup/configure-test
# OpenArms uses 4 types of motors:
# - DM8009 (DM-J8009P-2EC) for shoulders (high torque)
# - DM4340P and DM4340 for shoulder rotation and elbow
# - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper
motor_config: dict[str, tuple[int, int, str]] = field(
default_factory=lambda: {
"joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009)
"joint_2": (0x02, 0x12, "dm8009"), # J2 - Shoulder lift (DM8009)
"joint_3": (0x03, 0x13, "dm4340"), # J3 - Shoulder rotation (DM4340)
"joint_4": (0x04, 0x14, "dm4340"), # J4 - Elbow flex (DM4340)
"joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310)
"joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310)
"joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310)
"gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310)
}
)
# MIT control parameters for position control (used in send_action)
# List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
position_kp: list[float] = field(
default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 25.0]
)
position_kd: list[float] = field(default_factory=lambda: [5.0, 5.0, 3.0, 5.0, 0.3, 0.3, 0.3, 0.3])
# Values for joint limits. Can be overridden via CLI (for custom values) or by setting config.side to either 'left' or 'right'.
# If config.side is left set to None and no CLI values are passed, the default joint limit values are small for safety.
joint_limits: dict[str, tuple[float, float]] = field(
default_factory=lambda: {
"joint_1": (-5.0, 5.0),
"joint_2": (-5.0, 5.0),
"joint_3": (-5.0, 5.0),
"joint_4": (0.0, 5.0),
"joint_5": (-5.0, 5.0),
"joint_6": (-5.0, 5.0),
"joint_7": (-5.0, 5.0),
"gripper": (-5.0, 0.0),
}
)
@RobotConfig.register_subclass("openarm_follower")
@dataclass
class OpenArmFollowerConfig(RobotConfig, OpenArmFollowerConfigBase):
pass
@@ -0,0 +1,348 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from functools import cached_property
from typing import Any
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.damiao import DamiaoMotorsBus
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
from .config_openarm_follower import (
LEFT_DEFAULT_JOINTS_LIMITS,
RIGHT_DEFAULT_JOINTS_LIMITS,
OpenArmFollowerConfig,
)
logger = logging.getLogger(__name__)
class OpenArmFollower(Robot):
"""
OpenArms Follower Robot which uses CAN bus communication to control 7 DOF arm with a gripper.
The arm uses Damiao motors in MIT control mode.
"""
config_class = OpenArmFollowerConfig
name = "openarm_follower"
def __init__(self, config: OpenArmFollowerConfig):
super().__init__(config)
self.config = config
# Arm motors
motors: dict[str, Motor] = {}
for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items():
motor = Motor(
send_id, motor_type_str, MotorNormMode.DEGREES
) # Always use degrees for Damiao motors
motor.recv_id = recv_id
motor.motor_type_str = motor_type_str
motors[motor_name] = motor
self.bus = DamiaoMotorsBus(
port=self.config.port,
motors=motors,
calibration=self.calibration,
can_interface=self.config.can_interface,
use_can_fd=self.config.use_can_fd,
bitrate=self.config.can_bitrate,
data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None,
)
if config.side is not None:
if config.side == "left":
config.joint_limits = LEFT_DEFAULT_JOINTS_LIMITS
elif config.side == "right":
config.joint_limits = RIGHT_DEFAULT_JOINTS_LIMITS
else:
raise ValueError(
"config.side must be either 'left', 'right' (for default values) or 'None' (for CLI values)"
)
else:
logger.info(
"Set config.side to either 'left' or 'right' to use pre-configured values for joint limits."
)
logger.info(f"Values used for joint limits: {config.joint_limits}.")
# Initialize cameras
self.cameras = make_cameras_from_configs(config.cameras)
@property
def _motors_ft(self) -> dict[str, type]:
"""Motor features for observation and action spaces."""
features: dict[str, type] = {}
for motor in self.bus.motors:
features[f"{motor}.pos"] = float
features[f"{motor}.vel"] = float # Add this
features[f"{motor}.torque"] = float # Add this
return features
@property
def _cameras_ft(self) -> dict[str, tuple]:
"""Camera features for observation space."""
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
"""Combined observation features from motors and cameras."""
return {**self._motors_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
"""Action features."""
return self._motors_ft
@property
def is_connected(self) -> bool:
"""Check if robot is connected."""
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
def connect(self, calibrate: bool = True) -> None:
"""
Connect to the robot and optionally calibrate.
We assume that at connection time, the arms are in a safe rest position,
and torque can be safely disabled to run calibration if needed.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
# Connect to CAN bus
logger.info(f"Connecting arm on {self.config.port}...")
self.bus.connect()
# Run calibration if needed
if not self.is_calibrated and calibrate:
logger.info(
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
)
self.calibrate()
for cam in self.cameras.values():
cam.connect()
self.configure()
if self.is_calibrated:
self.bus.set_zero_position()
self.bus.enable_torque()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
"""Check if robot is calibrated."""
return self.bus.is_calibrated
def calibrate(self) -> None:
"""
Run calibration procedure for OpenArms robot.
The calibration procedure:
1. Disable torque
2. Ask user to position arms in hanging position with grippers closed
3. Set this as zero position
4. Record range of motion for each joint
5. Save calibration
"""
if self.calibration:
# Calibration file exists, ask user whether to use it or run new calibration
user_input = input(
f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
)
if user_input.strip().lower() != "c":
logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
self.bus.write_calibration(self.calibration)
return
logger.info(f"\nRunning calibration for {self}")
self.bus.disable_torque()
# Step 1: Set zero position
input(
"\nCalibration: Set Zero Position)\n"
"Position the arm in the following configuration:\n"
" - Arm hanging straight down\n"
" - Gripper closed\n"
"Press ENTER when ready..."
)
# Set current position as zero for all motors
self.bus.set_zero_position()
logger.info("Arm zero position set.")
logger.info("Setting range: -90° to +90° for safety by default for all joints")
for motor_name, motor in self.bus.motors.items():
self.calibration[motor_name] = MotorCalibration(
id=motor.id,
drive_mode=0,
homing_offset=0,
range_min=-90,
range_max=90,
)
self.bus.write_calibration(self.calibration)
self._save_calibration()
print(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
"""Configure motors with appropriate settings."""
# TODO(Steven, Pepijn): Slightly different from what it is happening in the leader
with self.bus.torque_disabled():
self.bus.configure_motors()
def setup_motors(self) -> None:
raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
)
def get_observation(self) -> RobotObservation:
"""
Get current observation from robot including position, velocity, and torque.
Reads all motor states (pos/vel/torque) in one CAN refresh cycle
instead of 3 separate reads.
"""
start = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
obs_dict: dict[str, Any] = {}
states = self.bus.sync_read_all_states()
for motor in self.bus.motors:
state = states.get(motor, {})
obs_dict[f"{motor}.pos"] = state.get("position", 0.0)
obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0)
obs_dict[f"{motor}.torque"] = state.get("torque", 0.0)
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} get_observation took: {dt_ms:.1f}ms")
return obs_dict
def send_action(
self,
action: RobotAction,
custom_kp: dict[str, float] | None = None,
custom_kd: dict[str, float] | None = None,
) -> RobotAction:
"""
Send action command to robot.
The action magnitude may be clipped based on safety limits.
Args:
action: Dictionary with motor positions (e.g., "joint_1.pos", "joint_2.pos")
custom_kp: Optional custom kp gains per motor (e.g., {"joint_1": 120.0, "joint_2": 150.0})
custom_kd: Optional custom kd gains per motor (e.g., {"joint_1": 1.5, "joint_2": 2.0})
Returns:
The action actually sent (potentially clipped)
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
# Apply joint limit clipping to arm
for motor_name, position in goal_pos.items():
if motor_name in self.config.joint_limits:
min_limit, max_limit = self.config.joint_limits[motor_name]
clipped_position = max(min_limit, min(max_limit, position))
if clipped_position != position:
logger.debug(f"Clipped {motor_name} from {position:.2f}° to {clipped_position:.2f}°")
goal_pos[motor_name] = clipped_position
# Cap goal position when too far away from present position.
# /!\ Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
present_pos = self.bus.sync_read("Present_Position")
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
# TODO(Steven, Pepijn): Refactor writing
# Motor name to index mapping for gains
motor_index = {
"joint_1": 0,
"joint_2": 1,
"joint_3": 2,
"joint_4": 3,
"joint_5": 4,
"joint_6": 5,
"joint_7": 6,
"gripper": 7,
}
# Use batch MIT control for arm (sends all commands, then collects responses)
commands = {}
for motor_name, position_degrees in goal_pos.items():
idx = motor_index.get(motor_name, 0)
# Use custom gains if provided, otherwise use config defaults
if custom_kp is not None and motor_name in custom_kp:
kp = custom_kp[motor_name]
else:
kp = (
self.config.position_kp[idx]
if isinstance(self.config.position_kp, list)
else self.config.position_kp
)
if custom_kd is not None and motor_name in custom_kd:
kd = custom_kd[motor_name]
else:
kd = (
self.config.position_kd[idx]
if isinstance(self.config.position_kd, list)
else self.config.position_kd
)
commands[motor_name] = (kp, kd, position_degrees, 0.0, 0.0)
self.bus._mit_control_batch(commands)
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
def disconnect(self):
"""Disconnect from robot."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Disconnect CAN bus
self.bus.disconnect(self.config.disable_torque_on_disconnect)
# Disconnect cameras
for cam in self.cameras.values():
cam.disconnect()
logger.info(f"{self} disconnected.")
@@ -65,3 +65,6 @@ class UnitreeG1Config(RobotConfig):
# Cameras (ZMQ-based remote cameras)
cameras: dict[str, CameraConfig] = field(default_factory=dict)
# Compensates for gravity on the unitree's arms using the arm ik solver
gravity_compensation: bool = False
+1 -1
View File
@@ -18,7 +18,7 @@ from enum import IntEnum
# ruff: noqa: N801, N815
NUM_MOTORS = 35
NUM_MOTORS = 29
class G1_29_JointArmIndex(IntEnum):
@@ -0,0 +1,313 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import numpy as np
logger = logging.getLogger(__name__)
parent2_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(parent2_dir)
class WeightedMovingFilter:
def __init__(self, weights, data_size=14):
self._window_size = len(weights)
self._weights = np.array(weights)
self._data_size = data_size
self._filtered_data = np.zeros(self._data_size)
self._data_queue = []
def _apply_filter(self):
if len(self._data_queue) < self._window_size:
return self._data_queue[-1]
data_array = np.array(self._data_queue)
temp_filtered_data = np.zeros(self._data_size)
for i in range(self._data_size):
temp_filtered_data[i] = np.convolve(data_array[:, i], self._weights, mode="valid")[-1]
return temp_filtered_data
def add_data(self, new_data):
assert len(new_data) == self._data_size
if len(self._data_queue) > 0 and np.array_equal(
new_data, self._data_queue[-1]
): # skip duplicate data
return
if len(self._data_queue) >= self._window_size:
self._data_queue.pop(0)
self._data_queue.append(new_data)
self._filtered_data = self._apply_filter()
@property
def filtered_data(self):
return self._filtered_data
class G1_29_ArmIK: # noqa: N801
def __init__(self, unit_test=False):
import casadi
import pinocchio as pin
from huggingface_hub import snapshot_download
from pinocchio import casadi as cpin
self._pin = pin
np.set_printoptions(precision=5, suppress=True, linewidth=200)
self.unit_test = unit_test
self.repo_path = snapshot_download("lerobot/unitree-g1-mujoco")
urdf_path = os.path.join(self.repo_path, "assets", "g1_body29_hand14.urdf")
mesh_dir = os.path.join(self.repo_path, "assets")
self.robot = self._pin.RobotWrapper.BuildFromURDF(urdf_path, mesh_dir)
self.mixed_jointsToLockIDs = [
"left_hip_pitch_joint",
"left_hip_roll_joint",
"left_hip_yaw_joint",
"left_knee_joint",
"left_ankle_pitch_joint",
"left_ankle_roll_joint",
"right_hip_pitch_joint",
"right_hip_roll_joint",
"right_hip_yaw_joint",
"right_knee_joint",
"right_ankle_pitch_joint",
"right_ankle_roll_joint",
"waist_yaw_joint",
"waist_roll_joint",
"waist_pitch_joint",
"left_hand_thumb_0_joint",
"left_hand_thumb_1_joint",
"left_hand_thumb_2_joint",
"left_hand_middle_0_joint",
"left_hand_middle_1_joint",
"left_hand_index_0_joint",
"left_hand_index_1_joint",
"right_hand_thumb_0_joint",
"right_hand_thumb_1_joint",
"right_hand_thumb_2_joint",
"right_hand_index_0_joint",
"right_hand_index_1_joint",
"right_hand_middle_0_joint",
"right_hand_middle_1_joint",
]
self.reduced_robot = self.robot.buildReducedRobot(
list_of_joints_to_lock=self.mixed_jointsToLockIDs,
reference_configuration=np.array([0.0] * self.robot.model.nq),
)
# Arm joint names in G1 motor order (G1_29_JointArmIndex)
self._arm_joint_names_g1 = [
"left_shoulder_pitch_joint",
"left_shoulder_roll_joint",
"left_shoulder_yaw_joint",
"left_elbow_joint",
"left_wrist_roll_joint",
"left_wrist_pitch_joint",
"left_wrist_yaw_joint",
"right_shoulder_pitch_joint",
"right_shoulder_roll_joint",
"right_shoulder_yaw_joint",
"right_elbow_joint",
"right_wrist_roll_joint",
"right_wrist_pitch_joint",
"right_wrist_yaw_joint",
]
# Pinocchio uses its own joint order in q; build index mapping.
self._arm_joint_names_pin = sorted(
self._arm_joint_names_g1,
key=lambda name: self.reduced_robot.model.idx_qs[self.reduced_robot.model.getJointId(name)],
)
logger.info(f"Pinocchio arm joint order: {self._arm_joint_names_pin}")
self._arm_reorder_g1_to_pin = [
self._arm_joint_names_g1.index(name) for name in self._arm_joint_names_pin
]
# Inverse mapping to return tau in G1 motor order.
self._arm_reorder_pin_to_g1 = np.argsort(self._arm_reorder_g1_to_pin)
self.reduced_robot.model.addFrame(
self._pin.Frame(
"L_ee",
self.reduced_robot.model.getJointId("left_wrist_yaw_joint"),
self._pin.SE3(np.eye(3), np.array([0.05, 0, 0]).T),
self._pin.FrameType.OP_FRAME,
)
)
self.reduced_robot.model.addFrame(
self._pin.Frame(
"R_ee",
self.reduced_robot.model.getJointId("right_wrist_yaw_joint"),
self._pin.SE3(np.eye(3), np.array([0.05, 0, 0]).T),
self._pin.FrameType.OP_FRAME,
)
)
# Creating Casadi models and data for symbolic computing
self.cmodel = cpin.Model(self.reduced_robot.model)
self.cdata = self.cmodel.createData()
# Creating symbolic variables
self.cq = casadi.SX.sym("q", self.reduced_robot.model.nq, 1)
self.cTf_l = casadi.SX.sym("tf_l", 4, 4)
self.cTf_r = casadi.SX.sym("tf_r", 4, 4)
cpin.framesForwardKinematics(self.cmodel, self.cdata, self.cq)
# Get the hand joint ID and define the error function
self.L_hand_id = self.reduced_robot.model.getFrameId("L_ee")
self.R_hand_id = self.reduced_robot.model.getFrameId("R_ee")
self.translational_error = casadi.Function(
"translational_error",
[self.cq, self.cTf_l, self.cTf_r],
[
casadi.vertcat(
self.cdata.oMf[self.L_hand_id].translation - self.cTf_l[:3, 3],
self.cdata.oMf[self.R_hand_id].translation - self.cTf_r[:3, 3],
)
],
)
self.rotational_error = casadi.Function(
"rotational_error",
[self.cq, self.cTf_l, self.cTf_r],
[
casadi.vertcat(
cpin.log3(self.cdata.oMf[self.L_hand_id].rotation @ self.cTf_l[:3, :3].T),
cpin.log3(self.cdata.oMf[self.R_hand_id].rotation @ self.cTf_r[:3, :3].T),
)
],
)
# Defining the optimization problem
self.opti = casadi.Opti()
self.var_q = self.opti.variable(self.reduced_robot.model.nq)
self.var_q_last = self.opti.parameter(self.reduced_robot.model.nq) # for smooth
self.param_tf_l = self.opti.parameter(4, 4)
self.param_tf_r = self.opti.parameter(4, 4)
self.translational_cost = casadi.sumsqr(
self.translational_error(self.var_q, self.param_tf_l, self.param_tf_r)
)
self.rotation_cost = casadi.sumsqr(
self.rotational_error(self.var_q, self.param_tf_l, self.param_tf_r)
)
self.regularization_cost = casadi.sumsqr(self.var_q)
self.smooth_cost = casadi.sumsqr(self.var_q - self.var_q_last)
# Setting optimization constraints and goals
self.opti.subject_to(
self.opti.bounded(
self.reduced_robot.model.lowerPositionLimit,
self.var_q,
self.reduced_robot.model.upperPositionLimit,
)
)
self.opti.minimize(
50 * self.translational_cost
+ self.rotation_cost
+ 0.02 * self.regularization_cost
+ 0.1 * self.smooth_cost
)
opts = {
"ipopt": {"print_level": 0, "max_iter": 50, "tol": 1e-6},
"print_time": False, # print or not
"calc_lam_p": False, # https://github.com/casadi/casadi/wiki/FAQ:-Why-am-I-getting-%22NaN-detected%22in-my-optimization%3F
}
self.opti.solver("ipopt", opts)
self.init_data = np.zeros(self.reduced_robot.model.nq)
self.smooth_filter = WeightedMovingFilter(np.array([0.4, 0.3, 0.2, 0.1]), 14)
def solve_ik(self, left_wrist, right_wrist, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None):
if current_lr_arm_motor_q is not None:
self.init_data = current_lr_arm_motor_q
self.opti.set_initial(self.var_q, self.init_data)
self.opti.set_value(self.param_tf_l, left_wrist)
self.opti.set_value(self.param_tf_r, right_wrist)
self.opti.set_value(self.var_q_last, self.init_data) # for smooth
try:
self.opti.solve()
sol_q = self.opti.value(self.var_q)
self.smooth_filter.add_data(sol_q)
sol_q = self.smooth_filter.filtered_data
if current_lr_arm_motor_dq is not None:
v = current_lr_arm_motor_dq * 0.0
else:
v = (sol_q - self.init_data) * 0.0
self.init_data = sol_q
sol_tauff = self._pin.rnea(
self.reduced_robot.model,
self.reduced_robot.data,
sol_q,
v,
np.zeros(self.reduced_robot.model.nv),
)
return sol_q, sol_tauff
except Exception as e:
logger.error(f"ERROR in convergence, plotting debug info.{e}")
sol_q = self.opti.debug.value(self.var_q)
self.smooth_filter.add_data(sol_q)
sol_q = self.smooth_filter.filtered_data
if current_lr_arm_motor_dq is not None:
v = current_lr_arm_motor_dq * 0.0
else:
v = (sol_q - self.init_data) * 0.0
self.init_data = sol_q
logger.error(
f"sol_q:{sol_q} \nmotorstate: \n{current_lr_arm_motor_q} \nleft_pose: \n{left_wrist} \nright_pose: \n{right_wrist}"
)
return current_lr_arm_motor_q, np.zeros(self.reduced_robot.model.nv)
def solve_tau(self, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None):
try:
q_g1 = np.array(current_lr_arm_motor_q, dtype=float)
if q_g1.shape[0] != len(self._arm_joint_names_g1):
raise ValueError(f"Expected {len(self._arm_joint_names_g1)} arm joints, got {q_g1.shape[0]}")
q_pin = q_g1[self._arm_reorder_g1_to_pin]
sol_tauff = self._pin.rnea(
self.reduced_robot.model,
self.reduced_robot.data,
q_pin,
np.zeros(self.reduced_robot.model.nv),
np.zeros(self.reduced_robot.model.nv),
)
return sol_tauff[self._arm_reorder_pin_to_g1]
except Exception as e:
logger.error(f"ERROR in convergence, plotting debug info.{e}")
return np.zeros(self.reduced_robot.model.nv)
+18 -1
View File
@@ -27,7 +27,8 @@ import numpy as np
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.envs.factory import make_env
from lerobot.processor import RobotAction, RobotObservation
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex, G1_29_JointIndex
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
from ..robot import Robot
from .config_unitree_g1 import UnitreeG1Config
@@ -127,6 +128,8 @@ class UnitreeG1(Robot):
self.subscribe_thread = None
self.remote_controller = self.RemoteController()
self.arm_ik = G1_29_ArmIK()
def _subscribe_motor_state(self): # polls robot state @ 250Hz
while not self._shutdown_event.is_set():
start_time = time.time()
@@ -361,6 +364,20 @@ class UnitreeG1(Robot):
self.msg.motor_cmd[motor.value].kd = self.kd[motor.value]
self.msg.motor_cmd[motor.value].tau = 0
if self.config.gravity_compensation:
# Build action_np from motor commands (arm joints are indices 15-28, local indices 0-13)
action_np = np.zeros(14)
arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value # 15
for joint in G1_29_JointArmIndex:
local_idx = joint.value - arm_start_idx
action_np[local_idx] = self.msg.motor_cmd[joint.value].q
tau = self.arm_ik.solve_tau(action_np)
# Apply tau back to motor commands
for joint in G1_29_JointArmIndex:
local_idx = joint.value - arm_start_idx
self.msg.motor_cmd[joint.value].tau = tau[local_idx]
self.msg.crc = self.crc.Crc(self.msg)
self.lowcmd_publisher.Write(self.msg)
return action
+8
View File
@@ -60,6 +60,14 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
from .reachy2 import Reachy2Robot
return Reachy2Robot(config)
elif config.type == "openarm_follower":
from .openarm_follower import OpenArmFollower
return OpenArmFollower(config)
elif config.type == "bi_openarm_follower":
from .bi_openarm_follower import BiOpenArmFollower
return BiOpenArmFollower(config)
elif config.type == "mock_robot":
from tests.mocks.mock_robot import MockRobot
+10 -2
View File
@@ -36,23 +36,28 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_openarm_follower,
bi_so_follower,
hope_jr,
koch_follower,
lekiwi,
make_robot_from_config,
omx_follower,
openarm_follower,
so_follower,
)
from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_so_leader,
homunculus,
koch_leader,
make_teleoperator_from_config,
omx_leader,
openarm_leader,
so_leader,
unitree_g1,
)
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.utils import init_logging
@@ -81,8 +86,11 @@ def calibrate(cfg: CalibrateConfig):
device = make_teleoperator_from_config(cfg.device)
device.connect(calibrate=False)
device.calibrate()
device.disconnect()
try:
device.calibrate()
finally:
device.disconnect()
def main():
+79 -3
View File
@@ -18,7 +18,7 @@
Edit LeRobot datasets using various transformation tools.
This script allows you to delete episodes, split datasets, merge datasets,
remove features, and convert image datasets to video format.
remove features, modify tasks, and convert image datasets to video format.
When new_repo_id is specified, creates a new dataset.
Usage Examples:
@@ -66,6 +66,25 @@ Remove camera feature:
--operation.type remove_feature \
--operation.feature_names "['observation.images.top']"
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type modify_tasks \
--operation.new_task "Pick up the cube and place it"
Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type modify_tasks \
--operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}'
Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type modify_tasks \
--operation.new_task "Default task" \
--operation.episode_tasks '{"5": "Special task for episode 5"}'
Convert image dataset to video format and save locally:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \
@@ -100,6 +119,7 @@ from lerobot.datasets.dataset_tools import (
convert_image_to_video_dataset,
delete_episodes,
merge_datasets,
modify_tasks,
remove_feature,
split_dataset,
)
@@ -132,6 +152,13 @@ class RemoveFeatureConfig:
feature_names: list[str] | None = None
@dataclass
class ModifyTasksConfig:
type: str = "modify_tasks"
new_task: str | None = None
episode_tasks: dict[str, str] | None = None
@dataclass
class ConvertImageToVideoConfig:
type: str = "convert_image_to_video"
@@ -151,7 +178,12 @@ class ConvertImageToVideoConfig:
class EditDatasetConfig:
repo_id: str
operation: (
DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig
DeleteEpisodesConfig
| SplitConfig
| MergeConfig
| RemoveFeatureConfig
| ModifyTasksConfig
| ConvertImageToVideoConfig
)
root: str | None = None
new_repo_id: str | None = None
@@ -296,6 +328,48 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
if not isinstance(cfg.operation, ModifyTasksConfig):
raise ValueError("Operation config must be ModifyTasksConfig")
new_task = cfg.operation.new_task
episode_tasks_raw = cfg.operation.episode_tasks
if new_task is None and episode_tasks_raw is None:
raise ValueError("Must specify at least one of new_task or episode_tasks for modify_tasks operation")
# Warn about in-place modification behavior
if cfg.new_repo_id is not None:
logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.")
# Convert episode_tasks keys from string to int if needed (CLI passes strings)
episode_tasks: dict[int, str] | None = None
if episode_tasks_raw is not None:
episode_tasks = {int(k): v for k, v in episode_tasks_raw.items()}
logging.info(f"Modifying tasks in {cfg.repo_id}")
if new_task:
logging.info(f" Default task: '{new_task}'")
if episode_tasks:
logging.info(f" Episode-specific tasks: {episode_tasks}")
modified_dataset = modify_tasks(
dataset,
new_task=new_task,
episode_tasks=episode_tasks,
)
logging.info(f"Dataset modified at {dataset.root}")
logging.info(f"Tasks: {list(modified_dataset.meta.tasks.index)}")
if cfg.push_to_hub:
logging.info(f"Pushing to hub as {cfg.repo_id}")
modified_dataset.push_to_hub()
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
# Note: Parser may create any config type with the right fields, so we access fields directly
# instead of checking isinstance()
@@ -371,12 +445,14 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
handle_merge(cfg)
elif operation_type == "remove_feature":
handle_remove_feature(cfg)
elif operation_type == "modify_tasks":
handle_modify_tasks(cfg)
elif operation_type == "convert_image_to_video":
handle_convert_image_to_video(cfg)
else:
raise ValueError(
f"Unknown operation type: {operation_type}\n"
f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video"
f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video"
)
@@ -44,19 +44,23 @@ import numpy as np
from lerobot.model.kinematics import RobotKinematics
from lerobot.robots import ( # noqa: F401
RobotConfig,
bi_openarm_follower,
bi_so_follower,
koch_follower,
make_robot_from_config,
omx_follower,
openarm_follower,
so_follower,
)
from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
bi_openarm_leader,
bi_so_leader,
gamepad,
koch_leader,
make_teleoperator_from_config,
omx_leader,
openarm_leader,
so_leader,
)
from lerobot.utils.robot_utils import precise_sleep
+6 -1
View File
@@ -98,26 +98,31 @@ from lerobot.processor.rename_processor import rename_stats
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_openarm_follower,
bi_so_follower,
earthrover_mini_plus,
hope_jr,
koch_follower,
make_robot_from_config,
omx_follower,
openarm_follower,
reachy2,
so_follower,
unitree_g1,
unitree_g1 as unitree_g1_robot,
)
from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_so_leader,
homunculus,
koch_leader,
make_teleoperator_from_config,
omx_leader,
openarm_leader,
reachy2_teleoperator,
so_leader,
unitree_g1,
)
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
from lerobot.utils.constants import ACTION, OBS_STR
+17 -14
View File
@@ -53,12 +53,14 @@ from lerobot.processor import (
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_openarm_follower,
bi_so_follower,
earthrover_mini_plus,
hope_jr,
koch_follower,
make_robot_from_config,
omx_follower,
openarm_follower,
reachy2,
so_follower,
unitree_g1,
@@ -108,25 +110,26 @@ def replay(cfg: ReplayConfig):
robot.connect()
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(len(episode_frames)):
start_episode_t = time.perf_counter()
try:
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(len(episode_frames)):
start_episode_t = time.perf_counter()
action_array = actions[idx][ACTION]
action = {}
for i, name in enumerate(dataset.features[ACTION]["names"]):
action[name] = action_array[i]
action_array = actions[idx][ACTION]
action = {}
for i, name in enumerate(dataset.features[ACTION]["names"]):
action[name] = action_array[i]
robot_obs = robot.get_observation()
robot_obs = robot.get_observation()
processed_action = robot_action_processor((action, robot_obs))
processed_action = robot_action_processor((action, robot_obs))
_ = robot.send_action(processed_action)
_ = robot.send_action(processed_action)
dt_s = time.perf_counter() - start_episode_t
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
robot.disconnect()
dt_s = time.perf_counter() - start_episode_t
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
finally:
robot.disconnect()
def main():
@@ -70,18 +70,22 @@ from lerobot.processor import (
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_openarm_follower,
bi_so_follower,
earthrover_mini_plus,
hope_jr,
koch_follower,
make_robot_from_config,
omx_follower,
openarm_follower,
reachy2,
so_follower,
unitree_g1 as unitree_g1_robot,
)
from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_so_leader,
gamepad,
homunculus,
@@ -89,8 +93,10 @@ from lerobot.teleoperators import ( # noqa: F401
koch_leader,
make_teleoperator_from_config,
omx_leader,
openarm_leader,
reachy2_teleoperator,
so_leader,
unitree_g1,
)
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import precise_sleep
@@ -0,0 +1,20 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bi_openarm_leader import BiOpenArmLeader
from .config_bi_openarm_leader import BiOpenArmLeaderConfig
__all__ = ["BiOpenArmLeader", "BiOpenArmLeaderConfig"]
@@ -0,0 +1,131 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from functools import cached_property
from lerobot.processor import RobotAction
from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig
from ..openarm_leader import OpenArmLeader
from ..teleoperator import Teleoperator
from .config_bi_openarm_leader import BiOpenArmLeaderConfig
logger = logging.getLogger(__name__)
class BiOpenArmLeader(Teleoperator):
"""
Bimanual OpenArm Leader Arms
"""
config_class = BiOpenArmLeaderConfig
name = "bi_openarm_leader"
def __init__(self, config: BiOpenArmLeaderConfig):
super().__init__(config)
self.config = config
left_arm_config = OpenArmLeaderConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.left_arm_config.port,
can_interface=config.left_arm_config.can_interface,
use_can_fd=config.left_arm_config.use_can_fd,
can_bitrate=config.left_arm_config.can_bitrate,
can_data_bitrate=config.left_arm_config.can_data_bitrate,
motor_config=config.left_arm_config.motor_config,
manual_control=config.left_arm_config.manual_control,
position_kd=config.left_arm_config.position_kd,
position_kp=config.left_arm_config.position_kp,
)
right_arm_config = OpenArmLeaderConfig(
id=f"{config.id}_right" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.right_arm_config.port,
can_interface=config.right_arm_config.can_interface,
use_can_fd=config.right_arm_config.use_can_fd,
can_bitrate=config.right_arm_config.can_bitrate,
can_data_bitrate=config.right_arm_config.can_data_bitrate,
motor_config=config.right_arm_config.motor_config,
manual_control=config.right_arm_config.manual_control,
position_kd=config.right_arm_config.position_kd,
position_kp=config.right_arm_config.position_kp,
)
self.left_arm = OpenArmLeader(left_arm_config)
self.right_arm = OpenArmLeader(right_arm_config)
@cached_property
def action_features(self) -> dict[str, type]:
left_arm_features = self.left_arm.action_features
right_arm_features = self.right_arm.action_features
return {
**{f"left_{k}": v for k, v in left_arm_features.items()},
**{f"right_{k}": v for k, v in right_arm_features.items()},
}
@cached_property
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
)
def get_action(self) -> RobotAction:
action_dict = {}
# Add "left_" prefix
left_action = self.left_arm.get_action()
action_dict.update({f"left_{key}": value for key, value in left_action.items()})
# Add "right_" prefix
right_action = self.right_arm.get_action()
action_dict.update({f"right_{key}": value for key, value in right_action.items()})
return action_dict
def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO: Implement force feedback
raise NotImplementedError
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -0,0 +1,30 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfigBase
from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("bi_openarm_leader")
@dataclass
class BiOpenArmLeaderConfig(TeleoperatorConfig):
"""Configuration class for Bi OpenArm Follower robots."""
left_arm_config: OpenArmLeaderConfigBase
right_arm_config: OpenArmLeaderConfigBase
@@ -0,0 +1,20 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_openarm_leader import OpenArmLeaderConfig, OpenArmLeaderConfigBase
from .openarm_leader import OpenArmLeader
__all__ = ["OpenArmLeader", "OpenArmLeaderConfig", "OpenArmLeaderConfigBase"]
@@ -0,0 +1,75 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from ..config import TeleoperatorConfig
@dataclass
class OpenArmLeaderConfigBase:
"""Base configuration for the OpenArms leader/teleoperator with Damiao motors."""
# CAN interfaces - one per arm
# Arm CAN interface (e.g., "can3")
# Linux: "can0", "can1", etc.
port: str
# CAN interface type: "socketcan" (Linux), "slcan" (serial), or "auto" (auto-detect)
can_interface: str = "socketcan"
# CAN FD settings (OpenArms uses CAN FD by default)
use_can_fd: bool = True
can_bitrate: int = 1000000 # Nominal bitrate (1 Mbps)
can_data_bitrate: int = 5000000 # Data bitrate for CAN FD (5 Mbps)
# Motor configuration for OpenArms (7 DOF per arm)
# Maps motor names to (send_can_id, recv_can_id, motor_type)
# Based on: https://docs.openarm.dev/software/setup/configure-test
# OpenArms uses 4 types of motors:
# - DM8009 (DM-J8009P-2EC) for shoulders (high torque)
# - DM4340P and DM4340 for shoulder rotation and elbow
# - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper
motor_config: dict[str, tuple[int, int, str]] = field(
default_factory=lambda: {
"joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009)
"joint_2": (0x02, 0x12, "dm8009"), # J2 - Shoulder lift (DM8009)
"joint_3": (0x03, 0x13, "dm4340"), # J3 - Shoulder rotation (DM4340)
"joint_4": (0x04, 0x14, "dm4340"), # J4 - Elbow flex (DM4340)
"joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310)
"joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310)
"joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310)
"gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310)
}
)
# Torque mode settings for manual control
# When enabled, motors have torque disabled for manual movement
manual_control: bool = True
# TODO(Steven, Pepijn): Not used ... ?
# MIT control parameters (used when manual_control=False for torque control)
# List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
position_kp: list[float] = field(
default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 16.0]
)
position_kd: list[float] = field(default_factory=lambda: [3.0, 3.0, 3.0, 3.0, 0.2, 0.2, 0.2, 0.2])
@TeleoperatorConfig.register_subclass("openarm_leader")
@dataclass
class OpenArmLeaderConfig(TeleoperatorConfig, OpenArmLeaderConfigBase):
pass
@@ -0,0 +1,225 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from typing import Any
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.damiao import DamiaoMotorsBus
from lerobot.processor import RobotAction
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from .config_openarm_leader import OpenArmLeaderConfig
logger = logging.getLogger(__name__)
class OpenArmLeader(Teleoperator):
"""
OpenArm Leader/Teleoperator Arm with Damiao motors.
This teleoperator uses CAN bus communication to read positions from
Damiao motors that are manually moved (torque disabled).
"""
config_class = OpenArmLeaderConfig
name = "openarm_leader"
def __init__(self, config: OpenArmLeaderConfig):
super().__init__(config)
self.config = config
# Arm motors
motors: dict[str, Motor] = {}
for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items():
motor = Motor(
send_id, motor_type_str, MotorNormMode.DEGREES
) # Always use degrees for Damiao motors
motor.recv_id = recv_id
motor.motor_type_str = motor_type_str
motors[motor_name] = motor
self.bus = DamiaoMotorsBus(
port=self.config.port,
motors=motors,
calibration=self.calibration,
can_interface=self.config.can_interface,
use_can_fd=self.config.use_can_fd,
bitrate=self.config.can_bitrate,
data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None,
)
@property
def action_features(self) -> dict[str, type]:
"""Features produced by this teleoperator."""
features: dict[str, type] = {}
for motor in self.bus.motors:
features[f"{motor}.pos"] = float
features[f"{motor}.vel"] = float
features[f"{motor}.torque"] = float
return features
@property
def feedback_features(self) -> dict[str, type]:
"""Feedback features (not implemented for OpenArms)."""
return {}
@property
def is_connected(self) -> bool:
"""Check if teleoperator is connected."""
return self.bus.is_connected
def connect(self, calibrate: bool = True) -> None:
"""
Connect to the teleoperator.
For manual control, we disable torque after connecting so the
arm can be moved by hand.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
# Connect to CAN bus
logger.info(f"Connecting arm on {self.config.port}...")
self.bus.connect()
# Run calibration if needed
if not self.is_calibrated and calibrate:
logger.info(
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
)
self.calibrate()
self.configure()
if self.is_calibrated:
self.bus.set_zero_position()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
"""Check if teleoperator is calibrated."""
return self.bus.is_calibrated
def calibrate(self) -> None:
"""
Run calibration procedure for OpenArms leader.
The calibration procedure:
1. Disable torque (if not already disabled)
2. Ask user to position arm in zero position (hanging with gripper closed)
3. Set this as zero position
4. Record range of motion for each joint
5. Save calibration
"""
if self.calibration:
# Calibration file exists, ask user whether to use it or run new calibration
user_input = input(
f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
)
if user_input.strip().lower() != "c":
logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
self.bus.write_calibration(self.calibration)
return
logger.info(f"\nRunning calibration for {self}")
self.bus.disable_torque()
# Step 1: Set zero position
input(
"\nCalibration: Set Zero Position)\n"
"Position the arm in the following configuration:\n"
" - Arm hanging straight down\n"
" - Gripper closed\n"
"Press ENTER when ready..."
)
# Set current position as zero for all motors
self.bus.set_zero_position()
logger.info("Arm zero position set.")
logger.info("Setting range: -90° to +90° by default for all joints")
# TODO(Steven, Pepijn): Check if MotorCalibration is actually needed here given that we only use Degrees
for motor_name, motor in self.bus.motors.items():
self.calibration[motor_name] = MotorCalibration(
id=motor.id,
drive_mode=0,
homing_offset=0,
range_min=-90,
range_max=90,
)
self.bus.write_calibration(self.calibration)
self._save_calibration()
print(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
"""
Configure motors for manual teleoperation.
For manual control, we disable torque so the arm can be moved by hand.
"""
return self.bus.disable_torque() if self.config.manual_control else self.bus.configure_motors()
def setup_motors(self) -> None:
raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
)
def get_action(self) -> RobotAction:
"""
Get current action from the leader arm.
This is the main method for teleoperators - it reads the current state
of the leader arm and returns it as an action that can be sent to a follower.
Reads all motor states (pos/vel/torque) in one CAN refresh cycle.
"""
start = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
action_dict: dict[str, Any] = {}
# Use sync_read_all_states to get pos/vel/torque in one go
states = self.bus.sync_read_all_states()
for motor in self.bus.motors:
state = states.get(motor, {})
action_dict[f"{motor}.pos"] = state.get("position")
action_dict[f"{motor}.vel"] = state.get("velocity")
action_dict[f"{motor}.torque"] = state.get("torque")
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
return action_dict
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not yet implemented for OpenArm leader.")
def disconnect(self) -> None:
"""Disconnect from teleoperator."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Disconnect CAN bus
# For manual control, ensure torque is disabled before disconnecting
self.bus.disconnect(disable_torque=self.config.manual_control)
logger.info(f"{self} disconnected.")
@@ -0,0 +1,21 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_unitree_g1 import ExoskeletonArmPortConfig, UnitreeG1TeleoperatorConfig
from .exo_calib import ExoskeletonCalibration, ExoskeletonJointCalibration
from .exo_ik import ExoskeletonIKHelper
from .exo_serial import ExoskeletonArm
from .unitree_g1 import UnitreeG1Teleoperator
@@ -0,0 +1,37 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from ..config import TeleoperatorConfig
@dataclass
class ExoskeletonArmPortConfig:
"""Serial port configuration for individual exoskeleton arm."""
port: str = ""
baud_rate: int = 115200
@TeleoperatorConfig.register_subclass("unitree_g1")
@dataclass
class UnitreeG1TeleoperatorConfig(TeleoperatorConfig):
left_arm_config: ExoskeletonArmPortConfig = field(default_factory=ExoskeletonArmPortConfig)
right_arm_config: ExoskeletonArmPortConfig = field(default_factory=ExoskeletonArmPortConfig)
# Frozen joints (comma-separated joint names that won't be moved by IK)
frozen_joints: str = ""
@@ -0,0 +1,446 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module handles calibration of hall effect sensors used in the exoskeleton.
Each joint has a pair of ADC channels outputting sin and cos values that trace an ellipse
as the joint rotates due to imprecision in magnet/sensor placement. We fit this ellipse to a unit circle,
and calculate arctan2 of the unit circle to get the joint angle.
We then store the ellipse parameters and the zero offset for each joint to be used at runtime.
"""
import json
import logging
import time
from collections import deque
from dataclasses import dataclass, field
from pathlib import Path
import numpy as np
import serial
logger = logging.getLogger(__name__)
# exoskeleton joint names -> ADC channel pairs. TODO: add wrist pitch and wrist yaw
JOINTS = {
"shoulder_pitch": (0, 1),
"shoulder_yaw": (2, 3),
"shoulder_roll": (4, 5),
"elbow_flex": (6, 7),
"wrist_roll": (14, 15),
}
@dataclass
class ExoskeletonJointCalibration:
name: str # joint name
center_fit: list[float] # center of the ellipse
T: list[list[float]] # 2x2 transformation matrix
zero_offset: float = 0.0 # angle at neutral pose
@dataclass
class ExoskeletonCalibration:
"""Full calibration data for an exoskeleton arm."""
version: int = 2
side: str = ""
adc_max: int = 2**12 - 1
joints: list[ExoskeletonJointCalibration] = field(default_factory=list)
def to_dict(self) -> dict:
return {
"version": self.version,
"side": self.side,
"adc_max": self.adc_max,
"joints": [
{
"name": j.name,
"center_fit": j.center_fit,
"T": j.T,
"zero_offset": j.zero_offset,
}
for j in self.joints
],
}
@classmethod
def from_dict(cls, data: dict) -> "ExoskeletonCalibration":
joints = [
ExoskeletonJointCalibration(
name=j["name"],
center_fit=j["center_fit"],
T=j["T"],
zero_offset=j.get("zero_offset", 0.0),
)
for j in data.get("joints", [])
]
return cls(
version=data.get("version", 2),
side=data.get("side", ""),
adc_max=data.get("adc_max", 2**12 - 1),
joints=joints,
)
@dataclass(frozen=True)
class CalibParams:
fit_every: float = 0.15
min_fit_points: int = 60
fit_window: int = 900
max_fit_points: int = 300
trim_low: float = 0.05
trim_high: float = 0.95
median_window: int = 5
history: int = 3500
draw_hz: float = 120.0
sample_count: int = 50
def normalize_angle(angle: float) -> float:
while angle > np.pi:
angle -= 2 * np.pi
while angle < -np.pi:
angle += 2 * np.pi
return angle
def joint_z_and_angle(raw16: list[int], j: ExoskeletonJointCalibration) -> tuple[np.ndarray, float]:
"""
Applies calibration to each joint: raw centered ellipse-to-circle angle.
"""
pair = JOINTS[j.name]
s, c = raw16[pair[0]], raw16[pair[1]] # get sin and cos
p = np.array([float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2]) # center the raw values
z = np.asarray(j.T) @ (
p - np.asarray(j.center_fit)
) # center the ellipse and invert the transformation matrix to get unit circle coords
ang = float(np.arctan2(z[1], z[0])) - j.zero_offset # calculate the anvgle and apply the zero offset
return z, normalize_angle(-ang) # ensure range is [-pi, pi]
def exo_raw_to_angles(raw16: list[int], calib: ExoskeletonCalibration) -> dict[str, float]:
"""Convert raw sensor readings to joint angles using calibration."""
return {j.name: joint_z_and_angle(raw16, j)[1] for j in calib.joints}
def run_exo_calibration(
ser: serial.Serial,
side: str,
save_path: Path,
params: CalibParams | None = None,
) -> ExoskeletonCalibration:
"""
Run interactive calibration for an exoskeleton arm.
"""
try:
import cv2
import matplotlib.pyplot as plt
except ImportError as e:
raise ImportError(
"Calibration requires matplotlib and opencv-python. "
"Install with: pip install matplotlib opencv-python"
) from e
from .exo_serial import read_raw_from_serial
params = params or CalibParams()
joint_list = list(JOINTS.items()) # Convert dict to list for indexing
logger.info(f"Starting calibration for {side} exoskeleton arm")
def running_median(win: deque) -> float:
return float(np.median(np.fromiter(win, dtype=float)))
def read_joint_point(raw16: list[int], pair: tuple[int, int]):
s, c = raw16[pair[0]], raw16[pair[1]]
return float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2, float(s), float(c)
def select_fit_subset(xs, ys):
"""Select and filter points for ellipse fitting. Trims outliers by radius and downsamples."""
n = min(params.fit_window, len(xs))
if n <= 0:
return None, None
x = np.asarray(list(xs)[-n:], dtype=float) # most recent n samples
y = np.asarray(list(ys)[-n:], dtype=float)
r = np.sqrt(x * x + y * y) # radius from origin
if len(r) >= 20:
lo, hi = np.quantile(r, params.trim_low), np.quantile(r, params.trim_high) # outlier bounds
keep = (r >= lo) & (r <= hi)
x, y = x[keep], y[keep] # remove outliers
if len(x) > params.max_fit_points:
idx = np.linspace(0, len(x) - 1, params.max_fit_points).astype(int) # downsample evenly
x, y = x[idx], y[idx]
return x, y
def fit_ellipse_opencv(x, y):
"""Fit ellipse to (x,y) points using OpenCV. Returns center, axes, rotation matrix, and outline."""
x, y = np.asarray(x, dtype=float), np.asarray(y, dtype=float)
if len(x) < 5:
return None
pts = np.stack([x, y], axis=1).astype(np.float32).reshape(-1, 1, 2)
try:
(xc, yc), (w, h), angle_deg = cv2.fitEllipse(pts) # returns center, axes, rotation in degrees
except cv2.error:
return None
a, b = float(w) * 0.5, float(h) * 0.5 # get ellipse major and minor semi-axes
phi = np.deg2rad(float(angle_deg)) # to rad
if b > a: # ensure major axis is a
a, b = b, a
phi += np.pi / 2.0
if not np.isfinite(a) or not np.isfinite(b) or a <= 1e-6 or b <= 1e-6:
return None
cp, sp = float(np.cos(phi)), float(np.sin(phi)) #
rot = np.array([[cp, -sp], [sp, cp]], dtype=float) # 2x2 rotation matrix
center = np.array([float(xc), float(yc)], dtype=float) # offset vector
tt = np.linspace(0, 2 * np.pi, 360)
outline = (rot @ np.stack([a * np.cos(tt), b * np.sin(tt)])).T + center # for viz
return {"center": center, "a": a, "b": b, "R": rot, "ex": outline[:, 0], "ey": outline[:, 1]}
# Setup matplotlib
plt.ion()
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(12, 6))
ax0.set_xlabel("cos - center")
ax0.set_ylabel("sin - center")
ax0.grid(True, alpha=0.25)
ax0.set_aspect("equal", adjustable="box")
ax1.set_title("Unit circle + angle")
ax1.set_xlabel("x")
ax1.set_ylabel("y")
ax1.grid(True, alpha=0.25)
ax1.set_aspect("equal", adjustable="box")
tt = np.linspace(0, 2 * np.pi, 360)
ax1.plot(np.cos(tt), np.sin(tt), "k-", linewidth=1)
ax0.set_xlim(-2200, 2200)
ax0.set_ylim(-2200, 2200)
ax1.set_xlim(-1.4, 1.4)
ax1.set_ylim(-1.4, 1.4)
sc0 = ax0.scatter([], [], s=6, animated=True)
(ell_line,) = ax0.plot([], [], "r-", linewidth=2, animated=True)
sc1 = ax1.scatter([], [], s=6, animated=True)
(radius_line,) = ax1.plot([], [], "g-", linewidth=2, animated=True)
angle_text = ax1.text(
0.02, 0.98, "", transform=ax1.transAxes, va="top", ha="left", fontsize=12, animated=True
)
fig.canvas.draw()
bg0 = fig.canvas.copy_from_bbox(ax0.bbox)
bg1 = fig.canvas.copy_from_bbox(ax1.bbox)
# State
joints_out = []
joint_idx = 0
phase = "ellipse"
advance_requested = False
zero_samples = []
def on_key(event):
nonlocal advance_requested
if event.key in ("n", "N", "enter", " "):
advance_requested = True
fig.canvas.mpl_connect("key_press_event", on_key)
def reset_state():
return {
"xs": deque(maxlen=params.history),
"ys": deque(maxlen=params.history),
"xu": deque(maxlen=params.history),
"yu": deque(maxlen=params.history),
"win_s": deque(maxlen=params.median_window),
"win_c": deque(maxlen=params.median_window),
"ellipse_cache": None,
"T": None,
"center_fit": None,
"have_transform": False,
"latest_z": None,
"last_fit": 0.0,
}
state = reset_state()
last_draw = 0.0
name, pair = joint_list[joint_idx]
fig.canvas.manager.set_window_title(f"[{joint_idx + 1}/{len(joint_list)}] {name} - ELLIPSE")
ax0.set_title(f"{name} raw (filtered)")
logger.info(f"[{joint_idx + 1}/{len(joint_list)}] Calibrating {name}")
logger.info("Step 1: Move joint around to map ellipse, then press 'n'")
try:
while plt.fignum_exists(fig.number):
name, pair = joint_list[joint_idx]
# Handles calibration GUI state: ellipse → zero_pose → next joint -> ellipse -> ...
if phase == "ellipse" and advance_requested and state["have_transform"]:
joints_out.append(
{
"name": name,
"center_fit": state["center_fit"].tolist(),
"T": state["T"].tolist(),
}
)
logger.info(f" -> Ellipse saved for {name}")
phase, zero_samples, advance_requested = "zero_pose", [], False
fig.canvas.manager.set_window_title(f"[{joint_idx + 1}/{len(joint_list)}] {name} - ZERO POSE")
ax0.set_title(f"{name} - hold zero pose")
fig.canvas.draw()
bg0, bg1 = fig.canvas.copy_from_bbox(ax0.bbox), fig.canvas.copy_from_bbox(ax1.bbox)
logger.info(f"Step 2: Hold {name} in zero position, then press 'n'")
elif phase == "ellipse" and advance_requested and not state["have_transform"]:
logger.info(" (Need valid fit first - keep moving the joint)")
advance_requested = False
elif phase == "zero_pose" and advance_requested:
if len(zero_samples) >= params.sample_count:
zero_offset = float(np.mean(zero_samples[-params.sample_count :]))
joints_out[-1]["zero_offset"] = zero_offset
logger.info(f" -> {name} zero: {zero_offset:+.3f} rad ({np.degrees(zero_offset):+.1f}°)")
joint_idx += 1
advance_requested = False
if joint_idx >= len(joint_list):
# All joints done
calib = ExoskeletonCalibration(
version=2,
side=side,
adc_max=2**12 - 1,
joints=[
ExoskeletonJointCalibration(
name=j["name"],
center_fit=j["center_fit"],
T=j["T"],
zero_offset=j.get("zero_offset", 0.0),
)
for j in joints_out
],
)
save_path.parent.mkdir(parents=True, exist_ok=True)
with open(save_path, "w") as f:
json.dump(calib.to_dict(), f, indent=2)
logger.info(f"Saved calibration to {save_path}")
logger.info("Calibration complete!")
plt.close(fig)
return calib
# Next joint
phase, state = "ellipse", reset_state()
name, pair = joint_list[joint_idx]
fig.canvas.manager.set_window_title(
f"[{joint_idx + 1}/{len(joint_list)}] {name} - ELLIPSE"
)
ax0.set_title(f"{name} raw (filtered)")
fig.canvas.draw()
bg0, bg1 = fig.canvas.copy_from_bbox(ax0.bbox), fig.canvas.copy_from_bbox(ax1.bbox)
logger.info(f"[{joint_idx + 1}/{len(joint_list)}] Calibrating {name}")
logger.info("Step 1: Move joint around to map ellipse, then press 'n'")
else:
logger.info(
f" (Collecting samples: {len(zero_samples)}/{params.sample_count} - hold still)"
)
advance_requested = False
# Read sensor
raw16 = read_raw_from_serial(ser)
if raw16 is not None:
x_raw, y_raw, s_raw, c_raw = read_joint_point(raw16, pair)
if phase == "ellipse":
if state["have_transform"]:
z = state["T"] @ (np.array([x_raw, y_raw]) - state["center_fit"])
state["xu"].append(float(z[0]))
state["yu"].append(float(z[1]))
state["latest_z"] = (float(z[0]), float(z[1]))
state["win_s"].append(s_raw)
state["win_c"].append(c_raw)
if len(state["win_s"]) >= max(3, params.median_window):
state["ys"].append(running_median(state["win_s"]) - (2**12 - 1) / 2)
state["xs"].append(running_median(state["win_c"]) - (2**12 - 1) / 2)
else:
jdata = joints_out[-1]
z = np.array(jdata["T"]) @ (np.array([x_raw, y_raw]) - np.array(jdata["center_fit"]))
zero_samples.append(float(np.arctan2(z[1], z[0])))
state["latest_z"] = (float(z[0]), float(z[1]))
# Ellipse fitting
t = time.time()
if (
phase == "ellipse"
and (t - state["last_fit"]) >= params.fit_every
and len(state["xs"]) >= params.min_fit_points
):
xfit, yfit = select_fit_subset(state["xs"], state["ys"])
if xfit is not None and len(xfit) >= params.min_fit_points:
fit = fit_ellipse_opencv(xfit, yfit)
if fit is not None:
state["center_fit"] = fit["center"]
state["T"] = np.diag([1.0 / fit["a"], 1.0 / fit["b"]]) @ fit["R"].T
state["ellipse_cache"] = (fit["ex"], fit["ey"])
state["have_transform"] = True
state["last_fit"] = t
# Drawing
if (t - last_draw) >= 1.0 / params.draw_hz:
fig.canvas.restore_region(bg0)
fig.canvas.restore_region(bg1)
if phase == "ellipse":
sc0.set_offsets(np.c_[state["xs"], state["ys"]] if state["xs"] else np.empty((0, 2)))
ax0.draw_artist(sc0)
ell_line.set_data(*state["ellipse_cache"] if state["ellipse_cache"] else ([], []))
ax0.draw_artist(ell_line)
sc1.set_offsets(np.c_[state["xu"], state["yu"]] if state["xu"] else np.empty((0, 2)))
ax1.draw_artist(sc1)
if state["latest_z"]:
zx, zy = state["latest_z"]
radius_line.set_data([0.0, zx], [0.0, zy])
ang = float(np.arctan2(zy, zx))
angle_text.set_text(
f"angle: {ang:+.3f} rad ({np.degrees(ang):+.1f}°)\nmove {name}, press 'n' to advance"
)
else:
radius_line.set_data([], [])
angle_text.set_text("(waiting for fit)")
else:
sc0.set_offsets(np.empty((0, 2)))
ax0.draw_artist(sc0)
ell_line.set_data([], [])
ax0.draw_artist(ell_line)
if state["latest_z"]:
zx, zy = state["latest_z"]
sc1.set_offsets([[zx, zy]])
radius_line.set_data([0.0, zx], [0.0, zy])
ang = float(np.arctan2(zy, zx))
angle_text.set_text(
f"Zero pose for {name}\nangle: {ang:+.3f} rad\nsamples: {len(zero_samples)}/{params.sample_count}\nhold still, press 'n'"
)
else:
sc1.set_offsets(np.empty((0, 2)))
radius_line.set_data([], [])
angle_text.set_text("(waiting for data)")
ax1.draw_artist(sc1)
ax1.draw_artist(radius_line)
ax1.draw_artist(angle_text)
fig.canvas.blit(ax0.bbox)
fig.canvas.blit(ax1.bbox)
fig.canvas.flush_events()
last_draw = t
plt.pause(0.001)
finally:
plt.close(fig)
@@ -0,0 +1,353 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
IK helper for exoskeleton-to-G1 teleoperation. We map Exoskeleton joint angles to end-effector pose in world frame,
visualizing the result in meshcat after calibration.
"""
import logging
import os
from dataclasses import dataclass
import numpy as np
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
from .exo_calib import JOINTS
logger = logging.getLogger(__name__)
def _frame_id(model, name: str) -> int | None:
try:
fid = model.getFrameId(name)
return fid if 0 <= fid < model.nframes else None
except Exception:
return None
@dataclass
class ArmCfg:
side: str # "left" | "right"
urdf: str # exo_left.urdf / exo_right.urdf
root: str # "exo_left" / "exo_right"
g1_ee: str # "l_ee" / "r_ee"
offset: np.ndarray # world offset for viz + target
marker_prefix: str # "left" / "right"
class Markers:
"""Creates meshcat visualization primitives, showing end-effector frames of exoskeleton and G1"""
def __init__(self, viewer):
self.v = viewer
def sphere(self, path: str, r: float, rgba: tuple[float, float, float, float]):
import meshcat.geometry as mg
c = (int(rgba[0] * 255) << 16) | (int(rgba[1] * 255) << 8) | int(rgba[2] * 255)
self.v[path].set_object(
mg.Sphere(r),
mg.MeshPhongMaterial(color=c, opacity=rgba[3], transparent=rgba[3] < 1.0),
)
def axes(self, path: str, axis_len: float = 0.1, axis_w: int = 6):
import meshcat.geometry as mg
pts = np.array(
[[0, 0, 0], [axis_len, 0, 0], [0, 0, 0], [0, axis_len, 0], [0, 0, 0], [0, 0, axis_len]],
dtype=np.float32,
).T
cols = np.array(
[[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]],
dtype=np.float32,
).T
self.v[path].set_object(
mg.LineSegments(
mg.PointsGeometry(position=pts, color=cols),
mg.LineBasicMaterial(linewidth=axis_w, vertexColors=True),
)
)
def tf(self, path: str, mat: np.ndarray):
self.v[path].set_transform(mat)
class ExoskeletonIKHelper:
"""
- Loads G1 robot and exoskeleton URDF models via Pinocchio
- Computes forward kinematics on exoskeleton to get end-effector poses
- Solves inverse kinematics on G1 to match those poses
- Provides meshcat visualization showing both robots and targets
Args:
frozen_joints: List of G1 joint names to exclude from IK (kept at neutral).
"""
def __init__(self, frozen_joints: list[str] | None = None):
try:
import pinocchio as pin
except ImportError as e:
raise ImportError("ik mode needs pinocchio: pip install pin") from e
self.pin = pin
self.frozen_joints = frozen_joints or []
self.g1_ik = G1_29_ArmIK()
self.robot_g1 = self.g1_ik.reduced_robot
self.robot_g1.data = self.robot_g1.model.createData()
self.q_g1 = pin.neutral(self.robot_g1.model)
assets_dir = os.path.join(self.g1_ik.repo_path, "assets")
self.frozen_idx = self._frozen_joint_indices()
self.arms = [
ArmCfg(
side="left",
urdf=os.path.join(assets_dir, "exo_left.urdf"),
root="exo_left",
g1_ee="L_ee",
offset=np.array([0.6, 0.3, 0.0]),
marker_prefix="left",
),
ArmCfg(
side="right",
urdf=os.path.join(assets_dir, "exo_right.urdf"),
root="exo_right",
g1_ee="R_ee",
offset=np.array([0.6, -0.3, 0.0]),
marker_prefix="right",
),
]
self.exo = {} # side -> pin.RobotWrapper
self.q_exo = {} # side -> q
self.ee_id_exo = {} # side -> frame id
self.qmap = {} # side -> {joint_name: q_idx}
self.ee_id_g1 = {} # side -> frame id
self._load_exo_models(assets_dir)
for a in self.arms:
self.ee_id_g1[a.side] = _frame_id(self.robot_g1.model, a.g1_ee)
self.viewer = None
self.markers: Markers | None = None
self.viz_g1 = None
self.viz_exo = {} # side -> viz
def _frozen_joint_indices(self) -> dict[str, int]:
out = {}
m = self.robot_g1.model
for name in self.frozen_joints:
if name in m.names:
jid = m.getJointId(name)
out[name] = m.idx_qs[jid]
logger.info(f"freezing joint: {name} (q_idx={out[name]})")
return out
def _find_exo_ee(self, model, ee_name: str = "ee") -> int:
ee = _frame_id(model, ee_name)
if ee is not None:
return ee
for fid in reversed(range(model.nframes)):
if model.frames[fid].type == self.pin.FrameType.BODY:
return fid
return 0
def _build_joint_map(self, robot) -> dict[str, int]:
m = robot.model
return {n: m.idx_qs[m.getJointId(n)] for n in JOINTS if n in m.names}
def _load_exo_models(self, assets_dir: str):
pin = self.pin
for a in self.arms:
if not os.path.exists(a.urdf):
logger.warning(f"{a.side} exo urdf not found: {a.urdf}")
continue
r = pin.RobotWrapper.BuildFromURDF(a.urdf, assets_dir)
self.exo[a.side] = r
self.q_exo[a.side] = pin.neutral(r.model)
self.ee_id_exo[a.side] = self._find_exo_ee(r.model)
self.qmap[a.side] = self._build_joint_map(r)
logger.info(f"loaded {a.side} exo urdf: {a.urdf}")
def init_visualization(self):
"""
Creates a browser-based visualization of exoskeleton and G1 robot,
highlighting end-effector frames and target positions.
"""
try:
from pinocchio.visualize import MeshcatVisualizer
except ImportError as e:
logger.warning(f"meshcat viz unavailable: {e}")
return
# g1
self.viz_g1 = MeshcatVisualizer(
self.robot_g1.model, self.robot_g1.collision_model, self.robot_g1.visual_model
)
self.viz_g1.initViewer(open=True)
self.viz_g1.loadViewerModel("g1")
self.viz_g1.display(self.q_g1)
self.viewer = self.viz_g1.viewer
self.markers = Markers(self.viewer)
# exos
for a in self.arms:
if a.side not in self.exo:
continue
r = self.exo[a.side]
v = MeshcatVisualizer(r.model, r.collision_model, r.visual_model)
v.initViewer(open=False)
v.viewer = self.viewer
v.loadViewerModel(a.root)
offset_tf = np.eye(4)
offset_tf[:3, 3] = a.offset
self.viewer[a.root].set_transform(offset_tf)
v.display(self.q_exo[a.side])
self.viz_exo[a.side] = v
# markers
for a in self.arms:
p = a.marker_prefix
self.markers.sphere(f"markers/{p}_exo_ee", 0.012, (0.2, 1.0, 0.2, 0.9))
self.markers.sphere(f"markers/{p}_g1_ee", 0.015, (1.0, 0.2, 0.2, 0.9))
self.markers.sphere(f"markers/{p}_ik_target", 0.015, (0.1, 0.3, 1.0, 0.9))
self.markers.axes(f"markers/{p}_exo_axes", 0.06)
self.markers.axes(f"markers/{p}_g1_axes", 0.08)
logger.info(f"meshcat viz initialized: {self.viewer.url()}")
print(f"\nmeshcat url: {self.viewer.url()}\n")
def _fk_target_world(self, side: str, angles: dict[str, float]) -> np.ndarray | None:
"""returns wrist frame target to be used for G1 IK in 4x4 homogeneous transform. Takes offset into account."""
if side not in self.exo or not angles:
return None
pin = self.pin
q = self.q_exo[side]
qmap = self.qmap[side]
for name, ang in angles.items():
idx = qmap.get(name)
if idx is not None:
q[idx] = float(ang)
r = self.exo[side]
pin.forwardKinematics(r.model, r.data, q)
pin.updateFramePlacements(r.model, r.data)
ee = r.data.oMf[self.ee_id_exo[side]]
target = np.eye(4)
target[:3, :3] = ee.rotation
# offset gets applied in world space
cfg = next(a for a in self.arms if a.side == side)
target[:3, 3] = cfg.offset + ee.translation
return target
def update_visualization(self):
if self.viewer is None or self.markers is None:
return
pin = self.pin
# g1
if self.viz_g1 is not None:
self.viz_g1.display(self.q_g1)
pin.forwardKinematics(self.robot_g1.model, self.robot_g1.data, self.q_g1)
pin.updateFramePlacements(self.robot_g1.model, self.robot_g1.data)
for a in self.arms:
fid = self.ee_id_g1.get(a.side)
if fid is None:
continue
ee_tf = self.robot_g1.data.oMf[fid].homogeneous
p = a.marker_prefix
self.markers.tf(f"markers/{p}_g1_ee", ee_tf)
self.markers.tf(f"markers/{p}_g1_axes", ee_tf)
# exos
for a in self.arms:
side = a.side
v = self.viz_exo.get(side)
if v is None:
continue
v.display(self.q_exo[side])
r = self.exo[side]
pin.forwardKinematics(r.model, r.data, self.q_exo[side])
pin.updateFramePlacements(r.model, r.data)
ee = r.data.oMf[self.ee_id_exo[side]]
world_tf = (pin.SE3(np.eye(3), a.offset) * ee).homogeneous
p = a.marker_prefix
self.markers.tf(f"markers/{p}_exo_ee", world_tf)
self.markers.tf(f"markers/{p}_exo_axes", world_tf)
target_tf = np.eye(4)
target_tf[:3, :3] = ee.rotation
target_tf[:3, 3] = a.offset + ee.translation
self.markers.tf(f"markers/{p}_ik_target", target_tf)
def compute_g1_joints_from_exo(
self,
left_angles: dict[str, float],
right_angles: dict[str, float],
) -> dict[str, float]:
"""
Performs FK on exoskeleton to get end-effector poses in world frame,
after which it solves IK on G1 to return joint angles matching those poses in G1 motor order.
"""
pin = self.pin
targets = {
"left": self._fk_target_world("left", left_angles),
"right": self._fk_target_world("right", right_angles),
}
# fallback to current g1 ee pose if missing target
pin.forwardKinematics(self.robot_g1.model, self.robot_g1.data, self.q_g1)
pin.updateFramePlacements(self.robot_g1.model, self.robot_g1.data)
for a in self.arms:
if targets[a.side] is not None:
continue
fid = self.ee_id_g1.get(a.side)
if fid is not None:
targets[a.side] = self.robot_g1.data.oMf[fid].homogeneous
if targets["left"] is None or targets["right"] is None:
logger.warning("missing ik targets, returning current pose")
return {}
frozen_vals = {n: self.q_g1[i] for n, i in self.frozen_idx.items()}
self.q_g1, _ = self.g1_ik.solve_ik(
targets["left"], targets["right"], current_lr_arm_motor_q=self.q_g1
)
for n, i in self.frozen_idx.items():
self.q_g1[i] = frozen_vals[n]
return {
f"{j.name}.q": float(self.q_g1[i])
for i, j in enumerate(G1_29_JointArmIndex)
if i < len(self.q_g1)
}
@@ -0,0 +1,119 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from dataclasses import dataclass
from pathlib import Path
import serial
from .exo_calib import ExoskeletonCalibration, exo_raw_to_angles, run_exo_calibration
logger = logging.getLogger(__name__)
def parse_raw16(line: bytes) -> list[int] | None:
try:
parts = line.decode("utf-8", errors="ignore").split()
if len(parts) < 16:
return None
return [int(x) for x in parts[:16]]
except Exception:
return None
def read_raw_from_serial(ser) -> list[int] | None:
"""Read latest sample from serial; if buffer is backed up, keep only the newest."""
last = None
while ser.in_waiting > 0:
b = ser.readline()
if not b:
break
raw16 = parse_raw16(b)
if raw16 is not None:
last = raw16
if last is None:
b = ser.readline()
if b:
last = parse_raw16(b)
return last
@dataclass
class ExoskeletonArm:
port: str
calibration_fpath: Path
side: str
baud_rate: int = 115200
_ser: serial.Serial | None = None
calibration: ExoskeletonCalibration | None = None
def __post_init__(self):
if self.calibration_fpath.is_file():
self._load_calibration()
@property
def is_connected(self) -> bool:
return self._ser is not None and getattr(self._ser, "is_open", False)
@property
def is_calibrated(self) -> bool:
return self.calibration is not None
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
return
try:
self._ser = serial.Serial(self.port, self.baud_rate, timeout=0.02)
self._ser.reset_input_buffer()
logger.info(f"connected: {self.port}")
except serial.SerialException as e:
raise ConnectionError(f"failed to connect to {self.port}: {e}") from e
if calibrate and not self.is_calibrated:
self.calibrate()
def disconnect(self) -> None:
if self._ser:
try:
self._ser.close()
finally:
self._ser = None
def _load_calibration(self) -> None:
try:
data = json.loads(self.calibration_fpath.read_text())
self.calibration = ExoskeletonCalibration.from_dict(data)
logger.info(f"loaded calibration: {self.calibration_fpath}")
except Exception as e:
logger.warning(f"failed to load calibration: {e}")
def read_raw(self) -> list[int] | None:
if not self._ser:
return None
return read_raw_from_serial(self._ser)
def get_angles(self) -> dict[str, float]:
if not self.calibration:
raise RuntimeError("exoskeleton not calibrated")
raw = self.read_raw()
return {} if raw is None else exo_raw_to_angles(raw, self.calibration)
def calibrate(self) -> None:
ser = self._ser
self.calibration = run_exo_calibration(ser, self.side, self.calibration_fpath)
@@ -0,0 +1,157 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from functools import cached_property
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS
from ..teleoperator import Teleoperator
from .config_unitree_g1 import UnitreeG1TeleoperatorConfig
from .exo_ik import ExoskeletonIKHelper
from .exo_serial import ExoskeletonArm
logger = logging.getLogger(__name__)
class UnitreeG1Teleoperator(Teleoperator):
"""
Bimanual exoskeleton arms teleoperator for Unitree G1 arms.
Uses inverse kinematics: exoskeleton FK computes end-effector pose,
G1 IK solves for joint angles.
"""
config_class = UnitreeG1TeleoperatorConfig
name = "unitree_g1"
def __init__(self, config: UnitreeG1TeleoperatorConfig):
super().__init__(config)
self.config = config
# Setup calibration directory
self.calibration_dir = (
config.calibration_dir
if config.calibration_dir
else HF_LEROBOT_CALIBRATION / TELEOPERATORS / self.name
)
self.calibration_dir.mkdir(parents=True, exist_ok=True)
left_id = f"{config.id}_left" if config.id else "left"
right_id = f"{config.id}_right" if config.id else "right"
# Create exoskeleton arm instances
self.left_arm = ExoskeletonArm(
port=config.left_arm_config.port,
baud_rate=config.left_arm_config.baud_rate,
calibration_fpath=self.calibration_dir / f"{left_id}.json",
side="left",
)
self.right_arm = ExoskeletonArm(
port=config.right_arm_config.port,
baud_rate=config.right_arm_config.baud_rate,
calibration_fpath=self.calibration_dir / f"{right_id}.json",
side="right",
)
self.ik_helper: ExoskeletonIKHelper | None = None
@cached_property
def action_features(self) -> dict[str, type]:
return {f"{name}.q": float for name in self._g1_joint_names}
@cached_property
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
frozen_joints = [j.strip() for j in self.config.frozen_joints.split(",") if j.strip()]
self.ik_helper = ExoskeletonIKHelper(frozen_joints=frozen_joints)
logger.info("IK helper initialized")
def calibrate(self) -> None:
if not self.left_arm.is_calibrated:
logger.info("Starting calibration for left arm...")
self.left_arm.calibrate()
else:
logger.info("Left arm already calibrated. Skipping.")
if not self.right_arm.is_calibrated:
logger.info("Starting calibration for right arm...")
self.right_arm.calibrate()
else:
logger.info("Right arm already calibrated. Skipping.")
logger.info("Starting visualization to verify calibration...")
self.run_visualization_loop()
def configure(self) -> None:
pass
def get_action(self) -> dict[str, float]:
left_angles = self.left_arm.get_angles()
right_angles = self.right_arm.get_angles()
return self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles)
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Exoskeleton arms do not support feedback")
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
def run_visualization_loop(self):
"""Run interactive Meshcat visualization loop to verify tracking."""
if self.ik_helper is None:
frozen_joints = [j.strip() for j in self.config.frozen_joints.split(",") if j.strip()]
self.ik_helper = ExoskeletonIKHelper(frozen_joints=frozen_joints)
self.ik_helper.init_visualization()
print("\n" + "=" * 60)
print("Visualization running! Move the exoskeletons to test tracking.")
print("Press Ctrl+C to exit.")
print("=" * 60 + "\n")
try:
while True:
left_angles = self.left_arm.get_angles()
right_angles = self.right_arm.get_angles()
self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles)
self.ik_helper.update_visualization()
time.sleep(0.01)
except KeyboardInterrupt:
print("\n\nVisualization stopped.")
@cached_property
def _g1_joint_names(self) -> list[str]:
return [joint.name for joint in G1_29_JointIndex]
+18 -4
View File
@@ -13,12 +13,14 @@
# limitations under the License.
from enum import Enum
from typing import cast
from typing import TYPE_CHECKING, cast
from lerobot.utils.import_utils import make_device_from_device_class
from .config import TeleoperatorConfig
from .teleoperator import Teleoperator
if TYPE_CHECKING:
from .teleoperator import Teleoperator
class TeleopEvents(Enum):
@@ -31,7 +33,7 @@ class TeleopEvents(Enum):
TERMINATE_EPISODE = "terminate_episode"
def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
# TODO(Steven): Consider just using the make_device_from_device_class for all types
if config.type == "keyboard":
from .keyboard import KeyboardTeleop
@@ -73,6 +75,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
from .homunculus import HomunculusArm
return HomunculusArm(config)
elif config.type == "unitree_g1":
from .unitree_g1 import UnitreeG1Teleoperator
return UnitreeG1Teleoperator(config)
elif config.type == "bi_so_leader":
from .bi_so_leader import BiSOLeader
@@ -81,8 +87,16 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
from .reachy2_teleoperator import Reachy2Teleoperator
return Reachy2Teleoperator(config)
elif config.type == "openarm_leader":
from .openarm_leader import OpenArmLeader
return OpenArmLeader(config)
elif config.type == "bi_openarm_leader":
from .bi_openarm_leader import BiOpenArmLeader
return BiOpenArmLeader(config)
else:
try:
return cast(Teleoperator, make_device_from_device_class(config))
return cast("Teleoperator", make_device_from_device_class(config))
except Exception as e:
raise ValueError(f"Error creating robot with config {config}: {e}") from e
+3
View File
@@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s"
OBS_LANGUAGE = OBS_STR + ".language"
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
OBS_LANGUAGE_SUBTASK = OBS_STR + ".subtask"
OBS_LANGUAGE_SUBTASK_TOKENS = OBS_LANGUAGE_SUBTASK + ".tokens"
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK + ".attention_mask"
ACTION = "action"
ACTION_PREFIX = ACTION + "."
+127 -57
View File
@@ -20,7 +20,9 @@
# ```
from pathlib import Path
from unittest.mock import patch
import cv2
import numpy as np
import pytest
@@ -28,6 +30,50 @@ from lerobot.cameras.configs import Cv2Rotation
from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
RealVideoCapture = cv2.VideoCapture
class MockLoopingVideoCapture:
"""
Wraps the real OpenCV VideoCapture.
Motivation: cv2.VideoCapture(file.png) is only valid for one read.
Strategy: Read the file once & return the cached frame for subsequent reads.
Consequence: No recurrent I/O operations, but we keep the test artifacts simple.
"""
def __init__(self, *args, **kwargs):
args_clean = [str(a) if isinstance(a, Path) else a for a in args]
self._real_vc = RealVideoCapture(*args_clean, **kwargs)
self._cached_frame = None
def read(self):
ret, frame = self._real_vc.read()
if ret:
self._cached_frame = frame
return ret, frame
if not ret and self._cached_frame is not None:
return True, self._cached_frame.copy()
return ret, frame
def __getattr__(self, name):
return getattr(self._real_vc, name)
@pytest.fixture(autouse=True)
def patch_opencv_videocapture():
"""
Automatically patches cv2.VideoCapture for all tests.
"""
module_path = OpenCVCamera.__module__
target = f"{module_path}.cv2.VideoCapture"
with patch(target, new=MockLoopingVideoCapture):
yield
# NOTE(Steven): more tests + assertions?
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras"
DEFAULT_PNG_FILE_PATH = TEST_ARTIFACTS_DIR / "image_160x120.png"
@@ -43,25 +89,22 @@ def test_abc_implementation():
def test_connect():
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
camera = OpenCVCamera(config)
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0)
camera.connect(warmup=False)
assert camera.is_connected
with OpenCVCamera(config) as camera:
assert camera.is_connected
def test_connect_already_connected():
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
camera = OpenCVCamera(config)
camera.connect(warmup=False)
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0)
with pytest.raises(DeviceAlreadyConnectedError):
camera.connect(warmup=False)
with OpenCVCamera(config) as camera, pytest.raises(DeviceAlreadyConnectedError):
camera.connect()
def test_connect_invalid_camera_path():
config = OpenCVCameraConfig(index_or_path="nonexistent/camera.png")
camera = OpenCVCamera(config)
with pytest.raises(ConnectionError):
@@ -74,27 +117,25 @@ def test_invalid_width_connect():
width=99999, # Invalid width to trigger error
height=480,
)
camera = OpenCVCamera(config)
camera = OpenCVCamera(config)
with pytest.raises(RuntimeError):
camera.connect(warmup=False)
@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES)
def test_read(index_or_path):
config = OpenCVCameraConfig(index_or_path=index_or_path)
camera = OpenCVCamera(config)
camera.connect(warmup=False)
config = OpenCVCameraConfig(index_or_path=index_or_path, warmup_s=0)
img = camera.read()
assert isinstance(img, np.ndarray)
with OpenCVCamera(config) as camera:
img = camera.read()
assert isinstance(img, np.ndarray)
def test_read_before_connect():
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
camera = OpenCVCamera(config)
camera = OpenCVCamera(config)
with pytest.raises(DeviceNotConnectedError):
_ = camera.read()
@@ -119,32 +160,22 @@ def test_disconnect_before_connect():
@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES)
def test_async_read(index_or_path):
config = OpenCVCameraConfig(index_or_path=index_or_path)
camera = OpenCVCamera(config)
camera.connect(warmup=False)
config = OpenCVCameraConfig(index_or_path=index_or_path, warmup_s=0)
try:
with OpenCVCamera(config) as camera:
img = camera.async_read()
assert camera.thread is not None
assert camera.thread.is_alive()
assert isinstance(img, np.ndarray)
finally:
if camera.is_connected:
camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends
def test_async_read_timeout():
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
camera = OpenCVCamera(config)
camera.connect(warmup=False)
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0)
try:
with pytest.raises(TimeoutError):
camera.async_read(timeout_ms=0)
finally:
if camera.is_connected:
camera.disconnect()
with OpenCVCamera(config) as camera, pytest.raises(TimeoutError):
camera.async_read(timeout_ms=0) # consumes any available frame by then
camera.async_read(timeout_ms=0) # request immediately another one
def test_async_read_before_connect():
@@ -155,6 +186,50 @@ def test_async_read_before_connect():
_ = camera.async_read()
def test_read_latest():
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0)
with OpenCVCamera(config) as camera:
# ensure at least one fresh frame is captured
frame = camera.read()
latest = camera.read_latest()
assert isinstance(latest, np.ndarray)
assert latest.shape == frame.shape
def test_read_latest_before_connect():
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
camera = OpenCVCamera(config)
with pytest.raises(DeviceNotConnectedError):
_ = camera.read_latest()
def test_read_latest_high_frequency():
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0)
with OpenCVCamera(config) as camera:
# prime to ensure frames are available
ref = camera.read()
for _ in range(20):
latest = camera.read_latest()
assert isinstance(latest, np.ndarray)
assert latest.shape == ref.shape
def test_read_latest_too_old():
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0)
with OpenCVCamera(config) as camera:
# prime to ensure frames are available
_ = camera.read()
with pytest.raises(TimeoutError):
_ = camera.read_latest(max_age_ms=0) # immediately too old
def test_fourcc_configuration():
"""Test FourCC configuration validation and application."""
@@ -181,18 +256,15 @@ def test_fourcc_configuration():
def test_fourcc_with_camera():
"""Test FourCC functionality with actual camera connection."""
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, fourcc="MJPG")
camera = OpenCVCamera(config)
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, fourcc="MJPG", warmup_s=0)
# Connect should work with MJPG specified
camera.connect(warmup=False)
assert camera.is_connected
with OpenCVCamera(config) as camera:
assert camera.is_connected
# Read should work normally
img = camera.read()
assert isinstance(img, np.ndarray)
camera.disconnect()
# Read should work normally
img = camera.read()
assert isinstance(img, np.ndarray)
@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES)
@@ -211,18 +283,16 @@ def test_rotation(rotation, index_or_path):
dimensions = filename.split("_")[-1].split(".")[0] # Assumes filenames format (_wxh.png)
original_width, original_height = map(int, dimensions.split("x"))
config = OpenCVCameraConfig(index_or_path=index_or_path, rotation=rotation)
camera = OpenCVCamera(config)
camera.connect(warmup=False)
config = OpenCVCameraConfig(index_or_path=index_or_path, rotation=rotation, warmup_s=0)
with OpenCVCamera(config) as camera:
img = camera.read()
assert isinstance(img, np.ndarray)
img = camera.read()
assert isinstance(img, np.ndarray)
if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270):
assert camera.width == original_height
assert camera.height == original_width
assert img.shape[:2] == (original_width, original_height)
else:
assert camera.width == original_width
assert camera.height == original_height
assert img.shape[:2] == (original_height, original_width)
if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270):
assert camera.width == original_height
assert camera.height == original_width
assert img.shape[:2] == (original_width, original_height)
else:
assert camera.width == original_width
assert camera.height == original_height
assert img.shape[:2] == (original_height, original_width)
+38
View File
@@ -150,6 +150,44 @@ def test_async_read_before_connect(camera):
_ = camera.async_read()
def test_read_latest(camera):
camera.connect()
frame = camera.read()
latest = camera.read_latest()
assert isinstance(latest, np.ndarray)
assert latest.shape == frame.shape
def test_read_latest_before_connect(camera):
# camera fixture yields an unconnected camera instance
with pytest.raises(DeviceNotConnectedError):
_ = camera.read_latest()
def test_read_latest_high_frequency(camera):
camera.connect()
# prime to ensure frames are available
ref = camera.read()
for _ in range(20):
latest = camera.read_latest()
assert isinstance(latest, np.ndarray)
assert latest.shape == ref.shape
def test_read_latest_too_old(camera):
camera.connect()
# prime to ensure frames are available
_ = camera.read()
with pytest.raises(TimeoutError):
_ = camera.read_latest(max_age_ms=0) # immediately too old
def test_wrong_camera_name():
with pytest.raises(ValueError):
_ = Reachy2CameraConfig(name="wrong-name", image_type="left")
+68 -46
View File
@@ -62,19 +62,15 @@ def test_abc_implementation():
def test_connect():
config = RealSenseCameraConfig(serial_number_or_name="042")
camera = RealSenseCamera(config)
config = RealSenseCameraConfig(serial_number_or_name="042", warmup_s=0)
camera.connect(warmup=False)
assert camera.is_connected
with RealSenseCamera(config) as camera:
assert camera.is_connected
def test_connect_already_connected():
config = RealSenseCameraConfig(serial_number_or_name="042")
camera = RealSenseCamera(config)
camera.connect(warmup=False)
with pytest.raises(DeviceAlreadyConnectedError):
config = RealSenseCameraConfig(serial_number_or_name="042", warmup_s=0)
with RealSenseCamera(config) as camera, pytest.raises(DeviceAlreadyConnectedError):
camera.connect(warmup=False)
@@ -96,12 +92,10 @@ def test_invalid_width_connect():
def test_read():
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30)
camera = RealSenseCamera(config)
camera.connect(warmup=False)
img = camera.read()
assert isinstance(img, np.ndarray)
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0)
with RealSenseCamera(config) as camera:
img = camera.read()
assert isinstance(img, np.ndarray)
# TODO(Steven): Fix this test for the latest version of pyrealsense2.
@@ -142,32 +136,21 @@ def test_disconnect_before_connect():
def test_async_read():
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30)
camera = RealSenseCamera(config)
camera.connect(warmup=False)
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0)
try:
with RealSenseCamera(config) as camera:
img = camera.async_read()
assert camera.thread is not None
assert camera.thread.is_alive()
assert isinstance(img, np.ndarray)
finally:
if camera.is_connected:
camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends
def test_async_read_timeout():
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30)
camera = RealSenseCamera(config)
camera.connect(warmup=False)
try:
with pytest.raises(TimeoutError):
camera.async_read(timeout_ms=0)
finally:
if camera.is_connected:
camera.disconnect()
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0)
with RealSenseCamera(config) as camera, pytest.raises(TimeoutError):
camera.async_read(timeout_ms=0) # consumes any available frame by then
camera.async_read(timeout_ms=0) # request immediately another one
def test_async_read_before_connect():
@@ -178,6 +161,47 @@ def test_async_read_before_connect():
_ = camera.async_read()
def test_read_latest():
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0)
with RealSenseCamera(config) as camera:
img = camera.read()
latest = camera.read_latest()
assert isinstance(latest, np.ndarray)
assert latest.shape == img.shape
def test_read_latest_high_frequency():
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0)
with RealSenseCamera(config) as camera:
# prime with one read to ensure frames are available
ref = camera.read()
for _ in range(20):
latest = camera.read_latest()
assert isinstance(latest, np.ndarray)
assert latest.shape == ref.shape
def test_read_latest_before_connect():
config = RealSenseCameraConfig(serial_number_or_name="042")
camera = RealSenseCamera(config)
with pytest.raises(DeviceNotConnectedError):
_ = camera.read_latest()
def test_read_latest_too_old():
config = RealSenseCameraConfig(serial_number_or_name="042")
with RealSenseCamera(config) as camera:
# prime to ensure frames are available
_ = camera.read()
with pytest.raises(TimeoutError):
_ = camera.read_latest(max_age_ms=0) # immediately too old
@pytest.mark.parametrize(
"rotation",
[
@@ -189,18 +213,16 @@ def test_async_read_before_connect():
ids=["no_rot", "rot90", "rot180", "rot270"],
)
def test_rotation(rotation):
config = RealSenseCameraConfig(serial_number_or_name="042", rotation=rotation)
camera = RealSenseCamera(config)
camera.connect(warmup=False)
config = RealSenseCameraConfig(serial_number_or_name="042", rotation=rotation, warmup_s=0)
with RealSenseCamera(config) as camera:
img = camera.read()
assert isinstance(img, np.ndarray)
img = camera.read()
assert isinstance(img, np.ndarray)
if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270):
assert camera.width == 480
assert camera.height == 640
assert img.shape[:2] == (640, 480)
else:
assert camera.width == 640
assert camera.height == 480
assert img.shape[:2] == (480, 640)
if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270):
assert camera.width == 480
assert camera.height == 640
assert img.shape[:2] == (640, 480)
else:
assert camera.width == 640
assert camera.height == 480
assert img.shape[:2] == (480, 640)
+89
View File
@@ -525,3 +525,92 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
assert img.shape[0] == 3, f"Image {image_key} should have 3 channels"
assert_dataset_iteration_works(aggr_ds)
def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory):
"""Regression test for aggregating a dataset that is itself a result of a previous merge.
This test reproduces the bug where merging datasets with multiple parquet files
(e.g., from a previous merge with file rotation) would cause FileNotFoundError
because metadata file indices were incorrectly preserved instead of being mapped
to their actual destination files.
The fix adds src_to_dst tracking in aggregate_data() to correctly map source
file indices to destination file indices.
"""
# Step 1: Create datasets A and B
ds_a = lerobot_dataset_factory(
root=tmp_path / "ds_a",
repo_id=f"{DUMMY_REPO_ID}_a",
total_episodes=4,
total_frames=200,
)
ds_b = lerobot_dataset_factory(
root=tmp_path / "ds_b",
repo_id=f"{DUMMY_REPO_ID}_b",
total_episodes=4,
total_frames=200,
)
# Step 2: Merge A+B into AB with small file size to force multiple files
aggregate_datasets(
repo_ids=[ds_a.repo_id, ds_b.repo_id],
roots=[ds_a.root, ds_b.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_ab",
aggr_root=tmp_path / "ds_ab",
data_files_size_in_mb=0.01, # Force file rotation
)
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "ds_ab")
ds_ab = LeRobotDataset(f"{DUMMY_REPO_ID}_ab", root=tmp_path / "ds_ab")
# Verify AB has multiple data files (file rotation occurred)
ab_data_files = list((tmp_path / "ds_ab" / "data").rglob("*.parquet"))
assert len(ab_data_files) > 1, "First merge should create multiple parquet files"
# Step 3: Create dataset C
ds_c = lerobot_dataset_factory(
root=tmp_path / "ds_c",
repo_id=f"{DUMMY_REPO_ID}_c",
total_episodes=2,
total_frames=100,
)
# Step 4: Merge AB+C into final - THIS IS WHERE THE BUG OCCURRED
aggregate_datasets(
repo_ids=[ds_ab.repo_id, ds_c.repo_id],
roots=[ds_ab.root, ds_c.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_abc",
aggr_root=tmp_path / "ds_abc",
)
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "ds_abc")
ds_abc = LeRobotDataset(f"{DUMMY_REPO_ID}_abc", root=tmp_path / "ds_abc")
# Step 5: Verify all data files referenced in metadata actually exist
for ep_idx in range(ds_abc.num_episodes):
data_file_path = ds_abc.root / ds_abc.meta.get_data_file_path(ep_idx)
assert data_file_path.exists(), (
f"Episode {ep_idx} references non-existent file: {data_file_path}\n"
"This indicates the src_to_dst mapping fix is not working correctly."
)
# Step 6: Verify we can iterate through the entire dataset without FileNotFoundError
expected_episodes = ds_a.num_episodes + ds_b.num_episodes + ds_c.num_episodes
expected_frames = ds_a.num_frames + ds_b.num_frames + ds_c.num_frames
assert ds_abc.num_episodes == expected_episodes
assert ds_abc.num_frames == expected_frames
# This would raise FileNotFoundError before the fix
assert_dataset_iteration_works(ds_abc)
+169
View File
@@ -26,6 +26,7 @@ from lerobot.datasets.dataset_tools import (
delete_episodes,
merge_datasets,
modify_features,
modify_tasks,
remove_feature,
split_dataset,
)
@@ -1050,6 +1051,174 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
assert "reward" in modified_dataset.meta.features
def test_modify_tasks_single_task_for_all(sample_dataset):
"""Test setting a single task for all episodes."""
new_task = "Pick up the cube and place it"
modified_dataset = modify_tasks(sample_dataset, new_task=new_task)
# Verify all episodes have the new task
assert len(modified_dataset.meta.tasks) == 1
assert new_task in modified_dataset.meta.tasks.index
# Verify task_index is 0 for all frames (only one task)
for i in range(len(modified_dataset)):
item = modified_dataset[i]
assert item["task_index"].item() == 0
assert item["task"] == new_task
def test_modify_tasks_episode_specific(sample_dataset):
"""Test setting different tasks for specific episodes."""
episode_tasks = {
0: "Task A",
1: "Task B",
2: "Task A",
3: "Task C",
4: "Task B",
}
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
# Verify correct number of unique tasks
unique_tasks = set(episode_tasks.values())
assert len(modified_dataset.meta.tasks) == len(unique_tasks)
# Verify each episode has the correct task
for ep_idx, expected_task in episode_tasks.items():
ep_data = modified_dataset.meta.episodes[ep_idx]
assert ep_data["tasks"][0] == expected_task
def test_modify_tasks_default_with_overrides(sample_dataset):
"""Test setting a default task with specific overrides."""
default_task = "Default task"
override_task = "Special task"
episode_tasks = {2: override_task, 4: override_task}
modified_dataset = modify_tasks(
sample_dataset,
new_task=default_task,
episode_tasks=episode_tasks,
)
# Verify correct number of unique tasks
assert len(modified_dataset.meta.tasks) == 2
assert default_task in modified_dataset.meta.tasks.index
assert override_task in modified_dataset.meta.tasks.index
# Verify episodes have correct tasks
for ep_idx in range(5):
ep_data = modified_dataset.meta.episodes[ep_idx]
if ep_idx in episode_tasks:
assert ep_data["tasks"][0] == override_task
else:
assert ep_data["tasks"][0] == default_task
def test_modify_tasks_no_task_specified(sample_dataset):
"""Test error when no task is specified."""
with pytest.raises(ValueError, match="Must specify at least one of new_task or episode_tasks"):
modify_tasks(sample_dataset)
def test_modify_tasks_invalid_episode_indices(sample_dataset):
"""Test error with invalid episode indices."""
with pytest.raises(ValueError, match="Invalid episode indices"):
modify_tasks(sample_dataset, episode_tasks={10: "Task", 20: "Task"})
def test_modify_tasks_updates_info_json(sample_dataset):
"""Test that total_tasks is updated in info.json."""
episode_tasks = {0: "Task A", 1: "Task B", 2: "Task C", 3: "Task A", 4: "Task B"}
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
# Verify total_tasks is updated
assert modified_dataset.meta.total_tasks == 3
def test_modify_tasks_preserves_other_metadata(sample_dataset):
"""Test that modifying tasks preserves other metadata."""
original_frames = sample_dataset.meta.total_frames
original_episodes = sample_dataset.meta.total_episodes
original_fps = sample_dataset.meta.fps
modified_dataset = modify_tasks(sample_dataset, new_task="New task")
# Verify other metadata is preserved
assert modified_dataset.meta.total_frames == original_frames
assert modified_dataset.meta.total_episodes == original_episodes
assert modified_dataset.meta.fps == original_fps
def test_modify_tasks_task_index_correct(sample_dataset):
"""Test that task_index values are correct in data files."""
# Create tasks that will have predictable indices (sorted alphabetically)
episode_tasks = {
0: "Alpha task", # Will be index 0
1: "Beta task", # Will be index 1
2: "Alpha task", # Will be index 0
3: "Gamma task", # Will be index 2
4: "Beta task", # Will be index 1
}
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
# Verify task indices are correct
task_to_expected_idx = {
"Alpha task": 0,
"Beta task": 1,
"Gamma task": 2,
}
for i in range(len(modified_dataset)):
item = modified_dataset[i]
ep_idx = item["episode_index"].item()
expected_task = episode_tasks[ep_idx]
expected_idx = task_to_expected_idx[expected_task]
assert item["task_index"].item() == expected_idx
assert item["task"] == expected_task
def test_modify_tasks_in_place(sample_dataset):
"""Test that modify_tasks modifies the dataset in-place."""
original_root = sample_dataset.root
modified_dataset = modify_tasks(sample_dataset, new_task="New task")
# Verify same instance is returned and root is unchanged
assert modified_dataset is sample_dataset
assert modified_dataset.root == original_root
def test_modify_tasks_keeps_original_when_not_overridden(sample_dataset):
"""Test that original tasks are kept when using episode_tasks without new_task."""
from lerobot.datasets.utils import load_episodes
# Ensure episodes metadata is loaded
if sample_dataset.meta.episodes is None:
sample_dataset.meta.episodes = load_episodes(sample_dataset.meta.root)
# Get original tasks for episodes not being overridden
original_task_ep0 = sample_dataset.meta.episodes[0]["tasks"][0]
original_task_ep1 = sample_dataset.meta.episodes[1]["tasks"][0]
# Only override episodes 2, 3, 4
episode_tasks = {2: "New Task A", 3: "New Task B", 4: "New Task A"}
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
# Verify original tasks are kept for episodes 0 and 1
assert modified_dataset.meta.episodes[0]["tasks"][0] == original_task_ep0
assert modified_dataset.meta.episodes[1]["tasks"][0] == original_task_ep1
# Verify new tasks for overridden episodes
assert modified_dataset.meta.episodes[2]["tasks"][0] == "New Task A"
assert modified_dataset.meta.episodes[3]["tasks"][0] == "New Task B"
assert modified_dataset.meta.episodes[4]["tasks"][0] == "New Task A"
def test_convert_image_to_video_dataset(tmp_path):
"""Test converting lerobot/pusht_image dataset to video format."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
+190
View File
@@ -0,0 +1,190 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for subtask functionality in LeRobotDataset.
These tests verify that:
- Subtask information is correctly loaded from datasets that have subtask data
- The __getitem__ method correctly adds subtask strings to returned items
- Subtask handling gracefully handles missing data
"""
import pandas as pd
import pytest
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
class TestSubtaskDataset:
"""Tests for subtask handling in LeRobotDataset."""
@pytest.fixture
def subtask_dataset(self):
"""Load the test subtask dataset from the hub."""
# Use lerobot/pusht-subtask dataset with episode 1
return LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
def test_subtask_dataset_loads(self, subtask_dataset):
"""Test that the subtask dataset loads successfully."""
assert subtask_dataset is not None
assert len(subtask_dataset) > 0
def test_subtask_metadata_loaded(self, subtask_dataset):
"""Test that subtask metadata is loaded when present in dataset."""
# The dataset should have subtasks metadata loaded
assert subtask_dataset.meta.subtasks is not None
assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame)
def test_subtask_index_in_features(self, subtask_dataset):
"""Test that subtask_index is a feature when dataset has subtasks."""
assert "subtask_index" in subtask_dataset.features
def test_getitem_returns_subtask_string(self, subtask_dataset):
"""Test that __getitem__ correctly adds subtask string to returned item."""
item = subtask_dataset[0]
# Subtask should be present in the returned item
assert "subtask" in item
assert isinstance(item["subtask"], str)
assert len(item["subtask"]) > 0 # Should not be empty
def test_getitem_has_subtask_index(self, subtask_dataset):
"""Test that __getitem__ includes subtask_index."""
item = subtask_dataset[0]
assert "subtask_index" in item
assert isinstance(item["subtask_index"], torch.Tensor)
def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset):
"""Test that subtask_index correctly maps to a subtask in metadata."""
item = subtask_dataset[0]
subtask_idx = item["subtask_index"].item()
subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name
assert item["subtask"] == subtask_from_metadata
def test_all_items_have_subtask(self, subtask_dataset):
"""Test that all items in the dataset have subtask information."""
for i in range(min(len(subtask_dataset), 5)): # Check first 5 items
item = subtask_dataset[i]
assert "subtask" in item
assert isinstance(item["subtask"], str)
def test_task_and_subtask_coexist(self, subtask_dataset):
"""Test that both task and subtask are present in returned items."""
item = subtask_dataset[0]
# Both task and subtask should be present
assert "task" in item
assert "subtask" in item
assert isinstance(item["task"], str)
assert isinstance(item["subtask"], str)
class TestSubtaskDatasetMissing:
"""Tests for graceful handling when subtask data is missing."""
@pytest.fixture
def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory):
"""Create a dataset without subtask information."""
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features)
# Add some frames and save
for _ in range(5):
dataset.add_frame({"state": torch.randn(2), "task": "Test task"})
dataset.save_episode()
dataset.finalize()
# Reload the dataset
return LeRobotDataset(dataset.repo_id, root=dataset.root)
def test_no_subtask_in_features(self, dataset_without_subtasks):
"""Test that subtask_index is not in features when not provided."""
assert "subtask_index" not in dataset_without_subtasks.features
def test_getitem_without_subtask(self, dataset_without_subtasks):
"""Test that __getitem__ works when subtask is not present."""
item = dataset_without_subtasks[0]
# Item should still be retrievable
assert item is not None
assert "state" in item
assert "task" in item
# Subtask should NOT be present
assert "subtask" not in item
def test_subtasks_metadata_is_none(self, dataset_without_subtasks):
"""Test that subtasks metadata is None when not present."""
assert dataset_without_subtasks.meta.subtasks is None
class TestSubtaskEdgeCases:
"""Edge case tests for subtask handling."""
def test_subtask_with_multiple_episodes(self):
"""Test subtask handling with multiple episodes if available."""
try:
dataset = LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
except Exception:
pytest.skip("Could not load test-subtask dataset")
# Check first and last items have valid subtasks
first_item = dataset[0]
last_item = dataset[len(dataset) - 1]
assert "subtask" in first_item
assert "subtask" in last_item
assert isinstance(first_item["subtask"], str)
assert isinstance(last_item["subtask"], str)
def test_subtask_index_consistency(self):
"""Test that same subtask_index returns same subtask string."""
try:
dataset = LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
except Exception:
pytest.skip("Could not load test-subtask dataset")
if len(dataset) < 2:
pytest.skip("Dataset too small for this test")
# Collect subtask_index to subtask mappings
subtask_map = {}
for i in range(min(len(dataset), 10)):
item = dataset[i]
idx = item["subtask_index"].item()
subtask = item["subtask"]
if idx in subtask_map:
# Same index should always return same subtask
assert subtask_map[idx] == subtask, (
f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'"
)
else:
subtask_map[idx] = subtask
+2 -1
View File
@@ -441,12 +441,13 @@ def test_sac_policy_with_predefined_entropy():
def test_sac_policy_update_temperature():
"""Test that temperature property is always in sync with log_alpha."""
config = create_default_config(continuous_action_dim=10, state_dim=10)
policy = SACPolicy(config=config)
assert policy.temperature == pytest.approx(1.0)
policy.log_alpha.data = torch.tensor([math.log(0.1)])
policy.update_temperature()
# Temperature property automatically reflects log_alpha changes
assert policy.temperature == pytest.approx(0.1)
+464 -1
View File
@@ -27,7 +27,14 @@ import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey
from lerobot.processor.converters import create_transition, identity_transition
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_LANGUAGE, OBS_STATE
from lerobot.utils.constants import (
ACTION,
OBS_IMAGE,
OBS_LANGUAGE,
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
OBS_LANGUAGE_SUBTASK_TOKENS,
OBS_STATE,
)
from tests.utils import require_package
@@ -1038,3 +1045,459 @@ def test_simulated_accelerate_scenario():
# MockTokenizer squeezes single-item batches, so shape is (max_length,) not (1, max_length)
assert tokens.shape == (10,) # MockTokenizer behavior for single string in list
assert attention_mask.shape == (10,)
# =============================================================================
# Tests for get_subtask method
# =============================================================================
@require_package("transformers")
def test_get_subtask_missing_key():
"""Test get_subtask returns None when subtask key is missing from complementary_data."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task"}, # No "subtask" key
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_none_value():
"""Test get_subtask returns None when subtask value is None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": None},
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_none_complementary_data():
"""Test get_subtask returns None when complementary_data is None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data=None, # No complementary data
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_string():
"""Test get_subtask returns list with single string when subtask is a string."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "pick up the cube"},
)
result = processor.get_subtask(transition)
assert result == ["pick up the cube"]
assert isinstance(result, list)
assert len(result) == 1
@require_package("transformers")
def test_get_subtask_list_of_strings():
"""Test get_subtask returns the list when subtask is already a list of strings."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
subtask_list = ["pick up", "move to target", "place down"]
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": subtask_list},
)
result = processor.get_subtask(transition)
assert result == subtask_list
assert isinstance(result, list)
assert len(result) == 3
@require_package("transformers")
def test_get_subtask_unsupported_type_integer():
"""Test get_subtask returns None when subtask is an unsupported type (integer)."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": 123},
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_unsupported_type_mixed_list():
"""Test get_subtask returns None when subtask is a list with mixed types."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": ["valid string", 123, "another string"]},
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_unsupported_type_dict():
"""Test get_subtask returns None when subtask is a dictionary."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": {"key": "value"}},
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_empty_string():
"""Test get_subtask with empty string returns list with empty string."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": ""},
)
result = processor.get_subtask(transition)
assert result == [""]
@require_package("transformers")
def test_get_subtask_empty_list():
"""Test get_subtask with empty list returns empty list."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": []},
)
result = processor.get_subtask(transition)
assert result == []
# =============================================================================
# Tests for subtask tokenization in observation method
# =============================================================================
@require_package("transformers")
def test_subtask_tokenization_when_present():
"""Test that subtask is tokenized and added to observation when present."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "pick up the red cube"},
)
result = processor(transition)
# Check that subtask tokens were added to observation
observation = result[TransitionKey.OBSERVATION]
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
# Check token structure
subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
assert isinstance(subtask_tokens, torch.Tensor)
assert isinstance(subtask_attention_mask, torch.Tensor)
assert subtask_tokens.shape == (8,)
assert subtask_attention_mask.shape == (8,)
assert subtask_attention_mask.dtype == torch.bool
@require_package("transformers")
def test_subtask_tokenization_not_added_when_none():
"""Test that subtask tokens are NOT added to observation when subtask is None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task"}, # No subtask
)
result = processor(transition)
# Check that subtask tokens were NOT added to observation
observation = result[TransitionKey.OBSERVATION]
assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation
# But main task tokens should still be present
assert f"{OBS_LANGUAGE}.tokens" in observation
assert f"{OBS_LANGUAGE}.attention_mask" in observation
@require_package("transformers")
def test_subtask_tokenization_not_added_when_subtask_value_is_none():
"""Test that subtask tokens are NOT added when subtask value is explicitly None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": None},
)
result = processor(transition)
# Check that subtask tokens were NOT added to observation
observation = result[TransitionKey.OBSERVATION]
assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation
@require_package("transformers")
def test_subtask_tokenization_list_of_strings():
"""Test subtask tokenization with list of strings."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": ["pick up", "place down"]},
)
result = processor(transition)
# Check that subtask tokens were added to observation
observation = result[TransitionKey.OBSERVATION]
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
# Check token structure for batch
subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
assert subtask_tokens.shape == (2, 8) # batch_size=2, seq_len=8
assert subtask_attention_mask.shape == (2, 8)
@require_package("transformers")
def test_subtask_tokenization_device_cpu():
"""Test that subtask tokens are on CPU when other tensors are on CPU."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
# Create transition with CPU tensors
observation = {OBS_STATE: torch.randn(10)} # CPU tensor
action = torch.randn(5) # CPU tensor
transition = create_transition(
observation=observation,
action=action,
complementary_data={"task": "main task", "subtask": "pick up cube"},
)
result = processor(transition)
# Check that subtask tokens are on CPU
subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
assert subtask_tokens.device.type == "cpu"
assert subtask_attention_mask.device.type == "cpu"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@require_package("transformers")
def test_subtask_tokenization_device_cuda():
"""Test that subtask tokens are moved to CUDA when other tensors are on CUDA."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
# Create transition with CUDA tensors
observation = {OBS_STATE: torch.randn(10).cuda()} # CUDA tensor
action = torch.randn(5).cuda() # CUDA tensor
transition = create_transition(
observation=observation,
action=action,
complementary_data={"task": "main task", "subtask": "pick up cube"},
)
result = processor(transition)
# Check that subtask tokens are on CUDA
subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
assert subtask_tokens.device.type == "cuda"
assert subtask_attention_mask.device.type == "cuda"
@require_package("transformers")
def test_subtask_tokenization_preserves_other_observation_data():
"""Test that subtask tokenization preserves other observation data."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
original_state = torch.tensor([1.0, 2.0, 3.0])
transition = create_transition(
observation={"state": original_state.clone()},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "pick up cube"},
)
result = processor(transition)
observation = result[TransitionKey.OBSERVATION]
# Check that original observation data is preserved
assert torch.equal(observation["state"], original_state)
# Check that both task and subtask tokens are present
assert f"{OBS_LANGUAGE}.tokens" in observation
assert f"{OBS_LANGUAGE}.attention_mask" in observation
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
@require_package("transformers")
def test_subtask_attention_mask_dtype():
"""Test that subtask attention mask has correct dtype (bool)."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "pick up cube"},
)
result = processor(transition)
observation = result[TransitionKey.OBSERVATION]
subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
assert subtask_attention_mask.dtype == torch.bool
@require_package("transformers")
def test_subtask_tokenization_deterministic():
"""Test that subtask tokenization is deterministic for the same input."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "consistent subtask"},
)
result1 = processor(transition)
result2 = processor(transition)
subtask_tokens1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_tokens2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_mask1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
subtask_mask2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
# Results should be identical
assert torch.equal(subtask_tokens1, subtask_tokens2)
assert torch.equal(subtask_mask1, subtask_mask2)
@require_package("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_subtask_tokenization_integration_with_pipeline(mock_auto_tokenizer):
"""Test subtask tokenization works correctly with DataProcessorPipeline."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6)
robot_processor = DataProcessorPipeline(
[tokenizer_processor], to_transition=identity_transition, to_output=identity_transition
)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "subtask instruction"},
)
result = robot_processor(transition)
# Check that observation exists and both tokenizations were applied
assert TransitionKey.OBSERVATION in result
observation = result[TransitionKey.OBSERVATION]
# Check task tokens
assert f"{OBS_LANGUAGE}.tokens" in observation
assert f"{OBS_LANGUAGE}.attention_mask" in observation
# Check subtask tokens
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
# Check shapes
assert observation[f"{OBS_LANGUAGE}.tokens"].shape == (6,)
assert observation[OBS_LANGUAGE_SUBTASK_TOKENS].shape == (6,)
@require_package("transformers")
def test_subtask_not_added_for_unsupported_types():
"""Test that subtask tokens are not added when subtask has unsupported type."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
# Test with integer subtask
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": 123},
)
result = processor(transition)
observation = result[TransitionKey.OBSERVATION]
# Subtask tokens should NOT be added for unsupported types
assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation
# But main task tokens should still be present
assert f"{OBS_LANGUAGE}.tokens" in observation