mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f147a4cd48 | |||
| c3fa269b21 | |||
| 385ba8d1b7 | |||
| f4ccf911fa | |||
| 0cb8c92fe4 |
@@ -7,6 +7,8 @@
|
||||
- sections:
|
||||
- local: il_robots
|
||||
title: Imitation Learning for Robots
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
- local: bring_your_own_policies
|
||||
title: Bring Your Own Policies
|
||||
- local: integrate_hardware
|
||||
@@ -27,10 +29,6 @@
|
||||
title: Porting Large Datasets
|
||||
- local: using_dataset_tools
|
||||
title: Using the Dataset Tools
|
||||
- local: annotation_tools
|
||||
title: Using the Annotation Tools
|
||||
- local: dataset_subtask
|
||||
title: Using Subtasks in the Dataset
|
||||
title: "Datasets"
|
||||
- sections:
|
||||
- local: act
|
||||
@@ -59,6 +57,8 @@
|
||||
title: Use Async Inference
|
||||
- local: rtc
|
||||
title: Real-Time Chunking (RTC)
|
||||
- local: training_time_rtc
|
||||
title: Training-Time RTC
|
||||
title: "Inference"
|
||||
- sections:
|
||||
- local: envhub
|
||||
@@ -103,17 +103,11 @@
|
||||
title: Earth Rover Mini
|
||||
- local: omx
|
||||
title: OMX
|
||||
- local: openarm
|
||||
title: OpenArm
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
title: Phone
|
||||
title: "Teleoperators"
|
||||
- sections:
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
title: "Sensors"
|
||||
- sections:
|
||||
- local: torch_accelerators
|
||||
title: PyTorch accelerators
|
||||
|
||||
@@ -1,425 +0,0 @@
|
||||
# Dataset Annotation Tools
|
||||
|
||||
This guide explains how to use the automatic annotation tools to add skill labels and synthetic dialogue to your LeRobot datasets.
|
||||
|
||||
## Overview
|
||||
|
||||
The annotation pipeline consists of two main components:
|
||||
|
||||
1. **Subtask Annotation** (`subtask_annotate.py`): Automatically segments robot demonstrations into atomic skills using Vision-Language Models (VLMs)
|
||||
2. **High-Level Annotation** (`high_level_annotate.py`): Generates synthetic user prompts and robot utterances for hierarchical policy training
|
||||
|
||||
These tools enable you to transform raw robot demonstration data into richly annotated datasets suitable for training hierarchical policies.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
Before using the annotation tools, ensure you have the required dependencies:
|
||||
|
||||
```bash
|
||||
pip install transformers qwen-vl-utils opencv-python rich pandas pyarrow
|
||||
```
|
||||
|
||||
You'll also need FFmpeg for video processing:
|
||||
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install ffmpeg
|
||||
|
||||
# macOS
|
||||
brew install ffmpeg
|
||||
```
|
||||
|
||||
## Part 1: Subtask Annotation
|
||||
|
||||
### What It Does
|
||||
|
||||
The subtask annotator segments each episode into short atomic manipulation skills (1-3 seconds each). For example, a "pick and place" episode might be segmented into:
|
||||
- "reach towards object" (0.0s - 1.2s)
|
||||
- "grasp object" (1.2s - 2.1s)
|
||||
- "lift object" (2.1s - 3.5s)
|
||||
- "move to target" (3.5s - 5.0s)
|
||||
- "release object" (5.0s - 6.2s)
|
||||
|
||||
### Usage
|
||||
|
||||
#### Basic Example
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--video-key observation.images.base \
|
||||
--output-dir /path/to/output
|
||||
```
|
||||
|
||||
#### With Local Dataset
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
|
||||
--data-dir /path/to/local/dataset \
|
||||
--video-key observation.images.base \
|
||||
--output-dir /path/to/output
|
||||
```
|
||||
|
||||
#### Advanced Options
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--video-key observation.images.base \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--batch-size 16 \
|
||||
--output-dir /path/to/output \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
### Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--repo-id` | HuggingFace Hub dataset ID | Required (or use --data-dir) |
|
||||
| `--data-dir` | Path to local dataset | Required (or use --repo-id) |
|
||||
| `--video-key` | Video observation key | Required |
|
||||
| `--model` | VLM model to use | `Qwen/Qwen2-VL-7B-Instruct` |
|
||||
| `--device` | Device to run model on | `cuda` |
|
||||
| `--dtype` | Model dtype | `bfloat16` |
|
||||
| `--batch-size` | Episodes per batch | `8` |
|
||||
| `--episodes` | Specific episodes to annotate | All episodes |
|
||||
| `--output-dir` | Output directory | Auto-generated |
|
||||
| `--push-to-hub` | Push to HuggingFace Hub | `False` |
|
||||
|
||||
### Supported Models
|
||||
|
||||
- **Qwen2-VL**: `Qwen/Qwen2-VL-2B-Instruct`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`
|
||||
- **Qwen3-VL**: `Qwen/Qwen3-VL-30B-A3B-Instruct`
|
||||
|
||||
### Output Files
|
||||
|
||||
The subtask annotation creates the following files in your dataset:
|
||||
|
||||
1. **`meta/subtasks.parquet`**: DataFrame with unique subtask names
|
||||
```python
|
||||
# Structure:
|
||||
# Index: subtask name (string)
|
||||
# Column: subtask_index (int64)
|
||||
```
|
||||
|
||||
2. **`meta/skills.json`**: Raw skill annotations with timestamps
|
||||
```json
|
||||
{
|
||||
"coarse_description": "Pick and place the object",
|
||||
"skill_to_subtask_index": {
|
||||
"reach towards object": 0,
|
||||
"grasp object": 1,
|
||||
...
|
||||
},
|
||||
"episodes": {
|
||||
"0": {
|
||||
"episode_index": 0,
|
||||
"description": "Pick and place the object",
|
||||
"skills": [
|
||||
{"name": "reach towards object", "start": 0.0, "end": 1.2},
|
||||
{"name": "grasp object", "start": 1.2, "end": 2.1},
|
||||
...
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
3. **`subtask_index` feature**: Added to each frame in the dataset
|
||||
- Type: `int64`
|
||||
- Shape: `(1,)`
|
||||
- Maps each frame to its corresponding subtask
|
||||
|
||||
### Accessing Subtask Annotations
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Load annotated dataset
|
||||
dataset = LeRobotDataset(repo_id="your/dataset_with_subtasks")
|
||||
|
||||
# Get a frame
|
||||
frame = dataset[100]
|
||||
|
||||
# Get the subtask for this frame
|
||||
subtask_idx = frame["subtask_index"].item()
|
||||
subtask_name = dataset.meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
print(f"Frame 100 is performing: {subtask_name}")
|
||||
|
||||
# Load all subtasks
|
||||
subtasks_df = dataset.meta.subtasks
|
||||
print(subtasks_df)
|
||||
```
|
||||
|
||||
## Part 2: High-Level Annotation
|
||||
|
||||
### What It Does
|
||||
|
||||
The high-level annotator generates synthetic dialogue for hierarchical policy training. For each skill, it creates:
|
||||
- **User Prompt** (`ℓ_t`): A natural language request from the user
|
||||
- **Robot Utterance** (`u_t`): A natural language response from the robot
|
||||
|
||||
This enables training policies that can understand and respond to human instructions in natural dialogue.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
**Important**: You must run subtask annotation first! High-level annotation requires the `skills.json` file generated by subtask annotation.
|
||||
|
||||
### Usage
|
||||
|
||||
#### Image Mode (Default)
|
||||
|
||||
Samples frames at regular intervals and passes images to the VLM:
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/pi05_full/annotate/high_level_annotate.py \
|
||||
--repo-id your/dataset_with_subtasks \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--image-key observation.images.base \
|
||||
--output-dir /path/to/output
|
||||
```
|
||||
|
||||
#### Video Mode
|
||||
|
||||
Passes entire episode videos to the VLM for better temporal understanding:
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/pi05_full/annotate/high_level_annotate.py \
|
||||
--repo-id your/dataset_with_subtasks \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--video-mode \
|
||||
--video-key observation.images.base \
|
||||
--video-batch-size 4 \
|
||||
--output-dir /path/to/output
|
||||
```
|
||||
|
||||
### Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--repo-id` | HuggingFace Hub dataset ID | Required (or use --data-dir) |
|
||||
| `--data-dir` | Path to local dataset | Required (or use --repo-id) |
|
||||
| `--model` | VLM model to use | `Qwen/Qwen2-VL-7B-Instruct` |
|
||||
| `--image-key` | Image observation key (image mode) | First camera key |
|
||||
| `--video-mode` | Use video instead of images | `False` |
|
||||
| `--video-key` | Video observation key (video mode) | Auto-detected |
|
||||
| `--video-batch-size` | Episodes per batch (video mode) | `1` |
|
||||
| `--sample-interval` | Sampling interval in seconds | `1.0` |
|
||||
| `--temperature` | Sampling temperature | `0.7` |
|
||||
| `--output-dir` | Output directory | Auto-generated |
|
||||
| `--push-to-hub` | Push to HuggingFace Hub | `False` |
|
||||
|
||||
### Output Files
|
||||
|
||||
The high-level annotation creates:
|
||||
|
||||
1. **`meta/tasks_high_level.parquet`**: DataFrame with high-level tasks
|
||||
```python
|
||||
# Structure:
|
||||
# Index: task string (concatenated user_prompt | robot_utterance)
|
||||
# Columns:
|
||||
# - task_index: int64
|
||||
# - user_prompt: string
|
||||
# - robot_utterance: string
|
||||
# - skill: string (associated subtask)
|
||||
# - scenario_type: string
|
||||
# - response_type: string
|
||||
```
|
||||
|
||||
2. **`meta/syn_annotations.jsonl`**: Debug annotations (JSONL format)
|
||||
```json
|
||||
{"episode_id": 0, "timestamp": 1.5, "skill_current": "grasp object", "user_prompt": "Can you pick that up?", "robot_utterance": "Sure, I'll grasp it now", ...}
|
||||
```
|
||||
|
||||
3. **`task_index_high_level` feature**: Added to each frame
|
||||
- Type: `int64`
|
||||
- Shape: `(1,)`
|
||||
- Maps each frame to its high-level task
|
||||
|
||||
### Dialogue Types Generated
|
||||
|
||||
The system generates diverse interaction types:
|
||||
|
||||
**Scenario Types:**
|
||||
- `specific_object`: "Pick up the red block"
|
||||
- `negative_task`: "Don't touch the blue one"
|
||||
- `situated_correction`: "Actually, move to the other box instead"
|
||||
- `implicit_request`: "I need something red for the tower"
|
||||
- `constraint_based`: "Make sure to handle it gently"
|
||||
|
||||
**Response Types:**
|
||||
- `confirmation`: "OK, I'll pick it up"
|
||||
- `clarification`: "Just to confirm, you want me to pick up the red block?"
|
||||
- `acknowledgment`: "Got it, picking up the red block"
|
||||
- `constraint_acknowledgment`: "Sure, I'll pick it up gently"
|
||||
|
||||
### Accessing High-Level Annotations
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
import pandas as pd
|
||||
|
||||
# Load annotated dataset
|
||||
dataset = LeRobotDataset(repo_id="your/dataset_with_high_level_tasks")
|
||||
|
||||
# Get a frame
|
||||
frame = dataset[100]
|
||||
|
||||
# Get the high-level task
|
||||
task_idx = frame["task_index_high_level"].item()
|
||||
|
||||
# Load tasks metadata
|
||||
tasks_df = pd.read_parquet(dataset.root / "meta" / "tasks_high_level.parquet")
|
||||
task_row = tasks_df[tasks_df["task_index"] == task_idx].iloc[0]
|
||||
|
||||
print(f"User: {task_row['user_prompt']}")
|
||||
print(f"Robot: {task_row['robot_utterance']}")
|
||||
print(f"Skill: {task_row['skill']}")
|
||||
|
||||
# Use in a DataLoader
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
print(f"Task indices: {batch['task_index_high_level']}")
|
||||
print(f"User prompts: {batch['user_prompt'][0]}")
|
||||
print(f"Robot utterances: {batch['robot_utterance'][0]}")
|
||||
```
|
||||
|
||||
## Complete Pipeline Example
|
||||
|
||||
Here's how to run both annotation stages:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
|
||||
REPO_ID="your-username/your-dataset"
|
||||
MODEL="Qwen/Qwen2-VL-7B-Instruct"
|
||||
OUTPUT_DIR="/path/to/output"
|
||||
|
||||
# Step 1: Subtask Annotation
|
||||
python src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
|
||||
--repo-id "$REPO_ID" \
|
||||
--video-key observation.images.base \
|
||||
--model "$MODEL" \
|
||||
--batch-size 8 \
|
||||
--output-dir "${OUTPUT_DIR}/subtasks"
|
||||
|
||||
# Step 2: High-Level Annotation (Image Mode)
|
||||
python src/lerobot/policies/pi05_full/annotate/high_level_annotate.py \
|
||||
--data-dir "${OUTPUT_DIR}/subtasks" \
|
||||
--model "$MODEL" \
|
||||
--image-key observation.images.base \
|
||||
--sample-interval 1.0 \
|
||||
--output-dir "${OUTPUT_DIR}/final"
|
||||
|
||||
# Or Step 2: High-Level Annotation (Video Mode - Recommended)
|
||||
python src/lerobot/policies/pi05_full/annotate/high_level_annotate.py \
|
||||
--data-dir "${OUTPUT_DIR}/subtasks" \
|
||||
--model "$MODEL" \
|
||||
--video-mode \
|
||||
--video-key observation.images.base \
|
||||
--video-batch-size 4 \
|
||||
--output-dir "${OUTPUT_DIR}/final"
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
### For Faster Processing
|
||||
|
||||
1. **Increase batch size**: Use `--batch-size 16` or higher (subtask annotation)
|
||||
2. **Increase video batch size**: Use `--video-batch-size 8` (high-level annotation in video mode)
|
||||
3. **Larger sampling interval**: Use `--sample-interval 5.0` for testing (samples every 5 seconds instead of 1)
|
||||
4. **Use smaller models**: `Qwen/Qwen2-VL-2B-Instruct` is faster than `Qwen2-VL-7B-Instruct`
|
||||
5. **Process specific episodes**: Use `--episodes 0 1 2 3` to annotate only a subset
|
||||
|
||||
### For Better Quality
|
||||
|
||||
1. **Use larger models**: `Qwen/Qwen3-VL-30B-A3B-Instruct` or `Qwen/Qwen2-VL-72B-Instruct`
|
||||
2. **Use video mode**: Provides better temporal context
|
||||
3. **Smaller sampling intervals**: `--sample-interval 0.5` for dense annotations
|
||||
4. **Adjust temperature**: Use `--temperature 0.9` for more diverse dialogue
|
||||
|
||||
## Memory Requirements
|
||||
|
||||
| Model | GPU Memory | Recommended Batch Size |
|
||||
|-------|------------|------------------------|
|
||||
| Qwen2-VL-2B | ~8 GB | 16-32 |
|
||||
| Qwen2-VL-7B | ~16 GB | 8-16 |
|
||||
| Qwen2-VL-72B | ~80 GB | 1-2 |
|
||||
| Qwen3-VL-30B | ~40 GB | 4-8 |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "FFmpeg not found"
|
||||
```bash
|
||||
# Install FFmpeg
|
||||
sudo apt-get install ffmpeg # Ubuntu/Debian
|
||||
brew install ffmpeg # macOS
|
||||
```
|
||||
|
||||
### "CUDA out of memory"
|
||||
- Reduce batch size: `--batch-size 1` or `--video-batch-size 1`
|
||||
- Use smaller model: `Qwen/Qwen2-VL-2B-Instruct`
|
||||
- Use CPU: `--device cpu` (much slower)
|
||||
|
||||
### "No skills.json found"
|
||||
Run subtask annotation first before high-level annotation.
|
||||
|
||||
### "Video key not found"
|
||||
List available keys:
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
dataset = LeRobotDataset(repo_id="your/dataset")
|
||||
print("Video keys:", dataset.meta.video_keys)
|
||||
print("Camera keys:", dataset.meta.camera_keys)
|
||||
```
|
||||
|
||||
## Dataset Structure After Annotation
|
||||
|
||||
```
|
||||
your_dataset_with_high_level_tasks/
|
||||
├── meta/
|
||||
│ ├── info.json # Original metadata
|
||||
│ ├── tasks.parquet # Original tasks (preserved)
|
||||
│ ├── subtasks.parquet # NEW: Subtask names and indices
|
||||
│ ├── skills.json # NEW: Raw skill annotations with timestamps
|
||||
│ ├── tasks_high_level.parquet # NEW: High-level tasks with dialogue
|
||||
│ └── syn_annotations.jsonl # NEW: Debug annotations
|
||||
├── data/
|
||||
│ └── chunk-000/
|
||||
│ ├── observation.images.base.mp4
|
||||
│ ├── action.safetensors
|
||||
│ ├── subtask_index.safetensors # NEW: Subtask per frame
|
||||
│ └── task_index_high_level.safetensors # NEW: High-level task per frame
|
||||
└── videos/
|
||||
└── ...
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
If you use these annotation tools in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@article{lerobot2024,
|
||||
title={LeRobot: State-of-the-art Machine Learning for Real-World Robotics},
|
||||
author={LeRobot Contributors},
|
||||
year={2024},
|
||||
url={https://github.com/huggingface/lerobot}
|
||||
}
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
After annotation, you can:
|
||||
1. Train hierarchical policies using the subtask and high-level annotations
|
||||
2. Use the synthetic dialogue for instruction-following policy training
|
||||
3. Analyze skill distributions and dialogue patterns
|
||||
4. Share your annotated dataset on HuggingFace Hub with `--push-to-hub`
|
||||
|
||||
For training examples, see the [training documentation](../training/).
|
||||
|
||||
+81
-95
@@ -1,22 +1,12 @@
|
||||
# Cameras
|
||||
|
||||
LeRobot offers multiple options for video capture:
|
||||
LeRobot offers multiple options for video capture, including phone cameras, built-in laptop cameras, external webcams, and Intel RealSense cameras. To efficiently record frames from most cameras, you can use either the `OpenCVCamera` or `RealSenseCamera` class. For additional compatibility details on the `OpenCVCamera` class, refer to the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
|
||||
|
||||
| Class | Supported Cameras |
|
||||
| ----------------- | ----------------------------------- |
|
||||
| `OpenCVCamera` | Phone, built-in laptop, USB webcams |
|
||||
| `ZMQCamera` | Network-connected cameras |
|
||||
| `RealSenseCamera` | Intel RealSense (with depth) |
|
||||
| `Reachy2Camera` | Reachy 2 robot cameras |
|
||||
### Finding your camera
|
||||
|
||||
> [!TIP]
|
||||
> For `OpenCVCamera` compatibility details, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
|
||||
To instantiate a camera, you need a camera identifier. This identifier might change if you reboot your computer or re-plug your camera, a behavior mostly dependant on your operating system.
|
||||
|
||||
### Find your camera
|
||||
|
||||
Every camera requires a unique identifier to be instantiated, allowing you to distinguish between multiple connected devices.
|
||||
|
||||
`OpenCVCamera` and `RealSenseCamera` support auto-discovery. Run the command below to list available devices and their identifiers. Note that these identifiers may change after rebooting your computer or re-plugging the camera, depending on your operating system.
|
||||
To find the camera indices of the cameras plugged into your system, run the following script:
|
||||
|
||||
```bash
|
||||
lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
|
||||
@@ -24,7 +14,7 @@ lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
|
||||
|
||||
The output will look something like this if you have two cameras connected:
|
||||
|
||||
```bash
|
||||
```
|
||||
--- Detected Cameras ---
|
||||
Camera #0:
|
||||
Name: OpenCV Camera @ 0
|
||||
@@ -43,37 +33,13 @@ Camera #0:
|
||||
> [!WARNING]
|
||||
> When using Intel RealSense cameras in `macOS`, you could get this [error](https://github.com/IntelRealSense/librealsense/issues/12307): `Error finding RealSense cameras: failed to set power state`, this can be solved by running the same command with `sudo` permissions. Note that using RealSense cameras in `macOS` is unstable.
|
||||
|
||||
`ZMQCamera` and `Reachy2Camera` do not support auto-discovery. They must be configured manually by providing their network address and port or robot SDK settings.
|
||||
## Use Cameras
|
||||
|
||||
## Use cameras
|
||||
Below are two examples, demonstrating how to work with the API.
|
||||
|
||||
### Frame access modes
|
||||
|
||||
All camera classes implement three access modes for capturing frames:
|
||||
|
||||
| Method | Behavior | Blocks? | Best For |
|
||||
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------- | ---------------------------------------- |
|
||||
| `read()` | Waits for the camera hardware to return a frame. May block for a long time depending on the camera and SDK. | Yes | Simple scripts, sequential capture |
|
||||
| `async_read(timeout_ms)` | Returns the latest unconsumed frame from background thread. Blocks only if buffer is empty, up to `timeout_ms`. Raises `TimeoutError` if no frame arrives. | With a timeout | Control loops synchronized to camera FPS |
|
||||
| `read_latest(max_age_ms)` | Peeks at the most recent frame in buffer (may be stale). Raises `TimeoutError` if frame is older than `max_age_ms`. | No | UI visualization, logging, monitoring |
|
||||
|
||||
### Usage examples
|
||||
|
||||
The following examples show how to use the camera API to configure and capture frames from different camera types.
|
||||
|
||||
- **Blocking and non-blocking frame capture** using an OpenCV-based camera
|
||||
- **Asynchronous frame capture** using an OpenCV-based camera
|
||||
- **Color and depth capture** using an Intel RealSense camera
|
||||
|
||||
> [!WARNING]
|
||||
> Failing to cleanly disconnect cameras can cause resource leaks. Use the context manager protocol to ensure automatic cleanup:
|
||||
>
|
||||
> ```python
|
||||
> with OpenCVCamera(config) as camera:
|
||||
> ...
|
||||
> ```
|
||||
>
|
||||
> You can also call `connect()` and `disconnect()` manually, but always use a `finally` block for the latter.
|
||||
|
||||
<hfoptions id="shell_restart">
|
||||
<hfoption id="Open CV Camera">
|
||||
|
||||
@@ -94,30 +60,16 @@ config = OpenCVCameraConfig(
|
||||
)
|
||||
|
||||
# Instantiate and connect an `OpenCVCamera`, performing a warm-up read (default).
|
||||
with OpenCVCamera(config) as camera:
|
||||
|
||||
# Read a frame synchronously — blocks until hardware delivers a new frame
|
||||
frame = camera.read()
|
||||
print(f"read() call returned frame with shape:", frame.shape)
|
||||
|
||||
# Read a frame asynchronously with a timeout — returns the latest unconsumed frame or waits up to timeout_ms for a new one
|
||||
try:
|
||||
for i in range(10):
|
||||
frame = camera.async_read(timeout_ms=200)
|
||||
print(f"async_read call returned frame {i} with shape:", frame.shape)
|
||||
except TimeoutError as e:
|
||||
print(f"No frame received within timeout: {e}")
|
||||
|
||||
# Instantly return a frame - returns the most recent frame captured by the camera
|
||||
try:
|
||||
initial_frame = camera.read_latest(max_age_ms=1000)
|
||||
for i in range(10):
|
||||
frame = camera.read_latest(max_age_ms=1000)
|
||||
print(f"read_latest call returned frame {i} with shape:", frame.shape)
|
||||
print(f"Was a new frame received by the camera? {not (initial_frame == frame).any()}")
|
||||
except TimeoutError as e:
|
||||
print(f"Frame too old: {e}")
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
|
||||
# Read frames asynchronously in a loop via `async_read(timeout_ms)`
|
||||
try:
|
||||
for i in range(10):
|
||||
frame = camera.async_read(timeout_ms=200)
|
||||
print(f"Async frame {i} shape:", frame.shape)
|
||||
finally:
|
||||
camera.disconnect()
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -159,10 +111,10 @@ finally:
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Use your phone's camera
|
||||
## Use your phone
|
||||
|
||||
<hfoptions id="use phone">
|
||||
<hfoption id="iPhone & macOS">
|
||||
<hfoption id="Mac">
|
||||
|
||||
To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
|
||||
|
||||
@@ -172,49 +124,83 @@ To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
|
||||
|
||||
For more details, visit [Apple support](https://support.apple.com/en-gb/guide/mac-help/mchl77879b8a/mac).
|
||||
|
||||
Your iPhone should be detected automatically when running the camera setup script in the next section.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="OBS virtual camera">
|
||||
<hfoption id="Linux">
|
||||
|
||||
If you want to use your phone as a camera using OBS, follow these steps to set up a virtual camera.
|
||||
If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera
|
||||
|
||||
1. _(Linux only) Install `v4l2loopback-dkms` and `v4l-utils`_. These packages create virtual camera devices and verify their settings. Install with:
|
||||
1. _Install `v4l2loopback-dkms` and `v4l-utils`_. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using:
|
||||
|
||||
```bash
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
sudo apt install v4l2loopback-dkms v4l-utils
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
2. _Install the [DroidCam app](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
|
||||
3. _Download and install [OBS Studio](https://obsproject.com)_.
|
||||
4. _Download and install the [DroidCam OBS plugin](https://droidcam.app/obs)_.
|
||||
5. _Start OBS Studio_.
|
||||
2. _Install [DroidCam](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
|
||||
3. _Install [OBS Studio](https://obsproject.com)_. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org):
|
||||
|
||||
6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480` to avoid the watermarks.
|
||||
7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video` or `OBS > Preferences... > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it.
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
flatpak install flathub com.obsproject.Studio
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
4. _Install the DroidCam OBS plugin_. This plugin integrates DroidCam with OBS Studio. Install it with:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
flatpak install flathub com.obsproject.Studio.Plugin.DroidCam
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
5. _Start OBS Studio_. Launch with:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
flatpak run com.obsproject.Studio
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`.
|
||||
7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in.
|
||||
8. _Start virtual camera_. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide).
|
||||
9. _Verify the virtual camera setup and resolution_.
|
||||
- **Linux**: Use `v4l2-ctl` to list devices and check resolution:
|
||||
```bash
|
||||
v4l2-ctl --list-devices # find VirtualCam and note its /dev/videoX path
|
||||
v4l2-ctl -d /dev/videoX --get-fmt-video # replace with your VirtualCam path
|
||||
```
|
||||
You should see `VirtualCam` listed and resolution `640x480`.
|
||||
- **macOS**: Open Photo Booth or FaceTime and select "OBS Virtual Camera" as the input.
|
||||
- **Windows**: The native Camera app doesn't support virtual cameras. Use a video conferencing app (Zoom, Teams) or run `lerobot-find-cameras opencv` directly to verify.
|
||||
9. _Verify the virtual camera setup_. Use `v4l2-ctl` to list the devices:
|
||||
|
||||
<details>
|
||||
<summary><strong>Troubleshooting</strong></summary>
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
v4l2-ctl --list-devices
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
> The virtual camera resolution is incorrect.
|
||||
You should see an entry like:
|
||||
|
||||
Delete the virtual camera source and recreate it. The resolution cannot be changed after creation.
|
||||
```
|
||||
VirtualCam (platform:v4l2loopback-000):
|
||||
/dev/video1
|
||||
```
|
||||
|
||||
> Error reading frame in background thread for OpenCVCamera(X): OpenCVCamera(X) frame width=640 or height=480 do not match configured width=1920 or height=1080.
|
||||
10. _Check the camera resolution_. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`.
|
||||
|
||||
This error is caused by OBS Virtual Camera advertising a `1920x1080` resolution despite rescaling. The only fix for now is to comment out the width and height check in `_postprocess_image()`.
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
v4l2-ctl -d /dev/video1 --get-fmt-video
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</details>
|
||||
You should see an entry like:
|
||||
|
||||
```
|
||||
>>> Format Video Capture:
|
||||
>>> Width/Height : 640/480
|
||||
>>> Pixel Format : 'YUYV' (YUYV 4:2:2)
|
||||
```
|
||||
|
||||
Troubleshooting: If the resolution is not correct you will have to delete the Virtual Camera port and try again as it cannot be changed.
|
||||
|
||||
If everything is set up correctly, you can proceed with the rest of the tutorial.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
If everything is set up correctly, your phone will appear as a standard OpenCV camera and can be used with `OpenCVCamera`.
|
||||
|
||||
@@ -1,278 +0,0 @@
|
||||
# Using Subtasks in LeRobot Datasets
|
||||
|
||||
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
|
||||
|
||||
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
|
||||
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
|
||||
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
|
||||
|
||||
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
|
||||
|
||||
## What are Subtasks?
|
||||
|
||||
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
|
||||
|
||||
1. "Approach the apple"
|
||||
2. "Grasp the apple"
|
||||
3. "Lift the apple"
|
||||
4. "Move to basket"
|
||||
5. "Release the apple"
|
||||
|
||||
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
|
||||
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
|
||||
width="80%"
|
||||
/>
|
||||
|
||||
<p>
|
||||
<em>Figure: Overview of subtask annotation.</em>
|
||||
</p>
|
||||
|
||||
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
Subtask information is stored in the dataset metadata:
|
||||
|
||||
```
|
||||
my-dataset/
|
||||
├── data/
|
||||
│ └── ...
|
||||
├── meta/
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ ├── tasks.parquet
|
||||
│ ├── subtasks.parquet # Subtask index → subtask string mapping
|
||||
│ └── episodes/
|
||||
│ └── ...
|
||||
└── videos/
|
||||
└── ...
|
||||
```
|
||||
|
||||
### Subtasks Parquet File
|
||||
|
||||
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
|
||||
|
||||
| subtask_index | subtask (index column) |
|
||||
| ------------- | ---------------------- |
|
||||
| 0 | "Approach the apple" |
|
||||
| 1 | "Grasp the apple" |
|
||||
| 2 | "Lift the apple" |
|
||||
| ... | ... |
|
||||
|
||||
### Frame-Level Annotations
|
||||
|
||||
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
|
||||
|
||||
```python
|
||||
# Example frame data in the parquet file
|
||||
{
|
||||
"index": 42,
|
||||
"timestamp": 1.4,
|
||||
"episode_index": 0,
|
||||
"task_index": 0,
|
||||
"subtask_index": 2, # References "Lift the apple"
|
||||
"observation.state": [...],
|
||||
"action": [...],
|
||||
}
|
||||
```
|
||||
|
||||
## Annotating Datasets with Subtasks
|
||||
|
||||
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
|
||||
|
||||
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
|
||||
|
||||
After completing your annotation:
|
||||
|
||||
1. Click "Push to Hub" to upload your annotated dataset
|
||||
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
|
||||
|
||||
## Loading Datasets with Subtasks
|
||||
|
||||
When you load a dataset with subtask annotations, the subtask information is automatically available:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Load a dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Access a sample
|
||||
sample = dataset[100]
|
||||
|
||||
# The sample includes both task and subtask information
|
||||
print(sample["task"]) # "Collect the fruit"
|
||||
print(sample["subtask"]) # "Grasp the apple"
|
||||
print(sample["task_index"]) # tensor(0)
|
||||
print(sample["subtask_index"]) # tensor(2)
|
||||
```
|
||||
|
||||
### Checking for Subtask Support
|
||||
|
||||
You can check if a dataset has subtask annotations:
|
||||
|
||||
```python
|
||||
# Check if subtasks are available
|
||||
has_subtasks = (
|
||||
"subtask_index" in dataset.features
|
||||
and dataset.meta.subtasks is not None
|
||||
)
|
||||
|
||||
if has_subtasks:
|
||||
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
|
||||
print("Subtasks:", list(dataset.meta.subtasks.index))
|
||||
```
|
||||
|
||||
## Using Subtasks for Training
|
||||
|
||||
### With the Tokenizer Processor
|
||||
|
||||
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
|
||||
|
||||
```python
|
||||
from lerobot.processor.tokenizer_processor import TokenizerProcessor
|
||||
from lerobot.processor.pipeline import ProcessorPipeline
|
||||
|
||||
# Create a tokenizer processor
|
||||
tokenizer_processor = TokenizerProcessor(
|
||||
tokenizer_name_or_path="google/paligemma-3b-pt-224",
|
||||
padding="max_length",
|
||||
max_length=64,
|
||||
)
|
||||
|
||||
# The processor will automatically tokenize subtasks if present in the batch
|
||||
# and add them to the observation under:
|
||||
# - "observation.subtask.tokens"
|
||||
# - "observation.subtask.attention_mask"
|
||||
```
|
||||
|
||||
When subtasks are available in the batch, the tokenizer processor adds:
|
||||
|
||||
- `observation.subtask.tokens`: Tokenized subtask text
|
||||
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
|
||||
|
||||
### DataLoader with Subtasks
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=16,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
# Access subtask information in the batch
|
||||
subtasks = batch["subtask"] # List of subtask strings
|
||||
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
|
||||
|
||||
# Use for training hierarchical policies or reward models
|
||||
print(f"Batch subtasks: {set(subtasks)}")
|
||||
```
|
||||
|
||||
## Example Datasets with Subtask Annotations
|
||||
|
||||
Try loading a dataset with subtask annotations:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Example dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Explore the subtasks
|
||||
print("Available subtasks:")
|
||||
for subtask_name in dataset.meta.subtasks.index:
|
||||
print(f" - {subtask_name}")
|
||||
|
||||
# Get subtask distribution
|
||||
subtask_counts = {}
|
||||
for i in range(len(dataset)):
|
||||
sample = dataset[i]
|
||||
subtask = sample["subtask"]
|
||||
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
|
||||
|
||||
print("\nSubtask distribution:")
|
||||
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {subtask}: {count} frames")
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Hierarchical Policy Training
|
||||
|
||||
Train policies that predict both actions and current subtask:
|
||||
|
||||
```python
|
||||
class HierarchicalPolicy(nn.Module):
|
||||
def __init__(self, num_subtasks):
|
||||
super().__init__()
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
|
||||
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
actions = self.action_head(features)
|
||||
subtask_logits = self.subtask_head(features)
|
||||
return actions, subtask_logits
|
||||
```
|
||||
|
||||
### 2. Stage-Aware Reward Modeling (SARM)
|
||||
|
||||
Build reward models that understand task progression:
|
||||
|
||||
```python
|
||||
# SARM predicts:
|
||||
# - Stage: Which subtask is being executed (discrete)
|
||||
# - Progress: How far along the subtask (continuous 0-1)
|
||||
|
||||
class SARMRewardModel(nn.Module):
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
stage_logits = self.stage_classifier(features)
|
||||
progress = self.progress_regressor(features)
|
||||
return stage_logits, progress
|
||||
```
|
||||
|
||||
### 3. Progress Visualization
|
||||
|
||||
Monitor robot execution by tracking subtask progression:
|
||||
|
||||
```python
|
||||
def visualize_execution(model, observations):
|
||||
for t, obs in enumerate(observations):
|
||||
action, subtask_logits = model(obs)
|
||||
predicted_subtask = subtask_names[subtask_logits.argmax()]
|
||||
print(f"t={t}: Executing '{predicted_subtask}'")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### LeRobotDataset Properties
|
||||
|
||||
| Property | Type | Description |
|
||||
| --------------------------- | ---------------------- | ------------------------------------------ |
|
||||
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
|
||||
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
|
||||
|
||||
### Sample Keys
|
||||
|
||||
When subtasks are available, each sample includes:
|
||||
|
||||
| Key | Type | Description |
|
||||
| --------------- | -------------- | ------------------------------------ |
|
||||
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
|
||||
| `subtask` | `str` | Natural language subtask description |
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
|
||||
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
|
||||
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
|
||||
@@ -1,276 +0,0 @@
|
||||
# OpenArm
|
||||
|
||||
[OpenArm](https://openarm.dev) is an open-source 7DOF humanoid arm designed for physical AI research and deployment.
|
||||
|
||||
To get your OpenArm, assembled or DIY, and join the global community, browse verified and certified manufacturers worldwide at [openarm.dev](https://openarm.dev).
|
||||
|
||||
## What's Unique?
|
||||
|
||||
- **Human-Scale Design**: OpenArm is designed with human-like proportions, scaled for a person around 160-165cm tall. This provides an optimal balance between practical reach and manageable inertia for safe, responsive operation.
|
||||
|
||||
- **Safety-First Architecture**: Built with QDD backdrivable motors and high compliance, OpenArm prioritizes safe human-robot interaction while maintaining practical payload capabilities (6.0kg peak / 4.1kg nominal) for real-world tasks.
|
||||
|
||||
- **Built for Durability**: Critical structural components use aluminum and stainless steel construction, ensuring robust performance for repetitive data collection and continuous research use.
|
||||
|
||||
- **Fully Accessible & Buildable**: Every component, from CNC parts and 3D-printed casings to electrical wiring is designed to be purchasable and buildable by individual researchers and labs, with complete fabrication data provided.
|
||||
|
||||
- **Practical & Affordable**: At $6,500 USD for a complete bimanual system, OpenArm delivers research-grade capabilities at a fraction of traditional humanoid robot costs.
|
||||
|
||||
## Platform Requirements
|
||||
|
||||
<Tip warning={true}>
|
||||
**Linux Only**: OpenArm currently only works on Linux. The CAN bus USB adapter
|
||||
does not have macOS drivers and has not been tested on Windows.
|
||||
</Tip>
|
||||
|
||||
## Safety Guide
|
||||
|
||||
Before operating OpenArm, please read the [official safety guide](https://docs.openarm.dev/getting-started/safety-guide). Key points:
|
||||
|
||||
- **Secure installation**: Fasten the arm to a flat, stable surface with screws or clamps
|
||||
- **Safe distance**: Keep body parts and objects outside the range of motion during operation
|
||||
- **Protective equipment**: Always wear safety goggles; use additional PPE as needed
|
||||
- **Payload limits**: Do not exceed specified payload limits (6.0kg peak / 4.1kg nominal per arm)
|
||||
- **Emergency stop**: Know the location and operation of the emergency stop device
|
||||
- **Regular inspection**: Check for loose screws, damaged mechanical limits, unusual noises, and wiring damage
|
||||
|
||||
## Hardware Setup
|
||||
|
||||
Follow the official [OpenArm hardware documentation](https://docs.openarm.dev) for:
|
||||
|
||||
- Bill of materials and sourcing
|
||||
- 3D printing instructions
|
||||
- Mechanical assembly
|
||||
- Electrical wiring
|
||||
|
||||
The hardware repositories are available at [github.com/enactic/openarm](https://github.com/enactic/openarm).
|
||||
|
||||
## CAN Bus Setup
|
||||
|
||||
OpenArm uses CAN bus communication with Damiao motors. Once you have the CAN bus USB adapter plugged into your Linux PC, follow the [Damiao Motors and CAN Bus guide](./damiao) to configure the interface.
|
||||
|
||||
Quick setup:
|
||||
|
||||
```bash
|
||||
# Setup CAN interfaces
|
||||
lerobot-setup-can --mode=setup --interfaces=can0,can1
|
||||
|
||||
# Test motor communication
|
||||
lerobot-setup-can --mode=test --interfaces=can0,can1
|
||||
```
|
||||
|
||||
## Install LeRobot 🤗
|
||||
|
||||
Follow our [Installation Guide](./installation), then install the Damiao motor support:
|
||||
|
||||
```bash
|
||||
pip install -e ".[damiao]"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Follower Arm (Robot)
|
||||
|
||||
<hfoptions id="follower">
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=openarm_follower \
|
||||
--robot.port=can0 \
|
||||
--robot.side=right \
|
||||
--robot.id=my_openarm_follower
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
```python
|
||||
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||
|
||||
config = OpenArmFollowerConfig(
|
||||
port="can0",
|
||||
side="right", # or "left" for left arm
|
||||
id="my_openarm_follower",
|
||||
)
|
||||
|
||||
follower = OpenArmFollower(config)
|
||||
follower.connect()
|
||||
|
||||
# Read current state
|
||||
obs = follower.get_observation()
|
||||
print(obs)
|
||||
|
||||
# Send action (position in degrees)
|
||||
action = {
|
||||
"joint_1.pos": 0.0,
|
||||
"joint_2.pos": 0.0,
|
||||
"joint_3.pos": 0.0,
|
||||
"joint_4.pos": 45.0,
|
||||
"joint_5.pos": 0.0,
|
||||
"joint_6.pos": 0.0,
|
||||
"joint_7.pos": 0.0,
|
||||
"gripper.pos": 0.0,
|
||||
}
|
||||
follower.send_action(action)
|
||||
|
||||
follower.disconnect()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Leader Arm (Teleoperator)
|
||||
|
||||
The leader arm is used for teleoperation - manually moving it to control the follower arm.
|
||||
|
||||
<hfoptions id="leader">
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=openarm_leader \
|
||||
--teleop.port=can1 \
|
||||
--teleop.id=my_openarm_leader
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
```python
|
||||
from lerobot.teleoperators.openarm_leader import OpenArmLeader, OpenArmLeaderConfig
|
||||
|
||||
config = OpenArmLeaderConfig(
|
||||
port="can1",
|
||||
id="my_openarm_leader",
|
||||
manual_control=True, # Disable torque for manual movement
|
||||
)
|
||||
|
||||
leader = OpenArmLeader(config)
|
||||
leader.connect()
|
||||
|
||||
# Read current position (as action to send to follower)
|
||||
action = leader.get_action()
|
||||
print(action)
|
||||
|
||||
leader.disconnect()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Teleoperation
|
||||
|
||||
To teleoperate OpenArm with leader-follower control:
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=openarm_follower \
|
||||
--robot.port=can0 \
|
||||
--robot.side=right \
|
||||
--robot.id=my_follower \
|
||||
--teleop.type=openarm_leader \
|
||||
--teleop.port=can1 \
|
||||
--teleop.id=my_leader
|
||||
```
|
||||
|
||||
### Bimanual Teleoperation
|
||||
|
||||
To teleoperate a bimanual OpenArm setup with two leader and two follower arms:
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.left_arm_config.port=can0 \
|
||||
--robot.left_arm_config.side=left \
|
||||
--robot.right_arm_config.port=can1 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.id=my_bimanual_follower \
|
||||
--teleop.type=bi_openarm_leader \
|
||||
--teleop.left_arm_config.port=can2 \
|
||||
--teleop.right_arm_config.port=can3 \
|
||||
--teleop.id=my_bimanual_leader
|
||||
```
|
||||
|
||||
### Recording Data
|
||||
|
||||
To record a dataset during teleoperation:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=openarm_follower \
|
||||
--robot.port=can0 \
|
||||
--robot.side=right \
|
||||
--robot.id=my_follower \
|
||||
--teleop.type=openarm_leader \
|
||||
--teleop.port=can1 \
|
||||
--teleop.id=my_leader \
|
||||
--repo-id=my_hf_username/my_openarm_dataset \
|
||||
--fps=30 \
|
||||
--num-episodes=10
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Follower Configuration
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| --------------------- | --------- | ---------------------------------------------------------- |
|
||||
| `port` | - | CAN interface (e.g., `can0`) |
|
||||
| `side` | `None` | Arm side: `"left"`, `"right"`, or `None` for custom limits |
|
||||
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
|
||||
| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
|
||||
| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
|
||||
| `max_relative_target` | `None` | Safety limit for relative target positions |
|
||||
| `position_kp` | Per-joint | Position control proportional gains |
|
||||
| `position_kd` | Per-joint | Position control derivative gains |
|
||||
|
||||
### Leader Configuration
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| ------------------ | --------- | ----------------------------------- |
|
||||
| `port` | - | CAN interface (e.g., `can1`) |
|
||||
| `manual_control` | `True` | Disable torque for manual movement |
|
||||
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
|
||||
| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
|
||||
| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
|
||||
|
||||
## Motor Configuration
|
||||
|
||||
OpenArm uses Damiao motors with the following default configuration:
|
||||
|
||||
| Joint | Motor Type | Send ID | Recv ID |
|
||||
| --------------------------- | ---------- | ------- | ------- |
|
||||
| joint_1 (Shoulder pan) | DM8009 | 0x01 | 0x11 |
|
||||
| joint_2 (Shoulder lift) | DM8009 | 0x02 | 0x12 |
|
||||
| joint_3 (Shoulder rotation) | DM4340 | 0x03 | 0x13 |
|
||||
| joint_4 (Elbow flex) | DM4340 | 0x04 | 0x14 |
|
||||
| joint_5 (Wrist roll) | DM4310 | 0x05 | 0x15 |
|
||||
| joint_6 (Wrist pitch) | DM4310 | 0x06 | 0x16 |
|
||||
| joint_7 (Wrist rotation) | DM4310 | 0x07 | 0x17 |
|
||||
| gripper | DM4310 | 0x08 | 0x18 |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### No Response from Motors
|
||||
|
||||
1. Check power supply connections
|
||||
2. Verify CAN wiring (CAN-H, CAN-L, GND)
|
||||
3. Run diagnostics: `lerobot-setup-can --mode=test --interfaces=can0`
|
||||
4. See the [Damiao troubleshooting guide](./damiao#troubleshooting) for more details
|
||||
|
||||
### CAN Interface Not Found
|
||||
|
||||
Ensure the CAN interface is configured:
|
||||
|
||||
```bash
|
||||
ip link show can0
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- [OpenArm Website](https://openarm.dev)
|
||||
- [OpenArm Documentation](https://docs.openarm.dev)
|
||||
- [OpenArm GitHub](https://github.com/enactic/openarm)
|
||||
- [Safety Guide](https://docs.openarm.dev/getting-started/safety-guide)
|
||||
- [Damiao Motors and CAN Bus](./damiao)
|
||||
@@ -0,0 +1,86 @@
|
||||
# Training-Time RTC
|
||||
|
||||
Training-Time RTC teaches the model to handle inference delay during training.
|
||||
It feeds the **ground-truth action prefix** to the model and trains only on the remaining postfix actions.
|
||||
This keeps chunk transitions smooth without doing any inference-time inpainting.
|
||||
|
||||
Based on: [Training-Time Action Conditioning for Efficient Real-Time Chunking](https://arxiv.org/abs/2512.05964).
|
||||
|
||||
LeRobot supports this for `pi0`, `pi05` and `smolvla` without changing model parameters.
|
||||
|
||||
---
|
||||
|
||||
## How It Works
|
||||
|
||||
### At Training Time
|
||||
|
||||
- Sample a delay `d` per batch element.
|
||||
- Keep the first `d` action steps as **ground truth** (no noise).
|
||||
- Add noise only to the postfix actions.
|
||||
- Set the flow-matching timestep to **1.0** for prefix tokens and normal timesteps for postfix tokens.
|
||||
- Mask the loss to only train on the postfix.
|
||||
|
||||
### At Inference Time
|
||||
|
||||
When `rtc_training_config.enabled=true`, the model uses training-time RTC inference:
|
||||
|
||||
- Replace prefix positions in `x_t` with previous chunk's leftover actions.
|
||||
- Set timestep to **1.0** for prefix positions.
|
||||
|
||||
---
|
||||
|
||||
## Quick Start (CLI)
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=pi0 \
|
||||
--dataset.repo_id=your/dataset \
|
||||
--policy.rtc_training_config.enabled=true \
|
||||
--policy.rtc_training_config.min_delay=0 \
|
||||
--policy.rtc_training_config.max_delay=6 \
|
||||
--policy.rtc_training_config.delay_distribution=UNIFORM
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Inference with Training-Time RTC
|
||||
|
||||
After training with `rtc_training_config`, use the same config at inference. The model will automatically use training-time RTC inference:
|
||||
|
||||
```python
|
||||
policy = PI0Policy.from_pretrained("path/to/trained/model")
|
||||
# rtc_training_config is loaded from the saved config
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
batch,
|
||||
inference_delay=5, # estimated delay in timesteps
|
||||
prev_chunk_left_over=previous_actions, # from previous chunk
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Parameters
|
||||
|
||||
`RTCTrainingConfig` is available on the policy config (`pi0`, `pi05`, `smolvla`, `xvla`):
|
||||
|
||||
- **`enabled`**: Toggle training-time RTC (both training and inference).
|
||||
- **`min_delay` / `max_delay`**: Delay range (inclusive).
|
||||
- **`delay_distribution`**:
|
||||
- `UNIFORM`: uniform in `[min_delay, max_delay]`
|
||||
- `EXP`: exponentially decayed distribution over delays
|
||||
- **`exp_decay`**: Exponential decay factor for `EXP` sampling.
|
||||
|
||||
---
|
||||
|
||||
## Notes and Recommendations
|
||||
|
||||
- Start with `min_delay=0` and `max_delay` around your expected worst-case inference delay.
|
||||
- Use `EXP` if you want more supervision on smaller delays.
|
||||
|
||||
---
|
||||
|
||||
## Related Docs
|
||||
|
||||
- [Real-Time Chunking (Inference-Time RTC)](./rtc)
|
||||
- [Pi0](./pi0), [Pi0.5](./pi05), [SmolVLA](./smolvla)
|
||||
@@ -188,105 +188,7 @@ Press `Ctrl+C` to stop the policy.
|
||||
|
||||
## Running in Simulation Mode (MuJoCo)
|
||||
|
||||
You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI.
|
||||
|
||||
### Calibrate Exoskeleton Teleoperator
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo
|
||||
```
|
||||
|
||||
### Teleoperate in Simulation
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--fps=100
|
||||
```
|
||||
|
||||
### Record Dataset in Simulation
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
|
||||
|
||||
---
|
||||
|
||||
## Running on Real Robot
|
||||
|
||||
Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot.
|
||||
|
||||
### Start the Camera Server
|
||||
|
||||
On the robot, start the ZMQ image server:
|
||||
|
||||
```bash
|
||||
python src/lerobot/cameras/zmq/image_server.py
|
||||
```
|
||||
|
||||
Keep this running in a separate terminal for camera streaming during recording.
|
||||
|
||||
### Teleoperate Real Robot
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--fps=100
|
||||
```
|
||||
|
||||
### Record Dataset on Real Robot
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
**Note**: Update `server_address` to match your robot's camera server IP.
|
||||
|
||||
Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real)
|
||||
|
||||
---
|
||||
You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config.
|
||||
|
||||
## Additional Resources
|
||||
|
||||
|
||||
@@ -81,25 +81,24 @@ def replay(cfg: ReplayConfig):
|
||||
actions = dataset.hf_dataset.select_columns(ACTION)
|
||||
robot.connect()
|
||||
|
||||
try:
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
key = f"{name.removeprefix('main_')}.pos"
|
||||
action[key] = action_array[i].item()
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
key = f"{name.removeprefix('main_')}.pos"
|
||||
action[key] = action_array[i].item()
|
||||
|
||||
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
|
||||
action["elbow_flex.pos"] -= 90
|
||||
robot.send_action(action)
|
||||
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
|
||||
action["elbow_flex.pos"] -= 90
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
|
||||
finally:
|
||||
robot.disconnect()
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+43
-45
@@ -78,24 +78,40 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_evaluate")
|
||||
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
print("Starting evaluate loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -104,42 +120,24 @@ def main():
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+44
-45
@@ -74,23 +74,40 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_record")
|
||||
|
||||
try:
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {recorded_episodes}")
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {recorded_episodes}")
|
||||
|
||||
# Main record loop
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
@@ -98,44 +115,26 @@ def main():
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
listener.stop()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+15
-17
@@ -42,27 +42,25 @@ def main():
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx][ACTION][i])
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
finally:
|
||||
robot.disconnect()
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -142,24 +142,38 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_evaluate")
|
||||
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -168,41 +182,24 @@ def main():
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -149,23 +149,38 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_record")
|
||||
|
||||
try:
|
||||
if not robot.is_connected or not phone.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
if not robot.is_connected or not phone.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop. Move your phone to teleoperate the robot...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print("Starting record loop. Move your phone to teleoperate the robot...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
@@ -173,43 +188,25 @@ def main():
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
listener.stop()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -73,34 +73,32 @@ def main():
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i])
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
finally:
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -142,24 +142,38 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="so100_so100_evaluate")
|
||||
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -168,41 +182,24 @@ def main():
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -146,23 +146,38 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording_phone")
|
||||
|
||||
try:
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print("Starting record loop...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
@@ -170,44 +185,25 @@ def main():
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
listener.stop()
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -74,35 +74,32 @@ def main():
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i])
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+1
-6
@@ -105,17 +105,12 @@ dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
damiao = ["python-can>=4.2.0,<5.0.0"]
|
||||
|
||||
# Robots
|
||||
openarms = ["lerobot[damiao]"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
unitree_g1 = [
|
||||
"pyzmq>=26.2.1,<28.0.0",
|
||||
"onnxruntime>=1.16.0,<2.0.0",
|
||||
"pin>=3.0.0,<4.0.0",
|
||||
"meshcat>=0.3.0,<0.4.0",
|
||||
"matplotlib>=3.9.0,<4.0.0",
|
||||
"casadi>=3.6.0,<4.0.0",
|
||||
"onnxruntime>=1.16.0,<2.0.0"
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
from .configs import CameraConfig
|
||||
from .configs import CameraConfig, ColorMode
|
||||
|
||||
|
||||
class Camera(abc.ABC):
|
||||
@@ -31,12 +30,20 @@ class Camera(abc.ABC):
|
||||
|
||||
Manages basic camera properties (FPS, resolution) and core operations:
|
||||
- Connection/disconnection
|
||||
- Frame capture (sync/async/latest)
|
||||
- Frame capture (sync/async)
|
||||
|
||||
Attributes:
|
||||
fps (int | None): Configured frames per second
|
||||
width (int | None): Frame width in pixels
|
||||
height (int | None): Frame height in pixels
|
||||
|
||||
Example:
|
||||
class MyCamera(Camera):
|
||||
def __init__(self, config): ...
|
||||
@property
|
||||
def is_connected(self) -> bool: ...
|
||||
def connect(self, warmup=True): ...
|
||||
# Plus other required methods
|
||||
"""
|
||||
|
||||
def __init__(self, config: CameraConfig):
|
||||
@@ -49,32 +56,6 @@ class Camera(abc.ABC):
|
||||
self.width: int | None = config.width
|
||||
self.height: int | None = config.height
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Context manager entry.
|
||||
Automatically connects to the camera.
|
||||
"""
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
"""
|
||||
Context manager exit.
|
||||
Automatically disconnects, ensuring resources are released even on error.
|
||||
"""
|
||||
self.disconnect()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
Destructor safety net.
|
||||
Attempts to disconnect if the object is garbage collected without cleanup.
|
||||
"""
|
||||
try:
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
except Exception: # nosec B110
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
@@ -108,10 +89,12 @@ class Camera(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self) -> NDArray[Any]:
|
||||
"""Capture and return a single frame from the camera synchronously.
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""Capture and return a single frame from the camera.
|
||||
|
||||
This is a blocking call that will wait for the hardware and its SDK.
|
||||
Args:
|
||||
color_mode: Desired color mode for the output frame. If None,
|
||||
uses the camera's default color mode.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured frame as a numpy array.
|
||||
@@ -120,64 +103,17 @@ class Camera(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def async_read(self, timeout_ms: float = ...) -> NDArray[Any]:
|
||||
"""Return the most recent new frame.
|
||||
|
||||
This method retrieves the latest frame captured by the background thread.
|
||||
If a new frame is already available in the buffer (captured since the last call),
|
||||
it returns it immediately.
|
||||
|
||||
It blocks up to `timeout_ms` only if the buffer is empty or if the latest frame
|
||||
was already consumed by a previous `async_read` call.
|
||||
|
||||
Essentially, this method return the latest unconsumed frame, waiting if necessary
|
||||
for a new one to arrive within the specified timeout.
|
||||
|
||||
Usage:
|
||||
- Ideal for control loops where you want to ensure every processed frame
|
||||
is fresh, effectively synchronizing your loop to the camera's FPS.
|
||||
- Causes of a timeout usually include: very low camera FPS, heavy processing load,
|
||||
or if the camera is disconnected.
|
||||
"""Asynchronously capture and return a single frame from the camera.
|
||||
|
||||
Args:
|
||||
timeout_ms: Maximum time to wait for a new frame in milliseconds.
|
||||
Defaults to 200ms (0.2s).
|
||||
timeout_ms: Maximum time to wait for a frame in milliseconds.
|
||||
Defaults to implementation-specific timeout.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured frame as a numpy array.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If no new frame arrives within `timeout_ms`.
|
||||
"""
|
||||
pass
|
||||
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Usage:
|
||||
Ideal for scenarios requiring zero latency or decoupled frequencies & when
|
||||
we want a guaranteed frame, such as UI visualization, logging, or
|
||||
non-critical monitoring.
|
||||
|
||||
Returns:
|
||||
NDArray[Any]: The frame image (numpy array).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
NotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
warnings.warn(
|
||||
f"{self.__class__.__name__}.read_latest() is not implemented. "
|
||||
"Please override read_latest(); it will be required in future releases.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.async_read()
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from the camera and release resources."""
|
||||
|
||||
@@ -70,24 +70,34 @@ class OpenCVCamera(Camera):
|
||||
Example:
|
||||
```python
|
||||
from lerobot.cameras.opencv import OpenCVCamera
|
||||
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig, ColorMode, Cv2Rotation
|
||||
|
||||
# Basic usage with camera index 0
|
||||
config = OpenCVCameraConfig(index_or_path=0)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
|
||||
# Read 1 frame synchronously (blocking)
|
||||
# Read 1 frame synchronously
|
||||
color_image = camera.read()
|
||||
print(color_image.shape)
|
||||
|
||||
# Read 1 frame asynchronously (waits for new frame with a timeout)
|
||||
# Read 1 frame asynchronously
|
||||
async_image = camera.async_read()
|
||||
|
||||
# Get the latest frame immediately (no wait, returns timestamp)
|
||||
latest_image, timestamp = camera.read_latest()
|
||||
|
||||
# When done, properly disconnect the camera using
|
||||
camera.disconnect()
|
||||
|
||||
# Example with custom settings
|
||||
custom_config = OpenCVCameraConfig(
|
||||
index_or_path='/dev/video0', # Or use an index
|
||||
fps=30,
|
||||
width=1280,
|
||||
height=720,
|
||||
color_mode=ColorMode.RGB,
|
||||
rotation=Cv2Rotation.ROTATE_90
|
||||
)
|
||||
custom_camera = OpenCVCamera(custom_config)
|
||||
# ... connect, read, disconnect ...
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -113,7 +123,6 @@ class OpenCVCamera(Camera):
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_timestamp: float | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
@@ -137,16 +146,12 @@ class OpenCVCamera(Camera):
|
||||
Connects to the OpenCV camera specified in the configuration.
|
||||
|
||||
Initializes the OpenCV VideoCapture object, sets desired camera properties
|
||||
(FPS, width, height), starts the background reading thread and performs initial checks.
|
||||
|
||||
Args:
|
||||
warmup (bool): If True, waits at connect() time until at least one valid frame
|
||||
has been captured by the background thread. Defaults to True.
|
||||
(FPS, width, height), and performs initial checks.
|
||||
|
||||
Raises:
|
||||
DeviceAlreadyConnectedError: If the camera is already connected.
|
||||
ConnectionError: If the specified camera index/path is not found or fails to open.
|
||||
RuntimeError: If the camera opens but fails to apply requested settings.
|
||||
ConnectionError: If the specified camera index/path is not found or the camera is found but fails to open.
|
||||
RuntimeError: If the camera opens but fails to apply requested FPS/resolution settings.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
@@ -165,16 +170,12 @@ class OpenCVCamera(Camera):
|
||||
)
|
||||
|
||||
self._configure_capture_settings()
|
||||
self._start_read_thread()
|
||||
|
||||
if warmup and self.warmup_s > 0:
|
||||
if warmup:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.warmup_s:
|
||||
self.async_read(timeout_ms=self.warmup_s * 1000)
|
||||
self.read()
|
||||
time.sleep(0.1)
|
||||
with self.frame_lock:
|
||||
if self.latest_frame is None:
|
||||
raise ConnectionError(f"{self} failed to capture frames during warmup.")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@@ -195,7 +196,8 @@ class OpenCVCamera(Camera):
|
||||
Raises:
|
||||
RuntimeError: If the camera fails to set any of the specified properties
|
||||
to the requested value.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
DeviceNotConnectedError: If the camera is not connected when attempting
|
||||
to configure settings.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
|
||||
@@ -337,17 +339,6 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return found_cameras_info
|
||||
|
||||
def _read_from_hardware(self) -> NDArray[Any]:
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
ret, frame = self.videocapture.read()
|
||||
|
||||
if not ret:
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
|
||||
return frame
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
@@ -355,6 +346,11 @@ class OpenCVCamera(Camera):
|
||||
This is a blocking call. It waits for the next available frame from the
|
||||
camera hardware via OpenCV.
|
||||
|
||||
Args:
|
||||
color_mode (Optional[ColorMode]): If specified, overrides the default
|
||||
color mode (`self.color_mode`) for this read operation (e.g.,
|
||||
request RGB even if default is BGR).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured frame as a NumPy array in the format
|
||||
(height, width, channels), using the specified or default
|
||||
@@ -366,34 +362,34 @@ class OpenCVCamera(Camera):
|
||||
received frame dimensions don't match expectations before rotation.
|
||||
ValueError: If an invalid `color_mode` is requested.
|
||||
"""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if color_mode is not None:
|
||||
logger.warning(
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
self.new_frame_event.clear()
|
||||
frame = self.async_read(timeout_ms=10000)
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
ret, frame = self.videocapture.read()
|
||||
|
||||
if not ret or frame is None:
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
|
||||
processed_frame = self._postprocess_image(frame, color_mode)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return frame
|
||||
return processed_frame
|
||||
|
||||
def _postprocess_image(self, image: NDArray[Any]) -> NDArray[Any]:
|
||||
def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Applies color conversion, dimension validation, and rotation to a raw frame.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The raw image frame (expected BGR format from OpenCV).
|
||||
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
|
||||
uses the instance's default `self.color_mode`.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The processed image frame.
|
||||
@@ -403,10 +399,11 @@ class OpenCVCamera(Camera):
|
||||
RuntimeError: If the raw frame dimensions do not match the configured
|
||||
`width` and `height`.
|
||||
"""
|
||||
requested_color_mode = self.color_mode if color_mode is None else color_mode
|
||||
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
if requested_color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
f"Invalid color mode '{requested_color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
)
|
||||
|
||||
h, w, c = image.shape
|
||||
@@ -420,7 +417,7 @@ class OpenCVCamera(Camera):
|
||||
raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).")
|
||||
|
||||
processed_image = image
|
||||
if self.color_mode == ColorMode.RGB:
|
||||
if requested_color_mode == ColorMode.RGB:
|
||||
processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]:
|
||||
@@ -434,7 +431,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
2. Stores result in latest_frame (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
@@ -442,37 +439,30 @@ class OpenCVCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
raw_frame = self._read_from_hardware()
|
||||
processed_frame = self._postprocess_image(raw_frame)
|
||||
capture_time = time.perf_counter()
|
||||
color_image = self.read()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = processed_frame
|
||||
self.latest_timestamp = capture_time
|
||||
self.latest_frame = color_image
|
||||
self.new_frame_event.set()
|
||||
failure_count = 0
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except Exception as e:
|
||||
if failure_count <= 10:
|
||||
failure_count += 1
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
else:
|
||||
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
"""Starts or restarts the background read thread if it's not running."""
|
||||
self._stop_read_thread()
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=0.1)
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
time.sleep(0.1)
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
"""Signals the background read thread to stop and waits for it to join."""
|
||||
@@ -485,11 +475,6 @@ class OpenCVCamera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
@@ -497,7 +482,6 @@ class OpenCVCamera(Camera):
|
||||
This method retrieves the most recent frame captured by the background
|
||||
read thread. It does not block waiting for the camera hardware directly,
|
||||
but may wait up to timeout_ms for the background thread to provide a frame.
|
||||
It is “best effort” under high FPS.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
@@ -516,12 +500,13 @@ class OpenCVCamera(Camera):
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
self._start_read_thread()
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
thread_alive = self.thread is not None and self.thread.is_alive()
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
|
||||
f"Read thread alive: {self.thread.is_alive()}."
|
||||
f"Read thread alive: {thread_alive}."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
@@ -533,42 +518,6 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Returns:
|
||||
NDArray[Any]: The frame image (numpy array).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera and cleans up resources.
|
||||
@@ -589,9 +538,4 @@ class OpenCVCamera(Camera):
|
||||
self.videocapture.release()
|
||||
self.videocapture = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -80,8 +80,6 @@ class Reachy2Camera(Camera):
|
||||
self.config = config
|
||||
|
||||
self.color_mode = config.color_mode
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_timestamp: float | None = None
|
||||
|
||||
self.cam_manager: CameraManager | None = None
|
||||
|
||||
@@ -127,7 +125,12 @@ class Reachy2Camera(Camera):
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
This method retrieves the most recent frame available in Reachy 2's low-level software.
|
||||
This is a blocking call.
|
||||
|
||||
Args:
|
||||
color_mode (Optional[ColorMode]): If specified, overrides the default
|
||||
color mode (`self.color_mode`) for this read operation (e.g.,
|
||||
request RGB even if default is BGR).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured frame as a NumPy array in the format
|
||||
@@ -142,11 +145,6 @@ class Reachy2Camera(Camera):
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if color_mode is not None:
|
||||
logger.warning(
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
|
||||
|
||||
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
|
||||
@@ -167,18 +165,11 @@ class Reachy2Camera(Camera):
|
||||
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
|
||||
|
||||
if frame is None:
|
||||
raise RuntimeError(f"Internal error: No frame available for {self}.")
|
||||
return np.empty((0, 0, 3), dtype=np.uint8)
|
||||
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
)
|
||||
if self.color_mode == ColorMode.RGB:
|
||||
if self.config.color_mode == "rgb":
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
self.latest_frame = frame
|
||||
self.latest_timestamp = time.perf_counter()
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
@@ -186,7 +177,13 @@ class Reachy2Camera(Camera):
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Same as read()
|
||||
Reads the latest available frame.
|
||||
|
||||
This method retrieves the most recent frame available in Reachy 2's low-level software.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
to become available. Defaults to 200ms (0.2 seconds).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The latest captured frame as a NumPy array in the format
|
||||
@@ -200,38 +197,12 @@ class Reachy2Camera(Camera):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
return self.read()
|
||||
frame = self.read()
|
||||
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
if frame is None:
|
||||
raise RuntimeError(f"Internal error: No frame available for {self}.")
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Returns:
|
||||
tuple[NDArray, float]:
|
||||
- The frame image (numpy array).
|
||||
- The timestamp (time.perf_counter) when this frame was captured.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.latest_frame is None or self.latest_timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - self.latest_timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return self.latest_frame
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -72,14 +72,15 @@ class RealSenseCamera(Camera):
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect()
|
||||
|
||||
# Read 1 frame synchronously (blocking)
|
||||
# Read 1 frame synchronously
|
||||
color_image = camera.read()
|
||||
print(color_image.shape)
|
||||
|
||||
# Read 1 frame asynchronously (waits for new frame with a timeout)
|
||||
# Read 1 frame asynchronously
|
||||
async_image = camera.async_read()
|
||||
|
||||
# Get the latest frame immediately (no wait, returns timestamp)
|
||||
latest_image, timestamp = camera.read_latest()
|
||||
# When done, properly disconnect the camera using
|
||||
camera.disconnect()
|
||||
|
||||
# Example with depth capture and custom settings
|
||||
custom_config = RealSenseCameraConfig(
|
||||
@@ -132,9 +133,7 @@ class RealSenseCamera(Camera):
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_color_frame: NDArray[Any] | None = None
|
||||
self.latest_depth_frame: NDArray[Any] | None = None
|
||||
self.latest_timestamp: float | None = None
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
@@ -159,10 +158,6 @@ class RealSenseCamera(Camera):
|
||||
Initializes the RealSense pipeline, configures the required streams (color
|
||||
and optionally depth), starts the pipeline, and validates the actual stream settings.
|
||||
|
||||
Args:
|
||||
warmup (bool): If True, waits at connect() time until at least one valid frame
|
||||
has been captured by the background thread. Defaults to True.
|
||||
|
||||
Raises:
|
||||
DeviceAlreadyConnectedError: If the camera is already connected.
|
||||
ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique).
|
||||
@@ -186,18 +181,15 @@ class RealSenseCamera(Camera):
|
||||
) from e
|
||||
|
||||
self._configure_capture_settings()
|
||||
self._start_read_thread()
|
||||
|
||||
# NOTE(Steven/Caroline): Enforcing at least one second of warmup as RS cameras need a bit of time before the first read. If we don't wait, the first read from the warmup will raise.
|
||||
self.warmup_s = max(self.warmup_s, 1)
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.warmup_s:
|
||||
self.async_read(timeout_ms=self.warmup_s * 1000)
|
||||
time.sleep(0.1)
|
||||
with self.frame_lock:
|
||||
if self.latest_color_frame is None or self.use_depth and self.latest_depth_frame is None:
|
||||
raise ConnectionError(f"{self} failed to capture frames during warmup.")
|
||||
if warmup:
|
||||
time.sleep(
|
||||
1
|
||||
) # NOTE(Steven): RS cameras need a bit of time to warm up before the first read. If we don't wait, the first read from the warmup will raise.
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.warmup_s:
|
||||
self.read()
|
||||
time.sleep(0.1)
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@@ -327,6 +319,9 @@ class RealSenseCamera(Camera):
|
||||
This is a blocking call. It waits for a coherent set of frames (depth)
|
||||
from the camera hardware via the RealSense pipeline.
|
||||
|
||||
Args:
|
||||
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The depth map as a NumPy array (height, width)
|
||||
of type `np.uint16` (raw depth values in millimeters) and rotation.
|
||||
@@ -335,52 +330,44 @@ class RealSenseCamera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If reading frames from the pipeline fails or frames are invalid.
|
||||
"""
|
||||
if timeout_ms:
|
||||
logger.warning(
|
||||
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(
|
||||
f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
self.new_frame_event.clear()
|
||||
|
||||
_ = self.async_read(timeout_ms=10000)
|
||||
|
||||
with self.frame_lock:
|
||||
depth_map = self.latest_depth_frame
|
||||
|
||||
if depth_map is None:
|
||||
raise RuntimeError("No depth frame available. Ensure camera is streaming.")
|
||||
|
||||
return depth_map
|
||||
|
||||
def _read_from_hardware(self):
|
||||
if self.rs_pipeline is None:
|
||||
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
|
||||
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=10000)
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
|
||||
|
||||
if not ret or frame is None:
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
raise RuntimeError(f"{self} read_depth failed (status={ret}).")
|
||||
|
||||
return frame
|
||||
depth_frame = frame.get_depth_frame()
|
||||
depth_map = np.asanyarray(depth_frame.get_data())
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]:
|
||||
depth_map_processed = self._postprocess_image(depth_map, depth_frame=True)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return depth_map_processed
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame (color) synchronously from the camera.
|
||||
|
||||
This is a blocking call. It waits for a coherent set of frames (color)
|
||||
from the camera hardware via the RealSense pipeline.
|
||||
|
||||
Args:
|
||||
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured color frame as a NumPy array
|
||||
(height, width, channels), processed according to `color_mode` and rotation.
|
||||
@@ -391,39 +378,39 @@ class RealSenseCamera(Camera):
|
||||
ValueError: If an invalid `color_mode` is requested.
|
||||
"""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if color_mode is not None:
|
||||
logger.warning(
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if timeout_ms:
|
||||
logger.warning(
|
||||
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
self.new_frame_event.clear()
|
||||
if self.rs_pipeline is None:
|
||||
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
|
||||
|
||||
frame = self.async_read(timeout_ms=10000)
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
|
||||
|
||||
if not ret or frame is None:
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
|
||||
color_frame = frame.get_color_frame()
|
||||
color_image_raw = np.asanyarray(color_frame.get_data())
|
||||
|
||||
color_image_processed = self._postprocess_image(color_image_raw, color_mode)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return frame
|
||||
return color_image_processed
|
||||
|
||||
def _postprocess_image(self, image: NDArray[Any], depth_frame: bool = False) -> NDArray[Any]:
|
||||
def _postprocess_image(
|
||||
self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False
|
||||
) -> NDArray[Any]:
|
||||
"""
|
||||
Applies color conversion, dimension validation, and rotation to a raw color frame.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The raw image frame (expected RGB format from RealSense).
|
||||
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
|
||||
uses the instance's default `self.color_mode`.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The processed image frame according to `self.color_mode` and `self.rotation`.
|
||||
@@ -434,9 +421,9 @@ class RealSenseCamera(Camera):
|
||||
`width` and `height`.
|
||||
"""
|
||||
|
||||
if self.color_mode and self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"Invalid requested color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
)
|
||||
|
||||
if depth_frame:
|
||||
@@ -467,7 +454,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame with 500ms timeout
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
2. Stores result in latest_frame (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
@@ -475,41 +462,25 @@ class RealSenseCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
color_frame_raw = frame.get_color_frame()
|
||||
color_frame = np.asanyarray(color_frame_raw.get_data())
|
||||
processed_color_frame = self._postprocess_image(color_frame)
|
||||
|
||||
if self.use_depth:
|
||||
depth_frame_raw = frame.get_depth_frame()
|
||||
depth_frame = np.asanyarray(depth_frame_raw.get_data())
|
||||
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
|
||||
|
||||
capture_time = time.perf_counter()
|
||||
color_image = self.read(timeout_ms=500)
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_color_frame = processed_color_frame
|
||||
if self.use_depth:
|
||||
self.latest_depth_frame = processed_depth_frame
|
||||
self.latest_timestamp = capture_time
|
||||
self.latest_frame = color_image
|
||||
self.new_frame_event.set()
|
||||
failure_count = 0
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except Exception as e:
|
||||
if failure_count <= 10:
|
||||
failure_count += 1
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
else:
|
||||
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
"""Starts or restarts the background read thread if it's not running."""
|
||||
self._stop_read_thread()
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=0.1)
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
|
||||
@@ -527,12 +498,6 @@ class RealSenseCamera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_color_frame = None
|
||||
self.latest_depth_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
@@ -541,7 +506,6 @@ class RealSenseCamera(Camera):
|
||||
This method retrieves the most recent color frame captured by the background
|
||||
read thread. It does not block waiting for the camera hardware directly,
|
||||
but may wait up to timeout_ms for the background thread to provide a frame.
|
||||
It is “best effort” under high FPS.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
@@ -560,16 +524,17 @@ class RealSenseCamera(Camera):
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
self._start_read_thread()
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
thread_alive = self.thread is not None and self.thread.is_alive()
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
|
||||
f"Read thread alive: {self.thread.is_alive()}."
|
||||
f"Read thread alive: {thread_alive}."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_color_frame
|
||||
frame = self.latest_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if frame is None:
|
||||
@@ -577,43 +542,6 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Returns:
|
||||
NDArray[Any]: The frame image (numpy array).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_color_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera, stops the pipeline, and cleans up resources.
|
||||
@@ -637,10 +565,4 @@ class RealSenseCamera(Camera):
|
||||
self.rs_pipeline = None
|
||||
self.rs_profile = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_color_frame = None
|
||||
self.latest_depth_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -45,12 +45,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ZMQCamera(Camera):
|
||||
"""
|
||||
Manages camera interactions via ZeroMQ for receiving frames from a remote server.
|
||||
|
||||
This class connects to a ZMQ Publisher, subscribes to frame topics, and decodes
|
||||
incoming JSON messages containing Base64 encoded images. It supports both
|
||||
synchronous and asynchronous frame reading patterns.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig
|
||||
@@ -58,16 +52,7 @@ class ZMQCamera(Camera):
|
||||
config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera")
|
||||
camera = ZMQCamera(config)
|
||||
camera.connect()
|
||||
|
||||
# Read 1 frame synchronously (blocking)
|
||||
color_image = camera.read()
|
||||
|
||||
# Read 1 frame asynchronously (waits for new frame with a timeout)
|
||||
async_image = camera.async_read()
|
||||
|
||||
# Get the latest frame immediately (no wait, returns timestamp)
|
||||
latest_image, timestamp = camera.read_latest()
|
||||
|
||||
frame = camera.read()
|
||||
camera.disconnect()
|
||||
```
|
||||
"""
|
||||
@@ -83,17 +68,14 @@ class ZMQCamera(Camera):
|
||||
self.color_mode = config.color_mode
|
||||
self.timeout_ms = config.timeout_ms
|
||||
|
||||
# ZMQ Context and Socket
|
||||
self.context: zmq.Context | None = None
|
||||
self.socket: zmq.Socket | None = None
|
||||
self._connected = False
|
||||
|
||||
# Threading resources
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_timestamp: float | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -101,16 +83,10 @@ class ZMQCamera(Camera):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Checks if the ZMQ socket is initialized and connected."""
|
||||
return self._connected and self.context is not None and self.socket is not None
|
||||
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""Connect to ZMQ camera server.
|
||||
|
||||
Args:
|
||||
warmup (bool): If True, waits for the camera to provide at least one
|
||||
valid frame before returning. Defaults to True.
|
||||
"""
|
||||
"""Connect to ZMQ camera server."""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
|
||||
@@ -127,28 +103,17 @@ class ZMQCamera(Camera):
|
||||
self.socket.connect(f"tcp://{self.server_address}:{self.port}")
|
||||
self._connected = True
|
||||
|
||||
# Auto-detect resolution if not provided
|
||||
# Auto-detect resolution
|
||||
if self.width is None or self.height is None:
|
||||
# Read directly from hardware because the thread isn't running yet
|
||||
temp_frame = self._read_from_hardware()
|
||||
h, w = temp_frame.shape[:2]
|
||||
h, w = self.read().shape[:2]
|
||||
self.height = h
|
||||
self.width = w
|
||||
logger.info(f"{self} resolution detected: {w}x{h}")
|
||||
logger.info(f"{self} resolution: {w}x{h}")
|
||||
|
||||
self._start_read_thread()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
if warmup:
|
||||
# Ensure we have captured at least one frame via the thread
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < (self.config.warmup_s): # Wait a bit more than timeout
|
||||
self.async_read(timeout_ms=self.config.warmup_s * 1000)
|
||||
time.sleep(0.1)
|
||||
|
||||
with self.frame_lock:
|
||||
if self.latest_frame is None:
|
||||
raise ConnectionError(f"{self} failed to capture frames during warmup.")
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
self._cleanup()
|
||||
@@ -166,14 +131,15 @@ class ZMQCamera(Camera):
|
||||
|
||||
@staticmethod
|
||||
def find_cameras() -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detection not implemented for ZMQ cameras. These cameras require manual configuration (server address/port).
|
||||
"""
|
||||
raise NotImplementedError("Camera detection is not implemented for ZMQ cameras.")
|
||||
"""ZMQ cameras require manual configuration (server address/port)."""
|
||||
return []
|
||||
|
||||
def _read_from_hardware(self) -> NDArray[Any]:
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame directly from the ZMQ socket.
|
||||
Read a single frame from the ZMQ camera.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decoded frame (height, width, 3)
|
||||
"""
|
||||
if not self.is_connected or self.socket is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
@@ -181,7 +147,6 @@ class ZMQCamera(Camera):
|
||||
try:
|
||||
message = self.socket.recv_string()
|
||||
except Exception as e:
|
||||
# Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import
|
||||
if type(e).__name__ == "Again":
|
||||
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
|
||||
raise
|
||||
@@ -211,117 +176,42 @@ class ZMQCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
This is a blocking call. It waits for the next available frame from the
|
||||
camera background thread.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decoded frame (height, width, 3)
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if color_mode is not None:
|
||||
logger.warning(
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
self.new_frame_event.clear()
|
||||
frame = self.async_read(timeout_ms=10000)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
while self.stop_event and not self.stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
frame = self.read()
|
||||
with self.frame_lock:
|
||||
self.latest_frame = frame
|
||||
self.latest_timestamp = capture_time
|
||||
self.new_frame_event.set()
|
||||
failure_count = 0
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except (TimeoutError, Exception) as e:
|
||||
if failure_count <= 10:
|
||||
failure_count += 1
|
||||
logger.warning(f"Read error: {e}")
|
||||
else:
|
||||
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
|
||||
except TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Read error: {e}")
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if self.thread and self.thread.is_alive():
|
||||
return
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, daemon=True, name=f"{self}_read_loop")
|
||||
self.thread = Thread(target=self._read_loop, daemon=True)
|
||||
self.thread.start()
|
||||
time.sleep(0.1)
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
if self.stop_event is not None:
|
||||
if self.stop_event:
|
||||
self.stop_event.set()
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
to become available. Defaults to 200ms.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The latest captured frame.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
TimeoutError: If no frame data becomes available within the specified timeout.
|
||||
RuntimeError: If the background thread is not running.
|
||||
"""
|
||||
def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]:
|
||||
"""Read latest frame asynchronously (non-blocking)."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
if not self.thread or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms")
|
||||
@@ -335,55 +225,11 @@ class ZMQCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Returns:
|
||||
NDArray[Any]: The frame image (numpy array).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from ZMQ camera."""
|
||||
if not self.is_connected and self.thread is None:
|
||||
if not self.is_connected and not self.thread:
|
||||
raise DeviceNotConnectedError(f"{self} not connected.")
|
||||
|
||||
if self.thread is not None:
|
||||
self._stop_read_thread()
|
||||
|
||||
self._stop_read_thread()
|
||||
self._cleanup()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -29,7 +29,6 @@ class ZMQCameraConfig(CameraConfig):
|
||||
camera_name: str = "zmq_camera"
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
timeout_ms: int = 5000
|
||||
warmup_s: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
|
||||
@@ -45,12 +45,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
|
||||
normalization mode to apply.
|
||||
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
|
||||
the original scale.
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
|
||||
@@ -50,3 +50,8 @@ class RTCAttentionSchedule(str, Enum):
|
||||
ONES = "ONES"
|
||||
LINEAR = "LINEAR"
|
||||
EXP = "EXP"
|
||||
|
||||
|
||||
class RTCTrainingDelayDistribution(str, Enum):
|
||||
UNIFORM = "UNIFORM"
|
||||
EXP = "EXP"
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Example script to run synthetic data generation with Qwen VLM
|
||||
# This generates user prompts and robot utterances for hierarchical policy training
|
||||
|
||||
# Configuration
|
||||
REPO_ID="lerobot/libero_10"
|
||||
MODEL="Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||
# or: MODEL="Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
|
||||
OUTPUT_DIR="/fsx/jade_choghari/outputs/libero-10-annotate-high"
|
||||
|
||||
BATCH_SIZE=16
|
||||
TEMPERATURE=0.9
|
||||
SAMPLE_INTERVAL=5.0 # generate dialogue every 1 second (all episodes processed)
|
||||
|
||||
# Run subtask annotation
|
||||
# python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --video-key observation.images.image \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --skip-existing \
|
||||
# --output-repo-id "jadechoghari/libero10-annotate" \
|
||||
# --batch-size "$BATCH_SIZE" \
|
||||
# run synthetic data generation (all episodes processed)
|
||||
# python examples/dataset/annotate_pgen.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --model "$MODEL" \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --temperature "$TEMPERATURE" \
|
||||
# --batch-size "$BATCH_SIZE" \
|
||||
# --sample-interval "$SAMPLE_INTERVAL" \
|
||||
# --image-key observation.images.base \
|
||||
# --num-image-views-per-sample 1
|
||||
|
||||
# for faster testing, increase sample interval:
|
||||
# --sample-interval 5.0 # Samples every 5 seconds (much faster)
|
||||
|
||||
# to push to hub after generation:
|
||||
# add --push-to-hub flag
|
||||
|
||||
# efficient batch processing: 4 episodes at once
|
||||
python src/lerobot/data_processing/annotations/high_level_annotate.py \
|
||||
--data-dir "/fsx/jade_choghari/outputs/libero-10-annotate" \
|
||||
--output-dir "$OUTPUT_DIR" \
|
||||
--video-mode \
|
||||
--video-key observation.images.image \
|
||||
--video-batch-size "$BATCH_SIZE" \
|
||||
--sample-interval 5.0
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,52 +0,0 @@
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
# /fsx/jade_choghari/data/libero_10_subtasks_kw_converted
|
||||
dataset = LeRobotDataset(repo_id="lerobot/libero_10_image_subtask")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path="/fsx/jade_choghari/models/pi05-base",
|
||||
)
|
||||
cfg.dtype = "bfloat16"
|
||||
|
||||
pre_processor, post_processor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path="/fsx/jade_choghari/models/pi05-base",
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
breakpoint()
|
||||
batch1 = pre_processor(batch)
|
||||
breakpoint()
|
||||
print(batch.keys())
|
||||
# print(batch['task_index_high_level'].shape)
|
||||
# print(batch['task_index_high_level'])
|
||||
# print(batch['user_prompt'][0])
|
||||
# print(batch['robot_utterance'][0])
|
||||
# print(batch['task'][0])
|
||||
|
||||
valid_episode_list = []
|
||||
for episode_idx in range(len(dataset.meta.episodes)):
|
||||
subtask_index = dataset[episode_idx]["subtask_index"]
|
||||
valid_episode_list.append(episode_idx)
|
||||
|
||||
print(len(valid_episode_list))
|
||||
|
||||
# read this parquet /fsx/jade_choghari/outputs/pgen_annotations1/meta/tasks.parquett
|
||||
# import pandas as pd
|
||||
# tasks_df = pd.read_parquet('/fsx/jade_choghari/outputs/pgen_annotations1/meta/tasks.parquet')
|
||||
|
||||
# # print all
|
||||
# print(tasks_df.columns)
|
||||
# breakpoint()
|
||||
@@ -1,74 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Example script to run synthetic data generation with Qwen VLM
|
||||
# This generates user prompts and robot utterances for hierarchical policy training
|
||||
|
||||
# Configuration
|
||||
REPO_ID="jadechoghari/piper-demo-20260205_103303"
|
||||
# MODEL="Qwen/Qwen3-VL-30B-A3B-Thinking"
|
||||
MODEL="Qwen/Qwen3.5-27B"
|
||||
# or: MODEL="Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
|
||||
OUTPUT_DIR="/fsx/jade_choghari/outputs/collect-data-pgen_new"
|
||||
|
||||
BATCH_SIZE=2
|
||||
TEMPERATURE=0.9
|
||||
SAMPLE_INTERVAL=5.0 # generate dialogue every 1 second (all episodes processed)
|
||||
|
||||
# Run subtask annotation.
|
||||
# To use closed-vocabulary labels, add a line: --subtask-labels "label1" "label2" ...
|
||||
# Example (add backslash after "$MODEL" and uncomment the next line):
|
||||
# --model "$MODEL" \
|
||||
# --subtask-labels "pick_up_yellow_nut_bar" "pick_up_cake" "pick_up_biscuit_pack" "pick_up_soda_can"
|
||||
python /home/lerobot/src/lerobot/data_processing/annotations/subtask_annotate.py \
|
||||
--repo-id "$REPO_ID" \
|
||||
--video-key observation.images.top \
|
||||
--output-dir "$OUTPUT_DIR" \
|
||||
--output-repo-id "jadechoghari/piper-demo-annotated1" \
|
||||
--push-to-hub \
|
||||
--no-timer-overlay \
|
||||
--model "$MODEL" \
|
||||
--subtask-labels "pick_up_yellow_nut_bar" "pick_up_cake" "pick_up_biscuit_pack" "pick_up_soda_can" \
|
||||
--batch-size 2
|
||||
|
||||
# Run subtask annotation (image-window: frames as images for better accuracy)
|
||||
# python /admin/home/jade_choghari/lerobot/src/lerobot/data_processing/annotations/subtask_annotate_image.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --camera-key observation.images.wrist \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --output-repo-id "jadechoghari/piper-demo-annotated1-image" \
|
||||
# --push-to-hub \
|
||||
# --model "$MODEL" \
|
||||
# --window-size 184 \
|
||||
# --max-frames-per-window 16 \
|
||||
# --subtask-labels "pick_up_yellow_nut_bar" "pick_up_cake" "pick_up_biscuit_pack" "pick_up_soda_can" \
|
||||
# --batch-size 2
|
||||
|
||||
|
||||
# run synthetic data generation (all episodes processed)
|
||||
# python examples/dataset/annotate_pgen.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --model "$MODEL" \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --temperature "$TEMPERATURE" \
|
||||
# --batch-size "$BATCH_SIZE" \
|
||||
# --sample-interval "$SAMPLE_INTERVAL" \
|
||||
# --image-key observation.images.base \
|
||||
# --num-image-views-per-sample 1
|
||||
|
||||
# for faster testing, increase sample interval:
|
||||
# --sample-interval 5.0 # Samples every 5 seconds (much faster)
|
||||
|
||||
# to push to hub after generation:
|
||||
# add --push-to-hub flag
|
||||
|
||||
# efficient batch processing: 4 episodes at once
|
||||
# python examples/dataset/annotate_pgen.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --model "$MODEL" \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --video-mode \
|
||||
# --video-key observation.images.up \
|
||||
# --video-batch-size "$BATCH_SIZE" \
|
||||
# --sample-interval 1.0
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,561 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Image-window subtask annotation for LeRobot datasets using Qwen VLMs.
|
||||
|
||||
This script assigns a subtask to each window of consecutive frames by sending
|
||||
those frames as images to the VLM (instead of a video) for better accuracy.
|
||||
Supports Qwen2-VL and Qwen3-VL (same models as subtask_annotate.py).
|
||||
|
||||
Pipeline:
|
||||
1. Load a LeRobot dataset (local or Hub).
|
||||
2. For each episode, slide a window over frame indices.
|
||||
3. For each window, load the corresponding images (from image_key or decoded video_key).
|
||||
4. Send the window of images to Qwen2-VL with the same skill prompt; get one subtask name.
|
||||
5. Assign that subtask to all frames in the window.
|
||||
6. Write subtasks.parquet and add subtask_index via add_features (same as subtask_annotate).
|
||||
|
||||
Usage:
|
||||
python -m lerobot.data_processing.annotations.subtask_annotate_image \\
|
||||
--data-dir /path/to/dataset --camera-key observation.images.base \\
|
||||
--window-size 8 --stride 8 --output-dir ./output
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import random
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from rich.console import Console
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Reuse data structures and save/load from the video-based annotator
|
||||
from lerobot.data_processing.annotations.subtask_annotate import (
|
||||
EpisodeSkills,
|
||||
Skill,
|
||||
load_skill_annotations,
|
||||
save_skill_annotations,
|
||||
)
|
||||
|
||||
|
||||
def create_window_skill_prompt(
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Prompt for labeling a single window of frames with one atomic skill.
|
||||
If subtask_labels are provided, the model must choose exactly one from that list.
|
||||
"""
|
||||
goal_context = f'The overall goal is: "{coarse_goal}".\n\n' if coarse_goal else ""
|
||||
if subtask_labels:
|
||||
labels_list = ", ".join(f'"{l}"' for l in subtask_labels)
|
||||
label_instruction = (
|
||||
f"You must choose exactly ONE skill from this list: [{labels_list}]. "
|
||||
"Do not create new labels. Reply with only that label.\n\n"
|
||||
)
|
||||
else:
|
||||
label_instruction = ""
|
||||
return textwrap.dedent(f"""\
|
||||
# Role
|
||||
You are a Robotics Vision System that labels short clips from robot manipulation demonstrations.
|
||||
|
||||
# Task
|
||||
{goal_context}{label_instruction}The following images are consecutive frames from a single short clip of a robot demonstration.
|
||||
What single atomic manipulation skill is being performed in this clip?
|
||||
|
||||
# Requirements
|
||||
- Reply with ONLY one short skill name (e.g. "pick up object", "move arm left", "release gripper").
|
||||
- No explanation, no timestamps, no JSON. Just the skill name.
|
||||
""").strip()
|
||||
|
||||
|
||||
def _run_image_segmenter(
|
||||
self,
|
||||
images: list[PIL.Image.Image],
|
||||
coarse_goal: str | None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Shared inference for Qwen2-VL and Qwen3-VL image window labeling."""
|
||||
prompt = create_window_skill_prompt(coarse_goal, subtask_labels)
|
||||
content = []
|
||||
for img in images:
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": "What single atomic skill is shown in these frames? Reply with only the skill name."})
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||
{"role": "user", "content": content},
|
||||
]
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=128, do_sample=False)
|
||||
|
||||
response = self.processor.batch_decode(
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||
skip_special_tokens=True,
|
||||
)[0].strip()
|
||||
skill_name = response.split("\n")[0].strip().strip('."')
|
||||
return skill_name if skill_name else "unknown"
|
||||
|
||||
|
||||
def _run_image_segmenter_batch(
|
||||
self,
|
||||
batch_images: list[list[PIL.Image.Image]],
|
||||
coarse_goal: str | None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Run VLM on multiple windows at once; returns one skill name per window."""
|
||||
if not batch_images:
|
||||
return []
|
||||
prompt = create_window_skill_prompt(coarse_goal, subtask_labels)
|
||||
all_texts = []
|
||||
all_image_inputs = []
|
||||
all_video_inputs = []
|
||||
for images in batch_images:
|
||||
content = []
|
||||
for img in images:
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": "What single atomic skill is shown in these frames? Reply with only the skill name."})
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||
{"role": "user", "content": content},
|
||||
]
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
all_texts.append(text)
|
||||
if image_inputs is not None:
|
||||
all_image_inputs.extend(image_inputs if isinstance(image_inputs, list) else [image_inputs])
|
||||
if video_inputs is not None:
|
||||
all_video_inputs.extend(video_inputs if isinstance(video_inputs, list) else [video_inputs])
|
||||
inputs = self.processor(
|
||||
text=all_texts,
|
||||
images=all_image_inputs if all_image_inputs else None,
|
||||
videos=all_video_inputs if all_video_inputs else None,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=128, do_sample=False)
|
||||
responses = self.processor.batch_decode(
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
return [
|
||||
(r.split("\n")[0].strip().strip('."') or "unknown")
|
||||
for r in responses
|
||||
]
|
||||
|
||||
|
||||
class Qwen2VLImageSegmenter:
|
||||
"""Uses Qwen2-VL to assign one skill name to a window of images (same model as subtask_annotate)."""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||
|
||||
self.console = Console()
|
||||
self.device = device
|
||||
self.process_vision_info = process_vision_info
|
||||
self.console.print(f"[cyan]Loading Qwen2-VL for image-window labeling: {model_name}...[/cyan]")
|
||||
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
self.console.print(f"[green]✓ Model loaded on {device}[/green]")
|
||||
|
||||
def segment_skill_from_images(
|
||||
self,
|
||||
images: list[PIL.Image.Image],
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Return a single skill name for the given window of images."""
|
||||
return _run_image_segmenter(self, images, coarse_goal, subtask_labels)
|
||||
|
||||
def segment_skill_from_images_batch(
|
||||
self,
|
||||
batch_images: list[list[PIL.Image.Image]],
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Return one skill name per window; processes multiple windows in one forward pass."""
|
||||
return _run_image_segmenter_batch(self, batch_images, coarse_goal, subtask_labels)
|
||||
|
||||
|
||||
class Qwen3VLImageSegmenter:
|
||||
"""Uses Qwen3-VL (MoE) to assign one skill name to a window of images."""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
|
||||
|
||||
self.console = Console()
|
||||
self.device = device
|
||||
self.process_vision_info = process_vision_info
|
||||
self.console.print(f"[cyan]Loading Qwen3-VL for image-window labeling: {model_name}...[/cyan]")
|
||||
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
self.console.print(f"[green]✓ Model loaded on {device}[/green]")
|
||||
|
||||
def segment_skill_from_images(
|
||||
self,
|
||||
images: list[PIL.Image.Image],
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Return a single skill name for the given window of images."""
|
||||
return _run_image_segmenter(self, images, coarse_goal, subtask_labels)
|
||||
|
||||
def segment_skill_from_images_batch(
|
||||
self,
|
||||
batch_images: list[list[PIL.Image.Image]],
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Return one skill name per window; processes multiple windows in one forward pass."""
|
||||
return _run_image_segmenter_batch(self, batch_images, coarse_goal, subtask_labels)
|
||||
|
||||
|
||||
def get_image_segmenter(
|
||||
model_name: str,
|
||||
device: str = "cuda",
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
):
|
||||
"""Return the appropriate image-window segmenter for the model (Qwen2-VL or Qwen3-VL)."""
|
||||
model_lower = model_name.lower()
|
||||
if "qwen3" in model_lower:
|
||||
return Qwen3VLImageSegmenter(model_name, device, torch_dtype)
|
||||
return Qwen2VLImageSegmenter(model_name, device, torch_dtype)
|
||||
|
||||
|
||||
def frame_to_pil(frame_value) -> PIL.Image.Image:
|
||||
"""Convert a single frame from dataset (tensor or PIL or path) to PIL.Image."""
|
||||
if isinstance(frame_value, PIL.Image.Image):
|
||||
return frame_value
|
||||
if isinstance(frame_value, (str, Path)):
|
||||
return PIL.Image.open(frame_value).convert("RGB")
|
||||
if hasattr(frame_value, "numpy"):
|
||||
arr = frame_value.numpy()
|
||||
else:
|
||||
arr = np.asarray(frame_value)
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
if arr.dtype == np.float32 or arr.dtype == np.float64:
|
||||
arr = (np.clip(arr, 0, 1) * 255).astype(np.uint8)
|
||||
elif arr.dtype != np.uint8:
|
||||
arr = np.clip(arr, 0, 255).astype(np.uint8)
|
||||
if arr.shape[-1] == 1:
|
||||
arr = np.repeat(arr, 3, axis=-1)
|
||||
return PIL.Image.fromarray(arr)
|
||||
|
||||
|
||||
def _sample_window_indices(window_length: int, max_frames: int) -> list[int]:
|
||||
"""Return indices into a window of length window_length, at most max_frames, in order.
|
||||
If window_length <= max_frames, returns range(window_length).
|
||||
Otherwise returns sorted random sample of max_frames indices (temporal order preserved).
|
||||
"""
|
||||
if max_frames <= 0 or window_length <= max_frames:
|
||||
return list(range(window_length))
|
||||
return sorted(random.sample(range(window_length), max_frames))
|
||||
|
||||
|
||||
class SkillAnnotatorImage:
|
||||
"""Annotates episodes by sliding a window over frames and labeling each window with the VLM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
segmenter: Qwen2VLImageSegmenter | Qwen3VLImageSegmenter,
|
||||
window_size: int = 8,
|
||||
stride: int | None = None,
|
||||
batch_size: int = 1,
|
||||
max_frames_per_window: int | None = None,
|
||||
console: Console | None = None,
|
||||
):
|
||||
self.segmenter = segmenter
|
||||
self.window_size = window_size
|
||||
self.stride = stride if stride is not None else window_size
|
||||
self.batch_size = max(1, batch_size)
|
||||
self.max_frames_per_window = max_frames_per_window
|
||||
self.console = console or Console()
|
||||
|
||||
def annotate_dataset(
|
||||
self,
|
||||
dataset: LeRobotDataset,
|
||||
camera_key: str,
|
||||
episodes: list[int] | None = None,
|
||||
skip_existing: bool = False,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> dict[int, EpisodeSkills]:
|
||||
"""Annotate episodes using image windows. camera_key can be an image_key or video_key."""
|
||||
episode_indices = episodes or list(range(dataset.meta.total_episodes))
|
||||
coarse_goal = self._get_coarse_goal(dataset)
|
||||
annotations: dict[int, EpisodeSkills] = {}
|
||||
|
||||
if skip_existing:
|
||||
existing = load_skill_annotations(dataset.root)
|
||||
if existing and existing.get("episodes"):
|
||||
existing_eps = {int(k) for k in existing["episodes"] if existing["episodes"][k].get("skills")}
|
||||
episode_indices = [i for i in episode_indices if i not in existing_eps]
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
try:
|
||||
skills = self._annotate_episode(
|
||||
dataset, ep_idx, camera_key, coarse_goal, subtask_labels
|
||||
)
|
||||
if skills:
|
||||
annotations[ep_idx] = EpisodeSkills(
|
||||
episode_index=ep_idx,
|
||||
description=coarse_goal,
|
||||
skills=skills,
|
||||
)
|
||||
self.console.print(f"[green]✓ Episode {ep_idx}: {len(skills)} window skills[/green]")
|
||||
else:
|
||||
self.console.print(f"[yellow]⚠ Episode {ep_idx}: no skills[/yellow]")
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]Episode {ep_idx} failed: {e}[/red]")
|
||||
|
||||
return annotations
|
||||
|
||||
def _get_coarse_goal(self, dataset: LeRobotDataset) -> str:
|
||||
if dataset.meta.tasks is not None and len(dataset.meta.tasks) > 0:
|
||||
return str(dataset.meta.tasks.index[0])
|
||||
return "Perform the demonstrated manipulation task."
|
||||
|
||||
def _annotate_episode(
|
||||
self,
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
camera_key: str,
|
||||
coarse_goal: str,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[Skill]:
|
||||
ep = dataset.meta.episodes[episode_index]
|
||||
ep_from = int(ep["dataset_from_index"])
|
||||
ep_to = int(ep["dataset_to_index"])
|
||||
length = ep_to - ep_from
|
||||
fps = dataset.meta.fps
|
||||
if length == 0:
|
||||
return []
|
||||
|
||||
# Collect full windows: (images, t_start, t_end) using frame timestamps.
|
||||
# If max_frames_per_window is set and window is larger, sample that many frames (order preserved).
|
||||
window_specs: list[tuple[list[PIL.Image.Image], float, float]] = []
|
||||
start = 0
|
||||
while start + self.window_size <= length:
|
||||
offsets = _sample_window_indices(
|
||||
self.window_size,
|
||||
self.max_frames_per_window or self.window_size,
|
||||
)
|
||||
frame_indices = [ep_from + start + i for i in offsets]
|
||||
images = []
|
||||
t_start = float(dataset[frame_indices[0]]["timestamp"].item())
|
||||
for idx in frame_indices:
|
||||
item = dataset[idx]
|
||||
images.append(frame_to_pil(item[camera_key]))
|
||||
t_end = t_start + self.window_size / fps
|
||||
window_specs.append((images, t_start, t_end))
|
||||
start += self.stride
|
||||
|
||||
# Last partial window
|
||||
if start < length:
|
||||
partial_len = ep_to - (ep_from + start)
|
||||
offsets = _sample_window_indices(
|
||||
partial_len,
|
||||
self.max_frames_per_window or partial_len,
|
||||
)
|
||||
frame_indices = [ep_from + start + i for i in offsets]
|
||||
images = []
|
||||
t_start = float(dataset[frame_indices[0]]["timestamp"].item())
|
||||
for idx in frame_indices:
|
||||
item = dataset[idx]
|
||||
images.append(frame_to_pil(item[camera_key]))
|
||||
t_end = float(dataset[frame_indices[-1]]["timestamp"].item()) + 1.0 / fps
|
||||
window_specs.append((images, t_start, t_end))
|
||||
|
||||
# Run in batches
|
||||
skills: list[Skill] = []
|
||||
for i in range(0, len(window_specs), self.batch_size):
|
||||
chunk = window_specs[i : i + self.batch_size]
|
||||
batch_images = [spec[0] for spec in chunk]
|
||||
if len(batch_images) > 1:
|
||||
skill_names = self.segmenter.segment_skill_from_images_batch(
|
||||
batch_images, coarse_goal, subtask_labels
|
||||
)
|
||||
else:
|
||||
skill_names = [
|
||||
self.segmenter.segment_skill_from_images(
|
||||
batch_images[0], coarse_goal, subtask_labels
|
||||
)
|
||||
]
|
||||
for (_, t_start, t_end), name in zip(chunk, skill_names, strict=True):
|
||||
skills.append(Skill(name=name, start=t_start, end=t_end))
|
||||
|
||||
return skills
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Image-window subtask annotation using Qwen VLM (frames as images for better accuracy)",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=textwrap.dedent("""\
|
||||
Examples:
|
||||
python -m lerobot.data_processing.annotations.subtask_annotate_image \\
|
||||
--data-dir /path/to/dataset --camera-key observation.images.base \\
|
||||
--window-size 8 --output-dir ./output
|
||||
|
||||
python -m lerobot.data_processing.annotations.subtask_annotate_image \\
|
||||
--repo-id user/dataset --camera-key observation.images.base \\
|
||||
--window-size 6 --stride 3 --model Qwen/Qwen2-VL-7B-Instruct
|
||||
|
||||
# Use Qwen3-VL (MoE)
|
||||
python -m lerobot.data_processing.annotations.subtask_annotate_image \\
|
||||
--data-dir /path/to/dataset --camera-key observation.images.base \\
|
||||
--model Qwen/Qwen3-VL-30B-A3B-Instruct
|
||||
"""),
|
||||
)
|
||||
data_group = parser.add_mutually_exclusive_group(required=True)
|
||||
data_group.add_argument("--data-dir", type=str, help="Path to local LeRobot dataset")
|
||||
data_group.add_argument("--repo-id", type=str, help="HuggingFace Hub dataset repository ID")
|
||||
|
||||
parser.add_argument(
|
||||
"--camera-key",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Image or video observation key (e.g. observation.images.base)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="Qwen/Qwen2-VL-7B-Instruct",
|
||||
help="VLM model: Qwen2-VL or Qwen3-VL (default: Qwen/Qwen2-VL-7B-Instruct)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--window-size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of frames per window (default: 8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stride",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Stride for sliding window (default: window_size = non-overlapping)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of windows to process in one VLM call (default: 1; increase for speed)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-frames-per-window",
|
||||
type=int,
|
||||
default=None,
|
||||
metavar="N",
|
||||
help="If window has more than N frames, randomly sample N frames (order kept) to avoid OOM (e.g. 16)",
|
||||
)
|
||||
parser.add_argument("--episodes", type=int, nargs="+", help="Episode indices to annotate (default: all)")
|
||||
parser.add_argument("--skip-existing", action="store_true", help="Skip episodes that already have annotations")
|
||||
parser.add_argument(
|
||||
"--subtask-labels",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="Closed vocabulary: model must choose only from these labels",
|
||||
)
|
||||
parser.add_argument("--output-dir", type=str, help="Output directory for dataset with subtask_index")
|
||||
parser.add_argument("--output-repo-id", type=str, help="Output repo id (default: <repo_id>_with_subtasks)")
|
||||
parser.add_argument("--push-to-hub", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
console = Console()
|
||||
|
||||
# Load dataset
|
||||
console.print("[cyan]Loading dataset...[/cyan]")
|
||||
if args.data_dir:
|
||||
dataset = LeRobotDataset(repo_id="local/dataset", root=args.data_dir, download_videos=False)
|
||||
else:
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, download_videos=True)
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
if args.camera_key not in camera_keys:
|
||||
console.print(f"[red]Error: camera key '{args.camera_key}' not in {camera_keys}[/red]")
|
||||
return
|
||||
console.print(f"[green]✓ Loaded dataset, {dataset.meta.total_episodes} episodes[/green]")
|
||||
|
||||
# Same Qwen VLM as subtask_annotate (Qwen2-VL or Qwen3-VL), image windows instead of video
|
||||
segmenter = get_image_segmenter(args.model, args.device, torch.bfloat16)
|
||||
|
||||
annotator = SkillAnnotatorImage(
|
||||
segmenter=segmenter,
|
||||
window_size=args.window_size,
|
||||
stride=args.stride,
|
||||
batch_size=args.batch_size,
|
||||
max_frames_per_window=args.max_frames_per_window,
|
||||
console=console,
|
||||
)
|
||||
annotations = annotator.annotate_dataset(
|
||||
dataset=dataset,
|
||||
camera_key=args.camera_key,
|
||||
episodes=args.episodes,
|
||||
skip_existing=args.skip_existing,
|
||||
subtask_labels=args.subtask_labels,
|
||||
)
|
||||
|
||||
if not annotations:
|
||||
console.print("[yellow]No annotations to save.[/yellow]")
|
||||
return
|
||||
|
||||
output_dir = Path(args.output_dir) if args.output_dir else None
|
||||
output_repo_id = args.output_repo_id
|
||||
new_dataset = save_skill_annotations(dataset, annotations, output_dir, output_repo_id)
|
||||
|
||||
total_skills = sum(len(a.skills) for a in annotations.values())
|
||||
console.print(f"[bold green]✓ Done.[/bold green] Episodes: {len(annotations)}, total window skills: {total_skills}")
|
||||
console.print(f" Dataset with subtask_index: {new_dataset.root}")
|
||||
|
||||
if args.push_to_hub and not args.data_dir:
|
||||
console.print("[cyan]Pushing to Hub...[/cyan]")
|
||||
try:
|
||||
new_dataset.push_to_hub(push_videos=False)
|
||||
console.print("[green]✓ Pushed.[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Push failed: {e}[/red]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -116,9 +116,6 @@ def update_meta_data(
|
||||
Adjusts all indices and timestamps to account for previously aggregated
|
||||
data and videos in the destination dataset.
|
||||
|
||||
For data file indices, uses the 'src_to_dst' mapping from aggregate_data()
|
||||
to correctly map source file indices to their destination locations.
|
||||
|
||||
Args:
|
||||
df: DataFrame containing the metadata to be updated.
|
||||
dst_meta: Destination dataset metadata.
|
||||
@@ -132,50 +129,8 @@ def update_meta_data(
|
||||
|
||||
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
||||
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
|
||||
|
||||
# Update data file indices using source-to-destination mapping
|
||||
# This is critical for handling datasets that are already results of a merge
|
||||
data_src_to_dst = data_idx.get("src_to_dst", {})
|
||||
if data_src_to_dst:
|
||||
# Store original indices for lookup
|
||||
df["_orig_data_chunk"] = df["data/chunk_index"].copy()
|
||||
df["_orig_data_file"] = df["data/file_index"].copy()
|
||||
|
||||
# Vectorized mapping from (src_chunk, src_file) to (dst_chunk, dst_file)
|
||||
# This is much faster than per-row iteration for large metadata tables
|
||||
mapping_index = pd.MultiIndex.from_tuples(
|
||||
list(data_src_to_dst.keys()),
|
||||
names=["chunk_index", "file_index"],
|
||||
)
|
||||
mapping_values = list(data_src_to_dst.values())
|
||||
mapping_df = pd.DataFrame(
|
||||
mapping_values,
|
||||
index=mapping_index,
|
||||
columns=["dst_chunk", "dst_file"],
|
||||
)
|
||||
|
||||
# Construct a MultiIndex for each row based on original data indices
|
||||
row_index = pd.MultiIndex.from_arrays(
|
||||
[df["_orig_data_chunk"], df["_orig_data_file"]],
|
||||
names=["chunk_index", "file_index"],
|
||||
)
|
||||
|
||||
# Align mapping to rows; missing keys fall back to the default destination
|
||||
reindexed = mapping_df.reindex(row_index)
|
||||
reindexed[["dst_chunk", "dst_file"]] = reindexed[["dst_chunk", "dst_file"]].fillna(
|
||||
{"dst_chunk": data_idx["chunk"], "dst_file": data_idx["file"]}
|
||||
)
|
||||
|
||||
# Assign mapped destination indices back to the DataFrame
|
||||
df["data/chunk_index"] = reindexed["dst_chunk"].to_numpy()
|
||||
df["data/file_index"] = reindexed["dst_file"].to_numpy()
|
||||
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"])
|
||||
else:
|
||||
# Fallback to simple offset (backward compatibility for single-file sources)
|
||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||
for key, video_idx in videos_idx.items():
|
||||
# Store original video file indices before updating
|
||||
orig_chunk_col = f"videos/{key}/chunk_index"
|
||||
@@ -191,7 +146,8 @@ def update_meta_data(
|
||||
if src_to_dst:
|
||||
# Map each episode to its correct destination file and apply offset
|
||||
for idx in df.index:
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
||||
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
||||
|
||||
# Get destination chunk/file for this source file
|
||||
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
|
||||
@@ -207,7 +163,8 @@ def update_meta_data(
|
||||
df[orig_chunk_col] = video_idx["chunk"]
|
||||
df[orig_file_col] = video_idx["file"]
|
||||
for idx in df.index:
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
||||
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||
@@ -305,10 +262,6 @@ def aggregate_datasets(
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||
|
||||
# Clear the src_to_dst mapping after processing each source dataset
|
||||
# to avoid interference between different source datasets
|
||||
data_idx.pop("src_to_dst", None)
|
||||
|
||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||
|
||||
@@ -359,6 +312,10 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
dst_file_durations = video_idx["dst_file_durations"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
# Convert to Python int to ensure consistent dict keys
|
||||
src_chunk_idx = int(src_chunk_idx)
|
||||
src_file_idx = int(src_file_idx)
|
||||
|
||||
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=src_chunk_idx,
|
||||
@@ -431,16 +388,10 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
Reads source data files, updates indices to match the aggregated dataset,
|
||||
and writes them to the destination with proper file rotation.
|
||||
|
||||
Tracks a `src_to_dst` mapping from source (chunk, file) to destination (chunk, file)
|
||||
which is critical for correctly updating episode metadata when source datasets
|
||||
have multiple data files (e.g., from a previous merge operation).
|
||||
|
||||
Args:
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
data_idx: Dictionary tracking data chunk and file indices.
|
||||
data_files_size_in_mb: Maximum size for data files in MB.
|
||||
chunk_size: Maximum number of files per chunk.
|
||||
|
||||
Returns:
|
||||
dict: Updated data_idx with current chunk and file indices.
|
||||
@@ -458,10 +409,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
# retrieve features schema for proper image typing in parquet
|
||||
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
|
||||
|
||||
# Track source to destination file mapping for metadata update
|
||||
# This is critical for handling datasets that are already results of a merge
|
||||
src_to_dst: dict[tuple[int, int], tuple[int, int]] = {}
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
||||
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
||||
chunk_index=src_chunk_idx, file_index=src_file_idx
|
||||
@@ -474,9 +421,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
df = pd.read_parquet(src_path)
|
||||
df = update_data_df(df, src_meta, dst_meta)
|
||||
|
||||
# Write data and get the actual destination file it was written to
|
||||
# This avoids duplicating the rotation logic here
|
||||
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
|
||||
data_idx = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
data_idx,
|
||||
@@ -488,12 +433,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
hf_features=hf_features,
|
||||
)
|
||||
|
||||
# Record the mapping from source to actual destination
|
||||
src_to_dst[(src_chunk_idx, src_file_idx)] = (dst_chunk, dst_file)
|
||||
|
||||
# Add the mapping to data_idx for use in metadata update
|
||||
data_idx["src_to_dst"] = src_to_dst
|
||||
|
||||
return data_idx
|
||||
|
||||
|
||||
@@ -534,7 +473,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
videos_idx,
|
||||
)
|
||||
|
||||
meta_idx, _ = append_or_create_parquet_file(
|
||||
meta_idx = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
meta_idx,
|
||||
@@ -562,7 +501,7 @@ def append_or_create_parquet_file(
|
||||
contains_images: bool = False,
|
||||
aggr_root: Path = None,
|
||||
hf_features: datasets.Features | None = None,
|
||||
) -> tuple[dict[str, int], tuple[int, int]]:
|
||||
):
|
||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||
|
||||
Manages file rotation when size limits are exceeded to prevent individual files
|
||||
@@ -580,11 +519,9 @@ def append_or_create_parquet_file(
|
||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||
|
||||
Returns:
|
||||
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
||||
and (dst_chunk, dst_file) is the actual destination file the data was written to.
|
||||
dict: Updated index dictionary with current chunk and file indices.
|
||||
"""
|
||||
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||
|
||||
if not dst_path.exists():
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -592,15 +529,14 @@ def append_or_create_parquet_file(
|
||||
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
||||
else:
|
||||
df.to_parquet(dst_path)
|
||||
return idx, (dst_chunk, dst_file)
|
||||
return idx
|
||||
|
||||
src_size = get_parquet_file_size_in_mb(src_path)
|
||||
dst_size = get_parquet_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= max_mb:
|
||||
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
||||
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||
new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
final_df = df
|
||||
target_path = new_path
|
||||
@@ -619,7 +555,7 @@ def append_or_create_parquet_file(
|
||||
else:
|
||||
final_df.to_parquet(target_path)
|
||||
|
||||
return idx, (dst_chunk, dst_file)
|
||||
return idx
|
||||
|
||||
|
||||
def finalize_aggregation(aggr_meta, all_metadata):
|
||||
|
||||
@@ -1396,132 +1396,6 @@ BYTES_PER_KIB = 1024
|
||||
BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB
|
||||
|
||||
|
||||
def modify_tasks(
|
||||
dataset: LeRobotDataset,
|
||||
new_task: str | None = None,
|
||||
episode_tasks: dict[int, str] | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Modify tasks in a LeRobotDataset.
|
||||
|
||||
This function allows you to either:
|
||||
1. Set a single task for the entire dataset (using `new_task`)
|
||||
2. Set specific tasks for specific episodes (using `episode_tasks`)
|
||||
|
||||
You can combine both: `new_task` sets the default, and `episode_tasks` overrides
|
||||
specific episodes.
|
||||
|
||||
The dataset is modified in-place, updating only the task-related files:
|
||||
- meta/tasks.parquet
|
||||
- data/**/*.parquet (task_index column)
|
||||
- meta/episodes/**/*.parquet (tasks column)
|
||||
- meta/info.json (total_tasks)
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset to modify.
|
||||
new_task: A single task string to apply to all episodes. If None and episode_tasks
|
||||
is also None, raises an error.
|
||||
episode_tasks: Optional dict mapping episode indices to their task strings.
|
||||
Overrides `new_task` for specific episodes.
|
||||
|
||||
|
||||
Examples:
|
||||
Set a single task for all episodes:
|
||||
dataset = modify_tasks(dataset, new_task="Pick up the cube")
|
||||
|
||||
Set different tasks for specific episodes:
|
||||
dataset = modify_tasks(
|
||||
dataset,
|
||||
episode_tasks={0: "Task A", 1: "Task B", 2: "Task A"}
|
||||
)
|
||||
|
||||
Set a default task with overrides:
|
||||
dataset = modify_tasks(
|
||||
dataset,
|
||||
new_task="Default task",
|
||||
episode_tasks={5: "Special task for episode 5"}
|
||||
)
|
||||
"""
|
||||
if new_task is None and episode_tasks is None:
|
||||
raise ValueError("Must specify at least one of new_task or episode_tasks")
|
||||
|
||||
if episode_tasks is not None:
|
||||
valid_indices = set(range(dataset.meta.total_episodes))
|
||||
invalid = set(episode_tasks.keys()) - valid_indices
|
||||
if invalid:
|
||||
raise ValueError(f"Invalid episode indices: {invalid}")
|
||||
|
||||
# Ensure episodes metadata is loaded
|
||||
if dataset.meta.episodes is None:
|
||||
dataset.meta.episodes = load_episodes(dataset.root)
|
||||
|
||||
# Build the mapping from episode index to task string
|
||||
episode_to_task: dict[int, str] = {}
|
||||
for ep_idx in range(dataset.meta.total_episodes):
|
||||
if episode_tasks and ep_idx in episode_tasks:
|
||||
episode_to_task[ep_idx] = episode_tasks[ep_idx]
|
||||
elif new_task is not None:
|
||||
episode_to_task[ep_idx] = new_task
|
||||
else:
|
||||
# Keep original task if not overridden and no default provided
|
||||
original_tasks = dataset.meta.episodes[ep_idx]["tasks"]
|
||||
if not original_tasks:
|
||||
raise ValueError(f"Episode {ep_idx} has no tasks and no default task was provided")
|
||||
episode_to_task[ep_idx] = original_tasks[0]
|
||||
|
||||
# Collect all unique tasks and create new task mapping
|
||||
unique_tasks = sorted(set(episode_to_task.values()))
|
||||
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
|
||||
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
|
||||
|
||||
logging.info(f"Modifying tasks in {dataset.repo_id}")
|
||||
logging.info(f"New tasks: {unique_tasks}")
|
||||
|
||||
root = dataset.root
|
||||
|
||||
# Update data files - modify task_index column
|
||||
logging.info("Updating data files...")
|
||||
data_dir = root / DATA_DIR
|
||||
|
||||
for parquet_path in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Updating data"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
# Build a mapping from episode_index to new task_index for rows in this file
|
||||
episode_indices_in_file = df["episode_index"].unique()
|
||||
ep_to_new_task_idx = {
|
||||
ep_idx: task_to_index[episode_to_task[ep_idx]] for ep_idx in episode_indices_in_file
|
||||
}
|
||||
|
||||
# Update task_index column
|
||||
df["task_index"] = df["episode_index"].map(ep_to_new_task_idx)
|
||||
df.to_parquet(parquet_path, index=False)
|
||||
|
||||
# Update episodes metadata - modify tasks column
|
||||
logging.info("Updating episodes metadata...")
|
||||
episodes_dir = root / "meta" / "episodes"
|
||||
|
||||
for parquet_path in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Updating episodes"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
# Update tasks column
|
||||
df["tasks"] = df["episode_index"].apply(lambda ep_idx: [episode_to_task[ep_idx]])
|
||||
df.to_parquet(parquet_path, index=False)
|
||||
|
||||
# Write new tasks.parquet
|
||||
write_tasks(new_task_df, root)
|
||||
|
||||
# Update info.json
|
||||
dataset.meta.info["total_tasks"] = len(unique_tasks)
|
||||
write_info(dataset.meta.info, root)
|
||||
|
||||
# Reload metadata to reflect changes
|
||||
dataset.meta.tasks = new_task_df
|
||||
dataset.meta.episodes = load_episodes(root)
|
||||
|
||||
logging.info(f"Tasks: {unique_tasks}")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
|
||||
@@ -57,9 +57,7 @@ from lerobot.datasets.utils import (
|
||||
load_info,
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
load_tasks_high_level,
|
||||
update_chunk_file_indices,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
@@ -164,8 +162,6 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.tasks_high_level = load_tasks_high_level(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
@@ -522,8 +518,6 @@ class LeRobotDatasetMetadata:
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.tasks_high_level = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
@@ -1070,17 +1064,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if len(self.meta.video_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
try:
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
except Exception as e:
|
||||
print("\n" + "=" * 120)
|
||||
print("[VIDEO DECODE FAILURE]")
|
||||
print(f"item={item}")
|
||||
print(f"query_indices={query_indices}")
|
||||
print(f"query_timestamps={query_timestamps}")
|
||||
print(f"ep_idx={ep_idx}")
|
||||
print("=" * 120 + "\n")
|
||||
raise
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
|
||||
if self.image_transforms is not None:
|
||||
@@ -1091,20 +1075,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks.iloc[task_idx].name
|
||||
|
||||
# optionally add high level task index
|
||||
if "task_index_high_level" in self.features:
|
||||
high_level_task_idx = item["task_index_high_level"].item()
|
||||
item["robot_utterance"] = self.meta.tasks_high_level.iloc[high_level_task_idx]["robot_utterance"]
|
||||
item["user_prompt"] = self.meta.tasks_high_level.iloc[high_level_task_idx]["user_prompt"]
|
||||
|
||||
|
||||
|
||||
# add subtask information if available
|
||||
if "subtask_index" in self.features and self.meta.subtasks is not None:
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -216,17 +216,16 @@ class ImageTransformsConfig:
|
||||
|
||||
|
||||
def make_transform_from_config(cfg: ImageTransformConfig):
|
||||
if cfg.type == "SharpnessJitter":
|
||||
if cfg.type == "Identity":
|
||||
return v2.Identity(**cfg.kwargs)
|
||||
elif cfg.type == "ColorJitter":
|
||||
return v2.ColorJitter(**cfg.kwargs)
|
||||
elif cfg.type == "SharpnessJitter":
|
||||
return SharpnessJitter(**cfg.kwargs)
|
||||
|
||||
transform_cls = getattr(v2, cfg.type, None)
|
||||
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
|
||||
return transform_cls(**cfg.kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Transform '{cfg.type}' is not valid. It must be a class in "
|
||||
f"torchvision.transforms.v2 or 'SharpnessJitter'."
|
||||
)
|
||||
elif cfg.type == "RandomAffine":
|
||||
return v2.RandomAffine(**cfg.kwargs)
|
||||
else:
|
||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||
|
||||
|
||||
class ImageTransforms(Transform):
|
||||
|
||||
@@ -60,10 +60,7 @@ VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_TASKS_HIGH_LEVEL_PATH = "meta/tasks_high_level.parquet"
|
||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
@@ -355,28 +352,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||
return tasks
|
||||
|
||||
def load_tasks_high_level(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load high-level tasks from tasks_high_level.parquet if it exists."""
|
||||
tasks_high_level_path = local_dir / DEFAULT_TASKS_HIGH_LEVEL_PATH
|
||||
if tasks_high_level_path.exists():
|
||||
return pd.read_parquet(tasks_high_level_path)
|
||||
return None
|
||||
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
|
||||
@@ -205,7 +205,6 @@ class ObservationConfig:
|
||||
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
display_cameras: bool = False
|
||||
|
||||
|
||||
|
||||
@@ -28,11 +28,8 @@ from lerobot.utils.import_utils import _can_available
|
||||
if TYPE_CHECKING or _can_available:
|
||||
import can
|
||||
else:
|
||||
|
||||
class can: # noqa: N801
|
||||
Message = object
|
||||
interface = None
|
||||
|
||||
can.Message = object
|
||||
can.interface = None
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -209,31 +206,11 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Raises ConnectionError if any motor fails to respond.
|
||||
"""
|
||||
logger.info("Starting handshake with motors...")
|
||||
|
||||
# Drain any pending messages
|
||||
while self.canbus.recv(timeout=0.01):
|
||||
pass
|
||||
|
||||
missing_motors = []
|
||||
|
||||
for motor_name in self.motors:
|
||||
motor_id = self._get_motor_id(motor_name)
|
||||
recv_id = self._get_motor_recv_id(motor_name)
|
||||
|
||||
# Send enable command
|
||||
data = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, CAN_CMD_ENABLE]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
|
||||
# Wait for response with longer timeout
|
||||
response = None
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < 0.1:
|
||||
response = self.canbus.recv(timeout=0.1)
|
||||
if response and response.arbitration_id == recv_id:
|
||||
break
|
||||
response = None
|
||||
|
||||
if response is None:
|
||||
msg = self._refresh_motor(motor_name)
|
||||
if msg is None:
|
||||
missing_motors.append(motor_name)
|
||||
else:
|
||||
self._process_response(motor_name, msg)
|
||||
@@ -282,7 +259,7 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
motor_name = self._get_motor_name(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [0xFF] * 7 + [command_byte]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
if msg := self._recv_motor_response(expected_recv_id=recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
@@ -340,7 +317,7 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
return self._recv_motor_response(expected_recv_id=recv_id)
|
||||
|
||||
@@ -462,7 +439,7 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
motor_type = self._motor_types[motor_name]
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
@@ -495,7 +472,7 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
motor_type = self._motor_types[motor_name]
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
|
||||
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
||||
@@ -660,10 +637,10 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
for motor in motors:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(
|
||||
arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd
|
||||
)
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
# Small delay to reduce bus congestion if necessary, though removed in sync_read previously
|
||||
# precise_sleep(PRECISE_SLEEP_SEC)
|
||||
|
||||
# Collect responses
|
||||
expected_recv_ids = [self._get_motor_recv_id(m) for m in motors]
|
||||
@@ -699,9 +676,7 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
kd = self._gains[motor]["kd"]
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0)
|
||||
msg = can.Message(
|
||||
arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd
|
||||
)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
precise_sleep(PRECISE_TIMEOUT_SEC)
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_features` and `output_features`.
|
||||
Those are: `input_shapes` and 'output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- Either:
|
||||
@@ -48,12 +48,21 @@ class ACTConfig(PreTrainedConfig):
|
||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||
environment, and throws the other 50 out.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
|
||||
@@ -30,7 +30,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_features` and `output_features`.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
@@ -48,12 +48,21 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
@@ -64,7 +73,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||
use_separate_rgb_encoder_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
||||
You may provide a variable number of dimensions, therefore also controlling the degree of
|
||||
downsampling.
|
||||
|
||||
@@ -34,7 +34,6 @@ from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05_full.configuration_pi05 import PI05FullConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
@@ -391,13 +390,6 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, PI05FullConfig):
|
||||
from lerobot.policies.pi05_full.processor_pi05 import make_pi05_full_pre_post_processors
|
||||
|
||||
processors = make_pi05_full_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
|
||||
@@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
@@ -50,8 +50,9 @@ class PI0Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
# Real-Time Chunking (RTC) configurations
|
||||
rtc_config: RTCConfig | None = None
|
||||
rtc_training_config: RTCTrainingConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
|
||||
@@ -44,6 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.rtc.training_time import (
|
||||
apply_rtc_training_time,
|
||||
apply_training_time_rtc_inference,
|
||||
masked_mean,
|
||||
sample_rtc_delay,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -79,8 +85,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
if time.ndim not in (1, 2):
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
@@ -88,8 +94,14 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
if time.ndim == 1:
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
time_flat = time.reshape(-1)
|
||||
sin_input = scaling_factor[None, :] * time_flat[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb.reshape(*time.shape, dimension)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
@@ -605,6 +617,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _training_time_rtc_inference_enabled(self):
|
||||
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -714,7 +729,10 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
if time_emb.dim() == 2:
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
elif time_emb.shape[:2] != action_emb.shape[:2]:
|
||||
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
def mlp_func(action_time_emb):
|
||||
@@ -750,7 +768,12 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
if time.ndim == 1:
|
||||
time_expanded = time[:, None, None]
|
||||
elif time.ndim == 2:
|
||||
time_expanded = time[:, :, None]
|
||||
else:
|
||||
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
@@ -846,24 +869,37 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
use_training_time_rtc = self._training_time_rtc_inference_enabled()
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
if use_training_time_rtc:
|
||||
x_t_cond, time_tensor = apply_training_time_rtc_inference(
|
||||
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
|
||||
)
|
||||
v_t = self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
x_t=x_t_cond,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
elif self._rtc_enabled():
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
@@ -874,7 +910,14 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=x_t,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
@@ -1277,7 +1320,19 @@ class PI0Policy(PreTrainedPolicy):
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Compute loss
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
postfix_mask = None
|
||||
rtc_cfg = self.config.rtc_training_config
|
||||
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
|
||||
batch_size = actions.shape[0]
|
||||
time = self.model.sample_time(batch_size, actions.device)
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
|
||||
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
|
||||
losses = self.model.forward(
|
||||
images, img_masks, lang_tokens, lang_masks, state, actions, noise=noise, time=time
|
||||
)
|
||||
else:
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
@@ -1289,12 +1344,12 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
@@ -52,6 +52,7 @@ class PI05Config(PreTrainedConfig):
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
rtc_training_config: RTCTrainingConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
|
||||
@@ -44,6 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.rtc.training_time import (
|
||||
apply_rtc_training_time,
|
||||
apply_training_time_rtc_inference,
|
||||
masked_mean,
|
||||
sample_rtc_delay,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -78,8 +84,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
if time.ndim not in (1, 2):
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
@@ -87,8 +93,14 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
if time.ndim == 1:
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
time_flat = time.reshape(-1)
|
||||
sin_input = scaling_factor[None, :] * time_flat[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb.reshape(*time.shape, dimension)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
@@ -602,6 +614,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _training_time_rtc_inference_enabled(self):
|
||||
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -729,7 +744,12 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
if time.ndim == 1:
|
||||
time_expanded = time[:, None, None]
|
||||
elif time.ndim == 2:
|
||||
time_expanded = time[:, :, None]
|
||||
else:
|
||||
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
@@ -820,23 +840,35 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
use_training_time_rtc = self._training_time_rtc_inference_enabled()
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
if use_training_time_rtc:
|
||||
x_t_cond, time_tensor = apply_training_time_rtc_inference(
|
||||
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
|
||||
)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
x_t=x_t_cond,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
elif self._rtc_enabled():
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
@@ -847,7 +879,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=x_t,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
@@ -1250,7 +1288,17 @@ class PI05Policy(PreTrainedPolicy):
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
postfix_mask = None
|
||||
rtc_cfg = self.config.rtc_training_config
|
||||
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
|
||||
batch_size = actions.shape[0]
|
||||
time = self.model.sample_time(batch_size, actions.device)
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
|
||||
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise=noise, time=time)
|
||||
else:
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
@@ -1262,12 +1310,12 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
# π₀.₅ (pi05)
|
||||
|
||||
This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
|
||||
It is designed as a **Vision-Language-Action model with open-world generalization**.
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | π₀ | π₀.₅ |
|
||||
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
|
||||
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
|
||||
| AdaRMS | Not used | Used in action expert |
|
||||
| Tokenizer Length | 48 tokens | 200 tokens |
|
||||
| Discrete State Input | False (Uses `state_proj` layer) | True |
|
||||
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{intelligence2025pi05visionlanguageactionmodelopenworld,
|
||||
title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization},
|
||||
author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky},
|
||||
year = {2025},
|
||||
eprint = {2504.16054},
|
||||
archivePrefix= {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2504.16054},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
@@ -1,21 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_pi05 import PI05FullConfig
|
||||
from .modeling_pi05 import PI05FullPolicy
|
||||
from .processor_pi05 import make_pi05_full_pre_post_processors
|
||||
|
||||
__all__ = ["PI05FullConfig", "PI05FullPolicy", "make_pi05_full_pre_post_processors"]
|
||||
@@ -1,50 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Example script to run synthetic data generation with Qwen VLM
|
||||
# This generates user prompts and robot utterances for hierarchical policy training
|
||||
|
||||
# Configuration
|
||||
REPO_ID="lerobot/libero_10"
|
||||
MODEL="Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||
# or: MODEL="Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
|
||||
OUTPUT_DIR="/fsx/jade_choghari/outputs/libero-10-annotate-high"
|
||||
|
||||
BATCH_SIZE=16
|
||||
TEMPERATURE=0.9
|
||||
SAMPLE_INTERVAL=5.0 # generate dialogue every 1 second (all episodes processed)
|
||||
|
||||
# Run subtask annotation
|
||||
# python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --video-key observation.images.image \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --skip-existing \
|
||||
# --output-repo-id "jadechoghari/libero10-annotate" \
|
||||
# --batch-size "$BATCH_SIZE" \
|
||||
# run synthetic data generation (all episodes processed)
|
||||
# python examples/dataset/annotate_pgen.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --model "$MODEL" \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --temperature "$TEMPERATURE" \
|
||||
# --batch-size "$BATCH_SIZE" \
|
||||
# --sample-interval "$SAMPLE_INTERVAL" \
|
||||
# --image-key observation.images.base \
|
||||
# --num-image-views-per-sample 1
|
||||
|
||||
# for faster testing, increase sample interval:
|
||||
# --sample-interval 5.0 # Samples every 5 seconds (much faster)
|
||||
|
||||
# to push to hub after generation:
|
||||
# add --push-to-hub flag
|
||||
|
||||
# efficient batch processing: 4 episodes at once
|
||||
python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/high_level_annotate.py \
|
||||
--data-dir "/fsx/jade_choghari/outputs/libero-10-annotate" \
|
||||
--output-dir "$OUTPUT_DIR" \
|
||||
--video-mode \
|
||||
--video-key observation.images.image \
|
||||
--video-batch-size "$BATCH_SIZE" \
|
||||
--sample-interval 5.0
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,52 +0,0 @@
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
# /fsx/jade_choghari/data/libero_10_subtasks_kw_converted
|
||||
dataset = LeRobotDataset(repo_id="lerobot/libero_10_image_subtask")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path="/fsx/jade_choghari/models/pi05-base",
|
||||
)
|
||||
cfg.dtype = "bfloat16"
|
||||
|
||||
pre_processor, post_processor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path="/fsx/jade_choghari/models/pi05-base",
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
breakpoint()
|
||||
batch1 = pre_processor(batch)
|
||||
breakpoint()
|
||||
print(batch.keys())
|
||||
# print(batch['task_index_high_level'].shape)
|
||||
# print(batch['task_index_high_level'])
|
||||
# print(batch['user_prompt'][0])
|
||||
# print(batch['robot_utterance'][0])
|
||||
# print(batch['task'][0])
|
||||
|
||||
valid_episode_list = []
|
||||
for episode_idx in range(len(dataset.meta.episodes)):
|
||||
subtask_index = dataset[episode_idx]["subtask_index"]
|
||||
valid_episode_list.append(episode_idx)
|
||||
|
||||
print(len(valid_episode_list))
|
||||
|
||||
# read this parquet /fsx/jade_choghari/outputs/pgen_annotations1/meta/tasks.parquett
|
||||
# import pandas as pd
|
||||
# tasks_df = pd.read_parquet('/fsx/jade_choghari/outputs/pgen_annotations1/meta/tasks.parquet')
|
||||
|
||||
# # print all
|
||||
# print(tasks_df.columns)
|
||||
# breakpoint()
|
||||
@@ -1,49 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Example script to run synthetic data generation with Qwen VLM
|
||||
# This generates user prompts and robot utterances for hierarchical policy training
|
||||
|
||||
# Configuration
|
||||
REPO_ID="jadechoghari/collect-data"
|
||||
MODEL="Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||
# or: MODEL="Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
|
||||
OUTPUT_DIR="/fsx/jade_choghari/outputs/collect-data-pgen_new"
|
||||
|
||||
BATCH_SIZE=32
|
||||
TEMPERATURE=0.9
|
||||
SAMPLE_INTERVAL=5.0 # generate dialogue every 1 second (all episodes processed)
|
||||
|
||||
# Run subtask annotation
|
||||
python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
|
||||
--repo-id "$REPO_ID" \
|
||||
--video-key observation.images.base \
|
||||
--output-dir "$OUTPUT_DIR" \
|
||||
--output-repo-id "jadechoghari/collect-data-with-subtasks"
|
||||
# run synthetic data generation (all episodes processed)
|
||||
# python examples/dataset/annotate_pgen.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --model "$MODEL" \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --temperature "$TEMPERATURE" \
|
||||
# --batch-size "$BATCH_SIZE" \
|
||||
# --sample-interval "$SAMPLE_INTERVAL" \
|
||||
# --image-key observation.images.base \
|
||||
# --num-image-views-per-sample 1
|
||||
|
||||
# for faster testing, increase sample interval:
|
||||
# --sample-interval 5.0 # Samples every 5 seconds (much faster)
|
||||
|
||||
# to push to hub after generation:
|
||||
# add --push-to-hub flag
|
||||
|
||||
# efficient batch processing: 4 episodes at once
|
||||
# python examples/dataset/annotate_pgen.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --model "$MODEL" \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --video-mode \
|
||||
# --video-key observation.images.up \
|
||||
# --video-batch-size "$BATCH_SIZE" \
|
||||
# --sample-interval 1.0
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,183 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05_full")
|
||||
@dataclass
|
||||
class PI05FullConfig(PreTrainedConfig):
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
|
||||
n_action_steps: int = 50 # Number of action steps to execute
|
||||
|
||||
# Shorter state and action vectors will be padded to these dimensions
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Flow matching parameters: see openpi `PI0Pytorch`
|
||||
num_inference_steps: int = 10
|
||||
time_sampling_beta_alpha: float = 1.5
|
||||
time_sampling_beta_beta: float = 1.0
|
||||
time_sampling_scale: float = 0.999
|
||||
time_sampling_offset: float = 0.001
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD, # Pi0.5 uses quantiles for state
|
||||
"ACTION": NormalizationMode.MEAN_STD, # Pi0.5 uses quantiles for action
|
||||
}
|
||||
)
|
||||
|
||||
action_tokenizer_name: str = "physical-intelligence/fast"
|
||||
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
|
||||
max_action_tokens: int = 256
|
||||
fast_skip_tokens: int = 128
|
||||
|
||||
# subtask stuff
|
||||
max_decoding_steps: int = 200
|
||||
temperature: float = 0.0
|
||||
subtask_regeneration_interval: float = 1.0 # Regenerate subtask tokens every N seconds (0 = every call)
|
||||
|
||||
# Training settings
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = False # Freeze only the vision encoder
|
||||
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
|
||||
knowledge_insulation: bool = True # Enable knowledge insulation in attention (blocks gradients from action to VLM K/V)
|
||||
|
||||
# Loss weights (used when knowledge_insulation is enabled)
|
||||
loss_weight_flow: float = 1.0 # Weight for flow matching MSE loss (continuous actions)
|
||||
loss_weight_action_ce: float = 1.0 # Weight for FAST action token cross-entropy loss
|
||||
loss_weight_subtask_ce: float = 1.0 # Weight for subtask token cross-entropy loss
|
||||
|
||||
# Optimizer settings: see openpi `AdamW`
|
||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.01
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
tokenizer_max_length: int = 48 # see openpi `__post_init__`
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Validate configuration
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
|
||||
|
||||
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
|
||||
|
||||
if self.dtype not in ["bfloat16", "float32"]:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
for i in range(self.empty_cameras):
|
||||
key = OBS_IMAGES + f".empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, *self.image_resolution), # Use configured image resolution
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,), # Padded to max_state_dim
|
||||
)
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,), # Padded to max_action_dim
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,92 +0,0 @@
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
# import make_pre_post_processors
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path="/fsx/jade_choghari/models/pi05-base",
|
||||
)
|
||||
cfg.dtype = "bfloat16"
|
||||
|
||||
pre_processor, post_processor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path="/fsx/jade_choghari/models/pi05-base",
|
||||
)
|
||||
|
||||
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
|
||||
|
||||
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps)
|
||||
|
||||
# rename map --rename_map='{
|
||||
# "observation.images.side": "observation.images.base_0_rgb",
|
||||
# "observation.images.up": "observation.images.left_wrist_0_rgb"
|
||||
# }'
|
||||
rename_map = {
|
||||
"observation.images.side": "observation.images.base_0_rgb",
|
||||
"observation.images.up": "observation.images.left_wrist_0_rgb"
|
||||
}
|
||||
policy = make_policy(
|
||||
cfg=cfg,
|
||||
ds_meta=dataset.meta,
|
||||
rename_map=rename_map,
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=4,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
breakpoint()
|
||||
batch = pre_processor(batch)
|
||||
policy.train()
|
||||
# run inference
|
||||
# action = policy.select_action(batch)
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
breakpoint()
|
||||
# import requests
|
||||
# from PIL import Image
|
||||
# from transformers import AutoProcessor
|
||||
# model = policy.model.paligemma_with_expert.paligemma
|
||||
# model = model.to(device="cuda", dtype=torch.bfloat16)
|
||||
# model.eval()
|
||||
# prompt = "Describe this image."
|
||||
# url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
# image = Image.open(requests.get(url, stream=True).raw)
|
||||
# processor = AutoProcessor.from_pretrained(
|
||||
# "google/paligemma-3b-pt-224",
|
||||
# )
|
||||
# inputs = processor(image, prompt, return_tensors="pt").to(model.device)
|
||||
# print("generating...")
|
||||
# output = model.generate(
|
||||
# **inputs,
|
||||
# max_new_tokens=50,
|
||||
# use_cache=True, # default dynamic cache
|
||||
# )
|
||||
# print(processor.decode(output[0], skip_special_tokens=True))
|
||||
|
||||
|
||||
# # other model
|
||||
# from transformers import PaliGemmaForConditionalGeneration
|
||||
# model = PaliGemmaForConditionalGeneration.from_pretrained(
|
||||
# "google/paligemma2-3b-pt-224",
|
||||
# torch_dtype=torch.bfloat16,
|
||||
# device_map="auto",
|
||||
# )
|
||||
# model.eval()
|
||||
# print("generating...")
|
||||
# output = model.generate(
|
||||
# **inputs,
|
||||
# max_new_tokens=100,
|
||||
# use_cache=True, # default dynamic cache
|
||||
# )
|
||||
# print("Model 2 output:")
|
||||
# print(processor.decode(output[0], skip_special_tokens=True))
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,194 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi05_full.configuration_pi05 import PI05FullConfig
|
||||
from lerobot.policies.pi05_full.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi05_full_prepare_state_tokenizer_processor_step")
|
||||
@dataclass
|
||||
class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Processor step to prepare the state and tokenize the language input.
|
||||
"""
|
||||
|
||||
max_state_dim: int = 32
|
||||
task_key: str = "task"
|
||||
subtask_key: str = "subtask"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
|
||||
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
|
||||
if state is None:
|
||||
raise ValueError("State is required for PI05")
|
||||
user_prompts = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
|
||||
if user_prompts is None:
|
||||
raise ValueError("No user prompts found in complementary data")
|
||||
commands = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.subtask_key)
|
||||
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.max_state_dim)
|
||||
|
||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
full_prompts = []
|
||||
for i, user_prompt in enumerate(user_prompts):
|
||||
cleaned_text = user_prompt.strip().replace("_", " ").replace("\n", " ")
|
||||
cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari)
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\n"
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
|
||||
|
||||
# process commands (optional)
|
||||
if commands is not None:
|
||||
full_commands = []
|
||||
for i, command in enumerate(commands):
|
||||
cleaned_text = command.strip().replace("_", " ").replace("\n", " ")
|
||||
cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari)
|
||||
full_command = f"Subtask: {cleaned_text};\n"
|
||||
full_commands.append(full_command)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.subtask_key] = full_commands
|
||||
|
||||
# note: action tokens will be processed in the ActionTokenizerProcessorStep
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step does not alter the feature definitions.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
def make_pi05_full_pre_post_processors(
|
||||
config: PI05FullConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the PI0 policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Appending a newline character to the task description for tokenizer compatibility.
|
||||
5. Tokenizing the text prompt using the PaliGemma tokenizer.
|
||||
6. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the PI0 policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
Pi05FullPrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=config.text_tokenizer_name,
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
ActionTokenizerProcessorStep(
|
||||
action_tokenizer_name=config.action_tokenizer_name,
|
||||
max_action_tokens=config.max_action_tokens,
|
||||
fast_skip_tokens=config.fast_skip_tokens,
|
||||
paligemma_tokenizer_name=config.text_tokenizer_name,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -23,7 +23,7 @@ Based on:
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.configs.types import RTCAttentionSchedule, RTCTrainingDelayDistribution
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -53,3 +53,22 @@ class RTCConfig:
|
||||
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
|
||||
if self.debug_maxlen <= 0:
|
||||
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCTrainingConfig:
|
||||
"""Configuration for training-time RTC action prefix conditioning."""
|
||||
|
||||
enabled: bool = False
|
||||
min_delay: int = 0
|
||||
max_delay: int = 0
|
||||
delay_distribution: RTCTrainingDelayDistribution = RTCTrainingDelayDistribution.UNIFORM
|
||||
exp_decay: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.min_delay < 0:
|
||||
raise ValueError(f"min_delay must be >= 0, got {self.min_delay}")
|
||||
if self.max_delay < self.min_delay:
|
||||
raise ValueError(f"max_delay ({self.max_delay}) must be >= min_delay ({self.min_delay})")
|
||||
if self.exp_decay <= 0:
|
||||
raise ValueError(f"exp_decay must be positive, got {self.exp_decay}")
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import RTCTrainingDelayDistribution
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig
|
||||
|
||||
|
||||
def sample_rtc_delay(cfg: RTCTrainingConfig, batch_size: int, device: torch.device) -> torch.Tensor:
|
||||
if cfg.max_delay == cfg.min_delay:
|
||||
return torch.full((batch_size,), cfg.min_delay, device=device, dtype=torch.long)
|
||||
|
||||
if cfg.delay_distribution == RTCTrainingDelayDistribution.UNIFORM:
|
||||
return torch.randint(cfg.min_delay, cfg.max_delay + 1, (batch_size,), device=device, dtype=torch.long)
|
||||
|
||||
delay_values = torch.arange(cfg.min_delay, cfg.max_delay + 1, device=device, dtype=torch.long)
|
||||
weights = torch.exp(-cfg.exp_decay * delay_values.to(dtype=torch.float32))
|
||||
probs = weights / weights.sum()
|
||||
samples = torch.multinomial(probs, batch_size, replacement=True)
|
||||
return delay_values[samples]
|
||||
|
||||
|
||||
def apply_rtc_training_time(
|
||||
time: torch.Tensor, delay: torch.Tensor, seq_len: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
device = time.device
|
||||
delay = torch.clamp(delay, max=seq_len)
|
||||
prefix_mask = torch.arange(seq_len, device=device)[None, :] < delay[:, None]
|
||||
time_tokens = time[:, None].expand(-1, seq_len)
|
||||
time_tokens = time_tokens.masked_fill(prefix_mask, 0.0)
|
||||
postfix_mask = ~prefix_mask
|
||||
return time_tokens, postfix_mask
|
||||
|
||||
|
||||
def masked_mean(
|
||||
losses: torch.Tensor, mask: torch.Tensor | None, reduce_dims: tuple[int, ...], eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
if mask is None:
|
||||
return losses.mean(dim=reduce_dims)
|
||||
|
||||
mask = mask.to(dtype=losses.dtype)
|
||||
while mask.dim() < losses.dim():
|
||||
mask = mask.unsqueeze(-1)
|
||||
masked = losses * mask
|
||||
denom = mask.sum(dim=reduce_dims).clamp_min(eps)
|
||||
return masked.sum(dim=reduce_dims) / denom
|
||||
|
||||
|
||||
def apply_training_time_rtc_inference(
|
||||
x_t: torch.Tensor,
|
||||
time: float,
|
||||
inference_delay: int | None,
|
||||
prev_chunk_left_over: torch.Tensor | None,
|
||||
chunk_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply training-time RTC conditioning during inference.
|
||||
|
||||
Based on Algorithm 1 from "Training-Time Action Conditioning for Efficient Real-Time Chunking".
|
||||
|
||||
At each denoising step:
|
||||
1. Replace prefix positions in x_t with ground truth from previous chunk
|
||||
2. Create per-token timesteps with 1.0 for prefix positions
|
||||
|
||||
Args:
|
||||
x_t: Current noisy actions (B, T, D)
|
||||
time: Current flow matching timestep (scalar)
|
||||
inference_delay: Number of prefix actions to condition on
|
||||
prev_chunk_left_over: Previous chunk's leftover actions (B, T, D)
|
||||
chunk_size: Total chunk size T
|
||||
|
||||
Returns:
|
||||
x_t_conditioned: x_t with prefix replaced by previous actions
|
||||
time_per_token: Per-token timesteps (B, T) with 1.0 for prefix
|
||||
"""
|
||||
batch_size = x_t.shape[0]
|
||||
device = x_t.device
|
||||
|
||||
if inference_delay is None or inference_delay <= 0 or prev_chunk_left_over is None:
|
||||
time_scalar = torch.full((batch_size,), time, device=device, dtype=torch.float32)
|
||||
return x_t, time_scalar
|
||||
|
||||
delay = min(inference_delay, chunk_size)
|
||||
prefix_mask = torch.arange(chunk_size, device=device)[None, :] < delay
|
||||
|
||||
x_t_conditioned = torch.where(
|
||||
prefix_mask[:, :, None].expand_as(x_t),
|
||||
prev_chunk_left_over[:, :chunk_size, :],
|
||||
x_t,
|
||||
)
|
||||
|
||||
time_per_token = torch.full((batch_size, chunk_size), time, device=device, dtype=torch.float32)
|
||||
time_per_token = time_per_token.masked_fill(prefix_mask, 1.0)
|
||||
|
||||
return x_t_conditioned, time_per_token
|
||||
@@ -239,10 +239,8 @@ class SACPolicy(
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
"""Return the current temperature value, always in sync with log_alpha."""
|
||||
return self.log_alpha.exp().item()
|
||||
def update_temperature(self):
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
def compute_loss_critic(
|
||||
self,
|
||||
@@ -459,10 +457,11 @@ class SACPolicy(
|
||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.target_entropy = -np.prod(dim) / 2
|
||||
|
||||
def _init_temperature(self) -> None:
|
||||
"""Set up temperature parameter (log_alpha)."""
|
||||
def _init_temperature(self):
|
||||
"""Set up temperature parameter and initial log_alpha."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
|
||||
@@ -20,7 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@@ -103,8 +103,9 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
# Real-Time Chunking (RTC) configurations
|
||||
rtc_config: RTCConfig | None = None
|
||||
rtc_training_config: RTCTrainingConfig | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
@@ -63,6 +63,12 @@ from typing_extensions import Unpack
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.rtc.training_time import (
|
||||
apply_rtc_training_time,
|
||||
apply_training_time_rtc_inference,
|
||||
masked_mean,
|
||||
sample_rtc_delay,
|
||||
)
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
from lerobot.policies.utils import (
|
||||
@@ -85,8 +91,8 @@ def create_sinusoidal_pos_embedding(
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
if time.ndim not in (1, 2):
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
@@ -94,9 +100,14 @@ def create_sinusoidal_pos_embedding(
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
if time.ndim == 1:
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
time_flat = time.reshape(-1)
|
||||
sin_input = scaling_factor[None, :] * time_flat[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb
|
||||
return pos_emb.reshape(*time.shape, dimension)
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
@@ -375,6 +386,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
actions = self.prepare_action(batch)
|
||||
postfix_mask = None
|
||||
rtc_cfg = self.config.rtc_training_config
|
||||
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
|
||||
batch_size = actions.shape[0]
|
||||
if time is None:
|
||||
time = self.model.sample_time(batch_size, actions.device)
|
||||
if noise is None:
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
|
||||
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
@@ -384,6 +405,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
in_episode_bound = ~actions_is_pad
|
||||
losses = losses * in_episode_bound.unsqueeze(-1)
|
||||
loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
||||
postfix_mask = in_episode_bound if postfix_mask is None else (postfix_mask & in_episode_bound)
|
||||
|
||||
# Remove padding
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
@@ -391,12 +413,12 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
@@ -596,6 +618,9 @@ class VLAFlowMatching(nn.Module):
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _training_time_rtc_inference_enabled(self):
|
||||
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
|
||||
|
||||
def set_requires_grad(self):
|
||||
for params in self.state_proj.parameters():
|
||||
params.requires_grad = self.config.train_state_proj
|
||||
@@ -731,7 +756,10 @@ class VLAFlowMatching(nn.Module):
|
||||
)
|
||||
time_emb = time_emb.type(dtype=dtype)
|
||||
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
if time_emb.dim() == 2:
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
elif time_emb.shape[:2] != action_emb.shape[:2]:
|
||||
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
||||
@@ -763,7 +791,12 @@ class VLAFlowMatching(nn.Module):
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
if time.ndim == 1:
|
||||
time_expanded = time[:, None, None]
|
||||
elif time.ndim == 2:
|
||||
time_expanded = time[:, :, None]
|
||||
else:
|
||||
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
@@ -826,23 +859,35 @@ class VLAFlowMatching(nn.Module):
|
||||
num_steps = self.config.num_steps
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
use_training_time_rtc = self._training_time_rtc_inference_enabled()
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
x_t=input_x_t,
|
||||
if use_training_time_rtc:
|
||||
x_t_cond, time_tensor = apply_training_time_rtc_inference(
|
||||
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
|
||||
)
|
||||
v_t = self.denoise_step(
|
||||
x_t=x_t_cond,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
timestep=current_timestep,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
elif self._rtc_enabled():
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
x_t=input_x_t,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
@@ -853,7 +898,13 @@ class VLAFlowMatching(nn.Module):
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
x_t=x_t,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_features`, `output_features`, and perhaps `max_random_shift_ratio`.
|
||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
|
||||
|
||||
Args:
|
||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||
@@ -40,12 +40,24 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
is an alternative to using action repeats. If this is set to more than 1, then we require
|
||||
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
|
||||
approach of using multiple steps from the plan is not in the original implementation.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
|
||||
match the original implementation.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
|
||||
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
|
||||
normalization mode here.
|
||||
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
|
||||
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
|
||||
latent_dim: Observation's latent embedding dimension.
|
||||
|
||||
@@ -32,7 +32,7 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_features` and `output_features`.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
@@ -46,12 +46,21 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
current step and additional steps going back).
|
||||
n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts.
|
||||
action_chunk_size: Action chunk size of each action prediction token.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.image" refers to an input from
|
||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
|
||||
@@ -168,14 +168,11 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
user_prompt_key = {"user_prompt": batch["user_prompt"]} if "user_prompt" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key, **user_prompt_key, **subtask_key}
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -18,18 +18,16 @@
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable
|
||||
from typing import Any, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
from .pipeline import (
|
||||
ComplementaryDataProcessorStep,
|
||||
@@ -71,10 +69,10 @@ class HasTeleopEvents(Protocol):
|
||||
|
||||
|
||||
# Type variable constrained to Teleoperator subclasses that also implement events
|
||||
TeleopWithEvents = TypeVar("TeleopWithEvents", bound="Teleoperator")
|
||||
TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
|
||||
|
||||
|
||||
def _check_teleop_with_events(teleop: "Teleoperator") -> None:
|
||||
def _check_teleop_with_events(teleop: Teleoperator) -> None:
|
||||
"""
|
||||
Runtime check that a teleoperator implements the `HasTeleopEvents` protocol.
|
||||
|
||||
@@ -105,7 +103,7 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
|
||||
teleop_device: The teleoperator instance to get the action from.
|
||||
"""
|
||||
|
||||
teleop_device: "Teleoperator"
|
||||
teleop_device: Teleoperator
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
"""
|
||||
@@ -314,7 +312,7 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
Applies a penalty for inefficient gripper usage.
|
||||
|
||||
@@ -329,27 +327,26 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
"""
|
||||
Calculates the gripper penalty and adds it to the complementary data.
|
||||
|
||||
Args:
|
||||
transition: The incoming environment transition.
|
||||
complementary_data: The incoming complementary data, which should contain
|
||||
raw joint positions.
|
||||
|
||||
Returns:
|
||||
The modified transition with the penalty added to complementary data.
|
||||
A new complementary data dictionary with the `discrete_penalty` key added.
|
||||
"""
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
action = self.transition.get(TransitionKey.ACTION)
|
||||
|
||||
raw_joint_positions = complementary_data.get("raw_joint_positions")
|
||||
if raw_joint_positions is None:
|
||||
return new_transition
|
||||
return complementary_data
|
||||
|
||||
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
|
||||
if current_gripper_pos is None:
|
||||
return new_transition
|
||||
return complementary_data
|
||||
|
||||
# Gripper action is a PolicyAction at this stage
|
||||
gripper_action = action[-1].item()
|
||||
@@ -365,12 +362,11 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
# Update complementary data with penalty info
|
||||
# Create new complementary data with penalty info
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
|
||||
return new_transition
|
||||
return new_complementary_data
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -34,12 +34,7 @@ from lerobot.utils.constants import (
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_TOKENS,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_LANGUAGE_USER_PROMPT,
|
||||
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_USER_PROMPT_TOKENS,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -144,70 +139,18 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
|
||||
return None
|
||||
|
||||
def get_user_prompt(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""
|
||||
Extracts the user_prompt from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
A list of user_prompt strings, or None if the user_prompt key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
user_prompt = complementary_data.get("user_prompt")
|
||||
if user_prompt is None:
|
||||
return None
|
||||
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(user_prompt, str):
|
||||
return [user_prompt]
|
||||
elif isinstance(user_prompt, list) and all(isinstance(t, str) for t in user_prompt):
|
||||
return user_prompt
|
||||
|
||||
return None
|
||||
|
||||
def get_subtask(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""
|
||||
Extracts the subtask from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
A list of subtask strings, or None if the subtask key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
subtask = complementary_data.get("subtask")
|
||||
if subtask is None:
|
||||
return None
|
||||
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(subtask, str):
|
||||
return [subtask]
|
||||
elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask):
|
||||
return subtask
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation: RobotObservation) -> RobotObservation:
|
||||
"""
|
||||
Tokenizes the task description and user_prompt (if available) and adds them to the observation dictionary.
|
||||
Tokenizes the task description and adds it to the observation dictionary.
|
||||
|
||||
This method retrieves the task and user_prompt, tokenizes them, moves the resulting tensors to the
|
||||
This method retrieves the task, tokenizes it, moves the resulting tensors to the
|
||||
same device as other data in the transition, and updates the observation.
|
||||
|
||||
Args:
|
||||
observation: The original observation dictionary.
|
||||
|
||||
Returns:
|
||||
The updated observation dictionary including token IDs and attention masks.
|
||||
The updated observation dictionary including token IDs and an attention mask.
|
||||
"""
|
||||
task = self.get_task(self.transition)
|
||||
if task is None:
|
||||
@@ -233,58 +176,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Tokenize user_prompt if available
|
||||
user_prompt = self.get_user_prompt(self.transition)
|
||||
if user_prompt is not None:
|
||||
tokenized_user_prompt = self._tokenize_text(user_prompt)
|
||||
|
||||
# Move new tokenized tensors to the detected device
|
||||
if target_device is not None:
|
||||
tokenized_user_prompt = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_user_prompt.items()
|
||||
}
|
||||
|
||||
# Add tokenized user_prompt to the observation
|
||||
new_observation[OBS_LANGUAGE_USER_PROMPT_TOKENS] = tokenized_user_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK] = tokenized_user_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Tokenize subtask if available
|
||||
subtask = self.get_subtask(self.transition)
|
||||
if subtask is not None:
|
||||
tokenized_subtask = self._tokenize_text(subtask)
|
||||
|
||||
# Add EOS token at the end of each subtask sequence (before padding)
|
||||
eos_token_id = self.input_tokenizer.eos_token_id
|
||||
input_ids = tokenized_subtask["input_ids"]
|
||||
attention_mask = tokenized_subtask["attention_mask"]
|
||||
for i in range(input_ids.size(0)):
|
||||
# Find the length of actual tokens (sum of attention mask)
|
||||
seq_len = attention_mask[i].sum().item()
|
||||
|
||||
max_len = input_ids.size(1)
|
||||
if seq_len >= max_len:
|
||||
raise ValueError(
|
||||
f"No room to append EOS: seq_len={seq_len} equals max_length={max_len}. "
|
||||
"Increase max_length or tokenize with padding=False then pad after adding EOS."
|
||||
)
|
||||
# Add EOS token at the end
|
||||
input_ids[i, seq_len] = eos_token_id
|
||||
attention_mask[i, seq_len] = 1
|
||||
|
||||
# Move new tokenized tensors to the detected device
|
||||
if target_device is not None:
|
||||
tokenized_subtask = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_subtask.items()
|
||||
}
|
||||
|
||||
# Add tokenized subtask to the observation
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
|
||||
return new_observation
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
@@ -383,28 +274,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
# Add features for user_prompt tokens and attention mask if they don't already exist
|
||||
if OBS_LANGUAGE_USER_PROMPT_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_USER_PROMPT_TOKENS] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
if OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
# Add features for subtask tokens and attention mask if they don't already exist
|
||||
if OBS_LANGUAGE_SUBTASK_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
if OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@@ -658,4 +527,4 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
Returns:
|
||||
The updated dictionary of policy features.
|
||||
"""
|
||||
return features
|
||||
return features
|
||||
|
||||
@@ -412,10 +412,7 @@ def make_processors(
|
||||
if cfg.processor.observation.add_current_to_observation:
|
||||
env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot))
|
||||
|
||||
add_ee_pose = (
|
||||
cfg.processor.observation is not None and cfg.processor.observation.add_ee_pose_to_observation
|
||||
)
|
||||
if kinematics_solver is not None and add_ee_pose:
|
||||
if kinematics_solver is not None:
|
||||
env_pipeline_steps.append(
|
||||
ForwardKinematicsJointsToEEObservation(
|
||||
kinematics=kinematics_solver,
|
||||
@@ -438,12 +435,7 @@ def make_processors(
|
||||
)
|
||||
|
||||
# Add gripper penalty processor if gripper config exists and enabled
|
||||
# Only add if max_gripper_pos is explicitly configured (required for normalization)
|
||||
if (
|
||||
cfg.processor.gripper is not None
|
||||
and cfg.processor.gripper.use_gripper
|
||||
and cfg.processor.max_gripper_pos is not None
|
||||
):
|
||||
if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper:
|
||||
env_pipeline_steps.append(
|
||||
GripperPenaltyProcessorStep(
|
||||
penalty=cfg.processor.gripper.gripper_penalty,
|
||||
|
||||
@@ -545,6 +545,9 @@ def add_actor_information_and_train(
|
||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||
training_infos["temperature"] = policy.temperature
|
||||
|
||||
# Update temperature
|
||||
policy.update_temperature()
|
||||
|
||||
# Push policy to actors if needed
|
||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
|
||||
@@ -26,21 +26,8 @@ from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.utils.constants import PRETRAINED_MODEL_DIR
|
||||
|
||||
|
||||
def cfg_to_group(
|
||||
cfg: TrainPipelineConfig, return_list: bool = False, truncate_tags: bool = False, max_tag_length: int = 64
|
||||
) -> list[str] | str:
|
||||
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
|
||||
def _maybe_truncate(tag: str) -> str:
|
||||
"""Truncate tag to max_tag_length characters if required.
|
||||
|
||||
wandb rejects tags longer than 64 characters.
|
||||
See: https://github.com/wandb/wandb/blob/main/wandb/sdk/wandb_settings.py
|
||||
"""
|
||||
if len(tag) <= max_tag_length:
|
||||
return tag
|
||||
return tag[:max_tag_length]
|
||||
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"seed:{cfg.seed}",
|
||||
@@ -49,8 +36,6 @@ def cfg_to_group(
|
||||
lst.append(f"dataset:{cfg.dataset.repo_id}")
|
||||
if cfg.env is not None:
|
||||
lst.append(f"env:{cfg.env.type}")
|
||||
if truncate_tags:
|
||||
lst = [_maybe_truncate(tag) for tag in lst]
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
@@ -98,7 +83,7 @@ class WandBLogger:
|
||||
entity=self.cfg.entity,
|
||||
name=self.job_name,
|
||||
notes=self.cfg.notes,
|
||||
tags=cfg_to_group(cfg, return_list=True, truncate_tags=True),
|
||||
tags=cfg_to_group(cfg, return_list=True),
|
||||
dir=self.log_dir,
|
||||
config=cfg.to_dict(),
|
||||
# TODO(rcadene): try set to True
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .bi_openarm_follower import BiOpenArmFollower
|
||||
from .config_bi_openarm_follower import BiOpenArmFollowerConfig
|
||||
|
||||
__all__ = ["BiOpenArmFollower", "BiOpenArmFollowerConfig"]
|
||||
@@ -1,175 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_bi_openarm_follower import BiOpenArmFollowerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiOpenArmFollower(Robot):
|
||||
"""
|
||||
Bimanual OpenArm Follower Arms
|
||||
"""
|
||||
|
||||
config_class = BiOpenArmFollowerConfig
|
||||
name = "bi_openarm_follower"
|
||||
|
||||
def __init__(self, config: BiOpenArmFollowerConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
left_arm_config = OpenArmFollowerConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.left_arm_config.port,
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
cameras=config.left_arm_config.cameras,
|
||||
side=config.left_arm_config.side,
|
||||
can_interface=config.left_arm_config.can_interface,
|
||||
use_can_fd=config.left_arm_config.use_can_fd,
|
||||
can_bitrate=config.left_arm_config.can_bitrate,
|
||||
can_data_bitrate=config.left_arm_config.can_data_bitrate,
|
||||
motor_config=config.left_arm_config.motor_config,
|
||||
position_kd=config.left_arm_config.position_kd,
|
||||
position_kp=config.left_arm_config.position_kp,
|
||||
joint_limits=config.left_arm_config.joint_limits,
|
||||
)
|
||||
|
||||
right_arm_config = OpenArmFollowerConfig(
|
||||
id=f"{config.id}_right" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.right_arm_config.port,
|
||||
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
|
||||
max_relative_target=config.right_arm_config.max_relative_target,
|
||||
cameras=config.right_arm_config.cameras,
|
||||
side=config.right_arm_config.side,
|
||||
can_interface=config.right_arm_config.can_interface,
|
||||
use_can_fd=config.right_arm_config.use_can_fd,
|
||||
can_bitrate=config.right_arm_config.can_bitrate,
|
||||
can_data_bitrate=config.right_arm_config.can_data_bitrate,
|
||||
motor_config=config.right_arm_config.motor_config,
|
||||
position_kd=config.right_arm_config.position_kd,
|
||||
position_kp=config.right_arm_config.position_kp,
|
||||
joint_limits=config.right_arm_config.joint_limits,
|
||||
)
|
||||
|
||||
self.left_arm = OpenArmFollower(left_arm_config)
|
||||
self.right_arm = OpenArmFollower(right_arm_config)
|
||||
|
||||
# Only for compatibility with other parts of the codebase that expect a `robot.cameras` attribute
|
||||
self.cameras = {**self.left_arm.cameras, **self.right_arm.cameras}
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
left_arm_motors_ft = self.left_arm._motors_ft
|
||||
right_arm_motors_ft = self.right_arm._motors_ft
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
|
||||
}
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
left_arm_cameras_ft = self.left_arm._cameras_ft
|
||||
right_arm_cameras_ft = self.right_arm._cameras_ft
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
left_obs = self.left_arm.get_observation()
|
||||
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
|
||||
|
||||
# Add "right_" prefix
|
||||
right_obs = self.right_arm.get_observation()
|
||||
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(
|
||||
self,
|
||||
action: RobotAction,
|
||||
custom_kp: dict[str, float] | None = None,
|
||||
custom_kd: dict[str, float] | None = None,
|
||||
) -> RobotAction:
|
||||
# Remove "left_" prefix
|
||||
left_action = {
|
||||
key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_")
|
||||
}
|
||||
# Remove "right_" prefix
|
||||
right_action = {
|
||||
key.removeprefix("right_"): value for key, value in action.items() if key.startswith("right_")
|
||||
}
|
||||
|
||||
sent_action_left = self.left_arm.send_action(left_action, custom_kp, custom_kd)
|
||||
sent_action_right = self.right_arm.send_action(right_action, custom_kp, custom_kd)
|
||||
|
||||
# Add prefixes back
|
||||
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
|
||||
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
||||
|
||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||
|
||||
def disconnect(self):
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
@@ -1,30 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.robots.openarm_follower import OpenArmFollowerConfigBase
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("bi_openarm_follower")
|
||||
@dataclass
|
||||
class BiOpenArmFollowerConfig(RobotConfig):
|
||||
"""Configuration class for Bi OpenArm Follower robots."""
|
||||
|
||||
left_arm_config: OpenArmFollowerConfigBase
|
||||
right_arm_config: OpenArmFollowerConfigBase
|
||||
@@ -1,20 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_openarm_follower import OpenArmFollowerConfig, OpenArmFollowerConfigBase
|
||||
from .openarm_follower import OpenArmFollower
|
||||
|
||||
__all__ = ["OpenArmFollower", "OpenArmFollowerConfig", "OpenArmFollowerConfigBase"]
|
||||
@@ -1,122 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
LEFT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = {
|
||||
"joint_1": (-75.0, 75.0),
|
||||
"joint_2": (-90.0, 9.0),
|
||||
"joint_3": (-85.0, 85.0),
|
||||
"joint_4": (0.0, 135.0),
|
||||
"joint_5": (-85.0, 85.0),
|
||||
"joint_6": (-40.0, 40.0),
|
||||
"joint_7": (-80.0, 80.0),
|
||||
"gripper": (-65.0, 0.0),
|
||||
}
|
||||
|
||||
RIGHT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = {
|
||||
"joint_1": (-75.0, 75.0),
|
||||
"joint_2": (-9.0, 90.0),
|
||||
"joint_3": (-85.0, 85.0),
|
||||
"joint_4": (0.0, 135.0),
|
||||
"joint_5": (-85.0, 85.0),
|
||||
"joint_6": (-40.0, 40.0),
|
||||
"joint_7": (-80.0, 80.0),
|
||||
"gripper": (-65.0, 0.0),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenArmFollowerConfigBase:
|
||||
"""Base configuration for the OpenArms follower robot with Damiao motors."""
|
||||
|
||||
# CAN interfaces - one per arm
|
||||
# arm CAN interface (e.g., "can1")
|
||||
# Linux: "can0", "can1", etc.
|
||||
port: str
|
||||
|
||||
# side of the arm: "left" or "right". If "None" default values will be used
|
||||
side: str | None = None
|
||||
|
||||
# CAN interface type: "socketcan" (Linux), "slcan" (serial), or "auto" (auto-detect)
|
||||
can_interface: str = "socketcan"
|
||||
|
||||
# CAN FD settings (OpenArms uses CAN FD by default)
|
||||
use_can_fd: bool = True
|
||||
can_bitrate: int = 1000000 # Nominal bitrate (1 Mbps)
|
||||
can_data_bitrate: int = 5000000 # Data bitrate for CAN FD (5 Mbps)
|
||||
|
||||
# Whether to disable torque when disconnecting
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# Safety limit for relative target positions
|
||||
# Set to a positive scalar for all motors, or a dict mapping motor names to limits
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
# Camera configurations
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# Motor configuration for OpenArms (7 DOF per arm)
|
||||
# Maps motor names to (send_can_id, recv_can_id, motor_type)
|
||||
# Based on: https://docs.openarm.dev/software/setup/configure-test
|
||||
# OpenArms uses 4 types of motors:
|
||||
# - DM8009 (DM-J8009P-2EC) for shoulders (high torque)
|
||||
# - DM4340P and DM4340 for shoulder rotation and elbow
|
||||
# - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper
|
||||
motor_config: dict[str, tuple[int, int, str]] = field(
|
||||
default_factory=lambda: {
|
||||
"joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009)
|
||||
"joint_2": (0x02, 0x12, "dm8009"), # J2 - Shoulder lift (DM8009)
|
||||
"joint_3": (0x03, 0x13, "dm4340"), # J3 - Shoulder rotation (DM4340)
|
||||
"joint_4": (0x04, 0x14, "dm4340"), # J4 - Elbow flex (DM4340)
|
||||
"joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310)
|
||||
"joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310)
|
||||
"joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310)
|
||||
"gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310)
|
||||
}
|
||||
)
|
||||
|
||||
# MIT control parameters for position control (used in send_action)
|
||||
# List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
|
||||
position_kp: list[float] = field(
|
||||
default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 25.0]
|
||||
)
|
||||
position_kd: list[float] = field(default_factory=lambda: [5.0, 5.0, 3.0, 5.0, 0.3, 0.3, 0.3, 0.3])
|
||||
|
||||
# Values for joint limits. Can be overridden via CLI (for custom values) or by setting config.side to either 'left' or 'right'.
|
||||
# If config.side is left set to None and no CLI values are passed, the default joint limit values are small for safety.
|
||||
joint_limits: dict[str, tuple[float, float]] = field(
|
||||
default_factory=lambda: {
|
||||
"joint_1": (-5.0, 5.0),
|
||||
"joint_2": (-5.0, 5.0),
|
||||
"joint_3": (-5.0, 5.0),
|
||||
"joint_4": (0.0, 5.0),
|
||||
"joint_5": (-5.0, 5.0),
|
||||
"joint_6": (-5.0, 5.0),
|
||||
"joint_7": (-5.0, 5.0),
|
||||
"gripper": (-5.0, 0.0),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("openarm_follower")
|
||||
@dataclass
|
||||
class OpenArmFollowerConfig(RobotConfig, OpenArmFollowerConfigBase):
|
||||
pass
|
||||
@@ -1,348 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.damiao import DamiaoMotorsBus
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
from .config_openarm_follower import (
|
||||
LEFT_DEFAULT_JOINTS_LIMITS,
|
||||
RIGHT_DEFAULT_JOINTS_LIMITS,
|
||||
OpenArmFollowerConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenArmFollower(Robot):
|
||||
"""
|
||||
OpenArms Follower Robot which uses CAN bus communication to control 7 DOF arm with a gripper.
|
||||
The arm uses Damiao motors in MIT control mode.
|
||||
"""
|
||||
|
||||
config_class = OpenArmFollowerConfig
|
||||
name = "openarm_follower"
|
||||
|
||||
def __init__(self, config: OpenArmFollowerConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Arm motors
|
||||
motors: dict[str, Motor] = {}
|
||||
for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items():
|
||||
motor = Motor(
|
||||
send_id, motor_type_str, MotorNormMode.DEGREES
|
||||
) # Always use degrees for Damiao motors
|
||||
motor.recv_id = recv_id
|
||||
motor.motor_type_str = motor_type_str
|
||||
motors[motor_name] = motor
|
||||
|
||||
self.bus = DamiaoMotorsBus(
|
||||
port=self.config.port,
|
||||
motors=motors,
|
||||
calibration=self.calibration,
|
||||
can_interface=self.config.can_interface,
|
||||
use_can_fd=self.config.use_can_fd,
|
||||
bitrate=self.config.can_bitrate,
|
||||
data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None,
|
||||
)
|
||||
|
||||
if config.side is not None:
|
||||
if config.side == "left":
|
||||
config.joint_limits = LEFT_DEFAULT_JOINTS_LIMITS
|
||||
elif config.side == "right":
|
||||
config.joint_limits = RIGHT_DEFAULT_JOINTS_LIMITS
|
||||
else:
|
||||
raise ValueError(
|
||||
"config.side must be either 'left', 'right' (for default values) or 'None' (for CLI values)"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Set config.side to either 'left' or 'right' to use pre-configured values for joint limits."
|
||||
)
|
||||
logger.info(f"Values used for joint limits: {config.joint_limits}.")
|
||||
|
||||
# Initialize cameras
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
"""Motor features for observation and action spaces."""
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus.motors:
|
||||
features[f"{motor}.pos"] = float
|
||||
features[f"{motor}.vel"] = float # Add this
|
||||
features[f"{motor}.torque"] = float # Add this
|
||||
return features
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
"""Camera features for observation space."""
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
"""Combined observation features from motors and cameras."""
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
"""Action features."""
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if robot is connected."""
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
Connect to the robot and optionally calibrate.
|
||||
|
||||
We assume that at connection time, the arms are in a safe rest position,
|
||||
and torque can be safely disabled to run calibration if needed.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
# Connect to CAN bus
|
||||
logger.info(f"Connecting arm on {self.config.port}...")
|
||||
self.bus.connect()
|
||||
|
||||
# Run calibration if needed
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
|
||||
)
|
||||
self.calibrate()
|
||||
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
self.configure()
|
||||
|
||||
if self.is_calibrated:
|
||||
self.bus.set_zero_position()
|
||||
|
||||
self.bus.enable_torque()
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
"""Check if robot is calibrated."""
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""
|
||||
Run calibration procedure for OpenArms robot.
|
||||
|
||||
The calibration procedure:
|
||||
1. Disable torque
|
||||
2. Ask user to position arms in hanging position with grippers closed
|
||||
3. Set this as zero position
|
||||
4. Record range of motion for each joint
|
||||
5. Save calibration
|
||||
"""
|
||||
if self.calibration:
|
||||
# Calibration file exists, ask user whether to use it or run new calibration
|
||||
user_input = input(
|
||||
f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
|
||||
)
|
||||
if user_input.strip().lower() != "c":
|
||||
logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
|
||||
self.bus.write_calibration(self.calibration)
|
||||
return
|
||||
|
||||
logger.info(f"\nRunning calibration for {self}")
|
||||
self.bus.disable_torque()
|
||||
|
||||
# Step 1: Set zero position
|
||||
input(
|
||||
"\nCalibration: Set Zero Position)\n"
|
||||
"Position the arm in the following configuration:\n"
|
||||
" - Arm hanging straight down\n"
|
||||
" - Gripper closed\n"
|
||||
"Press ENTER when ready..."
|
||||
)
|
||||
|
||||
# Set current position as zero for all motors
|
||||
self.bus.set_zero_position()
|
||||
logger.info("Arm zero position set.")
|
||||
|
||||
logger.info("Setting range: -90° to +90° for safety by default for all joints")
|
||||
for motor_name, motor in self.bus.motors.items():
|
||||
self.calibration[motor_name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=0,
|
||||
homing_offset=0,
|
||||
range_min=-90,
|
||||
range_max=90,
|
||||
)
|
||||
|
||||
self.bus.write_calibration(self.calibration)
|
||||
self._save_calibration()
|
||||
print(f"Calibration saved to {self.calibration_fpath}")
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Configure motors with appropriate settings."""
|
||||
# TODO(Steven, Pepijn): Slightly different from what it is happening in the leader
|
||||
with self.bus.torque_disabled():
|
||||
self.bus.configure_motors()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Get current observation from robot including position, velocity, and torque.
|
||||
|
||||
Reads all motor states (pos/vel/torque) in one CAN refresh cycle
|
||||
instead of 3 separate reads.
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
obs_dict: dict[str, Any] = {}
|
||||
|
||||
states = self.bus.sync_read_all_states()
|
||||
|
||||
for motor in self.bus.motors:
|
||||
state = states.get(motor, {})
|
||||
obs_dict[f"{motor}.pos"] = state.get("position", 0.0)
|
||||
obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0)
|
||||
obs_dict[f"{motor}.torque"] = state.get("torque", 0.0)
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} get_observation took: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(
|
||||
self,
|
||||
action: RobotAction,
|
||||
custom_kp: dict[str, float] | None = None,
|
||||
custom_kd: dict[str, float] | None = None,
|
||||
) -> RobotAction:
|
||||
"""
|
||||
Send action command to robot.
|
||||
|
||||
The action magnitude may be clipped based on safety limits.
|
||||
|
||||
Args:
|
||||
action: Dictionary with motor positions (e.g., "joint_1.pos", "joint_2.pos")
|
||||
custom_kp: Optional custom kp gains per motor (e.g., {"joint_1": 120.0, "joint_2": 150.0})
|
||||
custom_kd: Optional custom kd gains per motor (e.g., {"joint_1": 1.5, "joint_2": 2.0})
|
||||
|
||||
Returns:
|
||||
The action actually sent (potentially clipped)
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
# Apply joint limit clipping to arm
|
||||
for motor_name, position in goal_pos.items():
|
||||
if motor_name in self.config.joint_limits:
|
||||
min_limit, max_limit = self.config.joint_limits[motor_name]
|
||||
clipped_position = max(min_limit, min(max_limit, position))
|
||||
if clipped_position != position:
|
||||
logger.debug(f"Clipped {motor_name} from {position:.2f}° to {clipped_position:.2f}°")
|
||||
goal_pos[motor_name] = clipped_position
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# /!\ Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.bus.sync_read("Present_Position")
|
||||
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
|
||||
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
||||
|
||||
# TODO(Steven, Pepijn): Refactor writing
|
||||
# Motor name to index mapping for gains
|
||||
motor_index = {
|
||||
"joint_1": 0,
|
||||
"joint_2": 1,
|
||||
"joint_3": 2,
|
||||
"joint_4": 3,
|
||||
"joint_5": 4,
|
||||
"joint_6": 5,
|
||||
"joint_7": 6,
|
||||
"gripper": 7,
|
||||
}
|
||||
|
||||
# Use batch MIT control for arm (sends all commands, then collects responses)
|
||||
commands = {}
|
||||
for motor_name, position_degrees in goal_pos.items():
|
||||
idx = motor_index.get(motor_name, 0)
|
||||
# Use custom gains if provided, otherwise use config defaults
|
||||
if custom_kp is not None and motor_name in custom_kp:
|
||||
kp = custom_kp[motor_name]
|
||||
else:
|
||||
kp = (
|
||||
self.config.position_kp[idx]
|
||||
if isinstance(self.config.position_kp, list)
|
||||
else self.config.position_kp
|
||||
)
|
||||
if custom_kd is not None and motor_name in custom_kd:
|
||||
kd = custom_kd[motor_name]
|
||||
else:
|
||||
kd = (
|
||||
self.config.position_kd[idx]
|
||||
if isinstance(self.config.position_kd, list)
|
||||
else self.config.position_kd
|
||||
)
|
||||
commands[motor_name] = (kp, kd, position_degrees, 0.0, 0.0)
|
||||
|
||||
self.bus._mit_control_batch(commands)
|
||||
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from robot."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Disconnect CAN bus
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
|
||||
# Disconnect cameras
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -65,6 +65,3 @@ class UnitreeG1Config(RobotConfig):
|
||||
|
||||
# Cameras (ZMQ-based remote cameras)
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# Compensates for gravity on the unitree's arms using the arm ik solver
|
||||
gravity_compensation: bool = False
|
||||
|
||||
@@ -18,7 +18,7 @@ from enum import IntEnum
|
||||
|
||||
# ruff: noqa: N801, N815
|
||||
|
||||
NUM_MOTORS = 29
|
||||
NUM_MOTORS = 35
|
||||
|
||||
|
||||
class G1_29_JointArmIndex(IntEnum):
|
||||
|
||||
@@ -1,313 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
parent2_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(parent2_dir)
|
||||
|
||||
|
||||
class WeightedMovingFilter:
|
||||
def __init__(self, weights, data_size=14):
|
||||
self._window_size = len(weights)
|
||||
self._weights = np.array(weights)
|
||||
self._data_size = data_size
|
||||
self._filtered_data = np.zeros(self._data_size)
|
||||
self._data_queue = []
|
||||
|
||||
def _apply_filter(self):
|
||||
if len(self._data_queue) < self._window_size:
|
||||
return self._data_queue[-1]
|
||||
|
||||
data_array = np.array(self._data_queue)
|
||||
temp_filtered_data = np.zeros(self._data_size)
|
||||
for i in range(self._data_size):
|
||||
temp_filtered_data[i] = np.convolve(data_array[:, i], self._weights, mode="valid")[-1]
|
||||
|
||||
return temp_filtered_data
|
||||
|
||||
def add_data(self, new_data):
|
||||
assert len(new_data) == self._data_size
|
||||
|
||||
if len(self._data_queue) > 0 and np.array_equal(
|
||||
new_data, self._data_queue[-1]
|
||||
): # skip duplicate data
|
||||
return
|
||||
|
||||
if len(self._data_queue) >= self._window_size:
|
||||
self._data_queue.pop(0)
|
||||
|
||||
self._data_queue.append(new_data)
|
||||
self._filtered_data = self._apply_filter()
|
||||
|
||||
@property
|
||||
def filtered_data(self):
|
||||
return self._filtered_data
|
||||
|
||||
|
||||
class G1_29_ArmIK: # noqa: N801
|
||||
def __init__(self, unit_test=False):
|
||||
import casadi
|
||||
import pinocchio as pin
|
||||
from huggingface_hub import snapshot_download
|
||||
from pinocchio import casadi as cpin
|
||||
|
||||
self._pin = pin
|
||||
np.set_printoptions(precision=5, suppress=True, linewidth=200)
|
||||
|
||||
self.unit_test = unit_test
|
||||
|
||||
self.repo_path = snapshot_download("lerobot/unitree-g1-mujoco")
|
||||
urdf_path = os.path.join(self.repo_path, "assets", "g1_body29_hand14.urdf")
|
||||
mesh_dir = os.path.join(self.repo_path, "assets")
|
||||
|
||||
self.robot = self._pin.RobotWrapper.BuildFromURDF(urdf_path, mesh_dir)
|
||||
|
||||
self.mixed_jointsToLockIDs = [
|
||||
"left_hip_pitch_joint",
|
||||
"left_hip_roll_joint",
|
||||
"left_hip_yaw_joint",
|
||||
"left_knee_joint",
|
||||
"left_ankle_pitch_joint",
|
||||
"left_ankle_roll_joint",
|
||||
"right_hip_pitch_joint",
|
||||
"right_hip_roll_joint",
|
||||
"right_hip_yaw_joint",
|
||||
"right_knee_joint",
|
||||
"right_ankle_pitch_joint",
|
||||
"right_ankle_roll_joint",
|
||||
"waist_yaw_joint",
|
||||
"waist_roll_joint",
|
||||
"waist_pitch_joint",
|
||||
"left_hand_thumb_0_joint",
|
||||
"left_hand_thumb_1_joint",
|
||||
"left_hand_thumb_2_joint",
|
||||
"left_hand_middle_0_joint",
|
||||
"left_hand_middle_1_joint",
|
||||
"left_hand_index_0_joint",
|
||||
"left_hand_index_1_joint",
|
||||
"right_hand_thumb_0_joint",
|
||||
"right_hand_thumb_1_joint",
|
||||
"right_hand_thumb_2_joint",
|
||||
"right_hand_index_0_joint",
|
||||
"right_hand_index_1_joint",
|
||||
"right_hand_middle_0_joint",
|
||||
"right_hand_middle_1_joint",
|
||||
]
|
||||
|
||||
self.reduced_robot = self.robot.buildReducedRobot(
|
||||
list_of_joints_to_lock=self.mixed_jointsToLockIDs,
|
||||
reference_configuration=np.array([0.0] * self.robot.model.nq),
|
||||
)
|
||||
|
||||
# Arm joint names in G1 motor order (G1_29_JointArmIndex)
|
||||
self._arm_joint_names_g1 = [
|
||||
"left_shoulder_pitch_joint",
|
||||
"left_shoulder_roll_joint",
|
||||
"left_shoulder_yaw_joint",
|
||||
"left_elbow_joint",
|
||||
"left_wrist_roll_joint",
|
||||
"left_wrist_pitch_joint",
|
||||
"left_wrist_yaw_joint",
|
||||
"right_shoulder_pitch_joint",
|
||||
"right_shoulder_roll_joint",
|
||||
"right_shoulder_yaw_joint",
|
||||
"right_elbow_joint",
|
||||
"right_wrist_roll_joint",
|
||||
"right_wrist_pitch_joint",
|
||||
"right_wrist_yaw_joint",
|
||||
]
|
||||
# Pinocchio uses its own joint order in q; build index mapping.
|
||||
self._arm_joint_names_pin = sorted(
|
||||
self._arm_joint_names_g1,
|
||||
key=lambda name: self.reduced_robot.model.idx_qs[self.reduced_robot.model.getJointId(name)],
|
||||
)
|
||||
logger.info(f"Pinocchio arm joint order: {self._arm_joint_names_pin}")
|
||||
self._arm_reorder_g1_to_pin = [
|
||||
self._arm_joint_names_g1.index(name) for name in self._arm_joint_names_pin
|
||||
]
|
||||
# Inverse mapping to return tau in G1 motor order.
|
||||
self._arm_reorder_pin_to_g1 = np.argsort(self._arm_reorder_g1_to_pin)
|
||||
|
||||
self.reduced_robot.model.addFrame(
|
||||
self._pin.Frame(
|
||||
"L_ee",
|
||||
self.reduced_robot.model.getJointId("left_wrist_yaw_joint"),
|
||||
self._pin.SE3(np.eye(3), np.array([0.05, 0, 0]).T),
|
||||
self._pin.FrameType.OP_FRAME,
|
||||
)
|
||||
)
|
||||
|
||||
self.reduced_robot.model.addFrame(
|
||||
self._pin.Frame(
|
||||
"R_ee",
|
||||
self.reduced_robot.model.getJointId("right_wrist_yaw_joint"),
|
||||
self._pin.SE3(np.eye(3), np.array([0.05, 0, 0]).T),
|
||||
self._pin.FrameType.OP_FRAME,
|
||||
)
|
||||
)
|
||||
|
||||
# Creating Casadi models and data for symbolic computing
|
||||
self.cmodel = cpin.Model(self.reduced_robot.model)
|
||||
self.cdata = self.cmodel.createData()
|
||||
|
||||
# Creating symbolic variables
|
||||
self.cq = casadi.SX.sym("q", self.reduced_robot.model.nq, 1)
|
||||
self.cTf_l = casadi.SX.sym("tf_l", 4, 4)
|
||||
self.cTf_r = casadi.SX.sym("tf_r", 4, 4)
|
||||
cpin.framesForwardKinematics(self.cmodel, self.cdata, self.cq)
|
||||
|
||||
# Get the hand joint ID and define the error function
|
||||
self.L_hand_id = self.reduced_robot.model.getFrameId("L_ee")
|
||||
self.R_hand_id = self.reduced_robot.model.getFrameId("R_ee")
|
||||
|
||||
self.translational_error = casadi.Function(
|
||||
"translational_error",
|
||||
[self.cq, self.cTf_l, self.cTf_r],
|
||||
[
|
||||
casadi.vertcat(
|
||||
self.cdata.oMf[self.L_hand_id].translation - self.cTf_l[:3, 3],
|
||||
self.cdata.oMf[self.R_hand_id].translation - self.cTf_r[:3, 3],
|
||||
)
|
||||
],
|
||||
)
|
||||
self.rotational_error = casadi.Function(
|
||||
"rotational_error",
|
||||
[self.cq, self.cTf_l, self.cTf_r],
|
||||
[
|
||||
casadi.vertcat(
|
||||
cpin.log3(self.cdata.oMf[self.L_hand_id].rotation @ self.cTf_l[:3, :3].T),
|
||||
cpin.log3(self.cdata.oMf[self.R_hand_id].rotation @ self.cTf_r[:3, :3].T),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Defining the optimization problem
|
||||
self.opti = casadi.Opti()
|
||||
self.var_q = self.opti.variable(self.reduced_robot.model.nq)
|
||||
self.var_q_last = self.opti.parameter(self.reduced_robot.model.nq) # for smooth
|
||||
self.param_tf_l = self.opti.parameter(4, 4)
|
||||
self.param_tf_r = self.opti.parameter(4, 4)
|
||||
self.translational_cost = casadi.sumsqr(
|
||||
self.translational_error(self.var_q, self.param_tf_l, self.param_tf_r)
|
||||
)
|
||||
self.rotation_cost = casadi.sumsqr(
|
||||
self.rotational_error(self.var_q, self.param_tf_l, self.param_tf_r)
|
||||
)
|
||||
self.regularization_cost = casadi.sumsqr(self.var_q)
|
||||
self.smooth_cost = casadi.sumsqr(self.var_q - self.var_q_last)
|
||||
|
||||
# Setting optimization constraints and goals
|
||||
self.opti.subject_to(
|
||||
self.opti.bounded(
|
||||
self.reduced_robot.model.lowerPositionLimit,
|
||||
self.var_q,
|
||||
self.reduced_robot.model.upperPositionLimit,
|
||||
)
|
||||
)
|
||||
self.opti.minimize(
|
||||
50 * self.translational_cost
|
||||
+ self.rotation_cost
|
||||
+ 0.02 * self.regularization_cost
|
||||
+ 0.1 * self.smooth_cost
|
||||
)
|
||||
|
||||
opts = {
|
||||
"ipopt": {"print_level": 0, "max_iter": 50, "tol": 1e-6},
|
||||
"print_time": False, # print or not
|
||||
"calc_lam_p": False, # https://github.com/casadi/casadi/wiki/FAQ:-Why-am-I-getting-%22NaN-detected%22in-my-optimization%3F
|
||||
}
|
||||
self.opti.solver("ipopt", opts)
|
||||
|
||||
self.init_data = np.zeros(self.reduced_robot.model.nq)
|
||||
self.smooth_filter = WeightedMovingFilter(np.array([0.4, 0.3, 0.2, 0.1]), 14)
|
||||
|
||||
def solve_ik(self, left_wrist, right_wrist, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None):
|
||||
if current_lr_arm_motor_q is not None:
|
||||
self.init_data = current_lr_arm_motor_q
|
||||
self.opti.set_initial(self.var_q, self.init_data)
|
||||
|
||||
self.opti.set_value(self.param_tf_l, left_wrist)
|
||||
self.opti.set_value(self.param_tf_r, right_wrist)
|
||||
self.opti.set_value(self.var_q_last, self.init_data) # for smooth
|
||||
|
||||
try:
|
||||
self.opti.solve()
|
||||
|
||||
sol_q = self.opti.value(self.var_q)
|
||||
self.smooth_filter.add_data(sol_q)
|
||||
sol_q = self.smooth_filter.filtered_data
|
||||
|
||||
if current_lr_arm_motor_dq is not None:
|
||||
v = current_lr_arm_motor_dq * 0.0
|
||||
else:
|
||||
v = (sol_q - self.init_data) * 0.0
|
||||
|
||||
self.init_data = sol_q
|
||||
|
||||
sol_tauff = self._pin.rnea(
|
||||
self.reduced_robot.model,
|
||||
self.reduced_robot.data,
|
||||
sol_q,
|
||||
v,
|
||||
np.zeros(self.reduced_robot.model.nv),
|
||||
)
|
||||
|
||||
return sol_q, sol_tauff
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR in convergence, plotting debug info.{e}")
|
||||
|
||||
sol_q = self.opti.debug.value(self.var_q)
|
||||
self.smooth_filter.add_data(sol_q)
|
||||
sol_q = self.smooth_filter.filtered_data
|
||||
|
||||
if current_lr_arm_motor_dq is not None:
|
||||
v = current_lr_arm_motor_dq * 0.0
|
||||
else:
|
||||
v = (sol_q - self.init_data) * 0.0
|
||||
|
||||
self.init_data = sol_q
|
||||
|
||||
logger.error(
|
||||
f"sol_q:{sol_q} \nmotorstate: \n{current_lr_arm_motor_q} \nleft_pose: \n{left_wrist} \nright_pose: \n{right_wrist}"
|
||||
)
|
||||
|
||||
return current_lr_arm_motor_q, np.zeros(self.reduced_robot.model.nv)
|
||||
|
||||
def solve_tau(self, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None):
|
||||
try:
|
||||
q_g1 = np.array(current_lr_arm_motor_q, dtype=float)
|
||||
if q_g1.shape[0] != len(self._arm_joint_names_g1):
|
||||
raise ValueError(f"Expected {len(self._arm_joint_names_g1)} arm joints, got {q_g1.shape[0]}")
|
||||
q_pin = q_g1[self._arm_reorder_g1_to_pin]
|
||||
sol_tauff = self._pin.rnea(
|
||||
self.reduced_robot.model,
|
||||
self.reduced_robot.data,
|
||||
q_pin,
|
||||
np.zeros(self.reduced_robot.model.nv),
|
||||
np.zeros(self.reduced_robot.model.nv),
|
||||
)
|
||||
return sol_tauff[self._arm_reorder_pin_to_g1]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR in convergence, plotting debug info.{e}")
|
||||
return np.zeros(self.reduced_robot.model.nv)
|
||||
@@ -27,8 +27,7 @@ import numpy as np
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex, G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
@@ -128,8 +127,6 @@ class UnitreeG1(Robot):
|
||||
self.subscribe_thread = None
|
||||
self.remote_controller = self.RemoteController()
|
||||
|
||||
self.arm_ik = G1_29_ArmIK()
|
||||
|
||||
def _subscribe_motor_state(self): # polls robot state @ 250Hz
|
||||
while not self._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
@@ -364,20 +361,6 @@ class UnitreeG1(Robot):
|
||||
self.msg.motor_cmd[motor.value].kd = self.kd[motor.value]
|
||||
self.msg.motor_cmd[motor.value].tau = 0
|
||||
|
||||
if self.config.gravity_compensation:
|
||||
# Build action_np from motor commands (arm joints are indices 15-28, local indices 0-13)
|
||||
action_np = np.zeros(14)
|
||||
arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value # 15
|
||||
for joint in G1_29_JointArmIndex:
|
||||
local_idx = joint.value - arm_start_idx
|
||||
action_np[local_idx] = self.msg.motor_cmd[joint.value].q
|
||||
tau = self.arm_ik.solve_tau(action_np)
|
||||
|
||||
# Apply tau back to motor commands
|
||||
for joint in G1_29_JointArmIndex:
|
||||
local_idx = joint.value - arm_start_idx
|
||||
self.msg.motor_cmd[joint.value].tau = tau[local_idx]
|
||||
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
return action
|
||||
|
||||
@@ -60,14 +60,6 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
from .reachy2 import Reachy2Robot
|
||||
|
||||
return Reachy2Robot(config)
|
||||
elif config.type == "openarm_follower":
|
||||
from .openarm_follower import OpenArmFollower
|
||||
|
||||
return OpenArmFollower(config)
|
||||
elif config.type == "bi_openarm_follower":
|
||||
from .bi_openarm_follower import BiOpenArmFollower
|
||||
|
||||
return BiOpenArmFollower(config)
|
||||
elif config.type == "mock_robot":
|
||||
from tests.mocks.mock_robot import MockRobot
|
||||
|
||||
|
||||
@@ -36,28 +36,23 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
lekiwi,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_so_leader,
|
||||
homunculus,
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.utils import init_logging
|
||||
@@ -86,11 +81,8 @@ def calibrate(cfg: CalibrateConfig):
|
||||
device = make_teleoperator_from_config(cfg.device)
|
||||
|
||||
device.connect(calibrate=False)
|
||||
|
||||
try:
|
||||
device.calibrate()
|
||||
finally:
|
||||
device.disconnect()
|
||||
device.calibrate()
|
||||
device.disconnect()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
Edit LeRobot datasets using various transformation tools.
|
||||
|
||||
This script allows you to delete episodes, split datasets, merge datasets,
|
||||
remove features, modify tasks, and convert image datasets to video format.
|
||||
remove features, and convert image datasets to video format.
|
||||
When new_repo_id is specified, creates a new dataset.
|
||||
|
||||
Usage Examples:
|
||||
@@ -66,25 +66,6 @@ Remove camera feature:
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
|
||||
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Pick up the cube and place it"
|
||||
|
||||
Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}'
|
||||
|
||||
Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Default task" \
|
||||
--operation.episode_tasks '{"5": "Special task for episode 5"}'
|
||||
|
||||
Convert image dataset to video format and save locally:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
@@ -119,7 +100,6 @@ from lerobot.datasets.dataset_tools import (
|
||||
convert_image_to_video_dataset,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
modify_tasks,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
@@ -152,13 +132,6 @@ class RemoveFeatureConfig:
|
||||
feature_names: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModifyTasksConfig:
|
||||
type: str = "modify_tasks"
|
||||
new_task: str | None = None
|
||||
episode_tasks: dict[str, str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvertImageToVideoConfig:
|
||||
type: str = "convert_image_to_video"
|
||||
@@ -178,12 +151,7 @@ class ConvertImageToVideoConfig:
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
operation: (
|
||||
DeleteEpisodesConfig
|
||||
| SplitConfig
|
||||
| MergeConfig
|
||||
| RemoveFeatureConfig
|
||||
| ModifyTasksConfig
|
||||
| ConvertImageToVideoConfig
|
||||
DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig
|
||||
)
|
||||
root: str | None = None
|
||||
new_repo_id: str | None = None
|
||||
@@ -328,48 +296,6 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
|
||||
|
||||
def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
||||
if not isinstance(cfg.operation, ModifyTasksConfig):
|
||||
raise ValueError("Operation config must be ModifyTasksConfig")
|
||||
|
||||
new_task = cfg.operation.new_task
|
||||
episode_tasks_raw = cfg.operation.episode_tasks
|
||||
|
||||
if new_task is None and episode_tasks_raw is None:
|
||||
raise ValueError("Must specify at least one of new_task or episode_tasks for modify_tasks operation")
|
||||
|
||||
# Warn about in-place modification behavior
|
||||
if cfg.new_repo_id is not None:
|
||||
logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.")
|
||||
|
||||
# Convert episode_tasks keys from string to int if needed (CLI passes strings)
|
||||
episode_tasks: dict[int, str] | None = None
|
||||
if episode_tasks_raw is not None:
|
||||
episode_tasks = {int(k): v for k, v in episode_tasks_raw.items()}
|
||||
|
||||
logging.info(f"Modifying tasks in {cfg.repo_id}")
|
||||
if new_task:
|
||||
logging.info(f" Default task: '{new_task}'")
|
||||
if episode_tasks:
|
||||
logging.info(f" Episode-specific tasks: {episode_tasks}")
|
||||
|
||||
modified_dataset = modify_tasks(
|
||||
dataset,
|
||||
new_task=new_task,
|
||||
episode_tasks=episode_tasks,
|
||||
)
|
||||
|
||||
logging.info(f"Dataset modified at {dataset.root}")
|
||||
logging.info(f"Tasks: {list(modified_dataset.meta.tasks.index)}")
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {cfg.repo_id}")
|
||||
modified_dataset.push_to_hub()
|
||||
|
||||
|
||||
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
# Note: Parser may create any config type with the right fields, so we access fields directly
|
||||
# instead of checking isinstance()
|
||||
@@ -445,14 +371,12 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_merge(cfg)
|
||||
elif operation_type == "remove_feature":
|
||||
handle_remove_feature(cfg)
|
||||
elif operation_type == "modify_tasks":
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown operation type: {operation_type}\n"
|
||||
f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video"
|
||||
f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -44,23 +44,19 @@ import numpy as np
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_so_leader,
|
||||
gamepad,
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
so_leader,
|
||||
)
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
@@ -98,31 +98,26 @@ from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
earthrover_mini_plus,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
reachy2,
|
||||
so_follower,
|
||||
unitree_g1 as unitree_g1_robot,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_so_leader,
|
||||
homunculus,
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
|
||||
@@ -53,14 +53,12 @@ from lerobot.processor import (
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
earthrover_mini_plus,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
reachy2,
|
||||
so_follower,
|
||||
unitree_g1,
|
||||
@@ -110,26 +108,25 @@ def replay(cfg: ReplayConfig):
|
||||
|
||||
robot.connect()
|
||||
|
||||
try:
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(len(episode_frames)):
|
||||
start_episode_t = time.perf_counter()
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(len(episode_frames)):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
action[name] = action_array[i]
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
action[name] = action_array[i]
|
||||
|
||||
robot_obs = robot.get_observation()
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
processed_action = robot_action_processor((action, robot_obs))
|
||||
processed_action = robot_action_processor((action, robot_obs))
|
||||
|
||||
_ = robot.send_action(processed_action)
|
||||
_ = robot.send_action(processed_action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
|
||||
finally:
|
||||
robot.disconnect()
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -70,22 +70,18 @@ from lerobot.processor import (
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
earthrover_mini_plus,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
reachy2,
|
||||
so_follower,
|
||||
unitree_g1 as unitree_g1_robot,
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_so_leader,
|
||||
gamepad,
|
||||
homunculus,
|
||||
@@ -93,10 +89,8 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
@@ -338,21 +338,11 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
# loop over dataset subtask parquet file to find episode indices that don't have subtask index != -1
|
||||
# valid_episode_list passed to episode_indexes_to_use
|
||||
valid_episode_list = []
|
||||
for episode_idx in range(len(dataset.meta.episodes)):
|
||||
subtask_index = dataset[episode_idx]["subtask_index"]
|
||||
if subtask_index != -1:
|
||||
valid_episode_list.append(episode_idx)
|
||||
|
||||
episode_indices_to_use = valid_episode_list
|
||||
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=episode_indices_to_use,
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .bi_openarm_leader import BiOpenArmLeader
|
||||
from .config_bi_openarm_leader import BiOpenArmLeaderConfig
|
||||
|
||||
__all__ = ["BiOpenArmLeader", "BiOpenArmLeaderConfig"]
|
||||
@@ -1,131 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig
|
||||
|
||||
from ..openarm_leader import OpenArmLeader
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_bi_openarm_leader import BiOpenArmLeaderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiOpenArmLeader(Teleoperator):
|
||||
"""
|
||||
Bimanual OpenArm Leader Arms
|
||||
"""
|
||||
|
||||
config_class = BiOpenArmLeaderConfig
|
||||
name = "bi_openarm_leader"
|
||||
|
||||
def __init__(self, config: BiOpenArmLeaderConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
left_arm_config = OpenArmLeaderConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.left_arm_config.port,
|
||||
can_interface=config.left_arm_config.can_interface,
|
||||
use_can_fd=config.left_arm_config.use_can_fd,
|
||||
can_bitrate=config.left_arm_config.can_bitrate,
|
||||
can_data_bitrate=config.left_arm_config.can_data_bitrate,
|
||||
motor_config=config.left_arm_config.motor_config,
|
||||
manual_control=config.left_arm_config.manual_control,
|
||||
position_kd=config.left_arm_config.position_kd,
|
||||
position_kp=config.left_arm_config.position_kp,
|
||||
)
|
||||
|
||||
right_arm_config = OpenArmLeaderConfig(
|
||||
id=f"{config.id}_right" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.right_arm_config.port,
|
||||
can_interface=config.right_arm_config.can_interface,
|
||||
use_can_fd=config.right_arm_config.use_can_fd,
|
||||
can_bitrate=config.right_arm_config.can_bitrate,
|
||||
can_data_bitrate=config.right_arm_config.can_data_bitrate,
|
||||
motor_config=config.right_arm_config.motor_config,
|
||||
manual_control=config.right_arm_config.manual_control,
|
||||
position_kd=config.right_arm_config.position_kd,
|
||||
position_kp=config.right_arm_config.position_kp,
|
||||
)
|
||||
|
||||
self.left_arm = OpenArmLeader(left_arm_config)
|
||||
self.right_arm = OpenArmLeader(right_arm_config)
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
left_arm_features = self.left_arm.action_features
|
||||
right_arm_features = self.right_arm.action_features
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_features.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_features.items()},
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
def get_action(self) -> RobotAction:
|
||||
action_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
left_action = self.left_arm.get_action()
|
||||
action_dict.update({f"left_{key}": value for key, value in left_action.items()})
|
||||
|
||||
# Add "right_" prefix
|
||||
right_action = self.right_arm.get_action()
|
||||
action_dict.update({f"right_{key}": value for key, value in right_action.items()})
|
||||
|
||||
return action_dict
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
@@ -1,30 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfigBase
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("bi_openarm_leader")
|
||||
@dataclass
|
||||
class BiOpenArmLeaderConfig(TeleoperatorConfig):
|
||||
"""Configuration class for Bi OpenArm Follower robots."""
|
||||
|
||||
left_arm_config: OpenArmLeaderConfigBase
|
||||
right_arm_config: OpenArmLeaderConfigBase
|
||||
@@ -1,20 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_openarm_leader import OpenArmLeaderConfig, OpenArmLeaderConfigBase
|
||||
from .openarm_leader import OpenArmLeader
|
||||
|
||||
__all__ = ["OpenArmLeader", "OpenArmLeaderConfig", "OpenArmLeaderConfigBase"]
|
||||
@@ -1,75 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenArmLeaderConfigBase:
|
||||
"""Base configuration for the OpenArms leader/teleoperator with Damiao motors."""
|
||||
|
||||
# CAN interfaces - one per arm
|
||||
# Arm CAN interface (e.g., "can3")
|
||||
# Linux: "can0", "can1", etc.
|
||||
port: str
|
||||
|
||||
# CAN interface type: "socketcan" (Linux), "slcan" (serial), or "auto" (auto-detect)
|
||||
can_interface: str = "socketcan"
|
||||
|
||||
# CAN FD settings (OpenArms uses CAN FD by default)
|
||||
use_can_fd: bool = True
|
||||
can_bitrate: int = 1000000 # Nominal bitrate (1 Mbps)
|
||||
can_data_bitrate: int = 5000000 # Data bitrate for CAN FD (5 Mbps)
|
||||
|
||||
# Motor configuration for OpenArms (7 DOF per arm)
|
||||
# Maps motor names to (send_can_id, recv_can_id, motor_type)
|
||||
# Based on: https://docs.openarm.dev/software/setup/configure-test
|
||||
# OpenArms uses 4 types of motors:
|
||||
# - DM8009 (DM-J8009P-2EC) for shoulders (high torque)
|
||||
# - DM4340P and DM4340 for shoulder rotation and elbow
|
||||
# - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper
|
||||
motor_config: dict[str, tuple[int, int, str]] = field(
|
||||
default_factory=lambda: {
|
||||
"joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009)
|
||||
"joint_2": (0x02, 0x12, "dm8009"), # J2 - Shoulder lift (DM8009)
|
||||
"joint_3": (0x03, 0x13, "dm4340"), # J3 - Shoulder rotation (DM4340)
|
||||
"joint_4": (0x04, 0x14, "dm4340"), # J4 - Elbow flex (DM4340)
|
||||
"joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310)
|
||||
"joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310)
|
||||
"joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310)
|
||||
"gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310)
|
||||
}
|
||||
)
|
||||
|
||||
# Torque mode settings for manual control
|
||||
# When enabled, motors have torque disabled for manual movement
|
||||
manual_control: bool = True
|
||||
|
||||
# TODO(Steven, Pepijn): Not used ... ?
|
||||
# MIT control parameters (used when manual_control=False for torque control)
|
||||
# List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
|
||||
position_kp: list[float] = field(
|
||||
default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 16.0]
|
||||
)
|
||||
position_kd: list[float] = field(default_factory=lambda: [3.0, 3.0, 3.0, 3.0, 0.2, 0.2, 0.2, 0.2])
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("openarm_leader")
|
||||
@dataclass
|
||||
class OpenArmLeaderConfig(TeleoperatorConfig, OpenArmLeaderConfigBase):
|
||||
pass
|
||||
@@ -1,225 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.damiao import DamiaoMotorsBus
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_openarm_leader import OpenArmLeaderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenArmLeader(Teleoperator):
|
||||
"""
|
||||
OpenArm Leader/Teleoperator Arm with Damiao motors.
|
||||
|
||||
This teleoperator uses CAN bus communication to read positions from
|
||||
Damiao motors that are manually moved (torque disabled).
|
||||
"""
|
||||
|
||||
config_class = OpenArmLeaderConfig
|
||||
name = "openarm_leader"
|
||||
|
||||
def __init__(self, config: OpenArmLeaderConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Arm motors
|
||||
motors: dict[str, Motor] = {}
|
||||
for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items():
|
||||
motor = Motor(
|
||||
send_id, motor_type_str, MotorNormMode.DEGREES
|
||||
) # Always use degrees for Damiao motors
|
||||
motor.recv_id = recv_id
|
||||
motor.motor_type_str = motor_type_str
|
||||
motors[motor_name] = motor
|
||||
|
||||
self.bus = DamiaoMotorsBus(
|
||||
port=self.config.port,
|
||||
motors=motors,
|
||||
calibration=self.calibration,
|
||||
can_interface=self.config.can_interface,
|
||||
use_can_fd=self.config.use_can_fd,
|
||||
bitrate=self.config.can_bitrate,
|
||||
data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None,
|
||||
)
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
"""Features produced by this teleoperator."""
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus.motors:
|
||||
features[f"{motor}.pos"] = float
|
||||
features[f"{motor}.vel"] = float
|
||||
features[f"{motor}.torque"] = float
|
||||
return features
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
"""Feedback features (not implemented for OpenArms)."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if teleoperator is connected."""
|
||||
return self.bus.is_connected
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
Connect to the teleoperator.
|
||||
|
||||
For manual control, we disable torque after connecting so the
|
||||
arm can be moved by hand.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
# Connect to CAN bus
|
||||
logger.info(f"Connecting arm on {self.config.port}...")
|
||||
self.bus.connect()
|
||||
|
||||
# Run calibration if needed
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
|
||||
)
|
||||
self.calibrate()
|
||||
|
||||
self.configure()
|
||||
|
||||
if self.is_calibrated:
|
||||
self.bus.set_zero_position()
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
"""Check if teleoperator is calibrated."""
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""
|
||||
Run calibration procedure for OpenArms leader.
|
||||
|
||||
The calibration procedure:
|
||||
1. Disable torque (if not already disabled)
|
||||
2. Ask user to position arm in zero position (hanging with gripper closed)
|
||||
3. Set this as zero position
|
||||
4. Record range of motion for each joint
|
||||
5. Save calibration
|
||||
"""
|
||||
if self.calibration:
|
||||
# Calibration file exists, ask user whether to use it or run new calibration
|
||||
user_input = input(
|
||||
f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
|
||||
)
|
||||
if user_input.strip().lower() != "c":
|
||||
logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
|
||||
self.bus.write_calibration(self.calibration)
|
||||
return
|
||||
|
||||
logger.info(f"\nRunning calibration for {self}")
|
||||
self.bus.disable_torque()
|
||||
|
||||
# Step 1: Set zero position
|
||||
input(
|
||||
"\nCalibration: Set Zero Position)\n"
|
||||
"Position the arm in the following configuration:\n"
|
||||
" - Arm hanging straight down\n"
|
||||
" - Gripper closed\n"
|
||||
"Press ENTER when ready..."
|
||||
)
|
||||
|
||||
# Set current position as zero for all motors
|
||||
self.bus.set_zero_position()
|
||||
logger.info("Arm zero position set.")
|
||||
|
||||
logger.info("Setting range: -90° to +90° by default for all joints")
|
||||
# TODO(Steven, Pepijn): Check if MotorCalibration is actually needed here given that we only use Degrees
|
||||
for motor_name, motor in self.bus.motors.items():
|
||||
self.calibration[motor_name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=0,
|
||||
homing_offset=0,
|
||||
range_min=-90,
|
||||
range_max=90,
|
||||
)
|
||||
|
||||
self.bus.write_calibration(self.calibration)
|
||||
self._save_calibration()
|
||||
print(f"Calibration saved to {self.calibration_fpath}")
|
||||
|
||||
def configure(self) -> None:
|
||||
"""
|
||||
Configure motors for manual teleoperation.
|
||||
|
||||
For manual control, we disable torque so the arm can be moved by hand.
|
||||
"""
|
||||
|
||||
return self.bus.disable_torque() if self.config.manual_control else self.bus.configure_motors()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
def get_action(self) -> RobotAction:
|
||||
"""
|
||||
Get current action from the leader arm.
|
||||
|
||||
This is the main method for teleoperators - it reads the current state
|
||||
of the leader arm and returns it as an action that can be sent to a follower.
|
||||
|
||||
Reads all motor states (pos/vel/torque) in one CAN refresh cycle.
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
action_dict: dict[str, Any] = {}
|
||||
|
||||
# Use sync_read_all_states to get pos/vel/torque in one go
|
||||
states = self.bus.sync_read_all_states()
|
||||
for motor in self.bus.motors:
|
||||
state = states.get(motor, {})
|
||||
action_dict[f"{motor}.pos"] = state.get("position")
|
||||
action_dict[f"{motor}.vel"] = state.get("velocity")
|
||||
action_dict[f"{motor}.torque"] = state.get("torque")
|
||||
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||
|
||||
return action_dict
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError("Feedback is not yet implemented for OpenArm leader.")
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from teleoperator."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Disconnect CAN bus
|
||||
# For manual control, ensure torque is disabled before disconnecting
|
||||
self.bus.disconnect(disable_torque=self.config.manual_control)
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -1,21 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_unitree_g1 import ExoskeletonArmPortConfig, UnitreeG1TeleoperatorConfig
|
||||
from .exo_calib import ExoskeletonCalibration, ExoskeletonJointCalibration
|
||||
from .exo_ik import ExoskeletonIKHelper
|
||||
from .exo_serial import ExoskeletonArm
|
||||
from .unitree_g1 import UnitreeG1Teleoperator
|
||||
@@ -1,37 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExoskeletonArmPortConfig:
|
||||
"""Serial port configuration for individual exoskeleton arm."""
|
||||
|
||||
port: str = ""
|
||||
baud_rate: int = 115200
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("unitree_g1")
|
||||
@dataclass
|
||||
class UnitreeG1TeleoperatorConfig(TeleoperatorConfig):
|
||||
left_arm_config: ExoskeletonArmPortConfig = field(default_factory=ExoskeletonArmPortConfig)
|
||||
right_arm_config: ExoskeletonArmPortConfig = field(default_factory=ExoskeletonArmPortConfig)
|
||||
|
||||
# Frozen joints (comma-separated joint names that won't be moved by IK)
|
||||
frozen_joints: str = ""
|
||||
@@ -1,446 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This module handles calibration of hall effect sensors used in the exoskeleton.
|
||||
Each joint has a pair of ADC channels outputting sin and cos values that trace an ellipse
|
||||
as the joint rotates due to imprecision in magnet/sensor placement. We fit this ellipse to a unit circle,
|
||||
and calculate arctan2 of the unit circle to get the joint angle.
|
||||
We then store the ellipse parameters and the zero offset for each joint to be used at runtime.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import serial
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# exoskeleton joint names -> ADC channel pairs. TODO: add wrist pitch and wrist yaw
|
||||
JOINTS = {
|
||||
"shoulder_pitch": (0, 1),
|
||||
"shoulder_yaw": (2, 3),
|
||||
"shoulder_roll": (4, 5),
|
||||
"elbow_flex": (6, 7),
|
||||
"wrist_roll": (14, 15),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExoskeletonJointCalibration:
|
||||
name: str # joint name
|
||||
center_fit: list[float] # center of the ellipse
|
||||
T: list[list[float]] # 2x2 transformation matrix
|
||||
zero_offset: float = 0.0 # angle at neutral pose
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExoskeletonCalibration:
|
||||
"""Full calibration data for an exoskeleton arm."""
|
||||
|
||||
version: int = 2
|
||||
side: str = ""
|
||||
adc_max: int = 2**12 - 1
|
||||
joints: list[ExoskeletonJointCalibration] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"version": self.version,
|
||||
"side": self.side,
|
||||
"adc_max": self.adc_max,
|
||||
"joints": [
|
||||
{
|
||||
"name": j.name,
|
||||
"center_fit": j.center_fit,
|
||||
"T": j.T,
|
||||
"zero_offset": j.zero_offset,
|
||||
}
|
||||
for j in self.joints
|
||||
],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "ExoskeletonCalibration":
|
||||
joints = [
|
||||
ExoskeletonJointCalibration(
|
||||
name=j["name"],
|
||||
center_fit=j["center_fit"],
|
||||
T=j["T"],
|
||||
zero_offset=j.get("zero_offset", 0.0),
|
||||
)
|
||||
for j in data.get("joints", [])
|
||||
]
|
||||
return cls(
|
||||
version=data.get("version", 2),
|
||||
side=data.get("side", ""),
|
||||
adc_max=data.get("adc_max", 2**12 - 1),
|
||||
joints=joints,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CalibParams:
|
||||
fit_every: float = 0.15
|
||||
min_fit_points: int = 60
|
||||
fit_window: int = 900
|
||||
max_fit_points: int = 300
|
||||
trim_low: float = 0.05
|
||||
trim_high: float = 0.95
|
||||
median_window: int = 5
|
||||
history: int = 3500
|
||||
draw_hz: float = 120.0
|
||||
sample_count: int = 50
|
||||
|
||||
|
||||
def normalize_angle(angle: float) -> float:
|
||||
while angle > np.pi:
|
||||
angle -= 2 * np.pi
|
||||
while angle < -np.pi:
|
||||
angle += 2 * np.pi
|
||||
return angle
|
||||
|
||||
|
||||
def joint_z_and_angle(raw16: list[int], j: ExoskeletonJointCalibration) -> tuple[np.ndarray, float]:
|
||||
"""
|
||||
Applies calibration to each joint: raw → centered → ellipse-to-circle → angle.
|
||||
"""
|
||||
pair = JOINTS[j.name]
|
||||
s, c = raw16[pair[0]], raw16[pair[1]] # get sin and cos
|
||||
p = np.array([float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2]) # center the raw values
|
||||
z = np.asarray(j.T) @ (
|
||||
p - np.asarray(j.center_fit)
|
||||
) # center the ellipse and invert the transformation matrix to get unit circle coords
|
||||
ang = float(np.arctan2(z[1], z[0])) - j.zero_offset # calculate the anvgle and apply the zero offset
|
||||
return z, normalize_angle(-ang) # ensure range is [-pi, pi]
|
||||
|
||||
|
||||
def exo_raw_to_angles(raw16: list[int], calib: ExoskeletonCalibration) -> dict[str, float]:
|
||||
"""Convert raw sensor readings to joint angles using calibration."""
|
||||
return {j.name: joint_z_and_angle(raw16, j)[1] for j in calib.joints}
|
||||
|
||||
|
||||
def run_exo_calibration(
|
||||
ser: serial.Serial,
|
||||
side: str,
|
||||
save_path: Path,
|
||||
params: CalibParams | None = None,
|
||||
) -> ExoskeletonCalibration:
|
||||
"""
|
||||
Run interactive calibration for an exoskeleton arm.
|
||||
"""
|
||||
try:
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Calibration requires matplotlib and opencv-python. "
|
||||
"Install with: pip install matplotlib opencv-python"
|
||||
) from e
|
||||
|
||||
from .exo_serial import read_raw_from_serial
|
||||
|
||||
params = params or CalibParams()
|
||||
joint_list = list(JOINTS.items()) # Convert dict to list for indexing
|
||||
logger.info(f"Starting calibration for {side} exoskeleton arm")
|
||||
|
||||
def running_median(win: deque) -> float:
|
||||
return float(np.median(np.fromiter(win, dtype=float)))
|
||||
|
||||
def read_joint_point(raw16: list[int], pair: tuple[int, int]):
|
||||
s, c = raw16[pair[0]], raw16[pair[1]]
|
||||
return float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2, float(s), float(c)
|
||||
|
||||
def select_fit_subset(xs, ys):
|
||||
"""Select and filter points for ellipse fitting. Trims outliers by radius and downsamples."""
|
||||
n = min(params.fit_window, len(xs))
|
||||
if n <= 0:
|
||||
return None, None
|
||||
x = np.asarray(list(xs)[-n:], dtype=float) # most recent n samples
|
||||
y = np.asarray(list(ys)[-n:], dtype=float)
|
||||
r = np.sqrt(x * x + y * y) # radius from origin
|
||||
if len(r) >= 20:
|
||||
lo, hi = np.quantile(r, params.trim_low), np.quantile(r, params.trim_high) # outlier bounds
|
||||
keep = (r >= lo) & (r <= hi)
|
||||
x, y = x[keep], y[keep] # remove outliers
|
||||
if len(x) > params.max_fit_points:
|
||||
idx = np.linspace(0, len(x) - 1, params.max_fit_points).astype(int) # downsample evenly
|
||||
x, y = x[idx], y[idx]
|
||||
return x, y
|
||||
|
||||
def fit_ellipse_opencv(x, y):
|
||||
"""Fit ellipse to (x,y) points using OpenCV. Returns center, axes, rotation matrix, and outline."""
|
||||
x, y = np.asarray(x, dtype=float), np.asarray(y, dtype=float)
|
||||
if len(x) < 5:
|
||||
return None
|
||||
pts = np.stack([x, y], axis=1).astype(np.float32).reshape(-1, 1, 2)
|
||||
try:
|
||||
(xc, yc), (w, h), angle_deg = cv2.fitEllipse(pts) # returns center, axes, rotation in degrees
|
||||
except cv2.error:
|
||||
return None
|
||||
a, b = float(w) * 0.5, float(h) * 0.5 # get ellipse major and minor semi-axes
|
||||
phi = np.deg2rad(float(angle_deg)) # to rad
|
||||
if b > a: # ensure major axis is a
|
||||
a, b = b, a
|
||||
phi += np.pi / 2.0
|
||||
if not np.isfinite(a) or not np.isfinite(b) or a <= 1e-6 or b <= 1e-6:
|
||||
return None
|
||||
cp, sp = float(np.cos(phi)), float(np.sin(phi)) #
|
||||
rot = np.array([[cp, -sp], [sp, cp]], dtype=float) # 2x2 rotation matrix
|
||||
center = np.array([float(xc), float(yc)], dtype=float) # offset vector
|
||||
tt = np.linspace(0, 2 * np.pi, 360)
|
||||
outline = (rot @ np.stack([a * np.cos(tt), b * np.sin(tt)])).T + center # for viz
|
||||
return {"center": center, "a": a, "b": b, "R": rot, "ex": outline[:, 0], "ey": outline[:, 1]}
|
||||
|
||||
# Setup matplotlib
|
||||
plt.ion()
|
||||
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(12, 6))
|
||||
ax0.set_xlabel("cos - center")
|
||||
ax0.set_ylabel("sin - center")
|
||||
ax0.grid(True, alpha=0.25)
|
||||
ax0.set_aspect("equal", adjustable="box")
|
||||
ax1.set_title("Unit circle + angle")
|
||||
ax1.set_xlabel("x")
|
||||
ax1.set_ylabel("y")
|
||||
ax1.grid(True, alpha=0.25)
|
||||
ax1.set_aspect("equal", adjustable="box")
|
||||
tt = np.linspace(0, 2 * np.pi, 360)
|
||||
ax1.plot(np.cos(tt), np.sin(tt), "k-", linewidth=1)
|
||||
ax0.set_xlim(-2200, 2200)
|
||||
ax0.set_ylim(-2200, 2200)
|
||||
ax1.set_xlim(-1.4, 1.4)
|
||||
ax1.set_ylim(-1.4, 1.4)
|
||||
|
||||
sc0 = ax0.scatter([], [], s=6, animated=True)
|
||||
(ell_line,) = ax0.plot([], [], "r-", linewidth=2, animated=True)
|
||||
sc1 = ax1.scatter([], [], s=6, animated=True)
|
||||
(radius_line,) = ax1.plot([], [], "g-", linewidth=2, animated=True)
|
||||
angle_text = ax1.text(
|
||||
0.02, 0.98, "", transform=ax1.transAxes, va="top", ha="left", fontsize=12, animated=True
|
||||
)
|
||||
|
||||
fig.canvas.draw()
|
||||
bg0 = fig.canvas.copy_from_bbox(ax0.bbox)
|
||||
bg1 = fig.canvas.copy_from_bbox(ax1.bbox)
|
||||
|
||||
# State
|
||||
joints_out = []
|
||||
joint_idx = 0
|
||||
phase = "ellipse"
|
||||
advance_requested = False
|
||||
zero_samples = []
|
||||
|
||||
def on_key(event):
|
||||
nonlocal advance_requested
|
||||
if event.key in ("n", "N", "enter", " "):
|
||||
advance_requested = True
|
||||
|
||||
fig.canvas.mpl_connect("key_press_event", on_key)
|
||||
|
||||
def reset_state():
|
||||
return {
|
||||
"xs": deque(maxlen=params.history),
|
||||
"ys": deque(maxlen=params.history),
|
||||
"xu": deque(maxlen=params.history),
|
||||
"yu": deque(maxlen=params.history),
|
||||
"win_s": deque(maxlen=params.median_window),
|
||||
"win_c": deque(maxlen=params.median_window),
|
||||
"ellipse_cache": None,
|
||||
"T": None,
|
||||
"center_fit": None,
|
||||
"have_transform": False,
|
||||
"latest_z": None,
|
||||
"last_fit": 0.0,
|
||||
}
|
||||
|
||||
state = reset_state()
|
||||
last_draw = 0.0
|
||||
name, pair = joint_list[joint_idx]
|
||||
fig.canvas.manager.set_window_title(f"[{joint_idx + 1}/{len(joint_list)}] {name} - ELLIPSE")
|
||||
ax0.set_title(f"{name} raw (filtered)")
|
||||
logger.info(f"[{joint_idx + 1}/{len(joint_list)}] Calibrating {name}")
|
||||
logger.info("Step 1: Move joint around to map ellipse, then press 'n'")
|
||||
|
||||
try:
|
||||
while plt.fignum_exists(fig.number):
|
||||
name, pair = joint_list[joint_idx]
|
||||
|
||||
# Handles calibration GUI state: ellipse → zero_pose → next joint -> ellipse -> ...
|
||||
if phase == "ellipse" and advance_requested and state["have_transform"]:
|
||||
joints_out.append(
|
||||
{
|
||||
"name": name,
|
||||
"center_fit": state["center_fit"].tolist(),
|
||||
"T": state["T"].tolist(),
|
||||
}
|
||||
)
|
||||
logger.info(f" -> Ellipse saved for {name}")
|
||||
phase, zero_samples, advance_requested = "zero_pose", [], False
|
||||
fig.canvas.manager.set_window_title(f"[{joint_idx + 1}/{len(joint_list)}] {name} - ZERO POSE")
|
||||
ax0.set_title(f"{name} - hold zero pose")
|
||||
fig.canvas.draw()
|
||||
bg0, bg1 = fig.canvas.copy_from_bbox(ax0.bbox), fig.canvas.copy_from_bbox(ax1.bbox)
|
||||
logger.info(f"Step 2: Hold {name} in zero position, then press 'n'")
|
||||
|
||||
elif phase == "ellipse" and advance_requested and not state["have_transform"]:
|
||||
logger.info(" (Need valid fit first - keep moving the joint)")
|
||||
advance_requested = False
|
||||
|
||||
elif phase == "zero_pose" and advance_requested:
|
||||
if len(zero_samples) >= params.sample_count:
|
||||
zero_offset = float(np.mean(zero_samples[-params.sample_count :]))
|
||||
joints_out[-1]["zero_offset"] = zero_offset
|
||||
logger.info(f" -> {name} zero: {zero_offset:+.3f} rad ({np.degrees(zero_offset):+.1f}°)")
|
||||
joint_idx += 1
|
||||
advance_requested = False
|
||||
|
||||
if joint_idx >= len(joint_list):
|
||||
# All joints done
|
||||
calib = ExoskeletonCalibration(
|
||||
version=2,
|
||||
side=side,
|
||||
adc_max=2**12 - 1,
|
||||
joints=[
|
||||
ExoskeletonJointCalibration(
|
||||
name=j["name"],
|
||||
center_fit=j["center_fit"],
|
||||
T=j["T"],
|
||||
zero_offset=j.get("zero_offset", 0.0),
|
||||
)
|
||||
for j in joints_out
|
||||
],
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(save_path, "w") as f:
|
||||
json.dump(calib.to_dict(), f, indent=2)
|
||||
logger.info(f"Saved calibration to {save_path}")
|
||||
logger.info("Calibration complete!")
|
||||
plt.close(fig)
|
||||
return calib
|
||||
|
||||
# Next joint
|
||||
phase, state = "ellipse", reset_state()
|
||||
name, pair = joint_list[joint_idx]
|
||||
fig.canvas.manager.set_window_title(
|
||||
f"[{joint_idx + 1}/{len(joint_list)}] {name} - ELLIPSE"
|
||||
)
|
||||
ax0.set_title(f"{name} raw (filtered)")
|
||||
fig.canvas.draw()
|
||||
bg0, bg1 = fig.canvas.copy_from_bbox(ax0.bbox), fig.canvas.copy_from_bbox(ax1.bbox)
|
||||
logger.info(f"[{joint_idx + 1}/{len(joint_list)}] Calibrating {name}")
|
||||
logger.info("Step 1: Move joint around to map ellipse, then press 'n'")
|
||||
else:
|
||||
logger.info(
|
||||
f" (Collecting samples: {len(zero_samples)}/{params.sample_count} - hold still)"
|
||||
)
|
||||
advance_requested = False
|
||||
|
||||
# Read sensor
|
||||
raw16 = read_raw_from_serial(ser)
|
||||
if raw16 is not None:
|
||||
x_raw, y_raw, s_raw, c_raw = read_joint_point(raw16, pair)
|
||||
|
||||
if phase == "ellipse":
|
||||
if state["have_transform"]:
|
||||
z = state["T"] @ (np.array([x_raw, y_raw]) - state["center_fit"])
|
||||
state["xu"].append(float(z[0]))
|
||||
state["yu"].append(float(z[1]))
|
||||
state["latest_z"] = (float(z[0]), float(z[1]))
|
||||
state["win_s"].append(s_raw)
|
||||
state["win_c"].append(c_raw)
|
||||
if len(state["win_s"]) >= max(3, params.median_window):
|
||||
state["ys"].append(running_median(state["win_s"]) - (2**12 - 1) / 2)
|
||||
state["xs"].append(running_median(state["win_c"]) - (2**12 - 1) / 2)
|
||||
else:
|
||||
jdata = joints_out[-1]
|
||||
z = np.array(jdata["T"]) @ (np.array([x_raw, y_raw]) - np.array(jdata["center_fit"]))
|
||||
zero_samples.append(float(np.arctan2(z[1], z[0])))
|
||||
state["latest_z"] = (float(z[0]), float(z[1]))
|
||||
|
||||
# Ellipse fitting
|
||||
t = time.time()
|
||||
if (
|
||||
phase == "ellipse"
|
||||
and (t - state["last_fit"]) >= params.fit_every
|
||||
and len(state["xs"]) >= params.min_fit_points
|
||||
):
|
||||
xfit, yfit = select_fit_subset(state["xs"], state["ys"])
|
||||
if xfit is not None and len(xfit) >= params.min_fit_points:
|
||||
fit = fit_ellipse_opencv(xfit, yfit)
|
||||
if fit is not None:
|
||||
state["center_fit"] = fit["center"]
|
||||
state["T"] = np.diag([1.0 / fit["a"], 1.0 / fit["b"]]) @ fit["R"].T
|
||||
state["ellipse_cache"] = (fit["ex"], fit["ey"])
|
||||
state["have_transform"] = True
|
||||
state["last_fit"] = t
|
||||
|
||||
# Drawing
|
||||
if (t - last_draw) >= 1.0 / params.draw_hz:
|
||||
fig.canvas.restore_region(bg0)
|
||||
fig.canvas.restore_region(bg1)
|
||||
|
||||
if phase == "ellipse":
|
||||
sc0.set_offsets(np.c_[state["xs"], state["ys"]] if state["xs"] else np.empty((0, 2)))
|
||||
ax0.draw_artist(sc0)
|
||||
ell_line.set_data(*state["ellipse_cache"] if state["ellipse_cache"] else ([], []))
|
||||
ax0.draw_artist(ell_line)
|
||||
sc1.set_offsets(np.c_[state["xu"], state["yu"]] if state["xu"] else np.empty((0, 2)))
|
||||
ax1.draw_artist(sc1)
|
||||
if state["latest_z"]:
|
||||
zx, zy = state["latest_z"]
|
||||
radius_line.set_data([0.0, zx], [0.0, zy])
|
||||
ang = float(np.arctan2(zy, zx))
|
||||
angle_text.set_text(
|
||||
f"angle: {ang:+.3f} rad ({np.degrees(ang):+.1f}°)\nmove {name}, press 'n' to advance"
|
||||
)
|
||||
else:
|
||||
radius_line.set_data([], [])
|
||||
angle_text.set_text("(waiting for fit)")
|
||||
else:
|
||||
sc0.set_offsets(np.empty((0, 2)))
|
||||
ax0.draw_artist(sc0)
|
||||
ell_line.set_data([], [])
|
||||
ax0.draw_artist(ell_line)
|
||||
if state["latest_z"]:
|
||||
zx, zy = state["latest_z"]
|
||||
sc1.set_offsets([[zx, zy]])
|
||||
radius_line.set_data([0.0, zx], [0.0, zy])
|
||||
ang = float(np.arctan2(zy, zx))
|
||||
angle_text.set_text(
|
||||
f"Zero pose for {name}\nangle: {ang:+.3f} rad\nsamples: {len(zero_samples)}/{params.sample_count}\nhold still, press 'n'"
|
||||
)
|
||||
else:
|
||||
sc1.set_offsets(np.empty((0, 2)))
|
||||
radius_line.set_data([], [])
|
||||
angle_text.set_text("(waiting for data)")
|
||||
ax1.draw_artist(sc1)
|
||||
|
||||
ax1.draw_artist(radius_line)
|
||||
ax1.draw_artist(angle_text)
|
||||
fig.canvas.blit(ax0.bbox)
|
||||
fig.canvas.blit(ax1.bbox)
|
||||
fig.canvas.flush_events()
|
||||
last_draw = t
|
||||
|
||||
plt.pause(0.001)
|
||||
|
||||
finally:
|
||||
plt.close(fig)
|
||||
@@ -1,353 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
IK helper for exoskeleton-to-G1 teleoperation. We map Exoskeleton joint angles to end-effector pose in world frame,
|
||||
visualizing the result in meshcat after calibration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex
|
||||
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
|
||||
|
||||
from .exo_calib import JOINTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _frame_id(model, name: str) -> int | None:
|
||||
try:
|
||||
fid = model.getFrameId(name)
|
||||
return fid if 0 <= fid < model.nframes else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArmCfg:
|
||||
side: str # "left" | "right"
|
||||
urdf: str # exo_left.urdf / exo_right.urdf
|
||||
root: str # "exo_left" / "exo_right"
|
||||
g1_ee: str # "l_ee" / "r_ee"
|
||||
offset: np.ndarray # world offset for viz + target
|
||||
marker_prefix: str # "left" / "right"
|
||||
|
||||
|
||||
class Markers:
|
||||
"""Creates meshcat visualization primitives, showing end-effector frames of exoskeleton and G1"""
|
||||
|
||||
def __init__(self, viewer):
|
||||
self.v = viewer
|
||||
|
||||
def sphere(self, path: str, r: float, rgba: tuple[float, float, float, float]):
|
||||
import meshcat.geometry as mg
|
||||
|
||||
c = (int(rgba[0] * 255) << 16) | (int(rgba[1] * 255) << 8) | int(rgba[2] * 255)
|
||||
self.v[path].set_object(
|
||||
mg.Sphere(r),
|
||||
mg.MeshPhongMaterial(color=c, opacity=rgba[3], transparent=rgba[3] < 1.0),
|
||||
)
|
||||
|
||||
def axes(self, path: str, axis_len: float = 0.1, axis_w: int = 6):
|
||||
import meshcat.geometry as mg
|
||||
|
||||
pts = np.array(
|
||||
[[0, 0, 0], [axis_len, 0, 0], [0, 0, 0], [0, axis_len, 0], [0, 0, 0], [0, 0, axis_len]],
|
||||
dtype=np.float32,
|
||||
).T
|
||||
cols = np.array(
|
||||
[[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]],
|
||||
dtype=np.float32,
|
||||
).T
|
||||
self.v[path].set_object(
|
||||
mg.LineSegments(
|
||||
mg.PointsGeometry(position=pts, color=cols),
|
||||
mg.LineBasicMaterial(linewidth=axis_w, vertexColors=True),
|
||||
)
|
||||
)
|
||||
|
||||
def tf(self, path: str, mat: np.ndarray):
|
||||
self.v[path].set_transform(mat)
|
||||
|
||||
|
||||
class ExoskeletonIKHelper:
|
||||
"""
|
||||
- Loads G1 robot and exoskeleton URDF models via Pinocchio
|
||||
- Computes forward kinematics on exoskeleton to get end-effector poses
|
||||
- Solves inverse kinematics on G1 to match those poses
|
||||
- Provides meshcat visualization showing both robots and targets
|
||||
|
||||
Args:
|
||||
frozen_joints: List of G1 joint names to exclude from IK (kept at neutral).
|
||||
"""
|
||||
|
||||
def __init__(self, frozen_joints: list[str] | None = None):
|
||||
try:
|
||||
import pinocchio as pin
|
||||
except ImportError as e:
|
||||
raise ImportError("ik mode needs pinocchio: pip install pin") from e
|
||||
|
||||
self.pin = pin
|
||||
self.frozen_joints = frozen_joints or []
|
||||
|
||||
self.g1_ik = G1_29_ArmIK()
|
||||
self.robot_g1 = self.g1_ik.reduced_robot
|
||||
self.robot_g1.data = self.robot_g1.model.createData()
|
||||
self.q_g1 = pin.neutral(self.robot_g1.model)
|
||||
|
||||
assets_dir = os.path.join(self.g1_ik.repo_path, "assets")
|
||||
|
||||
self.frozen_idx = self._frozen_joint_indices()
|
||||
|
||||
self.arms = [
|
||||
ArmCfg(
|
||||
side="left",
|
||||
urdf=os.path.join(assets_dir, "exo_left.urdf"),
|
||||
root="exo_left",
|
||||
g1_ee="L_ee",
|
||||
offset=np.array([0.6, 0.3, 0.0]),
|
||||
marker_prefix="left",
|
||||
),
|
||||
ArmCfg(
|
||||
side="right",
|
||||
urdf=os.path.join(assets_dir, "exo_right.urdf"),
|
||||
root="exo_right",
|
||||
g1_ee="R_ee",
|
||||
offset=np.array([0.6, -0.3, 0.0]),
|
||||
marker_prefix="right",
|
||||
),
|
||||
]
|
||||
|
||||
self.exo = {} # side -> pin.RobotWrapper
|
||||
self.q_exo = {} # side -> q
|
||||
self.ee_id_exo = {} # side -> frame id
|
||||
self.qmap = {} # side -> {joint_name: q_idx}
|
||||
self.ee_id_g1 = {} # side -> frame id
|
||||
|
||||
self._load_exo_models(assets_dir)
|
||||
for a in self.arms:
|
||||
self.ee_id_g1[a.side] = _frame_id(self.robot_g1.model, a.g1_ee)
|
||||
|
||||
self.viewer = None
|
||||
self.markers: Markers | None = None
|
||||
self.viz_g1 = None
|
||||
self.viz_exo = {} # side -> viz
|
||||
|
||||
def _frozen_joint_indices(self) -> dict[str, int]:
|
||||
out = {}
|
||||
m = self.robot_g1.model
|
||||
for name in self.frozen_joints:
|
||||
if name in m.names:
|
||||
jid = m.getJointId(name)
|
||||
out[name] = m.idx_qs[jid]
|
||||
logger.info(f"freezing joint: {name} (q_idx={out[name]})")
|
||||
return out
|
||||
|
||||
def _find_exo_ee(self, model, ee_name: str = "ee") -> int:
|
||||
ee = _frame_id(model, ee_name)
|
||||
if ee is not None:
|
||||
return ee
|
||||
for fid in reversed(range(model.nframes)):
|
||||
if model.frames[fid].type == self.pin.FrameType.BODY:
|
||||
return fid
|
||||
return 0
|
||||
|
||||
def _build_joint_map(self, robot) -> dict[str, int]:
|
||||
m = robot.model
|
||||
return {n: m.idx_qs[m.getJointId(n)] for n in JOINTS if n in m.names}
|
||||
|
||||
def _load_exo_models(self, assets_dir: str):
|
||||
pin = self.pin
|
||||
for a in self.arms:
|
||||
if not os.path.exists(a.urdf):
|
||||
logger.warning(f"{a.side} exo urdf not found: {a.urdf}")
|
||||
continue
|
||||
r = pin.RobotWrapper.BuildFromURDF(a.urdf, assets_dir)
|
||||
self.exo[a.side] = r
|
||||
self.q_exo[a.side] = pin.neutral(r.model)
|
||||
self.ee_id_exo[a.side] = self._find_exo_ee(r.model)
|
||||
self.qmap[a.side] = self._build_joint_map(r)
|
||||
logger.info(f"loaded {a.side} exo urdf: {a.urdf}")
|
||||
|
||||
def init_visualization(self):
|
||||
"""
|
||||
Creates a browser-based visualization of exoskeleton and G1 robot,
|
||||
highlighting end-effector frames and target positions.
|
||||
"""
|
||||
try:
|
||||
from pinocchio.visualize import MeshcatVisualizer
|
||||
except ImportError as e:
|
||||
logger.warning(f"meshcat viz unavailable: {e}")
|
||||
return
|
||||
|
||||
# g1
|
||||
self.viz_g1 = MeshcatVisualizer(
|
||||
self.robot_g1.model, self.robot_g1.collision_model, self.robot_g1.visual_model
|
||||
)
|
||||
self.viz_g1.initViewer(open=True)
|
||||
self.viz_g1.loadViewerModel("g1")
|
||||
self.viz_g1.display(self.q_g1)
|
||||
|
||||
self.viewer = self.viz_g1.viewer
|
||||
self.markers = Markers(self.viewer)
|
||||
|
||||
# exos
|
||||
for a in self.arms:
|
||||
if a.side not in self.exo:
|
||||
continue
|
||||
r = self.exo[a.side]
|
||||
v = MeshcatVisualizer(r.model, r.collision_model, r.visual_model)
|
||||
v.initViewer(open=False)
|
||||
v.viewer = self.viewer
|
||||
v.loadViewerModel(a.root)
|
||||
offset_tf = np.eye(4)
|
||||
offset_tf[:3, 3] = a.offset
|
||||
self.viewer[a.root].set_transform(offset_tf)
|
||||
v.display(self.q_exo[a.side])
|
||||
self.viz_exo[a.side] = v
|
||||
|
||||
# markers
|
||||
for a in self.arms:
|
||||
p = a.marker_prefix
|
||||
self.markers.sphere(f"markers/{p}_exo_ee", 0.012, (0.2, 1.0, 0.2, 0.9))
|
||||
self.markers.sphere(f"markers/{p}_g1_ee", 0.015, (1.0, 0.2, 0.2, 0.9))
|
||||
self.markers.sphere(f"markers/{p}_ik_target", 0.015, (0.1, 0.3, 1.0, 0.9))
|
||||
self.markers.axes(f"markers/{p}_exo_axes", 0.06)
|
||||
self.markers.axes(f"markers/{p}_g1_axes", 0.08)
|
||||
|
||||
logger.info(f"meshcat viz initialized: {self.viewer.url()}")
|
||||
print(f"\nmeshcat url: {self.viewer.url()}\n")
|
||||
|
||||
def _fk_target_world(self, side: str, angles: dict[str, float]) -> np.ndarray | None:
|
||||
"""returns wrist frame target to be used for G1 IK in 4x4 homogeneous transform. Takes offset into account."""
|
||||
if side not in self.exo or not angles:
|
||||
return None
|
||||
|
||||
pin = self.pin
|
||||
q = self.q_exo[side]
|
||||
qmap = self.qmap[side]
|
||||
|
||||
for name, ang in angles.items():
|
||||
idx = qmap.get(name)
|
||||
if idx is not None:
|
||||
q[idx] = float(ang)
|
||||
|
||||
r = self.exo[side]
|
||||
pin.forwardKinematics(r.model, r.data, q)
|
||||
pin.updateFramePlacements(r.model, r.data)
|
||||
|
||||
ee = r.data.oMf[self.ee_id_exo[side]]
|
||||
target = np.eye(4)
|
||||
target[:3, :3] = ee.rotation
|
||||
# offset gets applied in world space
|
||||
cfg = next(a for a in self.arms if a.side == side)
|
||||
target[:3, 3] = cfg.offset + ee.translation
|
||||
return target
|
||||
|
||||
def update_visualization(self):
|
||||
if self.viewer is None or self.markers is None:
|
||||
return
|
||||
|
||||
pin = self.pin
|
||||
|
||||
# g1
|
||||
if self.viz_g1 is not None:
|
||||
self.viz_g1.display(self.q_g1)
|
||||
pin.forwardKinematics(self.robot_g1.model, self.robot_g1.data, self.q_g1)
|
||||
pin.updateFramePlacements(self.robot_g1.model, self.robot_g1.data)
|
||||
|
||||
for a in self.arms:
|
||||
fid = self.ee_id_g1.get(a.side)
|
||||
if fid is None:
|
||||
continue
|
||||
ee_tf = self.robot_g1.data.oMf[fid].homogeneous
|
||||
p = a.marker_prefix
|
||||
self.markers.tf(f"markers/{p}_g1_ee", ee_tf)
|
||||
self.markers.tf(f"markers/{p}_g1_axes", ee_tf)
|
||||
|
||||
# exos
|
||||
for a in self.arms:
|
||||
side = a.side
|
||||
v = self.viz_exo.get(side)
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
v.display(self.q_exo[side])
|
||||
r = self.exo[side]
|
||||
pin.forwardKinematics(r.model, r.data, self.q_exo[side])
|
||||
pin.updateFramePlacements(r.model, r.data)
|
||||
|
||||
ee = r.data.oMf[self.ee_id_exo[side]]
|
||||
world_tf = (pin.SE3(np.eye(3), a.offset) * ee).homogeneous
|
||||
p = a.marker_prefix
|
||||
self.markers.tf(f"markers/{p}_exo_ee", world_tf)
|
||||
self.markers.tf(f"markers/{p}_exo_axes", world_tf)
|
||||
|
||||
target_tf = np.eye(4)
|
||||
target_tf[:3, :3] = ee.rotation
|
||||
target_tf[:3, 3] = a.offset + ee.translation
|
||||
self.markers.tf(f"markers/{p}_ik_target", target_tf)
|
||||
|
||||
def compute_g1_joints_from_exo(
|
||||
self,
|
||||
left_angles: dict[str, float],
|
||||
right_angles: dict[str, float],
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Performs FK on exoskeleton to get end-effector poses in world frame,
|
||||
after which it solves IK on G1 to return joint angles matching those poses in G1 motor order.
|
||||
"""
|
||||
pin = self.pin
|
||||
|
||||
targets = {
|
||||
"left": self._fk_target_world("left", left_angles),
|
||||
"right": self._fk_target_world("right", right_angles),
|
||||
}
|
||||
|
||||
# fallback to current g1 ee pose if missing target
|
||||
pin.forwardKinematics(self.robot_g1.model, self.robot_g1.data, self.q_g1)
|
||||
pin.updateFramePlacements(self.robot_g1.model, self.robot_g1.data)
|
||||
|
||||
for a in self.arms:
|
||||
if targets[a.side] is not None:
|
||||
continue
|
||||
fid = self.ee_id_g1.get(a.side)
|
||||
if fid is not None:
|
||||
targets[a.side] = self.robot_g1.data.oMf[fid].homogeneous
|
||||
|
||||
if targets["left"] is None or targets["right"] is None:
|
||||
logger.warning("missing ik targets, returning current pose")
|
||||
return {}
|
||||
|
||||
frozen_vals = {n: self.q_g1[i] for n, i in self.frozen_idx.items()}
|
||||
|
||||
self.q_g1, _ = self.g1_ik.solve_ik(
|
||||
targets["left"], targets["right"], current_lr_arm_motor_q=self.q_g1
|
||||
)
|
||||
|
||||
for n, i in self.frozen_idx.items():
|
||||
self.q_g1[i] = frozen_vals[n]
|
||||
|
||||
return {
|
||||
f"{j.name}.q": float(self.q_g1[i])
|
||||
for i, j in enumerate(G1_29_JointArmIndex)
|
||||
if i < len(self.q_g1)
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import serial
|
||||
|
||||
from .exo_calib import ExoskeletonCalibration, exo_raw_to_angles, run_exo_calibration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_raw16(line: bytes) -> list[int] | None:
|
||||
try:
|
||||
parts = line.decode("utf-8", errors="ignore").split()
|
||||
if len(parts) < 16:
|
||||
return None
|
||||
return [int(x) for x in parts[:16]]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def read_raw_from_serial(ser) -> list[int] | None:
|
||||
"""Read latest sample from serial; if buffer is backed up, keep only the newest."""
|
||||
last = None
|
||||
while ser.in_waiting > 0:
|
||||
b = ser.readline()
|
||||
if not b:
|
||||
break
|
||||
raw16 = parse_raw16(b)
|
||||
if raw16 is not None:
|
||||
last = raw16
|
||||
if last is None:
|
||||
b = ser.readline()
|
||||
if b:
|
||||
last = parse_raw16(b)
|
||||
return last
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExoskeletonArm:
|
||||
port: str
|
||||
calibration_fpath: Path
|
||||
side: str
|
||||
baud_rate: int = 115200
|
||||
|
||||
_ser: serial.Serial | None = None
|
||||
calibration: ExoskeletonCalibration | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.calibration_fpath.is_file():
|
||||
self._load_calibration()
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._ser is not None and getattr(self._ser, "is_open", False)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.calibration is not None
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
return
|
||||
try:
|
||||
self._ser = serial.Serial(self.port, self.baud_rate, timeout=0.02)
|
||||
self._ser.reset_input_buffer()
|
||||
logger.info(f"connected: {self.port}")
|
||||
except serial.SerialException as e:
|
||||
raise ConnectionError(f"failed to connect to {self.port}: {e}") from e
|
||||
|
||||
if calibrate and not self.is_calibrated:
|
||||
self.calibrate()
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self._ser:
|
||||
try:
|
||||
self._ser.close()
|
||||
finally:
|
||||
self._ser = None
|
||||
|
||||
def _load_calibration(self) -> None:
|
||||
try:
|
||||
data = json.loads(self.calibration_fpath.read_text())
|
||||
self.calibration = ExoskeletonCalibration.from_dict(data)
|
||||
logger.info(f"loaded calibration: {self.calibration_fpath}")
|
||||
except Exception as e:
|
||||
logger.warning(f"failed to load calibration: {e}")
|
||||
|
||||
def read_raw(self) -> list[int] | None:
|
||||
if not self._ser:
|
||||
return None
|
||||
return read_raw_from_serial(self._ser)
|
||||
|
||||
def get_angles(self) -> dict[str, float]:
|
||||
if not self.calibration:
|
||||
raise RuntimeError("exoskeleton not calibrated")
|
||||
raw = self.read_raw()
|
||||
return {} if raw is None else exo_raw_to_angles(raw, self.calibration)
|
||||
|
||||
def calibrate(self) -> None:
|
||||
ser = self._ser
|
||||
self.calibration = run_exo_calibration(ser, self.side, self.calibration_fpath)
|
||||
@@ -1,157 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_unitree_g1 import UnitreeG1TeleoperatorConfig
|
||||
from .exo_ik import ExoskeletonIKHelper
|
||||
from .exo_serial import ExoskeletonArm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnitreeG1Teleoperator(Teleoperator):
|
||||
"""
|
||||
Bimanual exoskeleton arms teleoperator for Unitree G1 arms.
|
||||
|
||||
Uses inverse kinematics: exoskeleton FK computes end-effector pose,
|
||||
G1 IK solves for joint angles.
|
||||
"""
|
||||
|
||||
config_class = UnitreeG1TeleoperatorConfig
|
||||
name = "unitree_g1"
|
||||
|
||||
def __init__(self, config: UnitreeG1TeleoperatorConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Setup calibration directory
|
||||
self.calibration_dir = (
|
||||
config.calibration_dir
|
||||
if config.calibration_dir
|
||||
else HF_LEROBOT_CALIBRATION / TELEOPERATORS / self.name
|
||||
)
|
||||
self.calibration_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
left_id = f"{config.id}_left" if config.id else "left"
|
||||
right_id = f"{config.id}_right" if config.id else "right"
|
||||
|
||||
# Create exoskeleton arm instances
|
||||
self.left_arm = ExoskeletonArm(
|
||||
port=config.left_arm_config.port,
|
||||
baud_rate=config.left_arm_config.baud_rate,
|
||||
calibration_fpath=self.calibration_dir / f"{left_id}.json",
|
||||
side="left",
|
||||
)
|
||||
self.right_arm = ExoskeletonArm(
|
||||
port=config.right_arm_config.port,
|
||||
baud_rate=config.right_arm_config.baud_rate,
|
||||
calibration_fpath=self.calibration_dir / f"{right_id}.json",
|
||||
side="right",
|
||||
)
|
||||
|
||||
self.ik_helper: ExoskeletonIKHelper | None = None
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {f"{name}.q": float for name in self._g1_joint_names}
|
||||
|
||||
@cached_property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
frozen_joints = [j.strip() for j in self.config.frozen_joints.split(",") if j.strip()]
|
||||
self.ik_helper = ExoskeletonIKHelper(frozen_joints=frozen_joints)
|
||||
logger.info("IK helper initialized")
|
||||
|
||||
def calibrate(self) -> None:
|
||||
if not self.left_arm.is_calibrated:
|
||||
logger.info("Starting calibration for left arm...")
|
||||
self.left_arm.calibrate()
|
||||
else:
|
||||
logger.info("Left arm already calibrated. Skipping.")
|
||||
|
||||
if not self.right_arm.is_calibrated:
|
||||
logger.info("Starting calibration for right arm...")
|
||||
self.right_arm.calibrate()
|
||||
else:
|
||||
logger.info("Right arm already calibrated. Skipping.")
|
||||
|
||||
logger.info("Starting visualization to verify calibration...")
|
||||
self.run_visualization_loop()
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def get_action(self) -> dict[str, float]:
|
||||
left_angles = self.left_arm.get_angles()
|
||||
right_angles = self.right_arm.get_angles()
|
||||
return self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles)
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError("Exoskeleton arms do not support feedback")
|
||||
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
def run_visualization_loop(self):
|
||||
"""Run interactive Meshcat visualization loop to verify tracking."""
|
||||
if self.ik_helper is None:
|
||||
frozen_joints = [j.strip() for j in self.config.frozen_joints.split(",") if j.strip()]
|
||||
self.ik_helper = ExoskeletonIKHelper(frozen_joints=frozen_joints)
|
||||
|
||||
self.ik_helper.init_visualization()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Visualization running! Move the exoskeletons to test tracking.")
|
||||
print("Press Ctrl+C to exit.")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
try:
|
||||
while True:
|
||||
left_angles = self.left_arm.get_angles()
|
||||
right_angles = self.right_arm.get_angles()
|
||||
|
||||
self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles)
|
||||
self.ik_helper.update_visualization()
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nVisualization stopped.")
|
||||
|
||||
@cached_property
|
||||
def _g1_joint_names(self) -> list[str]:
|
||||
return [joint.name for joint in G1_29_JointIndex]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user