mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
Compare commits
71 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e29e89e4ed | |||
| 3d55c5e484 | |||
| 51b3b31927 | |||
| 4503019d18 | |||
| 6aa0cc267f | |||
| 6629b454b2 | |||
| 0059ca7924 | |||
| 97e7e0f9ed | |||
| 0f39248445 | |||
| a6370dd783 | |||
| 14a15f90e7 | |||
| 9c24a09665 | |||
| 6c94fcd1b1 | |||
| 092f4617ca | |||
| b18cef2e26 | |||
| 5c6182176f | |||
| 55c0471db9 | |||
| ec04b7ce3a | |||
| 04cbf669cf | |||
| 6380c0d0dd | |||
| 3409ef0dc2 | |||
| 0947111edd | |||
| 4483184875 | |||
| 149628dfd5 | |||
| bf337e716d | |||
| 477204d485 | |||
| 736b43f3cf | |||
| 4eb912da30 | |||
| 99dbbd56c2 | |||
| 6a6912ec37 | |||
| f6b1c39b78 | |||
| 0c0c171d35 | |||
| 2bf6359d24 | |||
| 9cfb5ce546 | |||
| 366bef915c | |||
| 4c694e20c7 | |||
| 5e609426fd | |||
| 9e10eb4a77 | |||
| 6d34a986de | |||
| 961277d86e | |||
| d0b6a66f34 | |||
| dc85e9b742 | |||
| 0b067df57d | |||
| 9ca680dce2 | |||
| 9919b16b36 | |||
| d36dfcdf71 | |||
| 90d9698c7e | |||
| 13bfee1aa4 | |||
| 79688a09f2 | |||
| bbef8bb077 | |||
| b2ff219624 | |||
| 66929c5935 | |||
| 80417111d3 | |||
| d44f3a3bd9 | |||
| 5286ef8439 | |||
| fe068df711 | |||
| da41646073 | |||
| b864c13dfb | |||
| 46e19ae579 | |||
| 77dc49b3a3 | |||
| 33910673ec | |||
| 19dce78457 | |||
| 112b2d173a | |||
| b825880c40 | |||
| fd917e4fa0 | |||
| 966fedfeef | |||
| 6e88d6f387 | |||
| 83276eeb2f | |||
| 72b0af4ed7 | |||
| b57504b89e | |||
| 72f7aaedb5 |
@@ -18,6 +18,11 @@ name: Documentation
|
||||
on:
|
||||
# Allows running this workflow manually from the Actions tab
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version tag (e.g. v0.1.2) - Leave empty for standard main build'
|
||||
required: false
|
||||
type: string
|
||||
|
||||
# Triggers the workflow on push events to main for the docs folder
|
||||
push:
|
||||
@@ -54,7 +59,13 @@ jobs:
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: lerobot
|
||||
additional_args: --not_python_module ${{ github.event_name == 'release' && format('--version {0}', github.event.release.tag_name) || '' }}
|
||||
additional_args: >-
|
||||
--not_python_module
|
||||
${{
|
||||
(github.event_name == 'release' && format('--version {0}', github.event.release.tag_name)) ||
|
||||
(inputs.version != '' && format('--version {0}', inputs.version)) ||
|
||||
''
|
||||
}}
|
||||
secrets:
|
||||
token: ${{ secrets.HUGGINGFACE_PUSH }}
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
|
||||
@@ -20,8 +20,8 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
# Run on the 1st and 15th of every month at 09:00 UTC
|
||||
schedule:
|
||||
- cron: '0 2 1,15 * *'
|
||||
# schedule:
|
||||
# - cron: '0 2 1,15 * *'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
+1
-1
@@ -14,7 +14,7 @@ You can contribute in many ways:
|
||||
- **Documentation:** Improve examples, guides, and docstrings.
|
||||
- **Feedback:** Submit tickets related to bugs or desired new features.
|
||||
|
||||
If you are unsure where to start, join our [Discord Channel](https://discord.gg/JkrYNdmw).
|
||||
If you are unsure where to start, join our [Discord Channel](https://discord.gg/q8Dzzpym3f).
|
||||
|
||||
## Development Setup
|
||||
|
||||
|
||||
@@ -128,6 +128,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
||||
## Resources
|
||||
|
||||
- **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API.
|
||||
- **[Chinese Tutorials: LeRobot+SO-ARM101中文教程-同济子豪兄](https://zihao-ai.feishu.cn/wiki/space/7589642043471924447)** Detailed doc for assembling, teleoperate, dataset, train, deploy. Verified by Seed Studio and 5 global hackathon players.
|
||||
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
|
||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||
|
||||
+48
@@ -0,0 +1,48 @@
|
||||
# Security Policy
|
||||
|
||||
## Project Status & Philosophy
|
||||
|
||||
`lerobot` has so far been primarily a research and prototyping tool, which is why deployment security hasn’t been a strong focus until now. As `lerobot` continues to be adopted and deployed in production, we are paying much closer attention to these kinds of issues.
|
||||
|
||||
Fortunately, being an open-source project, the community can also help by reporting and fixing vulnerabilities. We appreciate your efforts to responsibly disclose your findings and will make every effort to acknowledge your contributions.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/huggingface/lerobot/security/advisories/new) tab.
|
||||
|
||||
The `lerobot` team will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.
|
||||
|
||||
#### Hugging Face Security Team
|
||||
|
||||
Since this project is part of the Hugging Face ecosystem, feel free to submit vulnerability reports directly to: **[security@huggingface.co](mailto:security@huggingface.co)**. Someone from the HF security team will review the report and recommend next steps.
|
||||
|
||||
#### Open Source Disclosures
|
||||
|
||||
If reporting a vulnerability specific to the open-source codebase (and not the underlying Hub infrastructure), you may also use [Huntr](https://huntr.com), a vulnerability disclosure program for open source software.
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Currently, we treat `lerobot` as a rolling release. We prioritize security updates for the latest available version (`main` branch).
|
||||
|
||||
| Version | Supported |
|
||||
| -------- | --------- |
|
||||
| Latest | ✅ |
|
||||
| < Latest | ❌ |
|
||||
|
||||
## Secure Usage Guidelines
|
||||
|
||||
`lerobot` is tightly coupled to the Hugging Face Hub for sharing data and pretrained policies. When downloading artifacts uploaded by others, you expose yourself to risks. Please read below for recommendations to keep your runtime and robot environment safe.
|
||||
|
||||
### Remote Artefacts (Weights & Policies)
|
||||
|
||||
Models and policies uploaded to the Hugging Face Hub come in different formats. We heavily recommend uploading and downloading models in the [`safetensors`](https://github.com/huggingface/safetensors) format.
|
||||
|
||||
`safetensors` was developed specifically to prevent arbitrary code execution on your system, which is critical when running software on physical hardware/robots.
|
||||
|
||||
To avoid loading models from unsafe formats (e.g., `pickle`), you should ensure you are prioritizing `safetensors` files.
|
||||
|
||||
### Remote Code
|
||||
|
||||
Some models or environments on the Hub may require `trust_remote_code=True` to run custom architecture code.
|
||||
|
||||
Please **always** verify the content of the modeling files when using this argument. We recommend setting a specific `revision` (commit hash) when loading remote code to ensure you protect yourself from unverified updates to the repository.
|
||||
@@ -7,8 +7,6 @@
|
||||
- sections:
|
||||
- local: il_robots
|
||||
title: Imitation Learning for Robots
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
- local: bring_your_own_policies
|
||||
title: Bring Your Own Policies
|
||||
- local: integrate_hardware
|
||||
@@ -29,6 +27,10 @@
|
||||
title: Porting Large Datasets
|
||||
- local: using_dataset_tools
|
||||
title: Using the Dataset Tools
|
||||
- local: annotation_tools
|
||||
title: Using the Annotation Tools
|
||||
- local: dataset_subtask
|
||||
title: Using Subtasks in the Dataset
|
||||
title: "Datasets"
|
||||
- sections:
|
||||
- local: act
|
||||
@@ -99,11 +101,19 @@
|
||||
title: Unitree G1
|
||||
- local: earthrover_mini_plus
|
||||
title: Earth Rover Mini
|
||||
- local: omx
|
||||
title: OMX
|
||||
- local: openarm
|
||||
title: OpenArm
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
title: Phone
|
||||
title: "Teleoperators"
|
||||
- sections:
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
title: "Sensors"
|
||||
- sections:
|
||||
- local: torch_accelerators
|
||||
title: PyTorch accelerators
|
||||
@@ -113,6 +123,8 @@
|
||||
title: Notebooks
|
||||
- local: feetech
|
||||
title: Updating Feetech Firmware
|
||||
- local: damiao
|
||||
title: Damiao Motors and CAN Bus
|
||||
title: "Resources"
|
||||
- sections:
|
||||
- local: contributing
|
||||
|
||||
@@ -0,0 +1,425 @@
|
||||
# Dataset Annotation Tools
|
||||
|
||||
This guide explains how to use the automatic annotation tools to add skill labels and synthetic dialogue to your LeRobot datasets.
|
||||
|
||||
## Overview
|
||||
|
||||
The annotation pipeline consists of two main components:
|
||||
|
||||
1. **Subtask Annotation** (`subtask_annotate.py`): Automatically segments robot demonstrations into atomic skills using Vision-Language Models (VLMs)
|
||||
2. **High-Level Annotation** (`high_level_annotate.py`): Generates synthetic user prompts and robot utterances for hierarchical policy training
|
||||
|
||||
These tools enable you to transform raw robot demonstration data into richly annotated datasets suitable for training hierarchical policies.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
Before using the annotation tools, ensure you have the required dependencies:
|
||||
|
||||
```bash
|
||||
pip install transformers qwen-vl-utils opencv-python rich pandas pyarrow
|
||||
```
|
||||
|
||||
You'll also need FFmpeg for video processing:
|
||||
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install ffmpeg
|
||||
|
||||
# macOS
|
||||
brew install ffmpeg
|
||||
```
|
||||
|
||||
## Part 1: Subtask Annotation
|
||||
|
||||
### What It Does
|
||||
|
||||
The subtask annotator segments each episode into short atomic manipulation skills (1-3 seconds each). For example, a "pick and place" episode might be segmented into:
|
||||
- "reach towards object" (0.0s - 1.2s)
|
||||
- "grasp object" (1.2s - 2.1s)
|
||||
- "lift object" (2.1s - 3.5s)
|
||||
- "move to target" (3.5s - 5.0s)
|
||||
- "release object" (5.0s - 6.2s)
|
||||
|
||||
### Usage
|
||||
|
||||
#### Basic Example
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--video-key observation.images.base \
|
||||
--output-dir /path/to/output
|
||||
```
|
||||
|
||||
#### With Local Dataset
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
|
||||
--data-dir /path/to/local/dataset \
|
||||
--video-key observation.images.base \
|
||||
--output-dir /path/to/output
|
||||
```
|
||||
|
||||
#### Advanced Options
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--video-key observation.images.base \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--batch-size 16 \
|
||||
--output-dir /path/to/output \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
### Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--repo-id` | HuggingFace Hub dataset ID | Required (or use --data-dir) |
|
||||
| `--data-dir` | Path to local dataset | Required (or use --repo-id) |
|
||||
| `--video-key` | Video observation key | Required |
|
||||
| `--model` | VLM model to use | `Qwen/Qwen2-VL-7B-Instruct` |
|
||||
| `--device` | Device to run model on | `cuda` |
|
||||
| `--dtype` | Model dtype | `bfloat16` |
|
||||
| `--batch-size` | Episodes per batch | `8` |
|
||||
| `--episodes` | Specific episodes to annotate | All episodes |
|
||||
| `--output-dir` | Output directory | Auto-generated |
|
||||
| `--push-to-hub` | Push to HuggingFace Hub | `False` |
|
||||
|
||||
### Supported Models
|
||||
|
||||
- **Qwen2-VL**: `Qwen/Qwen2-VL-2B-Instruct`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`
|
||||
- **Qwen3-VL**: `Qwen/Qwen3-VL-30B-A3B-Instruct`
|
||||
|
||||
### Output Files
|
||||
|
||||
The subtask annotation creates the following files in your dataset:
|
||||
|
||||
1. **`meta/subtasks.parquet`**: DataFrame with unique subtask names
|
||||
```python
|
||||
# Structure:
|
||||
# Index: subtask name (string)
|
||||
# Column: subtask_index (int64)
|
||||
```
|
||||
|
||||
2. **`meta/skills.json`**: Raw skill annotations with timestamps
|
||||
```json
|
||||
{
|
||||
"coarse_description": "Pick and place the object",
|
||||
"skill_to_subtask_index": {
|
||||
"reach towards object": 0,
|
||||
"grasp object": 1,
|
||||
...
|
||||
},
|
||||
"episodes": {
|
||||
"0": {
|
||||
"episode_index": 0,
|
||||
"description": "Pick and place the object",
|
||||
"skills": [
|
||||
{"name": "reach towards object", "start": 0.0, "end": 1.2},
|
||||
{"name": "grasp object", "start": 1.2, "end": 2.1},
|
||||
...
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
3. **`subtask_index` feature**: Added to each frame in the dataset
|
||||
- Type: `int64`
|
||||
- Shape: `(1,)`
|
||||
- Maps each frame to its corresponding subtask
|
||||
|
||||
### Accessing Subtask Annotations
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Load annotated dataset
|
||||
dataset = LeRobotDataset(repo_id="your/dataset_with_subtasks")
|
||||
|
||||
# Get a frame
|
||||
frame = dataset[100]
|
||||
|
||||
# Get the subtask for this frame
|
||||
subtask_idx = frame["subtask_index"].item()
|
||||
subtask_name = dataset.meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
print(f"Frame 100 is performing: {subtask_name}")
|
||||
|
||||
# Load all subtasks
|
||||
subtasks_df = dataset.meta.subtasks
|
||||
print(subtasks_df)
|
||||
```
|
||||
|
||||
## Part 2: High-Level Annotation
|
||||
|
||||
### What It Does
|
||||
|
||||
The high-level annotator generates synthetic dialogue for hierarchical policy training. For each skill, it creates:
|
||||
- **User Prompt** (`ℓ_t`): A natural language request from the user
|
||||
- **Robot Utterance** (`u_t`): A natural language response from the robot
|
||||
|
||||
This enables training policies that can understand and respond to human instructions in natural dialogue.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
**Important**: You must run subtask annotation first! High-level annotation requires the `skills.json` file generated by subtask annotation.
|
||||
|
||||
### Usage
|
||||
|
||||
#### Image Mode (Default)
|
||||
|
||||
Samples frames at regular intervals and passes images to the VLM:
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/pi05_full/annotate/high_level_annotate.py \
|
||||
--repo-id your/dataset_with_subtasks \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--image-key observation.images.base \
|
||||
--output-dir /path/to/output
|
||||
```
|
||||
|
||||
#### Video Mode
|
||||
|
||||
Passes entire episode videos to the VLM for better temporal understanding:
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/pi05_full/annotate/high_level_annotate.py \
|
||||
--repo-id your/dataset_with_subtasks \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--video-mode \
|
||||
--video-key observation.images.base \
|
||||
--video-batch-size 4 \
|
||||
--output-dir /path/to/output
|
||||
```
|
||||
|
||||
### Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--repo-id` | HuggingFace Hub dataset ID | Required (or use --data-dir) |
|
||||
| `--data-dir` | Path to local dataset | Required (or use --repo-id) |
|
||||
| `--model` | VLM model to use | `Qwen/Qwen2-VL-7B-Instruct` |
|
||||
| `--image-key` | Image observation key (image mode) | First camera key |
|
||||
| `--video-mode` | Use video instead of images | `False` |
|
||||
| `--video-key` | Video observation key (video mode) | Auto-detected |
|
||||
| `--video-batch-size` | Episodes per batch (video mode) | `1` |
|
||||
| `--sample-interval` | Sampling interval in seconds | `1.0` |
|
||||
| `--temperature` | Sampling temperature | `0.7` |
|
||||
| `--output-dir` | Output directory | Auto-generated |
|
||||
| `--push-to-hub` | Push to HuggingFace Hub | `False` |
|
||||
|
||||
### Output Files
|
||||
|
||||
The high-level annotation creates:
|
||||
|
||||
1. **`meta/tasks_high_level.parquet`**: DataFrame with high-level tasks
|
||||
```python
|
||||
# Structure:
|
||||
# Index: task string (concatenated user_prompt | robot_utterance)
|
||||
# Columns:
|
||||
# - task_index: int64
|
||||
# - user_prompt: string
|
||||
# - robot_utterance: string
|
||||
# - skill: string (associated subtask)
|
||||
# - scenario_type: string
|
||||
# - response_type: string
|
||||
```
|
||||
|
||||
2. **`meta/syn_annotations.jsonl`**: Debug annotations (JSONL format)
|
||||
```json
|
||||
{"episode_id": 0, "timestamp": 1.5, "skill_current": "grasp object", "user_prompt": "Can you pick that up?", "robot_utterance": "Sure, I'll grasp it now", ...}
|
||||
```
|
||||
|
||||
3. **`task_index_high_level` feature**: Added to each frame
|
||||
- Type: `int64`
|
||||
- Shape: `(1,)`
|
||||
- Maps each frame to its high-level task
|
||||
|
||||
### Dialogue Types Generated
|
||||
|
||||
The system generates diverse interaction types:
|
||||
|
||||
**Scenario Types:**
|
||||
- `specific_object`: "Pick up the red block"
|
||||
- `negative_task`: "Don't touch the blue one"
|
||||
- `situated_correction`: "Actually, move to the other box instead"
|
||||
- `implicit_request`: "I need something red for the tower"
|
||||
- `constraint_based`: "Make sure to handle it gently"
|
||||
|
||||
**Response Types:**
|
||||
- `confirmation`: "OK, I'll pick it up"
|
||||
- `clarification`: "Just to confirm, you want me to pick up the red block?"
|
||||
- `acknowledgment`: "Got it, picking up the red block"
|
||||
- `constraint_acknowledgment`: "Sure, I'll pick it up gently"
|
||||
|
||||
### Accessing High-Level Annotations
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
import pandas as pd
|
||||
|
||||
# Load annotated dataset
|
||||
dataset = LeRobotDataset(repo_id="your/dataset_with_high_level_tasks")
|
||||
|
||||
# Get a frame
|
||||
frame = dataset[100]
|
||||
|
||||
# Get the high-level task
|
||||
task_idx = frame["task_index_high_level"].item()
|
||||
|
||||
# Load tasks metadata
|
||||
tasks_df = pd.read_parquet(dataset.root / "meta" / "tasks_high_level.parquet")
|
||||
task_row = tasks_df[tasks_df["task_index"] == task_idx].iloc[0]
|
||||
|
||||
print(f"User: {task_row['user_prompt']}")
|
||||
print(f"Robot: {task_row['robot_utterance']}")
|
||||
print(f"Skill: {task_row['skill']}")
|
||||
|
||||
# Use in a DataLoader
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
print(f"Task indices: {batch['task_index_high_level']}")
|
||||
print(f"User prompts: {batch['user_prompt'][0]}")
|
||||
print(f"Robot utterances: {batch['robot_utterance'][0]}")
|
||||
```
|
||||
|
||||
## Complete Pipeline Example
|
||||
|
||||
Here's how to run both annotation stages:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
|
||||
REPO_ID="your-username/your-dataset"
|
||||
MODEL="Qwen/Qwen2-VL-7B-Instruct"
|
||||
OUTPUT_DIR="/path/to/output"
|
||||
|
||||
# Step 1: Subtask Annotation
|
||||
python src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
|
||||
--repo-id "$REPO_ID" \
|
||||
--video-key observation.images.base \
|
||||
--model "$MODEL" \
|
||||
--batch-size 8 \
|
||||
--output-dir "${OUTPUT_DIR}/subtasks"
|
||||
|
||||
# Step 2: High-Level Annotation (Image Mode)
|
||||
python src/lerobot/policies/pi05_full/annotate/high_level_annotate.py \
|
||||
--data-dir "${OUTPUT_DIR}/subtasks" \
|
||||
--model "$MODEL" \
|
||||
--image-key observation.images.base \
|
||||
--sample-interval 1.0 \
|
||||
--output-dir "${OUTPUT_DIR}/final"
|
||||
|
||||
# Or Step 2: High-Level Annotation (Video Mode - Recommended)
|
||||
python src/lerobot/policies/pi05_full/annotate/high_level_annotate.py \
|
||||
--data-dir "${OUTPUT_DIR}/subtasks" \
|
||||
--model "$MODEL" \
|
||||
--video-mode \
|
||||
--video-key observation.images.base \
|
||||
--video-batch-size 4 \
|
||||
--output-dir "${OUTPUT_DIR}/final"
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
### For Faster Processing
|
||||
|
||||
1. **Increase batch size**: Use `--batch-size 16` or higher (subtask annotation)
|
||||
2. **Increase video batch size**: Use `--video-batch-size 8` (high-level annotation in video mode)
|
||||
3. **Larger sampling interval**: Use `--sample-interval 5.0` for testing (samples every 5 seconds instead of 1)
|
||||
4. **Use smaller models**: `Qwen/Qwen2-VL-2B-Instruct` is faster than `Qwen2-VL-7B-Instruct`
|
||||
5. **Process specific episodes**: Use `--episodes 0 1 2 3` to annotate only a subset
|
||||
|
||||
### For Better Quality
|
||||
|
||||
1. **Use larger models**: `Qwen/Qwen3-VL-30B-A3B-Instruct` or `Qwen/Qwen2-VL-72B-Instruct`
|
||||
2. **Use video mode**: Provides better temporal context
|
||||
3. **Smaller sampling intervals**: `--sample-interval 0.5` for dense annotations
|
||||
4. **Adjust temperature**: Use `--temperature 0.9` for more diverse dialogue
|
||||
|
||||
## Memory Requirements
|
||||
|
||||
| Model | GPU Memory | Recommended Batch Size |
|
||||
|-------|------------|------------------------|
|
||||
| Qwen2-VL-2B | ~8 GB | 16-32 |
|
||||
| Qwen2-VL-7B | ~16 GB | 8-16 |
|
||||
| Qwen2-VL-72B | ~80 GB | 1-2 |
|
||||
| Qwen3-VL-30B | ~40 GB | 4-8 |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "FFmpeg not found"
|
||||
```bash
|
||||
# Install FFmpeg
|
||||
sudo apt-get install ffmpeg # Ubuntu/Debian
|
||||
brew install ffmpeg # macOS
|
||||
```
|
||||
|
||||
### "CUDA out of memory"
|
||||
- Reduce batch size: `--batch-size 1` or `--video-batch-size 1`
|
||||
- Use smaller model: `Qwen/Qwen2-VL-2B-Instruct`
|
||||
- Use CPU: `--device cpu` (much slower)
|
||||
|
||||
### "No skills.json found"
|
||||
Run subtask annotation first before high-level annotation.
|
||||
|
||||
### "Video key not found"
|
||||
List available keys:
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
dataset = LeRobotDataset(repo_id="your/dataset")
|
||||
print("Video keys:", dataset.meta.video_keys)
|
||||
print("Camera keys:", dataset.meta.camera_keys)
|
||||
```
|
||||
|
||||
## Dataset Structure After Annotation
|
||||
|
||||
```
|
||||
your_dataset_with_high_level_tasks/
|
||||
├── meta/
|
||||
│ ├── info.json # Original metadata
|
||||
│ ├── tasks.parquet # Original tasks (preserved)
|
||||
│ ├── subtasks.parquet # NEW: Subtask names and indices
|
||||
│ ├── skills.json # NEW: Raw skill annotations with timestamps
|
||||
│ ├── tasks_high_level.parquet # NEW: High-level tasks with dialogue
|
||||
│ └── syn_annotations.jsonl # NEW: Debug annotations
|
||||
├── data/
|
||||
│ └── chunk-000/
|
||||
│ ├── observation.images.base.mp4
|
||||
│ ├── action.safetensors
|
||||
│ ├── subtask_index.safetensors # NEW: Subtask per frame
|
||||
│ └── task_index_high_level.safetensors # NEW: High-level task per frame
|
||||
└── videos/
|
||||
└── ...
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
If you use these annotation tools in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@article{lerobot2024,
|
||||
title={LeRobot: State-of-the-art Machine Learning for Real-World Robotics},
|
||||
author={LeRobot Contributors},
|
||||
year={2024},
|
||||
url={https://github.com/huggingface/lerobot}
|
||||
}
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
After annotation, you can:
|
||||
1. Train hierarchical policies using the subtask and high-level annotations
|
||||
2. Use the synthetic dialogue for instruction-following policy training
|
||||
3. Analyze skill distributions and dialogue patterns
|
||||
4. Share your annotated dataset on HuggingFace Hub with `--push-to-hub`
|
||||
|
||||
For training examples, see the [training documentation](../training/).
|
||||
|
||||
@@ -195,6 +195,7 @@ client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address="localhost:8080",
|
||||
policy_device="mps",
|
||||
client_device="cpu",
|
||||
policy_type="smolvla",
|
||||
pretrained_name_or_path="<user>/smolvla_async",
|
||||
chunk_size_threshold=0.5,
|
||||
|
||||
+95
-81
@@ -1,12 +1,22 @@
|
||||
# Cameras
|
||||
|
||||
LeRobot offers multiple options for video capture, including phone cameras, built-in laptop cameras, external webcams, and Intel RealSense cameras. To efficiently record frames from most cameras, you can use either the `OpenCVCamera` or `RealSenseCamera` class. For additional compatibility details on the `OpenCVCamera` class, refer to the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
|
||||
LeRobot offers multiple options for video capture:
|
||||
|
||||
### Finding your camera
|
||||
| Class | Supported Cameras |
|
||||
| ----------------- | ----------------------------------- |
|
||||
| `OpenCVCamera` | Phone, built-in laptop, USB webcams |
|
||||
| `ZMQCamera` | Network-connected cameras |
|
||||
| `RealSenseCamera` | Intel RealSense (with depth) |
|
||||
| `Reachy2Camera` | Reachy 2 robot cameras |
|
||||
|
||||
To instantiate a camera, you need a camera identifier. This identifier might change if you reboot your computer or re-plug your camera, a behavior mostly dependant on your operating system.
|
||||
> [!TIP]
|
||||
> For `OpenCVCamera` compatibility details, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
|
||||
|
||||
To find the camera indices of the cameras plugged into your system, run the following script:
|
||||
### Find your camera
|
||||
|
||||
Every camera requires a unique identifier to be instantiated, allowing you to distinguish between multiple connected devices.
|
||||
|
||||
`OpenCVCamera` and `RealSenseCamera` support auto-discovery. Run the command below to list available devices and their identifiers. Note that these identifiers may change after rebooting your computer or re-plugging the camera, depending on your operating system.
|
||||
|
||||
```bash
|
||||
lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
|
||||
@@ -14,7 +24,7 @@ lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
|
||||
|
||||
The output will look something like this if you have two cameras connected:
|
||||
|
||||
```
|
||||
```bash
|
||||
--- Detected Cameras ---
|
||||
Camera #0:
|
||||
Name: OpenCV Camera @ 0
|
||||
@@ -33,13 +43,37 @@ Camera #0:
|
||||
> [!WARNING]
|
||||
> When using Intel RealSense cameras in `macOS`, you could get this [error](https://github.com/IntelRealSense/librealsense/issues/12307): `Error finding RealSense cameras: failed to set power state`, this can be solved by running the same command with `sudo` permissions. Note that using RealSense cameras in `macOS` is unstable.
|
||||
|
||||
## Use Cameras
|
||||
`ZMQCamera` and `Reachy2Camera` do not support auto-discovery. They must be configured manually by providing their network address and port or robot SDK settings.
|
||||
|
||||
Below are two examples, demonstrating how to work with the API.
|
||||
## Use cameras
|
||||
|
||||
- **Asynchronous frame capture** using an OpenCV-based camera
|
||||
### Frame access modes
|
||||
|
||||
All camera classes implement three access modes for capturing frames:
|
||||
|
||||
| Method | Behavior | Blocks? | Best For |
|
||||
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------- | ---------------------------------------- |
|
||||
| `read()` | Waits for the camera hardware to return a frame. May block for a long time depending on the camera and SDK. | Yes | Simple scripts, sequential capture |
|
||||
| `async_read(timeout_ms)` | Returns the latest unconsumed frame from background thread. Blocks only if buffer is empty, up to `timeout_ms`. Raises `TimeoutError` if no frame arrives. | With a timeout | Control loops synchronized to camera FPS |
|
||||
| `read_latest(max_age_ms)` | Peeks at the most recent frame in buffer (may be stale). Raises `TimeoutError` if frame is older than `max_age_ms`. | No | UI visualization, logging, monitoring |
|
||||
|
||||
### Usage examples
|
||||
|
||||
The following examples show how to use the camera API to configure and capture frames from different camera types.
|
||||
|
||||
- **Blocking and non-blocking frame capture** using an OpenCV-based camera
|
||||
- **Color and depth capture** using an Intel RealSense camera
|
||||
|
||||
> [!WARNING]
|
||||
> Failing to cleanly disconnect cameras can cause resource leaks. Use the context manager protocol to ensure automatic cleanup:
|
||||
>
|
||||
> ```python
|
||||
> with OpenCVCamera(config) as camera:
|
||||
> ...
|
||||
> ```
|
||||
>
|
||||
> You can also call `connect()` and `disconnect()` manually, but always use a `finally` block for the latter.
|
||||
|
||||
<hfoptions id="shell_restart">
|
||||
<hfoption id="Open CV Camera">
|
||||
|
||||
@@ -60,16 +94,30 @@ config = OpenCVCameraConfig(
|
||||
)
|
||||
|
||||
# Instantiate and connect an `OpenCVCamera`, performing a warm-up read (default).
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
with OpenCVCamera(config) as camera:
|
||||
|
||||
# Read a frame synchronously — blocks until hardware delivers a new frame
|
||||
frame = camera.read()
|
||||
print(f"read() call returned frame with shape:", frame.shape)
|
||||
|
||||
# Read a frame asynchronously with a timeout — returns the latest unconsumed frame or waits up to timeout_ms for a new one
|
||||
try:
|
||||
for i in range(10):
|
||||
frame = camera.async_read(timeout_ms=200)
|
||||
print(f"async_read call returned frame {i} with shape:", frame.shape)
|
||||
except TimeoutError as e:
|
||||
print(f"No frame received within timeout: {e}")
|
||||
|
||||
# Instantly return a frame - returns the most recent frame captured by the camera
|
||||
try:
|
||||
initial_frame = camera.read_latest(max_age_ms=1000)
|
||||
for i in range(10):
|
||||
frame = camera.read_latest(max_age_ms=1000)
|
||||
print(f"read_latest call returned frame {i} with shape:", frame.shape)
|
||||
print(f"Was a new frame received by the camera? {not (initial_frame == frame).any()}")
|
||||
except TimeoutError as e:
|
||||
print(f"Frame too old: {e}")
|
||||
|
||||
# Read frames asynchronously in a loop via `async_read(timeout_ms)`
|
||||
try:
|
||||
for i in range(10):
|
||||
frame = camera.async_read(timeout_ms=200)
|
||||
print(f"Async frame {i} shape:", frame.shape)
|
||||
finally:
|
||||
camera.disconnect()
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -111,10 +159,10 @@ finally:
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Use your phone
|
||||
## Use your phone's camera
|
||||
|
||||
<hfoptions id="use phone">
|
||||
<hfoption id="Mac">
|
||||
<hfoption id="iPhone & macOS">
|
||||
|
||||
To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
|
||||
|
||||
@@ -124,83 +172,49 @@ To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
|
||||
|
||||
For more details, visit [Apple support](https://support.apple.com/en-gb/guide/mac-help/mchl77879b8a/mac).
|
||||
|
||||
Your iPhone should be detected automatically when running the camera setup script in the next section.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Linux">
|
||||
<hfoption id="OBS virtual camera">
|
||||
|
||||
If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera
|
||||
If you want to use your phone as a camera using OBS, follow these steps to set up a virtual camera.
|
||||
|
||||
1. _Install `v4l2loopback-dkms` and `v4l-utils`_. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using:
|
||||
1. _(Linux only) Install `v4l2loopback-dkms` and `v4l-utils`_. These packages create virtual camera devices and verify their settings. Install with:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```bash
|
||||
sudo apt install v4l2loopback-dkms v4l-utils
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
2. _Install [DroidCam](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
|
||||
3. _Install [OBS Studio](https://obsproject.com)_. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org):
|
||||
2. _Install the [DroidCam app](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
|
||||
3. _Download and install [OBS Studio](https://obsproject.com)_.
|
||||
4. _Download and install the [DroidCam OBS plugin](https://droidcam.app/obs)_.
|
||||
5. _Start OBS Studio_.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
flatpak install flathub com.obsproject.Studio
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
4. _Install the DroidCam OBS plugin_. This plugin integrates DroidCam with OBS Studio. Install it with:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
flatpak install flathub com.obsproject.Studio.Plugin.DroidCam
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
5. _Start OBS Studio_. Launch with:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
flatpak run com.obsproject.Studio
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`.
|
||||
7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in.
|
||||
6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480` to avoid the watermarks.
|
||||
7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video` or `OBS > Preferences... > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it.
|
||||
8. _Start virtual camera_. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide).
|
||||
9. _Verify the virtual camera setup_. Use `v4l2-ctl` to list the devices:
|
||||
9. _Verify the virtual camera setup and resolution_.
|
||||
- **Linux**: Use `v4l2-ctl` to list devices and check resolution:
|
||||
```bash
|
||||
v4l2-ctl --list-devices # find VirtualCam and note its /dev/videoX path
|
||||
v4l2-ctl -d /dev/videoX --get-fmt-video # replace with your VirtualCam path
|
||||
```
|
||||
You should see `VirtualCam` listed and resolution `640x480`.
|
||||
- **macOS**: Open Photo Booth or FaceTime and select "OBS Virtual Camera" as the input.
|
||||
- **Windows**: The native Camera app doesn't support virtual cameras. Use a video conferencing app (Zoom, Teams) or run `lerobot-find-cameras opencv` directly to verify.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
v4l2-ctl --list-devices
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
<details>
|
||||
<summary><strong>Troubleshooting</strong></summary>
|
||||
|
||||
You should see an entry like:
|
||||
> The virtual camera resolution is incorrect.
|
||||
|
||||
```
|
||||
VirtualCam (platform:v4l2loopback-000):
|
||||
/dev/video1
|
||||
```
|
||||
Delete the virtual camera source and recreate it. The resolution cannot be changed after creation.
|
||||
|
||||
10. _Check the camera resolution_. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`.
|
||||
> Error reading frame in background thread for OpenCVCamera(X): OpenCVCamera(X) frame width=640 or height=480 do not match configured width=1920 or height=1080.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
v4l2-ctl -d /dev/video1 --get-fmt-video
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
This error is caused by OBS Virtual Camera advertising a `1920x1080` resolution despite rescaling. The only fix for now is to comment out the width and height check in `_postprocess_image()`.
|
||||
|
||||
You should see an entry like:
|
||||
|
||||
```
|
||||
>>> Format Video Capture:
|
||||
>>> Width/Height : 640/480
|
||||
>>> Pixel Format : 'YUYV' (YUYV 4:2:2)
|
||||
```
|
||||
|
||||
Troubleshooting: If the resolution is not correct you will have to delete the Virtual Camera port and try again as it cannot be changed.
|
||||
|
||||
If everything is set up correctly, you can proceed with the rest of the tutorial.
|
||||
</details>
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
If everything is set up correctly, your phone will appear as a standard OpenCV camera and can be used with `OpenCVCamera`.
|
||||
|
||||
@@ -0,0 +1,165 @@
|
||||
# Damiao Motors and CAN Bus
|
||||
|
||||
This guide covers setup and usage of Damiao motors with LeRobot via CAN bus communication.
|
||||
|
||||
Currently, only Linux is supported, as the OpenArms CAN adapter only has drivers for Linux.
|
||||
|
||||
## Linux CAN Setup
|
||||
|
||||
Before using Damiao motors, you need to set up the CAN interface on your Linux system.
|
||||
|
||||
### Install CAN Utilities
|
||||
|
||||
```bash
|
||||
sudo apt-get install can-utils
|
||||
```
|
||||
|
||||
### Configure CAN Interface (Manual)
|
||||
|
||||
For standard CAN FD (recommended for OpenArms):
|
||||
|
||||
```bash
|
||||
sudo ip link set can0 down
|
||||
sudo ip link set can0 type can bitrate 1000000 dbitrate 5000000 fd on
|
||||
sudo ip link set can0 up
|
||||
```
|
||||
|
||||
For standard CAN (without FD):
|
||||
|
||||
```bash
|
||||
sudo ip link set can0 down
|
||||
sudo ip link set can0 type can bitrate 1000000
|
||||
sudo ip link set can0 up
|
||||
```
|
||||
|
||||
### Configure CAN Interface (Using LeRobot)
|
||||
|
||||
LeRobot provides a utility script to setup and test CAN interfaces:
|
||||
|
||||
```bash
|
||||
# Setup multiple interfaces (e.g., OpenArms Followers with 2 CAN buses)
|
||||
lerobot-setup-can --mode=setup --interfaces=can0,can1
|
||||
```
|
||||
|
||||
## Debugging CAN Communication
|
||||
|
||||
Use the built-in debug tools to test motor communication:
|
||||
|
||||
```bash
|
||||
# Test motors on all interfaces
|
||||
lerobot-setup-can --mode=test --interfaces=can0,can1
|
||||
|
||||
# Run speed/latency test
|
||||
lerobot-setup-can --mode=speed --interfaces=can0
|
||||
```
|
||||
|
||||
The test mode will scan for motors (IDs 0x01-0x08) and report which ones respond. Example output:
|
||||
|
||||
```
|
||||
can0: UP (CAN FD)
|
||||
Motor 0x01 (joint_1): ✓ FOUND
|
||||
→ Response 0x11 [FD]: 00112233...
|
||||
Motor 0x02 (joint_2): ✓ FOUND
|
||||
Motor 0x03 (joint_3): ✗ No response
|
||||
...
|
||||
Summary: 2/8 motors found
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
from lerobot.motors import Motor
|
||||
from lerobot.motors.damiao import DamiaoMotorsBus
|
||||
|
||||
# Define your motors with send/receive CAN IDs
|
||||
motors = {
|
||||
"joint_1": Motor(id=0x01, motor_type_str="dm8009", recv_id=0x11),
|
||||
"joint_2": Motor(id=0x02, motor_type_str="dm4340", recv_id=0x12),
|
||||
"joint_3": Motor(id=0x03, motor_type_str="dm4310", recv_id=0x13),
|
||||
}
|
||||
|
||||
# Create the bus
|
||||
bus = DamiaoMotorsBus(
|
||||
port="can0", # Linux socketcan interface
|
||||
motors=motors,
|
||||
)
|
||||
|
||||
# Connect
|
||||
bus.connect()
|
||||
```
|
||||
|
||||
### Reading Motor States
|
||||
|
||||
```python
|
||||
# Read single motor position (degrees)
|
||||
position = bus.read("Present_Position", "joint_1")
|
||||
|
||||
# Read from multiple motors
|
||||
positions = bus.sync_read("Present_Position") # All motors
|
||||
positions = bus.sync_read("Present_Position", ["joint_1", "joint_2"])
|
||||
|
||||
# Read all states at once (position, velocity, torque)
|
||||
states = bus.sync_read_all_states()
|
||||
# Returns: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
|
||||
```
|
||||
|
||||
### Writing Motor Commands
|
||||
|
||||
```python
|
||||
# Enable torque
|
||||
bus.enable_torque()
|
||||
|
||||
# Set goal position (degrees)
|
||||
bus.write("Goal_Position", "joint_1", 45.0)
|
||||
|
||||
# Set positions for multiple motors
|
||||
bus.sync_write("Goal_Position", {
|
||||
"joint_1": 45.0,
|
||||
"joint_2": -30.0,
|
||||
"joint_3": 90.0,
|
||||
})
|
||||
|
||||
# Disable torque
|
||||
bus.disable_torque()
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| -------------- | --------- | ----------------------------------------------------------- |
|
||||
| `port` | - | CAN interface (`can0`) or serial port (`/dev/cu.usbmodem*`) |
|
||||
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
|
||||
| `bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
|
||||
| `data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
|
||||
|
||||
## Motor Configuration
|
||||
|
||||
Each motor requires:
|
||||
|
||||
- `id`: CAN ID for sending commands
|
||||
- `motor_type`: One of the supported motor types (e.g., `"dm8009"`, `"dm4340"`)
|
||||
- `recv_id`: CAN ID for receiving responses
|
||||
|
||||
OpenArms default IDs follow the pattern: send ID `0x0N`, receive ID `0x1N` where N is the joint number.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### No Response from Motors
|
||||
|
||||
1. **Check power**
|
||||
2. **Verify CAN wiring**: Check CAN-H, CAN-L, and GND connections
|
||||
3. **Check motor IDs**: Use Damiao Debugging Tools to verify/configure IDs
|
||||
4. **Test CAN interface**: Run `candump can0` to see if messages are being received
|
||||
5. **Run diagnostics**: `lerobot-setup-can --mode=test --interfaces=can0`
|
||||
|
||||
### Motor Timeout Parameter
|
||||
|
||||
If motors were configured with timeout=0, they won't respond to commands. Use Damiao Debugging Tools to set a non-zero timeout value.
|
||||
|
||||
### Verify CAN FD Status
|
||||
|
||||
```bash
|
||||
ip -d link show can0 | grep fd
|
||||
```
|
||||
@@ -0,0 +1,278 @@
|
||||
# Using Subtasks in LeRobot Datasets
|
||||
|
||||
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
|
||||
|
||||
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
|
||||
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
|
||||
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
|
||||
|
||||
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
|
||||
|
||||
## What are Subtasks?
|
||||
|
||||
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
|
||||
|
||||
1. "Approach the apple"
|
||||
2. "Grasp the apple"
|
||||
3. "Lift the apple"
|
||||
4. "Move to basket"
|
||||
5. "Release the apple"
|
||||
|
||||
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
|
||||
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
|
||||
width="80%"
|
||||
/>
|
||||
|
||||
<p>
|
||||
<em>Figure: Overview of subtask annotation.</em>
|
||||
</p>
|
||||
|
||||
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
Subtask information is stored in the dataset metadata:
|
||||
|
||||
```
|
||||
my-dataset/
|
||||
├── data/
|
||||
│ └── ...
|
||||
├── meta/
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ ├── tasks.parquet
|
||||
│ ├── subtasks.parquet # Subtask index → subtask string mapping
|
||||
│ └── episodes/
|
||||
│ └── ...
|
||||
└── videos/
|
||||
└── ...
|
||||
```
|
||||
|
||||
### Subtasks Parquet File
|
||||
|
||||
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
|
||||
|
||||
| subtask_index | subtask (index column) |
|
||||
| ------------- | ---------------------- |
|
||||
| 0 | "Approach the apple" |
|
||||
| 1 | "Grasp the apple" |
|
||||
| 2 | "Lift the apple" |
|
||||
| ... | ... |
|
||||
|
||||
### Frame-Level Annotations
|
||||
|
||||
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
|
||||
|
||||
```python
|
||||
# Example frame data in the parquet file
|
||||
{
|
||||
"index": 42,
|
||||
"timestamp": 1.4,
|
||||
"episode_index": 0,
|
||||
"task_index": 0,
|
||||
"subtask_index": 2, # References "Lift the apple"
|
||||
"observation.state": [...],
|
||||
"action": [...],
|
||||
}
|
||||
```
|
||||
|
||||
## Annotating Datasets with Subtasks
|
||||
|
||||
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
|
||||
|
||||
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
|
||||
|
||||
After completing your annotation:
|
||||
|
||||
1. Click "Push to Hub" to upload your annotated dataset
|
||||
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
|
||||
|
||||
## Loading Datasets with Subtasks
|
||||
|
||||
When you load a dataset with subtask annotations, the subtask information is automatically available:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Load a dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Access a sample
|
||||
sample = dataset[100]
|
||||
|
||||
# The sample includes both task and subtask information
|
||||
print(sample["task"]) # "Collect the fruit"
|
||||
print(sample["subtask"]) # "Grasp the apple"
|
||||
print(sample["task_index"]) # tensor(0)
|
||||
print(sample["subtask_index"]) # tensor(2)
|
||||
```
|
||||
|
||||
### Checking for Subtask Support
|
||||
|
||||
You can check if a dataset has subtask annotations:
|
||||
|
||||
```python
|
||||
# Check if subtasks are available
|
||||
has_subtasks = (
|
||||
"subtask_index" in dataset.features
|
||||
and dataset.meta.subtasks is not None
|
||||
)
|
||||
|
||||
if has_subtasks:
|
||||
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
|
||||
print("Subtasks:", list(dataset.meta.subtasks.index))
|
||||
```
|
||||
|
||||
## Using Subtasks for Training
|
||||
|
||||
### With the Tokenizer Processor
|
||||
|
||||
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
|
||||
|
||||
```python
|
||||
from lerobot.processor.tokenizer_processor import TokenizerProcessor
|
||||
from lerobot.processor.pipeline import ProcessorPipeline
|
||||
|
||||
# Create a tokenizer processor
|
||||
tokenizer_processor = TokenizerProcessor(
|
||||
tokenizer_name_or_path="google/paligemma-3b-pt-224",
|
||||
padding="max_length",
|
||||
max_length=64,
|
||||
)
|
||||
|
||||
# The processor will automatically tokenize subtasks if present in the batch
|
||||
# and add them to the observation under:
|
||||
# - "observation.subtask.tokens"
|
||||
# - "observation.subtask.attention_mask"
|
||||
```
|
||||
|
||||
When subtasks are available in the batch, the tokenizer processor adds:
|
||||
|
||||
- `observation.subtask.tokens`: Tokenized subtask text
|
||||
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
|
||||
|
||||
### DataLoader with Subtasks
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=16,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
# Access subtask information in the batch
|
||||
subtasks = batch["subtask"] # List of subtask strings
|
||||
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
|
||||
|
||||
# Use for training hierarchical policies or reward models
|
||||
print(f"Batch subtasks: {set(subtasks)}")
|
||||
```
|
||||
|
||||
## Example Datasets with Subtask Annotations
|
||||
|
||||
Try loading a dataset with subtask annotations:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Example dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Explore the subtasks
|
||||
print("Available subtasks:")
|
||||
for subtask_name in dataset.meta.subtasks.index:
|
||||
print(f" - {subtask_name}")
|
||||
|
||||
# Get subtask distribution
|
||||
subtask_counts = {}
|
||||
for i in range(len(dataset)):
|
||||
sample = dataset[i]
|
||||
subtask = sample["subtask"]
|
||||
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
|
||||
|
||||
print("\nSubtask distribution:")
|
||||
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {subtask}: {count} frames")
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Hierarchical Policy Training
|
||||
|
||||
Train policies that predict both actions and current subtask:
|
||||
|
||||
```python
|
||||
class HierarchicalPolicy(nn.Module):
|
||||
def __init__(self, num_subtasks):
|
||||
super().__init__()
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
|
||||
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
actions = self.action_head(features)
|
||||
subtask_logits = self.subtask_head(features)
|
||||
return actions, subtask_logits
|
||||
```
|
||||
|
||||
### 2. Stage-Aware Reward Modeling (SARM)
|
||||
|
||||
Build reward models that understand task progression:
|
||||
|
||||
```python
|
||||
# SARM predicts:
|
||||
# - Stage: Which subtask is being executed (discrete)
|
||||
# - Progress: How far along the subtask (continuous 0-1)
|
||||
|
||||
class SARMRewardModel(nn.Module):
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
stage_logits = self.stage_classifier(features)
|
||||
progress = self.progress_regressor(features)
|
||||
return stage_logits, progress
|
||||
```
|
||||
|
||||
### 3. Progress Visualization
|
||||
|
||||
Monitor robot execution by tracking subtask progression:
|
||||
|
||||
```python
|
||||
def visualize_execution(model, observations):
|
||||
for t, obs in enumerate(observations):
|
||||
action, subtask_logits = model(obs)
|
||||
predicted_subtask = subtask_names[subtask_logits.argmax()]
|
||||
print(f"t={t}: Executing '{predicted_subtask}'")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### LeRobotDataset Properties
|
||||
|
||||
| Property | Type | Description |
|
||||
| --------------------------- | ---------------------- | ------------------------------------------ |
|
||||
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
|
||||
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
|
||||
|
||||
### Sample Keys
|
||||
|
||||
When subtasks are available, each sample includes:
|
||||
|
||||
| Key | Type | Description |
|
||||
| --------------- | -------------- | ------------------------------------ |
|
||||
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
|
||||
| `subtask` | `str` | Natural language subtask description |
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
|
||||
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
|
||||
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
|
||||
@@ -1,5 +1,11 @@
|
||||
# EarthRover Mini Plus
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Earth_Rover_Mini_5_240c9adc-4f9e-44b7-982f-5d1dc24af1d8.png.webp"
|
||||
alt="EarthRover Mini Plus"
|
||||
width="70%"
|
||||
/>
|
||||
|
||||
The EarthRover Mini Plus is a fully open source mobile robot that connects through the cloud using the Frodobots SDK. This lets you control the robot and record datasets for training AI models.
|
||||
|
||||
## What You Need
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
# LeKiwi
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/1740517739083.jpeg"
|
||||
alt="LeKiwi"
|
||||
width="70%"
|
||||
/>
|
||||
|
||||
In the steps below, we explain how to assemble the LeKiwi mobile robot.
|
||||
|
||||
## Source the parts
|
||||
|
||||
@@ -42,6 +42,7 @@ lerobot-eval \
|
||||
```
|
||||
|
||||
- `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.).
|
||||
- `--env.task_ids` picks task ids to run (`[0]`, `[1,2,3]`, etc.). Omit this flag (or set it to `null`) to run all tasks in the suite.
|
||||
- `--eval.batch_size` controls how many environments run in parallel.
|
||||
- `--eval.n_episodes` sets how many episodes to run in total.
|
||||
|
||||
|
||||
@@ -0,0 +1,197 @@
|
||||
## Order and Assemble the parts
|
||||
|
||||
First, assemble the OMX hardware following the official assembly guide.
|
||||
|
||||
OMX Assembly Guide: https://ai.robotis.com/omx/assembly_guide_omx.html
|
||||
|
||||
OMX robots are shipped preconfigured from the factory. Motor IDs, communication parameters, and joint offsets are already set, so no additional motor setup or calibration is required before using LeRobot.
|
||||
|
||||
## Install LeRobot 🤗
|
||||
|
||||
To install LeRobot, follow our [Installation Guide](./installation)
|
||||
|
||||
In addition to these instructions, you need to install the Dynamixel SDK:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dynamixel]"
|
||||
```
|
||||
|
||||
## Connect the robot
|
||||
|
||||
To find the port for each bus servo adapter, run this script:
|
||||
|
||||
```bash
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
This command runs and when prompted, disconnect the USB cable from either the leader or follower arm and press Enter. The output will show 'The port of this MotorsBus is [port]'. This identifies the port for the disconnected arm. Repeat for the other arm to identify both ports.
|
||||
|
||||
<hfoptions id="find_port">
|
||||
<hfoption id="Mac">
|
||||
|
||||
Example output on macOS:
|
||||
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the USB cable from your MotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect corresponding leader or follower arm and press Enter...]
|
||||
|
||||
The port of this MotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the USB cable.
|
||||
```
|
||||
|
||||
Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Linux">
|
||||
|
||||
On Linux, we strongly recommend using udev rules to assign persistent and human-readable device names to the OMX leader and follower arms. This avoids issues where device names such as ttyACM0 and ttyACM1 change when the robot is unplugged, replugged, or when the system is rebooted.
|
||||
|
||||
#### 1. Find your device serial numbers
|
||||
|
||||
You should have obtained the port numbers like ../../ttyACM? for the leader and follower using `lerobot-find-port`. You can match those results with the serial numbers using the `ls -l /dev/serial/by-id/` command.
|
||||
To create udev rules, you need the unique serial number for each OMX device. The easiest way is to list devices under:
|
||||
|
||||
```bash
|
||||
ls -l /dev/serial/by-id/
|
||||
```
|
||||
|
||||
You will see output similar to:
|
||||
|
||||
```bash
|
||||
usb-ROBOTIS_OpenRB-150_228BDD7B503059384C2E3120FF0A2B19-if00 -> ../../ttyACM0
|
||||
usb-ROBOTIS_OpenRB-150_67E1ED68503059384C2E3120FF092234-if00 -> ../../ttyACM1
|
||||
```
|
||||
|
||||
In each line, the serial number is the long string after `usb-ROBOTIS_OpenRB-150_` and before `-if00`.
|
||||
|
||||
Follower serial: `228BDD7B503059384C2E3120FF0A2B19`
|
||||
|
||||
Leader serial: `67E1ED68503059384C2E3120FF092234`
|
||||
|
||||
#### 2. Create the udev rule
|
||||
|
||||
Create a new udev rule file:
|
||||
|
||||
```bash
|
||||
sudo nano /etc/udev/rules.d/99-omx.rules
|
||||
```
|
||||
|
||||
Paste the following lines, replacing the serial numbers with the values you found above:
|
||||
|
||||
```bash
|
||||
SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="228BDD7B503059384C2E3120FF0A2B19", SYMLINK+="omx_follower"
|
||||
SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="67E1ED68503059384C2E3120FF092234", SYMLINK+="omx_leader"
|
||||
```
|
||||
|
||||
Save the file and reload udev rules:
|
||||
|
||||
```bash
|
||||
sudo udevadm control --reload-rules
|
||||
sudo udevadm trigger
|
||||
```
|
||||
|
||||
Now unplug and replug both devices once.
|
||||
|
||||
#### 3. Verify the symlinks
|
||||
|
||||
Check that the persistent device names exist:
|
||||
|
||||
```bash
|
||||
ls -l /dev/omx_follower /dev/omx_leader
|
||||
```
|
||||
|
||||
You should see them pointing to ttyACM\* devices:
|
||||
|
||||
```bash
|
||||
/dev/omx_follower -> ttyACM*
|
||||
/dev/omx_leader -> ttyACM*
|
||||
```
|
||||
|
||||
These names remain stable across reboots and reconnections.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Teleoperate
|
||||
|
||||
After identifying the correct ports, you can directly teleoperate the follower arm using the leader arm.
|
||||
|
||||
<hfoptions id="teleoperate">
|
||||
<hfoption id="Mac">
|
||||
|
||||
### Teleoperate without camera
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=<your_follower_port> \
|
||||
--robot.id=omx_follower_arm \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=<your_leader_port> \
|
||||
--teleop.id=omx_leader_arm
|
||||
```
|
||||
|
||||
During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps.
|
||||
|
||||
### Teleoperate with camera
|
||||
|
||||
You can also enable camera input during teleoperation by providing a camera configuration for the follower arm.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=<your_follower_port> \
|
||||
--robot.id=omx_follower_arm \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=<your_leader_port> \
|
||||
--teleop.id=omx_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Linux">
|
||||
|
||||
### Teleoperate without camera
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=/dev/omx_follower \
|
||||
--robot.id=omx_follower_arm \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=/dev/omx_leader \
|
||||
--teleop.id=omx_leader_arm
|
||||
```
|
||||
|
||||
During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps.
|
||||
|
||||
### Teleoperate with camera
|
||||
|
||||
You can also enable camera input during teleoperation by providing a camera configuration for the follower arm.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=/dev/omx_follower \
|
||||
--robot.id=omx_follower_arm \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=/dev/omx_leader \
|
||||
--teleop.id=omx_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own.
|
||||
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/robotis).
|
||||
@@ -0,0 +1,276 @@
|
||||
# OpenArm
|
||||
|
||||
[OpenArm](https://openarm.dev) is an open-source 7DOF humanoid arm designed for physical AI research and deployment.
|
||||
|
||||
To get your OpenArm, assembled or DIY, and join the global community, browse verified and certified manufacturers worldwide at [openarm.dev](https://openarm.dev).
|
||||
|
||||
## What's Unique?
|
||||
|
||||
- **Human-Scale Design**: OpenArm is designed with human-like proportions, scaled for a person around 160-165cm tall. This provides an optimal balance between practical reach and manageable inertia for safe, responsive operation.
|
||||
|
||||
- **Safety-First Architecture**: Built with QDD backdrivable motors and high compliance, OpenArm prioritizes safe human-robot interaction while maintaining practical payload capabilities (6.0kg peak / 4.1kg nominal) for real-world tasks.
|
||||
|
||||
- **Built for Durability**: Critical structural components use aluminum and stainless steel construction, ensuring robust performance for repetitive data collection and continuous research use.
|
||||
|
||||
- **Fully Accessible & Buildable**: Every component, from CNC parts and 3D-printed casings to electrical wiring is designed to be purchasable and buildable by individual researchers and labs, with complete fabrication data provided.
|
||||
|
||||
- **Practical & Affordable**: At $6,500 USD for a complete bimanual system, OpenArm delivers research-grade capabilities at a fraction of traditional humanoid robot costs.
|
||||
|
||||
## Platform Requirements
|
||||
|
||||
<Tip warning={true}>
|
||||
**Linux Only**: OpenArm currently only works on Linux. The CAN bus USB adapter
|
||||
does not have macOS drivers and has not been tested on Windows.
|
||||
</Tip>
|
||||
|
||||
## Safety Guide
|
||||
|
||||
Before operating OpenArm, please read the [official safety guide](https://docs.openarm.dev/getting-started/safety-guide). Key points:
|
||||
|
||||
- **Secure installation**: Fasten the arm to a flat, stable surface with screws or clamps
|
||||
- **Safe distance**: Keep body parts and objects outside the range of motion during operation
|
||||
- **Protective equipment**: Always wear safety goggles; use additional PPE as needed
|
||||
- **Payload limits**: Do not exceed specified payload limits (6.0kg peak / 4.1kg nominal per arm)
|
||||
- **Emergency stop**: Know the location and operation of the emergency stop device
|
||||
- **Regular inspection**: Check for loose screws, damaged mechanical limits, unusual noises, and wiring damage
|
||||
|
||||
## Hardware Setup
|
||||
|
||||
Follow the official [OpenArm hardware documentation](https://docs.openarm.dev) for:
|
||||
|
||||
- Bill of materials and sourcing
|
||||
- 3D printing instructions
|
||||
- Mechanical assembly
|
||||
- Electrical wiring
|
||||
|
||||
The hardware repositories are available at [github.com/enactic/openarm](https://github.com/enactic/openarm).
|
||||
|
||||
## CAN Bus Setup
|
||||
|
||||
OpenArm uses CAN bus communication with Damiao motors. Once you have the CAN bus USB adapter plugged into your Linux PC, follow the [Damiao Motors and CAN Bus guide](./damiao) to configure the interface.
|
||||
|
||||
Quick setup:
|
||||
|
||||
```bash
|
||||
# Setup CAN interfaces
|
||||
lerobot-setup-can --mode=setup --interfaces=can0,can1
|
||||
|
||||
# Test motor communication
|
||||
lerobot-setup-can --mode=test --interfaces=can0,can1
|
||||
```
|
||||
|
||||
## Install LeRobot 🤗
|
||||
|
||||
Follow our [Installation Guide](./installation), then install the Damiao motor support:
|
||||
|
||||
```bash
|
||||
pip install -e ".[damiao]"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Follower Arm (Robot)
|
||||
|
||||
<hfoptions id="follower">
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=openarm_follower \
|
||||
--robot.port=can0 \
|
||||
--robot.side=right \
|
||||
--robot.id=my_openarm_follower
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
```python
|
||||
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||
|
||||
config = OpenArmFollowerConfig(
|
||||
port="can0",
|
||||
side="right", # or "left" for left arm
|
||||
id="my_openarm_follower",
|
||||
)
|
||||
|
||||
follower = OpenArmFollower(config)
|
||||
follower.connect()
|
||||
|
||||
# Read current state
|
||||
obs = follower.get_observation()
|
||||
print(obs)
|
||||
|
||||
# Send action (position in degrees)
|
||||
action = {
|
||||
"joint_1.pos": 0.0,
|
||||
"joint_2.pos": 0.0,
|
||||
"joint_3.pos": 0.0,
|
||||
"joint_4.pos": 45.0,
|
||||
"joint_5.pos": 0.0,
|
||||
"joint_6.pos": 0.0,
|
||||
"joint_7.pos": 0.0,
|
||||
"gripper.pos": 0.0,
|
||||
}
|
||||
follower.send_action(action)
|
||||
|
||||
follower.disconnect()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Leader Arm (Teleoperator)
|
||||
|
||||
The leader arm is used for teleoperation - manually moving it to control the follower arm.
|
||||
|
||||
<hfoptions id="leader">
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=openarm_leader \
|
||||
--teleop.port=can1 \
|
||||
--teleop.id=my_openarm_leader
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
```python
|
||||
from lerobot.teleoperators.openarm_leader import OpenArmLeader, OpenArmLeaderConfig
|
||||
|
||||
config = OpenArmLeaderConfig(
|
||||
port="can1",
|
||||
id="my_openarm_leader",
|
||||
manual_control=True, # Disable torque for manual movement
|
||||
)
|
||||
|
||||
leader = OpenArmLeader(config)
|
||||
leader.connect()
|
||||
|
||||
# Read current position (as action to send to follower)
|
||||
action = leader.get_action()
|
||||
print(action)
|
||||
|
||||
leader.disconnect()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Teleoperation
|
||||
|
||||
To teleoperate OpenArm with leader-follower control:
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=openarm_follower \
|
||||
--robot.port=can0 \
|
||||
--robot.side=right \
|
||||
--robot.id=my_follower \
|
||||
--teleop.type=openarm_leader \
|
||||
--teleop.port=can1 \
|
||||
--teleop.id=my_leader
|
||||
```
|
||||
|
||||
### Bimanual Teleoperation
|
||||
|
||||
To teleoperate a bimanual OpenArm setup with two leader and two follower arms:
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.left_arm_config.port=can0 \
|
||||
--robot.left_arm_config.side=left \
|
||||
--robot.right_arm_config.port=can1 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.id=my_bimanual_follower \
|
||||
--teleop.type=bi_openarm_leader \
|
||||
--teleop.left_arm_config.port=can2 \
|
||||
--teleop.right_arm_config.port=can3 \
|
||||
--teleop.id=my_bimanual_leader
|
||||
```
|
||||
|
||||
### Recording Data
|
||||
|
||||
To record a dataset during teleoperation:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=openarm_follower \
|
||||
--robot.port=can0 \
|
||||
--robot.side=right \
|
||||
--robot.id=my_follower \
|
||||
--teleop.type=openarm_leader \
|
||||
--teleop.port=can1 \
|
||||
--teleop.id=my_leader \
|
||||
--repo-id=my_hf_username/my_openarm_dataset \
|
||||
--fps=30 \
|
||||
--num-episodes=10
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Follower Configuration
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| --------------------- | --------- | ---------------------------------------------------------- |
|
||||
| `port` | - | CAN interface (e.g., `can0`) |
|
||||
| `side` | `None` | Arm side: `"left"`, `"right"`, or `None` for custom limits |
|
||||
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
|
||||
| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
|
||||
| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
|
||||
| `max_relative_target` | `None` | Safety limit for relative target positions |
|
||||
| `position_kp` | Per-joint | Position control proportional gains |
|
||||
| `position_kd` | Per-joint | Position control derivative gains |
|
||||
|
||||
### Leader Configuration
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| ------------------ | --------- | ----------------------------------- |
|
||||
| `port` | - | CAN interface (e.g., `can1`) |
|
||||
| `manual_control` | `True` | Disable torque for manual movement |
|
||||
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
|
||||
| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
|
||||
| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
|
||||
|
||||
## Motor Configuration
|
||||
|
||||
OpenArm uses Damiao motors with the following default configuration:
|
||||
|
||||
| Joint | Motor Type | Send ID | Recv ID |
|
||||
| --------------------------- | ---------- | ------- | ------- |
|
||||
| joint_1 (Shoulder pan) | DM8009 | 0x01 | 0x11 |
|
||||
| joint_2 (Shoulder lift) | DM8009 | 0x02 | 0x12 |
|
||||
| joint_3 (Shoulder rotation) | DM4340 | 0x03 | 0x13 |
|
||||
| joint_4 (Elbow flex) | DM4340 | 0x04 | 0x14 |
|
||||
| joint_5 (Wrist roll) | DM4310 | 0x05 | 0x15 |
|
||||
| joint_6 (Wrist pitch) | DM4310 | 0x06 | 0x16 |
|
||||
| joint_7 (Wrist rotation) | DM4310 | 0x07 | 0x17 |
|
||||
| gripper | DM4310 | 0x08 | 0x18 |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### No Response from Motors
|
||||
|
||||
1. Check power supply connections
|
||||
2. Verify CAN wiring (CAN-H, CAN-L, GND)
|
||||
3. Run diagnostics: `lerobot-setup-can --mode=test --interfaces=can0`
|
||||
4. See the [Damiao troubleshooting guide](./damiao#troubleshooting) for more details
|
||||
|
||||
### CAN Interface Not Found
|
||||
|
||||
Ensure the CAN interface is configured:
|
||||
|
||||
```bash
|
||||
ip link show can0
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- [OpenArm Website](https://openarm.dev)
|
||||
- [OpenArm Documentation](https://docs.openarm.dev)
|
||||
- [OpenArm GitHub](https://github.com/enactic/openarm)
|
||||
- [Safety Guide](https://docs.openarm.dev/getting-started/safety-guide)
|
||||
- [Damiao Motors and CAN Bus](./damiao)
|
||||
@@ -1,5 +1,18 @@
|
||||
# SO-101
|
||||
|
||||
<div style="display: flex; align-items: center; gap: 10px;">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/SO101_Follower.webp"
|
||||
alt="SO-101"
|
||||
width="60%"
|
||||
/>
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/SO101_Leader.webp"
|
||||
alt="SO-101"
|
||||
width="60%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
In the steps below, we explain how to assemble our flagship robot, the SO-101.
|
||||
|
||||
## Source the parts
|
||||
|
||||
@@ -188,7 +188,105 @@ Press `Ctrl+C` to stop the policy.
|
||||
|
||||
## Running in Simulation Mode (MuJoCo)
|
||||
|
||||
You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config.
|
||||
You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI.
|
||||
|
||||
### Calibrate Exoskeleton Teleoperator
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo
|
||||
```
|
||||
|
||||
### Teleoperate in Simulation
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--fps=100
|
||||
```
|
||||
|
||||
### Record Dataset in Simulation
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
|
||||
|
||||
---
|
||||
|
||||
## Running on Real Robot
|
||||
|
||||
Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot.
|
||||
|
||||
### Start the Camera Server
|
||||
|
||||
On the robot, start the ZMQ image server:
|
||||
|
||||
```bash
|
||||
python src/lerobot/cameras/zmq/image_server.py
|
||||
```
|
||||
|
||||
Keep this running in a separate terminal for camera streaming during recording.
|
||||
|
||||
### Teleoperate Real Robot
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--fps=100
|
||||
```
|
||||
|
||||
### Record Dataset on Real Robot
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
**Note**: Update `server_address` to match your robot's camera server IP.
|
||||
|
||||
Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real)
|
||||
|
||||
---
|
||||
|
||||
## Additional Resources
|
||||
|
||||
|
||||
@@ -95,26 +95,26 @@ Convert an image-based dataset to video format, creating a new LeRobotDataset wh
|
||||
# Local-only: Save to a custom output directory (no hub push)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
# Save with new repo_id (local storage)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video
|
||||
--operation.type convert_image_to_video
|
||||
|
||||
# Convert and push to Hugging Face Hub
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
# Convert with custom video codec and quality settings
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.vcodec libsvtav1 \
|
||||
--operation.pix_fmt yuv420p \
|
||||
@@ -124,16 +124,23 @@ lerobot-edit-dataset \
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.episode_indices "[0, 1, 2, 5, 10]"
|
||||
|
||||
# Convert with multiple workers for parallel processing
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.num_workers 8
|
||||
|
||||
# For memory-constrained systems, users can now specify limits:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.max_episodes_per_batch 50 \
|
||||
--operation.max_frames_per_batch 10000
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
@@ -81,24 +81,25 @@ def replay(cfg: ReplayConfig):
|
||||
actions = dataset.hf_dataset.select_columns(ACTION)
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
try:
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
key = f"{name.removeprefix('main_')}.pos"
|
||||
action[key] = action_array[i].item()
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
key = f"{name.removeprefix('main_')}.pos"
|
||||
action[key] = action_array[i].item()
|
||||
|
||||
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
|
||||
action["elbow_flex.pos"] -= 90
|
||||
robot.send_action(action)
|
||||
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
|
||||
action["elbow_flex.pos"] -= 90
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
|
||||
|
||||
robot.disconnect()
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
|
||||
finally:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+45
-43
@@ -78,40 +78,24 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
print("Starting evaluate loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -120,24 +104,42 @@ def main():
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+45
-44
@@ -74,40 +74,23 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_record")
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
try:
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {recorded_episodes}")
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {recorded_episodes}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
@@ -115,26 +98,44 @@ def main():
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+17
-15
@@ -42,25 +42,27 @@ def main():
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx][ACTION][i])
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
robot.disconnect()
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
finally:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -142,38 +142,24 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -182,24 +168,41 @@ def main():
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -149,38 +149,23 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_record")
|
||||
|
||||
if not robot.is_connected or not phone.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
try:
|
||||
if not robot.is_connected or not phone.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop. Move your phone to teleoperate the robot...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print("Starting record loop. Move your phone to teleoperate the robot...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
@@ -188,25 +173,43 @@ def main():
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -73,32 +73,34 @@ def main():
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i])
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
finally:
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -142,38 +142,24 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="so100_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -182,24 +168,41 @@ def main():
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -146,38 +146,23 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording_phone")
|
||||
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
try:
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print("Starting record loop...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
@@ -185,25 +170,44 @@ def main():
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -74,32 +74,35 @@ def main():
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i])
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
finally:
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -30,6 +30,7 @@ def main():
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
client_device="cpu",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="<user>/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
|
||||
+10
-2
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.4.3"
|
||||
version = "0.4.4"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
dynamic = ["readme"]
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -102,14 +102,20 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
damiao = ["python-can>=4.2.0,<5.0.0"]
|
||||
|
||||
# Robots
|
||||
openarms = ["lerobot[damiao]"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
unitree_g1 = [
|
||||
"pyzmq>=26.2.1,<28.0.0",
|
||||
"onnxruntime>=1.16.0,<2.0.0"
|
||||
"onnxruntime>=1.16.0,<2.0.0",
|
||||
"pin>=3.0.0,<4.0.0",
|
||||
"meshcat>=0.3.0,<0.4.0",
|
||||
"matplotlib>=3.9.0,<4.0.0",
|
||||
"casadi>=3.6.0,<4.0.0",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
@@ -203,6 +209,7 @@ lerobot-info="lerobot.scripts.lerobot_info:main"
|
||||
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.packages.find]
|
||||
@@ -278,6 +285,7 @@ default.extend-ignore-identifiers-re = [
|
||||
"thw",
|
||||
"inpt",
|
||||
"ROBOTIS",
|
||||
"OT_VALUE"
|
||||
]
|
||||
|
||||
# TODO: Uncomment when ready to use
|
||||
|
||||
@@ -126,6 +126,12 @@ class RobotClientConfig:
|
||||
|
||||
# Device configuration
|
||||
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
|
||||
client_device: str = field(
|
||||
default="cpu",
|
||||
metadata={
|
||||
"help": "Device to move actions to after receiving from server (e.g., for downstream planners)"
|
||||
},
|
||||
)
|
||||
|
||||
# Control behavior configuration
|
||||
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
|
||||
@@ -161,6 +167,9 @@ class RobotClientConfig:
|
||||
if not self.policy_device:
|
||||
raise ValueError("policy_device cannot be empty")
|
||||
|
||||
if not self.client_device:
|
||||
raise ValueError("client_device cannot be empty")
|
||||
|
||||
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
|
||||
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
|
||||
|
||||
@@ -184,6 +193,7 @@ class RobotClientConfig:
|
||||
"policy_type": self.policy_type,
|
||||
"pretrained_name_or_path": self.pretrained_name_or_path,
|
||||
"policy_device": self.policy_device,
|
||||
"client_device": self.client_device,
|
||||
"chunk_size_threshold": self.chunk_size_threshold,
|
||||
"fps": self.fps,
|
||||
"actions_per_chunk": self.actions_per_chunk,
|
||||
|
||||
@@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
||||
|
||||
# All action chunking policies
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05", "groot"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"]
|
||||
|
||||
@@ -18,6 +18,7 @@ import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -39,8 +40,8 @@ from lerobot.utils.utils import init_logging
|
||||
|
||||
Action = torch.Tensor
|
||||
|
||||
# observation as received from the robot
|
||||
RawObservation = dict[str, torch.Tensor]
|
||||
# observation as received from the robot (can be numpy arrays, floats, etc.)
|
||||
RawObservation = dict[str, Any]
|
||||
|
||||
# observation as those recorded in LeRobot dataset (keys are different)
|
||||
LeRobotObservation = dict[str, torch.Tensor]
|
||||
|
||||
@@ -381,6 +381,8 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
|
||||
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
|
||||
|
||||
action_tensor = action_tensor.detach().cpu()
|
||||
|
||||
"""5. Convert to TimedAction list"""
|
||||
action_chunk = self._time_action_chunk(
|
||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
||||
|
||||
@@ -25,6 +25,7 @@ python src/lerobot/async_inference/robot_client.py \
|
||||
--policy_type=act \
|
||||
--pretrained_name_or_path=user/model \
|
||||
--policy_device=mps \
|
||||
--client_device=cpu \
|
||||
--actions_per_chunk=50 \
|
||||
--chunk_size_threshold=0.5 \
|
||||
--aggregate_fn_name=weighted_average \
|
||||
@@ -40,6 +41,7 @@ from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
@@ -47,7 +49,6 @@ import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -285,6 +286,21 @@ class RobotClient:
|
||||
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
||||
deserialize_time = time.perf_counter() - deserialize_start
|
||||
|
||||
# Log device type of received actions
|
||||
if len(timed_actions) > 0:
|
||||
received_device = timed_actions[0].get_action().device.type
|
||||
self.logger.debug(f"Received actions on device: {received_device}")
|
||||
|
||||
# Move actions to client_device (e.g., for downstream planners that need GPU)
|
||||
client_device = self.config.client_device
|
||||
if client_device != "cpu":
|
||||
for timed_action in timed_actions:
|
||||
if timed_action.get_action().device.type != client_device:
|
||||
timed_action.action = timed_action.get_action().to(client_device)
|
||||
self.logger.debug(f"Converted actions to device: {client_device}")
|
||||
else:
|
||||
self.logger.debug(f"Actions kept on device: {client_device}")
|
||||
|
||||
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
|
||||
|
||||
# Calculate network latency if we have matching observations
|
||||
@@ -351,7 +367,7 @@ class RobotClient:
|
||||
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
|
||||
return action
|
||||
|
||||
def control_loop_action(self, verbose: bool = False) -> RobotAction:
|
||||
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
|
||||
"""Reading and performing actions in local queue"""
|
||||
|
||||
# Lock only for queue operations
|
||||
|
||||
@@ -15,11 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
from .configs import CameraConfig, ColorMode
|
||||
from .configs import CameraConfig
|
||||
|
||||
|
||||
class Camera(abc.ABC):
|
||||
@@ -30,20 +31,12 @@ class Camera(abc.ABC):
|
||||
|
||||
Manages basic camera properties (FPS, resolution) and core operations:
|
||||
- Connection/disconnection
|
||||
- Frame capture (sync/async)
|
||||
- Frame capture (sync/async/latest)
|
||||
|
||||
Attributes:
|
||||
fps (int | None): Configured frames per second
|
||||
width (int | None): Frame width in pixels
|
||||
height (int | None): Frame height in pixels
|
||||
|
||||
Example:
|
||||
class MyCamera(Camera):
|
||||
def __init__(self, config): ...
|
||||
@property
|
||||
def is_connected(self) -> bool: ...
|
||||
def connect(self, warmup=True): ...
|
||||
# Plus other required methods
|
||||
"""
|
||||
|
||||
def __init__(self, config: CameraConfig):
|
||||
@@ -56,6 +49,32 @@ class Camera(abc.ABC):
|
||||
self.width: int | None = config.width
|
||||
self.height: int | None = config.height
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Context manager entry.
|
||||
Automatically connects to the camera.
|
||||
"""
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
"""
|
||||
Context manager exit.
|
||||
Automatically disconnects, ensuring resources are released even on error.
|
||||
"""
|
||||
self.disconnect()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
Destructor safety net.
|
||||
Attempts to disconnect if the object is garbage collected without cleanup.
|
||||
"""
|
||||
try:
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
except Exception: # nosec B110
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
@@ -89,12 +108,10 @@ class Camera(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""Capture and return a single frame from the camera.
|
||||
def read(self) -> NDArray[Any]:
|
||||
"""Capture and return a single frame from the camera synchronously.
|
||||
|
||||
Args:
|
||||
color_mode: Desired color mode for the output frame. If None,
|
||||
uses the camera's default color mode.
|
||||
This is a blocking call that will wait for the hardware and its SDK.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured frame as a numpy array.
|
||||
@@ -103,17 +120,64 @@ class Camera(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def async_read(self, timeout_ms: float = ...) -> NDArray[Any]:
|
||||
"""Asynchronously capture and return a single frame from the camera.
|
||||
"""Return the most recent new frame.
|
||||
|
||||
This method retrieves the latest frame captured by the background thread.
|
||||
If a new frame is already available in the buffer (captured since the last call),
|
||||
it returns it immediately.
|
||||
|
||||
It blocks up to `timeout_ms` only if the buffer is empty or if the latest frame
|
||||
was already consumed by a previous `async_read` call.
|
||||
|
||||
Essentially, this method return the latest unconsumed frame, waiting if necessary
|
||||
for a new one to arrive within the specified timeout.
|
||||
|
||||
Usage:
|
||||
- Ideal for control loops where you want to ensure every processed frame
|
||||
is fresh, effectively synchronizing your loop to the camera's FPS.
|
||||
- Causes of a timeout usually include: very low camera FPS, heavy processing load,
|
||||
or if the camera is disconnected.
|
||||
|
||||
Args:
|
||||
timeout_ms: Maximum time to wait for a frame in milliseconds.
|
||||
Defaults to implementation-specific timeout.
|
||||
timeout_ms: Maximum time to wait for a new frame in milliseconds.
|
||||
Defaults to 200ms (0.2s).
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured frame as a numpy array.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If no new frame arrives within `timeout_ms`.
|
||||
"""
|
||||
pass
|
||||
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Usage:
|
||||
Ideal for scenarios requiring zero latency or decoupled frequencies & when
|
||||
we want a guaranteed frame, such as UI visualization, logging, or
|
||||
non-critical monitoring.
|
||||
|
||||
Returns:
|
||||
NDArray[Any]: The frame image (numpy array).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
NotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
warnings.warn(
|
||||
f"{self.__class__.__name__}.read_latest() is not implemented. "
|
||||
"Please override read_latest(); it will be required in future releases.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.async_read()
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from the camera and release resources."""
|
||||
|
||||
@@ -70,34 +70,24 @@ class OpenCVCamera(Camera):
|
||||
Example:
|
||||
```python
|
||||
from lerobot.cameras.opencv import OpenCVCamera
|
||||
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig, ColorMode, Cv2Rotation
|
||||
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig
|
||||
|
||||
# Basic usage with camera index 0
|
||||
config = OpenCVCameraConfig(index_or_path=0)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
|
||||
# Read 1 frame synchronously
|
||||
# Read 1 frame synchronously (blocking)
|
||||
color_image = camera.read()
|
||||
print(color_image.shape)
|
||||
|
||||
# Read 1 frame asynchronously
|
||||
# Read 1 frame asynchronously (waits for new frame with a timeout)
|
||||
async_image = camera.async_read()
|
||||
|
||||
# Get the latest frame immediately (no wait, returns timestamp)
|
||||
latest_image, timestamp = camera.read_latest()
|
||||
|
||||
# When done, properly disconnect the camera using
|
||||
camera.disconnect()
|
||||
|
||||
# Example with custom settings
|
||||
custom_config = OpenCVCameraConfig(
|
||||
index_or_path='/dev/video0', # Or use an index
|
||||
fps=30,
|
||||
width=1280,
|
||||
height=720,
|
||||
color_mode=ColorMode.RGB,
|
||||
rotation=Cv2Rotation.ROTATE_90
|
||||
)
|
||||
custom_camera = OpenCVCamera(custom_config)
|
||||
# ... connect, read, disconnect ...
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -123,6 +113,7 @@ class OpenCVCamera(Camera):
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_timestamp: float | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
@@ -146,12 +137,16 @@ class OpenCVCamera(Camera):
|
||||
Connects to the OpenCV camera specified in the configuration.
|
||||
|
||||
Initializes the OpenCV VideoCapture object, sets desired camera properties
|
||||
(FPS, width, height), and performs initial checks.
|
||||
(FPS, width, height), starts the background reading thread and performs initial checks.
|
||||
|
||||
Args:
|
||||
warmup (bool): If True, waits at connect() time until at least one valid frame
|
||||
has been captured by the background thread. Defaults to True.
|
||||
|
||||
Raises:
|
||||
DeviceAlreadyConnectedError: If the camera is already connected.
|
||||
ConnectionError: If the specified camera index/path is not found or the camera is found but fails to open.
|
||||
RuntimeError: If the camera opens but fails to apply requested FPS/resolution settings.
|
||||
ConnectionError: If the specified camera index/path is not found or fails to open.
|
||||
RuntimeError: If the camera opens but fails to apply requested settings.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
@@ -170,12 +165,16 @@ class OpenCVCamera(Camera):
|
||||
)
|
||||
|
||||
self._configure_capture_settings()
|
||||
self._start_read_thread()
|
||||
|
||||
if warmup:
|
||||
if warmup and self.warmup_s > 0:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.warmup_s:
|
||||
self.read()
|
||||
self.async_read(timeout_ms=self.warmup_s * 1000)
|
||||
time.sleep(0.1)
|
||||
with self.frame_lock:
|
||||
if self.latest_frame is None:
|
||||
raise ConnectionError(f"{self} failed to capture frames during warmup.")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@@ -196,8 +195,7 @@ class OpenCVCamera(Camera):
|
||||
Raises:
|
||||
RuntimeError: If the camera fails to set any of the specified properties
|
||||
to the requested value.
|
||||
DeviceNotConnectedError: If the camera is not connected when attempting
|
||||
to configure settings.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
|
||||
@@ -339,6 +337,17 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return found_cameras_info
|
||||
|
||||
def _read_from_hardware(self) -> NDArray[Any]:
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
ret, frame = self.videocapture.read()
|
||||
|
||||
if not ret:
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
|
||||
return frame
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
@@ -346,11 +355,6 @@ class OpenCVCamera(Camera):
|
||||
This is a blocking call. It waits for the next available frame from the
|
||||
camera hardware via OpenCV.
|
||||
|
||||
Args:
|
||||
color_mode (Optional[ColorMode]): If specified, overrides the default
|
||||
color mode (`self.color_mode`) for this read operation (e.g.,
|
||||
request RGB even if default is BGR).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured frame as a NumPy array in the format
|
||||
(height, width, channels), using the specified or default
|
||||
@@ -362,34 +366,34 @@ class OpenCVCamera(Camera):
|
||||
received frame dimensions don't match expectations before rotation.
|
||||
ValueError: If an invalid `color_mode` is requested.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
if color_mode is not None:
|
||||
logger.warning(
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
ret, frame = self.videocapture.read()
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if not ret or frame is None:
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
processed_frame = self._postprocess_image(frame, color_mode)
|
||||
self.new_frame_event.clear()
|
||||
frame = self.async_read(timeout_ms=10000)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return processed_frame
|
||||
return frame
|
||||
|
||||
def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
def _postprocess_image(self, image: NDArray[Any]) -> NDArray[Any]:
|
||||
"""
|
||||
Applies color conversion, dimension validation, and rotation to a raw frame.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The raw image frame (expected BGR format from OpenCV).
|
||||
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
|
||||
uses the instance's default `self.color_mode`.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The processed image frame.
|
||||
@@ -399,11 +403,10 @@ class OpenCVCamera(Camera):
|
||||
RuntimeError: If the raw frame dimensions do not match the configured
|
||||
`width` and `height`.
|
||||
"""
|
||||
requested_color_mode = self.color_mode if color_mode is None else color_mode
|
||||
|
||||
if requested_color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"Invalid color mode '{requested_color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
)
|
||||
|
||||
h, w, c = image.shape
|
||||
@@ -417,7 +420,7 @@ class OpenCVCamera(Camera):
|
||||
raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).")
|
||||
|
||||
processed_image = image
|
||||
if requested_color_mode == ColorMode.RGB:
|
||||
if self.color_mode == ColorMode.RGB:
|
||||
processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]:
|
||||
@@ -431,7 +434,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame
|
||||
2. Stores result in latest_frame (thread-safe)
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
@@ -439,30 +442,37 @@ class OpenCVCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
raw_frame = self._read_from_hardware()
|
||||
processed_frame = self._postprocess_image(raw_frame)
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = color_image
|
||||
self.latest_frame = processed_frame
|
||||
self.latest_timestamp = capture_time
|
||||
self.new_frame_event.set()
|
||||
failure_count = 0
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
if failure_count <= 10:
|
||||
failure_count += 1
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
else:
|
||||
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
"""Starts or restarts the background read thread if it's not running."""
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=0.1)
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
self._stop_read_thread()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
time.sleep(0.1)
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
"""Signals the background read thread to stop and waits for it to join."""
|
||||
@@ -475,6 +485,11 @@ class OpenCVCamera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
@@ -482,6 +497,7 @@ class OpenCVCamera(Camera):
|
||||
This method retrieves the most recent frame captured by the background
|
||||
read thread. It does not block waiting for the camera hardware directly,
|
||||
but may wait up to timeout_ms for the background thread to provide a frame.
|
||||
It is “best effort” under high FPS.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
@@ -500,13 +516,12 @@ class OpenCVCamera(Camera):
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
thread_alive = self.thread is not None and self.thread.is_alive()
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
|
||||
f"Read thread alive: {thread_alive}."
|
||||
f"Read thread alive: {self.thread.is_alive()}."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
@@ -518,6 +533,42 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Returns:
|
||||
NDArray[Any]: The frame image (numpy array).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera and cleans up resources.
|
||||
@@ -538,4 +589,9 @@ class OpenCVCamera(Camera):
|
||||
self.videocapture.release()
|
||||
self.videocapture = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -80,6 +80,8 @@ class Reachy2Camera(Camera):
|
||||
self.config = config
|
||||
|
||||
self.color_mode = config.color_mode
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_timestamp: float | None = None
|
||||
|
||||
self.cam_manager: CameraManager | None = None
|
||||
|
||||
@@ -125,12 +127,7 @@ class Reachy2Camera(Camera):
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
This is a blocking call.
|
||||
|
||||
Args:
|
||||
color_mode (Optional[ColorMode]): If specified, overrides the default
|
||||
color mode (`self.color_mode`) for this read operation (e.g.,
|
||||
request RGB even if default is BGR).
|
||||
This method retrieves the most recent frame available in Reachy 2's low-level software.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured frame as a NumPy array in the format
|
||||
@@ -145,6 +142,11 @@ class Reachy2Camera(Camera):
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if color_mode is not None:
|
||||
logger.warning(
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
|
||||
|
||||
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
|
||||
@@ -165,11 +167,18 @@ class Reachy2Camera(Camera):
|
||||
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
|
||||
|
||||
if frame is None:
|
||||
return np.empty((0, 0, 3), dtype=np.uint8)
|
||||
raise RuntimeError(f"Internal error: No frame available for {self}.")
|
||||
|
||||
if self.config.color_mode == "rgb":
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
)
|
||||
if self.color_mode == ColorMode.RGB:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
self.latest_frame = frame
|
||||
self.latest_timestamp = time.perf_counter()
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
@@ -177,13 +186,7 @@ class Reachy2Camera(Camera):
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame.
|
||||
|
||||
This method retrieves the most recent frame available in Reachy 2's low-level software.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
to become available. Defaults to 200ms (0.2 seconds).
|
||||
Same as read()
|
||||
|
||||
Returns:
|
||||
np.ndarray: The latest captured frame as a NumPy array in the format
|
||||
@@ -197,12 +200,38 @@ class Reachy2Camera(Camera):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
frame = self.read()
|
||||
return self.read()
|
||||
|
||||
if frame is None:
|
||||
raise RuntimeError(f"Internal error: No frame available for {self}.")
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
return frame
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Returns:
|
||||
tuple[NDArray, float]:
|
||||
- The frame image (numpy array).
|
||||
- The timestamp (time.perf_counter) when this frame was captured.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.latest_frame is None or self.latest_timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - self.latest_timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return self.latest_frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -72,15 +72,14 @@ class RealSenseCamera(Camera):
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect()
|
||||
|
||||
# Read 1 frame synchronously
|
||||
# Read 1 frame synchronously (blocking)
|
||||
color_image = camera.read()
|
||||
print(color_image.shape)
|
||||
|
||||
# Read 1 frame asynchronously
|
||||
# Read 1 frame asynchronously (waits for new frame with a timeout)
|
||||
async_image = camera.async_read()
|
||||
|
||||
# When done, properly disconnect the camera using
|
||||
camera.disconnect()
|
||||
# Get the latest frame immediately (no wait, returns timestamp)
|
||||
latest_image, timestamp = camera.read_latest()
|
||||
|
||||
# Example with depth capture and custom settings
|
||||
custom_config = RealSenseCameraConfig(
|
||||
@@ -133,7 +132,9 @@ class RealSenseCamera(Camera):
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_color_frame: NDArray[Any] | None = None
|
||||
self.latest_depth_frame: NDArray[Any] | None = None
|
||||
self.latest_timestamp: float | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
@@ -158,6 +159,10 @@ class RealSenseCamera(Camera):
|
||||
Initializes the RealSense pipeline, configures the required streams (color
|
||||
and optionally depth), starts the pipeline, and validates the actual stream settings.
|
||||
|
||||
Args:
|
||||
warmup (bool): If True, waits at connect() time until at least one valid frame
|
||||
has been captured by the background thread. Defaults to True.
|
||||
|
||||
Raises:
|
||||
DeviceAlreadyConnectedError: If the camera is already connected.
|
||||
ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique).
|
||||
@@ -181,15 +186,18 @@ class RealSenseCamera(Camera):
|
||||
) from e
|
||||
|
||||
self._configure_capture_settings()
|
||||
self._start_read_thread()
|
||||
|
||||
if warmup:
|
||||
time.sleep(
|
||||
1
|
||||
) # NOTE(Steven): RS cameras need a bit of time to warm up before the first read. If we don't wait, the first read from the warmup will raise.
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.warmup_s:
|
||||
self.read()
|
||||
time.sleep(0.1)
|
||||
# NOTE(Steven/Caroline): Enforcing at least one second of warmup as RS cameras need a bit of time before the first read. If we don't wait, the first read from the warmup will raise.
|
||||
self.warmup_s = max(self.warmup_s, 1)
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.warmup_s:
|
||||
self.async_read(timeout_ms=self.warmup_s * 1000)
|
||||
time.sleep(0.1)
|
||||
with self.frame_lock:
|
||||
if self.latest_color_frame is None or self.use_depth and self.latest_depth_frame is None:
|
||||
raise ConnectionError(f"{self} failed to capture frames during warmup.")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@@ -319,9 +327,6 @@ class RealSenseCamera(Camera):
|
||||
This is a blocking call. It waits for a coherent set of frames (depth)
|
||||
from the camera hardware via the RealSense pipeline.
|
||||
|
||||
Args:
|
||||
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The depth map as a NumPy array (height, width)
|
||||
of type `np.uint16` (raw depth values in millimeters) and rotation.
|
||||
@@ -330,44 +335,52 @@ class RealSenseCamera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If reading frames from the pipeline fails or frames are invalid.
|
||||
"""
|
||||
if timeout_ms:
|
||||
logger.warning(
|
||||
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(
|
||||
f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
self.new_frame_event.clear()
|
||||
|
||||
_ = self.async_read(timeout_ms=10000)
|
||||
|
||||
with self.frame_lock:
|
||||
depth_map = self.latest_depth_frame
|
||||
|
||||
if depth_map is None:
|
||||
raise RuntimeError("No depth frame available. Ensure camera is streaming.")
|
||||
|
||||
return depth_map
|
||||
|
||||
def _read_from_hardware(self):
|
||||
if self.rs_pipeline is None:
|
||||
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
|
||||
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=10000)
|
||||
|
||||
if not ret or frame is None:
|
||||
raise RuntimeError(f"{self} read_depth failed (status={ret}).")
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
|
||||
depth_frame = frame.get_depth_frame()
|
||||
depth_map = np.asanyarray(depth_frame.get_data())
|
||||
return frame
|
||||
|
||||
depth_map_processed = self._postprocess_image(depth_map, depth_frame=True)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return depth_map_processed
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]:
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame (color) synchronously from the camera.
|
||||
|
||||
This is a blocking call. It waits for a coherent set of frames (color)
|
||||
from the camera hardware via the RealSense pipeline.
|
||||
|
||||
Args:
|
||||
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured color frame as a NumPy array
|
||||
(height, width, channels), processed according to `color_mode` and rotation.
|
||||
@@ -378,39 +391,39 @@ class RealSenseCamera(Camera):
|
||||
ValueError: If an invalid `color_mode` is requested.
|
||||
"""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if color_mode is not None:
|
||||
logger.warning(
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if timeout_ms:
|
||||
logger.warning(
|
||||
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if self.rs_pipeline is None:
|
||||
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
|
||||
self.new_frame_event.clear()
|
||||
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
|
||||
|
||||
if not ret or frame is None:
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
|
||||
color_frame = frame.get_color_frame()
|
||||
color_image_raw = np.asanyarray(color_frame.get_data())
|
||||
|
||||
color_image_processed = self._postprocess_image(color_image_raw, color_mode)
|
||||
frame = self.async_read(timeout_ms=10000)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return color_image_processed
|
||||
return frame
|
||||
|
||||
def _postprocess_image(
|
||||
self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False
|
||||
) -> NDArray[Any]:
|
||||
def _postprocess_image(self, image: NDArray[Any], depth_frame: bool = False) -> NDArray[Any]:
|
||||
"""
|
||||
Applies color conversion, dimension validation, and rotation to a raw color frame.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The raw image frame (expected RGB format from RealSense).
|
||||
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
|
||||
uses the instance's default `self.color_mode`.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The processed image frame according to `self.color_mode` and `self.rotation`.
|
||||
@@ -421,9 +434,9 @@ class RealSenseCamera(Camera):
|
||||
`width` and `height`.
|
||||
"""
|
||||
|
||||
if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
if self.color_mode and self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
f"Invalid requested color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
)
|
||||
|
||||
if depth_frame:
|
||||
@@ -454,7 +467,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame with 500ms timeout
|
||||
2. Stores result in latest_frame (thread-safe)
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
@@ -462,25 +475,41 @@ class RealSenseCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read(timeout_ms=500)
|
||||
frame = self._read_from_hardware()
|
||||
color_frame_raw = frame.get_color_frame()
|
||||
color_frame = np.asanyarray(color_frame_raw.get_data())
|
||||
processed_color_frame = self._postprocess_image(color_frame)
|
||||
|
||||
if self.use_depth:
|
||||
depth_frame_raw = frame.get_depth_frame()
|
||||
depth_frame = np.asanyarray(depth_frame_raw.get_data())
|
||||
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
|
||||
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = color_image
|
||||
self.latest_color_frame = processed_color_frame
|
||||
if self.use_depth:
|
||||
self.latest_depth_frame = processed_depth_frame
|
||||
self.latest_timestamp = capture_time
|
||||
self.new_frame_event.set()
|
||||
failure_count = 0
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
if failure_count <= 10:
|
||||
failure_count += 1
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
else:
|
||||
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
"""Starts or restarts the background read thread if it's not running."""
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=0.1)
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
self._stop_read_thread()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
|
||||
@@ -498,6 +527,12 @@ class RealSenseCamera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_color_frame = None
|
||||
self.latest_depth_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
@@ -506,6 +541,7 @@ class RealSenseCamera(Camera):
|
||||
This method retrieves the most recent color frame captured by the background
|
||||
read thread. It does not block waiting for the camera hardware directly,
|
||||
but may wait up to timeout_ms for the background thread to provide a frame.
|
||||
It is “best effort” under high FPS.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
@@ -524,17 +560,16 @@ class RealSenseCamera(Camera):
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
thread_alive = self.thread is not None and self.thread.is_alive()
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
|
||||
f"Read thread alive: {thread_alive}."
|
||||
f"Read thread alive: {self.thread.is_alive()}."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
frame = self.latest_color_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if frame is None:
|
||||
@@ -542,6 +577,43 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Returns:
|
||||
NDArray[Any]: The frame image (numpy array).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_color_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera, stops the pipeline, and cleans up resources.
|
||||
@@ -565,4 +637,10 @@ class RealSenseCamera(Camera):
|
||||
self.rs_pipeline = None
|
||||
self.rs_profile = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_color_frame = None
|
||||
self.latest_depth_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -45,6 +45,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ZMQCamera(Camera):
|
||||
"""
|
||||
Manages camera interactions via ZeroMQ for receiving frames from a remote server.
|
||||
|
||||
This class connects to a ZMQ Publisher, subscribes to frame topics, and decodes
|
||||
incoming JSON messages containing Base64 encoded images. It supports both
|
||||
synchronous and asynchronous frame reading patterns.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig
|
||||
@@ -52,7 +58,16 @@ class ZMQCamera(Camera):
|
||||
config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera")
|
||||
camera = ZMQCamera(config)
|
||||
camera.connect()
|
||||
frame = camera.read()
|
||||
|
||||
# Read 1 frame synchronously (blocking)
|
||||
color_image = camera.read()
|
||||
|
||||
# Read 1 frame asynchronously (waits for new frame with a timeout)
|
||||
async_image = camera.async_read()
|
||||
|
||||
# Get the latest frame immediately (no wait, returns timestamp)
|
||||
latest_image, timestamp = camera.read_latest()
|
||||
|
||||
camera.disconnect()
|
||||
```
|
||||
"""
|
||||
@@ -68,14 +83,17 @@ class ZMQCamera(Camera):
|
||||
self.color_mode = config.color_mode
|
||||
self.timeout_ms = config.timeout_ms
|
||||
|
||||
# ZMQ Context and Socket
|
||||
self.context: zmq.Context | None = None
|
||||
self.socket: zmq.Socket | None = None
|
||||
self._connected = False
|
||||
|
||||
# Threading resources
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_timestamp: float | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -83,10 +101,16 @@ class ZMQCamera(Camera):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Checks if the ZMQ socket is initialized and connected."""
|
||||
return self._connected and self.context is not None and self.socket is not None
|
||||
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""Connect to ZMQ camera server."""
|
||||
"""Connect to ZMQ camera server.
|
||||
|
||||
Args:
|
||||
warmup (bool): If True, waits for the camera to provide at least one
|
||||
valid frame before returning. Defaults to True.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
|
||||
@@ -103,17 +127,28 @@ class ZMQCamera(Camera):
|
||||
self.socket.connect(f"tcp://{self.server_address}:{self.port}")
|
||||
self._connected = True
|
||||
|
||||
# Auto-detect resolution
|
||||
# Auto-detect resolution if not provided
|
||||
if self.width is None or self.height is None:
|
||||
h, w = self.read().shape[:2]
|
||||
# Read directly from hardware because the thread isn't running yet
|
||||
temp_frame = self._read_from_hardware()
|
||||
h, w = temp_frame.shape[:2]
|
||||
self.height = h
|
||||
self.width = w
|
||||
logger.info(f"{self} resolution: {w}x{h}")
|
||||
logger.info(f"{self} resolution detected: {w}x{h}")
|
||||
|
||||
self._start_read_thread()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
if warmup:
|
||||
time.sleep(0.1)
|
||||
# Ensure we have captured at least one frame via the thread
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < (self.config.warmup_s): # Wait a bit more than timeout
|
||||
self.async_read(timeout_ms=self.config.warmup_s * 1000)
|
||||
time.sleep(0.1)
|
||||
|
||||
with self.frame_lock:
|
||||
if self.latest_frame is None:
|
||||
raise ConnectionError(f"{self} failed to capture frames during warmup.")
|
||||
|
||||
except Exception as e:
|
||||
self._cleanup()
|
||||
@@ -131,15 +166,14 @@ class ZMQCamera(Camera):
|
||||
|
||||
@staticmethod
|
||||
def find_cameras() -> list[dict[str, Any]]:
|
||||
"""ZMQ cameras require manual configuration (server address/port)."""
|
||||
return []
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Read a single frame from the ZMQ camera.
|
||||
Detection not implemented for ZMQ cameras. These cameras require manual configuration (server address/port).
|
||||
"""
|
||||
raise NotImplementedError("Camera detection is not implemented for ZMQ cameras.")
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decoded frame (height, width, 3)
|
||||
def _read_from_hardware(self) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame directly from the ZMQ socket.
|
||||
"""
|
||||
if not self.is_connected or self.socket is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
@@ -147,6 +181,7 @@ class ZMQCamera(Camera):
|
||||
try:
|
||||
message = self.socket.recv_string()
|
||||
except Exception as e:
|
||||
# Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import
|
||||
if type(e).__name__ == "Again":
|
||||
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
|
||||
raise
|
||||
@@ -176,42 +211,117 @@ class ZMQCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
while self.stop_event and not self.stop_event.is_set():
|
||||
try:
|
||||
frame = self.read()
|
||||
with self.frame_lock:
|
||||
self.latest_frame = frame
|
||||
self.new_frame_event.set()
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Read error: {e}")
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
if self.thread and self.thread.is_alive():
|
||||
return
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, daemon=True)
|
||||
self.thread.start()
|
||||
This is a blocking call. It waits for the next available frame from the
|
||||
camera background thread.
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
if self.stop_event:
|
||||
self.stop_event.set()
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
Returns:
|
||||
np.ndarray: Decoded frame (height, width, 3)
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if color_mode is not None:
|
||||
logger.warning(
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]:
|
||||
"""Read latest frame asynchronously (non-blocking)."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if not self.thread or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
self.new_frame_event.clear()
|
||||
frame = self.async_read(timeout_ms=10000)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = frame
|
||||
self.latest_timestamp = capture_time
|
||||
self.new_frame_event.set()
|
||||
failure_count = 0
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except (TimeoutError, Exception) as e:
|
||||
if failure_count <= 10:
|
||||
failure_count += 1
|
||||
logger.warning(f"Read error: {e}")
|
||||
else:
|
||||
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, daemon=True, name=f"{self}_read_loop")
|
||||
self.thread.start()
|
||||
time.sleep(0.1)
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
to become available. Defaults to 200ms.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The latest captured frame.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
TimeoutError: If no frame data becomes available within the specified timeout.
|
||||
RuntimeError: If the background thread is not running.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms")
|
||||
@@ -225,11 +335,55 @@ class ZMQCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Returns:
|
||||
NDArray[Any]: The frame image (numpy array).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from ZMQ camera."""
|
||||
if not self.is_connected and not self.thread:
|
||||
if not self.is_connected and self.thread is None:
|
||||
raise DeviceNotConnectedError(f"{self} not connected.")
|
||||
|
||||
self._stop_read_thread()
|
||||
if self.thread is not None:
|
||||
self._stop_read_thread()
|
||||
|
||||
self._cleanup()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -29,6 +29,7 @@ class ZMQCameraConfig(CameraConfig):
|
||||
camera_name: str = "zmq_camera"
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
timeout_ms: int = 5000
|
||||
warmup_s: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
|
||||
@@ -45,12 +45,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
|
||||
normalization mode to apply.
|
||||
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
|
||||
the original scale.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Example script to run synthetic data generation with Qwen VLM
|
||||
# This generates user prompts and robot utterances for hierarchical policy training
|
||||
|
||||
# Configuration
|
||||
REPO_ID="lerobot/libero_10"
|
||||
MODEL="Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||
# or: MODEL="Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
|
||||
OUTPUT_DIR="/fsx/jade_choghari/outputs/libero-10-annotate-high"
|
||||
|
||||
BATCH_SIZE=16
|
||||
TEMPERATURE=0.9
|
||||
SAMPLE_INTERVAL=5.0 # generate dialogue every 1 second (all episodes processed)
|
||||
|
||||
# Run subtask annotation
|
||||
# python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --video-key observation.images.image \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --skip-existing \
|
||||
# --output-repo-id "jadechoghari/libero10-annotate" \
|
||||
# --batch-size "$BATCH_SIZE" \
|
||||
# run synthetic data generation (all episodes processed)
|
||||
# python examples/dataset/annotate_pgen.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --model "$MODEL" \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --temperature "$TEMPERATURE" \
|
||||
# --batch-size "$BATCH_SIZE" \
|
||||
# --sample-interval "$SAMPLE_INTERVAL" \
|
||||
# --image-key observation.images.base \
|
||||
# --num-image-views-per-sample 1
|
||||
|
||||
# for faster testing, increase sample interval:
|
||||
# --sample-interval 5.0 # Samples every 5 seconds (much faster)
|
||||
|
||||
# to push to hub after generation:
|
||||
# add --push-to-hub flag
|
||||
|
||||
# efficient batch processing: 4 episodes at once
|
||||
python src/lerobot/data_processing/annotations/high_level_annotate.py \
|
||||
--data-dir "/fsx/jade_choghari/outputs/libero-10-annotate" \
|
||||
--output-dir "$OUTPUT_DIR" \
|
||||
--video-mode \
|
||||
--video-key observation.images.image \
|
||||
--video-batch-size "$BATCH_SIZE" \
|
||||
--sample-interval 5.0
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
# /fsx/jade_choghari/data/libero_10_subtasks_kw_converted
|
||||
dataset = LeRobotDataset(repo_id="lerobot/libero_10_image_subtask")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path="/fsx/jade_choghari/models/pi05-base",
|
||||
)
|
||||
cfg.dtype = "bfloat16"
|
||||
|
||||
pre_processor, post_processor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path="/fsx/jade_choghari/models/pi05-base",
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
breakpoint()
|
||||
batch1 = pre_processor(batch)
|
||||
breakpoint()
|
||||
print(batch.keys())
|
||||
# print(batch['task_index_high_level'].shape)
|
||||
# print(batch['task_index_high_level'])
|
||||
# print(batch['user_prompt'][0])
|
||||
# print(batch['robot_utterance'][0])
|
||||
# print(batch['task'][0])
|
||||
|
||||
valid_episode_list = []
|
||||
for episode_idx in range(len(dataset.meta.episodes)):
|
||||
subtask_index = dataset[episode_idx]["subtask_index"]
|
||||
valid_episode_list.append(episode_idx)
|
||||
|
||||
print(len(valid_episode_list))
|
||||
|
||||
# read this parquet /fsx/jade_choghari/outputs/pgen_annotations1/meta/tasks.parquett
|
||||
# import pandas as pd
|
||||
# tasks_df = pd.read_parquet('/fsx/jade_choghari/outputs/pgen_annotations1/meta/tasks.parquet')
|
||||
|
||||
# # print all
|
||||
# print(tasks_df.columns)
|
||||
# breakpoint()
|
||||
@@ -0,0 +1,74 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Example script to run synthetic data generation with Qwen VLM
|
||||
# This generates user prompts and robot utterances for hierarchical policy training
|
||||
|
||||
# Configuration
|
||||
REPO_ID="jadechoghari/piper-demo-20260205_103303"
|
||||
# MODEL="Qwen/Qwen3-VL-30B-A3B-Thinking"
|
||||
MODEL="Qwen/Qwen3.5-27B"
|
||||
# or: MODEL="Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
|
||||
OUTPUT_DIR="/fsx/jade_choghari/outputs/collect-data-pgen_new"
|
||||
|
||||
BATCH_SIZE=2
|
||||
TEMPERATURE=0.9
|
||||
SAMPLE_INTERVAL=5.0 # generate dialogue every 1 second (all episodes processed)
|
||||
|
||||
# Run subtask annotation.
|
||||
# To use closed-vocabulary labels, add a line: --subtask-labels "label1" "label2" ...
|
||||
# Example (add backslash after "$MODEL" and uncomment the next line):
|
||||
# --model "$MODEL" \
|
||||
# --subtask-labels "pick_up_yellow_nut_bar" "pick_up_cake" "pick_up_biscuit_pack" "pick_up_soda_can"
|
||||
python /home/lerobot/src/lerobot/data_processing/annotations/subtask_annotate.py \
|
||||
--repo-id "$REPO_ID" \
|
||||
--video-key observation.images.top \
|
||||
--output-dir "$OUTPUT_DIR" \
|
||||
--output-repo-id "jadechoghari/piper-demo-annotated1" \
|
||||
--push-to-hub \
|
||||
--no-timer-overlay \
|
||||
--model "$MODEL" \
|
||||
--subtask-labels "pick_up_yellow_nut_bar" "pick_up_cake" "pick_up_biscuit_pack" "pick_up_soda_can" \
|
||||
--batch-size 2
|
||||
|
||||
# Run subtask annotation (image-window: frames as images for better accuracy)
|
||||
# python /admin/home/jade_choghari/lerobot/src/lerobot/data_processing/annotations/subtask_annotate_image.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --camera-key observation.images.wrist \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --output-repo-id "jadechoghari/piper-demo-annotated1-image" \
|
||||
# --push-to-hub \
|
||||
# --model "$MODEL" \
|
||||
# --window-size 184 \
|
||||
# --max-frames-per-window 16 \
|
||||
# --subtask-labels "pick_up_yellow_nut_bar" "pick_up_cake" "pick_up_biscuit_pack" "pick_up_soda_can" \
|
||||
# --batch-size 2
|
||||
|
||||
|
||||
# run synthetic data generation (all episodes processed)
|
||||
# python examples/dataset/annotate_pgen.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --model "$MODEL" \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --temperature "$TEMPERATURE" \
|
||||
# --batch-size "$BATCH_SIZE" \
|
||||
# --sample-interval "$SAMPLE_INTERVAL" \
|
||||
# --image-key observation.images.base \
|
||||
# --num-image-views-per-sample 1
|
||||
|
||||
# for faster testing, increase sample interval:
|
||||
# --sample-interval 5.0 # Samples every 5 seconds (much faster)
|
||||
|
||||
# to push to hub after generation:
|
||||
# add --push-to-hub flag
|
||||
|
||||
# efficient batch processing: 4 episodes at once
|
||||
# python examples/dataset/annotate_pgen.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --model "$MODEL" \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --video-mode \
|
||||
# --video-key observation.images.up \
|
||||
# --video-batch-size "$BATCH_SIZE" \
|
||||
# --sample-interval 1.0
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,561 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Image-window subtask annotation for LeRobot datasets using Qwen VLMs.
|
||||
|
||||
This script assigns a subtask to each window of consecutive frames by sending
|
||||
those frames as images to the VLM (instead of a video) for better accuracy.
|
||||
Supports Qwen2-VL and Qwen3-VL (same models as subtask_annotate.py).
|
||||
|
||||
Pipeline:
|
||||
1. Load a LeRobot dataset (local or Hub).
|
||||
2. For each episode, slide a window over frame indices.
|
||||
3. For each window, load the corresponding images (from image_key or decoded video_key).
|
||||
4. Send the window of images to Qwen2-VL with the same skill prompt; get one subtask name.
|
||||
5. Assign that subtask to all frames in the window.
|
||||
6. Write subtasks.parquet and add subtask_index via add_features (same as subtask_annotate).
|
||||
|
||||
Usage:
|
||||
python -m lerobot.data_processing.annotations.subtask_annotate_image \\
|
||||
--data-dir /path/to/dataset --camera-key observation.images.base \\
|
||||
--window-size 8 --stride 8 --output-dir ./output
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import random
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from rich.console import Console
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Reuse data structures and save/load from the video-based annotator
|
||||
from lerobot.data_processing.annotations.subtask_annotate import (
|
||||
EpisodeSkills,
|
||||
Skill,
|
||||
load_skill_annotations,
|
||||
save_skill_annotations,
|
||||
)
|
||||
|
||||
|
||||
def create_window_skill_prompt(
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Prompt for labeling a single window of frames with one atomic skill.
|
||||
If subtask_labels are provided, the model must choose exactly one from that list.
|
||||
"""
|
||||
goal_context = f'The overall goal is: "{coarse_goal}".\n\n' if coarse_goal else ""
|
||||
if subtask_labels:
|
||||
labels_list = ", ".join(f'"{l}"' for l in subtask_labels)
|
||||
label_instruction = (
|
||||
f"You must choose exactly ONE skill from this list: [{labels_list}]. "
|
||||
"Do not create new labels. Reply with only that label.\n\n"
|
||||
)
|
||||
else:
|
||||
label_instruction = ""
|
||||
return textwrap.dedent(f"""\
|
||||
# Role
|
||||
You are a Robotics Vision System that labels short clips from robot manipulation demonstrations.
|
||||
|
||||
# Task
|
||||
{goal_context}{label_instruction}The following images are consecutive frames from a single short clip of a robot demonstration.
|
||||
What single atomic manipulation skill is being performed in this clip?
|
||||
|
||||
# Requirements
|
||||
- Reply with ONLY one short skill name (e.g. "pick up object", "move arm left", "release gripper").
|
||||
- No explanation, no timestamps, no JSON. Just the skill name.
|
||||
""").strip()
|
||||
|
||||
|
||||
def _run_image_segmenter(
|
||||
self,
|
||||
images: list[PIL.Image.Image],
|
||||
coarse_goal: str | None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Shared inference for Qwen2-VL and Qwen3-VL image window labeling."""
|
||||
prompt = create_window_skill_prompt(coarse_goal, subtask_labels)
|
||||
content = []
|
||||
for img in images:
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": "What single atomic skill is shown in these frames? Reply with only the skill name."})
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||
{"role": "user", "content": content},
|
||||
]
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=128, do_sample=False)
|
||||
|
||||
response = self.processor.batch_decode(
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||
skip_special_tokens=True,
|
||||
)[0].strip()
|
||||
skill_name = response.split("\n")[0].strip().strip('."')
|
||||
return skill_name if skill_name else "unknown"
|
||||
|
||||
|
||||
def _run_image_segmenter_batch(
|
||||
self,
|
||||
batch_images: list[list[PIL.Image.Image]],
|
||||
coarse_goal: str | None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Run VLM on multiple windows at once; returns one skill name per window."""
|
||||
if not batch_images:
|
||||
return []
|
||||
prompt = create_window_skill_prompt(coarse_goal, subtask_labels)
|
||||
all_texts = []
|
||||
all_image_inputs = []
|
||||
all_video_inputs = []
|
||||
for images in batch_images:
|
||||
content = []
|
||||
for img in images:
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": "What single atomic skill is shown in these frames? Reply with only the skill name."})
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||
{"role": "user", "content": content},
|
||||
]
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
all_texts.append(text)
|
||||
if image_inputs is not None:
|
||||
all_image_inputs.extend(image_inputs if isinstance(image_inputs, list) else [image_inputs])
|
||||
if video_inputs is not None:
|
||||
all_video_inputs.extend(video_inputs if isinstance(video_inputs, list) else [video_inputs])
|
||||
inputs = self.processor(
|
||||
text=all_texts,
|
||||
images=all_image_inputs if all_image_inputs else None,
|
||||
videos=all_video_inputs if all_video_inputs else None,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=128, do_sample=False)
|
||||
responses = self.processor.batch_decode(
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
return [
|
||||
(r.split("\n")[0].strip().strip('."') or "unknown")
|
||||
for r in responses
|
||||
]
|
||||
|
||||
|
||||
class Qwen2VLImageSegmenter:
|
||||
"""Uses Qwen2-VL to assign one skill name to a window of images (same model as subtask_annotate)."""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||
|
||||
self.console = Console()
|
||||
self.device = device
|
||||
self.process_vision_info = process_vision_info
|
||||
self.console.print(f"[cyan]Loading Qwen2-VL for image-window labeling: {model_name}...[/cyan]")
|
||||
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
self.console.print(f"[green]✓ Model loaded on {device}[/green]")
|
||||
|
||||
def segment_skill_from_images(
|
||||
self,
|
||||
images: list[PIL.Image.Image],
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Return a single skill name for the given window of images."""
|
||||
return _run_image_segmenter(self, images, coarse_goal, subtask_labels)
|
||||
|
||||
def segment_skill_from_images_batch(
|
||||
self,
|
||||
batch_images: list[list[PIL.Image.Image]],
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Return one skill name per window; processes multiple windows in one forward pass."""
|
||||
return _run_image_segmenter_batch(self, batch_images, coarse_goal, subtask_labels)
|
||||
|
||||
|
||||
class Qwen3VLImageSegmenter:
|
||||
"""Uses Qwen3-VL (MoE) to assign one skill name to a window of images."""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
|
||||
|
||||
self.console = Console()
|
||||
self.device = device
|
||||
self.process_vision_info = process_vision_info
|
||||
self.console.print(f"[cyan]Loading Qwen3-VL for image-window labeling: {model_name}...[/cyan]")
|
||||
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
self.console.print(f"[green]✓ Model loaded on {device}[/green]")
|
||||
|
||||
def segment_skill_from_images(
|
||||
self,
|
||||
images: list[PIL.Image.Image],
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Return a single skill name for the given window of images."""
|
||||
return _run_image_segmenter(self, images, coarse_goal, subtask_labels)
|
||||
|
||||
def segment_skill_from_images_batch(
|
||||
self,
|
||||
batch_images: list[list[PIL.Image.Image]],
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Return one skill name per window; processes multiple windows in one forward pass."""
|
||||
return _run_image_segmenter_batch(self, batch_images, coarse_goal, subtask_labels)
|
||||
|
||||
|
||||
def get_image_segmenter(
|
||||
model_name: str,
|
||||
device: str = "cuda",
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
):
|
||||
"""Return the appropriate image-window segmenter for the model (Qwen2-VL or Qwen3-VL)."""
|
||||
model_lower = model_name.lower()
|
||||
if "qwen3" in model_lower:
|
||||
return Qwen3VLImageSegmenter(model_name, device, torch_dtype)
|
||||
return Qwen2VLImageSegmenter(model_name, device, torch_dtype)
|
||||
|
||||
|
||||
def frame_to_pil(frame_value) -> PIL.Image.Image:
|
||||
"""Convert a single frame from dataset (tensor or PIL or path) to PIL.Image."""
|
||||
if isinstance(frame_value, PIL.Image.Image):
|
||||
return frame_value
|
||||
if isinstance(frame_value, (str, Path)):
|
||||
return PIL.Image.open(frame_value).convert("RGB")
|
||||
if hasattr(frame_value, "numpy"):
|
||||
arr = frame_value.numpy()
|
||||
else:
|
||||
arr = np.asarray(frame_value)
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
if arr.dtype == np.float32 or arr.dtype == np.float64:
|
||||
arr = (np.clip(arr, 0, 1) * 255).astype(np.uint8)
|
||||
elif arr.dtype != np.uint8:
|
||||
arr = np.clip(arr, 0, 255).astype(np.uint8)
|
||||
if arr.shape[-1] == 1:
|
||||
arr = np.repeat(arr, 3, axis=-1)
|
||||
return PIL.Image.fromarray(arr)
|
||||
|
||||
|
||||
def _sample_window_indices(window_length: int, max_frames: int) -> list[int]:
|
||||
"""Return indices into a window of length window_length, at most max_frames, in order.
|
||||
If window_length <= max_frames, returns range(window_length).
|
||||
Otherwise returns sorted random sample of max_frames indices (temporal order preserved).
|
||||
"""
|
||||
if max_frames <= 0 or window_length <= max_frames:
|
||||
return list(range(window_length))
|
||||
return sorted(random.sample(range(window_length), max_frames))
|
||||
|
||||
|
||||
class SkillAnnotatorImage:
|
||||
"""Annotates episodes by sliding a window over frames and labeling each window with the VLM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
segmenter: Qwen2VLImageSegmenter | Qwen3VLImageSegmenter,
|
||||
window_size: int = 8,
|
||||
stride: int | None = None,
|
||||
batch_size: int = 1,
|
||||
max_frames_per_window: int | None = None,
|
||||
console: Console | None = None,
|
||||
):
|
||||
self.segmenter = segmenter
|
||||
self.window_size = window_size
|
||||
self.stride = stride if stride is not None else window_size
|
||||
self.batch_size = max(1, batch_size)
|
||||
self.max_frames_per_window = max_frames_per_window
|
||||
self.console = console or Console()
|
||||
|
||||
def annotate_dataset(
|
||||
self,
|
||||
dataset: LeRobotDataset,
|
||||
camera_key: str,
|
||||
episodes: list[int] | None = None,
|
||||
skip_existing: bool = False,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> dict[int, EpisodeSkills]:
|
||||
"""Annotate episodes using image windows. camera_key can be an image_key or video_key."""
|
||||
episode_indices = episodes or list(range(dataset.meta.total_episodes))
|
||||
coarse_goal = self._get_coarse_goal(dataset)
|
||||
annotations: dict[int, EpisodeSkills] = {}
|
||||
|
||||
if skip_existing:
|
||||
existing = load_skill_annotations(dataset.root)
|
||||
if existing and existing.get("episodes"):
|
||||
existing_eps = {int(k) for k in existing["episodes"] if existing["episodes"][k].get("skills")}
|
||||
episode_indices = [i for i in episode_indices if i not in existing_eps]
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
try:
|
||||
skills = self._annotate_episode(
|
||||
dataset, ep_idx, camera_key, coarse_goal, subtask_labels
|
||||
)
|
||||
if skills:
|
||||
annotations[ep_idx] = EpisodeSkills(
|
||||
episode_index=ep_idx,
|
||||
description=coarse_goal,
|
||||
skills=skills,
|
||||
)
|
||||
self.console.print(f"[green]✓ Episode {ep_idx}: {len(skills)} window skills[/green]")
|
||||
else:
|
||||
self.console.print(f"[yellow]⚠ Episode {ep_idx}: no skills[/yellow]")
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]Episode {ep_idx} failed: {e}[/red]")
|
||||
|
||||
return annotations
|
||||
|
||||
def _get_coarse_goal(self, dataset: LeRobotDataset) -> str:
|
||||
if dataset.meta.tasks is not None and len(dataset.meta.tasks) > 0:
|
||||
return str(dataset.meta.tasks.index[0])
|
||||
return "Perform the demonstrated manipulation task."
|
||||
|
||||
def _annotate_episode(
|
||||
self,
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
camera_key: str,
|
||||
coarse_goal: str,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[Skill]:
|
||||
ep = dataset.meta.episodes[episode_index]
|
||||
ep_from = int(ep["dataset_from_index"])
|
||||
ep_to = int(ep["dataset_to_index"])
|
||||
length = ep_to - ep_from
|
||||
fps = dataset.meta.fps
|
||||
if length == 0:
|
||||
return []
|
||||
|
||||
# Collect full windows: (images, t_start, t_end) using frame timestamps.
|
||||
# If max_frames_per_window is set and window is larger, sample that many frames (order preserved).
|
||||
window_specs: list[tuple[list[PIL.Image.Image], float, float]] = []
|
||||
start = 0
|
||||
while start + self.window_size <= length:
|
||||
offsets = _sample_window_indices(
|
||||
self.window_size,
|
||||
self.max_frames_per_window or self.window_size,
|
||||
)
|
||||
frame_indices = [ep_from + start + i for i in offsets]
|
||||
images = []
|
||||
t_start = float(dataset[frame_indices[0]]["timestamp"].item())
|
||||
for idx in frame_indices:
|
||||
item = dataset[idx]
|
||||
images.append(frame_to_pil(item[camera_key]))
|
||||
t_end = t_start + self.window_size / fps
|
||||
window_specs.append((images, t_start, t_end))
|
||||
start += self.stride
|
||||
|
||||
# Last partial window
|
||||
if start < length:
|
||||
partial_len = ep_to - (ep_from + start)
|
||||
offsets = _sample_window_indices(
|
||||
partial_len,
|
||||
self.max_frames_per_window or partial_len,
|
||||
)
|
||||
frame_indices = [ep_from + start + i for i in offsets]
|
||||
images = []
|
||||
t_start = float(dataset[frame_indices[0]]["timestamp"].item())
|
||||
for idx in frame_indices:
|
||||
item = dataset[idx]
|
||||
images.append(frame_to_pil(item[camera_key]))
|
||||
t_end = float(dataset[frame_indices[-1]]["timestamp"].item()) + 1.0 / fps
|
||||
window_specs.append((images, t_start, t_end))
|
||||
|
||||
# Run in batches
|
||||
skills: list[Skill] = []
|
||||
for i in range(0, len(window_specs), self.batch_size):
|
||||
chunk = window_specs[i : i + self.batch_size]
|
||||
batch_images = [spec[0] for spec in chunk]
|
||||
if len(batch_images) > 1:
|
||||
skill_names = self.segmenter.segment_skill_from_images_batch(
|
||||
batch_images, coarse_goal, subtask_labels
|
||||
)
|
||||
else:
|
||||
skill_names = [
|
||||
self.segmenter.segment_skill_from_images(
|
||||
batch_images[0], coarse_goal, subtask_labels
|
||||
)
|
||||
]
|
||||
for (_, t_start, t_end), name in zip(chunk, skill_names, strict=True):
|
||||
skills.append(Skill(name=name, start=t_start, end=t_end))
|
||||
|
||||
return skills
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Image-window subtask annotation using Qwen VLM (frames as images for better accuracy)",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=textwrap.dedent("""\
|
||||
Examples:
|
||||
python -m lerobot.data_processing.annotations.subtask_annotate_image \\
|
||||
--data-dir /path/to/dataset --camera-key observation.images.base \\
|
||||
--window-size 8 --output-dir ./output
|
||||
|
||||
python -m lerobot.data_processing.annotations.subtask_annotate_image \\
|
||||
--repo-id user/dataset --camera-key observation.images.base \\
|
||||
--window-size 6 --stride 3 --model Qwen/Qwen2-VL-7B-Instruct
|
||||
|
||||
# Use Qwen3-VL (MoE)
|
||||
python -m lerobot.data_processing.annotations.subtask_annotate_image \\
|
||||
--data-dir /path/to/dataset --camera-key observation.images.base \\
|
||||
--model Qwen/Qwen3-VL-30B-A3B-Instruct
|
||||
"""),
|
||||
)
|
||||
data_group = parser.add_mutually_exclusive_group(required=True)
|
||||
data_group.add_argument("--data-dir", type=str, help="Path to local LeRobot dataset")
|
||||
data_group.add_argument("--repo-id", type=str, help="HuggingFace Hub dataset repository ID")
|
||||
|
||||
parser.add_argument(
|
||||
"--camera-key",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Image or video observation key (e.g. observation.images.base)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="Qwen/Qwen2-VL-7B-Instruct",
|
||||
help="VLM model: Qwen2-VL or Qwen3-VL (default: Qwen/Qwen2-VL-7B-Instruct)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--window-size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of frames per window (default: 8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stride",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Stride for sliding window (default: window_size = non-overlapping)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of windows to process in one VLM call (default: 1; increase for speed)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-frames-per-window",
|
||||
type=int,
|
||||
default=None,
|
||||
metavar="N",
|
||||
help="If window has more than N frames, randomly sample N frames (order kept) to avoid OOM (e.g. 16)",
|
||||
)
|
||||
parser.add_argument("--episodes", type=int, nargs="+", help="Episode indices to annotate (default: all)")
|
||||
parser.add_argument("--skip-existing", action="store_true", help="Skip episodes that already have annotations")
|
||||
parser.add_argument(
|
||||
"--subtask-labels",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="Closed vocabulary: model must choose only from these labels",
|
||||
)
|
||||
parser.add_argument("--output-dir", type=str, help="Output directory for dataset with subtask_index")
|
||||
parser.add_argument("--output-repo-id", type=str, help="Output repo id (default: <repo_id>_with_subtasks)")
|
||||
parser.add_argument("--push-to-hub", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
console = Console()
|
||||
|
||||
# Load dataset
|
||||
console.print("[cyan]Loading dataset...[/cyan]")
|
||||
if args.data_dir:
|
||||
dataset = LeRobotDataset(repo_id="local/dataset", root=args.data_dir, download_videos=False)
|
||||
else:
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, download_videos=True)
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
if args.camera_key not in camera_keys:
|
||||
console.print(f"[red]Error: camera key '{args.camera_key}' not in {camera_keys}[/red]")
|
||||
return
|
||||
console.print(f"[green]✓ Loaded dataset, {dataset.meta.total_episodes} episodes[/green]")
|
||||
|
||||
# Same Qwen VLM as subtask_annotate (Qwen2-VL or Qwen3-VL), image windows instead of video
|
||||
segmenter = get_image_segmenter(args.model, args.device, torch.bfloat16)
|
||||
|
||||
annotator = SkillAnnotatorImage(
|
||||
segmenter=segmenter,
|
||||
window_size=args.window_size,
|
||||
stride=args.stride,
|
||||
batch_size=args.batch_size,
|
||||
max_frames_per_window=args.max_frames_per_window,
|
||||
console=console,
|
||||
)
|
||||
annotations = annotator.annotate_dataset(
|
||||
dataset=dataset,
|
||||
camera_key=args.camera_key,
|
||||
episodes=args.episodes,
|
||||
skip_existing=args.skip_existing,
|
||||
subtask_labels=args.subtask_labels,
|
||||
)
|
||||
|
||||
if not annotations:
|
||||
console.print("[yellow]No annotations to save.[/yellow]")
|
||||
return
|
||||
|
||||
output_dir = Path(args.output_dir) if args.output_dir else None
|
||||
output_repo_id = args.output_repo_id
|
||||
new_dataset = save_skill_annotations(dataset, annotations, output_dir, output_repo_id)
|
||||
|
||||
total_skills = sum(len(a.skills) for a in annotations.values())
|
||||
console.print(f"[bold green]✓ Done.[/bold green] Episodes: {len(annotations)}, total window skills: {total_skills}")
|
||||
console.print(f" Dataset with subtask_index: {new_dataset.root}")
|
||||
|
||||
if args.push_to_hub and not args.data_dir:
|
||||
console.print("[cyan]Pushing to Hub...[/cyan]")
|
||||
try:
|
||||
new_dataset.push_to_hub(push_videos=False)
|
||||
console.print("[green]✓ Pushed.[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Push failed: {e}[/red]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -19,6 +19,7 @@ import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
|
||||
@@ -32,6 +33,7 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
get_file_size_in_mb,
|
||||
get_hf_features_from_features,
|
||||
get_parquet_file_size_in_mb,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
@@ -114,6 +116,9 @@ def update_meta_data(
|
||||
Adjusts all indices and timestamps to account for previously aggregated
|
||||
data and videos in the destination dataset.
|
||||
|
||||
For data file indices, uses the 'src_to_dst' mapping from aggregate_data()
|
||||
to correctly map source file indices to their destination locations.
|
||||
|
||||
Args:
|
||||
df: DataFrame containing the metadata to be updated.
|
||||
dst_meta: Destination dataset metadata.
|
||||
@@ -127,8 +132,50 @@ def update_meta_data(
|
||||
|
||||
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
||||
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
|
||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||
|
||||
# Update data file indices using source-to-destination mapping
|
||||
# This is critical for handling datasets that are already results of a merge
|
||||
data_src_to_dst = data_idx.get("src_to_dst", {})
|
||||
if data_src_to_dst:
|
||||
# Store original indices for lookup
|
||||
df["_orig_data_chunk"] = df["data/chunk_index"].copy()
|
||||
df["_orig_data_file"] = df["data/file_index"].copy()
|
||||
|
||||
# Vectorized mapping from (src_chunk, src_file) to (dst_chunk, dst_file)
|
||||
# This is much faster than per-row iteration for large metadata tables
|
||||
mapping_index = pd.MultiIndex.from_tuples(
|
||||
list(data_src_to_dst.keys()),
|
||||
names=["chunk_index", "file_index"],
|
||||
)
|
||||
mapping_values = list(data_src_to_dst.values())
|
||||
mapping_df = pd.DataFrame(
|
||||
mapping_values,
|
||||
index=mapping_index,
|
||||
columns=["dst_chunk", "dst_file"],
|
||||
)
|
||||
|
||||
# Construct a MultiIndex for each row based on original data indices
|
||||
row_index = pd.MultiIndex.from_arrays(
|
||||
[df["_orig_data_chunk"], df["_orig_data_file"]],
|
||||
names=["chunk_index", "file_index"],
|
||||
)
|
||||
|
||||
# Align mapping to rows; missing keys fall back to the default destination
|
||||
reindexed = mapping_df.reindex(row_index)
|
||||
reindexed[["dst_chunk", "dst_file"]] = reindexed[["dst_chunk", "dst_file"]].fillna(
|
||||
{"dst_chunk": data_idx["chunk"], "dst_file": data_idx["file"]}
|
||||
)
|
||||
|
||||
# Assign mapped destination indices back to the DataFrame
|
||||
df["data/chunk_index"] = reindexed["dst_chunk"].to_numpy()
|
||||
df["data/file_index"] = reindexed["dst_file"].to_numpy()
|
||||
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"])
|
||||
else:
|
||||
# Fallback to simple offset (backward compatibility for single-file sources)
|
||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||
for key, video_idx in videos_idx.items():
|
||||
# Store original video file indices before updating
|
||||
orig_chunk_col = f"videos/{key}/chunk_index"
|
||||
@@ -144,8 +191,7 @@ def update_meta_data(
|
||||
if src_to_dst:
|
||||
# Map each episode to its correct destination file and apply offset
|
||||
for idx in df.index:
|
||||
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
||||
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
|
||||
# Get destination chunk/file for this source file
|
||||
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
|
||||
@@ -161,8 +207,7 @@ def update_meta_data(
|
||||
df[orig_chunk_col] = video_idx["chunk"]
|
||||
df[orig_file_col] = video_idx["file"]
|
||||
for idx in df.index:
|
||||
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
||||
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||
@@ -260,6 +305,10 @@ def aggregate_datasets(
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||
|
||||
# Clear the src_to_dst mapping after processing each source dataset
|
||||
# to avoid interference between different source datasets
|
||||
data_idx.pop("src_to_dst", None)
|
||||
|
||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||
|
||||
@@ -310,10 +359,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
dst_file_durations = video_idx["dst_file_durations"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
# Convert to Python int to ensure consistent dict keys
|
||||
src_chunk_idx = int(src_chunk_idx)
|
||||
src_file_idx = int(src_file_idx)
|
||||
|
||||
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=src_chunk_idx,
|
||||
@@ -386,10 +431,16 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
Reads source data files, updates indices to match the aggregated dataset,
|
||||
and writes them to the destination with proper file rotation.
|
||||
|
||||
Tracks a `src_to_dst` mapping from source (chunk, file) to destination (chunk, file)
|
||||
which is critical for correctly updating episode metadata when source datasets
|
||||
have multiple data files (e.g., from a previous merge operation).
|
||||
|
||||
Args:
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
data_idx: Dictionary tracking data chunk and file indices.
|
||||
data_files_size_in_mb: Maximum size for data files in MB.
|
||||
chunk_size: Maximum number of files per chunk.
|
||||
|
||||
Returns:
|
||||
dict: Updated data_idx with current chunk and file indices.
|
||||
@@ -402,25 +453,47 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
}
|
||||
|
||||
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
|
||||
contains_images = len(dst_meta.image_keys) > 0
|
||||
|
||||
# retrieve features schema for proper image typing in parquet
|
||||
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
|
||||
|
||||
# Track source to destination file mapping for metadata update
|
||||
# This is critical for handling datasets that are already results of a merge
|
||||
src_to_dst: dict[tuple[int, int], tuple[int, int]] = {}
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
||||
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
||||
chunk_index=src_chunk_idx, file_index=src_file_idx
|
||||
)
|
||||
df = pd.read_parquet(src_path)
|
||||
if contains_images:
|
||||
# Use HuggingFace datasets to read source data to preserve image format
|
||||
src_ds = datasets.Dataset.from_parquet(str(src_path))
|
||||
df = src_ds.to_pandas()
|
||||
else:
|
||||
df = pd.read_parquet(src_path)
|
||||
df = update_data_df(df, src_meta, dst_meta)
|
||||
|
||||
data_idx = append_or_create_parquet_file(
|
||||
# Write data and get the actual destination file it was written to
|
||||
# This avoids duplicating the rotation logic here
|
||||
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
data_idx,
|
||||
data_files_size_in_mb,
|
||||
chunk_size,
|
||||
DEFAULT_DATA_PATH,
|
||||
contains_images=len(dst_meta.image_keys) > 0,
|
||||
contains_images=contains_images,
|
||||
aggr_root=dst_meta.root,
|
||||
hf_features=hf_features,
|
||||
)
|
||||
|
||||
# Record the mapping from source to actual destination
|
||||
src_to_dst[(src_chunk_idx, src_file_idx)] = (dst_chunk, dst_file)
|
||||
|
||||
# Add the mapping to data_idx for use in metadata update
|
||||
data_idx["src_to_dst"] = src_to_dst
|
||||
|
||||
return data_idx
|
||||
|
||||
|
||||
@@ -461,7 +534,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
videos_idx,
|
||||
)
|
||||
|
||||
meta_idx = append_or_create_parquet_file(
|
||||
meta_idx, _ = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
meta_idx,
|
||||
@@ -488,7 +561,8 @@ def append_or_create_parquet_file(
|
||||
default_path: str,
|
||||
contains_images: bool = False,
|
||||
aggr_root: Path = None,
|
||||
):
|
||||
hf_features: datasets.Features | None = None,
|
||||
) -> tuple[dict[str, int], tuple[int, int]]:
|
||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||
|
||||
Manages file rotation when size limits are exceeded to prevent individual files
|
||||
@@ -503,40 +577,49 @@ def append_or_create_parquet_file(
|
||||
default_path: Format string for generating file paths.
|
||||
contains_images: Whether the data contains images requiring special handling.
|
||||
aggr_root: Root path for the aggregated dataset.
|
||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||
|
||||
Returns:
|
||||
dict: Updated index dictionary with current chunk and file indices.
|
||||
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
||||
and (dst_chunk, dst_file) is the actual destination file the data was written to.
|
||||
"""
|
||||
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||
|
||||
if not dst_path.exists():
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(df, dst_path)
|
||||
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
||||
else:
|
||||
df.to_parquet(dst_path)
|
||||
return idx
|
||||
return idx, (dst_chunk, dst_file)
|
||||
|
||||
src_size = get_parquet_file_size_in_mb(src_path)
|
||||
dst_size = get_parquet_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= max_mb:
|
||||
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
||||
new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
final_df = df
|
||||
target_path = new_path
|
||||
else:
|
||||
existing_df = pd.read_parquet(dst_path)
|
||||
if contains_images:
|
||||
# Use HuggingFace datasets to read existing data to preserve image format
|
||||
existing_ds = datasets.Dataset.from_parquet(str(dst_path))
|
||||
existing_df = existing_ds.to_pandas()
|
||||
else:
|
||||
existing_df = pd.read_parquet(dst_path)
|
||||
final_df = pd.concat([existing_df, df], ignore_index=True)
|
||||
target_path = dst_path
|
||||
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(final_df, target_path)
|
||||
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
||||
else:
|
||||
final_df.to_parquet(target_path)
|
||||
|
||||
return idx
|
||||
return idx, (dst_chunk, dst_file)
|
||||
|
||||
|
||||
def finalize_aggregation(aggr_meta, all_metadata):
|
||||
|
||||
@@ -26,6 +26,7 @@ This module provides utilities for:
|
||||
import logging
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -51,7 +52,8 @@ from lerobot.datasets.utils import (
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
|
||||
|
||||
|
||||
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||
@@ -1083,3 +1085,687 @@ def _copy_episodes_metadata_and_stats(
|
||||
else:
|
||||
if src_dataset.meta.stats:
|
||||
write_stats(src_dataset.meta.stats, dst_meta.root)
|
||||
|
||||
|
||||
def _save_episode_images_for_video(
|
||||
dataset: LeRobotDataset,
|
||||
imgs_dir: Path,
|
||||
img_key: str,
|
||||
episode_index: int,
|
||||
num_workers: int = 4,
|
||||
) -> None:
|
||||
"""Save images from a specific episode and camera to disk for video encoding.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset to extract images from
|
||||
imgs_dir: Directory to save images to
|
||||
img_key: The image key (camera) to extract
|
||||
episode_index: Index of the episode to save
|
||||
num_workers: Number of threads for parallel image saving
|
||||
"""
|
||||
# Create directory
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get dataset without torch format for PIL image access
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# Select only this camera's images
|
||||
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||
|
||||
# Get episode start and end indices
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
|
||||
# Get all items for this episode
|
||||
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||
|
||||
# Define function to save a single image
|
||||
def save_single_image(i_item_tuple):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key]
|
||||
# Use frame-XXXXXX.png format to match encode_video_frames expectations
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
return i
|
||||
|
||||
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
|
||||
items = list(enumerate(episode_dataset))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(save_single_image, item) for item in items]
|
||||
for future in as_completed(futures):
|
||||
future.result() # This will raise any exceptions that occurred
|
||||
|
||||
|
||||
def _save_batch_episodes_images(
|
||||
dataset: LeRobotDataset,
|
||||
imgs_dir: Path,
|
||||
img_key: str,
|
||||
episode_indices: list[int],
|
||||
num_workers: int = 4,
|
||||
) -> list[float]:
|
||||
"""Save images from multiple episodes to disk for batch video encoding.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset to extract images from
|
||||
imgs_dir: Directory to save images to
|
||||
img_key: The image key (camera) to extract
|
||||
episode_indices: List of episode indices to save
|
||||
num_workers: Number of threads for parallel image saving
|
||||
|
||||
Returns:
|
||||
List of episode durations in seconds
|
||||
"""
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||
|
||||
# Define function to save a single image with global frame index
|
||||
# Defined once outside the loop to avoid repeated closure creation
|
||||
def save_single_image(i_item_tuple, base_frame_idx, img_key_param):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key_param]
|
||||
# Use global frame index for naming
|
||||
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100)
|
||||
return i
|
||||
|
||||
episode_durations = []
|
||||
frame_idx = 0
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
# Get episode range
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][ep_idx]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][ep_idx]
|
||||
episode_length = to_idx - from_idx
|
||||
episode_durations.append(episode_length / dataset.fps)
|
||||
|
||||
# Get episode images
|
||||
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||
|
||||
# Save images
|
||||
items = list(enumerate(episode_dataset))
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(save_single_image, item, frame_idx, img_key) for item in items]
|
||||
for future in as_completed(futures):
|
||||
future.result()
|
||||
|
||||
frame_idx += episode_length
|
||||
|
||||
return episode_durations
|
||||
|
||||
|
||||
def _iter_episode_batches(
|
||||
episode_indices: list[int],
|
||||
episode_lengths: dict[int, int],
|
||||
size_per_frame_mb: float,
|
||||
video_file_size_limit: float,
|
||||
max_episodes: int | None,
|
||||
max_frames: int | None,
|
||||
):
|
||||
"""Generator that yields batches of episode indices for video encoding.
|
||||
|
||||
Groups episodes into batches that respect size and memory constraints:
|
||||
- Stays under video file size limit
|
||||
- Respects maximum episodes per batch (if specified)
|
||||
- Respects maximum frames per batch (if specified)
|
||||
|
||||
Args:
|
||||
episode_indices: List of episode indices to batch
|
||||
episode_lengths: Dictionary mapping episode index to episode length
|
||||
size_per_frame_mb: Estimated size per frame in MB
|
||||
video_file_size_limit: Maximum video file size in MB
|
||||
max_episodes: Maximum number of episodes per batch (None = no limit)
|
||||
max_frames: Maximum number of frames per batch (None = no limit)
|
||||
|
||||
Yields:
|
||||
List of episode indices for each batch
|
||||
"""
|
||||
batch_episodes = []
|
||||
estimated_size = 0.0
|
||||
total_frames = 0
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
ep_length = episode_lengths[ep_idx]
|
||||
ep_estimated_size = ep_length * size_per_frame_mb
|
||||
|
||||
# we check if adding this episode would exceed any constraint
|
||||
would_exceed_size = estimated_size > 0 and estimated_size + ep_estimated_size >= video_file_size_limit
|
||||
would_exceed_episodes = max_episodes is not None and len(batch_episodes) >= max_episodes
|
||||
would_exceed_frames = max_frames is not None and total_frames + ep_length > max_frames
|
||||
|
||||
if batch_episodes and (would_exceed_size or would_exceed_episodes or would_exceed_frames):
|
||||
# yield current batch before adding this episode
|
||||
yield batch_episodes
|
||||
# start a new batch with current episode
|
||||
batch_episodes = [ep_idx]
|
||||
estimated_size = ep_estimated_size
|
||||
total_frames = ep_length
|
||||
else:
|
||||
# add to current batch
|
||||
batch_episodes.append(ep_idx)
|
||||
estimated_size += ep_estimated_size
|
||||
total_frames += ep_length
|
||||
|
||||
# yield final batch if not empty
|
||||
if batch_episodes:
|
||||
yield batch_episodes
|
||||
|
||||
|
||||
def _estimate_frame_size_via_calibration(
|
||||
dataset: LeRobotDataset,
|
||||
img_key: str,
|
||||
episode_indices: list[int],
|
||||
temp_dir: Path,
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
fast_decode: int,
|
||||
num_calibration_frames: int = 30,
|
||||
) -> float:
|
||||
"""Estimate MB per frame by encoding a small calibration sample.
|
||||
|
||||
Encodes a representative sample of frames using the exact codec parameters
|
||||
to measure actual compression ratio, which is more accurate than heuristics.
|
||||
|
||||
Args:
|
||||
dataset: Source dataset with images.
|
||||
img_key: Image key to calibrate (e.g., "observation.images.top").
|
||||
episode_indices: List of episode indices being processed.
|
||||
temp_dir: Temporary directory for calibration files.
|
||||
fps: Frames per second for video encoding.
|
||||
vcodec: Video codec (libsvtav1, h264, hevc).
|
||||
pix_fmt: Pixel format (yuv420p, etc.).
|
||||
g: GOP size (group of pictures).
|
||||
crf: Constant Rate Factor (quality).
|
||||
fast_decode: Fast decode tuning parameter.
|
||||
num_calibration_frames: Number of frames to use for calibration (default: 30).
|
||||
|
||||
Returns:
|
||||
Estimated size in MB per frame based on actual encoding.
|
||||
"""
|
||||
calibration_dir = temp_dir / "calibration" / img_key
|
||||
calibration_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
# Select a representative episode (prefer middle episode if available)
|
||||
calibration_ep_idx = episode_indices[len(episode_indices) // 2]
|
||||
|
||||
# Get episode range
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][calibration_ep_idx]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][calibration_ep_idx]
|
||||
episode_length = to_idx - from_idx
|
||||
|
||||
# Use up to num_calibration_frames from this episode
|
||||
num_frames = min(num_calibration_frames, episode_length)
|
||||
|
||||
# Get frames from dataset
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
sample_indices = range(from_idx, from_idx + num_frames)
|
||||
|
||||
# Save calibration frames
|
||||
for i, idx in enumerate(sample_indices):
|
||||
img = hf_dataset[idx][img_key]
|
||||
img.save(str(calibration_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
|
||||
# Encode calibration video
|
||||
calibration_video_path = calibration_dir / "calibration.mp4"
|
||||
encode_video_frames(
|
||||
imgs_dir=calibration_dir,
|
||||
video_path=calibration_video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Measure actual compressed size
|
||||
video_size_bytes = calibration_video_path.stat().st_size
|
||||
video_size_mb = video_size_bytes / BYTES_PER_MIB
|
||||
size_per_frame_mb = video_size_mb / num_frames
|
||||
|
||||
logging.info(
|
||||
f" Calibration: {num_frames} frames -> {video_size_mb:.2f} MB "
|
||||
f"= {size_per_frame_mb:.4f} MB/frame for {img_key}"
|
||||
)
|
||||
|
||||
return size_per_frame_mb
|
||||
|
||||
finally:
|
||||
# Clean up calibration files
|
||||
if calibration_dir.exists():
|
||||
shutil.rmtree(calibration_dir)
|
||||
|
||||
|
||||
def _copy_data_without_images(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_indices: list[int],
|
||||
img_keys: list[str],
|
||||
) -> None:
|
||||
"""Copy data files without image columns.
|
||||
|
||||
Args:
|
||||
src_dataset: Source dataset
|
||||
dst_meta: Destination metadata
|
||||
episode_indices: Episodes to include
|
||||
img_keys: Image keys to remove
|
||||
"""
|
||||
from lerobot.datasets.utils import DATA_DIR
|
||||
|
||||
data_dir = src_dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
episode_set = set(episode_indices)
|
||||
|
||||
for src_path in tqdm(parquet_files, desc="Processing data files"):
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
|
||||
# Filter to only include selected episodes
|
||||
df = df[df["episode_index"].isin(episode_set)].copy()
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
# Remove image columns
|
||||
columns_to_drop = [col for col in img_keys if col in df.columns]
|
||||
if columns_to_drop:
|
||||
df = df.drop(columns=columns_to_drop)
|
||||
|
||||
# Get chunk and file indices from path
|
||||
relative_path = src_path.relative_to(src_dataset.root)
|
||||
chunk_dir = relative_path.parts[1]
|
||||
file_name = relative_path.parts[2]
|
||||
chunk_idx = int(chunk_dir.split("-")[1])
|
||||
file_idx = int(file_name.split("-")[1].split(".")[0])
|
||||
|
||||
# Write to destination without pandas index
|
||||
dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet"
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_path, index=False)
|
||||
|
||||
|
||||
# Video conversion constants
|
||||
BYTES_PER_KIB = 1024
|
||||
BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB
|
||||
|
||||
|
||||
def modify_tasks(
|
||||
dataset: LeRobotDataset,
|
||||
new_task: str | None = None,
|
||||
episode_tasks: dict[int, str] | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Modify tasks in a LeRobotDataset.
|
||||
|
||||
This function allows you to either:
|
||||
1. Set a single task for the entire dataset (using `new_task`)
|
||||
2. Set specific tasks for specific episodes (using `episode_tasks`)
|
||||
|
||||
You can combine both: `new_task` sets the default, and `episode_tasks` overrides
|
||||
specific episodes.
|
||||
|
||||
The dataset is modified in-place, updating only the task-related files:
|
||||
- meta/tasks.parquet
|
||||
- data/**/*.parquet (task_index column)
|
||||
- meta/episodes/**/*.parquet (tasks column)
|
||||
- meta/info.json (total_tasks)
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset to modify.
|
||||
new_task: A single task string to apply to all episodes. If None and episode_tasks
|
||||
is also None, raises an error.
|
||||
episode_tasks: Optional dict mapping episode indices to their task strings.
|
||||
Overrides `new_task` for specific episodes.
|
||||
|
||||
|
||||
Examples:
|
||||
Set a single task for all episodes:
|
||||
dataset = modify_tasks(dataset, new_task="Pick up the cube")
|
||||
|
||||
Set different tasks for specific episodes:
|
||||
dataset = modify_tasks(
|
||||
dataset,
|
||||
episode_tasks={0: "Task A", 1: "Task B", 2: "Task A"}
|
||||
)
|
||||
|
||||
Set a default task with overrides:
|
||||
dataset = modify_tasks(
|
||||
dataset,
|
||||
new_task="Default task",
|
||||
episode_tasks={5: "Special task for episode 5"}
|
||||
)
|
||||
"""
|
||||
if new_task is None and episode_tasks is None:
|
||||
raise ValueError("Must specify at least one of new_task or episode_tasks")
|
||||
|
||||
if episode_tasks is not None:
|
||||
valid_indices = set(range(dataset.meta.total_episodes))
|
||||
invalid = set(episode_tasks.keys()) - valid_indices
|
||||
if invalid:
|
||||
raise ValueError(f"Invalid episode indices: {invalid}")
|
||||
|
||||
# Ensure episodes metadata is loaded
|
||||
if dataset.meta.episodes is None:
|
||||
dataset.meta.episodes = load_episodes(dataset.root)
|
||||
|
||||
# Build the mapping from episode index to task string
|
||||
episode_to_task: dict[int, str] = {}
|
||||
for ep_idx in range(dataset.meta.total_episodes):
|
||||
if episode_tasks and ep_idx in episode_tasks:
|
||||
episode_to_task[ep_idx] = episode_tasks[ep_idx]
|
||||
elif new_task is not None:
|
||||
episode_to_task[ep_idx] = new_task
|
||||
else:
|
||||
# Keep original task if not overridden and no default provided
|
||||
original_tasks = dataset.meta.episodes[ep_idx]["tasks"]
|
||||
if not original_tasks:
|
||||
raise ValueError(f"Episode {ep_idx} has no tasks and no default task was provided")
|
||||
episode_to_task[ep_idx] = original_tasks[0]
|
||||
|
||||
# Collect all unique tasks and create new task mapping
|
||||
unique_tasks = sorted(set(episode_to_task.values()))
|
||||
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
|
||||
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
|
||||
|
||||
logging.info(f"Modifying tasks in {dataset.repo_id}")
|
||||
logging.info(f"New tasks: {unique_tasks}")
|
||||
|
||||
root = dataset.root
|
||||
|
||||
# Update data files - modify task_index column
|
||||
logging.info("Updating data files...")
|
||||
data_dir = root / DATA_DIR
|
||||
|
||||
for parquet_path in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Updating data"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
# Build a mapping from episode_index to new task_index for rows in this file
|
||||
episode_indices_in_file = df["episode_index"].unique()
|
||||
ep_to_new_task_idx = {
|
||||
ep_idx: task_to_index[episode_to_task[ep_idx]] for ep_idx in episode_indices_in_file
|
||||
}
|
||||
|
||||
# Update task_index column
|
||||
df["task_index"] = df["episode_index"].map(ep_to_new_task_idx)
|
||||
df.to_parquet(parquet_path, index=False)
|
||||
|
||||
# Update episodes metadata - modify tasks column
|
||||
logging.info("Updating episodes metadata...")
|
||||
episodes_dir = root / "meta" / "episodes"
|
||||
|
||||
for parquet_path in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Updating episodes"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
# Update tasks column
|
||||
df["tasks"] = df["episode_index"].apply(lambda ep_idx: [episode_to_task[ep_idx]])
|
||||
df.to_parquet(parquet_path, index=False)
|
||||
|
||||
# Write new tasks.parquet
|
||||
write_tasks(new_task_df, root)
|
||||
|
||||
# Update info.json
|
||||
dataset.meta.info["total_tasks"] = len(unique_tasks)
|
||||
write_info(dataset.meta.info, root)
|
||||
|
||||
# Reload metadata to reflect changes
|
||||
dataset.meta.tasks = new_task_df
|
||||
dataset.meta.episodes = load_episodes(root)
|
||||
|
||||
logging.info(f"Tasks: {unique_tasks}")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int = 2,
|
||||
crf: int = 30,
|
||||
fast_decode: int = 0,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
max_episodes_per_batch: int | None = None,
|
||||
max_frames_per_batch: int | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Convert image-to-video dataset.
|
||||
|
||||
Creates a new LeRobotDataset with images encoded as videos, following the proper
|
||||
LeRobot dataset structure with videos stored in chunked MP4 files.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Directory to save the new video dataset
|
||||
repo_id: Repository ID for the new dataset (default: original_id + "_video")
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
crf: Constant rate factor (default: 30)
|
||||
fast_decode: Fast decode tuning (default: 0)
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
|
||||
max_frames_per_batch: Maximum frames per video batch to avoid memory issues (None = no limit)
|
||||
|
||||
Returns:
|
||||
New LeRobotDataset with images encoded as videos
|
||||
"""
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
|
||||
)
|
||||
|
||||
# Get all image keys
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
if len(img_keys) == 0:
|
||||
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
|
||||
|
||||
# Determine which episodes to process
|
||||
if episode_indices is None:
|
||||
episode_indices = list(range(dataset.meta.total_episodes))
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_video"
|
||||
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
for key, value in dataset.meta.features.items():
|
||||
if key not in img_keys:
|
||||
new_features[key] = value
|
||||
else:
|
||||
# Convert image key to video format
|
||||
new_features[key] = value.copy()
|
||||
new_features[key]["dtype"] = "video" # Change dtype from "image" to "video"
|
||||
# Video info will be updated after episodes are encoded
|
||||
|
||||
# Create new metadata for video dataset
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=new_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=True,
|
||||
chunks_size=dataset.meta.chunks_size,
|
||||
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
|
||||
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
|
||||
)
|
||||
|
||||
# Create temporary directory for image extraction
|
||||
temp_dir = output_dir / "temp_images"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Process all episodes and batch encode videos
|
||||
# Use dictionary for O(1) episode metadata lookups instead of O(n) linear search
|
||||
all_episode_metadata = {}
|
||||
fps = int(dataset.fps)
|
||||
|
||||
try:
|
||||
# Build episode metadata entries first
|
||||
logging.info("Building episode metadata...")
|
||||
cumulative_frame_idx = 0
|
||||
for ep_idx in episode_indices:
|
||||
src_episode = dataset.meta.episodes[ep_idx]
|
||||
ep_length = src_episode["length"]
|
||||
ep_meta = {
|
||||
"episode_index": ep_idx,
|
||||
"length": ep_length,
|
||||
"dataset_from_index": cumulative_frame_idx,
|
||||
"dataset_to_index": cumulative_frame_idx + ep_length,
|
||||
}
|
||||
if "data/chunk_index" in src_episode:
|
||||
ep_meta["data/chunk_index"] = src_episode["data/chunk_index"]
|
||||
ep_meta["data/file_index"] = src_episode["data/file_index"]
|
||||
all_episode_metadata[ep_idx] = ep_meta
|
||||
cumulative_frame_idx += ep_length
|
||||
|
||||
# Process each camera and batch encode multiple episodes together
|
||||
video_file_size_limit = new_meta.video_files_size_in_mb
|
||||
|
||||
# Pre-compute episode lengths for batching
|
||||
episode_lengths = {ep_idx: dataset.meta.episodes["length"][ep_idx] for ep_idx in episode_indices}
|
||||
|
||||
for img_key in tqdm(img_keys, desc="Processing cameras"):
|
||||
# Estimate size per frame by encoding a small calibration sample
|
||||
# This provides accurate compression ratio for the specific codec parameters
|
||||
size_per_frame_mb = _estimate_frame_size_via_calibration(
|
||||
dataset=dataset,
|
||||
img_key=img_key,
|
||||
episode_indices=episode_indices,
|
||||
temp_dir=temp_dir,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
)
|
||||
|
||||
logging.info(f"Processing camera: {img_key}")
|
||||
chunk_idx, file_idx = 0, 0
|
||||
cumulative_timestamp = 0.0
|
||||
|
||||
# Process episodes in batches to stay under size limit
|
||||
for batch_episodes in _iter_episode_batches(
|
||||
episode_indices=episode_indices,
|
||||
episode_lengths=episode_lengths,
|
||||
size_per_frame_mb=size_per_frame_mb,
|
||||
video_file_size_limit=video_file_size_limit,
|
||||
max_episodes=max_episodes_per_batch,
|
||||
max_frames=max_frames_per_batch,
|
||||
):
|
||||
total_frames_in_batch = sum(episode_lengths[idx] for idx in batch_episodes)
|
||||
logging.info(
|
||||
f" Encoding batch of {len(batch_episodes)} episodes "
|
||||
f"({batch_episodes[0]}-{batch_episodes[-1]}) = {total_frames_in_batch} frames"
|
||||
)
|
||||
|
||||
# Save images for all episodes in this batch
|
||||
imgs_dir = temp_dir / f"batch_{chunk_idx}_{file_idx}" / img_key
|
||||
episode_durations = _save_batch_episodes_images(
|
||||
dataset=dataset,
|
||||
imgs_dir=imgs_dir,
|
||||
img_key=img_key,
|
||||
episode_indices=batch_episodes,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
# Encode all batched episodes into single video
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
encode_video_frames(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Clean up temporary images
|
||||
shutil.rmtree(imgs_dir)
|
||||
|
||||
# Update metadata for each episode in the batch
|
||||
for ep_idx, duration in zip(batch_episodes, episode_durations, strict=True):
|
||||
from_timestamp = cumulative_timestamp
|
||||
to_timestamp = cumulative_timestamp + duration
|
||||
cumulative_timestamp = to_timestamp
|
||||
|
||||
# Find episode metadata entry and add video metadata (O(1) dictionary lookup)
|
||||
ep_meta = all_episode_metadata[ep_idx]
|
||||
ep_meta[f"videos/{img_key}/chunk_index"] = chunk_idx
|
||||
ep_meta[f"videos/{img_key}/file_index"] = file_idx
|
||||
ep_meta[f"videos/{img_key}/from_timestamp"] = from_timestamp
|
||||
ep_meta[f"videos/{img_key}/to_timestamp"] = to_timestamp
|
||||
|
||||
# Move to next video file for next batch
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, new_meta.chunks_size)
|
||||
cumulative_timestamp = 0.0
|
||||
|
||||
# Copy and transform data files (removing image columns)
|
||||
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
|
||||
|
||||
# Save episode metadata
|
||||
episodes_df = pd.DataFrame(list(all_episode_metadata.values()))
|
||||
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
|
||||
episodes_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes_df.to_parquet(episodes_path, index=False)
|
||||
|
||||
# Update metadata info
|
||||
new_meta.info["total_episodes"] = len(episode_indices)
|
||||
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values())
|
||||
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
||||
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys (now videos)
|
||||
# We need to manually set video info since update_video_info() checks video_keys first
|
||||
for img_key in img_keys:
|
||||
if not new_meta.features[img_key].get("info", None):
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
# Copy stats and tasks
|
||||
if dataset.meta.stats is not None:
|
||||
# Remove image stats
|
||||
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
|
||||
write_stats(new_stats, new_meta.root)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
|
||||
finally:
|
||||
# Clean up temporary directory
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
logging.info(f"Completed converting {dataset.repo_id} to video format")
|
||||
logging.info(f"New dataset saved to: {output_dir}")
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
|
||||
@@ -57,7 +57,9 @@ from lerobot.datasets.utils import (
|
||||
load_info,
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
load_tasks_high_level,
|
||||
update_chunk_file_indices,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
@@ -162,6 +164,8 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.tasks_high_level = load_tasks_high_level(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
@@ -518,6 +522,8 @@ class LeRobotDatasetMetadata:
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.tasks_high_level = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
@@ -935,17 +941,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
def _get_query_indices(
|
||||
self, abs_idx: int, ep_idx: int
|
||||
) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]:
|
||||
"""Compute query indices for delta timestamps.
|
||||
|
||||
Args:
|
||||
abs_idx: The absolute index in the full dataset (not the relative index in filtered episodes).
|
||||
ep_idx: The episode index.
|
||||
|
||||
Returns:
|
||||
A tuple of (query_indices, padding) where:
|
||||
- query_indices: Dict mapping keys to lists of absolute indices to query
|
||||
- padding: Dict mapping "{key}_is_pad" to boolean tensors indicating padded positions
|
||||
"""
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
ep_start = ep["dataset_from_index"]
|
||||
ep_end = ep["dataset_to_index"]
|
||||
query_indices = {
|
||||
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
|
||||
key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx]
|
||||
[(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx]
|
||||
)
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
@@ -1037,10 +1056,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self._ensure_hf_dataset_loaded()
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
# Use the absolute index from the dataset for delta timestamp calculations
|
||||
abs_idx = item["index"].item()
|
||||
|
||||
query_indices = None
|
||||
if self.delta_indices is not None:
|
||||
query_indices, padding = self._get_query_indices(idx, ep_idx)
|
||||
query_indices, padding = self._get_query_indices(abs_idx, ep_idx)
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
@@ -1049,7 +1070,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if len(self.meta.video_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
try:
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
except Exception as e:
|
||||
print("\n" + "=" * 120)
|
||||
print("[VIDEO DECODE FAILURE]")
|
||||
print(f"item={item}")
|
||||
print(f"query_indices={query_indices}")
|
||||
print(f"query_timestamps={query_timestamps}")
|
||||
print(f"ep_idx={ep_idx}")
|
||||
print("=" * 120 + "\n")
|
||||
raise
|
||||
item = {**video_frames, **item}
|
||||
|
||||
if self.image_transforms is not None:
|
||||
@@ -1060,6 +1091,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks.iloc[task_idx].name
|
||||
|
||||
# optionally add high level task index
|
||||
if "task_index_high_level" in self.features:
|
||||
high_level_task_idx = item["task_index_high_level"].item()
|
||||
item["robot_utterance"] = self.meta.tasks_high_level.iloc[high_level_task_idx]["robot_utterance"]
|
||||
item["user_prompt"] = self.meta.tasks_high_level.iloc[high_level_task_idx]["user_prompt"]
|
||||
|
||||
|
||||
|
||||
# add subtask information if available
|
||||
if "subtask_index" in self.features and self.meta.subtasks is not None:
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
@@ -1498,7 +1543,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if isinstance(episode_index, np.ndarray):
|
||||
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
|
||||
for cam_key in self.meta.camera_keys:
|
||||
for cam_key in self.meta.image_keys:
|
||||
img_dir = self._get_image_file_dir(episode_index, cam_key)
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
@@ -216,16 +216,17 @@ class ImageTransformsConfig:
|
||||
|
||||
|
||||
def make_transform_from_config(cfg: ImageTransformConfig):
|
||||
if cfg.type == "Identity":
|
||||
return v2.Identity(**cfg.kwargs)
|
||||
elif cfg.type == "ColorJitter":
|
||||
return v2.ColorJitter(**cfg.kwargs)
|
||||
elif cfg.type == "SharpnessJitter":
|
||||
if cfg.type == "SharpnessJitter":
|
||||
return SharpnessJitter(**cfg.kwargs)
|
||||
elif cfg.type == "RandomAffine":
|
||||
return v2.RandomAffine(**cfg.kwargs)
|
||||
else:
|
||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||
|
||||
transform_cls = getattr(v2, cfg.type, None)
|
||||
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
|
||||
return transform_cls(**cfg.kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Transform '{cfg.type}' is not valid. It must be a class in "
|
||||
f"torchvision.transforms.v2 or 'SharpnessJitter'."
|
||||
)
|
||||
|
||||
|
||||
class ImageTransforms(Transform):
|
||||
|
||||
@@ -60,7 +60,10 @@ VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_TASKS_HIGH_LEVEL_PATH = "meta/tasks_high_level.parquet"
|
||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
@@ -352,6 +355,28 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||
return tasks
|
||||
|
||||
def load_tasks_high_level(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load high-level tasks from tasks_high_level.parquet if it exists."""
|
||||
tasks_high_level_path = local_dir / DEFAULT_TASKS_HIGH_LEVEL_PATH
|
||||
if tasks_high_level_path.exists():
|
||||
return pd.read_parquet(tasks_high_level_path)
|
||||
return None
|
||||
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
@@ -1172,12 +1197,21 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
||||
)
|
||||
|
||||
|
||||
def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None:
|
||||
def to_parquet_with_hf_images(
|
||||
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
|
||||
) -> None:
|
||||
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
||||
|
||||
Args:
|
||||
df: DataFrame to write to parquet.
|
||||
path: Path to write the parquet file.
|
||||
features: Optional HuggingFace Features schema. If provided, ensures image columns
|
||||
are properly typed as Image() in the parquet schema.
|
||||
"""
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
|
||||
ds.to_parquet(path)
|
||||
|
||||
|
||||
def item_to_torch(item: dict) -> dict:
|
||||
|
||||
@@ -205,6 +205,7 @@ class ObservationConfig:
|
||||
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
display_cameras: bool = False
|
||||
|
||||
|
||||
@@ -260,6 +261,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
@dataclass
|
||||
class LiberoEnv(EnvConfig):
|
||||
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
|
||||
task_ids: list[int] | None = None
|
||||
fps: int = 30
|
||||
episode_length: int | None = None
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
@@ -338,10 +340,10 @@ class LiberoEnv(EnvConfig):
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
}
|
||||
kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
||||
if self.task_ids is not None:
|
||||
kwargs["task_ids"] = self.task_ids
|
||||
return kwargs
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("metaworld")
|
||||
|
||||
@@ -293,9 +293,9 @@ class LiberoEnv(gym.Env):
|
||||
def reset(self, seed=None, **kwargs):
|
||||
super().reset(seed=seed)
|
||||
self._env.seed(seed)
|
||||
if self.init_states and self._init_states is not None:
|
||||
self._env.set_init_state(self._init_states[self._init_state_id])
|
||||
raw_obs = self._env.reset()
|
||||
if self.init_states and self._init_states is not None:
|
||||
raw_obs = self._env.set_init_state(self._init_states[self._init_state_id])
|
||||
|
||||
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
|
||||
# Step the simulator with a no-op action for a few frames so everything settles.
|
||||
|
||||
@@ -14,4 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus
|
||||
from .motors_bus import (
|
||||
Motor,
|
||||
MotorCalibration,
|
||||
MotorNormMode,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ from dataclasses import dataclass
|
||||
|
||||
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"
|
||||
|
||||
from lerobot.motors import MotorCalibration, MotorsBus
|
||||
from .motors_bus import MotorCalibration, MotorsBus
|
||||
|
||||
BAR_LEN, BAR_THICKNESS = 450, 8
|
||||
HANDLE_R = 10
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .damiao import DamiaoMotorsBus
|
||||
from .tables import *
|
||||
@@ -0,0 +1,833 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Portions of this file are derived from DM_Control_Python by cmjang.
|
||||
# Licensed under the MIT License; see `LICENSE` for the full text:
|
||||
# https://github.com/cmjang/DM_Control_Python
|
||||
|
||||
import logging
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from lerobot.utils.import_utils import _can_available
|
||||
|
||||
if TYPE_CHECKING or _can_available:
|
||||
import can
|
||||
else:
|
||||
|
||||
class can: # noqa: N801
|
||||
Message = object
|
||||
interface = None
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBusBase, NameOrID, Value
|
||||
from .tables import (
|
||||
AVAILABLE_BAUDRATES,
|
||||
CAN_CMD_DISABLE,
|
||||
CAN_CMD_ENABLE,
|
||||
CAN_CMD_REFRESH,
|
||||
CAN_CMD_SET_ZERO,
|
||||
CAN_PARAM_ID,
|
||||
DEFAULT_BAUDRATE,
|
||||
DEFAULT_TIMEOUT_MS,
|
||||
MIT_KD_RANGE,
|
||||
MIT_KP_RANGE,
|
||||
MOTOR_LIMIT_PARAMS,
|
||||
MotorType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
LONG_TIMEOUT_SEC = 0.1
|
||||
MEDIUM_TIMEOUT_SEC = 0.01
|
||||
SHORT_TIMEOUT_SEC = 0.001
|
||||
PRECISE_TIMEOUT_SEC = 0.0001
|
||||
|
||||
|
||||
class MotorState(TypedDict):
|
||||
position: float
|
||||
velocity: float
|
||||
torque: float
|
||||
temp_mos: float
|
||||
temp_rotor: float
|
||||
|
||||
|
||||
class DamiaoMotorsBus(MotorsBusBase):
|
||||
"""
|
||||
The Damiao implementation for a MotorsBus using CAN bus communication.
|
||||
|
||||
This class uses python-can for CAN bus communication with Damiao motors.
|
||||
For more info, see:
|
||||
- python-can documentation: https://python-can.readthedocs.io/en/stable/
|
||||
- Seedstudio documentation: https://wiki.seeedstudio.com/damiao_series/
|
||||
- DM_Control_Python repo: https://github.com/cmjang/DM_Control_Python
|
||||
"""
|
||||
|
||||
# CAN-specific settings
|
||||
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
|
||||
default_baudrate = DEFAULT_BAUDRATE
|
||||
default_timeout = DEFAULT_TIMEOUT_MS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
can_interface: str = "auto",
|
||||
use_can_fd: bool = True,
|
||||
bitrate: int = 1000000,
|
||||
data_bitrate: int | None = 5000000,
|
||||
):
|
||||
"""
|
||||
Initialize the Damiao motors bus.
|
||||
|
||||
Args:
|
||||
port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS)
|
||||
motors: Dictionary mapping motor names to Motor objects
|
||||
calibration: Optional calibration data
|
||||
can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial)
|
||||
use_can_fd: Whether to use CAN FD mode (default: True for OpenArms)
|
||||
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
|
||||
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
|
||||
"""
|
||||
super().__init__(port, motors, calibration)
|
||||
self.port = port
|
||||
self.can_interface = can_interface
|
||||
self.use_can_fd = use_can_fd
|
||||
self.bitrate = bitrate
|
||||
self.data_bitrate = data_bitrate
|
||||
self.canbus: can.interface.Bus | None = None
|
||||
self._is_connected = False
|
||||
|
||||
# Map motor names to CAN IDs
|
||||
self._motor_can_ids: dict[str, int] = {}
|
||||
self._recv_id_to_motor: dict[int, str] = {}
|
||||
self._motor_types: dict[str, MotorType] = {}
|
||||
|
||||
for name, motor in self.motors.items():
|
||||
if motor.motor_type_str is None:
|
||||
raise ValueError(f"Motor '{name}' is missing required 'motor_type'")
|
||||
self._motor_types[name] = getattr(MotorType, motor.motor_type_str.upper().replace("-", "_"))
|
||||
|
||||
# Map recv_id to motor name for filtering responses
|
||||
if motor.recv_id is not None:
|
||||
self._recv_id_to_motor[motor.recv_id] = name
|
||||
|
||||
# State cache for handling packet drops safely
|
||||
self._last_known_states: dict[str, MotorState] = {
|
||||
name: {
|
||||
"position": 0.0,
|
||||
"velocity": 0.0,
|
||||
"torque": 0.0,
|
||||
"temp_mos": 0.0,
|
||||
"temp_rotor": 0.0,
|
||||
}
|
||||
for name in self.motors
|
||||
}
|
||||
|
||||
# Dynamic gains storage
|
||||
# Defaults: Kp=10.0 (Stiffness), Kd=0.5 (Damping)
|
||||
self._gains: dict[str, dict[str, float]] = {name: {"kp": 10.0, "kd": 0.5} for name in self.motors}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the CAN bus is connected."""
|
||||
return self._is_connected and self.canbus is not None
|
||||
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""
|
||||
Open the CAN bus and initialize communication.
|
||||
|
||||
Args:
|
||||
handshake: If True, ping all motors to verify they're present
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is already connected."
|
||||
)
|
||||
|
||||
try:
|
||||
# Auto-detect interface type based on port name
|
||||
if self.can_interface == "auto":
|
||||
if self.port.startswith("/dev/"):
|
||||
self.can_interface = "slcan"
|
||||
logger.info(f"Auto-detected slcan interface for port {self.port}")
|
||||
else:
|
||||
self.can_interface = "socketcan"
|
||||
logger.info(f"Auto-detected socketcan interface for port {self.port}")
|
||||
|
||||
# Connect to CAN bus
|
||||
kwargs = {
|
||||
"channel": self.port,
|
||||
"bitrate": self.bitrate,
|
||||
"interface": self.can_interface,
|
||||
}
|
||||
|
||||
if self.can_interface == "socketcan" and self.use_can_fd and self.data_bitrate is not None:
|
||||
kwargs.update({"data_bitrate": self.data_bitrate, "fd": True})
|
||||
logger.info(
|
||||
f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Connected to {self.port} with {self.can_interface} (bitrate={self.bitrate})")
|
||||
|
||||
self.canbus = can.interface.Bus(**kwargs)
|
||||
self._is_connected = True
|
||||
|
||||
if handshake:
|
||||
self._handshake()
|
||||
|
||||
logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.")
|
||||
except Exception as e:
|
||||
self._is_connected = False
|
||||
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
|
||||
|
||||
def _handshake(self) -> None:
|
||||
"""
|
||||
Verify all motors are present and populate initial state cache.
|
||||
Raises ConnectionError if any motor fails to respond.
|
||||
"""
|
||||
logger.info("Starting handshake with motors...")
|
||||
|
||||
# Drain any pending messages
|
||||
while self.canbus.recv(timeout=0.01):
|
||||
pass
|
||||
|
||||
missing_motors = []
|
||||
for motor_name in self.motors:
|
||||
motor_id = self._get_motor_id(motor_name)
|
||||
recv_id = self._get_motor_recv_id(motor_name)
|
||||
|
||||
# Send enable command
|
||||
data = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, CAN_CMD_ENABLE]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
|
||||
# Wait for response with longer timeout
|
||||
response = None
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < 0.1:
|
||||
response = self.canbus.recv(timeout=0.1)
|
||||
if response and response.arbitration_id == recv_id:
|
||||
break
|
||||
response = None
|
||||
|
||||
if response is None:
|
||||
missing_motors.append(motor_name)
|
||||
else:
|
||||
self._process_response(motor_name, msg)
|
||||
time.sleep(MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
if missing_motors:
|
||||
raise ConnectionError(
|
||||
f"Handshake failed. The following motors did not respond: {missing_motors}. "
|
||||
"Check power (24V) and CAN wiring."
|
||||
)
|
||||
logger.info("Handshake successful. All motors ready.")
|
||||
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""
|
||||
Close the CAN bus connection.
|
||||
|
||||
Args:
|
||||
disable_torque: If True, disable torque on all motors before disconnecting
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.")
|
||||
|
||||
if disable_torque:
|
||||
try:
|
||||
self.disable_torque()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to disable torque during disconnect: {e}")
|
||||
|
||||
if self.canbus:
|
||||
self.canbus.shutdown()
|
||||
self.canbus = None
|
||||
self._is_connected = False
|
||||
logger.debug(f"{self.__class__.__name__} disconnected.")
|
||||
|
||||
def configure_motors(self) -> None:
|
||||
"""Configure all motors with default settings."""
|
||||
# Damiao motors don't require much configuration in MIT mode
|
||||
# Just ensure they're enabled
|
||||
for motor in self.motors:
|
||||
self._send_simple_command(motor, CAN_CMD_ENABLE)
|
||||
time.sleep(MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
def _send_simple_command(self, motor: NameOrID, command_byte: int) -> None:
|
||||
"""Helper to send simple 8-byte commands (Enable, Disable, Zero)."""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [0xFF] * 7 + [command_byte]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
if msg := self._recv_motor_response(expected_recv_id=recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
else:
|
||||
logger.debug(f"No response from {motor_name} after command 0x{command_byte:02X}")
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Enable torque on selected motors."""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
for motor in target_motors:
|
||||
for _ in range(num_retry + 1):
|
||||
try:
|
||||
self._send_simple_command(motor, CAN_CMD_ENABLE)
|
||||
break
|
||||
except Exception as e:
|
||||
if _ == num_retry:
|
||||
raise e
|
||||
time.sleep(MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Disable torque on selected motors."""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
for motor in target_motors:
|
||||
for _ in range(num_retry + 1):
|
||||
try:
|
||||
self._send_simple_command(motor, CAN_CMD_DISABLE)
|
||||
break
|
||||
except Exception as e:
|
||||
if _ == num_retry:
|
||||
raise e
|
||||
time.sleep(MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
@contextmanager
|
||||
def torque_disabled(self, motors: str | list[str] | None = None):
|
||||
"""
|
||||
Context manager that guarantees torque is re-enabled.
|
||||
|
||||
This helper is useful to temporarily disable torque when configuring motors.
|
||||
"""
|
||||
self.disable_torque(motors)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.enable_torque(motors)
|
||||
|
||||
def set_zero_position(self, motors: str | list[str] | None = None) -> None:
|
||||
"""Set current position as zero for selected motors."""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
for motor in target_motors:
|
||||
self._send_simple_command(motor, CAN_CMD_SET_ZERO)
|
||||
time.sleep(MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
def _refresh_motor(self, motor: NameOrID) -> can.Message | None:
|
||||
"""Refresh motor status and return the response."""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
return self._recv_motor_response(expected_recv_id=recv_id)
|
||||
|
||||
def _recv_motor_response(
|
||||
self, expected_recv_id: int | None = None, timeout: float = 0.001
|
||||
) -> can.Message | None:
|
||||
"""
|
||||
Receive a response from a motor.
|
||||
|
||||
Args:
|
||||
expected_recv_id: If provided, only return messages from this CAN ID
|
||||
timeout: Timeout in seconds (default: 1ms for high-speed operation)
|
||||
Returns:
|
||||
CAN message if received, None otherwise
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
messages_seen = []
|
||||
while time.time() - start_time < timeout:
|
||||
msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC)
|
||||
if msg:
|
||||
messages_seen.append(f"0x{msg.arbitration_id:02X}")
|
||||
if expected_recv_id is None or msg.arbitration_id == expected_recv_id:
|
||||
return msg
|
||||
logger.debug(
|
||||
f"Ignoring message from 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}"
|
||||
)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
if messages_seen:
|
||||
logger.debug(
|
||||
f"Received {len(messages_seen)} msgs from {set(messages_seen)}, expected 0x{expected_recv_id:02X}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No CAN messages received (expected 0x{expected_recv_id:02X})")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to receive CAN message: {e}")
|
||||
return None
|
||||
|
||||
def _recv_all_responses(
|
||||
self, expected_recv_ids: list[int], timeout: float = 0.002
|
||||
) -> dict[int, can.Message]:
|
||||
"""
|
||||
Efficiently receive responses from multiple motors at once.
|
||||
Uses the OpenArms pattern: collect all available messages within timeout.
|
||||
|
||||
Args:
|
||||
expected_recv_ids: List of CAN IDs we expect responses from
|
||||
timeout: Total timeout in seconds (default: 2ms)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping recv_id to CAN message
|
||||
"""
|
||||
responses = {}
|
||||
expected_set = set(expected_recv_ids)
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
|
||||
# 100us poll timeout
|
||||
msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC)
|
||||
if msg and msg.arbitration_id in expected_set:
|
||||
responses[msg.arbitration_id] = msg
|
||||
if len(responses) == len(expected_recv_ids):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"Error receiving responses: {e}")
|
||||
|
||||
return responses
|
||||
|
||||
def _encode_mit_packet(
|
||||
self,
|
||||
motor_type: MotorType,
|
||||
kp: float,
|
||||
kd: float,
|
||||
position_degrees: float,
|
||||
velocity_deg_per_sec: float,
|
||||
torque: float,
|
||||
) -> list[int]:
|
||||
"""Helper to encode control parameters into 8 bytes for MIT mode."""
|
||||
# Convert degrees to radians
|
||||
position_rad = np.radians(position_degrees)
|
||||
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
|
||||
|
||||
# Get motor limits
|
||||
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||
|
||||
# Encode parameters
|
||||
kp_uint = self._float_to_uint(kp, *MIT_KP_RANGE, 12)
|
||||
kd_uint = self._float_to_uint(kd, *MIT_KD_RANGE, 12)
|
||||
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
|
||||
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
|
||||
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
|
||||
|
||||
# Pack data
|
||||
data = [0] * 8
|
||||
data[0] = (q_uint >> 8) & 0xFF
|
||||
data[1] = q_uint & 0xFF
|
||||
data[2] = dq_uint >> 4
|
||||
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
|
||||
data[4] = kp_uint & 0xFF
|
||||
data[5] = kd_uint >> 4
|
||||
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
|
||||
data[7] = tau_uint & 0xFF
|
||||
return data
|
||||
|
||||
def _mit_control(
|
||||
self,
|
||||
motor: NameOrID,
|
||||
kp: float,
|
||||
kd: float,
|
||||
position_degrees: float,
|
||||
velocity_deg_per_sec: float,
|
||||
torque: float,
|
||||
) -> None:
|
||||
"""Send MIT control command to a motor."""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types[motor_name]
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
if msg := self._recv_motor_response(expected_recv_id=recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
else:
|
||||
logger.debug(f"No response from {motor_name} after MIT control command")
|
||||
|
||||
def _mit_control_batch(
|
||||
self,
|
||||
commands: dict[NameOrID, tuple[float, float, float, float, float]],
|
||||
) -> None:
|
||||
"""
|
||||
Send MIT control commands to multiple motors in batch.
|
||||
Sends all commands first, then collects responses.
|
||||
|
||||
Args:
|
||||
commands: Dict mapping motor name/ID to (kp, kd, position_deg, velocity_deg/s, torque)
|
||||
Example: {'joint_1': (10.0, 0.5, 45.0, 0.0, 0.0), ...}
|
||||
"""
|
||||
if not commands:
|
||||
return
|
||||
|
||||
recv_id_to_motor: dict[int, str] = {}
|
||||
|
||||
# Step 1: Send all MIT control commands
|
||||
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types[motor_name]
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
|
||||
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
||||
|
||||
# Step 2: Collect responses and update state cache
|
||||
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=SHORT_TIMEOUT_SEC)
|
||||
for recv_id, motor_name in recv_id_to_motor.items():
|
||||
if msg := responses.get(recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
|
||||
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
|
||||
"""Convert float to unsigned integer for CAN transmission."""
|
||||
x = max(x_min, min(x_max, x)) # Clamp to range
|
||||
span = x_max - x_min
|
||||
data_norm = (x - x_min) / span
|
||||
return int(data_norm * ((1 << bits) - 1))
|
||||
|
||||
def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float:
|
||||
"""Convert unsigned integer from CAN to float."""
|
||||
span = x_max - x_min
|
||||
data_norm = float(x) / ((1 << bits) - 1)
|
||||
return data_norm * span + x_min
|
||||
|
||||
def _decode_motor_state(
|
||||
self, data: bytearray | bytes, motor_type: MotorType
|
||||
) -> tuple[float, float, float, int, int]:
|
||||
"""
|
||||
Decode motor state from CAN data.
|
||||
Returns: (position_deg, velocity_deg_s, torque, temp_mos, temp_rotor)
|
||||
"""
|
||||
if len(data) < 8:
|
||||
raise ValueError("Invalid motor state data")
|
||||
|
||||
# Extract encoded values
|
||||
q_uint = (data[1] << 8) | data[2]
|
||||
dq_uint = (data[3] << 4) | (data[4] >> 4)
|
||||
tau_uint = ((data[4] & 0x0F) << 8) | data[5]
|
||||
t_mos = data[6]
|
||||
t_rotor = data[7]
|
||||
|
||||
# Get motor limits
|
||||
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||
|
||||
# Decode to physical values
|
||||
position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16)
|
||||
velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12)
|
||||
torque = self._uint_to_float(tau_uint, -tmax, tmax, 12)
|
||||
|
||||
return np.degrees(position_rad), np.degrees(velocity_rad_per_sec), torque, t_mos, t_rotor
|
||||
|
||||
def _process_response(self, motor: str, msg: can.Message) -> None:
|
||||
"""Decode a message and update the motor state cache."""
|
||||
try:
|
||||
motor_type = self._motor_types[motor]
|
||||
pos, vel, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
|
||||
|
||||
self._last_known_states[motor] = {
|
||||
"position": pos,
|
||||
"velocity": vel,
|
||||
"torque": torque,
|
||||
"temp_mos": float(t_mos),
|
||||
"temp_rotor": float(t_rotor),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to decode response from {motor}: {e}")
|
||||
|
||||
def read(self, data_name: str, motor: str) -> Value:
|
||||
"""Read a value from a single motor. Positions are always in degrees."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Refresh motor to get latest state
|
||||
msg = self._refresh_motor(motor)
|
||||
if msg is None:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
raise ConnectionError(
|
||||
f"No response from motor '{motor}' (send ID: 0x{motor_id:02X}, recv ID: 0x{recv_id:02X}). "
|
||||
f"Check that: 1) Motor is powered (24V), 2) CAN wiring is correct, "
|
||||
f"3) Motor IDs are configured correctly using Damiao Debugging Tools"
|
||||
)
|
||||
|
||||
self._process_response(motor, msg)
|
||||
return self._get_cached_value(motor, data_name)
|
||||
|
||||
def _get_cached_value(self, motor: str, data_name: str) -> Value:
|
||||
"""Retrieve a specific value from the cache."""
|
||||
state = self._last_known_states[motor]
|
||||
mapping: dict[str, Any] = {
|
||||
"Present_Position": state["position"],
|
||||
"Present_Velocity": state["velocity"],
|
||||
"Present_Torque": state["torque"],
|
||||
"Temperature_MOS": state["temp_mos"],
|
||||
"Temperature_Rotor": state["temp_rotor"],
|
||||
}
|
||||
if data_name not in mapping:
|
||||
raise ValueError(f"Unknown data_name: {data_name}")
|
||||
return mapping[data_name]
|
||||
|
||||
def write(
|
||||
self,
|
||||
data_name: str,
|
||||
motor: str,
|
||||
value: Value,
|
||||
) -> None:
|
||||
"""
|
||||
Write a value to a single motor. Positions are always in degrees.
|
||||
Can write 'Goal_Position', 'Kp', or 'Kd'.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if data_name in ("Kp", "Kd"):
|
||||
self._gains[motor][data_name.lower()] = float(value)
|
||||
elif data_name == "Goal_Position":
|
||||
kp = self._gains[motor]["kp"]
|
||||
kd = self._gains[motor]["kd"]
|
||||
self._mit_control(motor, kp, kd, float(value), 0.0, 0.0)
|
||||
else:
|
||||
raise ValueError(f"Writing {data_name} not supported in MIT mode")
|
||||
|
||||
def sync_read(
|
||||
self,
|
||||
data_name: str,
|
||||
motors: str | list[str] | None = None,
|
||||
) -> dict[str, Value]:
|
||||
"""
|
||||
Read the same value from multiple motors simultaneously.
|
||||
"""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
self._batch_refresh(target_motors)
|
||||
|
||||
result = {}
|
||||
for motor in target_motors:
|
||||
result[motor] = self._get_cached_value(motor, data_name)
|
||||
return result
|
||||
|
||||
def sync_read_all_states(
|
||||
self,
|
||||
motors: str | list[str] | None = None,
|
||||
*,
|
||||
num_retry: int = 0,
|
||||
) -> dict[str, MotorState]:
|
||||
"""
|
||||
Read ALL motor states (position, velocity, torque) from multiple motors in ONE refresh cycle.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping motor names to state dicts with keys: 'position', 'velocity', 'torque'
|
||||
Example: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
|
||||
"""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
self._batch_refresh(target_motors)
|
||||
|
||||
result = {}
|
||||
for motor in target_motors:
|
||||
result[motor] = self._last_known_states[motor].copy()
|
||||
return result
|
||||
|
||||
def _batch_refresh(self, motors: list[str]) -> None:
|
||||
"""Internal helper to refresh a list of motors and update cache."""
|
||||
# Send refresh commands
|
||||
for motor in motors:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(
|
||||
arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd
|
||||
)
|
||||
self.canbus.send(msg)
|
||||
|
||||
# Collect responses
|
||||
expected_recv_ids = [self._get_motor_recv_id(m) for m in motors]
|
||||
responses = self._recv_all_responses(expected_recv_ids, timeout=MEDIUM_TIMEOUT_SEC)
|
||||
|
||||
# Update cache
|
||||
for motor in motors:
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
msg = responses.get(recv_id)
|
||||
if msg:
|
||||
self._process_response(motor, msg)
|
||||
else:
|
||||
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
|
||||
"""
|
||||
Write values to multiple motors simultaneously. Positions are always in degrees.
|
||||
"""
|
||||
if data_name in ("Kp", "Kd"):
|
||||
key = data_name.lower()
|
||||
for motor, val in values.items():
|
||||
self._gains[motor][key] = float(val)
|
||||
|
||||
elif data_name == "Goal_Position":
|
||||
# Step 1: Send all MIT control commands
|
||||
recv_id_to_motor: dict[int, str] = {}
|
||||
for motor, value_degrees in values.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types[motor_name]
|
||||
|
||||
kp = self._gains[motor]["kp"]
|
||||
kd = self._gains[motor]["kd"]
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0)
|
||||
msg = can.Message(
|
||||
arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd
|
||||
)
|
||||
self.canbus.send(msg)
|
||||
precise_sleep(PRECISE_TIMEOUT_SEC)
|
||||
|
||||
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
||||
|
||||
# Step 2: Collect responses and update state cache
|
||||
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=MEDIUM_TIMEOUT_SEC)
|
||||
for recv_id, motor_name in recv_id_to_motor.items():
|
||||
if msg := responses.get(recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
else:
|
||||
# Fall back to individual writes
|
||||
for motor, value in values.items():
|
||||
self.write(data_name, motor, value)
|
||||
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
"""Read calibration data from motors."""
|
||||
# Damiao motors don't store calibration internally
|
||||
# Return existing calibration or empty dict
|
||||
return self.calibration if self.calibration else {}
|
||||
|
||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
|
||||
"""Write calibration data to motors."""
|
||||
# Damiao motors don't store calibration internally
|
||||
# Just cache it in memory
|
||||
if cache:
|
||||
self.calibration = calibration_dict
|
||||
|
||||
def record_ranges_of_motion(
|
||||
self,
|
||||
motors: NameOrID | list[NameOrID] | None = None,
|
||||
display_values: bool = True,
|
||||
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
|
||||
"""
|
||||
Interactively record the min/max values of each motor in degrees.
|
||||
|
||||
Move the joints by hand (with torque disabled) while the method streams live positions.
|
||||
Press Enter to finish.
|
||||
"""
|
||||
target_motors = self._get_motors_list(motors)
|
||||
|
||||
self.disable_torque(target_motors)
|
||||
time.sleep(LONG_TIMEOUT_SEC)
|
||||
|
||||
start_positions = self.sync_read("Present_Position", target_motors)
|
||||
mins = start_positions.copy()
|
||||
maxes = start_positions.copy()
|
||||
|
||||
print("\nMove joints through their full range of motion. Press ENTER when done.")
|
||||
user_pressed_enter = False
|
||||
|
||||
while not user_pressed_enter:
|
||||
positions = self.sync_read("Present_Position", target_motors)
|
||||
|
||||
for motor in target_motors:
|
||||
if motor in positions:
|
||||
mins[motor] = min(positions[motor], mins.get(motor, positions[motor]))
|
||||
maxes[motor] = max(positions[motor], maxes.get(motor, positions[motor]))
|
||||
|
||||
if display_values:
|
||||
print("\n" + "=" * 50)
|
||||
print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}")
|
||||
print("-" * 50)
|
||||
for motor in target_motors:
|
||||
if motor in positions:
|
||||
print(
|
||||
f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}"
|
||||
)
|
||||
|
||||
if enter_pressed():
|
||||
user_pressed_enter = True
|
||||
|
||||
if display_values and not user_pressed_enter:
|
||||
move_cursor_up(len(target_motors) + 4)
|
||||
|
||||
time.sleep(LONG_TIMEOUT_SEC)
|
||||
|
||||
self.enable_torque(target_motors)
|
||||
|
||||
for motor in target_motors:
|
||||
if (motor in mins) and (motor in maxes) and (int(abs(maxes[motor] - mins[motor])) < 5):
|
||||
raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)")
|
||||
|
||||
return mins, maxes
|
||||
|
||||
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
|
||||
"""Convert motor specification to list of motor names."""
|
||||
if motors is None:
|
||||
return list(self.motors.keys())
|
||||
elif isinstance(motors, str):
|
||||
return [motors]
|
||||
elif isinstance(motors, list):
|
||||
return motors
|
||||
else:
|
||||
raise TypeError(f"Invalid motors type: {type(motors)}")
|
||||
|
||||
def _get_motor_id(self, motor: NameOrID) -> int:
|
||||
"""Get CAN ID for a motor."""
|
||||
if isinstance(motor, str):
|
||||
if motor in self.motors:
|
||||
return self.motors[motor].id
|
||||
else:
|
||||
raise ValueError(f"Unknown motor: {motor}")
|
||||
else:
|
||||
return motor
|
||||
|
||||
def _get_motor_name(self, motor: NameOrID) -> str:
|
||||
"""Get motor name from name or ID."""
|
||||
if isinstance(motor, str):
|
||||
return motor
|
||||
else:
|
||||
for name, m in self.motors.items():
|
||||
if m.id == motor:
|
||||
return name
|
||||
raise ValueError(f"Unknown motor ID: {motor}")
|
||||
|
||||
def _get_motor_recv_id(self, motor: NameOrID) -> int:
|
||||
"""Get motor recv_id from name or ID."""
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_obj = self.motors.get(motor_name)
|
||||
if motor_obj and motor_obj.recv_id is not None:
|
||||
return motor_obj.recv_id
|
||||
else:
|
||||
raise ValueError(f"Motor {motor_obj} doesn't have a valid recv_id (None).")
|
||||
|
||||
@cached_property
|
||||
def is_calibrated(self) -> bool:
|
||||
"""Check if motors are calibrated."""
|
||||
return bool(self.calibration)
|
||||
@@ -0,0 +1,209 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration tables for Damiao motors."""
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
# Motor type definitions
|
||||
class MotorType(IntEnum):
|
||||
DM3507 = 0
|
||||
DM4310 = 1
|
||||
DM4310_48V = 2
|
||||
DM4340 = 3
|
||||
DM4340_48V = 4
|
||||
DM6006 = 5
|
||||
DM8006 = 6
|
||||
DM8009 = 7
|
||||
DM10010L = 8
|
||||
DM10010 = 9
|
||||
DMH3510 = 10
|
||||
DMH6215 = 11
|
||||
DMG6220 = 12
|
||||
|
||||
|
||||
# Control modes
|
||||
class ControlMode(IntEnum):
|
||||
MIT = 1
|
||||
POS_VEL = 2
|
||||
VEL = 3
|
||||
TORQUE_POS = 4
|
||||
|
||||
|
||||
# Motor variable IDs (RID)
|
||||
class MotorVariable(IntEnum):
|
||||
UV_VALUE = 0
|
||||
KT_VALUE = 1
|
||||
OT_VALUE = 2
|
||||
OC_VALUE = 3
|
||||
ACC = 4
|
||||
DEC = 5
|
||||
MAX_SPD = 6
|
||||
MST_ID = 7
|
||||
ESC_ID = 8
|
||||
TIMEOUT = 9
|
||||
CTRL_MODE = 10
|
||||
DAMP = 11
|
||||
INERTIA = 12
|
||||
HW_VER = 13
|
||||
SW_VER = 14
|
||||
SN = 15
|
||||
NPP = 16
|
||||
RS = 17
|
||||
LS = 18
|
||||
FLUX = 19
|
||||
GR = 20
|
||||
PMAX = 21
|
||||
VMAX = 22
|
||||
TMAX = 23
|
||||
I_BW = 24
|
||||
KP_ASR = 25
|
||||
KI_ASR = 26
|
||||
KP_APR = 27
|
||||
KI_APR = 28
|
||||
OV_VALUE = 29
|
||||
GREF = 30
|
||||
DETA = 31
|
||||
V_BW = 32
|
||||
IQ_C1 = 33
|
||||
VL_C1 = 34
|
||||
CAN_BR = 35
|
||||
SUB_VER = 36
|
||||
U_OFF = 50
|
||||
V_OFF = 51
|
||||
K1 = 52
|
||||
K2 = 53
|
||||
M_OFF = 54
|
||||
DIR = 55
|
||||
P_M = 80
|
||||
XOUT = 81
|
||||
|
||||
|
||||
# Motor limit parameters [PMAX, VMAX, TMAX]
|
||||
# PMAX: Maximum position (rad)
|
||||
# VMAX: Maximum velocity (rad/s)
|
||||
# TMAX: Maximum torque (N·m)
|
||||
MOTOR_LIMIT_PARAMS = {
|
||||
MotorType.DM3507: (12.5, 30, 10),
|
||||
MotorType.DM4310: (12.5, 30, 10),
|
||||
MotorType.DM4310_48V: (12.5, 50, 10),
|
||||
MotorType.DM4340: (12.5, 8, 28),
|
||||
MotorType.DM4340_48V: (12.5, 10, 28),
|
||||
MotorType.DM6006: (12.5, 45, 20),
|
||||
MotorType.DM8006: (12.5, 45, 40),
|
||||
MotorType.DM8009: (12.5, 45, 54),
|
||||
MotorType.DM10010L: (12.5, 25, 200),
|
||||
MotorType.DM10010: (12.5, 20, 200),
|
||||
MotorType.DMH3510: (12.5, 280, 1),
|
||||
MotorType.DMH6215: (12.5, 45, 10),
|
||||
MotorType.DMG6220: (12.5, 45, 10),
|
||||
}
|
||||
|
||||
# Motor model names
|
||||
MODEL_NAMES = {
|
||||
MotorType.DM3507: "dm3507",
|
||||
MotorType.DM4310: "dm4310",
|
||||
MotorType.DM4310_48V: "dm4310_48v",
|
||||
MotorType.DM4340: "dm4340",
|
||||
MotorType.DM4340_48V: "dm4340_48v",
|
||||
MotorType.DM6006: "dm6006",
|
||||
MotorType.DM8006: "dm8006",
|
||||
MotorType.DM8009: "dm8009",
|
||||
MotorType.DM10010L: "dm10010l",
|
||||
MotorType.DM10010: "dm10010",
|
||||
MotorType.DMH3510: "dmh3510",
|
||||
MotorType.DMH6215: "dmh6215",
|
||||
MotorType.DMG6220: "dmg6220",
|
||||
}
|
||||
|
||||
# Motor resolution table (encoder counts per revolution)
|
||||
MODEL_RESOLUTION = {
|
||||
"dm3507": 65536,
|
||||
"dm4310": 65536,
|
||||
"dm4310_48v": 65536,
|
||||
"dm4340": 65536,
|
||||
"dm4340_48v": 65536,
|
||||
"dm6006": 65536,
|
||||
"dm8006": 65536,
|
||||
"dm8009": 65536,
|
||||
"dm10010l": 65536,
|
||||
"dm10010": 65536,
|
||||
"dmh3510": 65536,
|
||||
"dmh6215": 65536,
|
||||
"dmg6220": 65536,
|
||||
}
|
||||
|
||||
# CAN baudrates supported by Damiao motors
|
||||
AVAILABLE_BAUDRATES = [
|
||||
125000, # 0: 125 kbps
|
||||
200000, # 1: 200 kbps
|
||||
250000, # 2: 250 kbps
|
||||
500000, # 3: 500 kbps
|
||||
1000000, # 4: 1 mbps (default for OpenArms)
|
||||
2000000, # 5: 2 mbps
|
||||
2500000, # 6: 2.5 mbps
|
||||
3200000, # 7: 3.2 mbps
|
||||
4000000, # 8: 4 mbps
|
||||
5000000, # 9: 5 mbps
|
||||
]
|
||||
DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms
|
||||
|
||||
# Default timeout in milliseconds
|
||||
DEFAULT_TIMEOUT_MS = 1000
|
||||
|
||||
# OpenArms specific configurations
|
||||
# Based on: https://docs.openarm.dev/software/setup/configure-test
|
||||
# OpenArms has 7 DOF per arm (14 total for dual arm)
|
||||
OPENARMS_ARM_MOTOR_IDS = {
|
||||
"joint_1": {"send": 0x01, "recv": 0x11}, # J1 - Shoulder pan
|
||||
"joint_2": {"send": 0x02, "recv": 0x12}, # J2 - Shoulder lift
|
||||
"joint_3": {"send": 0x03, "recv": 0x13}, # J3 - Elbow flex
|
||||
"joint_4": {"send": 0x04, "recv": 0x14}, # J4 - Wrist flex
|
||||
"joint_5": {"send": 0x05, "recv": 0x15}, # J5 - Wrist roll
|
||||
"joint_6": {"send": 0x06, "recv": 0x16}, # J6 - Wrist pitch
|
||||
"joint_7": {"send": 0x07, "recv": 0x17}, # J7 - Wrist rotation
|
||||
}
|
||||
|
||||
OPENARMS_GRIPPER_MOTOR_IDS = {
|
||||
"gripper": {"send": 0x08, "recv": 0x18}, # J8 - Gripper
|
||||
}
|
||||
|
||||
# Default motor types for OpenArms
|
||||
OPENARMS_DEFAULT_MOTOR_TYPES = {
|
||||
"joint_1": MotorType.DM8009, # Shoulder pan - high torque
|
||||
"joint_2": MotorType.DM8009, # Shoulder lift - high torque
|
||||
"joint_3": MotorType.DM4340, # Shoulder rotation
|
||||
"joint_4": MotorType.DM4340, # Elbow flex
|
||||
"joint_5": MotorType.DM4310, # Wrist roll
|
||||
"joint_6": MotorType.DM4310, # Wrist pitch
|
||||
"joint_7": MotorType.DM4310, # Wrist rotation
|
||||
"gripper": MotorType.DM4310, # Gripper
|
||||
}
|
||||
|
||||
# MIT control parameter ranges
|
||||
MIT_KP_RANGE = (0.0, 500.0)
|
||||
MIT_KD_RANGE = (0.0, 5.0)
|
||||
|
||||
# CAN frame command IDs
|
||||
CAN_CMD_ENABLE = 0xFC
|
||||
CAN_CMD_DISABLE = 0xFD
|
||||
CAN_CMD_SET_ZERO = 0xFE
|
||||
CAN_CMD_REFRESH = 0xCC
|
||||
CAN_CMD_QUERY_PARAM = 0x33
|
||||
CAN_CMD_WRITE_PARAM = 0x55
|
||||
CAN_CMD_SAVE_PARAM = 0xAA
|
||||
|
||||
# CAN ID for parameter operations
|
||||
CAN_PARAM_ID = 0x7FF
|
||||
@@ -22,9 +22,8 @@ import logging
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
|
||||
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
|
||||
from ..encoding_utils import decode_twos_complement, encode_twos_complement
|
||||
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
||||
from .tables import (
|
||||
AVAILABLE_BAUDRATES,
|
||||
MODEL_BAUDRATE_TABLE,
|
||||
@@ -100,7 +99,7 @@ def _split_into_byte_chunks(value: int, length: int) -> list[int]:
|
||||
return data
|
||||
|
||||
|
||||
class DynamixelMotorsBus(MotorsBus):
|
||||
class DynamixelMotorsBus(SerialMotorsBus):
|
||||
"""
|
||||
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
|
||||
the motors. For more info, see the Dynamixel SDK Documentation:
|
||||
@@ -203,9 +202,9 @@ class DynamixelMotorsBus(MotorsBus):
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
|
||||
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
|
||||
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
|
||||
@@ -17,9 +17,8 @@ from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pprint import pformat
|
||||
|
||||
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
|
||||
from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
||||
from .tables import (
|
||||
FIRMWARE_MAJOR_VERSION,
|
||||
FIRMWARE_MINOR_VERSION,
|
||||
@@ -96,7 +95,7 @@ def patch_setPacketTimeout(self, packet_length): # noqa: N802
|
||||
self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50
|
||||
|
||||
|
||||
class FeetechMotorsBus(MotorsBus):
|
||||
class FeetechMotorsBus(SerialMotorsBus):
|
||||
"""
|
||||
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the
|
||||
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
|
||||
@@ -298,11 +297,11 @@ class FeetechMotorsBus(MotorsBus):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self.write("Lock", motor, 0, num_retry=num_retry)
|
||||
|
||||
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
|
||||
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
|
||||
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Lock")
|
||||
self._write(addr, length, motor_id, 0, num_retry=num_retry)
|
||||
self._write(addr, length, motor, 0, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
|
||||
@@ -205,6 +205,7 @@ MODEL_BAUDRATE_TABLE = {
|
||||
|
||||
# Sign-Magnitude encoding bits
|
||||
STS_SMS_SERIES_ENCODINGS_TABLE = {
|
||||
"Present_Load": 10,
|
||||
"Homing_Offset": 11,
|
||||
"Goal_Position": 15,
|
||||
"Goal_Velocity": 15,
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
# TODO(aliberts): Add block noqa when feature below is available
|
||||
# https://github.com/astral-sh/ruff/issues/3711
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
@@ -32,7 +34,7 @@ import serial
|
||||
from deepdiff import DeepDiff
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
NameOrID: TypeAlias = str | int
|
||||
@@ -41,6 +43,81 @@ Value: TypeAlias = int | float
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MotorsBusBase(abc.ABC):
|
||||
"""
|
||||
Base class for all motor bus implementations.
|
||||
|
||||
This is a minimal interface that all motor buses must implement, regardless of their
|
||||
communication protocol (serial, CAN, etc.).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
self.calibration = calibration if calibration else {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""Establish connection to the motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""Disconnect from the motors."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected to the motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, data_name: str, motor: str) -> Value:
|
||||
"""Read a value from a single motor."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def write(self, data_name: str, motor: str, value: Value) -> None:
|
||||
"""Write a value to a single motor."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_read(self, data_name: str, motors: str | list[str] | None = None) -> dict[str, Value]:
|
||||
"""Read a value from multiple motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
|
||||
"""Write values to multiple motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Enable torque on selected motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Disable torque on selected motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
"""Read calibration parameters from the motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
|
||||
"""Write calibration parameters to the motors."""
|
||||
pass
|
||||
|
||||
|
||||
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
|
||||
ctrl_table = model_ctrl_table.get(model)
|
||||
if ctrl_table is None:
|
||||
@@ -97,6 +174,8 @@ class Motor:
|
||||
id: int
|
||||
model: str
|
||||
norm_mode: MotorNormMode
|
||||
motor_type_str: str | None = None
|
||||
recv_id: int | None = None
|
||||
|
||||
|
||||
class PortHandler(Protocol):
|
||||
@@ -203,15 +282,15 @@ class GroupSyncWrite(Protocol):
|
||||
def txPacket(self): ...
|
||||
|
||||
|
||||
class MotorsBus(abc.ABC):
|
||||
class SerialMotorsBus(MotorsBusBase):
|
||||
"""
|
||||
A MotorsBus allows to efficiently read and write to the attached motors.
|
||||
A SerialMotorsBus allows to efficiently read and write to motors connected via serial communication.
|
||||
It represents several motors daisy-chained together and connected through a serial port.
|
||||
There are currently two implementations of this abstract class:
|
||||
There are currently two implementations of this class:
|
||||
- DynamixelMotorsBus
|
||||
- FeetechMotorsBus
|
||||
|
||||
Note: This class may evolve in the future should we add support for other types of bus.
|
||||
This class is specifically for serial-based motor protocols (Dynamixel, Feetech, etc.).
|
||||
|
||||
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
||||
To find the port, you can run our utility script:
|
||||
@@ -260,9 +339,7 @@ class MotorsBus(abc.ABC):
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
self.calibration = calibration if calibration else {}
|
||||
super().__init__(port, motors, calibration)
|
||||
|
||||
self.port_handler: PortHandler
|
||||
self.packet_handler: PacketHandler
|
||||
@@ -411,6 +488,7 @@ class MotorsBus(abc.ABC):
|
||||
"""bool: `True` if the underlying serial port is open."""
|
||||
return self.port_handler.is_open
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""Open the serial port and initialise communication.
|
||||
|
||||
@@ -422,10 +500,6 @@ class MotorsBus(abc.ABC):
|
||||
DeviceAlreadyConnectedError: The port is already open.
|
||||
ConnectionError: The underlying SDK failed to open the port or the handshake did not succeed.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice."
|
||||
)
|
||||
|
||||
self._connect(handshake)
|
||||
self.set_timeout()
|
||||
@@ -447,6 +521,7 @@ class MotorsBus(abc.ABC):
|
||||
def _handshake(self) -> None:
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""Close the serial port (optionally disabling torque first).
|
||||
|
||||
@@ -455,10 +530,6 @@ class MotorsBus(abc.ABC):
|
||||
closing the port. This can prevent damaging motors if they are left applying resisting torque
|
||||
after disconnect.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. Try running `{self.__class__.__name__}.connect()` first."
|
||||
)
|
||||
|
||||
if disable_torque:
|
||||
self.port_handler.clearPort()
|
||||
@@ -538,7 +609,7 @@ class MotorsBus(abc.ABC):
|
||||
self.set_baudrate(self.default_baudrate)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]:
|
||||
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -551,13 +622,13 @@ class MotorsBus(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Disable torque on selected motors.
|
||||
|
||||
Disabling Torque allows to write to the motors' permanent memory area (EPROM/EEPROM).
|
||||
|
||||
Args:
|
||||
motors (int | str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a
|
||||
motors ( str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a
|
||||
list of names or `None` to affect every registered motor. Defaults to `None`.
|
||||
num_retry (int, optional): Number of additional retry attempts on communication failure.
|
||||
Defaults to 0.
|
||||
@@ -907,6 +978,7 @@ class MotorsBus(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def read(
|
||||
self,
|
||||
data_name: str,
|
||||
@@ -927,10 +999,6 @@ class MotorsBus(abc.ABC):
|
||||
Returns:
|
||||
Value: Raw or normalised value depending on *normalize*.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
id_ = self.motors[motor].id
|
||||
model = self.motors[motor].model
|
||||
@@ -981,6 +1049,7 @@ class MotorsBus(abc.ABC):
|
||||
|
||||
return value, comm, error
|
||||
|
||||
@check_if_not_connected
|
||||
def write(
|
||||
self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0
|
||||
) -> None:
|
||||
@@ -999,10 +1068,6 @@ class MotorsBus(abc.ABC):
|
||||
normalize (bool, optional): Enable or disable normalisation. Defaults to `True`.
|
||||
num_retry (int, optional): Retry attempts. Defaults to `0`.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
id_ = self.motors[motor].id
|
||||
model = self.motors[motor].model
|
||||
@@ -1044,6 +1109,7 @@ class MotorsBus(abc.ABC):
|
||||
|
||||
return comm, error
|
||||
|
||||
@check_if_not_connected
|
||||
def sync_read(
|
||||
self,
|
||||
data_name: str,
|
||||
@@ -1063,10 +1129,6 @@ class MotorsBus(abc.ABC):
|
||||
Returns:
|
||||
dict[str, Value]: Mapping *motor name → value*.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
self._assert_protocol_is_compatible("sync_read")
|
||||
|
||||
@@ -1139,6 +1201,7 @@ class MotorsBus(abc.ABC):
|
||||
# for id_ in motor_ids:
|
||||
# value = self.sync_reader.getData(id_, address, length)
|
||||
|
||||
@check_if_not_connected
|
||||
def sync_write(
|
||||
self,
|
||||
data_name: str,
|
||||
@@ -1160,10 +1223,6 @@ class MotorsBus(abc.ABC):
|
||||
normalize (bool, optional): If `True` (default) convert values from the user range to raw units.
|
||||
num_retry (int, optional): Retry attempts. Defaults to `0`.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
ids_values = self._get_ids_values_dict(values)
|
||||
models = [self._id_to_model(id_) for id_ in ids_values]
|
||||
@@ -1212,3 +1271,7 @@ class MotorsBus(abc.ABC):
|
||||
for id_, value in ids_values.items():
|
||||
data = self._serialize_data(value, length)
|
||||
self.sync_writer.addParam(id_, data)
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
MotorsBus: TypeAlias = SerialMotorsBus
|
||||
|
||||
@@ -28,7 +28,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and 'output_shapes`.
|
||||
Those are: `input_features` and `output_features`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- Either:
|
||||
@@ -48,21 +48,12 @@ class ACTConfig(PreTrainedConfig):
|
||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||
environment, and throws the other 50 out.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
|
||||
@@ -30,7 +30,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
Those are: `input_features` and `output_features`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
@@ -48,21 +48,12 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
@@ -73,7 +64,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
use_separate_rgb_encoder_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
||||
You may provide a variable number of dimensions, therefore also controlling the degree of
|
||||
downsampling.
|
||||
|
||||
@@ -34,6 +34,7 @@ from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05_full.configuration_pi05 import PI05FullConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
@@ -390,6 +391,13 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, PI05FullConfig):
|
||||
from lerobot.policies.pi05_full.processor_pi05 import make_pi05_full_pre_post_processors
|
||||
|
||||
processors = make_pi05_full_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
|
||||
@@ -32,16 +32,22 @@ Notes:
|
||||
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.groot.groot_n1 import GR00TN15
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
T = TypeVar("T", bound="GrootPolicy")
|
||||
|
||||
|
||||
class GrootPolicy(PreTrainedPolicy):
|
||||
@@ -90,6 +96,129 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
"""Reset policy state when environment resets."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: GrootConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = True,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Load Groot policy from pretrained model.
|
||||
|
||||
Handles two cases:
|
||||
1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model
|
||||
2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors
|
||||
|
||||
Args:
|
||||
pretrained_name_or_path: Path to the GR00T model or fine-tuned checkpoint
|
||||
config: Optional GrootConfig. If None, loads from checkpoint or creates default
|
||||
force_download: Force download even if cached
|
||||
resume_download: Resume interrupted download
|
||||
proxies: Proxy settings
|
||||
token: HuggingFace authentication token
|
||||
cache_dir: Cache directory path
|
||||
local_files_only: Only use local files
|
||||
revision: Specific model revision
|
||||
strict: Strict state dict loading
|
||||
**kwargs: Additional arguments (passed to config)
|
||||
|
||||
Returns:
|
||||
Initialized GrootPolicy instance with loaded model
|
||||
"""
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
print(
|
||||
"The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n"
|
||||
f"Loading pretrained model from: {pretrained_name_or_path}"
|
||||
)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
is_finetuned_checkpoint = False
|
||||
|
||||
# Check if this is a fine-tuned LeRobot checkpoint (has model.safetensors)
|
||||
try:
|
||||
if os.path.isdir(model_id):
|
||||
is_finetuned_checkpoint = os.path.exists(os.path.join(model_id, SAFETENSORS_SINGLE_FILE))
|
||||
else:
|
||||
# Try to download the safetensors file to check if it exists
|
||||
try:
|
||||
hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_SINGLE_FILE,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=False, # Just check, don't force download
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
is_finetuned_checkpoint = True
|
||||
except HfHubHTTPError:
|
||||
is_finetuned_checkpoint = False
|
||||
except Exception:
|
||||
is_finetuned_checkpoint = False
|
||||
|
||||
if is_finetuned_checkpoint:
|
||||
# This is a fine-tuned LeRobot checkpoint - use parent class loading
|
||||
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||
return super().from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
config=config,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
strict=strict,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# This is a base GR00T model - load it fresh
|
||||
print("Detected base GR00T model, loading from HuggingFace...")
|
||||
|
||||
if config is None:
|
||||
# Create default config with the pretrained path
|
||||
config = GrootConfig(base_model_path=str(pretrained_name_or_path))
|
||||
|
||||
# Add minimal visual feature required for validation
|
||||
# validate_features() will automatically add state and action features
|
||||
# These are placeholders - actual robot features come from the preprocessor
|
||||
if not config.input_features:
|
||||
config.input_features = {
|
||||
f"{OBS_IMAGES}.camera": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224), # Default image size from config
|
||||
),
|
||||
}
|
||||
else:
|
||||
# Override the base_model_path with the provided path
|
||||
config.base_model_path = str(pretrained_name_or_path)
|
||||
|
||||
# Pass through any additional config overrides from kwargs
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
# Create a fresh policy instance - this will automatically load the GR00T model
|
||||
# in __init__ via _create_groot_model()
|
||||
policy = cls(config)
|
||||
|
||||
policy.eval()
|
||||
return policy
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
|
||||
@@ -1297,3 +1297,14 @@ class PI0Policy(PreTrainedPolicy):
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any]:
|
||||
"""Return default PEFT target modules for PI0 fine-tuning."""
|
||||
common_projections = (
|
||||
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
||||
)
|
||||
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
|
||||
return {
|
||||
"target_modules": target_modules,
|
||||
"modules_to_save": [],
|
||||
}
|
||||
|
||||
@@ -1270,3 +1270,14 @@ class PI05Policy(PreTrainedPolicy):
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any]:
|
||||
"""Return default PEFT target modules for PI0.5 fine-tuning."""
|
||||
common_projections = (
|
||||
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
||||
)
|
||||
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
|
||||
return {
|
||||
"target_modules": target_modules,
|
||||
"modules_to_save": [],
|
||||
}
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
# π₀.₅ (pi05)
|
||||
|
||||
This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
|
||||
It is designed as a **Vision-Language-Action model with open-world generalization**.
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | π₀ | π₀.₅ |
|
||||
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
|
||||
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
|
||||
| AdaRMS | Not used | Used in action expert |
|
||||
| Tokenizer Length | 48 tokens | 200 tokens |
|
||||
| Discrete State Input | False (Uses `state_proj` layer) | True |
|
||||
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{intelligence2025pi05visionlanguageactionmodelopenworld,
|
||||
title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization},
|
||||
author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky},
|
||||
year = {2025},
|
||||
eprint = {2504.16054},
|
||||
archivePrefix= {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2504.16054},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_pi05 import PI05FullConfig
|
||||
from .modeling_pi05 import PI05FullPolicy
|
||||
from .processor_pi05 import make_pi05_full_pre_post_processors
|
||||
|
||||
__all__ = ["PI05FullConfig", "PI05FullPolicy", "make_pi05_full_pre_post_processors"]
|
||||
@@ -0,0 +1,50 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Example script to run synthetic data generation with Qwen VLM
|
||||
# This generates user prompts and robot utterances for hierarchical policy training
|
||||
|
||||
# Configuration
|
||||
REPO_ID="lerobot/libero_10"
|
||||
MODEL="Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||
# or: MODEL="Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
|
||||
OUTPUT_DIR="/fsx/jade_choghari/outputs/libero-10-annotate-high"
|
||||
|
||||
BATCH_SIZE=16
|
||||
TEMPERATURE=0.9
|
||||
SAMPLE_INTERVAL=5.0 # generate dialogue every 1 second (all episodes processed)
|
||||
|
||||
# Run subtask annotation
|
||||
# python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --video-key observation.images.image \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --skip-existing \
|
||||
# --output-repo-id "jadechoghari/libero10-annotate" \
|
||||
# --batch-size "$BATCH_SIZE" \
|
||||
# run synthetic data generation (all episodes processed)
|
||||
# python examples/dataset/annotate_pgen.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --model "$MODEL" \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --temperature "$TEMPERATURE" \
|
||||
# --batch-size "$BATCH_SIZE" \
|
||||
# --sample-interval "$SAMPLE_INTERVAL" \
|
||||
# --image-key observation.images.base \
|
||||
# --num-image-views-per-sample 1
|
||||
|
||||
# for faster testing, increase sample interval:
|
||||
# --sample-interval 5.0 # Samples every 5 seconds (much faster)
|
||||
|
||||
# to push to hub after generation:
|
||||
# add --push-to-hub flag
|
||||
|
||||
# efficient batch processing: 4 episodes at once
|
||||
python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/high_level_annotate.py \
|
||||
--data-dir "/fsx/jade_choghari/outputs/libero-10-annotate" \
|
||||
--output-dir "$OUTPUT_DIR" \
|
||||
--video-mode \
|
||||
--video-key observation.images.image \
|
||||
--video-batch-size "$BATCH_SIZE" \
|
||||
--sample-interval 5.0
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
# /fsx/jade_choghari/data/libero_10_subtasks_kw_converted
|
||||
dataset = LeRobotDataset(repo_id="lerobot/libero_10_image_subtask")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path="/fsx/jade_choghari/models/pi05-base",
|
||||
)
|
||||
cfg.dtype = "bfloat16"
|
||||
|
||||
pre_processor, post_processor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path="/fsx/jade_choghari/models/pi05-base",
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
breakpoint()
|
||||
batch1 = pre_processor(batch)
|
||||
breakpoint()
|
||||
print(batch.keys())
|
||||
# print(batch['task_index_high_level'].shape)
|
||||
# print(batch['task_index_high_level'])
|
||||
# print(batch['user_prompt'][0])
|
||||
# print(batch['robot_utterance'][0])
|
||||
# print(batch['task'][0])
|
||||
|
||||
valid_episode_list = []
|
||||
for episode_idx in range(len(dataset.meta.episodes)):
|
||||
subtask_index = dataset[episode_idx]["subtask_index"]
|
||||
valid_episode_list.append(episode_idx)
|
||||
|
||||
print(len(valid_episode_list))
|
||||
|
||||
# read this parquet /fsx/jade_choghari/outputs/pgen_annotations1/meta/tasks.parquett
|
||||
# import pandas as pd
|
||||
# tasks_df = pd.read_parquet('/fsx/jade_choghari/outputs/pgen_annotations1/meta/tasks.parquet')
|
||||
|
||||
# # print all
|
||||
# print(tasks_df.columns)
|
||||
# breakpoint()
|
||||
@@ -0,0 +1,49 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Example script to run synthetic data generation with Qwen VLM
|
||||
# This generates user prompts and robot utterances for hierarchical policy training
|
||||
|
||||
# Configuration
|
||||
REPO_ID="jadechoghari/collect-data"
|
||||
MODEL="Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||
# or: MODEL="Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
|
||||
OUTPUT_DIR="/fsx/jade_choghari/outputs/collect-data-pgen_new"
|
||||
|
||||
BATCH_SIZE=32
|
||||
TEMPERATURE=0.9
|
||||
SAMPLE_INTERVAL=5.0 # generate dialogue every 1 second (all episodes processed)
|
||||
|
||||
# Run subtask annotation
|
||||
python /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py \
|
||||
--repo-id "$REPO_ID" \
|
||||
--video-key observation.images.base \
|
||||
--output-dir "$OUTPUT_DIR" \
|
||||
--output-repo-id "jadechoghari/collect-data-with-subtasks"
|
||||
# run synthetic data generation (all episodes processed)
|
||||
# python examples/dataset/annotate_pgen.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --model "$MODEL" \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --temperature "$TEMPERATURE" \
|
||||
# --batch-size "$BATCH_SIZE" \
|
||||
# --sample-interval "$SAMPLE_INTERVAL" \
|
||||
# --image-key observation.images.base \
|
||||
# --num-image-views-per-sample 1
|
||||
|
||||
# for faster testing, increase sample interval:
|
||||
# --sample-interval 5.0 # Samples every 5 seconds (much faster)
|
||||
|
||||
# to push to hub after generation:
|
||||
# add --push-to-hub flag
|
||||
|
||||
# efficient batch processing: 4 episodes at once
|
||||
# python examples/dataset/annotate_pgen.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --model "$MODEL" \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --video-mode \
|
||||
# --video-key observation.images.up \
|
||||
# --video-batch-size "$BATCH_SIZE" \
|
||||
# --sample-interval 1.0
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05_full")
|
||||
@dataclass
|
||||
class PI05FullConfig(PreTrainedConfig):
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
|
||||
n_action_steps: int = 50 # Number of action steps to execute
|
||||
|
||||
# Shorter state and action vectors will be padded to these dimensions
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Flow matching parameters: see openpi `PI0Pytorch`
|
||||
num_inference_steps: int = 10
|
||||
time_sampling_beta_alpha: float = 1.5
|
||||
time_sampling_beta_beta: float = 1.0
|
||||
time_sampling_scale: float = 0.999
|
||||
time_sampling_offset: float = 0.001
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD, # Pi0.5 uses quantiles for state
|
||||
"ACTION": NormalizationMode.MEAN_STD, # Pi0.5 uses quantiles for action
|
||||
}
|
||||
)
|
||||
|
||||
action_tokenizer_name: str = "physical-intelligence/fast"
|
||||
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
|
||||
max_action_tokens: int = 256
|
||||
fast_skip_tokens: int = 128
|
||||
|
||||
# subtask stuff
|
||||
max_decoding_steps: int = 200
|
||||
temperature: float = 0.0
|
||||
subtask_regeneration_interval: float = 1.0 # Regenerate subtask tokens every N seconds (0 = every call)
|
||||
|
||||
# Training settings
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = False # Freeze only the vision encoder
|
||||
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
|
||||
knowledge_insulation: bool = True # Enable knowledge insulation in attention (blocks gradients from action to VLM K/V)
|
||||
|
||||
# Loss weights (used when knowledge_insulation is enabled)
|
||||
loss_weight_flow: float = 1.0 # Weight for flow matching MSE loss (continuous actions)
|
||||
loss_weight_action_ce: float = 1.0 # Weight for FAST action token cross-entropy loss
|
||||
loss_weight_subtask_ce: float = 1.0 # Weight for subtask token cross-entropy loss
|
||||
|
||||
# Optimizer settings: see openpi `AdamW`
|
||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.01
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
tokenizer_max_length: int = 48 # see openpi `__post_init__`
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Validate configuration
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
|
||||
|
||||
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
|
||||
|
||||
if self.dtype not in ["bfloat16", "float32"]:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
for i in range(self.empty_cameras):
|
||||
key = OBS_IMAGES + f".empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, *self.image_resolution), # Use configured image resolution
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,), # Padded to max_state_dim
|
||||
)
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,), # Padded to max_action_dim
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
# import make_pre_post_processors
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path="/fsx/jade_choghari/models/pi05-base",
|
||||
)
|
||||
cfg.dtype = "bfloat16"
|
||||
|
||||
pre_processor, post_processor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path="/fsx/jade_choghari/models/pi05-base",
|
||||
)
|
||||
|
||||
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
|
||||
|
||||
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps)
|
||||
|
||||
# rename map --rename_map='{
|
||||
# "observation.images.side": "observation.images.base_0_rgb",
|
||||
# "observation.images.up": "observation.images.left_wrist_0_rgb"
|
||||
# }'
|
||||
rename_map = {
|
||||
"observation.images.side": "observation.images.base_0_rgb",
|
||||
"observation.images.up": "observation.images.left_wrist_0_rgb"
|
||||
}
|
||||
policy = make_policy(
|
||||
cfg=cfg,
|
||||
ds_meta=dataset.meta,
|
||||
rename_map=rename_map,
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=4,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
breakpoint()
|
||||
batch = pre_processor(batch)
|
||||
policy.train()
|
||||
# run inference
|
||||
# action = policy.select_action(batch)
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
breakpoint()
|
||||
# import requests
|
||||
# from PIL import Image
|
||||
# from transformers import AutoProcessor
|
||||
# model = policy.model.paligemma_with_expert.paligemma
|
||||
# model = model.to(device="cuda", dtype=torch.bfloat16)
|
||||
# model.eval()
|
||||
# prompt = "Describe this image."
|
||||
# url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
# image = Image.open(requests.get(url, stream=True).raw)
|
||||
# processor = AutoProcessor.from_pretrained(
|
||||
# "google/paligemma-3b-pt-224",
|
||||
# )
|
||||
# inputs = processor(image, prompt, return_tensors="pt").to(model.device)
|
||||
# print("generating...")
|
||||
# output = model.generate(
|
||||
# **inputs,
|
||||
# max_new_tokens=50,
|
||||
# use_cache=True, # default dynamic cache
|
||||
# )
|
||||
# print(processor.decode(output[0], skip_special_tokens=True))
|
||||
|
||||
|
||||
# # other model
|
||||
# from transformers import PaliGemmaForConditionalGeneration
|
||||
# model = PaliGemmaForConditionalGeneration.from_pretrained(
|
||||
# "google/paligemma2-3b-pt-224",
|
||||
# torch_dtype=torch.bfloat16,
|
||||
# device_map="auto",
|
||||
# )
|
||||
# model.eval()
|
||||
# print("generating...")
|
||||
# output = model.generate(
|
||||
# **inputs,
|
||||
# max_new_tokens=100,
|
||||
# use_cache=True, # default dynamic cache
|
||||
# )
|
||||
# print("Model 2 output:")
|
||||
# print(processor.decode(output[0], skip_special_tokens=True))
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,194 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi05_full.configuration_pi05 import PI05FullConfig
|
||||
from lerobot.policies.pi05_full.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi05_full_prepare_state_tokenizer_processor_step")
|
||||
@dataclass
|
||||
class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Processor step to prepare the state and tokenize the language input.
|
||||
"""
|
||||
|
||||
max_state_dim: int = 32
|
||||
task_key: str = "task"
|
||||
subtask_key: str = "subtask"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
|
||||
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
|
||||
if state is None:
|
||||
raise ValueError("State is required for PI05")
|
||||
user_prompts = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
|
||||
if user_prompts is None:
|
||||
raise ValueError("No user prompts found in complementary data")
|
||||
commands = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.subtask_key)
|
||||
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.max_state_dim)
|
||||
|
||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
full_prompts = []
|
||||
for i, user_prompt in enumerate(user_prompts):
|
||||
cleaned_text = user_prompt.strip().replace("_", " ").replace("\n", " ")
|
||||
cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari)
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\n"
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
|
||||
|
||||
# process commands (optional)
|
||||
if commands is not None:
|
||||
full_commands = []
|
||||
for i, command in enumerate(commands):
|
||||
cleaned_text = command.strip().replace("_", " ").replace("\n", " ")
|
||||
cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari)
|
||||
full_command = f"Subtask: {cleaned_text};\n"
|
||||
full_commands.append(full_command)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.subtask_key] = full_commands
|
||||
|
||||
# note: action tokens will be processed in the ActionTokenizerProcessorStep
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step does not alter the feature definitions.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
def make_pi05_full_pre_post_processors(
|
||||
config: PI05FullConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the PI0 policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Appending a newline character to the task description for tokenizer compatibility.
|
||||
5. Tokenizing the text prompt using the PaliGemma tokenizer.
|
||||
6. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the PI0 policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
Pi05FullPrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=config.text_tokenizer_name,
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
ActionTokenizerProcessorStep(
|
||||
action_tokenizer_name=config.action_tokenizer_name,
|
||||
max_action_tokens=config.max_action_tokens,
|
||||
fast_skip_tokens=config.fast_skip_tokens,
|
||||
paligemma_tokenizer_name=config.text_tokenizer_name,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import builtins
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
from importlib.resources import files
|
||||
@@ -265,3 +266,166 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
card = ModelCard.from_template(card_data, template_str=template_card)
|
||||
card.validate()
|
||||
return card
|
||||
|
||||
def wrap_with_peft(
|
||||
self,
|
||||
peft_config=None,
|
||||
peft_cli_overrides: dict | None = None,
|
||||
) -> "PreTrainedPolicy":
|
||||
"""
|
||||
Wrap this policy with PEFT adapters for parameter-efficient fine-tuning.
|
||||
|
||||
This method is the single entry point for PEFT integration. Subclasses should
|
||||
override `_get_default_peft_targets()` to provide default target modules, and
|
||||
`_validate_peft_config()` for policy-specific validation.
|
||||
|
||||
Args:
|
||||
peft_config: Optional PEFT adapter configuration (e.g., LoraConfig).
|
||||
If provided, used directly (with CLI overrides applied).
|
||||
peft_cli_overrides: Optional dict of CLI overrides (method_type, target_modules, r, etc.)
|
||||
These are merged with policy defaults to build the final config.
|
||||
"""
|
||||
from peft import get_peft_model
|
||||
|
||||
# If user provided a complete config, use it directly (with overrides)
|
||||
if peft_config is not None:
|
||||
final_config = peft_config
|
||||
if peft_cli_overrides:
|
||||
final_config = self._apply_peft_cli_overrides(final_config, peft_cli_overrides)
|
||||
else:
|
||||
# Build config from defaults + CLI overrides
|
||||
final_config = self._build_peft_config(peft_cli_overrides or {})
|
||||
|
||||
# Validate the configuration
|
||||
self._validate_peft_config(final_config)
|
||||
|
||||
# Freeze base parameters, only adapter params will be trained
|
||||
for p in self.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
# Store pretrained path for PEFT's base_model_name_or_path
|
||||
if self.config.pretrained_path:
|
||||
self.name_or_path = str(self.config.pretrained_path)
|
||||
|
||||
# Wrap with PEFT
|
||||
peft_model = get_peft_model(self, final_config)
|
||||
|
||||
# Mark config as using PEFT for proper loading later
|
||||
peft_model.config.use_peft = True
|
||||
|
||||
logging.info(f"Wrapped {self.name} with PEFT ({type(final_config).__name__})")
|
||||
return peft_model
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any] | None:
|
||||
"""
|
||||
Return default PEFT target modules for this policy.
|
||||
|
||||
Override this in subclasses to provide policy-specific defaults. These defaults
|
||||
are PEFT-method agnostic - they only specify which modules to target.
|
||||
|
||||
"""
|
||||
return None
|
||||
|
||||
def _validate_peft_config(self, peft_config) -> None:
|
||||
"""
|
||||
Validate the PEFT configuration for this policy.
|
||||
|
||||
Override this in subclasses to add policy-specific validation or warnings.
|
||||
The default implementation checks that a pretrained_path exists.
|
||||
|
||||
Args:
|
||||
peft_config: The PEFT configuration to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configuration is invalid.
|
||||
"""
|
||||
if not self.config.pretrained_path:
|
||||
raise ValueError(
|
||||
"Training from scratch using PEFT is unlikely to yield good results. "
|
||||
"Supply a `policy.pretrained_path` to fine-tune an existing model."
|
||||
)
|
||||
|
||||
def _preprocess_peft_cli_overrides(self, cli_overrides: dict, peft_method_type) -> dict:
|
||||
"""
|
||||
Preprocess CLI overrides: rename keys and handle method-specific init_type.
|
||||
|
||||
Args:
|
||||
cli_overrides: Dict of CLI options (will be copied, not mutated).
|
||||
peft_method_type: The PeftType enum value for the PEFT method.
|
||||
|
||||
Returns:
|
||||
Preprocessed dict with renamed keys and init_type mapped to method-specific key.
|
||||
"""
|
||||
from peft import PeftType
|
||||
|
||||
cli_overrides = cli_overrides.copy()
|
||||
|
||||
# Handle the full_training_modules -> modules_to_save rename
|
||||
if "full_training_modules" in cli_overrides:
|
||||
cli_overrides["modules_to_save"] = cli_overrides.pop("full_training_modules")
|
||||
|
||||
# Remove method_type as it's handled separately
|
||||
cli_overrides.pop("method_type", None)
|
||||
|
||||
# Handle init_type specially based on PEFT method
|
||||
init_type = cli_overrides.pop("init_type", None)
|
||||
if init_type is not None:
|
||||
if peft_method_type == PeftType.LORA:
|
||||
cli_overrides["init_lora_weights"] = init_type
|
||||
elif peft_method_type == PeftType.MISS:
|
||||
cli_overrides["init_weights"] = init_type
|
||||
else:
|
||||
raise ValueError(f"Init type '{init_type}' unknown for PEFT method {peft_method_type}.")
|
||||
|
||||
return cli_overrides
|
||||
|
||||
def _build_peft_config(self, cli_overrides: dict):
|
||||
"""Build a PEFT config from policy defaults and CLI overrides."""
|
||||
from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType
|
||||
|
||||
# Determine PEFT method type (default to LORA)
|
||||
method_type_str = cli_overrides.get("method_type") or "lora"
|
||||
peft_method_type = PeftType[method_type_str.upper()]
|
||||
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]
|
||||
|
||||
# Preprocess CLI overrides
|
||||
cli_overrides = self._preprocess_peft_cli_overrides(cli_overrides, peft_method_type)
|
||||
|
||||
# Start with policy defaults, apply CLI overrides
|
||||
config_dict = dict(self._get_default_peft_targets() or {})
|
||||
for key, value in cli_overrides.items():
|
||||
if value is not None:
|
||||
config_dict[key] = value
|
||||
|
||||
# Ensure we have target_modules
|
||||
if not config_dict.get("target_modules"):
|
||||
raise ValueError(
|
||||
f"Policy '{self.name}' does not define default target_modules. "
|
||||
"Please pass --peft.target_modules explicitly."
|
||||
)
|
||||
|
||||
return peft_config_cls(**config_dict)
|
||||
|
||||
def _apply_peft_cli_overrides(self, peft_config, cli_overrides: dict):
|
||||
"""Apply CLI overrides to an existing PEFT config."""
|
||||
from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType
|
||||
|
||||
# Get method type from existing config or CLI override
|
||||
method_type_str = cli_overrides.get("method_type")
|
||||
if method_type_str:
|
||||
peft_method_type = PeftType[method_type_str.upper()]
|
||||
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]
|
||||
else:
|
||||
peft_method_type = PeftType(peft_config.peft_type)
|
||||
peft_config_cls = type(peft_config)
|
||||
|
||||
# Preprocess CLI overrides
|
||||
cli_overrides = self._preprocess_peft_cli_overrides(cli_overrides, peft_method_type)
|
||||
|
||||
# Start with existing config, apply CLI overrides
|
||||
config_dict = {k: v for k, v in dataclasses.asdict(peft_config).items() if not k.startswith("_")}
|
||||
for key, value in cli_overrides.items():
|
||||
if value is not None:
|
||||
config_dict[key] = value
|
||||
|
||||
return peft_config_cls(**config_dict)
|
||||
|
||||
@@ -239,8 +239,10 @@ class SACPolicy(
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def update_temperature(self):
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
"""Return the current temperature value, always in sync with log_alpha."""
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def compute_loss_critic(
|
||||
self,
|
||||
@@ -457,11 +459,10 @@ class SACPolicy(
|
||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.target_entropy = -np.prod(dim) / 2
|
||||
|
||||
def _init_temperature(self):
|
||||
"""Set up temperature parameter and initial log_alpha."""
|
||||
def _init_temperature(self) -> None:
|
||||
"""Set up temperature parameter (log_alpha)."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
|
||||
@@ -480,6 +480,28 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
return actions
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any]:
|
||||
"""Return default PEFT target modules for SmolVLA fine-tuning."""
|
||||
common_projections = (
|
||||
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
||||
)
|
||||
target_modules = rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))"
|
||||
return {
|
||||
"target_modules": target_modules,
|
||||
"modules_to_save": [],
|
||||
}
|
||||
|
||||
def _validate_peft_config(self, peft_config) -> None:
|
||||
"""Validate PEFT configuration for SmolVLA."""
|
||||
super()._validate_peft_config(peft_config)
|
||||
if not self.config.load_vlm_weights:
|
||||
import logging
|
||||
|
||||
logging.warning(
|
||||
"Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. "
|
||||
"Set `load_vlm_weights=True` to fine-tune the existing policy."
|
||||
)
|
||||
|
||||
|
||||
def pad_tensor(tensor, max_len, pad_value=0):
|
||||
"""
|
||||
|
||||
@@ -30,7 +30,7 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
|
||||
Those are: `input_features`, `output_features`, and perhaps `max_random_shift_ratio`.
|
||||
|
||||
Args:
|
||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||
@@ -40,24 +40,12 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
is an alternative to using action repeats. If this is set to more than 1, then we require
|
||||
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
|
||||
approach of using multiple steps from the plan is not in the original implementation.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
|
||||
match the original implementation.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
|
||||
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
|
||||
normalization mode here.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
|
||||
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
|
||||
latent_dim: Observation's latent embedding dimension.
|
||||
|
||||
@@ -32,7 +32,7 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
Those are: `input_features` and `output_features`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
@@ -46,21 +46,12 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
current step and additional steps going back).
|
||||
n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts.
|
||||
action_chunk_size: Action chunk size of each action prediction token.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.image" refers to an input from
|
||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
|
||||
@@ -168,11 +168,14 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
user_prompt_key = {"user_prompt": batch["user_prompt"]} if "user_prompt" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key, **user_prompt_key, **subtask_key}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -18,16 +18,18 @@
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol, TypeVar, runtime_checkable
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
from .pipeline import (
|
||||
ComplementaryDataProcessorStep,
|
||||
@@ -69,10 +71,10 @@ class HasTeleopEvents(Protocol):
|
||||
|
||||
|
||||
# Type variable constrained to Teleoperator subclasses that also implement events
|
||||
TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
|
||||
TeleopWithEvents = TypeVar("TeleopWithEvents", bound="Teleoperator")
|
||||
|
||||
|
||||
def _check_teleop_with_events(teleop: Teleoperator) -> None:
|
||||
def _check_teleop_with_events(teleop: "Teleoperator") -> None:
|
||||
"""
|
||||
Runtime check that a teleoperator implements the `HasTeleopEvents` protocol.
|
||||
|
||||
@@ -103,7 +105,7 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
|
||||
teleop_device: The teleoperator instance to get the action from.
|
||||
"""
|
||||
|
||||
teleop_device: Teleoperator
|
||||
teleop_device: "Teleoperator"
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
"""
|
||||
@@ -312,7 +314,7 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Applies a penalty for inefficient gripper usage.
|
||||
|
||||
@@ -327,26 +329,27 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Calculates the gripper penalty and adds it to the complementary data.
|
||||
|
||||
Args:
|
||||
complementary_data: The incoming complementary data, which should contain
|
||||
raw joint positions.
|
||||
transition: The incoming environment transition.
|
||||
|
||||
Returns:
|
||||
A new complementary data dictionary with the `discrete_penalty` key added.
|
||||
The modified transition with the penalty added to complementary data.
|
||||
"""
|
||||
action = self.transition.get(TransitionKey.ACTION)
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
raw_joint_positions = complementary_data.get("raw_joint_positions")
|
||||
if raw_joint_positions is None:
|
||||
return complementary_data
|
||||
return new_transition
|
||||
|
||||
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
|
||||
if current_gripper_pos is None:
|
||||
return complementary_data
|
||||
return new_transition
|
||||
|
||||
# Gripper action is a PolicyAction at this stage
|
||||
gripper_action = action[-1].item()
|
||||
@@ -362,11 +365,12 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
# Create new complementary data with penalty info
|
||||
# Update complementary data with penalty info
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
|
||||
return new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -34,7 +34,12 @@ from lerobot.utils.constants import (
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_TOKENS,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_LANGUAGE_USER_PROMPT,
|
||||
OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_USER_PROMPT_TOKENS,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -139,18 +144,70 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
|
||||
return None
|
||||
|
||||
def get_user_prompt(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""
|
||||
Extracts the user_prompt from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
A list of user_prompt strings, or None if the user_prompt key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
user_prompt = complementary_data.get("user_prompt")
|
||||
if user_prompt is None:
|
||||
return None
|
||||
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(user_prompt, str):
|
||||
return [user_prompt]
|
||||
elif isinstance(user_prompt, list) and all(isinstance(t, str) for t in user_prompt):
|
||||
return user_prompt
|
||||
|
||||
return None
|
||||
|
||||
def get_subtask(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""
|
||||
Extracts the subtask from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
A list of subtask strings, or None if the subtask key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
subtask = complementary_data.get("subtask")
|
||||
if subtask is None:
|
||||
return None
|
||||
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(subtask, str):
|
||||
return [subtask]
|
||||
elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask):
|
||||
return subtask
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation: RobotObservation) -> RobotObservation:
|
||||
"""
|
||||
Tokenizes the task description and adds it to the observation dictionary.
|
||||
Tokenizes the task description and user_prompt (if available) and adds them to the observation dictionary.
|
||||
|
||||
This method retrieves the task, tokenizes it, moves the resulting tensors to the
|
||||
This method retrieves the task and user_prompt, tokenizes them, moves the resulting tensors to the
|
||||
same device as other data in the transition, and updates the observation.
|
||||
|
||||
Args:
|
||||
observation: The original observation dictionary.
|
||||
|
||||
Returns:
|
||||
The updated observation dictionary including token IDs and an attention mask.
|
||||
The updated observation dictionary including token IDs and attention masks.
|
||||
"""
|
||||
task = self.get_task(self.transition)
|
||||
if task is None:
|
||||
@@ -176,6 +233,58 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Tokenize user_prompt if available
|
||||
user_prompt = self.get_user_prompt(self.transition)
|
||||
if user_prompt is not None:
|
||||
tokenized_user_prompt = self._tokenize_text(user_prompt)
|
||||
|
||||
# Move new tokenized tensors to the detected device
|
||||
if target_device is not None:
|
||||
tokenized_user_prompt = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_user_prompt.items()
|
||||
}
|
||||
|
||||
# Add tokenized user_prompt to the observation
|
||||
new_observation[OBS_LANGUAGE_USER_PROMPT_TOKENS] = tokenized_user_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK] = tokenized_user_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Tokenize subtask if available
|
||||
subtask = self.get_subtask(self.transition)
|
||||
if subtask is not None:
|
||||
tokenized_subtask = self._tokenize_text(subtask)
|
||||
|
||||
# Add EOS token at the end of each subtask sequence (before padding)
|
||||
eos_token_id = self.input_tokenizer.eos_token_id
|
||||
input_ids = tokenized_subtask["input_ids"]
|
||||
attention_mask = tokenized_subtask["attention_mask"]
|
||||
for i in range(input_ids.size(0)):
|
||||
# Find the length of actual tokens (sum of attention mask)
|
||||
seq_len = attention_mask[i].sum().item()
|
||||
|
||||
max_len = input_ids.size(1)
|
||||
if seq_len >= max_len:
|
||||
raise ValueError(
|
||||
f"No room to append EOS: seq_len={seq_len} equals max_length={max_len}. "
|
||||
"Increase max_length or tokenize with padding=False then pad after adding EOS."
|
||||
)
|
||||
# Add EOS token at the end
|
||||
input_ids[i, seq_len] = eos_token_id
|
||||
attention_mask[i, seq_len] = 1
|
||||
|
||||
# Move new tokenized tensors to the detected device
|
||||
if target_device is not None:
|
||||
tokenized_subtask = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_subtask.items()
|
||||
}
|
||||
|
||||
# Add tokenized subtask to the observation
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
|
||||
return new_observation
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
@@ -274,6 +383,28 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
# Add features for user_prompt tokens and attention mask if they don't already exist
|
||||
if OBS_LANGUAGE_USER_PROMPT_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_USER_PROMPT_TOKENS] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
if OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
# Add features for subtask tokens and attention mask if they don't already exist
|
||||
if OBS_LANGUAGE_SUBTASK_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
if OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@@ -527,4 +658,4 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
Returns:
|
||||
The updated dictionary of policy features.
|
||||
"""
|
||||
return features
|
||||
return features
|
||||
@@ -412,7 +412,10 @@ def make_processors(
|
||||
if cfg.processor.observation.add_current_to_observation:
|
||||
env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot))
|
||||
|
||||
if kinematics_solver is not None:
|
||||
add_ee_pose = (
|
||||
cfg.processor.observation is not None and cfg.processor.observation.add_ee_pose_to_observation
|
||||
)
|
||||
if kinematics_solver is not None and add_ee_pose:
|
||||
env_pipeline_steps.append(
|
||||
ForwardKinematicsJointsToEEObservation(
|
||||
kinematics=kinematics_solver,
|
||||
@@ -435,7 +438,12 @@ def make_processors(
|
||||
)
|
||||
|
||||
# Add gripper penalty processor if gripper config exists and enabled
|
||||
if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper:
|
||||
# Only add if max_gripper_pos is explicitly configured (required for normalization)
|
||||
if (
|
||||
cfg.processor.gripper is not None
|
||||
and cfg.processor.gripper.use_gripper
|
||||
and cfg.processor.max_gripper_pos is not None
|
||||
):
|
||||
env_pipeline_steps.append(
|
||||
GripperPenaltyProcessorStep(
|
||||
penalty=cfg.processor.gripper.gripper_penalty,
|
||||
|
||||
@@ -545,9 +545,6 @@ def add_actor_information_and_train(
|
||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||
training_infos["temperature"] = policy.temperature
|
||||
|
||||
# Update temperature
|
||||
policy.update_temperature()
|
||||
|
||||
# Push policy to actors if needed
|
||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
|
||||
@@ -26,8 +26,21 @@ from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.utils.constants import PRETRAINED_MODEL_DIR
|
||||
|
||||
|
||||
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
|
||||
def cfg_to_group(
|
||||
cfg: TrainPipelineConfig, return_list: bool = False, truncate_tags: bool = False, max_tag_length: int = 64
|
||||
) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
|
||||
def _maybe_truncate(tag: str) -> str:
|
||||
"""Truncate tag to max_tag_length characters if required.
|
||||
|
||||
wandb rejects tags longer than 64 characters.
|
||||
See: https://github.com/wandb/wandb/blob/main/wandb/sdk/wandb_settings.py
|
||||
"""
|
||||
if len(tag) <= max_tag_length:
|
||||
return tag
|
||||
return tag[:max_tag_length]
|
||||
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"seed:{cfg.seed}",
|
||||
@@ -36,6 +49,8 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st
|
||||
lst.append(f"dataset:{cfg.dataset.repo_id}")
|
||||
if cfg.env is not None:
|
||||
lst.append(f"env:{cfg.env.type}")
|
||||
if truncate_tags:
|
||||
lst = [_maybe_truncate(tag) for tag in lst]
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
@@ -83,7 +98,7 @@ class WandBLogger:
|
||||
entity=self.cfg.entity,
|
||||
name=self.job_name,
|
||||
notes=self.cfg.notes,
|
||||
tags=cfg_to_group(cfg, return_list=True),
|
||||
tags=cfg_to_group(cfg, return_list=True, truncate_tags=True),
|
||||
dir=self.log_dir,
|
||||
config=cfg.to_dict(),
|
||||
# TODO(rcadene): try set to True
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .bi_openarm_follower import BiOpenArmFollower
|
||||
from .config_bi_openarm_follower import BiOpenArmFollowerConfig
|
||||
|
||||
__all__ = ["BiOpenArmFollower", "BiOpenArmFollowerConfig"]
|
||||
@@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_bi_openarm_follower import BiOpenArmFollowerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiOpenArmFollower(Robot):
|
||||
"""
|
||||
Bimanual OpenArm Follower Arms
|
||||
"""
|
||||
|
||||
config_class = BiOpenArmFollowerConfig
|
||||
name = "bi_openarm_follower"
|
||||
|
||||
def __init__(self, config: BiOpenArmFollowerConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
left_arm_config = OpenArmFollowerConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.left_arm_config.port,
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
cameras=config.left_arm_config.cameras,
|
||||
side=config.left_arm_config.side,
|
||||
can_interface=config.left_arm_config.can_interface,
|
||||
use_can_fd=config.left_arm_config.use_can_fd,
|
||||
can_bitrate=config.left_arm_config.can_bitrate,
|
||||
can_data_bitrate=config.left_arm_config.can_data_bitrate,
|
||||
motor_config=config.left_arm_config.motor_config,
|
||||
position_kd=config.left_arm_config.position_kd,
|
||||
position_kp=config.left_arm_config.position_kp,
|
||||
joint_limits=config.left_arm_config.joint_limits,
|
||||
)
|
||||
|
||||
right_arm_config = OpenArmFollowerConfig(
|
||||
id=f"{config.id}_right" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.right_arm_config.port,
|
||||
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
|
||||
max_relative_target=config.right_arm_config.max_relative_target,
|
||||
cameras=config.right_arm_config.cameras,
|
||||
side=config.right_arm_config.side,
|
||||
can_interface=config.right_arm_config.can_interface,
|
||||
use_can_fd=config.right_arm_config.use_can_fd,
|
||||
can_bitrate=config.right_arm_config.can_bitrate,
|
||||
can_data_bitrate=config.right_arm_config.can_data_bitrate,
|
||||
motor_config=config.right_arm_config.motor_config,
|
||||
position_kd=config.right_arm_config.position_kd,
|
||||
position_kp=config.right_arm_config.position_kp,
|
||||
joint_limits=config.right_arm_config.joint_limits,
|
||||
)
|
||||
|
||||
self.left_arm = OpenArmFollower(left_arm_config)
|
||||
self.right_arm = OpenArmFollower(right_arm_config)
|
||||
|
||||
# Only for compatibility with other parts of the codebase that expect a `robot.cameras` attribute
|
||||
self.cameras = {**self.left_arm.cameras, **self.right_arm.cameras}
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
left_arm_motors_ft = self.left_arm._motors_ft
|
||||
right_arm_motors_ft = self.right_arm._motors_ft
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
|
||||
}
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
left_arm_cameras_ft = self.left_arm._cameras_ft
|
||||
right_arm_cameras_ft = self.right_arm._cameras_ft
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
left_obs = self.left_arm.get_observation()
|
||||
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
|
||||
|
||||
# Add "right_" prefix
|
||||
right_obs = self.right_arm.get_observation()
|
||||
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(
|
||||
self,
|
||||
action: RobotAction,
|
||||
custom_kp: dict[str, float] | None = None,
|
||||
custom_kd: dict[str, float] | None = None,
|
||||
) -> RobotAction:
|
||||
# Remove "left_" prefix
|
||||
left_action = {
|
||||
key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_")
|
||||
}
|
||||
# Remove "right_" prefix
|
||||
right_action = {
|
||||
key.removeprefix("right_"): value for key, value in action.items() if key.startswith("right_")
|
||||
}
|
||||
|
||||
sent_action_left = self.left_arm.send_action(left_action, custom_kp, custom_kd)
|
||||
sent_action_right = self.right_arm.send_action(right_action, custom_kp, custom_kd)
|
||||
|
||||
# Add prefixes back
|
||||
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
|
||||
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
||||
|
||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||
|
||||
def disconnect(self):
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.robots.openarm_follower import OpenArmFollowerConfigBase
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("bi_openarm_follower")
|
||||
@dataclass
|
||||
class BiOpenArmFollowerConfig(RobotConfig):
|
||||
"""Configuration class for Bi OpenArm Follower robots."""
|
||||
|
||||
left_arm_config: OpenArmFollowerConfigBase
|
||||
right_arm_config: OpenArmFollowerConfigBase
|
||||
@@ -24,7 +24,8 @@ import numpy as np
|
||||
import requests
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
|
||||
@@ -99,6 +100,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
"""Check if robot is connected to SDK."""
|
||||
return self._is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""Connect to robot via Frodobots SDK.
|
||||
|
||||
@@ -109,8 +111,6 @@ class EarthRoverMiniPlus(Robot):
|
||||
DeviceAlreadyConnectedError: If robot is already connected
|
||||
DeviceNotConnectedError: If cannot connect to SDK server
|
||||
"""
|
||||
if self._is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self.name} is already connected")
|
||||
|
||||
# Verify SDK is running and accessible
|
||||
try:
|
||||
@@ -197,6 +197,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
ACTION_ANGULAR_VEL: float,
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""Get current robot observation from SDK.
|
||||
|
||||
@@ -223,8 +224,6 @@ class EarthRoverMiniPlus(Robot):
|
||||
Robot telemetry is retrieved from /data endpoint.
|
||||
All SDK values are normalized to appropriate ranges for dataset recording.
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||
|
||||
observation = {}
|
||||
|
||||
@@ -255,6 +254,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
|
||||
return observation
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Send action to robot via SDK.
|
||||
|
||||
@@ -272,8 +272,6 @@ class EarthRoverMiniPlus(Robot):
|
||||
Actions are sent to SDK via POST /control endpoint.
|
||||
SDK expects commands in range [-1, 1].
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||
|
||||
# Extract action values and convert to float
|
||||
linear = float(action.get(ACTION_LINEAR_VEL, 0.0))
|
||||
@@ -291,6 +289,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
ACTION_ANGULAR_VEL: angular,
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from robot.
|
||||
|
||||
@@ -299,8 +298,6 @@ class EarthRoverMiniPlus(Robot):
|
||||
Raises:
|
||||
DeviceNotConnectedError: If robot is not connected
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||
|
||||
# Stop the robot before disconnecting
|
||||
try:
|
||||
|
||||
@@ -25,7 +25,7 @@ from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -82,13 +82,12 @@ class HopeJrArm(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
We assume that at connection time, arm is in a rest position,
|
||||
and torque can be safely disabled to run calibration.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect(handshake=False)
|
||||
if not self.is_calibrated and calibrate:
|
||||
@@ -128,10 +127,8 @@ class HopeJrArm(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position", self.other_motors)
|
||||
@@ -149,10 +146,8 @@ class HopeJrArm(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
@@ -165,10 +160,8 @@ class HopeJrArm(Robot):
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
@@ -25,7 +25,7 @@ from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_hope_jr import HopeJrHandConfig
|
||||
@@ -118,10 +118,8 @@ class HopeJrHand(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
self.calibrate()
|
||||
@@ -159,10 +157,8 @@ class HopeJrHand(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
obs_dict = {}
|
||||
|
||||
# Read hand position
|
||||
@@ -181,18 +177,14 @@ class HopeJrHand(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return action
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
@@ -25,7 +25,7 @@ from lerobot.motors.dynamixel import (
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -84,13 +84,12 @@ class KochFollower(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
We assume that at connection time, arm is in a rest position,
|
||||
and torque can be safely disabled to run calibration.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
@@ -182,10 +181,8 @@ class KochFollower(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
@@ -202,6 +199,7 @@ class KochFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
@@ -215,8 +213,6 @@ class KochFollower(Robot):
|
||||
Returns:
|
||||
RobotAction: The action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
@@ -231,10 +227,8 @@ class KochFollower(Robot):
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user