Compare commits

..

10 Commits

Author SHA1 Message Date
Martino Russi ee24f64ae5 add motion imitation 2025-12-17 16:00:43 +01:00
Martino Russi 123b9f7851 add motion imitation 2025-12-17 15:59:56 +01:00
Martino Russi a6c3a0fa09 Feat/add mj env (#2613)
* add sim support

* close fix threading issues
2025-12-15 16:22:27 +01:00
Woojin Wie c2fb644613 feat(robot): Add support for OMX robot (#2614)
* upload

* feat(omx): simplify motor initialization and remove default calibration files

* feat(omx): read motor positions without normalization for improved accuracy

* update calibration method for return factory value

Signed-off-by: Junha Cha <ckwnsgk1@gachon.ac.kr>

* change the drive mode

* refactor: clean up code by removing unnecessary blank lines in omx_follower and omx_leader modules

* feat(omx): update calibration method to set drive modes for motors

* feat(pyproject): add 'ROBOTIS' to extend-ignore-identifiers-re list

* feat(omx): enhance calibration method to write default drive modes to motors

* Update src/lerobot/robots/omx_follower/__init__.py

Add informations about the robot

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Signed-off-by: Woojin Wie <dnldnwls1123@gmail.com>

---------

Signed-off-by: Junha Cha <ckwnsgk1@gachon.ac.kr>
Signed-off-by: Woojin Wie <dnldnwls1123@gmail.com>
Co-authored-by: Junha02 <chajunha2023@naver.com>
Co-authored-by: Junha Cha <ckwnsgk1@gachon.ac.kr>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-12-15 15:50:29 +01:00
Jade Choghari 1d07a4aefd add auto in docs (#2645)
Signed-off-by: Jade Choghari <chogharijade@gmail.com>
2025-12-13 17:11:19 +01:00
Michel Aractingi ce348a3460 enable variable image sizes to pi0/pi0.5 (#2609)
* enable variable image sizes to pi0/pi0.5

* add square image assertion
2025-12-10 19:41:11 +01:00
Jade Choghari cb920235c4 docs: update X-VLA training strategies/commands (#2611) 2025-12-09 19:08:09 +01:00
Jade Choghari 7f40b3bf82 feat(dataset): add tool to convert images to video datasets (#2560)
* add video encoding tool

* style

* make it work

* more fixes
2025-12-08 18:50:21 +01:00
Michel Aractingi 2e9c9fd832 Replay while loop in sample actions with for loops (#2600) 2025-12-08 14:47:54 +01:00
Steven Palma f9cb5e659c chore(ci): skip workflows if not lerobot repository (#2601)
Co-authored-by: Alex Tyshka <atyshka15@gmail.com>
2025-12-08 12:44:36 +01:00
99 changed files with 3696 additions and 14266 deletions
@@ -31,7 +31,8 @@ jobs:
name: Upload Preview and Comment
if: >
github.event.workflow_run.event == 'pull_request' &&
github.event.workflow_run.conclusion == 'success'
github.event.workflow_run.conclusion == 'success' &&
github.repository == 'huggingface/lerobot'
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
with:
package_name: lerobot
+4 -2
View File
@@ -42,7 +42,9 @@ jobs:
# This job builds and deploys the official documentation.
build_main_docs:
name: Build Main Docs
if: github.event_name == 'push' || github.event_name == 'workflow_dispatch'
if: >
(github.event_name == 'push' || github.event_name == 'workflow_dispatch') &&
github.repository == 'huggingface/lerobot'
permissions:
contents: read
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
@@ -58,7 +60,7 @@ jobs:
# The result of this job triggers the 'Upload PR Documentation' workflow.
build_pr_docs:
name: Build PR Docs
if: github.event_name == 'pull_request'
if: github.event_name == 'pull_request' && github.repository == 'huggingface/lerobot'
permissions:
contents: read
pull-requests: write
-1
View File
@@ -45,7 +45,6 @@ permissions:
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
concurrency:
+2
View File
@@ -43,6 +43,7 @@ jobs:
name: Build CPU Docker for Nightly
runs-on:
group: aws-general-8-plus
if: github.repository == 'huggingface/lerobot'
outputs:
image_tag: ${{ env.DOCKER_IMAGE_NAME_CPU }}
steps:
@@ -77,6 +78,7 @@ jobs:
name: Build GPU Docker for Nightly
runs-on:
group: aws-general-8-plus
if: github.repository == 'huggingface/lerobot'
outputs:
image_tag: ${{ env.DOCKER_IMAGE_NAME_GPU }}
steps:
+1
View File
@@ -29,6 +29,7 @@ jobs:
build-and-publish:
name: Build and publish Python distributions
runs-on: ubuntu-latest
if: github.repository == 'huggingface/lerobot'
outputs:
version: ${{ steps.extract_info.outputs.tag_version }}
permissions:
+1
View File
@@ -45,6 +45,7 @@ jobs:
stale:
name: Close Stale Issues and PRs
runs-on: ubuntu-latest
if: github.repository == 'huggingface/lerobot'
permissions:
actions: write
contents: write # only for delete-branch option
+1
View File
@@ -43,6 +43,7 @@ jobs:
full-tests:
name: Full Unbound Tests
runs-on: ubuntu-latest
if: github.repository == 'huggingface/lerobot'
env:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 51 KiB

+6 -1
View File
@@ -4,11 +4,12 @@ This guide covers the complete setup process for the Unitree G1 humanoid, from i
## About the Unitree G1
We offer support for both 29 and 23 DOF G1. In this first PR we introduce:
We offer support for both 29 and 23 DOF G1. We introduce:
- **`unitree g1` robot class, handling low level communication with the humanoid**
- **ZMQ socket bridge** for remote communication over WiFi, allowing one to deploy policies remotely instead of over ethernet or directly on the Orin
- **GR00T locomotion policy** for bipedal walking and balance
- **MuJoCo simulation mode** for testing policies without the physical robot
---
@@ -191,6 +192,10 @@ Press `Ctrl+C` to stop the policy.
---
## Extra: Running in Simulation Mode (MuJoCo)
You can now test and develop policies without a physical robot using MuJoCo. to do so set `is_simulation=True` in config.
## Additional Resources
- [Unitree SDK Documentation](https://github.com/unitreerobotics/unitree_sdk2_python)
+66 -3
View File
@@ -11,13 +11,14 @@ LeRobot provides several utilities for manipulating datasets:
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
4. **Add Features** - Add new features to a dataset
5. **Remove Features** - Remove features from a dataset
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
The core implementation is in `lerobot.datasets.dataset_tools`.
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
## Command-Line Tool: lerobot-edit-dataset
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features.
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, remove features, and convert image datasets to video format.
Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
@@ -86,9 +87,71 @@ lerobot-edit-dataset \
--operation.feature_names "['observation.images.top']"
```
#### Convert to Video
Convert an image-based dataset to video format, creating a new LeRobotDataset where images are stored as videos. This is useful for reducing storage requirements and improving data loading performance. The new dataset will have the exact same structure as the original, but with images encoded as MP4 videos in the proper LeRobot format.
```bash
# Local-only: Save to a custom output directory (no hub push)
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir /path/to/output/pusht_video
# Save with new repo_id (local storage)
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_to_video
# Convert and push to Hugging Face Hub
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_to_video \
--push_to_hub true
# Convert with custom video codec and quality settings
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \
--operation.vcodec libsvtav1 \
--operation.pix_fmt yuv420p \
--operation.g 2 \
--operation.crf 30
# Convert only specific episodes
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \
--operation.episode_indices "[0, 1, 2, 5, 10]"
# Convert with multiple workers for parallel processing
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \
--operation.num_workers 8
```
**Parameters:**
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
- `fast_decode`: Fast decode tuning option (default: 0)
- `episode_indices`: List of specific episodes to convert (default: all episodes)
- `num_workers`: Number of parallel workers for processing (default: 4)
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
### Push to Hub
Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
```bash
lerobot-edit-dataset \
@@ -96,7 +159,7 @@ lerobot-edit-dataset \
--new_repo_id lerobot/pusht_after_deletion \
--operation.type delete_episodes \
--operation.episode_indices "[0, 2, 5]" \
--push_to_hub
--push_to_hub true
```
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.
+42 -84
View File
@@ -24,7 +24,7 @@ Built from pure Transformer encoders, X-VLA scales naturally with model size and
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture2.png"
alt="XVLA Architecture 2"
style="width: 32%; max-width: 450px; height: auto;"
style="width: 60%; height: auto;"
/>
</p>
@@ -120,7 +120,7 @@ Adapted for Google Robot platforms.
### Recommended Training Configuration
When fine-tuning X-VLA for a new embodiment or task, we recommend the following freezing strategy:
When fine-tuning X-VLA for a new embodiment or task, we recommend not freezing the VLM, and also setting the `policy.dtype=bfloat16` to not hit OOM errors.
```bash
lerobot-train \
@@ -129,25 +129,26 @@ lerobot-train \
--job_name=xvla_training \
--policy.path="lerobot/xvla-base" \
--policy.repo_id="HF_USER/xvla-your-robot" \
--steps=3000 \
--policy.dtype=bfloat16 \
--policy.action_mode=auto \
--steps=20000 \
--policy.device=cuda \
--policy.freeze_vision_encoder=True \
--policy.freeze_language_encoder=True \
--policy.train_policy_transformer=True \
--policy.train_soft_prompts=True \
--policy.action_mode=YOUR_ACTION_MODE
--policy.freeze_vision_encoder=false \
--policy.freeze_language_encoder=false \
--policy.train_policy_transformer=true \
--policy.train_soft_prompts=true \
```
### Training Parameters Explained
| Parameter | Default | Description |
| -------------------------- | ------- | ---------------------------------------- |
| `freeze_vision_encoder` | `True` | Freeze the VLM vision encoder weights |
| `freeze_language_encoder` | `True` | Freeze the VLM language encoder weights |
| `train_policy_transformer` | `True` | Allow policy transformer layers to train |
| `train_soft_prompts` | `True` | Allow soft prompts to train |
| Parameter | Default | Description |
| -------------------------- | ------- | ---------------------------------------------- |
| `freeze_vision_encoder` | `false` | Do not freeze the VLM vision encoder weights |
| `freeze_language_encoder` | `false` | Do not freeze the VLM language encoder weights |
| `train_policy_transformer` | `true` | Allow policy transformer layers to train |
| `train_soft_prompts` | `true` | Allow soft prompts to train |
**💡 Best Practice**: For Phase II adaptation to new embodiments, freeze the VLM encoders and only train the policy transformer and soft prompts. This provides excellent sample efficiency with minimal compute.
**💡 Best Practice**: For Phase II adaptation to new embodiments, do not freeze the VLM encoders and also train the policy transformer and soft prompts.
### Example: Training on Bimanual Robot
@@ -157,14 +158,15 @@ lerobot-train \
--output_dir=./outputs/xvla_bimanual \
--job_name=xvla_so101_training \
--policy.path="lerobot/xvla-base" \
--policy.dtype=bfloat16 \
--policy.repo_id="YOUR_USERNAME/xvla-biso101" \
--steps=3000 \
--policy.device=cuda \
--policy.action_mode=so101_bimanual \
--policy.freeze_vision_encoder=True \
--policy.freeze_language_encoder=True \
--policy.train_policy_transformer=True \
--policy.train_soft_prompts=True
--policy.freeze_vision_encoder=false \
--policy.freeze_language_encoder=false \
--policy.train_policy_transformer=true \
--policy.train_soft_prompts=true
```
💡 **Best Performance:** If you have sufficient computational resources and want to achieve best X-VLA finetuning performance, you should follow the official finetuning strategy:
@@ -172,71 +174,7 @@ lerobot-train \
**🔥 Full-finetune all components with a custom learning-rate scheme**
To ensure stable optimization, the Vision-Language Model (VLM) must be trained with only 1/10 of the base learning rate, while all other components use the full LR.
This LR ratio is crucial for achieving strong and stable finetuning performance.
To enable this behavior, you must:
1. Implement a custom optimizer and register it in your training config
```
from dataclasses import dataclass, asdict
from lerobot.optim.optimizers import OptimizerConfig
import torch
@OptimizerConfig.register_subclass("xvla-adamw")
@dataclass
class XVLAAdamW(OptimizerConfig):
lr: float = 1e-4
betas: tuple[float, float] = (0.9, 0.99)
eps: float = 1e-8
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
def build(self, params: dict) -> torch.optim.Optimizer:
"""
Expect `named_parameters()` as input.
Apply lr = lr / 10 for all VLM-related parameters.
"""
assert isinstance(params, dict), \
"Custom LR optimizer requires `named_parameters()` as inputs."
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
vlm_group, other_group = [], []
for name, p in params.items():
if not p.requires_grad:
continue
if "vlm" in name.lower():
vlm_group.append(p)
else:
other_group.append(p)
param_groups = [
{"params": vlm_group, "lr": self.lr * 0.1, "weight_decay": self.weight_decay * 0.1},
{"params": other_group, "lr": self.lr, "weight_decay": self.weight_decay},
]
return torch.optim.AdamW(param_groups, **kwargs)
```
2. Modify X-VLAs get_optim_params to return named parameters
Replace:
```
def get_optim_params(self) -> dict:
"""Return only trainable parameters for optimization."""
return filter(lambda p: p.requires_grad, self.parameters())
```
with:
```
def get_optim_params(self):
"""Return trainable named parameters."""
return filter(lambda kv: kv[1].requires_grad, self.named_parameters())
```
This ensures the optimizer receives a dict of named parameters, allowing it to correctly detect VLM modules and apply the 1/10 LR rule.
This LR ratio is crucial for achieving strong and stable finetuning performance. This is already done for you by default.
❕Note
Completely matching the official reported performance may require an additional warm-up LR schedule for soft-prompts, which can bring minor improvements.
@@ -326,6 +264,26 @@ domain_id = 3
The domain_id is automatically added to observations by the `XVLAAddDomainIdProcessorStep` in the preprocessing pipeline.
The `lerobot/xvla-base` model has been trained on the following domain IDs. It is recommended to choose one that most resembles your robot/configuration:
#### Fine-tuning Datasets
| Dataset Name | Domain ID |
| ---------------- | --------- |
| Bridge | 0 |
| RT1 | 1 |
| Calvin | 2 |
| libero | 3 |
| widowx-air | 4 |
| AIR-AGILEX-HQ | 5 |
| robotwin2_abs_ee | 6 |
| robotwin2_clean | 6 |
| robocasa-human | 7 |
| VLABench | 8 |
| AGIBOT-challenge | 9 |
| AIR-AGILEX | 10 |
| AIRBOT | 18 |
### 3. Processor Steps
X-VLA requires specific preprocessing and postprocessing steps for proper operation.
-243
View File
@@ -1,243 +0,0 @@
# Synthetic Data Generation Script - Summary
## ✅ What Was Created
### Main Script: `annotate_pgen.py` (717 lines)
A production-ready script implementing the Hi-Robot synthetic data generation pipeline.
**Key Features:**
- ✅ Loads LeRobot datasets with skill annotations
- ✅ Generates synthetic user prompts and robot utterances using Qwen VLM
- ✅ **Temporal sampling** - generates dialogue every N seconds (default: 1s)
- ✅ Adds `task_index_high_level` feature to dataset parquets
- ✅ Saves high-level tasks to `meta/tasks_high_level.parquet`
- ✅ Exports debug JSONL for quality analysis
- ✅ Supports both Qwen2-VL and Qwen3-VL models
- ✅ Multi-view camera support
- ✅ Episode-aware processing with automatic first-frame sampling
- ✅ Modular architecture for easy extension
### Supporting Files Created
1. **`run_pgen.sh`** - Convenience script with sensible defaults
2. **`README_PGEN.md`** - Comprehensive documentation with examples
3. **`example_pgen_usage.md`** - Practical examples and performance estimates
4. **`SAMPLING_DIAGRAM.md`** - Visual explanation of temporal sampling strategy
5. **`PGEN_SUMMARY.md`** - This file
## 🚀 Key Innovation: Temporal Sampling
The script processes **ALL episodes** in the dataset efficiently via `--sample-interval`:
```bash
# Instead of calling VLM for every frame (expensive):
# 15,000 frames × VLM call = ~5 hours
# Generate dialogue every 1 second (efficient):
python annotate_pgen.py --repo-id dataset --model qwen --sample-interval 1.0
# 15,000 frames processed, only ~500 VLM calls (30x speedup!)
```
**How it works:**
- Process ALL frames in ALL episodes (complete coverage)
- Generate dialogue at sampled timepoints (e.g., every 1 second)
- Propagate task indices to intermediate frames
- Always sample first frame of each episode
- All frames get labeled, but VLM is only called for samples
- No dummy values or skipped episodes
**Benefits:**
- 30-100x speedup depending on interval
- Maintains temporal coherence
- Reduces cost without losing quality
- Configurable based on skill duration
## 📊 Efficiency Comparison
For a typical 15,000 frame dataset at 30 fps:
| Method | VLM Calls | Time | Cost |
|--------|-----------|------|------|
| Every frame | 15,000 | ~5 hours | $$$$ |
| Every 0.5s | 1,000 | ~20 min | $$$ |
| **Every 1s** (default) | **500** | **~10 min** | **$$** |
| Every 2s | 250 | ~5 min | $ |
## 🎯 Usage
### Quick Test (5s sampling for fast iteration)
```bash
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--sample-interval 5.0 \
--output-dir ./outputs/test_quick
```
### Production Run (Recommended Settings)
```bash
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--sample-interval 1.0 \
--output-dir ./outputs/full_pgen
```
### High-Quality with Qwen3
```bash
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
--sample-interval 0.5 \
--temperature 0.6 \
--output-dir ./outputs/high_quality
```
## 📦 Output Structure
After running, you'll have:
```
dataset_root/
├── meta/
│ ├── tasks_high_level.parquet # High-level tasks with prompts/utterances
│ └── syn_annotations.jsonl # Debug: full context for each sample
└── data/
└── chunk-000/
└── file-000.parquet # Updated with task_index_high_level
```
**New feature added to all parquet files:**
- `task_index_high_level` (int64): Links to tasks_high_level.parquet
## 🔧 All Parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `--repo-id` / `--data-dir` | - | Dataset source |
| `--model` | Qwen/Qwen2-VL-7B-Instruct | VLM model |
| `--device` | cuda | Device to use |
| `--dtype` | bfloat16 | Model precision |
| `--temperature` | 0.7 | Sampling temperature |
| **`--sample-interval`** | **1.0** | **Generate every N seconds (all episodes processed)** |
| `--num-image-views-per-sample` | 1 | Number of cameras |
| `--batch-size` | 1 | Batch size (currently unused) |
| `--output-dir` | None | Output directory |
| `--push-to-hub` | False | Push to HuggingFace |
## 🎨 Generated Data Format
Each sampled frame produces:
```json
{
"scenario_type": "specific_object",
"response_type": "confirmation",
"user_prompt": "Can you pick up the pink brick?",
"robot_utterance": "Sure, I'll grab the pink lego brick.",
"skill": "robot arm picks up pink lego brick",
"episode_id": 0,
"frame_index": 45,
"timestamp": 1.5,
"skill_history": ["robot arm moves towards pink lego brick"],
"task_description": "pink lego brick into the transparent box"
}
```
**Scenario Types:**
- specific_object, negative_task, situated_correction, implicit_request, constraint_based
**Response Types:**
- confirmation, clarification, acknowledgment, constraint_acknowledgment
## 🔬 Code Architecture
```python
# Main components (modular design)
class QwenPgen:
"""VLM wrapper supporting Qwen2/3"""
def call_qwen(images, prompt) -> dict
def construct_prompt(task, history, skill) -> str:
"""Build contextual prompt with history"""
def annotate_sample(pgen, images, ...) -> dict:
"""Generate dialogue for one sample"""
def generate_synthetic_data(dataset, pgen, ...) -> tuple:
"""Process entire dataset with temporal sampling"""
# Core sampling logic:
# - Track last_sample_timestamp per episode
# - Sample if time_elapsed >= sample_interval
# - Always sample first frame of episodes
# - Propagate task_index to intermediate frames
def main():
"""CLI entrypoint with argparse"""
```
## ✨ Next Steps
1. **Quick test with large interval:**
```bash
# Fast iteration - samples every 5 seconds
python examples/dataset/annotate_pgen.py \
--data-dir /path/to/dataset \
--model Qwen/Qwen2-VL-7B-Instruct \
--sample-interval 5.0 \
--output-dir ./outputs/quick_test
```
2. **Verify output quality:**
```bash
head outputs/quick_test/meta/syn_annotations.jsonl
```
3. **Production run:**
```bash
# Standard 1 second sampling for production
bash examples/dataset/run_pgen.sh
```
4. **Use in training:**
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
ds = LeRobotDataset(repo_id="...", root="outputs/pgen_annotations")
# Access high-level task for each frame
frame = ds[100]
task_idx = frame["task_index_high_level"].item()
```
## 📚 Documentation Files
- **`README_PGEN.md`**: Full API reference and troubleshooting
- **`example_pgen_usage.md`**: Practical examples with performance estimates
- **`SAMPLING_DIAGRAM.md`**: Visual explanation of temporal sampling
- **`PGEN_SUMMARY.md`**: This overview document
## 🎯 Success Criteria
✅ Script generates synthetic dialogue using Qwen VLM
✅ Adds `task_index_high_level` feature to dataset
✅ Saves tasks to `tasks_high_level.parquet`
✅ Implements efficient temporal sampling (30-100x speedup)
✅ Handles episode boundaries correctly
✅ Produces diverse interaction types (scenarios + responses)
✅ Maintains temporal coherence within episodes
✅ Includes comprehensive documentation and examples
✅ Ready for production use on real datasets
## 💡 Key Takeaway
**The script processes ALL episodes with intelligent sampling:**
- `--sample-interval` controls how often VLM is called (default: 1.0s)
- ALL frames in ALL episodes get labeled (complete coverage)
- Intermediate frames inherit from most recent sample (temporal coherence)
- Achieves 30-100x speedup while maintaining quality
- Adjust interval based on use case: 5.0s for testing, 1.0s for production, 0.5s for fine detail
This makes the synthetic data generation **practical, scalable, and complete** for real-world datasets!
-243
View File
@@ -1,243 +0,0 @@
# Synthetic Data Generation for Hierarchical Robot Policies
This directory contains `annotate_pgen.py`, a script for generating synthetic user prompts and robot utterances for hierarchical policy training using Vision-Language Models (VLMs).
## Overview
The script implements the synthetic data generation pipeline described in the Hi-Robot paper:
1. **Load** a LeRobot dataset with skill annotations (from `annotate.py`)
2. **Generate** synthetic dialogue using Qwen VLM:
- User prompts (_t): Natural requests that lead to specific skills
- Robot utterances (u_t): Acknowledgments and clarifications
3. **Save** results as a new dataset feature `task_index_high_level`
## Prerequisites
1. First, annotate your dataset with skills using `annotate.py`:
```bash
python examples/dataset/annotate.py \
--repo-id lerobot/svla_so101_pickplace \
--video-key observation.images.base \
--model Qwen/Qwen2-VL-7B-Instruct
```
This creates `meta/skills.json` with skill segmentation for each episode.
## Usage
### Basic Usage
```bash
python examples/dataset/annotate_pgen.py \
--repo-id lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--sample-interval 1.0 \
--output-dir ./outputs/pgen_dataset
```
**Note**: The script processes **all episodes** in the dataset. It generates dialogue every 1 second (`--sample-interval 1.0`) using temporal sampling. Frames between samples reuse the last generated dialogue. This makes the process efficient while ensuring complete dataset coverage.
### Advanced Options
```bash
python examples/dataset/annotate_pgen.py \
--repo-id lerobot/svla_so101_pickplace \
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
--temperature 0.8 \
--sample-interval 0.5 \
--num-image-views-per-sample 2 \
--output-dir ./outputs/pgen_dataset \
--push-to-hub
```
This example uses a more powerful model and samples every 0.5 seconds for finer granularity.
### Fast Testing (larger interval)
```bash
python examples/dataset/annotate_pgen.py \
--repo-id lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--sample-interval 5.0 \
--output-dir ./outputs/pgen_quick_test
```
Use a larger interval (5.0 seconds) for rapid iteration during development. All episodes are still processed.
### Using Local Dataset
```bash
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--output-dir ./outputs/pgen_dataset
```
## Output Files
The script produces several outputs:
1. **`meta/tasks_high_level.parquet`**: High-level tasks with user prompts and robot utterances
- Columns: task_index, user_prompt, robot_utterance, skill, scenario_type, response_type
2. **`meta/syn_annotations.jsonl`**: Debug file with all generated dialogues
- One JSON object per line with full context for each frame
3. **Modified dataset**: New dataset with `task_index_high_level` feature added to all parquet files
## Scenario and Response Types
The generator produces diverse interaction types:
### Scenario Types
- **specific_object**: Direct specification of objects/actions
- **negative_task**: Instructions about what NOT to do
- **situated_correction**: Adjustments based on current state
- **implicit_request**: Implied needs without direct commands
- **constraint_based**: Specific constraints or preferences
### Response Types
- **confirmation**: Simple acknowledgment ("OK, I'll do X")
- **clarification**: Seeking confirmation ("Just to confirm...")
- **acknowledgment**: Action acknowledgment ("Got it, doing X")
- **constraint_acknowledgment**: Acknowledging constraints ("Sure, I'll X while Y")
## Example Generated Data
```json
{
"episode_id": 0,
"frame_index": 45,
"timestamp": 2.5,
"skill_current": "robot arm picks up pink lego brick",
"skill_history": ["robot arm moves towards pink lego brick"],
"task_description": "pink lego brick into the transparent box",
"scenario_type": "specific_object",
"response_type": "confirmation",
"user_prompt": "Can you grab the pink brick?",
"robot_utterance": "Sure, I'll pick up the pink lego brick."
}
```
## Accessing the Data
After running the script, access the synthetic data in your code:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
import pandas as pd
# Load modified dataset
dataset = LeRobotDataset(repo_id="lerobot/svla_so101_pickplace_with_high_level_tasks")
# Access frame with high-level task
frame = dataset[100]
high_level_task_idx = frame["task_index_high_level"].item()
# Load high-level tasks
tasks_df = pd.read_parquet(dataset.root / "meta" / "tasks_high_level.parquet")
task_info = tasks_df.iloc[high_level_task_idx]
print(f"User prompt: {task_info['user_prompt']}")
print(f"Robot utterance: {task_info['robot_utterance']}")
print(f"Skill: {task_info['skill']}")
```
## Architecture
The script is modular and extensible:
```python
# Core components
class QwenPgen:
"""VLM wrapper for generation"""
def call_qwen(images, prompt) -> dict
def construct_prompt(task, history, skill) -> str
"""Build prompt for VLM"""
def annotate_sample(pgen, images, ...) -> dict
"""Generate dialogue for one sample"""
def generate_synthetic_data(dataset, pgen, ...) -> tuple
"""Process entire dataset"""
```
## Parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `--repo-id` | - | HuggingFace dataset ID |
| `--data-dir` | - | Local dataset path |
| `--model` | Qwen/Qwen2-VL-7B-Instruct | VLM model name |
| `--device` | cuda | Device (cuda/cpu) |
| `--dtype` | bfloat16 | Model precision |
| `--temperature` | 0.7 | Sampling temperature |
| `--sample-interval` | 1.0 | Generate dialogue every N seconds (all episodes processed) |
| `--num-image-views-per-sample` | 1 | Number of cameras |
| `--output-dir` | None | Output directory |
| `--push-to-hub` | False | Push to HuggingFace Hub |
## Sampling Strategy
The script uses **temporal sampling** to efficiently generate dialogue:
- **Default**: Generate dialogue every 1 second (`--sample-interval 1.0`)
- **Efficiency**: If a dataset runs at 30fps, this samples ~3% of frames
- **Propagation**: Frames between samples reuse the last generated task_index
- **Episode-aware**: Always samples the first frame of each episode
### Example with 30 fps dataset:
```bash
# Sample every 1 second (every 30 frames)
--sample-interval 1.0 # ~3,000 generations for a 100 episode dataset (3 sec/episode)
# Sample every 0.5 seconds (every 15 frames)
--sample-interval 0.5 # ~6,000 generations (more granular)
# Sample every 2 seconds (every 60 frames)
--sample-interval 2.0 # ~1,500 generations (more efficient)
```
### Why sampling works:
- Skills typically last 1-3 seconds
- Dialogue doesn't need to change every frame
- Reduces computational cost by 30-100x
- Still provides good coverage for training
## Tips
1. **Quick testing**: Use larger `--sample-interval` (e.g., 5.0 or 10.0) for rapid iteration
2. **Monitor GPU**: VLM inference is memory-intensive
3. **Check outputs**: Review `syn_annotations.jsonl` for quality
4. **Adjust temperature**: Higher = more diverse, lower = more consistent
5. **Multiple views**: Use `--num-image-views-per-sample 2+` for better context
6. **Tune sampling**: Start with 1.0s, increase for speed (testing), decrease for granularity (production)
## Troubleshooting
### No skills.json found
Run `annotate.py` first to generate skill annotations.
### Out of memory
- Reduce batch size to 1
- Use smaller model (Qwen2-VL-7B instead of Qwen3-VL-30B)
- Process fewer samples at a time
### Poor quality generations
- Adjust temperature (try 0.6-0.9)
- Check that skills.json has good annotations
- Ensure images are loading correctly
## Citation
Based on the Hi-Robot paper's synthetic data generation approach:
```
@article{hirobot2024,
title={Hi-Robot: Hierarchical Robot Learning with Vision-Language Models},
year={2024}
}
```
-141
View File
@@ -1,141 +0,0 @@
# Temporal Sampling Strategy Visualization
## How `--sample-interval` Works
### Example: 30 fps dataset, `--sample-interval 1.0` (1 second)
```
Timeline (seconds): 0.0 0.5 1.0 1.5 2.0 2.5 3.0
│ │ │ │ │ │ │
Frames: 0───15───30───45───60───75───90───105──120──135──150
│ │ │ │ │ │ │
▼ ▼ ▼ ▼
Sampled: YES NO YES NO YES NO YES
│ │ │ │
Task Index: [0]──────────────>[1]──────────────>[2]──────────────>[3]
│ │ │ │
VLM Called: ✓ Gen ✓ Gen ✓ Gen ✓ Gen
dialogue dialogue dialogue dialogue
│ │ │ │
Frames 0-29 ─────┘ │ │ │
get task 0 │ │ │
│ │ │
Frames 30-59 ────────────────────────┘ │ │
get task 1 │ │
│ │
Frames 60-89 ──────────────────────────────────────────┘ │
get task 2 │
Frames 90-119 ────────────────────────────────────────────────────────────┘
get task 3
```
## Comparison: Different Sampling Intervals
### `--sample-interval 2.0` (every 2 seconds)
```
Timeline: 0.0 1.0 2.0 3.0 4.0 5.0 6.0
│ │ │ │ │ │ │
Sampled: YES NO YES NO YES NO YES
│ │ │ │
Tasks: [0]───────────────>[1]───────────────>[2]───────────────>[3]
VLM Calls: 4 (fewer calls, faster but less granular)
```
### `--sample-interval 1.0` (every 1 second) - **DEFAULT**
```
Timeline: 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 4.5 5.0 5.5 6.0
│ │ │ │ │ │ │ │ │ │ │ │ │
Sampled: YES NO YES NO YES NO YES NO YES NO YES NO YES
│ │ │ │ │ │ │
Tasks: [0]─────────>[1]─────────>[2]─────────>[3]─────────>[4]─────────>[5]─────>[6]
VLM Calls: 7 (balanced coverage and speed)
```
### `--sample-interval 0.5` (every 0.5 seconds)
```
Timeline: 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 4.5 5.0 5.5 6.0
│ │ │ │ │ │ │ │ │ │ │ │ │
Sampled: YES YES YES YES YES YES YES YES YES YES YES YES YES
│ │ │ │ │ │ │ │ │ │ │ │ │
Tasks: [0]─>[1]─>[2]─>[3]─>[4]─>[5]─>[6]─>[7]─>[8]─>[9]─>[10]>[11]>[12]
VLM Calls: 13 (high granularity, slower but more detailed)
```
## Episode Boundaries
The script always samples the **first frame** of each episode:
```
Episode 0 Episode 1 Episode 2
├─────────────────────────────────┤├─────────────────────────────────┤├──────...
│ ││ ││
Frame: 0 30 60 90 120 130 160 190 220 250 260 290 320
Time: 0.0 1.0 2.0 3.0 4.0 0.0 1.0 2.0 3.0 4.0 0.0 1.0 2.0
│ │ │ │ │ │ │ │ │ │ │ │ │
▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼
Sample:YES YES YES YES YES YES YES YES YES YES YES YES YES
│ │ │ │ │ │ │ │ │ │ │ │ │
Task: 0────1─────2─────3────4 5─────6─────7─────8────9 10────11───12
Note: Frames 0, 130, 260 are ALWAYS sampled (episode starts)
Even if they're within the sample-interval window
```
## Real-World Example: svla_so101_pickplace Dataset
Typical stats:
- **Total episodes**: 50
- **Avg episode length**: 300 frames (10 seconds at 30 fps)
- **Total frames**: 15,000
### Without Sampling (every frame)
```
Frames processed: 15,000
VLM calls: 15,000
Time estimate: ~5 hours
Unique tasks: ~12,000 (lots of duplicates)
```
### With `--sample-interval 1.0` (every 1 second)
```
Frames processed: 15,000 ✓
VLM calls: 500
Time estimate: ~10 minutes
Unique tasks: ~450 (meaningful variety)
Efficiency gain: 30x faster
```
### With `--sample-interval 2.0` (every 2 seconds)
```
Frames processed: 15,000 ✓
VLM calls: 250
Time estimate: ~5 minutes
Unique tasks: ~220
Efficiency gain: 60x faster
```
## Key Points
1. **All frames get labeled**: Every frame gets a `task_index_high_level`
2. **Only sampled frames call VLM**: Huge efficiency gain
3. **Temporal coherence**: Nearby frames share the same task
4. **Episode-aware**: Always samples episode starts
5. **Configurable**: Adjust `--sample-interval` based on your needs
## Choosing Your Sampling Interval
| Use Case | Recommended Interval | Why |
|----------|---------------------|-----|
| Quick testing | 2.0s | Fastest iteration |
| Standard training | 1.0s | Good balance |
| High-quality dataset | 0.5s | Better coverage |
| Fine-grained control | 0.33s | Very detailed |
| Dense annotations | 0.1s | Nearly every frame |
**Rule of thumb**: Match your sampling interval to your typical skill duration.
If skills last 1-3 seconds, sampling every 1 second captures each skill multiple times.
@@ -1,138 +0,0 @@
#!/usr/bin/env python
"""
Example demonstrating how to use the ActionTokenizerProcessorStep to tokenize actions.
This example shows how to:
1. Load a dataset with action data
2. Apply the action tokenizer processor to tokenize actions with proper padding/truncation
3. Access both the tokenized actions and the attention mask
4. Decode tokenized actions back to their original form
"""
import torch
from transformers import AutoProcessor
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.processor.tokenizer_processor import ActionTokenizerProcessorStep
from lerobot.utils.constants import ACTION_TOKEN_MASK
# Define delta timestamps for the dataset
delta_timestamps = {
'action': [
0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333,
0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3,
0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335,
0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6,
0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333,
0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9,
0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334,
1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2,
1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333,
1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5,
1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333
]
}
# Load the dataset
print("Loading dataset...")
dataset = LeRobotDataset(
repo_id="local",
root="/fsx/jade_choghari/outputs/pgen_annotations1",
delta_timestamps=delta_timestamps
)
# Create a dataloader
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=4,
shuffle=True,
)
# Get a batch of data
batch = next(iter(dataloader))
action_data = batch["action"] # Shape: (batch_size, action_horizon, action_dim)
print(f"\nOriginal action shape: {action_data.shape}")
print(f"Original action data (first sample, first timestep):\n{action_data[0, 0]}")
# Method 1: Using the tokenizer directly (as in fast_tokenize.py)
print("\n" + "="*80)
print("Method 1: Direct tokenizer usage")
print("="*80)
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
# Tokenize directly
tokens = tokenizer(action_data)
print(f"\nDirect tokenization result type: {type(tokens)}")
print(f"Tokens shape/length: {tokens.shape if isinstance(tokens, torch.Tensor) else len(tokens)}")
# Decode
decoded_actions = tokenizer.decode(tokens)
print(f"Decoded actions shape: {decoded_actions.shape}")
reconstruction_error = torch.abs(action_data - decoded_actions).mean()
print(f"Mean absolute reconstruction error: {reconstruction_error.item():.6f}")
# Method 2: Using the ActionTokenizerProcessorStep with proper padding/truncation
print("\n" + "="*80)
print("Method 2: Using ActionTokenizerProcessorStep (with padding & mask)")
print("="*80)
# Create the action tokenizer processor step
action_tokenizer_processor = ActionTokenizerProcessorStep(
tokenizer_name="physical-intelligence/fast",
trust_remote_code=True,
max_action_tokens=32, # Maximum number of tokens per action
)
# Create a transition with the action data
transition = {
TransitionKey.ACTION: action_data,
TransitionKey.OBSERVATION: {}, # Empty for this example
}
# Apply the processor
processed_transition = action_tokenizer_processor(transition)
# Extract tokenized actions and mask
tokenized_actions = processed_transition[TransitionKey.ACTION]
complementary_data = processed_transition[TransitionKey.COMPLEMENTARY_DATA]
action_mask = complementary_data[ACTION_TOKEN_MASK]
print(f"\nTokenized actions shape: {tokenized_actions.shape}") # (batch_size, max_action_tokens)
print(f"Action mask shape: {action_mask.shape}") # (batch_size, max_action_tokens)
print(f"Tokenized actions dtype: {tokenized_actions.dtype}")
print(f"Action mask dtype: {action_mask.dtype}")
# Show token statistics
print(f"\nFirst sample tokens: {tokenized_actions[0]}")
print(f"First sample mask: {action_mask[0]}")
num_real_tokens = action_mask[0].sum().item()
print(f"Number of real tokens (non-padding): {num_real_tokens}")
print(f"Number of padding tokens: {action_mask.shape[1] - num_real_tokens}")
# Decode using the mask
print("\nDecoding tokenized actions...")
decoded_with_processor = tokenizer.decode(tokenized_actions)
print(f"Decoded actions shape: {decoded_with_processor.shape}")
# Calculate reconstruction error
reconstruction_error_processor = torch.abs(action_data - decoded_with_processor).mean()
print(f"Mean absolute reconstruction error: {reconstruction_error_processor.item():.6f}")
# Show that masking works correctly
print("\n" + "="*80)
print("Mask demonstration")
print("="*80)
for i in range(min(4, tokenized_actions.shape[0])):
mask_i = action_mask[i]
num_real = mask_i.sum().item()
print(f"Sample {i}: {num_real} real tokens, {len(mask_i) - num_real} padding tokens")
print("\n" + "="*80)
print("Action tokenization example completed successfully!")
print("="*80)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
-143
View File
@@ -1,143 +0,0 @@
# Example: Synthetic Data Generation with Sampling
## Quick Start
### 1. Test with 100 frames and 1 second sampling
```bash
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--num-samples 100 \
--sample-interval 1.0 \
--output-dir ./outputs/test_pgen
```
**Expected behavior** (assuming 30 fps):
- Total frames: 100
- Frames sampled: ~4 (every 30 frames = 1 second)
- Efficiency: 96% fewer VLM calls
- Output: All 100 frames get `task_index_high_level`, but only 4 unique dialogues generated
### 2. Process full dataset with different sampling rates
#### Conservative (every 2 seconds)
```bash
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--sample-interval 2.0 \
--output-dir ./outputs/pgen_2s
```
#### Standard (every 1 second) - **RECOMMENDED**
```bash
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--sample-interval 1.0 \
--output-dir ./outputs/pgen_1s
```
#### Fine-grained (every 0.5 seconds)
```bash
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--sample-interval 0.5 \
--output-dir ./outputs/pgen_0.5s
```
## Performance Estimates
For a dataset with:
- 100 episodes
- 10 seconds per episode (average)
- 30 fps
- Total frames: 30,000
| Sampling Interval | Frames Sampled | % Sampled | Speedup | Time Estimate |
|-------------------|----------------|-----------|---------|---------------|
| Every frame (0.033s) | 30,000 | 100% | 1x | ~10 hours |
| 0.5 seconds | 2,000 | 6.7% | 15x | ~40 min |
| **1.0 seconds** | **1,000** | **3.3%** | **30x** | **~20 min** |
| 2.0 seconds | 500 | 1.7% | 60x | ~10 min |
*Note: Times are approximate and depend on GPU, model size, and generation speed*
## Understanding the Output
### Console Output Example
```
[cyan]Generating synthetic data for 30000 frames...[/cyan]
[cyan]Sampling interval: 1.0s (fps: 30)[/cyan]
Generating synthetic dialogue: 100%|████████| 30000/30000 [20:15<00:00, 24.68it/s]
[green]✓ Sampled 1000 frames out of 30000 (3.3%)[/green]
[green]✓ Generated 450 unique high-level tasks[/green]
```
### What happens:
1. **Frame 0 (t=0.0s)**: Generate dialogue → Task index 0
2. **Frames 1-29 (t=0.033s-0.967s)**: Reuse task index 0
3. **Frame 30 (t=1.0s)**: Generate new dialogue → Task index 1
4. **Frames 31-59 (t=1.033s-1.967s)**: Reuse task index 1
5. And so on...
### Result:
- Every frame has a `task_index_high_level`
- Only sampled frames have unique dialogues generated
- Intermediate frames inherit from the most recent sample
- Maintains temporal coherence within episodes
## Checking Your Results
After running, verify the output:
```bash
# Check the generated tasks
python -c "
import pandas as pd
from pathlib import Path
tasks = pd.read_parquet('outputs/test_pgen/meta/tasks_high_level.parquet')
print(f'Total unique tasks: {len(tasks)}')
print(f'Sample tasks:')
print(tasks[['user_prompt', 'robot_utterance', 'skill']].head())
"
# Check debug output
head outputs/test_pgen/meta/syn_annotations.jsonl
# Load and verify dataset
python -c "
from lerobot.datasets.lerobot_dataset import LeRobotDataset
ds = LeRobotDataset(repo_id='local_with_high_level_tasks',
root='outputs/test_pgen')
print(f'Dataset has {len(ds)} frames')
print(f'Features: {list(ds.features.keys())}')
assert 'task_index_high_level' in ds.features
print('✓ task_index_high_level feature added successfully!')
"
```
## Common Use Cases
### Development/Testing
```bash
--sample-interval 2.0 # Fast iteration
--num-samples 500 # Small subset
```
### Production Training
```bash
--sample-interval 1.0 # Good coverage
# Process all samples (no --num-samples)
```
### High-Quality Dataset
```bash
--sample-interval 0.5 # Fine-grained
--temperature 0.6 # More consistent
--model Qwen/Qwen3-VL-30B-A3B-Instruct # Larger model
```
-25
View File
@@ -1,25 +0,0 @@
import numpy as np
from transformers import AutoProcessor
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=4,
shuffle=True,
)
batch = next(iter(dataloader))
# Load the tokenizer from the Hugging Face hub
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
# Tokenize & decode action chunks (we use dummy data here)
action_data = batch["action"] # one batch of action chunks
tokens = tokenizer(action_data) # tokens = list[int]
decoded_actions = tokenizer.decode(tokens)
print("tokenized actions: ", tokens)
-17
View File
@@ -1,17 +0,0 @@
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
breakpoint()
prefix_output = model.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
)
prefix_past_key_values = prefix_output.past_key_values
# prefix_output to be used for the language head
# shape: [batch_size, seq_len, hidden_size] with hidden_size = 2048
prefix_output = prefix_output.last_hidden_state
-91
View File
@@ -1,91 +0,0 @@
import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
# import make_pre_post_processors
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.factory import make_policy, make_policy_config
from lerobot.configs.policies import PreTrainedConfig
cfg = PreTrainedConfig.from_pretrained(
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model",
)
cfg.dtype = "bfloat16"
pre_processor, post_processor = make_pre_post_processors(
policy_cfg=cfg,
pretrained_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model",
)
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps)
# rename map --rename_map='{
# "observation.images.side": "observation.images.base_0_rgb",
# "observation.images.up": "observation.images.left_wrist_0_rgb"
# }'
rename_map = {
"observation.images.side": "observation.images.base_0_rgb",
"observation.images.up": "observation.images.left_wrist_0_rgb"
}
policy = make_policy(
cfg=cfg,
ds_meta=dataset.meta,
rename_map=rename_map,
)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=4,
shuffle=True,
)
batch = next(iter(dataloader))
batch = pre_processor(batch)
policy.train()
# run inference
# action = policy.select_action(batch)
loss, loss_dict = policy.forward(batch)
breakpoint()
# import requests
# from PIL import Image
# from transformers import AutoProcessor
# model = policy.model.paligemma_with_expert.paligemma
# model = model.to(device="cuda", dtype=torch.bfloat16)
# model.eval()
# prompt = "Describe this image."
# url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
# image = Image.open(requests.get(url, stream=True).raw)
# processor = AutoProcessor.from_pretrained(
# "google/paligemma-3b-pt-224",
# )
# inputs = processor(image, prompt, return_tensors="pt").to(model.device)
# print("generating...")
# output = model.generate(
# **inputs,
# max_new_tokens=50,
# use_cache=True, # default dynamic cache
# )
# print(processor.decode(output[0], skip_special_tokens=True))
# # other model
# from transformers import PaliGemmaForConditionalGeneration
# model = PaliGemmaForConditionalGeneration.from_pretrained(
# "google/paligemma2-3b-pt-224",
# torch_dtype=torch.bfloat16,
# device_map="auto",
# )
# model.eval()
# print("generating...")
# output = model.generate(
# **inputs,
# max_new_tokens=100,
# use_cache=True, # default dynamic cache
# )
# print("Model 2 output:")
# print(processor.decode(output[0], skip_special_tokens=True))
-23
View File
@@ -1,23 +0,0 @@
import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1")
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=32,
shuffle=True,
)
batch = next(iter(dataloader))
print(batch.keys())
print(batch['task_index_high_level'].shape)
print(batch['task_index_high_level'])
print(batch['user_prompt'][0])
print(batch['robot_utterance'][0])
print(batch['task'][0])
breakpoint()
-18
View File
@@ -1,18 +0,0 @@
import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
dataset = LeRobotDataset(repo_id="lerobot/libero")
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=4,
shuffle=True,
)
batch = next(iter(dataloader))
print(batch.keys())
breakpoint()
-159
View File
@@ -1,159 +0,0 @@
## One-sentence answer
> `make_att_2d_masks(prefix_pad_masks, prefix_att_masks)` builds the **actual 2D attention mask** `[B, L, L]` that tells the transformer **which token positions may attend to which others**, combining **padding** and **causality**.
Everything else youve seen so far was just metadata.
---
## What goes in
### Inputs
```python
prefix_pad_masks # shape [B, L]
prefix_att_masks # shape [B, L]
```
Where:
* `prefix_pad_masks[b, i] = True`
→ token `i` exists (not padding)
* `prefix_att_masks[b, i] = False`
→ token `i` is **bidirectional**
* `prefix_att_masks[b, i] = True`
→ token `i` is **causal (autoregressive)**
---
## What comes out
```python
att_2d_prefix # shape [B, L, L]
```
Each entry:
```text
att_2d_prefix[b, i, j] = True
```
means:
> “In batch `b`, **token i (query)** is allowed to attend to **token j (key)**.”
---
## How it is constructed (conceptually)
For **each batch b**, **each query position i**, **each key position j**:
```python
if not prefix_pad_masks[b, j]:
att[b, i, j] = False # cannot attend to padding
else if not prefix_att_masks[b, i]:
att[b, i, j] = True # bidirectional token → can see all real tokens
else:
att[b, i, j] = (j <= i) # causal token → can see only past + itself
```
Thats it.
---
## Tiny concrete example (exactly matching your code)
Suppose:
```python
prefix_pad_masks[0] = [T, T, T, T, T, F]
prefix_att_masks[0] = [F, F, F, T, T, T]
```
Tokens:
```
0: IMG
1: IMG
2: LANG
3: SUB0
4: SUB1
5: PAD
```
---
### Resulting `att_2d_prefix[0]`
`✓ = True, ✗ = False`
| Q \ K | 0 | 1 | 2 | 3 | 4 | 5 |
| ---------- | - | - | - | - | - | - |
| 0 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 1 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 2 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 3 (causal) | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ |
| 4 (causal) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 5 (pad) | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ |
---
## Why this matters for your training code
This line:
```python
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix)
```
Converts `[B, L, L] → [B, 1, L, L]` and possibly flips True/False to `0/-inf`.
This is **exactly what Paligemma uses inside self-attention**.
---
## Key implications (VERY important)
### 1️⃣ This mask does **not isolate token groups**
* Bidirectional tokens can attend to **everything**
* Causal tokens only restrict *their own row*
So **flow/action tokens must be blocked separately**.
---
### 2️⃣ This is why your AR subtask prediction works
* Subtask tokens are causal
* Output at position `i` predicts token `i+1`
* Padding is fully ignored
---
### 3️⃣ Inference behavior
When `subtask_tokens = None`:
* `prefix_att_masks` contains only `False`
* `att_2d_prefix` becomes **fully bidirectional**
* No AR behavior remains
Exactly what you want.
---
## One-sentence takeaway (commit this)
> `make_att_2d_masks` fuses **padding** and **causality** into a concrete `[B, L, L]` attention matrix that the transformer actually uses.
If you want next, I can:
* inspect `make_att_2d_masks()` source with you
* show how to block **flow → subtask** attention
* explain how this changes when suffix tokens are added
* help you refactor this into a cleaner “grouped attention” API
Youre now at the point where the models behavior should feel *predictable*, not magical.
-334
View File
@@ -1,334 +0,0 @@
Generate annotate_pgen.py using Qwen for synthetic data generation
You are writing a Python script called annotate_pgen.py.
This script generates synthetic user prompts (_t) and robot utterances (u_t) for Hi Robotstyle hierarchical policy training, using Qwen 3vl as the generator model (pgen).
SCRIPT PURPOSE
The script must:
Load Dlabeled which is a LeRobot Dataset that has been annotate using the annotate.py script, which contains:
images: list of image paths at time t
skill_current: the annotated skill label (ℓ̂_t)
skill_history: list of previous skill labels (ℓ̂₀ … ℓ̂_{t1}), those where annotated, and you can find details on them stored in teh dataset inside the the DATA_PATH/meta/skills.json
you will find something like
{
"coarse_description": "pink lego brick into the transparent box",
"skill_to_task_index": {
"robot arm picks up pink lego brick": 19,
"robot arm approaches transparent box": 3,
"robot arm retracts from transparent box": 28,
"robot arm moves towards pink lego brick": 12,
"robot arm releases red lego brick into box": 26,
"robot arm releases red lego brick into transparent box": 27,
"robot arm closes gripper to pick up the pink lego brick": 5,
"robot arm lifts the pink lego brick": 7,
etc..
},
"episodes": {
"0": {
"episode_index": 0,
"description": "pink lego brick into the transparent box",
"skills": [
{
"name": "robot arm moves towards pink lego brick",
"start": 0.0,
"end": 1.8
},
{
"name": "robot arm picks up pink lego brick",
"start": 1.8,
"end": 3.1
},
{
"name": "robot arm moves towards transparent box",
"start": 3.1,
"end": 5.5
},
{
"name": "robot arm releases pink lego brick into transparent box",
"start": 5.5,
"end": 7.0
},
{
"name": "robot arm retracts from transparent box",
"start": 7.0,
"end": 10.1
}
]
},
"1": {
"episode_index": 1,
"description": "pink lego brick into the transparent box",
"skills": [
{
"name": "robot arm moves towards red lego brick",
"start": 0.0,
"end": 1.2
},
{
"name": "robot arm picks up red lego brick",
"start": 1.2,
"end": 2.0
},
{
"name": "robot arm moves towards transparent box",
"start": 2.0,
"end": 3.8
},
{
"name": "robot arm places red lego brick into transparent box",
"start": 3.8,
"end": 5.0
},
{
"name": "robot arm moves away from transparent box",
"start": 5.0,
"end": 8.9
}
]
},
notice how task_description: is a high-level description (e.g., "make a sandwich") stored in description for each episode
For each sample, call Qwen VLM to generate:
synthetic user prompt _t
synthetic robot response u_t
Save results to D_syn in Parquet format insdie DATA_PATH/meta/tasks.parquet ; note tasks.parquet already contains the other tasks, so you need to update
Should be modular, clean, easy to extend, with:
a PGEN_PROMPT_TEMPLATE
a construct_prompt() method
a call_qwen() method
a annotate_sample() method
a CLI entrypoint (if __name__ == "__main__":)
📦 INPUT FORMAT (Dlabeled)
The script should expect Dlabeled as a .jsonl file where each line has:
{
"episode_id": "ep_001",
"t": 37,
"images": ["path/to/cam0_t.jpg", "path/to/cam1_t.jpg"],
"skill_current": "pick up the KitKat",
"skill_history": ["open fridge", "pick up lettuce", "place lettuce"],
"task_description": "making a sandwich"
}
📤 OUTPUT FORMAT (D_syn)
Each line of synthetically generated data should be:
{
"episode_id": "ep_001",
"t": 37,
"images": ["path/to/cam0_t.jpg", "path/to/cam1_t.jpg"],
"skill_current": "pick up the KitKat",
"skill_history": [...],
"user_prompt": "Can you grab me something sweet?",
"robot_utterance": "Sure, I can pick up the KitKat.",
"task_description": "making a sandwich"
}
Store as syn_annotations.jsonl. for debugging
🧠 pgen MODEL (Qwen) REQUIREMENTS
Use HuggingFace Transformers:
Qwen/Qwen2-VL-7B-Instruct (or any Qwen2-VL Vision-Language model available)
Use the image + text chat interface
Vision inputs should be loaded with PIL
Use a single forward pass that outputs BOTH _t and u_t in a structured JSON
📝 PROMPT FORMAT FOR pgen
Create a template like:
You are a robot-assistant dialogue generator for hierarchical robot policies.
You will receive:
- A list of images showing the current robot scene.
- The high-level task: {task_description}
- Previous skill steps completed: {skill_history}
- The next skill to be performed by the robot: {skill_current}
Generate two things in JSON:
1. "user_prompt": a natural-sounding user request that logically leads to the robot performing the skill "{skill_current}" given the task and history.
2. "robot_utterance": a natural robot reply acknowledging or clarifying the request.
The responses must be grounded in the visual scene, the task, and the skill history.
Respond ONLY in JSON:
{
"user_prompt": "...",
"robot_utterance": "..."
}
This resposne will have a corresponsing task_index, and the task will be saved in task.parqeut and you must update each dataset parquet in for example /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace/data/chunk-000/
file-000.parquet to include this new feature called task_index_high_level consider udpatign the metadata in info.json as well
📌 LOGIC REQUIRED
construct_prompt(sample)
Loads sample dict
Inserts:
task_description
skill_history
skill_current
Returns a full text prompt string
call_qwen(images, prompt)
Loads images into Qwen-VL multimodal input format
Calls model.generate
Parses JSON output
annotate_sample(sample)
Builds prompt
Calls Qwen
Returns augmented sample with user_prompt + robot_utterance
🚀 CLI Usage
The script should run as:
python annotate_pgen.py \
--output-dir PATH \
--model Qwen/Qwen2-VL-7B-Instruct \
--repo-id lerobot/svla_so101_pickplace \
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
--batch-size 1
Include arguments via argparse.
🔧 OTHER REQUIREMENTS
Use tqdm for progress bars
Log errors gracefully and continue
Support GPU acceleration (device="cuda")
Cache model loading so it's not reloaded every call
Make the prompt deterministic but allow temperature parameter
Add a flag --num-image-views-per-sample
Add automatic JSON parsing with helpful error messages
🎯 FINAL DELIVERABLE
Cursor must now generate:
A full Python file named annotate_pgen.py implementing the above functionality end-to-end.
It should be production-ready, runnable on real data, cleanly structured, and easy to modify.
from the paper:
Next, we use a large vision-language model (VLM) pgen
to produce synthetic user prompts and interjections t,
and corresponding robot utterance ut. Given Dlabeled, we
prompt pgen with both the visual context I1
t ,...,In
t and the
skill labelˆ
t (e.g., pick up the lettuce). pgen then imag-
ines an appropriate interaction that might have led toˆ
t in a
real user interaction: it generates possible user prompts t
(e.g., “Can you add some lettuce for me?”) along with the
robots verbal responses and clarifications ut. We detail the
A. Synthetic Data Generation
A.1. Scenario and Response Categorization
To ensure the quality and diversity of the synthetic data,
we incorporate structured scenario classification and re-
sponse categorization into the prompt design for pgen, fol-
lowing (Stephan et al., 2024). Specifically, we classify
interactions into different scenario types, such as nega-
tive task (where the user instructs the robot what not to
do), situated correction (where the user adjusts an earlier
command based on the evolving task state), and specific
constraint (where the user specifies particular constraints,
such as dietary preferences). In addition, we categorize
the robots responses into types such as simple confirma-
tions, clarifications, and error handling. These classifica-
tions guide the generation process to ensure a broad range
of user-robot interactions.
A.2. Prompt Construction for Contextual Grounding
In prompt P, we include a detailed description of the task
(e.g., bussing a table, making a sandwich, grocery shop-
ping) and instruct the model to ground responses in visual
observations and prior context. A key advantage of lever-
aging large pretrained VLMs is their ability to incorporate
world knowledge when generating interactions. For in-
stance, the model can infer dietary constraints when gener-
ating prompts for sandwich-making, producing user com-
mands such as “Can you make a sandwich for me? Im
lactose intolerant” and an appropriate robot response like
“Sure, I wont put cheese on it.” Similarly, it can reason
over ambiguous or implicit requests, such as inferring that
“I want something sweet” in a grocery shopping scenario
should lead to suggestions like chocolate or candy.
To maintain consistency in multi-step tasks, we condition
pgen on prior skill labels within an episodeˆ
ˆ
0,...,
t1,
allowing it to generate coherent user commands that
account for past actions. For instance, if the robot
has already placed lettuce and tomato on a sandwich,
the generated user prompt might request additional in-
gredients that logically follow. This ensures that the
synthetic interactions reflect realistic task progression
rather than isolated commands. As such, we leverage
ˆ
ˆ
ˆ
pgen(t,ut|I1
t ,...,In
t ,
0,...,
t1,
t,P) to produce a richer,
more diverse synthetic dataset Dsyn that provides mean-
ingful supervision for training our high-level policy.
While in this work we generate a separate Dsyn and train
a separate high-level policy for each task (e.g., sandwich
making vs. table cleaning) for clarity and ease of bench-
marking, the architecture is readily amenable to a unified
multi-task formulation. In principle, the same hierarchical
approach could be used to train a single high-level policy
across a multitude of tasks, facilitating knowledge transfer
The result should be a new LeRobotDataset with a new feature called task_index_high_level inside each dataset parquet
-11
View File
@@ -1,11 +0,0 @@
python examples/dataset/annotate.py \
--repo-id jadechoghari/collect-data \
--video-key observation.images.base \
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
--episodes 16 22
# python examples/dataset/annotate.py \
# --repo-id lerobot/svla_so101_pickplace \
# --video-key observation.images.side \
# --model Qwen/Qwen3-VL-30B-A3B-Instruct \
# --episodes 5
-43
View File
@@ -1,43 +0,0 @@
#!/bin/bash
# Example script to run synthetic data generation with Qwen VLM
# This generates user prompts and robot utterances for hierarchical policy training
# Configuration
REPO_ID="jadechoghari/collect-data"
MODEL="Qwen/Qwen3-VL-30B-A3B-Instruct"
# Alternative: MODEL="Qwen/Qwen2-VL-7B-Instruct"
OUTPUT_DIR="/fsx/jade_choghari/outputs/collect-data-pgen"
BATCH_SIZE=32
TEMPERATURE=0.9
SAMPLE_INTERVAL=5.0 # Generate dialogue every 1 second (all episodes processed)
# Run synthetic data generation (processes ALL episodes)
python examples/dataset/annotate_pgen.py \
--repo-id "$REPO_ID" \
--model "$MODEL" \
--output-dir "$OUTPUT_DIR" \
--temperature "$TEMPERATURE" \
--batch-size "$BATCH_SIZE" \
--sample-interval "$SAMPLE_INTERVAL" \
--image-key observation.images.base \
--num-image-views-per-sample 1
# For faster testing, increase sample interval:
# --sample-interval 5.0 # Samples every 5 seconds (much faster)
# To push to hub after generation:
# Add --push-to-hub flag
# Efficient batch processing: 4 episodes at once
# python examples/dataset/annotate_pgen.py \
# --repo-id "$REPO_ID" \
# --model "$MODEL" \
# --output-dir "$OUTPUT_DIR" \
# --video-mode \
# --video-key observation.images.up \
# --video-batch-size "$BATCH_SIZE" \
# --sample-interval 1.0
-802
View File
@@ -1,802 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SARM Subtask Annotation using local GPU (Qwen3-VL).
This script implements the annotation approach from the SARM paper using local GPU inference:
"SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation"
Paper: https://arxiv.org/pdf/2509.25358
What it does:
1. Takes videos from a LeRobot dataset
2. Uses Qwen3-VL running locally on GPU to identify when subtasks occur
3. Saves subtask timestamps to the dataset metadata
4. Optionally pushes the annotated dataset to HuggingFace Hub
SARM trains reward models that predict:
- Stage: Which subtask is currently being executed (discrete classification)
- Progress: How far along the subtask we are (continuous 0-1)
Supports three annotation modes:
1. No annotations (no args): Auto-creates single sparse "task" stage covering full episode.
Use with SARM config annotation_mode="single_stage" for simple tasks.
2. Dense-only (--dense-only --dense-subtasks): Dense annotations from VLM, auto-generated
single sparse "task" stage. Use with annotation_mode="dense_only".
3. Dual mode (--sparse-subtasks + --dense-subtasks): Both sparse and dense annotations
from VLM. Use with annotation_mode="dual".
Requirements:
- GPU with sufficient VRAM (16GB+ recommended for 30B model)
- `pip install transformers, torch, qwen-vl-utils`
Run with:
```bash
python examples/dataset_annotation/subtask_annotation.py \
--repo-id your-username/your-dataset \
--sparse-subtasks "Do ..." \
--dense-subtasks "Do task 1, Do task 2, Do task 3" \
--video-key observation.images.base \
--push-to-hub
```
"""
import argparse
import json
import multiprocessing as mp
import re
import subprocess
import tempfile
import textwrap
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
import cv2
import pandas as pd
import torch
from qwen_vl_utils import process_vision_info
from rich.console import Console
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.sarm.sarm_utils import (
Subtask,
SubtaskAnnotation,
Timestamp,
compute_temporal_proportions,
)
def create_sarm_prompt(subtask_list: list[str]) -> str:
subtask_str = "\n".join([f" - {name}" for name in subtask_list])
return textwrap.dedent(f"""\
# Role
You are a Robotics Vision System specializing in temporal action localization for robot manipulation. Your job is to segment a single demonstration video into distinct, non-overlapping atomic actions from a fixed subtask list.
# Subtask Label Set (Closed Vocabulary)
You must strictly identify the video segments using ONLY the following labels. Do not create new labels or modify existing ones:
[
{subtask_str}
]
The video shows one successful execution of all subtasks in a logical order.
# Ground-Truth Semantics (Very Important)
Use **visual state changes** to define when a subtask starts and ends. Do NOT assume equal durations for the subtasks.
- A subtask **starts** at the first frame where the robot's motion clearly initiates that subtask.
- A subtask **ends** at the first frame where that specific action is visually completed and the manipulated object reaches a temporary, stable configuration.
If there are short pauses or micro-motions that don't clearly correspond to a new subtask, they belong to the **current** subtask.
# Hard Constraints & Logic
1. **Continuous Coverage (No Gaps):**
- The entire video duration from "00:00" to the final timestamp must be covered by subtasks.
- There can be no gaps between subtasks.
- If there is any idle or ambiguous time between clear actions, extend the *preceding* subtask to cover it.
2. **Boundary Consistency:**
- The `"end"` timestamp of one subtask must be exactly equal to the `"start"` timestamp of the next subtask.
- Boundaries must coincide with a real visual state transition, not just a convenient time split.
3. **Chronological Order, One Occurrence Each:**
- This is a single successful demonstration.
- Each subtask from the vocabulary appears **exactly once**, in the correct logical order.
- **Durations may be very different** between subtasks. Never assume they are similar lengths. Base all boundaries only on the video.
4. **Reject Uniform Segmentation (Important):**
- Do NOT simply divide the video into equal or nearly equal time chunks.
- If your boundaries would result in subtasks with similar durations (e.g. all around 5 seconds), treat this as evidence that your segmentation is wrong and refine the boundaries.
- Only use nearly equal durations if the video truly shows each subtask taking the same amount of time (this is very rare).
5. **Timestamps:**
- Timestamps must be in `"MM:SS"` format.
- The first subtask always starts at `"00:00"`.
- The last subtask ends at the final visible frame of the video.
# Step 1 — Textual Timeline (must do this first)
First, write a extensive and detailed textual timeline describing what happens in the video with approximate timestamps.
For each subtask, include:
- its name
- an approximate start and end time,
- an description of the visual event at the boundary (e.g. "shirt fully folded to the left", "robot rotates folded shirt 90 degrees").
Format this as a bullet list.
# Step 2 — JSON Output (final answer)
After the textual timeline, output **only** valid JSON with this structure.
The JSON **must** be consistent with the textual timeline above:
{{
"subtasks": [
{{
"name": "EXACT_NAME_FROM_LIST",
"timestamps": {{
"start": "MM:SS",
"end": "MM:SS"
}}
}},
{{
"name": "EXACT_NAME_FROM_LIST",
"timestamps": {{
"start": "MM:SS",
"end": "MM:SS"
}}
}}
]
}}
Do not add any extra keys to the JSON.
""")
class VideoAnnotator:
"""Annotates robot manipulation videos using local Qwen3-VL model on GPU"""
def __init__(
self,
subtask_list: list[str],
model_name: str = "Qwen/Qwen3-VL-30B-A3B-Instruct",
device: str = "cuda",
torch_dtype: torch.dtype = torch.bfloat16,
model: "Qwen3VLMoeForConditionalGeneration | None" = None,
processor: "AutoProcessor | None" = None,
):
"""
Initialize the video annotator with local model.
Args:
subtask_list: List of allowed subtask names (for consistency)
model_name: Hugging Face model name (default: Qwen/Qwen3-VL-30B-A3B-Instruct)
device: Device to use (cuda, cpu)
torch_dtype: Data type for model (bfloat16, float16, float32)
model: Pre-loaded model instance (optional, to share between annotators)
processor: Pre-loaded processor instance (optional, to share between annotators)
"""
self.subtask_list = subtask_list
self.prompt = create_sarm_prompt(subtask_list)
self.console = Console()
self.device = device
# Use provided model/processor or load new ones
if model is not None and processor is not None:
self.model = model
self.processor = processor
self.console.print(f"[green]✓ Using shared model on {device}[/green]")
else:
self.console.print(f"[cyan]Loading model: {model_name}...[/cyan]")
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
def extract_episode_segment(
self, file_path: Path, start_timestamp: float, end_timestamp: float, target_fps: int = 1
) -> Path:
"""
Extract a specific episode segment from concatenated video.
Uses minimal compression to preserve quality for local inference.
Args:
file_path: Path to the concatenated video file
start_timestamp: Starting timestamp in seconds (within this video file)
end_timestamp: Ending timestamp in seconds (within this video file)
target_fps: Target FPS (default: 1 for faster processing)
Returns:
Path to extracted video file
"""
# Create temporary file for extracted video
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
tmp_path = Path(tmp_file.name)
tmp_file.close()
try:
# Check if ffmpeg is available
subprocess.run(
["ffmpeg", "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True
)
except (subprocess.CalledProcessError, FileNotFoundError):
raise RuntimeError("ffmpeg not found, cannot extract episode segment") from e
try:
# Calculate duration
duration = end_timestamp - start_timestamp
self.console.print(
f"[cyan]Extracting episode: {start_timestamp:.1f}s-{end_timestamp:.1f}s ({duration:.1f}s)[/cyan]"
)
# Use ffmpeg to extract segment with minimal quality loss
cmd = [
"ffmpeg",
"-i",
str(file_path),
"-ss",
str(start_timestamp),
"-t",
str(duration),
"-r",
str(target_fps),
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-crf",
"23",
"-an",
"-y",
str(tmp_path),
]
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
# Verify the output file was created and is not empty
if not tmp_path.exists() or tmp_path.stat().st_size == 0:
self.console.print("[red]✗ Video extraction failed (0 bytes) - skipping episode[/red]")
if tmp_path.exists():
tmp_path.unlink()
raise RuntimeError("FFmpeg produced empty video file")
# Show extraction results
file_size_mb = tmp_path.stat().st_size / (1024 * 1024)
# Fail if file is too small (< 100KB likely means extraction failed)
if file_size_mb < 0.1:
self.console.print(
f"[red]✗ Extracted video too small ({file_size_mb:.2f}MB) - skipping episode[/red]"
)
tmp_path.unlink()
raise RuntimeError(f"Video extraction produced invalid file ({file_size_mb:.2f}MB)")
self.console.print(f"[green]✓ Extracted: {file_size_mb:.1f}MB ({target_fps} FPS)[/green]")
return tmp_path
except subprocess.CalledProcessError as e:
raise RuntimeError(f"ffmpeg failed ({e})") from e
def annotate(
self,
file_path: str | Path,
fps: int,
start_timestamp: float = 0.0,
end_timestamp: float | None = None,
max_retries: int = 3,
) -> SubtaskAnnotation:
"""Annotate a video segment using local GPU."""
file_path = Path(file_path)
if end_timestamp is None:
cap = cv2.VideoCapture(str(file_path))
end_timestamp = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) / (cap.get(cv2.CAP_PROP_FPS) or 1)
cap.release()
duration = end_timestamp - start_timestamp
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
extracted_path = self.extract_episode_segment(file_path, start_timestamp, end_timestamp, 1)
is_extracted = extracted_path != file_path
try:
messages = [
{"role": "system", "content": [{"type": "text", "text": self.prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(extracted_path), "fps": 1.0},
{
"type": "text",
"text": f"Video is {duration_str} (~{duration:.1f}s). Follow instructions.",
},
],
},
]
for attempt in range(max_retries):
try:
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
response = self.processor.batch_decode(
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
skip_special_tokens=True,
)[0].strip()
# Extract JSON
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
try:
return SubtaskAnnotation.model_validate(json.loads(response))
except json.JSONDecodeError:
match = re.search(r"\{.*\}", response, re.DOTALL)
if match:
return SubtaskAnnotation.model_validate(json.loads(match.group()))
raise ValueError("No JSON found")
except Exception as e:
if attempt == max_retries - 1:
raise RuntimeError(f"Failed after {max_retries} attempts") from e
time.sleep(1)
finally:
if is_extracted and extracted_path.exists():
extracted_path.unlink()
def display_annotation(
annotation: SubtaskAnnotation, console: Console, episode_idx: int, fps: int, prefix: str = ""
):
"""Display annotation summary."""
subtask_summary = ", ".join(
f"{s.name}({s.timestamps.start}-{s.timestamps.end})" for s in annotation.subtasks
)
console.print(
f"[green]Episode {episode_idx} {prefix}: {len(annotation.subtasks)} subtasks - {subtask_summary}[/green]"
)
def timestamp_to_seconds(timestamp: str) -> float:
"""Convert MM:SS or SS timestamp to seconds"""
parts = timestamp.split(":")
if len(parts) == 2:
return int(parts[0]) * 60 + int(parts[1])
else:
return int(parts[0])
def save_annotations_to_dataset(
dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse"
):
"""Save annotations to LeRobot dataset parquet format."""
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, load_episodes
episodes_dataset = load_episodes(dataset_path)
if not episodes_dataset or len(episodes_dataset) == 0:
return
episodes_df = episodes_dataset.to_pandas()
cols = [
f"{prefix}_{c}"
for c in [
"subtask_names",
"subtask_start_times",
"subtask_end_times",
"subtask_start_frames",
"subtask_end_frames",
]
]
for col in cols:
episodes_df[col] = None
for ep_idx, ann in annotations.items():
if ep_idx >= len(episodes_df):
continue
names, starts, ends, start_frames, end_frames = [], [], [], [], []
for s in ann.subtasks:
names.append(s.name)
st, et = timestamp_to_seconds(s.timestamps.start), timestamp_to_seconds(s.timestamps.end)
starts.append(st)
ends.append(et)
start_frames.append(int(st * fps))
end_frames.append(int(et * fps))
episodes_df.at[ep_idx, cols[0]] = names
episodes_df.at[ep_idx, cols[1]] = starts
episodes_df.at[ep_idx, cols[2]] = ends
episodes_df.at[ep_idx, cols[3]] = start_frames
episodes_df.at[ep_idx, cols[4]] = end_frames
# Group by file and write
for ep_idx in episodes_df.index:
key = (
episodes_df.loc[ep_idx, "meta/episodes/chunk_index"],
episodes_df.loc[ep_idx, "meta/episodes/file_index"],
)
path = dataset_path / DEFAULT_EPISODES_PATH.format(chunk_index=key[0], file_index=key[1])
if path.exists():
file_df = pd.read_parquet(path)
for col in cols + (
[
"subtask_names",
"subtask_start_times",
"subtask_end_times",
"subtask_start_frames",
"subtask_end_frames",
]
if prefix == "sparse"
else []
):
if col not in file_df.columns:
file_df[col] = None
if ep_idx in annotations:
for col in cols:
file_df.at[ep_idx, col] = episodes_df.loc[ep_idx, col]
if prefix == "sparse": # Legacy columns
for i, legacy in enumerate(
[
"subtask_names",
"subtask_start_times",
"subtask_end_times",
"subtask_start_frames",
"subtask_end_frames",
]
):
file_df.at[ep_idx, legacy] = episodes_df.loc[ep_idx, cols[i]]
file_df.to_parquet(path, engine="pyarrow", compression="snappy")
def generate_auto_sparse_annotations(
dataset: LeRobotDataset, episode_indices: list[int], video_key: str
) -> dict[int, SubtaskAnnotation]:
"""Auto-generate single 'task' stage annotations for all episodes."""
annotations = {}
for ep_idx in episode_indices:
start = float(dataset.meta.episodes[f"videos/{video_key}/from_timestamp"][ep_idx])
end = float(dataset.meta.episodes[f"videos/{video_key}/to_timestamp"][ep_idx])
duration = end - start
end_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
annotations[ep_idx] = SubtaskAnnotation(
subtasks=[Subtask(name="task", timestamps=Timestamp(start="00:00", end=end_str))]
)
return annotations
def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") -> dict[int, SubtaskAnnotation]:
"""Load annotations from LeRobot dataset parquet files."""
from lerobot.datasets.utils import load_episodes
episodes_dataset = load_episodes(dataset_path)
if not episodes_dataset or len(episodes_dataset) == 0:
return {}
col_names = f"{prefix}_subtask_names"
col_start = f"{prefix}_subtask_start_times"
col_end = f"{prefix}_subtask_end_times"
# Fall back to legacy columns for sparse
if col_names not in episodes_dataset.column_names:
if prefix == "sparse" and "subtask_names" in episodes_dataset.column_names:
col_names, col_start, col_end = "subtask_names", "subtask_start_times", "subtask_end_times"
else:
return {}
df = episodes_dataset.to_pandas()
annotations = {}
for ep_idx in df.index:
names = df.loc[ep_idx, col_names]
if names is None or (isinstance(names, float) and pd.isna(names)):
continue
starts, ends = df.loc[ep_idx, col_start], df.loc[ep_idx, col_end]
annotations[int(ep_idx)] = SubtaskAnnotation(
subtasks=[
Subtask(
name=n,
timestamps=Timestamp(
start=f"{int(s) // 60:02d}:{int(s) % 60:02d}",
end=f"{int(e) // 60:02d}:{int(e) % 60:02d}",
),
)
for n, s, e in zip(names, starts, ends)
]
)
return annotations
def process_single_episode(
ep_idx: int,
dataset_root: Path,
dataset_meta,
video_key: str,
fps: int,
annotator: VideoAnnotator,
console: Console,
) -> tuple[int, SubtaskAnnotation | None, str | None]:
"""Process a single episode annotation."""
try:
video_path = dataset_root / dataset_meta.get_video_file_path(ep_idx, video_key)
if not video_path.exists():
return ep_idx, None, f"Video not found: {video_path}"
start = float(dataset_meta.episodes[f"videos/{video_key}/from_timestamp"][ep_idx])
end = float(dataset_meta.episodes[f"videos/{video_key}/to_timestamp"][ep_idx])
return ep_idx, annotator.annotate(video_path, fps, start, end), None
except Exception as e:
return ep_idx, None, str(e)
def worker_process_episodes(
worker_id: int,
gpu_id: int,
episode_indices: list[int],
repo_id: str,
video_key: str,
sparse_subtask_list: list[str],
dense_subtask_list: list[str] | None,
model_name: str,
torch_dtype: torch.dtype,
) -> tuple[dict, dict | None]:
"""Worker for parallel processing across GPUs."""
device = f"cuda:{gpu_id}"
console = Console()
dataset = LeRobotDataset(repo_id, download_videos=False)
sparse_annotator = VideoAnnotator(sparse_subtask_list, model_name, device, torch_dtype)
dense_annotator = (
VideoAnnotator(
dense_subtask_list,
model_name,
device,
torch_dtype,
sparse_annotator.model,
sparse_annotator.processor,
)
if dense_subtask_list
else None
)
sparse_annotations, dense_annotations = {}, {} if dense_subtask_list else None
for ep_idx in episode_indices:
_, sparse_ann, err = process_single_episode(
ep_idx, dataset.root, dataset.meta, video_key, dataset.fps, sparse_annotator, console
)
if sparse_ann:
sparse_annotations[ep_idx] = sparse_ann
if dense_annotator:
_, dense_ann, _ = process_single_episode(
ep_idx, dataset.root, dataset.meta, video_key, dataset.fps, dense_annotator, console
)
if dense_ann:
dense_annotations[ep_idx] = dense_ann
return sparse_annotations, dense_annotations
def main():
parser = argparse.ArgumentParser(description="SARM-style subtask annotation using local GPU (Qwen3-VL)")
parser.add_argument("--repo-id", type=str, required=True, help="HuggingFace dataset repository ID")
parser.add_argument(
"--sparse-subtasks", type=str, default=None, help="Comma-separated sparse subtask names"
)
parser.add_argument(
"--dense-subtasks", type=str, default=None, help="Comma-separated dense subtask names"
)
parser.add_argument(
"--dense-only", action="store_true", help="Dense-only mode with auto-generated sparse 'task' stage"
)
parser.add_argument("--episodes", type=int, nargs="+", default=None, help="Episode indices to annotate")
parser.add_argument("--model", type=str, default="Qwen/Qwen3-VL-30B-A3B-Instruct", help="VLM model")
parser.add_argument("--skip-existing", action="store_true", help="Skip already annotated episodes")
parser.add_argument("--video-key", type=str, default=None, help="Video key (default: first available)")
parser.add_argument("--push-to-hub", action="store_true", help="Push to HuggingFace Hub")
parser.add_argument("--output-repo-id", type=str, default=None, help="Output repo ID for push")
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
parser.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"])
parser.add_argument("--num-workers", type=int, default=1, help="Parallel workers for multi-GPU")
parser.add_argument("--gpu-ids", type=int, nargs="+", default=None, help="GPU IDs to use")
args = parser.parse_args()
console = Console()
# Validate arguments
if args.dense_only and not args.dense_subtasks:
return console.print("[red]Error: --dense-only requires --dense-subtasks[/red]")
if args.dense_subtasks and not args.sparse_subtasks and not args.dense_only:
return console.print("[red]Error: --dense-subtasks requires --sparse-subtasks or --dense-only[/red]")
sparse_subtask_list = (
[s.strip() for s in args.sparse_subtasks.split(",")] if args.sparse_subtasks else None
)
dense_subtask_list = [s.strip() for s in args.dense_subtasks.split(",")] if args.dense_subtasks else None
auto_sparse = sparse_subtask_list is None
dense_mode = dense_subtask_list is not None
torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
console.print(f"[cyan]Loading dataset: {args.repo_id}[/cyan]")
dataset = LeRobotDataset(args.repo_id, download_videos=True)
fps = dataset.fps
if not dataset.meta.video_keys:
raise ValueError("No video keys found")
video_key = (
args.video_key if args.video_key in (dataset.meta.video_keys or []) else dataset.meta.video_keys[0]
)
console.print(f"[cyan]Using camera: {video_key}, FPS: {fps}[/cyan]")
# Determine episodes
episode_indices = args.episodes or list(range(dataset.meta.total_episodes))
existing_annotations = load_annotations_from_dataset(dataset.root, prefix="sparse")
if args.skip_existing:
episode_indices = [ep for ep in episode_indices if ep not in existing_annotations]
if not episode_indices:
return console.print("[green]All episodes already annotated![/green]")
console.print(f"[cyan]Annotating {len(episode_indices)} episodes[/cyan]")
# GPU setup
gpu_ids = args.gpu_ids or list(
range(min(args.num_workers, torch.cuda.device_count() if torch.cuda.is_available() else 1))
)
args.num_workers = len(gpu_ids)
sparse_annotations = existing_annotations.copy()
dense_annotations = {} if dense_mode else None
# Auto-sparse mode
if auto_sparse:
sparse_annotations.update(generate_auto_sparse_annotations(dataset, episode_indices, video_key))
save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse")
console.print(f"[green]Auto-generated {len(episode_indices)} sparse 'task' annotations[/green]")
# VLM annotation (for sparse if not auto, and for dense)
need_vlm = (not auto_sparse) or dense_mode
if need_vlm:
if args.num_workers > 1 and not auto_sparse:
# Parallel processing
console.print(f"[cyan]Parallel processing with {args.num_workers} workers[/cyan]")
episodes_per_worker = [[] for _ in range(args.num_workers)]
for i, ep_idx in enumerate(episode_indices):
episodes_per_worker[i % args.num_workers].append(ep_idx)
with ProcessPoolExecutor(
max_workers=args.num_workers, mp_context=mp.get_context("spawn")
) as executor:
futures = [
executor.submit(
worker_process_episodes,
w,
gpu_ids[w],
episodes_per_worker[w],
args.repo_id,
video_key,
sparse_subtask_list,
dense_subtask_list,
args.model,
torch_dtype,
)
for w in range(args.num_workers)
if episodes_per_worker[w]
]
for future in as_completed(futures):
try:
worker_sparse, worker_dense = future.result()
sparse_annotations.update(worker_sparse)
if dense_mode and worker_dense:
dense_annotations.update(worker_dense)
save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse")
if dense_mode:
save_annotations_to_dataset(dataset.root, dense_annotations, fps, prefix="dense")
except Exception as e:
raise RuntimeError(f"Worker failed: {e}") from e
else:
# Sequential processing
sparse_annotator = (
VideoAnnotator(sparse_subtask_list, args.model, args.device, torch_dtype)
if not auto_sparse and sparse_subtask_list
else None
)
dense_annotator = (
VideoAnnotator(
dense_subtask_list,
args.model,
args.device,
torch_dtype,
sparse_annotator.model if sparse_annotator else None,
sparse_annotator.processor if sparse_annotator else None,
)
if dense_mode
else None
)
for i, ep_idx in enumerate(episode_indices):
console.print(f"[cyan]Episode {ep_idx} ({i + 1}/{len(episode_indices)})[/cyan]")
if sparse_annotator:
_, sparse_ann, err = process_single_episode(
ep_idx, dataset.root, dataset.meta, video_key, fps, sparse_annotator, console
)
if sparse_ann:
sparse_annotations[ep_idx] = sparse_ann
save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse")
elif err:
console.print(f"[red]Sparse failed: {err}[/red]")
if dense_annotator:
_, dense_ann, err = process_single_episode(
ep_idx, dataset.root, dataset.meta, video_key, fps, dense_annotator, console
)
if dense_ann:
dense_annotations[ep_idx] = dense_ann
save_annotations_to_dataset(dataset.root, dense_annotations, fps, prefix="dense")
elif err:
console.print(f"[red]Dense failed: {err}[/red]")
# Save temporal proportions
def save_proportions(annotations, prefix, is_auto=False):
props: dict[str, float] = {"task": 1.0} if is_auto else compute_temporal_proportions(annotations, fps)
path = dataset.root / "meta" / f"temporal_proportions_{prefix}.json"
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
json.dump(props, f, indent=2)
console.print(f"[green]Saved {prefix} temporal proportions[/green]")
save_proportions(sparse_annotations, "sparse", auto_sparse)
if dense_mode and dense_annotations:
save_proportions(dense_annotations, "dense")
console.print(
f"\n[bold green]Complete! {len(sparse_annotations)} sparse, {len(dense_annotations or {})} dense annotations[/bold green]"
)
if args.push_to_hub:
try:
dataset.push_to_hub(push_videos=True)
console.print(f"[green]Pushed to {args.output_repo_id or args.repo_id}[/green]")
except Exception as e:
console.print(f"[red]Push failed: {e}[/red]")
if __name__ == "__main__":
main()
-1
View File
@@ -1 +0,0 @@
srun --time 12:00:00 --qos=high --gres=gpu:1 --mem=24G --partition=hopper-prod --container-image /fsx/michel_aractingi/docker_images/huggingface+lerobot-gpu+dev.sqsh --container-mounts /fsx/jade_choghari
-44
View File
@@ -1,44 +0,0 @@
#!/bin/bash
# Quick test to verify the fix for task_indices length mismatch
# This should now work correctly even with --num-samples < full dataset length
echo "Testing annotate_pgen.py with --num-samples=100 on full dataset..."
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
--num-samples 100 \
--sample-interval 1.0 \
--output-dir /fsx/jade_choghari/outputs/pgen_test_fixed
if [ $? -eq 0 ]; then
echo "✓ SUCCESS: Script completed without errors!"
echo ""
echo "Verifying output..."
# Check that all frames have task_index_high_level
python -c "
from lerobot.datasets.lerobot_dataset import LeRobotDataset
import numpy as np
ds = LeRobotDataset(repo_id='local_test', root='/fsx/jade_choghari/outputs/pgen_test_fixed')
print(f'Dataset has {len(ds)} frames')
print(f'Features: {list(ds.features.keys())}')
# Check that task_index_high_level exists
assert 'task_index_high_level' in ds.features, 'task_index_high_level not in features!'
# Sample some frames
for idx in [0, 50, 99, 100, 500, 1000, 11938]:
if idx < len(ds):
frame = ds[idx]
task_idx = frame['task_index_high_level'].item()
print(f'Frame {idx}: task_index_high_level = {task_idx}')
print('✓ All checks passed!')
"
else
echo "✗ FAILED: Script exited with error code $?"
fi
+454
View File
@@ -0,0 +1,454 @@
#!/usr/bin/env python3
"""
WBT (Whole Body Tracking) Dance Policy for Unitree G1
Uses ONNX model with motion data baked in.
Pattern matches gr00t_locomotion.py - uses UnitreeG1 robot class.
Usage:
python examples/unitree_g1/dance.py
"""
import argparse
import json
import logging
import threading
import time
from xml.etree import ElementTree
import numpy as np
import onnx
import onnxruntime as ort
import pinocchio as pin
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# =============================================================================
# CONFIGURATION
# =============================================================================
DANCE_ONNX_PATH = "examples/unitree_g1/fastsac_g1_29dof_dancing.onnx"
CONTROL_DT = 0.02 # 50 Hz
NUM_DOFS = 29
# Default joint positions (holosoma training defaults)
DEFAULT_DOF_POS = np.array([
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # Left leg (6)
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # Right leg (6)
0.0, 0.0, 0.0, # Waist (3)
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # Left arm (7)
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # Right arm (7)
], dtype=np.float32)
# Stiff hold KP/KD (for initialization)
STIFF_KP = np.array([
150, 150, 200, 200, 40, 40,
150, 150, 200, 200, 40, 40,
200, 200, 100,
100, 100, 100, 100, 50, 50, 50,
100, 100, 100, 100, 50, 50, 50,
], dtype=np.float32)
STIFF_KD = np.array([
2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
5.0, 5.0, 5.0,
2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
], dtype=np.float32)
# Joints to freeze at 0 with high KP
FROZEN_JOINTS = [13, 14, 20, 21, 27, 28]
FROZEN_KP = 500.0
FROZEN_KD = 5.0
# =============================================================================
# QUATERNION UTILITIES
# =============================================================================
def quat_inverse(q):
return np.concatenate((q[:, 0:1], -q[:, 1:]), axis=1)
def quat_mul(a, b):
a, b = a.reshape(-1, 4), b.reshape(-1, 4)
w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
xx = ww + yy + zz
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
w = qq - ww + (z1 - y1) * (y2 - z2)
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
return np.stack([w, x, y, z]).T.reshape(a.shape)
def subtract_frame_transforms(q01, q02):
return quat_mul(quat_inverse(q01), q02)
def matrix_from_quat(q):
r, i, j, k = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
two_s = 2.0 / (q * q).sum(-1)
o = np.stack((
1 - two_s * (j*j + k*k), two_s * (i*j - k*r), two_s * (i*k + j*r),
two_s * (i*j + k*r), 1 - two_s * (i*i + k*k), two_s * (j*k - i*r),
two_s * (i*k - j*r), two_s * (j*k + i*r), 1 - two_s * (i*i + j*j),
), -1)
return o.reshape(q.shape[:-1] + (3, 3))
def xyzw_to_wxyz(xyzw):
return np.concatenate([xyzw[:, -1:], xyzw[:, :3]], axis=1)
def quat_to_rpy(q):
w, x, y, z = q
roll = np.arctan2(2*(w*x + y*z), 1 - 2*(x**2 + y**2))
pitch = np.arcsin(np.clip(2*(w*y - z*x), -1, 1))
yaw = np.arctan2(2*(w*z + x*y), 1 - 2*(y**2 + z**2))
return roll, pitch, yaw
def rpy_to_quat(rpy):
roll, pitch, yaw = rpy
cy, sy = np.cos(yaw*0.5), np.sin(yaw*0.5)
cp, sp = np.cos(pitch*0.5), np.sin(pitch*0.5)
cr, sr = np.cos(roll*0.5), np.sin(roll*0.5)
return np.array([cr*cp*cy + sr*sp*sy, sr*cp*cy - cr*sp*sy,
cr*sp*cy + sr*cp*sy, cr*cp*sy - sr*sp*cy])
# =============================================================================
# PINOCCHIO FK
# =============================================================================
DOF_NAMES = (
"left_hip_pitch_joint", "left_hip_roll_joint", "left_hip_yaw_joint",
"left_knee_joint", "left_ankle_pitch_joint", "left_ankle_roll_joint",
"right_hip_pitch_joint", "right_hip_roll_joint", "right_hip_yaw_joint",
"right_knee_joint", "right_ankle_pitch_joint", "right_ankle_roll_joint",
"waist_yaw_joint", "waist_roll_joint", "waist_pitch_joint",
"left_shoulder_pitch_joint", "left_shoulder_roll_joint", "left_shoulder_yaw_joint", "left_elbow_joint",
"left_wrist_roll_joint", "left_wrist_pitch_joint", "left_wrist_yaw_joint",
"right_shoulder_pitch_joint", "right_shoulder_roll_joint", "right_shoulder_yaw_joint", "right_elbow_joint",
"right_wrist_roll_joint", "right_wrist_pitch_joint", "right_wrist_yaw_joint",
)
class PinocchioFK:
"""Pinocchio forward kinematics for torso_link orientation."""
def __init__(self, urdf_text: str):
root = ElementTree.fromstring(urdf_text)
for parent in root.iter():
for child in list(parent):
if child.tag.split("}")[-1] in {"visual", "collision"}:
parent.remove(child)
xml_text = '<?xml version="1.0"?>\n' + ElementTree.tostring(root, encoding="unicode")
self.model = pin.buildModelFromXML(xml_text, pin.JointModelFreeFlyer())
self.data = self.model.createData()
pin_names = [n for n in self.model.names if n not in ["universe", "root_joint"]]
self.idx_map = np.array([DOF_NAMES.index(n) for n in pin_names])
self.ref_frame_id = self.model.getFrameId("torso_link")
logger.info(f"Pinocchio FK: {len(pin_names)} joints, torso_link frame={self.ref_frame_id}")
def get_torso_quat(self, pos, quat_wxyz, dof_pos):
"""Get torso_link orientation in world frame."""
quat_xyzw = np.array([quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]])
config = np.concatenate([pos, quat_xyzw, dof_pos[self.idx_map]])
pin.framesForwardKinematics(self.model, self.data, config)
coeffs = pin.Quaternion(self.data.oMf[self.ref_frame_id].rotation).coeffs()
return np.array([coeffs[3], coeffs[0], coeffs[1], coeffs[2]]).reshape(1, 4)
# =============================================================================
# DANCE CONTROLLER
# =============================================================================
class DanceController:
"""
Handles WBT dance policy for the Unitree G1 robot.
This controller manages:
- 29-joint observation processing
- Pinocchio FK for torso orientation
- Policy inference with motion data from ONNX
"""
def __init__(self, policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale):
self.policy = policy
self.robot = robot
self.pinocchio_fk = pinocchio_fk
self.motor_kp = motor_kp
self.motor_kd = motor_kd
self.action_scale = action_scale
self.obs_dim = policy.get_inputs()[0].shape[1]
self.last_action = np.zeros((1, NUM_DOFS), dtype=np.float32)
self.motion_command = None
self.ref_quat_xyzw = None
self.timestep = 0
self.yaw_offset = 0.0
# Get initial motion data from ONNX
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
self.motion_command = np.concatenate(outs[0:2], axis=1)
self.ref_quat_xyzw = outs[2]
self.motion_start_pose = outs[0].flatten()
# Thread management
self.dance_running = False
self.dance_thread = None
logger.info(f"DanceController: obs_dim={self.obs_dim}, action_scale={action_scale}")
def capture_yaw_offset(self):
"""Capture robot's current yaw for relative tracking."""
robot_state = self.robot.lowstate_buffer.get_data()
if robot_state and self.pinocchio_fk:
quat = np.array(robot_state.imu_state.quaternion, dtype=np.float32)
dof = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof)
_, _, self.yaw_offset = quat_to_rpy(torso_q.flatten())
logger.info(f"Captured yaw offset: {np.degrees(self.yaw_offset):.1f}°")
def _remove_yaw_offset(self, quat_wxyz):
"""Remove stored yaw offset from orientation."""
if abs(self.yaw_offset) < 1e-6:
return quat_wxyz
yaw_q = rpy_to_quat((0, 0, -self.yaw_offset)).reshape(1, 4)
return quat_mul(yaw_q, quat_wxyz)
def run_step(self):
"""Single dance step - reads state, runs policy, sends commands."""
robot_state = self.robot.lowstate_buffer.get_data()
if robot_state is None:
return
# Read robot state
quat = np.array(robot_state.imu_state.quaternion, dtype=np.float32)
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
dof_pos = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
dof_vel = np.array([robot_state.motor_state[i].dq for i in range(NUM_DOFS)], dtype=np.float32)
# Compute motion_ref_ori_b using FK
if self.pinocchio_fk:
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof_pos)
torso_q = self._remove_yaw_offset(torso_q)
motion_ori = xyzw_to_wxyz(self.ref_quat_xyzw)
rel_quat = subtract_frame_transforms(torso_q, motion_ori)
ori_b = matrix_from_quat(rel_quat)[..., :2].reshape(1, -1)
else:
ori_b = np.zeros((1, 6), dtype=np.float32)
dof_rel = (dof_pos - DEFAULT_DOF_POS).reshape(1, -1)
# Build observation (alphabetical order)
obs_dict = {
"actions": self.last_action,
"base_ang_vel": ang_vel.reshape(1, 3),
"dof_pos": dof_rel,
"dof_vel": dof_vel.reshape(1, -1),
"motion_command": self.motion_command,
"motion_ref_ori_b": ori_b,
}
obs = np.concatenate([obs_dict[k].astype(np.float32) for k in sorted(obs_dict.keys())], axis=1)
obs = np.clip(obs, -100, 100)
# Run policy
outs = self.policy.run(["actions", "joint_pos", "joint_vel", "ref_quat_xyzw"],
{"obs": obs, "time_step": np.array([[self.timestep]], dtype=np.float32)})
action = np.clip(outs[0], -100, 100)
self.motion_command = np.concatenate(outs[1:3], axis=1)
self.ref_quat_xyzw = outs[3]
self.last_action = action.copy()
# Compute target positions
target_pos = DEFAULT_DOF_POS + action.flatten() * self.action_scale
# Send commands
for i in range(NUM_DOFS):
if i in FROZEN_JOINTS:
self.robot.msg.motor_cmd[i].q = 0.0
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
else:
self.robot.msg.motor_cmd[i].q = float(target_pos[i])
self.robot.msg.motor_cmd[i].kp = self.motor_kp[i]
self.robot.msg.motor_cmd[i].kd = self.motor_kd[i]
self.robot.msg.motor_cmd[i].qd = 0
self.robot.msg.motor_cmd[i].tau = 0
self.robot.send_action(self.robot.msg)
self.timestep += 1
def _dance_thread_loop(self):
"""Background thread that runs the dance policy."""
logger.info("Dance thread started")
while self.dance_running:
start_time = time.time()
try:
self.run_step()
except Exception as e:
logger.error(f"Error in dance loop: {e}")
import traceback
traceback.print_exc()
elapsed = time.time() - start_time
sleep_time = max(0, CONTROL_DT - elapsed)
time.sleep(sleep_time)
logger.info("Dance thread stopped")
def start_dance_thread(self):
"""Start the dance control thread."""
if self.dance_running:
logger.warning("Dance thread already running")
return
# Reset state for fresh start
self.timestep = 0
self.last_action.fill(0)
# Re-get initial motion data
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
self.motion_command = np.concatenate(outs[0:2], axis=1)
self.ref_quat_xyzw = outs[2]
self.capture_yaw_offset()
logger.info("Starting dance control thread...")
self.dance_running = True
self.dance_thread = threading.Thread(target=self._dance_thread_loop, daemon=True)
self.dance_thread.start()
def stop_dance_thread(self):
"""Stop the dance control thread."""
if not self.dance_running:
return
logger.info("Stopping dance control thread...")
self.dance_running = False
if self.dance_thread:
self.dance_thread.join(timeout=2.0)
logger.info("Dance control thread stopped")
def reset_to_motion_pose(self, duration: float = 3.0):
"""Move robot to initial motion pose over given duration."""
logger.info(f"Moving to dance start pose ({duration}s)...")
robot_state = self.robot.lowstate_buffer.get_data()
init_pos = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
target_pos = self.motion_start_pose
num_steps = int(duration / CONTROL_DT)
for step in range(num_steps):
alpha = step / num_steps
interp = init_pos * (1 - alpha) + target_pos * alpha
for i in range(NUM_DOFS):
if i in FROZEN_JOINTS:
self.robot.msg.motor_cmd[i].q = 0.0
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
else:
self.robot.msg.motor_cmd[i].q = float(interp[i])
self.robot.msg.motor_cmd[i].kp = STIFF_KP[i]
self.robot.msg.motor_cmd[i].kd = STIFF_KD[i]
self.robot.msg.motor_cmd[i].qd = 0
self.robot.msg.motor_cmd[i].tau = 0
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
self.robot.lowcmd_publisher.Write(self.robot.msg)
time.sleep(CONTROL_DT)
logger.info("At dance start pose!")
# =============================================================================
# MAIN
# =============================================================================
def load_dance_policy(onnx_path: str):
"""Load dance policy and extract metadata."""
logger.info(f"Loading dance policy: {onnx_path}")
policy = ort.InferenceSession(onnx_path)
model = onnx.load(onnx_path)
metadata = {p.key: json.loads(p.value) for p in model.metadata_props}
motor_kp = np.array(metadata.get("kp", STIFF_KP), dtype=np.float32)
motor_kd = np.array(metadata.get("kd", STIFF_KD), dtype=np.float32)
action_scale = float(metadata.get("action_scale", 1.0))
urdf_text = metadata.get("robot_urdf", None)
logger.info(f" Obs dim: {policy.get_inputs()[0].shape[1]}")
logger.info(f" Action scale: {action_scale}")
logger.info(f" KP range: [{motor_kp.min():.1f}, {motor_kp.max():.1f}]")
# Build Pinocchio FK if URDF available
pinocchio_fk = None
if urdf_text:
logger.info(" Building Pinocchio FK from URDF...")
pinocchio_fk = PinocchioFK(urdf_text)
else:
logger.warning(" No URDF in metadata - FK will not work!")
return policy, pinocchio_fk, motor_kp, motor_kd, action_scale
def main():
parser = argparse.ArgumentParser(description="WBT Dance Policy for Unitree G1")
parser.add_argument("--onnx", type=str, default=DANCE_ONNX_PATH, help="Path to dance ONNX model")
parser.add_argument("--sim", action="store_true", help="Run in simulation mode")
args = parser.parse_args()
print("=" * 70)
print("💃 WBT DANCE POLICY")
print("=" * 70)
# Load policy
policy, pinocchio_fk, motor_kp, motor_kd, action_scale = load_dance_policy(args.onnx)
# Initialize robot
logger.info("Initializing robot...")
config = UnitreeG1Config()
robot = UnitreeG1(config)
logger.info("Robot connected!")
# Create controller
controller = DanceController(policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale)
try:
# Move to start pose
controller.reset_to_motion_pose(duration=3.0)
# Start dancing
controller.start_dance_thread()
logger.info("Dancing! Press Ctrl+C to stop.")
print("-" * 70)
# Log status periodically
while True:
time.sleep(2.0)
logger.info(f"timestep={controller.timestep}")
except KeyboardInterrupt:
print("\n\nStopping...")
finally:
controller.stop_dance_thread()
robot.disconnect()
print("\nDone!")
if __name__ == "__main__":
main()
Binary file not shown.
+479
View File
@@ -0,0 +1,479 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example: Holosoma Whole-Body Locomotion (23-DOF and 29-DOF)
This example demonstrates loading Holosoma whole-body locomotion policies
and running them on the Unitree G1 robot.
Supports both:
- 23-DOF native policies (82D observations, 23D actions)
- 29-DOF policies (100D observations, 29D actions)
"""
import argparse
import logging
import threading
import time
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# =============================================================================
# 29-DOF Configuration
# =============================================================================
# fmt: off
HOLOSOMA_29DOF_DEFAULT_ANGLES = np.array([
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg
0.0, 0.0, 0.0, # waist (yaw, roll, pitch)
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # left arm
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # right arm
], dtype=np.float32)
HOLOSOMA_29DOF_KP = np.array([
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # left leg
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # right leg
40.179238471, 28.501246196, 28.501246196, # waist
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, 16.778327481, 16.778327481, # left arm
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, 16.778327481, 16.778327481, # right arm
], dtype=np.float32)
HOLOSOMA_29DOF_KD = np.array([
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # left leg
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # right leg
2.557889765, 1.814445687, 1.814445687, # waist
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, 1.068141502, 1.068141502, # left arm
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, 1.068141502, 1.068141502, # right arm
], dtype=np.float32)
# =============================================================================
# 23-DOF Configuration (native G1-23: no waist_roll/pitch, no wrist_pitch/yaw)
# Derived from 29-DOF Holosoma values
# =============================================================================
# Joint order: 6 left leg, 6 right leg, 1 waist_yaw, 5 left arm, 5 right arm
HOLOSOMA_23DOF_DEFAULT_ANGLES = np.array([
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg (from 29-DOF)
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg (from 29-DOF)
0.0, # waist_yaw only (from 29-DOF)
0.2, 0.2, 0.0, 0.6, 0.0, # left arm first 5 joints (from 29-DOF)
0.2, -0.2, 0.0, 0.6, 0.0, # right arm first 5 joints (from 29-DOF)
], dtype=np.float32)
HOLOSOMA_23DOF_KP = np.array([
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # left leg
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # right leg
40.179238471, # waist_yaw
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, # left arm
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, # right arm
], dtype=np.float32)
HOLOSOMA_23DOF_KD = np.array([
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # left leg
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # right leg
2.557889765, # waist_yaw
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, # left arm
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, # right arm
], dtype=np.float32)
# Maps 23-DOF policy index → 29-DOF motor index
# 23-DOF: legs(0-11), waist_yaw(12), L_arm(13-17), R_arm(18-22)
# 29-DOF: legs(0-11), waist(12-14), L_arm(15-21), R_arm(22-28)
DOF_23_TO_MOTOR_MAP = [
0, 1, 2, 3, 4, 5, # left leg → motor 0-5
6, 7, 8, 9, 10, 11, # right leg → motor 6-11
12, # waist_yaw → motor 12
15, 16, 17, 18, 19, # left arm (skip wrist_pitch/yaw) → motor 15-19
22, 23, 24, 25, 26, # right arm (skip wrist_pitch/yaw) → motor 22-26
]
# fmt: on
# Control parameters
LOCOMOTION_CONTROL_DT = 0.02 # 50Hz
LOCOMOTION_ACTION_SCALE = 0.25
ANG_VEL_SCALE = 0.25
DOF_POS_SCALE = 1.0
DOF_VEL_SCALE = 0.05
GAIT_PERIOD = 1.0
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
def load_holosoma_policy(
repo_id: str = DEFAULT_HOLOSOMA_REPO_ID,
policy_name: str = "fastsac",
local_path: str | None = None,
) -> tuple[ort.InferenceSession, int]:
"""Load Holosoma policy and detect observation dimension.
Returns:
(policy, obs_dim) tuple where obs_dim is 82 (23-DOF) or 100 (29-DOF)
"""
if local_path is not None:
logger.info(f"Loading policy from local path: {local_path}")
policy_path = local_path
else:
logger.info(f"Loading policy from Hugging Face Hub: {repo_id}")
policy_path = hf_hub_download(repo_id=repo_id, filename=f"{policy_name}_g1_29dof.onnx")
policy = ort.InferenceSession(policy_path)
# Detect observation dimension from model input shape
input_shape = policy.get_inputs()[0].shape
obs_dim = input_shape[1] if len(input_shape) > 1 else input_shape[0]
logger.info(f"Policy loaded successfully")
logger.info(f" Input: {policy.get_inputs()[0].name}, shape: {input_shape} → obs_dim={obs_dim}")
logger.info(f" Output: {policy.get_outputs()[0].name}, shape: {policy.get_outputs()[0].shape}")
return policy, obs_dim
class HolosomaLocomotionController:
"""
Handles Holosoma whole-body locomotion for Unitree G1.
Supports both 23-DOF (82D obs) and 29-DOF (100D obs) policies.
"""
def __init__(self, policy, robot, config, obs_dim: int = 100):
self.policy = policy
self.robot = robot
self.config = config
self.obs_dim = obs_dim
# Detect policy type from observation dimension
self.is_23dof = (obs_dim == 82)
self.num_dof = 23 if self.is_23dof else 29
# Velocity commands
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
# State variables sized for policy type
self.qj = np.zeros(self.num_dof, dtype=np.float32)
self.dqj = np.zeros(self.num_dof, dtype=np.float32)
self.locomotion_action = np.zeros(self.num_dof, dtype=np.float32)
self.locomotion_obs = np.zeros(obs_dim, dtype=np.float32)
self.last_unscaled_action = np.zeros(self.num_dof, dtype=np.float32)
# Select config based on DOF
if self.is_23dof:
self.default_angles = HOLOSOMA_23DOF_DEFAULT_ANGLES
self.kp = HOLOSOMA_23DOF_KP
self.kd = HOLOSOMA_23DOF_KD
self.motor_map = DOF_23_TO_MOTOR_MAP
else:
self.default_angles = HOLOSOMA_29DOF_DEFAULT_ANGLES
self.kp = HOLOSOMA_29DOF_KP
self.kd = HOLOSOMA_29DOF_KD
self.motor_map = list(range(29)) # Identity map for 29-DOF
# Phase state for gait
self.phase = np.zeros((1, 2), dtype=np.float32)
self.phase[0, 0] = 0.0
self.phase[0, 1] = np.pi
self.phase_dt = 2 * np.pi / (50.0 * GAIT_PERIOD)
self.is_standing = False
self.counter = 0
self.locomotion_running = False
self.locomotion_thread = None
logger.info(f"HolosomaLocomotionController initialized")
logger.info(f" Mode: {'23-DOF (82D obs)' if self.is_23dof else '29-DOF (100D obs)'}")
logger.info(f" Action dim: {self.num_dof}")
def holosoma_locomotion_run(self):
"""Main locomotion loop - handles both 23-DOF and 29-DOF."""
self.counter += 1
if self.counter == 1:
print("\n" + "=" * 60)
print(f"🚀 RUNNING HOLOSOMA {self.num_dof}-DOF LOCOMOTION POLICY")
print(f" {self.obs_dim}D observations → {self.num_dof}D actions")
print("=" * 60 + "\n")
robot_state = self.robot.get_observation()
if robot_state is None:
return
# Remote controller
if robot_state.wireless_remote is not None:
self.robot.remote_controller.set(robot_state.wireless_remote)
else:
self.robot.remote_controller.lx = 0.0
self.robot.remote_controller.ly = 0.0
self.robot.remote_controller.rx = 0.0
self.robot.remote_controller.ry = 0.0
# Deadzone
ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0
lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0
rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0
self.locomotion_cmd[0] = ly
self.locomotion_cmd[1] = -lx
self.locomotion_cmd[2] = -rx
# Read joint states using motor map
for i in range(self.num_dof):
motor_idx = self.motor_map[i]
self.qj[i] = robot_state.motor_state[motor_idx].q
self.dqj[i] = robot_state.motor_state[motor_idx].dq
# IMU
quat = robot_state.imu_state.quaternion
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
gravity_orientation = self.robot.get_gravity_orientation(quat)
# Scale observations
qj_obs = (self.qj - self.default_angles) * DOF_POS_SCALE
dqj_obs = self.dqj * DOF_VEL_SCALE
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
# Phase update
cmd_norm = np.linalg.norm(self.locomotion_cmd[:2])
ang_cmd_norm = np.abs(self.locomotion_cmd[2])
if cmd_norm < 0.01 and ang_cmd_norm < 0.01:
self.phase[0, :] = np.pi * np.ones(2)
self.is_standing = True
elif self.is_standing:
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
self.is_standing = False
else:
phase_tp1 = self.phase + self.phase_dt
self.phase = np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi
sin_phase = np.sin(self.phase[0, :])
cos_phase = np.cos(self.phase[0, :])
# Build observation (format depends on DOF)
if self.is_23dof:
# 82D: [23 actions, 3 ang_vel, 1 cmd_yaw, 2 cmd_lin, 2 cos, 23 pos, 23 vel, 3 grav, 2 sin]
self.locomotion_obs[0:23] = self.last_unscaled_action
self.locomotion_obs[23:26] = ang_vel_scaled
self.locomotion_obs[26] = self.locomotion_cmd[2]
self.locomotion_obs[27:29] = self.locomotion_cmd[:2]
self.locomotion_obs[29:31] = cos_phase
self.locomotion_obs[31:54] = qj_obs
self.locomotion_obs[54:77] = dqj_obs
self.locomotion_obs[77:80] = gravity_orientation
self.locomotion_obs[80:82] = sin_phase
else:
# 100D: [29 actions, 3 ang_vel, 1 cmd_yaw, 2 cmd_lin, 2 cos, 29 pos, 29 vel, 3 grav, 2 sin]
self.locomotion_obs[0:29] = self.last_unscaled_action
self.locomotion_obs[29:32] = ang_vel_scaled
self.locomotion_obs[32] = self.locomotion_cmd[2]
self.locomotion_obs[33:35] = self.locomotion_cmd[:2]
self.locomotion_obs[35:37] = cos_phase
self.locomotion_obs[37:66] = qj_obs
self.locomotion_obs[66:95] = dqj_obs
self.locomotion_obs[95:98] = gravity_orientation
self.locomotion_obs[98:100] = sin_phase
# Policy inference
obs_input = self.locomotion_obs.reshape(1, -1).astype(np.float32)
ort_inputs = {self.policy.get_inputs()[0].name: obs_input}
ort_outs = self.policy.run(None, ort_inputs)
raw_action = ort_outs[0].squeeze()
clipped_action = np.clip(raw_action, -100.0, 100.0)
self.last_unscaled_action = clipped_action.copy()
self.locomotion_action = clipped_action * LOCOMOTION_ACTION_SCALE
# Debug
if self.counter <= 3:
print(f"\n[Holosoma Debug #{self.counter}]")
print(f" Phase: ({self.phase[0, 0]:.3f}, {self.phase[0, 1]:.3f})")
print(f" Cmd: ({self.locomotion_cmd[0]:.2f}, {self.locomotion_cmd[1]:.2f}, {self.locomotion_cmd[2]:.2f})")
print(f" Action range: [{raw_action.min():.3f}, {raw_action.max():.3f}]")
# Compute target positions
target_dof_pos = self.default_angles + self.locomotion_action
# Send commands to motors via motor map
for i in range(self.num_dof):
motor_idx = self.motor_map[i]
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# For 23-DOF: zero out missing joints (waist_roll/pitch, wrist_pitch/yaw)
if self.is_23dof:
missing_motors = [13, 14, 20, 21, 27, 28] # waist_roll, waist_pitch, wrist_pitch/yaw
for motor_idx in missing_motors:
self.robot.msg.motor_cmd[motor_idx].q = 0.0
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
self.robot.msg.motor_cmd[motor_idx].tau = 0
self.robot.send_action(self.robot.msg)
def _locomotion_thread_loop(self):
logger.info("Locomotion thread started")
while self.locomotion_running:
start_time = time.time()
try:
self.holosoma_locomotion_run()
except Exception as e:
logger.error(f"Error in locomotion loop: {e}")
import traceback
traceback.print_exc()
elapsed = time.time() - start_time
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
time.sleep(sleep_time)
logger.info("Locomotion thread stopped")
def start_locomotion_thread(self):
if self.locomotion_running:
logger.warning("Locomotion thread already running")
return
logger.info("Starting locomotion control thread...")
self.locomotion_running = True
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
self.locomotion_thread.start()
logger.info("Locomotion control thread started!")
def stop_locomotion_thread(self):
if not self.locomotion_running:
return
logger.info("Stopping locomotion control thread...")
self.locomotion_running = False
if self.locomotion_thread:
self.locomotion_thread.join(timeout=2.0)
logger.info("Locomotion control thread stopped")
def reset_robot(self):
"""Move joints to default position."""
logger.info(f"Moving {self.num_dof} joints to default position...")
total_time = 3.0
num_step = int(total_time / self.robot.control_dt)
robot_state = self.robot.get_observation()
# Record current positions
init_dof_pos = np.zeros(self.num_dof, dtype=np.float32)
for i in range(self.num_dof):
motor_idx = self.motor_map[i]
init_dof_pos[i] = robot_state.motor_state[motor_idx].q
# Interpolate to target
for step in range(num_step):
alpha = step / num_step
for i in range(self.num_dof):
motor_idx = self.motor_map[i]
target = self.default_angles[i]
self.robot.msg.motor_cmd[motor_idx].q = init_dof_pos[i] * (1 - alpha) + target * alpha
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Zero missing joints for 23-DOF
if self.is_23dof:
for motor_idx in [13, 14, 20, 21, 27, 28]:
self.robot.msg.motor_cmd[motor_idx].q = 0.0
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
self.robot.msg.motor_cmd[motor_idx].tau = 0
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
self.robot.lowcmd_publisher.Write(self.robot.msg)
time.sleep(self.robot.control_dt)
logger.info(f"Reached default position ({self.num_dof} joints)")
# Hold for 2 seconds
logger.info("Holding default position for 2 seconds...")
hold_steps = int(2.0 / self.robot.control_dt)
for _ in range(hold_steps):
for i in range(self.num_dof):
motor_idx = self.motor_map[i]
self.robot.msg.motor_cmd[motor_idx].q = self.default_angles[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
if self.is_23dof:
for motor_idx in [13, 14, 20, 21, 27, 28]:
self.robot.msg.motor_cmd[motor_idx].q = 0.0
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
self.robot.msg.motor_cmd[motor_idx].tau = 0
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
self.robot.lowcmd_publisher.Write(self.robot.msg)
time.sleep(self.robot.control_dt)
logger.info("Ready to start locomotion!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Holosoma Locomotion Controller for Unitree G1")
parser.add_argument("--repo-id", type=str, default=DEFAULT_HOLOSOMA_REPO_ID)
parser.add_argument("--policy", type=str, default="fastsac", choices=["fastsac", "ppo"])
parser.add_argument("--local-path", type=str, default=None, help="Path to local ONNX file")
args = parser.parse_args()
# Load policy and detect dimensions
policy, obs_dim = load_holosoma_policy(
repo_id=args.repo_id,
policy_name=args.policy,
local_path=args.local_path,
)
# Initialize robot
config = UnitreeG1Config()
robot = UnitreeG1(config)
# Initialize controller with detected obs_dim
controller = HolosomaLocomotionController(
policy=policy,
robot=robot,
config=config,
obs_dim=obs_dim,
)
try:
#controller.reset_robot()
controller.start_locomotion_thread()
logger.info(f"Robot initialized with Holosoma {'23-DOF' if obs_dim == 82 else '29-DOF'} policy")
logger.info("Use remote controller: LY=fwd/back, LX=left/right, RX=rotate")
logger.info("Press Ctrl+C to stop")
while True:
time.sleep(1.0)
except KeyboardInterrupt:
print("\nStopping locomotion...")
controller.stop_locomotion_thread()
print("Done!")
+607
View File
@@ -0,0 +1,607 @@
#!/usr/bin/env python3
"""
Locomotion Dance Toggle for Unitree G1
Press Enter to instantly switch between locomotion and dance modes.
- Starts in LOCOMOTION mode (joystick control)
- Press Enter DANCE mode (resets to frame 0)
- Press Enter LOCOMOTION mode
- Repeat...
Auto-recovery feature:
- If robot tilts beyond threshold during dance, auto-switches to locomotion
- When robot recovers (tilt below recovery threshold), resumes dance from where it left off
Usage:
python examples/unitree_g1/locomotion_to_dance.py
python examples/unitree_g1/locomotion_to_dance.py --tilt-threshold 25 --recovery-threshold 10
"""
import argparse
import json
import logging
import select
import sys
import threading
import time
from xml.etree import ElementTree
import numpy as np
import onnx
import onnxruntime as ort
import pinocchio as pin
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# =============================================================================
# CONFIGURATION
# =============================================================================
NUM_DOFS = 29
CONTROL_DT = 0.02 # 50Hz
# Locomotion config
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
LOCOMOTION_ACTION_SCALE = 0.25
ANG_VEL_SCALE = 0.25
DOF_POS_SCALE = 1.0
DOF_VEL_SCALE = 0.05
GAIT_PERIOD = 1.0
# Dance config
DANCE_ONNX_PATH = "examples/unitree_g1/fastsac_g1_29dof_dancing.onnx"
FROZEN_JOINTS = [13, 14, 20, 21, 27, 28]
FROZEN_KP = 500.0
FROZEN_KD = 5.0
# fmt: off
# 29-DOF defaults (holosoma training)
DEFAULT_29DOF_ANGLES = np.array([
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg
0.0, 0.0, 0.0, # waist
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # left arm
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # right arm
], dtype=np.float32)
DEFAULT_29DOF_KP = np.array([
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
40.179, 28.501, 28.501,
14.251, 14.251, 14.251, 14.251, 14.251, 16.778, 16.778,
14.251, 14.251, 14.251, 14.251, 14.251, 16.778, 16.778,
], dtype=np.float32)
DEFAULT_29DOF_KD = np.array([
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
2.558, 1.814, 1.814,
0.907, 0.907, 0.907, 0.907, 0.907, 1.068, 1.068,
0.907, 0.907, 0.907, 0.907, 0.907, 1.068, 1.068,
], dtype=np.float32)
# 23-DOF config (no waist_roll/pitch, no wrist_pitch/yaw)
DEFAULT_23DOF_ANGLES = np.array([
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg
0.0, # waist_yaw only
0.2, 0.2, 0.0, 0.6, 0.0, # left arm (5 joints)
0.2, -0.2, 0.0, 0.6, 0.0, # right arm (5 joints)
], dtype=np.float32)
DEFAULT_23DOF_KP = np.array([
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
40.179,
14.251, 14.251, 14.251, 14.251, 14.251,
14.251, 14.251, 14.251, 14.251, 14.251,
], dtype=np.float32)
DEFAULT_23DOF_KD = np.array([
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
2.558,
0.907, 0.907, 0.907, 0.907, 0.907,
0.907, 0.907, 0.907, 0.907, 0.907,
], dtype=np.float32)
# 23-DOF policy index → 29-DOF motor index
DOF_23_TO_MOTOR = [
0, 1, 2, 3, 4, 5, # left leg
6, 7, 8, 9, 10, 11, # right leg
12, # waist_yaw
15, 16, 17, 18, 19, # left arm (skip wrist_pitch/yaw)
22, 23, 24, 25, 26, # right arm (skip wrist_pitch/yaw)
]
MISSING_23DOF_MOTORS = [13, 14, 20, 21, 27, 28]
# fmt: on
# =============================================================================
# QUATERNION UTILITIES
# =============================================================================
def quat_inverse(q):
return np.concatenate((q[:, 0:1], -q[:, 1:]), axis=1)
def quat_mul(a, b):
a, b = a.reshape(-1, 4), b.reshape(-1, 4)
w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
xx = ww + yy + zz
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
w = qq - ww + (z1 - y1) * (y2 - z2)
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
return np.stack([w, x, y, z]).T.reshape(a.shape)
def subtract_frame_transforms(q01, q02):
return quat_mul(quat_inverse(q01), q02)
def matrix_from_quat(q):
r, i, j, k = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
two_s = 2.0 / (q * q).sum(-1)
o = np.stack((
1 - two_s * (j*j + k*k), two_s * (i*j - k*r), two_s * (i*k + j*r),
two_s * (i*j + k*r), 1 - two_s * (i*i + k*k), two_s * (j*k - i*r),
two_s * (i*k - j*r), two_s * (j*k + i*r), 1 - two_s * (i*i + j*j),
), -1)
return o.reshape(q.shape[:-1] + (3, 3))
def xyzw_to_wxyz(xyzw):
return np.concatenate([xyzw[:, -1:], xyzw[:, :3]], axis=1)
def quat_to_rpy(q):
w, x, y, z = q
roll = np.arctan2(2*(w*x + y*z), 1 - 2*(x**2 + y**2))
pitch = np.arcsin(np.clip(2*(w*y - z*x), -1, 1))
yaw = np.arctan2(2*(w*z + x*y), 1 - 2*(y**2 + z**2))
return roll, pitch, yaw
def rpy_to_quat(rpy):
roll, pitch, yaw = rpy
cy, sy = np.cos(yaw*0.5), np.sin(yaw*0.5)
cp, sp = np.cos(pitch*0.5), np.sin(pitch*0.5)
cr, sr = np.cos(roll*0.5), np.sin(roll*0.5)
return np.array([cr*cp*cy + sr*sp*sy, sr*cp*cy - cr*sp*sy,
cr*sp*cy + sr*cp*sy, cr*cp*sy - sr*sp*cy])
# =============================================================================
# PINOCCHIO FK
# =============================================================================
DOF_NAMES = (
"left_hip_pitch_joint", "left_hip_roll_joint", "left_hip_yaw_joint",
"left_knee_joint", "left_ankle_pitch_joint", "left_ankle_roll_joint",
"right_hip_pitch_joint", "right_hip_roll_joint", "right_hip_yaw_joint",
"right_knee_joint", "right_ankle_pitch_joint", "right_ankle_roll_joint",
"waist_yaw_joint", "waist_roll_joint", "waist_pitch_joint",
"left_shoulder_pitch_joint", "left_shoulder_roll_joint", "left_shoulder_yaw_joint", "left_elbow_joint",
"left_wrist_roll_joint", "left_wrist_pitch_joint", "left_wrist_yaw_joint",
"right_shoulder_pitch_joint", "right_shoulder_roll_joint", "right_shoulder_yaw_joint", "right_elbow_joint",
"right_wrist_roll_joint", "right_wrist_pitch_joint", "right_wrist_yaw_joint",
)
class PinocchioFK:
def __init__(self, urdf_text: str):
root = ElementTree.fromstring(urdf_text)
for parent in root.iter():
for child in list(parent):
if child.tag.split("}")[-1] in {"visual", "collision"}:
parent.remove(child)
xml_text = '<?xml version="1.0"?>\n' + ElementTree.tostring(root, encoding="unicode")
self.model = pin.buildModelFromXML(xml_text, pin.JointModelFreeFlyer())
self.data = self.model.createData()
pin_names = [n for n in self.model.names if n not in ["universe", "root_joint"]]
self.idx_map = np.array([DOF_NAMES.index(n) for n in pin_names])
self.ref_frame_id = self.model.getFrameId("torso_link")
def get_torso_quat(self, pos, quat_wxyz, dof_pos):
quat_xyzw = np.array([quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]])
config = np.concatenate([pos, quat_xyzw, dof_pos[self.idx_map]])
pin.framesForwardKinematics(self.model, self.data, config)
coeffs = pin.Quaternion(self.data.oMf[self.ref_frame_id].rotation).coeffs()
return np.array([coeffs[3], coeffs[0], coeffs[1], coeffs[2]]).reshape(1, 4)
def get_torso_tilt(self, pos, quat_wxyz, dof_pos):
"""Get torso tilt angle from upright (degrees). Uses roll and pitch."""
torso_q = self.get_torso_quat(pos, quat_wxyz, dof_pos)
roll, pitch, _ = quat_to_rpy(torso_q.flatten())
# Tilt is the angle from vertical - combine roll and pitch
tilt_rad = np.sqrt(roll**2 + pitch**2)
return np.degrees(tilt_rad), np.degrees(roll), np.degrees(pitch)
# =============================================================================
# LOCOMOTION CONTROLLER
# =============================================================================
class LocomotionController:
"""Holosoma whole-body locomotion (23-DOF or 29-DOF)."""
def __init__(self, policy, robot, obs_dim: int):
self.policy = policy
self.robot = robot
self.obs_dim = obs_dim
# Detect DOF mode
self.is_23dof = (obs_dim == 82)
self.num_dof = 23 if self.is_23dof else 29
if self.is_23dof:
self.default_angles = DEFAULT_23DOF_ANGLES
self.kp = DEFAULT_23DOF_KP
self.kd = DEFAULT_23DOF_KD
self.motor_map = DOF_23_TO_MOTOR
logger.info("Locomotion: 23-DOF (82D obs)")
else:
self.default_angles = DEFAULT_29DOF_ANGLES
self.kp = DEFAULT_29DOF_KP
self.kd = DEFAULT_29DOF_KD
self.motor_map = list(range(29))
logger.info("Locomotion: 29-DOF (100D obs)")
self.cmd = np.zeros(3, dtype=np.float32)
self.qj = np.zeros(self.num_dof, dtype=np.float32)
self.dqj = np.zeros(self.num_dof, dtype=np.float32)
self.obs = np.zeros(obs_dim, dtype=np.float32)
self.last_action = np.zeros(self.num_dof, dtype=np.float32)
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
self.phase_dt = 2 * np.pi / (50.0 * GAIT_PERIOD)
self.is_standing = True
def run_step(self):
"""Single locomotion step."""
state = self.robot.lowstate_buffer.get_data()
if state is None:
return
# Joystick
if state.wireless_remote is not None:
self.robot.remote_controller.set(state.wireless_remote)
ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0
lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0
rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0
self.cmd[0], self.cmd[1], self.cmd[2] = ly, -lx, -rx
# Read joints via motor map
for i in range(self.num_dof):
self.qj[i] = state.motor_state[self.motor_map[i]].q
self.dqj[i] = state.motor_state[self.motor_map[i]].dq
# IMU
quat = state.imu_state.quaternion
ang_vel = np.array(state.imu_state.gyroscope, dtype=np.float32)
gravity = self.robot.get_gravity_orientation(quat)
# Scale
qj_obs = (self.qj - self.default_angles) * DOF_POS_SCALE
dqj_obs = self.dqj * DOF_VEL_SCALE
ang_vel_s = ang_vel * ANG_VEL_SCALE
# Phase
cmd_mag = np.linalg.norm(self.cmd[:2])
ang_mag = abs(self.cmd[2])
if cmd_mag < 0.01 and ang_mag < 0.01:
self.phase[0, :] = np.pi
self.is_standing = True
elif self.is_standing:
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
self.is_standing = False
else:
self.phase = np.fmod(self.phase + self.phase_dt + np.pi, 2*np.pi) - np.pi
sin_ph, cos_ph = np.sin(self.phase[0]), np.cos(self.phase[0])
# Build obs
if self.is_23dof:
self.obs[0:23] = self.last_action
self.obs[23:26] = ang_vel_s
self.obs[26] = self.cmd[2]
self.obs[27:29] = self.cmd[:2]
self.obs[29:31] = cos_ph
self.obs[31:54] = qj_obs
self.obs[54:77] = dqj_obs
self.obs[77:80] = gravity
self.obs[80:82] = sin_ph
else:
self.obs[0:29] = self.last_action
self.obs[29:32] = ang_vel_s
self.obs[32] = self.cmd[2]
self.obs[33:35] = self.cmd[:2]
self.obs[35:37] = cos_ph
self.obs[37:66] = qj_obs
self.obs[66:95] = dqj_obs
self.obs[95:98] = gravity
self.obs[98:100] = sin_ph
# Inference
obs_in = self.obs.reshape(1, -1).astype(np.float32)
ort_in = {self.policy.get_inputs()[0].name: obs_in}
raw_action = self.policy.run(None, ort_in)[0].squeeze()
clipped = np.clip(raw_action, -100.0, 100.0)
self.last_action = clipped.copy()
scaled = clipped * LOCOMOTION_ACTION_SCALE
target = self.default_angles + scaled
# Send commands
for i in range(self.num_dof):
motor_idx = self.motor_map[i]
self.robot.msg.motor_cmd[motor_idx].q = float(target[i])
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Zero missing joints for 23-DOF
if self.is_23dof:
for idx in MISSING_23DOF_MOTORS:
self.robot.msg.motor_cmd[idx].q = 0.0
self.robot.msg.motor_cmd[idx].qd = 0
self.robot.msg.motor_cmd[idx].kp = 40.0
self.robot.msg.motor_cmd[idx].kd = 2.0
self.robot.msg.motor_cmd[idx].tau = 0
self.robot.send_action(self.robot.msg)
def reset(self):
"""Reset state for fresh start."""
self.last_action.fill(0)
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
self.is_standing = True
# =============================================================================
# DANCE CONTROLLER
# =============================================================================
class DanceController:
"""WBT dance policy with FK for torso tracking."""
def __init__(self, policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale):
self.policy = policy
self.robot = robot
self.pinocchio_fk = pinocchio_fk
self.motor_kp = motor_kp
self.motor_kd = motor_kd
self.action_scale = action_scale
self.obs_dim = policy.get_inputs()[0].shape[1]
self.last_action = np.zeros((1, NUM_DOFS), dtype=np.float32)
self.motion_command = None
self.ref_quat_xyzw = None
self.timestep = 0
self.yaw_offset = 0.0
logger.info(f"Dance: obs_dim={self.obs_dim}, action_scale={action_scale}")
def initialize(self, reset_to_frame_0: bool = True):
"""Initialize dance. If reset_to_frame_0=True, starts from frame 0. Otherwise resumes."""
if reset_to_frame_0:
self.timestep = 0
self.last_action.fill(0)
# Get initial motion data at frame 0
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
self.motion_command = np.concatenate(outs[0:2], axis=1)
self.ref_quat_xyzw = outs[2]
logger.info("Dance: reset to frame 0")
else:
# Resume from current timestep - just update motion command for current frame
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
{"obs": dummy, "time_step": np.array([[self.timestep]], dtype=np.float32)})
self.motion_command = np.concatenate(outs[0:2], axis=1)
self.ref_quat_xyzw = outs[2]
logger.info(f"Dance: resuming from frame {self.timestep}")
# Capture yaw offset
state = self.robot.lowstate_buffer.get_data()
if state and self.pinocchio_fk:
quat = np.array(state.imu_state.quaternion, dtype=np.float32)
dof = np.array([state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof)
_, _, self.yaw_offset = quat_to_rpy(torso_q.flatten())
logger.info(f"Dance yaw offset: {np.degrees(self.yaw_offset):.1f}°")
def _remove_yaw_offset(self, quat_wxyz):
if abs(self.yaw_offset) < 1e-6:
return quat_wxyz
yaw_q = rpy_to_quat((0, 0, -self.yaw_offset)).reshape(1, 4)
return quat_mul(yaw_q, quat_wxyz)
def run_step(self):
"""Single dance step."""
state = self.robot.lowstate_buffer.get_data()
if state is None:
return
quat = np.array(state.imu_state.quaternion, dtype=np.float32)
ang_vel = np.array(state.imu_state.gyroscope, dtype=np.float32)
dof_pos = np.array([state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
dof_vel = np.array([state.motor_state[i].dq for i in range(NUM_DOFS)], dtype=np.float32)
# FK for torso orientation
if self.pinocchio_fk:
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof_pos)
torso_q = self._remove_yaw_offset(torso_q)
motion_ori = xyzw_to_wxyz(self.ref_quat_xyzw)
rel_quat = subtract_frame_transforms(torso_q, motion_ori)
ori_b = matrix_from_quat(rel_quat)[..., :2].reshape(1, -1)
else:
ori_b = np.zeros((1, 6), dtype=np.float32)
dof_rel = (dof_pos - DEFAULT_29DOF_ANGLES).reshape(1, -1)
# Build obs (alphabetical)
obs_dict = {
"actions": self.last_action,
"base_ang_vel": ang_vel.reshape(1, 3),
"dof_pos": dof_rel,
"dof_vel": dof_vel.reshape(1, -1),
"motion_command": self.motion_command,
"motion_ref_ori_b": ori_b,
}
obs = np.concatenate([obs_dict[k].astype(np.float32) for k in sorted(obs_dict.keys())], axis=1)
obs = np.clip(obs, -100, 100)
# Inference
outs = self.policy.run(["actions", "joint_pos", "joint_vel", "ref_quat_xyzw"],
{"obs": obs, "time_step": np.array([[self.timestep]], dtype=np.float32)})
action = np.clip(outs[0], -100, 100)
self.motion_command = np.concatenate(outs[1:3], axis=1)
self.ref_quat_xyzw = outs[3]
self.last_action = action.copy()
target = DEFAULT_29DOF_ANGLES + action.flatten() * self.action_scale
# Send commands
for i in range(NUM_DOFS):
if i in FROZEN_JOINTS:
self.robot.msg.motor_cmd[i].q = 0.0
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
else:
self.robot.msg.motor_cmd[i].q = float(target[i])
self.robot.msg.motor_cmd[i].kp = self.motor_kp[i]
self.robot.msg.motor_cmd[i].kd = self.motor_kd[i]
self.robot.msg.motor_cmd[i].qd = 0
self.robot.msg.motor_cmd[i].tau = 0
self.robot.send_action(self.robot.msg)
self.timestep += 1
# =============================================================================
# MAIN
# =============================================================================
def main():
parser = argparse.ArgumentParser(description="Locomotion ↔ Dance Toggle")
parser.add_argument("--loco-repo", type=str, default=DEFAULT_HOLOSOMA_REPO_ID)
parser.add_argument("--dance-onnx", type=str, default=DANCE_ONNX_PATH)
args = parser.parse_args()
print("=" * 70)
print("🚶 LOCOMOTION ↔ 💃 DANCE")
print("=" * 70)
print("Press ENTER to toggle between modes")
print("=" * 70)
# Load locomotion policy
logger.info("Loading locomotion policy...")
loco_path = hf_hub_download(repo_id=args.loco_repo, filename="fastsac_g1_29dof.onnx")
loco_policy = ort.InferenceSession(loco_path)
loco_obs_dim = loco_policy.get_inputs()[0].shape[1]
logger.info(f"Locomotion: {loco_obs_dim}D obs")
# Load dance policy
logger.info("Loading dance policy...")
dance_policy = ort.InferenceSession(args.dance_onnx)
dance_model = onnx.load(args.dance_onnx)
dance_meta = {p.key: json.loads(p.value) for p in dance_model.metadata_props}
dance_kp = np.array(dance_meta.get("kp", DEFAULT_29DOF_KP), dtype=np.float32)
dance_kd = np.array(dance_meta.get("kd", DEFAULT_29DOF_KD), dtype=np.float32)
dance_action_scale = float(dance_meta.get("action_scale", 1.0))
logger.info(f"Dance: {dance_policy.get_inputs()[0].shape[1]}D obs, scale={dance_action_scale}")
# Build Pinocchio FK
pinocchio_fk = None
if "robot_urdf" in dance_meta:
logger.info("Building Pinocchio FK...")
pinocchio_fk = PinocchioFK(dance_meta["robot_urdf"])
# Initialize robot
logger.info("Initializing robot...")
config = UnitreeG1Config()
robot = UnitreeG1(config)
logger.info("Robot connected!")
# Create controllers
loco_ctrl = LocomotionController(loco_policy, robot, loco_obs_dim)
dance_ctrl = DanceController(dance_policy, robot, pinocchio_fk, dance_kp, dance_kd, dance_action_scale)
# State
mode = "locomotion"
toggle_event = threading.Event()
shutdown = threading.Event()
# Input thread
def input_loop():
while not shutdown.is_set():
if select.select([sys.stdin], [], [], 0.1)[0]:
sys.stdin.readline()
toggle_event.set()
input_thread = threading.Thread(target=input_loop, daemon=True)
input_thread.start()
print("\n🚶 LOCOMOTION MODE - Use joystick to walk")
print(" Press ENTER to switch to DANCE")
print("-" * 70)
step = 0
try:
while not shutdown.is_set():
t0 = time.time()
# Check toggle
if toggle_event.is_set():
toggle_event.clear()
if mode == "locomotion":
mode = "dance"
dance_ctrl.initialize()
print("\n" + "=" * 70)
print("💃 DANCE MODE (frame 0)")
print(" Press ENTER to switch to LOCOMOTION")
print("=" * 70)
else:
mode = "locomotion"
loco_ctrl.reset()
print("\n" + "=" * 70)
print("🚶 LOCOMOTION MODE")
print(" Press ENTER to switch to DANCE")
print("=" * 70)
# Run controller
if mode == "locomotion":
loco_ctrl.run_step()
else:
dance_ctrl.run_step()
# Log
if step % 100 == 0:
if mode == "locomotion":
print(f"[LOCO ] step={step:5d} cmd=[{loco_ctrl.cmd[0]:.2f},{loco_ctrl.cmd[1]:.2f},{loco_ctrl.cmd[2]:.2f}]")
else:
print(f"[DANCE] step={step:5d} timestep={dance_ctrl.timestep}")
step += 1
elapsed = time.time() - t0
if elapsed < CONTROL_DT:
time.sleep(CONTROL_DT - elapsed)
except KeyboardInterrupt:
print("\n\nStopping...")
finally:
shutdown.set()
robot.disconnect()
print("Done!")
if __name__ == "__main__":
main()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,447 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example: Unitree RL 12-DOF Legs-Only Locomotion (TorchScript)
This example demonstrates loading a 12-DOF legs-only locomotion policy
(TorchScript .pt format) and running it on the Unitree G1 robot.
Key characteristics:
- Single TorchScript policy (.pt)
- 47D observations, 12D actions (legs only)
- Phase-based gait timing
- Arms and waist held at fixed positions
"""
import argparse
import logging
import threading
import time
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from scipy.spatial.transform import Rotation as R
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 12-DOF leg joint configuration
# Joint order: [L_hip_pitch, L_hip_roll, L_hip_yaw, L_knee, L_ankle_pitch, L_ankle_roll,
# R_hip_pitch, R_hip_roll, R_hip_yaw, R_knee, R_ankle_pitch, R_ankle_roll]
LEG_JOINT_INDICES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
# Default leg angles for standing
DEFAULT_LEG_ANGLES = np.array([
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # left leg
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # right leg
], dtype=np.float32)
# KP/KD for leg joints
LEG_KPS = np.array([150, 150, 150, 300, 40, 40, 150, 150, 150, 300, 40, 40], dtype=np.float32)
LEG_KDS = np.array([6, 6, 6, 4, 2, 2, 6, 6, 6, 4, 2, 2], dtype=np.float32)
# Waist configuration (held at zero)
WAIST_JOINT_INDICES = [12, 13, 14] # yaw, roll, pitch
WAIST_KPS = np.array([250, 250, 250], dtype=np.float32)
WAIST_KDS = np.array([5, 5, 5], dtype=np.float32)
# Arm configuration (indices 15-28, held at initial position)
ARM_JOINT_INDICES = list(range(15, 29))
ARM_KPS = np.array([80, 80, 80, 80, 40, 40, 40, # left arm (shoulder + wrist)
80, 80, 80, 80, 40, 40, 40], dtype=np.float32) # right arm
ARM_KDS = np.array([3, 3, 3, 3, 1.5, 1.5, 1.5,
3, 3, 3, 3, 1.5, 1.5, 1.5], dtype=np.float32)
# Control parameters
LOCOMOTION_CONTROL_DT = 0.02 # 50Hz control rate
LOCOMOTION_ACTION_SCALE = 0.25
ANG_VEL_SCALE = 0.25
DOF_POS_SCALE = 1.0
DOF_VEL_SCALE = 0.05
CMD_SCALE = np.array([2.0, 2.0, 0.25], dtype=np.float32)
MAX_CMD = np.array([0.8, 0.5, 1.57], dtype=np.float32) # max vx, vy, yaw_rate
# Gait parameters
GAIT_PERIOD = 0.8 # seconds
DEFAULT_REPO_ID = "nepyope/unitree_rl_locomotion"
def load_torchscript_policy(
repo_id: str = DEFAULT_REPO_ID,
filename: str = "motion.pt",
) -> torch.jit.ScriptModule:
"""Load TorchScript locomotion policy from Hugging Face Hub.
Args:
repo_id: Hugging Face Hub repository ID containing the policy.
filename: Policy filename (default: motion.pt).
"""
logger.info(f"Loading TorchScript policy from Hugging Face Hub ({repo_id}/{filename})...")
policy_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
)
policy = torch.jit.load(policy_path)
policy.eval()
logger.info("TorchScript policy loaded successfully")
return policy
class UnitreeRLLocomotionController:
"""
Handles 12-DOF legs-only locomotion control for the Unitree G1 robot.
This controller manages:
- Single TorchScript policy
- 47D observations (single frame)
- 12D action output (legs only)
- Arms and waist held at fixed positions
- Phase-based gait timing
"""
def __init__(self, policy, robot, config):
self.policy = policy
self.robot = robot
self.config = config
# Velocity commands (vx, vy, yaw_rate)
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
# State variables (12 DOF legs)
self.qj = np.zeros(12, dtype=np.float32)
self.dqj = np.zeros(12, dtype=np.float32)
self.locomotion_action = np.zeros(12, dtype=np.float32)
self.locomotion_obs = np.zeros(47, dtype=np.float32)
# Initial arm positions (captured on reset)
self.initial_arm_positions = np.zeros(14, dtype=np.float32)
# Counter for phase calculation
self.counter = 0
# Thread management
self.locomotion_running = False
self.locomotion_thread = None
logger.info("UnitreeRLLocomotionController initialized")
logger.info(" Observation dim: 47, Action dim: 12 (legs only)")
def locomotion_run(self):
"""12-DOF legs-only locomotion policy loop."""
self.counter += 1
if self.counter == 1:
print("\n" + "=" * 60)
print("🚀 RUNNING UNITREE RL 12-DOF LOCOMOTION POLICY")
print(" 47D observations → 12D actions (legs only)")
print(" Arms and waist held at fixed positions")
print("=" * 60 + "\n")
# Get current observation
robot_state = self.robot.get_observation()
if robot_state is None:
return
# Get command from remote controller
if robot_state.wireless_remote is not None:
self.robot.remote_controller.set(robot_state.wireless_remote)
else:
self.robot.remote_controller.lx = 0.0
self.robot.remote_controller.ly = 0.0
self.robot.remote_controller.rx = 0.0
self.robot.remote_controller.ry = 0.0
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right (inverted)
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # yaw (inverted)
# Get leg joint positions and velocities (12 DOF)
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
self.qj[i] = robot_state.motor_state[motor_idx].q
self.dqj[i] = robot_state.motor_state[motor_idx].dq
# Get IMU data
quat = robot_state.imu_state.quaternion
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
# Scale observations
gravity_orientation = self.robot.get_gravity_orientation(quat)
qj_obs = (self.qj - DEFAULT_LEG_ANGLES) * DOF_POS_SCALE
dqj_obs = self.dqj * DOF_VEL_SCALE
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
# Calculate phase
count = self.counter * LOCOMOTION_CONTROL_DT
phase = (count % GAIT_PERIOD) / GAIT_PERIOD
sin_phase = np.sin(2 * np.pi * phase)
cos_phase = np.cos(2 * np.pi * phase)
# Build 47D observation vector
# [0:3] - angular velocity (scaled)
# [3:6] - gravity orientation
# [6:9] - velocity command (scaled)
# [9:21] - joint positions (12D, relative to default)
# [21:33] - joint velocities (12D, scaled)
# [33:45] - previous actions (12D)
# [45] - sin_phase
# [46] - cos_phase
self.locomotion_obs[0:3] = ang_vel_scaled
self.locomotion_obs[3:6] = gravity_orientation
self.locomotion_obs[6:9] = self.locomotion_cmd * CMD_SCALE * MAX_CMD
self.locomotion_obs[9:21] = qj_obs
self.locomotion_obs[21:33] = dqj_obs
self.locomotion_obs[33:45] = self.locomotion_action
self.locomotion_obs[45] = sin_phase
self.locomotion_obs[46] = cos_phase
# Run policy inference (TorchScript)
obs_tensor = torch.from_numpy(self.locomotion_obs).unsqueeze(0).float()
with torch.no_grad():
action_tensor = self.policy(obs_tensor)
self.locomotion_action = action_tensor.squeeze().numpy()
# Transform action to target joint positions
target_leg_pos = DEFAULT_LEG_ANGLES + self.locomotion_action * LOCOMOTION_ACTION_SCALE
# Debug logging (first 3 iterations)
if self.counter <= 3:
print(f"\n[Unitree RL Debug #{self.counter}]")
print(f" Phase: {phase:.3f} (sin={sin_phase:.3f}, cos={cos_phase:.3f})")
print(f" Cmd (vx, vy, yaw): ({self.locomotion_cmd[0]:.2f}, {self.locomotion_cmd[1]:.2f}, {self.locomotion_cmd[2]:.2f})")
print(f" Action range: [{self.locomotion_action.min():.3f}, {self.locomotion_action.max():.3f}]")
# Send commands to LEG motors (0-11)
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
self.robot.msg.motor_cmd[motor_idx].q = target_leg_pos[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Hold WAIST motors at zero (12, 13, 14)
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
self.robot.msg.motor_cmd[motor_idx].q = 0.0
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Hold ARM motors at initial position (15-28)
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Send command
self.robot.send_action(self.robot.msg)
def _locomotion_thread_loop(self):
"""Background thread that runs the locomotion policy at specified rate."""
logger.info("Locomotion thread started")
while self.locomotion_running:
start_time = time.time()
try:
self.locomotion_run()
except Exception as e:
logger.error(f"Error in locomotion loop: {e}")
import traceback
traceback.print_exc()
# Sleep to maintain control rate
elapsed = time.time() - start_time
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
time.sleep(sleep_time)
logger.info("Locomotion thread stopped")
def start_locomotion_thread(self):
if self.locomotion_running:
logger.warning("Locomotion thread already running")
return
logger.info("Starting locomotion control thread...")
self.locomotion_running = True
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
self.locomotion_thread.start()
logger.info("Locomotion control thread started!")
def stop_locomotion_thread(self):
if not self.locomotion_running:
return
logger.info("Stopping locomotion control thread...")
self.locomotion_running = False
if self.locomotion_thread:
self.locomotion_thread.join(timeout=2.0)
logger.info("Locomotion control thread stopped")
def reset_robot(self):
"""Move legs to default standing position over 2 seconds (arms are captured and held)."""
logger.info("Moving legs to default position...")
total_time = 2.0
num_step = int(total_time / self.robot.control_dt)
# Get current state
robot_state = self.robot.get_observation()
# Capture initial arm positions (to hold during locomotion)
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
self.initial_arm_positions[i] = robot_state.motor_state[motor_idx].q
logger.info(f"Captured initial arm positions: {self.initial_arm_positions[:4]}...")
# Record current leg positions
init_leg_pos = np.zeros(12, dtype=np.float32)
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
init_leg_pos[i] = robot_state.motor_state[motor_idx].q
# Interpolate legs to default position
for step in range(num_step):
alpha = step / num_step
# Interpolate leg positions
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
target_pos = DEFAULT_LEG_ANGLES[i]
self.robot.msg.motor_cmd[motor_idx].q = (
init_leg_pos[i] * (1 - alpha) + target_pos * alpha
)
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Hold waist at zero
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
self.robot.msg.motor_cmd[motor_idx].q = 0.0
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Hold arms at initial position
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
self.robot.lowcmd_publisher.Write(self.robot.msg)
time.sleep(self.robot.control_dt)
logger.info("Reached default leg position")
# Hold position for 2 seconds
logger.info("Holding default position for 2 seconds...")
hold_time = 2.0
num_hold_steps = int(hold_time / self.robot.control_dt)
for _ in range(num_hold_steps):
# Hold legs at default
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
self.robot.msg.motor_cmd[motor_idx].q = DEFAULT_LEG_ANGLES[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Hold waist at zero
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
self.robot.msg.motor_cmd[motor_idx].q = 0.0
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Hold arms at initial position
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
self.robot.lowcmd_publisher.Write(self.robot.msg)
time.sleep(self.robot.control_dt)
logger.info("Ready to start locomotion!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Unitree RL 12-DOF Locomotion Controller for Unitree G1")
parser.add_argument(
"--repo-id",
type=str,
default=DEFAULT_REPO_ID,
help=f"Hugging Face Hub repo ID for policy (default: {DEFAULT_REPO_ID})",
)
parser.add_argument(
"--filename",
type=str,
default="motion.pt",
help="Policy filename (default: motion.pt)",
)
args = parser.parse_args()
# Load policy
policy = load_torchscript_policy(repo_id=args.repo_id, filename=args.filename)
# Initialize robot
config = UnitreeG1Config()
robot = UnitreeG1(config)
# Initialize locomotion controller
locomotion_controller = UnitreeRLLocomotionController(
policy=policy,
robot=robot,
config=config,
)
# Reset robot and start locomotion thread
try:
locomotion_controller.reset_robot()
locomotion_controller.start_locomotion_thread()
# Log status
logger.info("Robot initialized with Unitree RL locomotion policy")
logger.info("Locomotion controller running in background thread")
logger.info("Use remote controller to command velocity:")
logger.info(" Left stick Y: forward/backward")
logger.info(" Left stick X: left/right")
logger.info(" Right stick X: rotate")
logger.info("Press Ctrl+C to stop")
# Keep robot alive
while True:
time.sleep(1.0)
except KeyboardInterrupt:
print("\nStopping locomotion...")
locomotion_controller.stop_locomotion_thread()
print("Done!")
-47
View File
@@ -1,47 +0,0 @@
# Voice Assistant Examples
Voice-enabled robot assistant examples using speech-to-text (STT), and text-to-speech (TTS).
## Overview
These examples demonstrate how to build a voice interface for robot control:
1. **Hold SPACE** → Push-to-talk recording starts
2. **Release SPACE** → Recording stops
3. **STT (Whisper)** → Converts speech to text (high-level task prompt)
4. **Pi0.5** → Generates robot response/utterance
5. **TTS (Kokoro)** → Speaks the response back
## Requirements
```bash
pip install torch transformers sounddevice numpy pynput kokoro>=0.9.2
```
## Usage
### With Pi0.5 Model
```bash
python examples/voice_assistant/voice_assistant_pi05.py \
--pretrained_path path/to/pi05/checkpoint
```
## How It Works
### Pi0.5 Voice Integration
Pi0.5 can generate robot utterances as part of its subtask prediction. The flow:
1. **High-level prompt**: User voice command is transcribed and formatted as a task prompt
2. **Subtask generation**: Pi0.5 autoregressively generates a response
3. **Utterance extraction**: If the response contains `<utterance>...</utterance>` tags, the content is extracted
4. **TTS output**: The response is spoken back to the user
## Configuration Options
| Option | Default | Description |
|--------|---------|-------------|
| `--pretrained_path` | None | Path to Pi0.5 checkpoint |
| `--record_seconds` | 5.0 | Audio recording duration |
| `--max_response_tokens` | 100 | Max tokens in generated response |
@@ -1,336 +0,0 @@
#!/usr/bin/env python
"""
Voice Assistant with Pi0.5: Microphone STT Pi0.5 TTS Speaker
This example demonstrates how to use Pi0.5 as a conversational robot assistant:
1. Hold SPACE to record your voice command
2. Speech-to-text (Whisper) converts speech to text
3. Text is fed as a high-level prompt to Pi0.5
4. Pi0.5 generates a response (robot utterance)
5. Text-to-speech (Kokoro) speaks the response back
Requirements:
pip install torch transformers sounddevice numpy pynput kokoro>=0.9.2
Usage:
python examples/voice_assistant/voice_assistant_pi05.py \
--pretrained_path lerobot/pi0.5-base
"""
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import argparse
import re
import subprocess
import threading
import time
import numpy as np
import sounddevice as sd
import torch
from pynput import keyboard
from transformers import AutoTokenizer, WhisperForConditionalGeneration, WhisperProcessor
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.modeling_pi05 import PI05Pytorch
SAMPLE_RATE = 16000
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
class Pi05VoiceAssistant:
"""Voice assistant using Pi0.5 for generating robot utterances."""
def __init__(
self,
pretrained_path: str | None = None,
max_response_tokens: int = 100,
max_record_seconds: float = 30.0,
):
self.device = get_device()
self.dtype = torch.float32 if self.device.type == "mps" else torch.bfloat16
self.max_response_tokens = max_response_tokens
self.max_record_seconds = max_record_seconds
# Push-to-talk state
self._recording = False
self._audio_chunks: list[np.ndarray] = []
self._stream: sd.InputStream | None = None
print(f"Using device: {self.device}")
self._load_models(pretrained_path)
def _load_models(self, pretrained_path: str | None):
print("Loading STT (Whisper tiny)...")
self.stt_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
self.stt_model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-tiny.en", torch_dtype=self.dtype
).to(self.device)
print("Loading Pi0.5 model...")
self._load_pi05(pretrained_path)
print("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
self._load_tts()
print("Ready!\n")
def _load_pi05(self, pretrained_path: str | None):
"""Load Pi0.5 model for utterance generation."""
config = PI05Config()
config.dtype = "float32" if self.device.type == "mps" else "bfloat16"
self.pi05_model = PI05Pytorch(config)
if pretrained_path:
try:
from safetensors.torch import load_file
state_dict = load_file(f"{pretrained_path}/model.safetensors")
self.pi05_model.load_state_dict(state_dict, strict=False)
print(f"✓ Loaded Pi0.5 weights from {pretrained_path}")
except Exception as e:
print(f"Warning: Could not load pretrained weights: {e}")
print("Using randomly initialized model for demo purposes")
self.pi05_model = self.pi05_model.to(self.device)
self.pi05_model.eval()
def _load_tts(self):
try:
print("Loading TTS (Kokoro 82M)...")
from kokoro import KPipeline
self.tts_pipeline = KPipeline(lang_code="a") # American English
self.tts_voice = "af_heart"
self.tts_type = "kokoro"
print("Kokoro loaded!")
except Exception as e:
print(f"Kokoro not available ({e})")
print("Using macOS `say` for TTS")
self.tts_pipeline = None
self.tts_type = "system"
def _audio_callback(self, indata, frames, time_info, status):
"""Callback for audio stream - collects chunks while recording."""
if self._recording:
self._audio_chunks.append(indata.copy())
def _start_recording(self):
"""Start recording audio."""
if self._recording:
return
self._recording = True
self._audio_chunks = []
print("🎤 Recording... (release SPACE to stop)")
def _stop_recording(self) -> np.ndarray | None:
"""Stop recording and return the audio."""
if not self._recording:
return None
self._recording = False
if not self._audio_chunks:
return None
audio = np.concatenate(self._audio_chunks, axis=0).flatten()
duration = len(audio) / SAMPLE_RATE
volume = np.abs(audio).max()
print(f"Recorded {duration:.1f}s, volume: {volume:.4f}")
if volume < 0.001:
print("⚠️ Very low audio - check microphone permissions!")
return None
return audio
def wait_for_spacebar(self) -> np.ndarray | None:
"""Wait for spacebar press, record while held, return audio on release."""
audio_result = None
recording_done = threading.Event()
def on_press(key):
if key == keyboard.Key.space:
self._start_recording()
def on_release(key):
nonlocal audio_result
if key == keyboard.Key.space and self._recording:
audio_result = self._stop_recording()
recording_done.set()
return False # Stop listener
# Start audio stream
self._stream = sd.InputStream(
samplerate=SAMPLE_RATE,
channels=1,
dtype="float32",
callback=self._audio_callback,
blocksize=int(SAMPLE_RATE * 0.1), # 100ms blocks
)
with self._stream:
print("\n⏳ Press and hold SPACE to speak...")
with keyboard.Listener(on_press=on_press, on_release=on_release) as listener:
# Wait for recording to complete or timeout
recording_done.wait(timeout=self.max_record_seconds)
if self._recording:
audio_result = self._stop_recording()
return audio_result
def transcribe(self, audio: np.ndarray) -> str:
start = time.perf_counter()
inputs = self.stt_processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt")
input_features = inputs.input_features.to(self.device, dtype=self.dtype)
tokens = self.stt_model.generate(input_features)
text = self.stt_processor.batch_decode(tokens, skip_special_tokens=True)[0]
print(f"STT: {time.perf_counter() - start:.2f}s")
return text.strip()
def _create_dummy_images(self, batch_size: int = 1) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""Create placeholder images for Pi0.5 when no camera is available."""
image_shape = (batch_size, 3, 224, 224)
dummy_image = torch.zeros(image_shape, dtype=torch.float32, device=self.device)
dummy_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device)
return [dummy_image], [dummy_mask]
def _tokenize_prompt(self, text: str) -> tuple[torch.Tensor, torch.Tensor]:
"""Tokenize the user prompt for Pi0.5."""
prompt = f"User request: {text}\nRobot response:"
tokenized = self.tokenizer(
[prompt],
max_length=200,
truncation=True,
padding="max_length",
return_tensors="pt",
)
tokens = tokenized["input_ids"].to(self.device)
masks = tokenized["attention_mask"].to(self.device, dtype=torch.bool)
return tokens, masks
def generate_response(self, user_text: str) -> str:
"""Generate robot utterance using Pi0.5's language generation."""
start = time.perf_counter()
images, img_masks = self._create_dummy_images()
tokens, masks = self._tokenize_prompt(user_text)
with torch.no_grad():
generated_tokens = self.pi05_model._generate_subtask_tokens(
images=images,
img_masks=img_masks,
tokens=tokens,
masks=masks,
tokenizer=self.tokenizer,
max_length=self.max_response_tokens,
device=self.device,
)
# Decode generated tokens
valid_tokens = generated_tokens[0][generated_tokens[0] != 0]
response = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
# Extract utterance if marked with special tokens
response = self._extract_utterance(response)
print(f"Pi0.5: {time.perf_counter() - start:.2f}s")
return response.strip()
def _extract_utterance(self, text: str) -> str:
"""Extract utterance from between <utterance> tokens if present."""
pattern = r"<utterance>(.*?)</utterance>"
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(1).strip()
return text
def speak(self, text: str):
start = time.perf_counter()
if self.tts_type == "kokoro":
generator = self.tts_pipeline(text, voice=self.tts_voice)
audio_chunks = [audio for _, _, audio in generator]
if audio_chunks:
audio = np.concatenate(audio_chunks)
sd.play(audio, 24000)
sd.wait()
else:
subprocess.run(["say", text], check=True)
print(f"TTS: {time.perf_counter() - start:.2f}s")
def run(self):
print("=" * 50)
print("Pi0.5 Voice Assistant")
print("=" * 50)
print("• Hold SPACE to record your voice command")
print("• Release SPACE when done speaking")
print("• Press Ctrl+C to exit")
print("=" * 50)
while True:
try:
audio = self.wait_for_spacebar()
if audio is None:
print("(no audio captured)\n")
continue
user_text = self.transcribe(audio)
if not user_text:
print("(no speech detected)\n")
continue
print(f"You: {user_text}")
response = self.generate_response(user_text)
print(f"Robot: {response}\n")
self.speak(response)
except KeyboardInterrupt:
print("\nGoodbye!")
break
def main():
parser = argparse.ArgumentParser(description="Pi0.5 Voice Assistant")
parser.add_argument(
"--pretrained_path",
type=str,
default=None,
help="Path to pretrained Pi0.5 model (optional)",
)
parser.add_argument(
"--max_response_tokens",
type=int,
default=100,
help="Maximum tokens in generated response",
)
parser.add_argument(
"--max_record_seconds",
type=float,
default=30.0,
help="Maximum recording duration in seconds",
)
args = parser.parse_args()
assistant = Pi05VoiceAssistant(
pretrained_path=args.pretrained_path,
max_response_tokens=args.max_response_tokens,
max_record_seconds=args.max_record_seconds,
)
assistant.run()
if __name__ == "__main__":
main()
-27
View File
@@ -1,27 +0,0 @@
{
"repo_id": "local",
"vocab_size": 1024,
"scale": 10.0,
"encoded_dims": "0:7",
"encoded_dim_ranges": [
[
0,
7
]
],
"total_encoded_dims": 7,
"delta_dims": null,
"delta_dim_list": null,
"use_delta_transform": false,
"state_key": "observation.state",
"normalization_mode": "QUANTILES",
"action_horizon": 10,
"num_training_chunks": 25065,
"compression_stats": {
"compression_ratio": 3.464660463274599,
"mean_token_length": 20.204,
"p99_token_length": 36.00999999999999,
"min_token_length": 5.0,
"max_token_length": 38.0
}
}
@@ -1,158 +0,0 @@
import logging
from typing import ClassVar
import numpy as np
from scipy.fft import dct
from scipy.fft import idct
from tokenizers import ByteLevelBPETokenizer
from tokenizers.trainers import BpeTrainer
from transformers import PreTrainedTokenizerFast
from transformers.processing_utils import ProcessorMixin
class UniversalActionProcessor(ProcessorMixin):
attributes: ClassVar[list[str]] = ["bpe_tokenizer"]
bpe_tokenizer_class: str = "AutoTokenizer"
def __init__(
self,
bpe_tokenizer: PreTrainedTokenizerFast,
scale: float = 10,
vocab_size: int = 1024,
min_token: int = 0,
*,
action_dim: int | None = None,
time_horizon: int | None = None,
):
self.scale = scale
self.vocab_size = vocab_size
self.min_token = min_token
# Action horizon and dimension needed during decoding. These can be specified
# in three ways (in order of priority):
# 1. passed in as kwargs to decode()
# 2. in the constructor
# 3. cached from the last time decode() was called
self.time_horizon = time_horizon
self.action_dim = action_dim
self.called_time_horizon = time_horizon
self.called_action_dim = action_dim
super().__init__(bpe_tokenizer)
def __call__(self, action_chunk: np.array) -> np.array:
assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]"
if action_chunk.ndim == 2:
action_chunk = action_chunk[None, ...]
# Cache the time horizon and action dimension for decoding
self.called_time_horizon = action_chunk.shape[-2]
self.called_action_dim = action_chunk.shape[-1]
dct_coeff = dct(action_chunk, axis=1, norm="ortho")
dct_coeff = np.around(dct_coeff * self.scale)
tokens = []
for elem in dct_coeff:
token_str = "".join(map(chr, np.maximum(elem.flatten() - self.min_token, 0).astype(int)))
tokens.append(self.bpe_tokenizer(token_str)["input_ids"])
return tokens
def decode(
self,
tokens: list[list[int]],
*,
time_horizon: int | None = None,
action_dim: int | None = None,
) -> np.array:
self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon
self.action_dim = action_dim or self.action_dim or self.called_action_dim
# Cache the time horizon and action dimension for the next call
self.called_time_horizon = self.time_horizon
self.called_action_dim = self.action_dim
assert (
self.time_horizon is not None and self.action_dim is not None
), "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
decoded_actions = []
for token in tokens:
try:
decoded_tokens = self.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
assert (
decoded_dct_coeff.shape
== (
self.time_horizon,
self.action_dim,
)
), f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
except Exception as e:
print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}")
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
@classmethod
def fit(
cls,
action_data: list[np.array],
scale: float = 10,
vocab_size: int = 1024,
*,
time_horizon: int | None = None,
action_dim: int | None = None,
) -> "UniversalActionProcessor":
# Run DCT over all inputs
dct_tokens = [dct(a, axis=0, norm="ortho").flatten() for a in action_data]
# Quantize and find min token
max_token = int(np.around(np.concatenate(dct_tokens) * scale).max())
min_token = int(np.around(np.concatenate(dct_tokens) * scale).min())
min_vocab_size = max_token - min_token
assert (
min_vocab_size <= vocab_size
), f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}"
if min_vocab_size + 100 > vocab_size:
logging.warning(
f"Initial alphabet size {min_vocab_size} is almost as large as the vocab"
f"size {vocab_size}, consider increasing vocab size"
)
# Make token iterator for BPE training
def _token_iter():
for tokens in dct_tokens:
rounded_tokens = np.around(tokens * scale) - min_token
rounded_tokens = rounded_tokens.astype(int)
string = "".join(map(chr, rounded_tokens))
yield string
# Train BPE tokenizer
bpe = ByteLevelBPETokenizer()
# Set up the entire range of possible tokens as the initial alphabet
alphabet = [chr(i) for i in range(max_token - min_token + 1)]
trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=2,
show_progress=True,
special_tokens=[],
initial_alphabet=alphabet,
max_token_length=10000,
)
# Train the inner tokenizer (don't use ByteLevelBPETokenizer.train_from_iterator()
# because it doesn't support custom alphabets)
bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer)
return cls(
PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False),
scale=scale,
vocab_size=vocab_size,
min_token=min_token,
time_horizon=time_horizon,
action_dim=action_dim,
)
@@ -1,11 +0,0 @@
{
"action_dim": 7,
"auto_map": {
"AutoProcessor": "processing_action_tokenizer.UniversalActionProcessor"
},
"min_token": -32,
"processor_class": "UniversalActionProcessor",
"scale": 10.0,
"time_horizon": 10,
"vocab_size": 1024
}
@@ -1 +0,0 @@
{}
File diff suppressed because it is too large Load Diff
@@ -1,11 +0,0 @@
{
"added_tokens_decoder": {},
"auto_map": {
"AutoProcessor": "processing_action_tokenizer.UniversalActionProcessor"
},
"clean_up_tokenization_spaces": false,
"extra_special_tokens": {},
"model_max_length": 1000000000000000019884624838656,
"processor_class": "UniversalActionProcessor",
"tokenizer_class": "PreTrainedTokenizerFast"
}
+1
View File
@@ -263,6 +263,7 @@ default.extend-ignore-identifiers-re = [
"ein",
"thw",
"inpt",
"ROBOTIS",
]
# TODO: Uncomment when ready to use
+1 -1
View File
@@ -26,4 +26,4 @@ DEFAULT_OBS_QUEUE_TIMEOUT = 2
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
# TODO: Add all other robots
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower"]
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower", "omx_follower"]
@@ -54,6 +54,7 @@ from lerobot.robots import ( # noqa: F401
bi_so100_follower,
koch_follower,
make_robot_from_config,
omx_follower,
so100_follower,
so101_follower,
)
-8
View File
@@ -58,7 +58,6 @@ from lerobot.datasets.utils import (
load_nested_dataset,
load_stats,
load_tasks,
load_tasks_high_level,
update_chunk_file_indices,
validate_episode_buffer,
validate_frame,
@@ -162,7 +161,6 @@ class LeRobotDatasetMetadata:
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
# self.tasks_high_level = load_tasks_high_level(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
@@ -1052,12 +1050,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks.iloc[task_idx].name
# Optionally add high level task index
if "task_index_high_level" in self.features:
high_level_task_idx = item["task_index_high_level"].item()
item["robot_utterance"] = self.meta.tasks_high_level.iloc[high_level_task_idx]["robot_utterance"]
item["user_prompt"] = self.meta.tasks_high_level.iloc[high_level_task_idx]["user_prompt"]
return item
def __repr__(self):
-4
View File
@@ -60,7 +60,6 @@ VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_TASKS_HIGH_LEVEL_PATH = "meta/tasks_high_level.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
@@ -353,9 +352,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
return tasks
def load_tasks_high_level(local_dir: Path) -> pandas.DataFrame:
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_HIGH_LEVEL_PATH)
return tasks
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
@@ -23,6 +23,8 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import OBS_IMAGES
DEFAULT_IMAGE_SIZE = 224
@PreTrainedConfig.register_subclass("pi0")
@dataclass
@@ -51,7 +53,10 @@ class PI0Config(PreTrainedConfig):
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE,
DEFAULT_IMAGE_SIZE,
) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
empty_cameras: int = 0
+14 -13
View File
@@ -41,7 +41,7 @@ else:
PaliGemmaForConditionalGeneration = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import (
@@ -337,6 +337,7 @@ class PaliGemmaWithExpertModel(
action_expert_config,
use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16",
image_size: int = DEFAULT_IMAGE_SIZE,
):
if use_adarms is None:
use_adarms = [False, False]
@@ -356,6 +357,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.image_size = image_size
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
@@ -519,11 +521,17 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
if config.image_resolution[0] != config.image_resolution[1]:
raise ValueError(
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(
paligemma_config,
action_expert_config,
use_adarms=[False, False],
precision=config.dtype,
image_size=config.image_resolution[0],
)
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
@@ -812,16 +820,13 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
)
dt = -1.0 / num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
@@ -846,15 +851,11 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
x_t = x_t + dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
def denoise_step(
@@ -1,196 +0,0 @@
# FAST Tokenizer Training for LeRobotDataset
This directory contains tools for training a FAST (Factorized Action Sequence Tokenizer) on LeRobot datasets.
## Files
- **`train_fast_tokenizer.py`**: Main training script (refactored for LeRobotDataset)
- **`train_fast_tokenizer_example.md`**: Usage examples and parameter documentation
- **`MIGRATION_NOTES.md`**: Migration guide from B1K to LeRobotDataset
## Quick Start
```bash
# Basic usage
python train_fast_tokenizer.py \
--repo_id "lerobot/aloha_sim_insertion_human" \
--action_horizon 10 \
--encoded_dims "0:14"
# With delta transform
python train_fast_tokenizer.py \
--repo_id "lerobot/aloha_sim_insertion_human" \
--action_horizon 10 \
--encoded_dims "0:14" \
--delta_dims "0,1,2,3,4,5,6,7,8,9,10,11,12,13" \
--state_key "observation.state" \
--vocab_size 1024
```
## What is FAST?
FAST is a tokenizer for robotic action sequences that:
1. Applies DCT (Discrete Cosine Transform) to action chunks
2. Quantizes DCT coefficients
3. Uses BPE (Byte-Pair Encoding) to compress the quantized sequence
4. Achieves high compression ratios (e.g., 10-20x) while maintaining accuracy
This enables efficient storage and processing of long action sequences in vision-language-action models.
## Requirements
- Python 3.10+
- LeRobot dataset (either local or from HuggingFace Hub)
- transformers (for AutoProcessor)
- numpy
- torch
- tyro
## Workflow
```
LeRobotDataset → Extract Episodes → Apply Delta Transform
Select Dimensions → Normalize (q01, q99) → Create Chunks
Train FAST Tokenizer → Compute Stats → Save
```
## Parameters Guide
### Essential Parameters
- **`repo_id`**: HuggingFace dataset repository ID
- Example: `"lerobot/aloha_sim_insertion_human"`
- **`action_horizon`**: Length of action sequences to tokenize
- Typical: 10-16 steps
- **`encoded_dims`**: Which action dimensions to encode
- Format: `"start:end,start:end"`
- Example: `"0:7"` = dimensions 0-6
- Example: `"0:3,7:10"` = dimensions 0-2 and 7-9
### Optional Parameters
- **`delta_dims`**: Apply delta transform (action - state) to these dimensions
- Format: `"0,1,2,3,4,5"`
- Use for position-based actions
- **`state_key`**: Dataset key containing state observations
- Default: `"observation.state"`
- **`vocab_size`**: BPE vocabulary size
- Default: 1024
- Larger = better compression but more memory
- **`scale`**: DCT quantization scale
- Default: 10.0
- Smaller = finer quantization, larger = coarser
- **`sample_fraction`**: Fraction of action chunks to use per episode
- Default: 0.1 (10%)
- Increase for small datasets, decrease for large datasets
## Output
The script creates a directory (default: `./fast_tokenizer_{repo_id}`) containing:
1. **Tokenizer files**: Can be loaded with `AutoProcessor.from_pretrained()`
2. **`metadata.json`**: Contains:
- Training configuration
- Compression statistics
- Dataset information
## Example Output
```
Loading dataset: lerobot/aloha_sim_insertion_human
Dataset loaded: 50 episodes, 5000 frames
Encoding 14 dimensions: 0:14
Delta dimensions: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
Action horizon: 10
Processing 50 episodes...
Collected 4500 action chunks
Extracted 14 encoded dimensions
Before normalization - overall stats:
Min: -2.3451, Max: 3.1234, Mean: 0.0234, Std: 0.8765
Applied quantile normalization [q01, q99] → [-1, 1]
After normalization - overall stats:
Min: -1.0000, Max: 1.0000, Mean: 0.0156, Std: 0.4321
Training FAST tokenizer on 4500 action chunks...
Action chunk shape: (4500, 10, 14)
Vocab size: 1024
DCT scale: 10.0
✓ Tokenizer training complete!
Compression Statistics:
Average compression ratio: 14.23x
Mean token length: 9.8
P99 token length: 15
Min token length: 6
Max token length: 18
✅ Saved FAST tokenizer to ./fast_tokenizer_lerobot_aloha_sim_insertion_human
```
## Using the Trained Tokenizer
```python
from transformers import AutoProcessor
# Load tokenizer
tokenizer = AutoProcessor.from_pretrained(
"./fast_tokenizer_lerobot_aloha_sim_insertion_human",
trust_remote_code=True
)
# Encode action chunk [horizon, action_dim]
action_chunk = np.random.randn(10, 14) # Example
tokens = tokenizer(action_chunk[None])[0] # Returns token IDs
# Decode tokens back to actions
reconstructed = tokenizer.decode(tokens)
```
## Tips
1. **Start Small**: Use `--max_episodes 10` for initial testing
2. **Check Dimensions**: Verify encoded dimensions match your robot's action space
3. **Delta Transform**: Use for position-based actions, not velocity-based
4. **Normalization**: Ensure dataset has proper statistics computed
5. **Compression Ratio**: Aim for 10-20x for good balance of compression and accuracy
## Troubleshooting
**Issue**: "No normalization stats found"
- **Solution**: Compute dataset statistics first, or use raw actions
**Issue**: "Episode too short for action horizon"
- **Solution**: Reduce `--action_horizon` or filter short episodes
**Issue**: "State key not found"
- **Solution**: Check dataset features and use correct `--state_key`
**Issue**: Memory error with large datasets
- **Solution**: Reduce `--sample_fraction` or `--max_episodes`
## Citation
If you use FAST in your research, please cite:
```bibtex
@article{black2023fast,
title={FAST: Factorized Action Sequence Tokenizer for Vision-Language-Action Models},
author={Black, Kevin and others},
journal={arXiv preprint},
year={2023}
}
```
@@ -22,6 +22,8 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
DEFAULT_IMAGE_SIZE = 224
@PreTrainedConfig.register_subclass("pi05")
@dataclass
@@ -37,11 +39,6 @@ class PI05Config(PreTrainedConfig):
# Shorter state and action vectors will be padded to these dimensions
max_state_dim: int = 32
max_action_dim: int = 32
max_action_tokens: int = 32
fast_vocab_size: int = 2048
# FAST-only mode: train with only discrete action token prediction (no flow matching, no subtask)
fast_only: bool = False
# Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10
@@ -55,7 +52,10 @@ class PI05Config(PreTrainedConfig):
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE,
DEFAULT_IMAGE_SIZE,
) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
empty_cameras: int = 0
@@ -65,8 +65,8 @@ class PI05Config(PreTrainedConfig):
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD, # Pi0.5 uses quantiles for state
"ACTION": NormalizationMode.MEAN_STD, # Pi0.5 uses quantiles for action
"STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state
"ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action
}
)
-21
View File
@@ -1,21 +0,0 @@
lerobot-train \
--dataset.repo_id=lerobot \
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
--output_dir=/fsx/jade_choghari/outputs/pi0test1 \
--job_name=pi0_training \
--policy.repo_id=jade_choghari/pi0-base \
--policy.path=/fsx/jade_choghari/outputs/pi0_fast_fruit1/checkpoints/last/pretrained_model \
--policy.dtype=bfloat16 \
--steps=3000 \
--save_freq=1000 \
--rename_map='{
"observation.images.base": "observation.images.base_0_rgb",
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
}' \
--batch_size=4 \
--policy.device=cuda \
# --wandb.enable=true \
# --wandb.disable_artifact=true \
# --wandb.project=pi05hi-training \
File diff suppressed because it is too large Load Diff
+11 -44
View File
@@ -33,7 +33,6 @@ from lerobot.processor import (
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
ActionTokenizerProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
@@ -48,15 +47,13 @@ from lerobot.utils.constants import (
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
@dataclass
class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
"""
Processor step to prepare the state and tokenize the language input.
"""
max_state_dim: int = 32
task_key: str = "task"
high_level_task_key: str = "user_prompt"
subtask_only_key: str = "subtask"
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
@@ -67,8 +64,6 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
if tasks is None:
raise ValueError("No task found in complementary data")
high_level_tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.high_level_task_key)
# TODO: check if this necessary
state = deepcopy(state)
@@ -81,42 +76,16 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Clean high level tasks first (if available)
cleaned_high_level_tasks = []
if high_level_tasks is not None:
for high_level_task in high_level_tasks:
cleaned_high_level_tasks.append(high_level_task.strip().replace("_", " ").replace("\n", " "))
# Process low level tasks with state information
low_level_prompts = []
subtask_only_prompts = [] # Store only the subtask text for prediction
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
# Store only the subtask text (used as prediction target)
subtask_only_prompts.append(cleaned_text)
if cleaned_high_level_tasks:
cleaned_high_level_task = cleaned_high_level_tasks[i]
full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask: {cleaned_text}"
else:
full_prompt = f"Task: {cleaned_text}, State: {state_str};\n" #remove Action by jade
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompts.append(full_prompt)
low_level_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = low_level_prompts
transition[TransitionKey.COMPLEMENTARY_DATA][self.subtask_only_key] = subtask_only_prompts
# Process high level tasks without state information (if available)
if high_level_tasks is not None:
high_level_prompts = []
for i, cleaned_high_level_task in enumerate(cleaned_high_level_tasks):
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask:"
high_level_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.high_level_task_key] = high_level_prompts
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
return transition
def transform_features(
@@ -159,27 +128,25 @@ def make_pi05_pre_post_processors(
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateAndLanguageTokenizerProcessorStep
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
Pi05PrepareStateAndLanguageTokenizerProcessorStep(max_state_dim=config.max_state_dim),
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
ActionTokenizerProcessorStep(
tokenizer_name="/fsx/jade_choghari/outputs/fast_tokenizer", # TODO: jade put the PI
),
DeviceProcessorStep(device=config.device),
]
@@ -189,7 +156,7 @@ def make_pi05_pre_post_processors(
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
-23
View File
@@ -1,23 +0,0 @@
export CUDA_LAUNCH_BLOCKING=1
lerobot-train \
--dataset.repo_id=local \
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
--output_dir=/fsx/jade_choghari/outputs/pi0_fast_fruit2 \
--job_name=pi0_training \
--policy.repo_id=jade_choghari/pi0-base1 \
--policy.path=lerobot/pi05_base \
--policy.dtype=bfloat16 \
--steps=200000 \
--save_freq=5000 \
--rename_map='{
"observation.images.base": "observation.images.base_0_rgb",
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
}' \
--batch_size=16 \
--policy.device=cuda \
--policy.fast_only=true \
# --wandb.enable=true \
# --wandb.disable_artifact=true \
# --wandb.project=pi05hi-training \
# /fsx/jade_choghari/.cache/huggingface/lerobot/jadechoghari/collect-data
-13
View File
@@ -1,13 +0,0 @@
rm -rf /fsx/jade_choghari/outputs/pi0_multi_training
lerobot-train \
--dataset.repo_id=local\
--dataset.root=/fsx/jade_choghari/data/libero \
--output_dir=/fsx/jade_choghari/outputs/pi0_multi_training \
--job_name=pi0_multi_training \
--policy.repo_id=jadechoghari/pi0-base1 \
--policy.path=/fsx/jade_choghari/outputs/libero_training_fast_6/checkpoints/last/pretrained_model/ \
--policy.dtype=bfloat16 \
--steps=50000 \
--save_freq=5000 \
--batch_size=4 \
--policy.device=cuda \
-12
View File
@@ -1,12 +0,0 @@
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
--repo_id "local" \
--root /fsx/jade_choghari/data/libero \
--action_horizon 10 \
--encoded_dims "0:7" \
--vocab_size 1024 \
--push_to_hub \
--hub_repo_id jadechoghari/fast-libero-tokenizer-quantiles \
--normalization_mode QUANTILES \
# python train_fast_tokenizer.py --repo_id my_dataset
@@ -1,533 +0,0 @@
"""Train FAST tokenizer for action encoding.
This script:
1. Loads action chunks from LeRobotDataset (with sampling)
2. Applies delta transforms and per-timestamp normalization
3. Trains FAST tokenizer on specified action dimensions
4. Saves tokenizer to assets directory
5. Reports compression statistics
"""
import json
import numpy as np
import tyro
from pathlib import Path
from transformers import AutoProcessor
import torch
from huggingface_hub import HfApi
from lerobot.configs.types import NormalizationMode
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: list[int] | None) -> np.ndarray:
"""Apply delta transform to specified dimensions.
Args:
state: Current state [D]
actions: Future actions [D]
delta_dims: List of dimension indices to apply delta transform to
Returns:
Transformed actions [D]
"""
if delta_dims is None or len(delta_dims) == 0:
return actions
delta_actions = actions.copy()
for dim in delta_dims:
delta_actions[dim] = actions[dim] - state[dim]
return delta_actions
def apply_normalization(
data: np.ndarray,
stats: dict[str, np.ndarray],
mode: NormalizationMode,
eps: float = 1e-8,
) -> np.ndarray:
"""Apply normalization to data based on the specified mode.
Args:
data: Data to normalize [N, H, D] or [D]
stats: Dictionary of statistics (mean, std, min, max, q01, q99, q10, q90)
mode: Normalization mode to apply
eps: Small epsilon for numerical stability
Returns:
Normalized data with the same shape as input
"""
if mode == NormalizationMode.IDENTITY:
return data
if mode == NormalizationMode.MEAN_STD:
mean = stats.get("mean")
std = stats.get("std")
if mean is None or std is None:
raise ValueError("MEAN_STD mode requires 'mean' and 'std' in stats")
return (data - mean) / np.maximum(std, eps)
if mode == NormalizationMode.MIN_MAX:
min_val = stats.get("min")
max_val = stats.get("max")
if min_val is None or max_val is None:
raise ValueError("MIN_MAX mode requires 'min' and 'max' in stats")
denom = np.maximum(max_val - min_val, eps)
return 2.0 * (data - min_val) / denom - 1.0
if mode == NormalizationMode.QUANTILES:
q01 = stats.get("q01")
q99 = stats.get("q99")
if q01 is None or q99 is None:
raise ValueError("QUANTILES mode requires 'q01' and 'q99' in stats")
denom = np.maximum(q99 - q01, eps)
# Clip to quantile range then normalize to [-1, 1]
clipped = np.clip(data, q01, q99)
return 2.0 * (clipped - q01) / denom - 1.0
if mode == NormalizationMode.QUANTILE10:
q10 = stats.get("q10")
q90 = stats.get("q90")
if q10 is None or q90 is None:
raise ValueError("QUANTILE10 mode requires 'q10' and 'q90' in stats")
denom = np.maximum(q90 - q10, eps)
# Clip to quantile range then normalize to [-1, 1]
clipped = np.clip(data, q10, q90)
return 2.0 * (clipped - q10) / denom - 1.0
raise ValueError(f"Unsupported normalization mode: {mode}")
def process_episode(args):
"""Process single episode and return action chunks."""
dataset, ep_idx, action_horizon, delta_dims, sample_fraction, state_key, use_delta_transform = args
try:
# Get episode info
ep_info = dataset.meta.episodes[ep_idx]
from_idx = ep_info["dataset_from_index"]
to_idx = ep_info["dataset_to_index"]
ep_length = to_idx - from_idx
if ep_length < action_horizon:
return None
# Load all frames in episode
# If dataset has episode filtering, we need to use the mapping
states = []
actions = []
for abs_idx in range(from_idx, to_idx):
# Map absolute index to relative index if needed
if dataset._absolute_to_relative_idx is not None:
if abs_idx not in dataset._absolute_to_relative_idx:
# This episode's frames aren't in the filtered dataset
return None
rel_idx = dataset._absolute_to_relative_idx[abs_idx]
else:
rel_idx = abs_idx
frame = dataset.hf_dataset[rel_idx]
# Get state (could be from observation.state or other state key)
if state_key in frame:
state = frame[state_key].numpy() if torch.is_tensor(frame[state_key]) else np.array(frame[state_key])
else:
# If no state key, use zeros (no delta transform)
state = np.zeros_like(frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"]))
action = frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"])
states.append(state)
actions.append(action)
states = np.array(states)
actions = np.array(actions)
# Create action chunks (sliding window)
# All actions in a chunk are relative to the FIRST state in that chunk
action_chunks = []
for i in range(len(states) - action_horizon + 1):
current_state = states[i] # First state in chunk
future_absolute_actions = actions[i:i + action_horizon]
if use_delta_transform:
# Relative actions
delta_chunk = np.zeros_like(future_absolute_actions)
for t in range(action_horizon):
delta_chunk[t] = apply_delta_transform(
current_state,
future_absolute_actions[t],
delta_dims,
)
action_chunks.append(delta_chunk)
else:
# Absolute actions (NO delta)
action_chunks.append(future_absolute_actions)
if len(action_chunks) == 0:
return None
action_chunks = np.array(action_chunks)
# Sample chunks
if sample_fraction < 1.0:
n_chunks = len(action_chunks)
n_samples = max(1, int(n_chunks * sample_fraction))
episode_seed = hash(ep_idx) % (2**31)
rng = np.random.RandomState(episode_seed)
indices = rng.choice(n_chunks, size=n_samples, replace=False)
action_chunks = action_chunks[indices]
return action_chunks
except Exception as e:
print(f"Error processing episode {ep_idx}: {e}")
import traceback
traceback.print_exc()
return None
def train_fast_tokenizer(
action_chunks: np.ndarray,
vocab_size: int = 1024,
scale: float = 10.0,
) -> AutoProcessor:
"""
Train FAST tokenizer (BPE on DCT coefficients) on action chunks.
Uses the .fit() method to train a new tokenizer on the provided data.
Args:
action_chunks: Array of action chunks [N, H, D] where N=num_chunks, H=horizon, D=action_dim
vocab_size: BPE vocabulary size
scale: DCT scaling factor for quantization
Returns:
Trained FAST tokenizer
"""
print(f"Training FAST tokenizer on {len(action_chunks)} action chunks...")
print(f"Action chunk shape: {action_chunks.shape}")
print(f"Vocab size: {vocab_size}")
print(f"DCT scale: {scale}")
# Download the tokenizer source code (not pretrained weights)
# We'll train a new tokenizer on our own data
base_tokenizer = AutoProcessor.from_pretrained(
"physical-intelligence/fast",
trust_remote_code=True
)
# Convert action_chunks array to list of arrays (expected by .fit())
action_data_list = [action_chunks[i] for i in range(len(action_chunks))]
# Train the new tokenizer on our action data using .fit()
# This trains the BPE tokenizer on DCT coefficients
print("Training new tokenizer (this may take a few minutes)...")
tokenizer = base_tokenizer.fit(
action_data_list,
scale=scale,
vocab_size=vocab_size,
time_horizon=action_chunks.shape[1], # action_horizon
action_dim=action_chunks.shape[2], # encoded dimensions
)
print("✓ Tokenizer training complete!")
# Validate it works
sample_chunk = action_chunks[0]
encoded = tokenizer(sample_chunk[None])[0]
if isinstance(encoded, list):
encoded = np.array(encoded)
print(f"Sample encoding: {len(encoded)} tokens for chunk shape {sample_chunk.shape}")
return tokenizer
def compute_compression_stats(tokenizer, action_chunks: np.ndarray):
"""Compute compression statistics."""
print("\nComputing compression statistics...")
# Sample for stats (use max 1000 chunks for speed)
sample_size = min(1000, len(action_chunks))
sample_indices = np.random.RandomState(42).choice(len(action_chunks), size=sample_size, replace=False)
sample_chunks = action_chunks[sample_indices]
token_lengths = []
for chunk in sample_chunks:
encoded = tokenizer(chunk[None])[0]
if isinstance(encoded, list):
token_lengths.append(len(encoded))
else:
token_lengths.append(encoded.shape[0] if hasattr(encoded, 'shape') else len(encoded))
token_lengths = np.array(token_lengths)
# Compression ratio: (H * D) / avg_tokens
input_size = action_chunks.shape[1] * action_chunks.shape[2]
avg_tokens = np.mean(token_lengths)
compression_ratio = input_size / avg_tokens
stats = {
'compression_ratio': float(compression_ratio),
'mean_token_length': float(np.mean(token_lengths)),
'p99_token_length': float(np.percentile(token_lengths, 99)),
'min_token_length': float(np.min(token_lengths)),
'max_token_length': float(np.max(token_lengths)),
}
print(f"Compression Statistics:")
print(f" Average compression ratio: {stats['compression_ratio']:.2f}x")
print(f" Mean token length: {stats['mean_token_length']:.1f}")
print(f" P99 token length: {stats['p99_token_length']:.0f}")
print(f" Min token length: {stats['min_token_length']:.0f}")
print(f" Max token length: {stats['max_token_length']:.0f}")
return stats
def main(
repo_id: str,
root: str | None = None,
action_horizon: int = 10,
max_episodes: int | None = None,
sample_fraction: float = 0.1,
encoded_dims: str = "0:6,7:23",
delta_dims: str | None = None,
use_delta_transform: bool = False,
state_key: str = "observation.state",
normalization_mode: str = "QUANTILES",
vocab_size: int = 1024,
scale: float = 10.0,
output_dir: str | None = None,
push_to_hub: bool = False,
hub_repo_id: str | None = None,
hub_private: bool = False,
):
"""
Train FAST tokenizer for action encoding.
Args:
repo_id: LeRobot dataset repository ID
root: Root directory for dataset (default: ~/.cache/huggingface/lerobot)
action_horizon: Number of future actions in each chunk
max_episodes: Max episodes to use (None = all episodes in dataset)
sample_fraction: Fraction of chunks to sample per episode
encoded_dims: Comma-separated dimension ranges to encode (e.g., "0:6,7:23")
delta_dims: Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5")
use_delta_transform: Whether to apply delta transform (relative actions vs absolute actions)
state_key: Dataset key for state observations (default: "observation.state")
normalization_mode: Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY)
vocab_size: FAST vocabulary size (BPE vocab size)
scale: DCT scaling factor (default: 10.0)
output_dir: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id})
push_to_hub: Whether to push the tokenizer to Hugging Face Hub
hub_repo_id: Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name
hub_private: Whether to create a private repository on the Hub
"""
# Load dataset
print(f"Loading dataset: {repo_id}")
dataset = LeRobotDataset(repo_id=repo_id, root=root)
print(f"Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
# Parse normalization mode
try:
norm_mode = NormalizationMode(normalization_mode)
except ValueError:
raise ValueError(
f"Invalid normalization_mode: {normalization_mode}. "
f"Must be one of: {', '.join([m.value for m in NormalizationMode])}"
)
print(f"Normalization mode: {norm_mode.value}")
# Parse encoded dimensions
encoded_dim_ranges = []
for range_str in encoded_dims.split(','):
start, end = map(int, range_str.strip().split(':'))
encoded_dim_ranges.append((start, end))
total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges)
print(f"Encoding {total_encoded_dims} dimensions: {encoded_dims}")
# Parse delta dimensions
delta_dim_list = None
if delta_dims is not None and delta_dims.strip():
delta_dim_list = [int(d.strip()) for d in delta_dims.split(',')]
print(f"Delta dimensions: {delta_dim_list}")
else:
print("No delta dimensions specified")
print(f"Use delta transform: {use_delta_transform}")
if use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0):
print("Warning: use_delta_transform=True but no delta_dims specified. No delta will be applied.")
print(f"Action horizon: {action_horizon}")
print(f"State key: {state_key}")
# Determine episodes to process
num_episodes = dataset.num_episodes
if max_episodes is not None:
num_episodes = min(max_episodes, num_episodes)
print(f"Processing {num_episodes} episodes...")
# Process episodes sequentially (to avoid pickling issues with dataset)
all_chunks = []
for ep_idx in range(num_episodes):
if ep_idx % 10 == 0:
print(f" Processing episode {ep_idx}/{num_episodes}...")
chunks = process_episode(
(dataset, ep_idx, action_horizon, delta_dim_list, sample_fraction, state_key, use_delta_transform)
)
if chunks is not None:
all_chunks.append(chunks)
# Concatenate all chunks
all_chunks = np.concatenate(all_chunks, axis=0)
print(f"Collected {len(all_chunks)} action chunks")
# Extract only encoded dimensions FIRST (before normalization)
encoded_chunks = []
for start, end in encoded_dim_ranges:
encoded_chunks.append(all_chunks[:, :, start:end])
encoded_chunks = np.concatenate(encoded_chunks, axis=-1) # [N, H, D_encoded]
print(f"Extracted {encoded_chunks.shape[-1]} encoded dimensions")
# Apply normalization to encoded dimensions
print(f"\nBefore normalization - overall stats:")
print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}")
print(f" Mean: {np.mean(encoded_chunks):.4f}, Std: {np.std(encoded_chunks):.4f}")
# Get normalization stats from dataset
norm_stats = dataset.meta.stats
if norm_stats is not None and "action" in norm_stats:
action_stats = norm_stats["action"]
# Build encoded dimension indices
encoded_dim_indices = []
for start, end in encoded_dim_ranges:
encoded_dim_indices.extend(range(start, end))
encoded_dim_indices = np.array(encoded_dim_indices)
# Extract stats for encoded dimensions only
encoded_stats = {}
for stat_name, stat_values in action_stats.items():
if isinstance(stat_values, (list, np.ndarray)):
stat_array = np.array(stat_values)
if len(stat_array) > max(encoded_dim_indices):
encoded_stats[stat_name] = stat_array[encoded_dim_indices]
if encoded_stats:
print(f"\nNormalization stats for encoded dimensions (mode: {norm_mode.value}):")
for stat_name, stat_values in encoded_stats.items():
print(f" {stat_name}: shape={stat_values.shape}, "
f"range=[{np.min(stat_values):.4f}, {np.max(stat_values):.4f}]")
# Apply normalization based on mode
try:
encoded_chunks = apply_normalization(
encoded_chunks,
encoded_stats,
norm_mode,
eps=1e-8
)
print(f"\nApplied {norm_mode.value} normalization")
except ValueError as e:
print(f"Warning: {e}. Using raw actions without normalization.")
print(f"\nAfter normalization - overall stats:")
print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}")
print(f" Mean: {np.mean(encoded_chunks):.4f}, Std: {np.std(encoded_chunks):.4f}")
print(f"\nPer-dimension stats (after normalization):")
for d in range(encoded_chunks.shape[-1]):
dim_data = encoded_chunks[:, :, d]
print(f" Dim {d}: min={np.min(dim_data):7.4f}, max={np.max(dim_data):7.4f}, "
f"mean={np.mean(dim_data):7.4f}, std={np.std(dim_data):7.4f}")
else:
print("Warning: Could not extract stats for encoded dimensions, using raw actions")
else:
print("Warning: No normalization stats found in dataset, using raw actions")
print(f"Encoded chunks shape: {encoded_chunks.shape}")
# Train FAST tokenizer
tokenizer = train_fast_tokenizer(
encoded_chunks,
vocab_size=vocab_size,
scale=scale,
)
# Compute compression statistics
compression_stats = compute_compression_stats(tokenizer, encoded_chunks)
# Save tokenizer
if output_dir is None:
output_dir = f"fast_tokenizer_{repo_id.replace('/', '_')}"
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
tokenizer.save_pretrained(output_path)
# Save metadata
metadata = {
'repo_id': repo_id,
'vocab_size': vocab_size,
'scale': scale,
'encoded_dims': encoded_dims,
'encoded_dim_ranges': encoded_dim_ranges,
'total_encoded_dims': total_encoded_dims,
'delta_dims': delta_dims,
'delta_dim_list': delta_dim_list,
'use_delta_transform': use_delta_transform,
'state_key': state_key,
'normalization_mode': norm_mode.value,
'action_horizon': action_horizon,
'num_training_chunks': len(encoded_chunks),
'compression_stats': compression_stats,
}
with open(output_path / "metadata.json", 'w') as f:
json.dump(metadata, f, indent=2)
print(f"\nSaved FAST tokenizer to {output_path}")
print(f"Metadata: {json.dumps(metadata, indent=2)}")
# Push to Hugging Face Hub if requested
if push_to_hub:
# Determine the hub repository ID
if hub_repo_id is None:
hub_repo_id = output_path.name
print(f"\nNo hub_repo_id provided, using: {hub_repo_id}")
print(f"\nPushing tokenizer to Hugging Face Hub: {hub_repo_id}")
print(f" Private: {hub_private}")
try:
# Use the tokenizer's push_to_hub method
tokenizer.push_to_hub(
repo_id=hub_repo_id,
private=hub_private,
commit_message=f"Upload FAST tokenizer trained on {repo_id}"
)
# Also upload the metadata.json file separately
api = HfApi()
api.upload_file(
path_or_fileobj=str(output_path / "metadata.json"),
path_in_repo="metadata.json",
repo_id=hub_repo_id,
repo_type="model",
commit_message="Upload tokenizer metadata"
)
print(f"Successfully pushed tokenizer to: https://huggingface.co/{hub_repo_id}")
except Exception as e:
print(f"Error pushing to hub: {e}")
print(" Make sure you're logged in with `huggingface-cli login`")
if __name__ == "__main__":
tyro.cli(main)
@@ -1,101 +0,0 @@
# Train FAST Tokenizer - Usage Examples
This script trains a FAST (Factorized Action Sequence Tokenizer) on LeRobotDataset action data.
## Basic Usage
```bash
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
--repo_id "lerobot/aloha_sim_insertion_human" \
--action_horizon 10 \
--encoded_dims "0:7" \
--vocab_size 1024 \
--scale 10.0
```
## Parameters
### Required
- `--repo_id`: LeRobot dataset repository ID (e.g., "lerobot/aloha_sim_insertion_human")
### Optional
- `--root`: Root directory for dataset (default: ~/.cache/huggingface/lerobot)
- `--action_horizon`: Number of future actions in each chunk (default: 10)
- `--max_episodes`: Maximum number of episodes to use (default: None = all)
- `--sample_fraction`: Fraction of chunks to sample per episode (default: 0.1)
- `--encoded_dims`: Comma-separated dimension ranges to encode (default: "0:6,7:23")
- Example: "0:7" encodes dimensions 0-6
- Example: "0:3,6:9" encodes dimensions 0-2 and 6-8
- `--delta_dims`: Comma-separated dimension indices for delta transform (default: None)
- Example: "0,1,2,3,4,5" applies delta transform to first 6 dimensions
- Delta transform: action[i] - state[i] for specified dimensions
- `--state_key`: Dataset key for state observations (default: "observation.state")
- `--vocab_size`: FAST vocabulary size / BPE vocab size (default: 1024)
- `--scale`: DCT scaling factor (default: 10.0)
- `--output_dir`: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id})
## Examples
### Example 1: Train on full action space
```bash
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
--repo_id "lerobot/pusht" \
--action_horizon 16 \
--encoded_dims "0:2" \
--vocab_size 512 \
--max_episodes 100
```
### Example 2: Train with delta transform
```bash
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
--repo_id "lerobot/aloha_sim_insertion_human" \
--action_horizon 10 \
--encoded_dims "0:14" \
--delta_dims "0,1,2,3,4,5,6,7,8,9,10,11,12,13" \
--state_key "observation.state" \
--vocab_size 1024 \
--scale 10.0 \
--sample_fraction 0.2
```
### Example 3: Train on subset of dimensions
```bash
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
--repo_id "lerobot/aloha_sim_insertion_human" \
--action_horizon 10 \
--encoded_dims "0:7" \
--vocab_size 1024 \
--output_dir "./my_tokenizer"
```
## Output
The script saves:
1. **Tokenizer files**: Trained FAST tokenizer (can be loaded with `AutoProcessor.from_pretrained()`)
2. **metadata.json**: Contains:
- Configuration parameters
- Compression statistics (compression ratio, token lengths)
- Training dataset information
## Understanding the Process
1. **Load Dataset**: Loads the LeRobotDataset from HuggingFace
2. **Extract Action Chunks**: Creates sliding windows of actions with specified horizon
3. **Apply Delta Transform**: (Optional) Computes action deltas relative to current state
4. **Select Encoded Dimensions**: Extracts only the dimensions to be encoded
5. **Normalize**: Applies quantile normalization ([q01, q99] → [-1, 1])
6. **Train Tokenizer**: Trains BPE tokenizer on DCT coefficients
7. **Compute Stats**: Reports compression ratio and token length statistics
8. **Save**: Saves tokenizer and metadata
## Notes
- **Normalization**: The script uses quantile normalization (q01, q99) from the dataset's statistics
- **Sampling**: To speed up training, you can sample a fraction of chunks per episode
- **Delta Transform**: Applied per-dimension to make actions relative to current state
- **Compression**: FAST uses DCT + BPE to compress action sequences efficiently
-28
View File
@@ -1,28 +0,0 @@
#!/bin/bash
# FSDP training script for PI05 with aggressive memory optimization
# Use this for large models that OOM with standard DDP
accelerate launch --config_file /admin/home/jade_choghari/lerobot/fsdp_config.yaml \
$(which lerobot-train) \
--dataset.repo_id=local \
--dataset.root=/fsx/jade_choghari/data/libero \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fsdp \
--job_name=libero_training_fsdp \
--policy.repo_id=jade_choghari/pi05-fast-libero-fsdp \
--policy.path=/fsx/jade_choghari/models/libero-pi-fast \
--policy.dtype=bfloat16 \
--steps=100000 \
--save_freq=10 \
--batch_size=8 \
--policy.device=cuda \
--policy.fast_only=true \
--policy.scheduler_warmup_steps=2000 \
--policy.scheduler_decay_steps=60000 \
--policy.scheduler_decay_lr=1e-5 \
--policy.gradient_checkpointing=false \
--wandb.enable=true \
--wandb.disable_artifact=true \
--wandb.project=pi05-libero-training-fsdp
-24
View File
@@ -1,24 +0,0 @@
export CUDA_LAUNCH_BLOCKING=1
lerobot-train \
--dataset.repo_id=local \
--dataset.root=/fsx/jade_choghari/data/libero \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_4 \
--job_name=libero_training_fast \
--policy.repo_id=jade_choghari/pi05-fast-libero \
--policy.path=/fsx/jade_choghari/models/pi05-base \
--policy.dtype=bfloat16 \
--steps=100000 \
--save_freq=20000 \
--batch_size=4 \
--policy.device=cuda \
--policy.fast_only=true \
--policy.scheduler_warmup_steps=1000 \
--policy.scheduler_decay_steps=30000 \
--policy.scheduler_decay_lr=1e-5 \
--policy.gradient_checkpointing=true \
--rename_map='{
"observation.images.image1": "observation.images.base_0_rgb",
"observation.images.image2": "observation.images.left_wrist_0_rgb",
}' \
--policy.empty_cameras=1 \
# /fsx/jade_choghari/.cache/huggingface/lerobot/jadechoghari/collect-data
@@ -1,15 +0,0 @@
#!/bin/bash
#SBATCH --job-name=pi05-train
#SBATCH --time=24:00:00
#SBATCH --qos=high
#SBATCH --gres=gpu:8
#SBATCH --mem=256G
#SBATCH --partition=hopper-prod
#SBATCH --output=/fsx/jade_choghari/logs/%x-%j.out
#SBATCH --error=/fsx/jade_choghari/logs/%x-%j.err
srun \
--container-image=/fsx/michel_aractingi/docker_images/huggingface+lerobot-gpu+dev.sqsh \
--container-mounts=/fsx/jade_choghari \
--container-workdir=$HOME/lerobot \
bash /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/train_multi.sh
-36
View File
@@ -1,36 +0,0 @@
#!/bin/bash
set -euxo pipefail
# Source YOUR Miniforge conda (mounted from FSX)
source /fsx/jade_choghari/miniforge3/etc/profile.d/conda.sh
conda activate lerobot
accelerate launch --mixed_precision=bf16 --multi_gpu --num_processes=8 \
$(which lerobot-train) \
--dataset.repo_id=local \
--dataset.root=/fsx/jade_choghari/data/libero \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_mean_1 \
--job_name=libero_training_fast \
--policy.repo_id=jade_choghari/pi05-fast-libero \
--policy.path=/fsx/jade_choghari/models/pi05-base \
--policy.dtype=bfloat16 \
--steps=100000 \
--save_freq=20000 \
--batch_size=4 \
--policy.device=cuda \
--policy.fast_only=true \
--policy.scheduler_warmup_steps=4000 \
--policy.scheduler_decay_steps=100000 \
--policy.scheduler_decay_lr=1e-5 \
--policy.gradient_checkpointing=true \
--policy.chunk_size=10 \
--policy.n_action_steps=10 \
--policy.max_action_tokens=256 \
--rename_map='{
"observation.images.image1": "observation.images.base_0_rgb",
"observation.images.image2": "observation.images.left_wrist_0_rgb",
}' \
--policy.empty_cameras=1 \
--wandb.enable=true \
--wandb.disable_artifact=true \
--wandb.project=pi05-libero-training \
@@ -783,18 +783,15 @@ class VLAFlowMatching(nn.Module):
use_cache=self.config.use_cache,
fill_kv_cache=True,
)
dt = -1.0 / self.config.num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
num_steps = self.config.num_steps
dt = -1.0 / num_steps
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
x_t=input_x_t,
prefix_pad_masks=prefix_pad_masks,
@@ -818,15 +815,11 @@ class VLAFlowMatching(nn.Module):
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
x_t = x_t + dt * v_t
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
def denoise_step(
+1 -1
View File
@@ -75,7 +75,7 @@ from .policy_robot_bridge import (
RobotActionToPolicyActionProcessorStep,
)
from .rename_processor import RenameObservationsProcessorStep
from .tokenizer_processor import TokenizerProcessorStep, ActionTokenizerProcessorStep
from .tokenizer_processor import TokenizerProcessorStep
__all__ = [
"ActionProcessorStep",
+1 -3
View File
@@ -168,12 +168,10 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
user_prompt_key = {"user_prompt": batch["user_prompt"]} if "user_prompt" in batch else {}
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
return {**pad_keys, **task_key, **index_key, **task_index_key, **user_prompt_key, **subtask_key}
return {**pad_keys, **task_key, **index_key, **task_index_key}
def create_transition(
@@ -47,6 +47,7 @@ class RenameObservationsProcessorStep(ObservationProcessorStep):
processed_obs[self.rename_map[key]] = value
else:
processed_obs[key] = value
return processed_obs
def get_config(self) -> dict[str, Any]:
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,21 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# OMX is a fully open-source robot from ROBOTIS.
# More information at: https://ai.robotis.com/omx/introduction_omx.html
from .config_omx_follower import OmxFollowerConfig
from .omx_follower import OmxFollower
@@ -0,0 +1,39 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("omx_follower")
@dataclass
class OmxFollowerConfig(RobotConfig):
# Port to connect to the arm
port: str
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
# Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = False
@@ -0,0 +1,225 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from functools import cached_property
from typing import Any
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.dynamixel import (
DriveMode,
DynamixelMotorsBus,
OperatingMode,
)
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
from .config_omx_follower import OmxFollowerConfig
logger = logging.getLogger(__name__)
class OmxFollower(Robot):
"""
- [OMX](https://github.com/ROBOTIS-GIT/open_manipulator),
expansion, developed by Woojin Wie and Junha Cha from [ROBOTIS](https://ai.robotis.com/)
"""
config_class = OmxFollowerConfig
name = "omx_follower"
def __init__(self, config: OmxFollowerConfig):
super().__init__(config)
self.config = config
norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100
self.bus = DynamixelMotorsBus(
port=self.config.port,
motors={
"shoulder_pan": Motor(11, "xl430-w250", norm_mode_body),
"shoulder_lift": Motor(12, "xl430-w250", norm_mode_body),
"elbow_flex": Motor(13, "xl430-w250", norm_mode_body),
"wrist_flex": Motor(14, "xl330-m288", norm_mode_body),
"wrist_roll": Motor(15, "xl330-m288", norm_mode_body),
"gripper": Motor(16, "xl330-m288", MotorNormMode.RANGE_0_100),
},
calibration=self.calibration,
)
self.cameras = make_cameras_from_configs(config.cameras)
@property
def _motors_ft(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.bus.motors}
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
def connect(self, calibrate: bool = True) -> None:
"""
For OMX robots that come pre-calibrated:
- If default calibration from package doesn't match motors, read from motors and save
- This allows using pre-calibrated robots without manual calibration
- If no calibration file exists, use factory default values (homing_offset=0, range_min=0, range_max=4095)
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect()
if not self.is_calibrated and calibrate:
logger.info(
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
)
self.calibrate()
for cam in self.cameras.values():
cam.connect()
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
def calibrate(self) -> None:
self.bus.disable_torque()
logger.info(f"\nUsing factory default calibration values for {self}")
logger.info(f"\nWriting default configuration of {self} to the motors")
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
for motor in self.bus.motors:
self.bus.write("Drive_Mode", motor, DriveMode.NON_INVERTED.value)
self.calibration = {}
for motor, m in self.bus.motors.items():
self.calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=0,
homing_offset=0,
range_min=0,
range_max=4095,
)
self.bus.write_calibration(self.calibration)
self._save_calibration()
logger.info(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
with self.bus.torque_disabled():
self.bus.configure_motors()
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling
# the arm, you could end up with a servo with a position 0 or 4095 at a crucial point
for motor in self.bus.motors:
if motor != "gripper":
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
# Use 'position control current based' for gripper to be limited by the limit of the current. For
# the follower gripper, it means it can grasp an object without forcing too much even tho, its
# goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
# For the leader gripper, it means we can use it as a physical trigger, since we can force with
# our finger to make it move, and it will move back to its original target position when we
# release the force.
self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
# Set better PID values to close the gap between recorded states and actions
# TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor
self.bus.write("Position_P_Gain", "elbow_flex", 1500)
self.bus.write("Position_I_Gain", "elbow_flex", 0)
self.bus.write("Position_D_Gain", "elbow_flex", 600)
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Read arm position
start = time.perf_counter()
obs_dict = self.bus.sync_read("Present_Position")
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
return obs_dict
def send_action(self, action: dict[str, float]) -> dict[str, float]:
"""Command arm to move to a target joint configuration.
The relative action magnitude may be clipped depending on the configuration parameter
`max_relative_target`. In this case, the action sent differs from original action.
Thus, this function always returns the action actually sent.
Args:
action (dict[str, float]): The goal positions for the motors.
Returns:
dict[str, float]: The action sent to the motors, potentially clipped.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
# Cap goal position when too far away from present position.
# /!\ Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
present_pos = self.bus.sync_read("Present_Position")
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
# Send goal position to the arm
self.bus.sync_write("Goal_Position", goal_pos)
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
logger.info(f"{self} disconnected.")
@@ -16,6 +16,8 @@
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig
_GAINS: dict[str, dict[str, list[float]]] = {
@@ -51,5 +53,11 @@ class UnitreeG1Config(RobotConfig):
control_dt: float = 1.0 / 250.0 # 250Hz
# launch mujoco simulation
is_simulation: bool = False
# socket config for ZMQ bridge
robot_ip: str = "192.168.123.164"
robot_ip: str = "172.18.129.215"
# cameras (optional)
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -0,0 +1,302 @@
#!/usr/bin/env python
"""
Standalone keyboard control script for Unitree G1 robot.
This script provides keyboard-based velocity control for the G1 robot's
locomotion system. It can be run alongside the main robot control to
provide manual movement commands.
Usage:
python keyboard_control.py [--robot-ip IP] [--simulation]
Controls:
W/S: Forward/Backward
A/D: Strafe Left/Right
Q/E: Rotate Left/Right
R/F: Raise/Lower Height (GR00T policies only)
Z: Stop (zero all velocity commands)
ESC/Ctrl+C: Exit
"""
import argparse
import sys
import select
import time
import numpy as np
# Terminal handling for non-blocking keyboard input
try:
import termios
import tty
HAS_TERMIOS = True
except ImportError:
HAS_TERMIOS = False
print("Warning: termios not available. Keyboard controls require Linux/macOS.")
class KeyboardController:
"""Handles keyboard input and converts to locomotion commands."""
def __init__(self, callback=None):
"""
Initialize keyboard controller.
Args:
callback: Optional function called when commands change.
Signature: callback(vx, vy, yaw, height)
"""
self.callback = callback
self.running = False
# Locomotion commands
self.vx = 0.0 # Forward/backward velocity
self.vy = 0.0 # Left/right velocity (strafe)
self.yaw = 0.0 # Rotation rate
self.height = 0.74 # Base height (for GR00T policies)
# Command limits
self.vx_limit = (-0.8, 0.8)
self.vy_limit = (-0.5, 0.5)
self.yaw_limit = (-1.0, 1.0)
self.height_limit = (0.50, 1.00)
# Increments per keypress
self.vx_increment = 0.4
self.vy_increment = 0.25
self.yaw_increment = 0.5
self.height_increment = 0.05
self._old_terminal_settings = None
def get_commands(self) -> tuple[float, float, float, float]:
"""Get current command values as tuple (vx, vy, yaw, height)."""
return (self.vx, self.vy, self.yaw, self.height)
def get_commands_array(self) -> np.ndarray:
"""Get velocity commands as numpy array [vx, vy, yaw]."""
return np.array([self.vx, self.vy, self.yaw], dtype=np.float32)
def reset_commands(self):
"""Reset all commands to zero (stop)."""
self.vx = 0.0
self.vy = 0.0
self.yaw = 0.0
self._notify_callback()
def _clamp(self, value: float, limits: tuple[float, float]) -> float:
"""Clamp value to limits."""
return max(limits[0], min(limits[1], value))
def _notify_callback(self):
"""Call callback with current commands if set."""
if self.callback:
self.callback(self.vx, self.vy, self.yaw, self.height)
def process_key(self, key: str) -> bool:
"""
Process a single key press and update commands.
Args:
key: Single character key that was pressed.
Returns:
True if key was handled, False otherwise.
"""
key = key.lower()
handled = True
if key == 'w':
self.vx = self._clamp(self.vx + self.vx_increment, self.vx_limit)
elif key == 's':
self.vx = self._clamp(self.vx - self.vx_increment, self.vx_limit)
elif key == 'a':
self.vy = self._clamp(self.vy + self.vy_increment, self.vy_limit)
elif key == 'd':
self.vy = self._clamp(self.vy - self.vy_increment, self.vy_limit)
elif key == 'q':
self.yaw = self._clamp(self.yaw + self.yaw_increment, self.yaw_limit)
elif key == 'e':
self.yaw = self._clamp(self.yaw - self.yaw_increment, self.yaw_limit)
elif key == 'r':
self.height = self._clamp(self.height + self.height_increment, self.height_limit)
elif key == 'f':
self.height = self._clamp(self.height - self.height_increment, self.height_limit)
elif key == 'z':
self.reset_commands()
return True # Already notified in reset_commands
else:
handled = False
if handled:
self._notify_callback()
return handled
def _setup_terminal(self):
"""Set terminal to raw mode for single character input."""
if HAS_TERMIOS:
self._old_terminal_settings = termios.tcgetattr(sys.stdin)
tty.setcbreak(sys.stdin.fileno())
def _restore_terminal(self):
"""Restore terminal to original settings."""
if HAS_TERMIOS and self._old_terminal_settings is not None:
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, self._old_terminal_settings)
self._old_terminal_settings = None
def run(self):
"""Run the keyboard listener loop (blocking)."""
if not HAS_TERMIOS:
print("Error: Keyboard controls require termios (Linux/macOS)")
return
self.running = True
self._print_controls()
try:
self._setup_terminal()
while self.running:
# Check for keyboard input with timeout
if select.select([sys.stdin], [], [], 0.1)[0]:
key = sys.stdin.read(1)
# Handle escape sequences (arrow keys, etc.)
if key == '\x1b': # ESC
self.running = False
break
if self.process_key(key):
self._print_status()
except KeyboardInterrupt:
print("\nInterrupted by user")
finally:
self._restore_terminal()
print("\nKeyboard controls stopped")
def stop(self):
"""Stop the keyboard listener."""
self.running = False
def _print_controls(self):
"""Print control instructions."""
print("\n" + "=" * 60)
print("KEYBOARD CONTROLS ACTIVE")
print("=" * 60)
print(" W/S: Forward/Backward")
print(" A/D: Strafe Left/Right")
print(" Q/E: Rotate Left/Right")
print(" R/F: Raise/Lower Height (±5cm)")
print(" Z: Stop (zero all commands)")
print(" ESC: Exit")
print("=" * 60 + "\n")
def _print_status(self):
"""Print current command status."""
print(f"[CMD] vx={self.vx:+.2f}, vy={self.vy:+.2f}, yaw={self.yaw:+.2f} | height={self.height:.3f}m")
class RobotKeyboardController(KeyboardController):
"""Keyboard controller that directly updates a robot's locomotion commands."""
def __init__(self, robot):
"""
Initialize with a UnitreeG1 robot instance.
Args:
robot: UnitreeG1 robot instance with locomotion_cmd attribute.
"""
super().__init__()
self.robot = robot
# Initialize from robot's current state if available
if hasattr(robot, 'locomotion_cmd'):
self.vx = robot.locomotion_cmd[0]
self.vy = robot.locomotion_cmd[1]
self.yaw = robot.locomotion_cmd[2]
if hasattr(robot, 'groot_height_cmd'):
self.height = robot.groot_height_cmd
def _notify_callback(self):
"""Update robot's locomotion commands directly."""
if hasattr(self.robot, 'locomotion_cmd'):
self.robot.locomotion_cmd[0] = self.vx
self.robot.locomotion_cmd[1] = self.vy
self.robot.locomotion_cmd[2] = self.yaw
if hasattr(self.robot, 'groot_height_cmd'):
self.robot.groot_height_cmd = self.height
def start_keyboard_control_thread(robot) -> tuple:
"""
Start keyboard controls for a robot in a background thread.
Args:
robot: UnitreeG1 robot instance.
Returns:
Tuple of (controller, thread) for later stopping.
"""
import threading
controller = RobotKeyboardController(robot)
thread = threading.Thread(target=controller.run, daemon=True)
thread.start()
return controller, thread
def stop_keyboard_control_thread(controller, thread, timeout: float = 2.0):
"""
Stop the keyboard control thread.
Args:
controller: KeyboardController instance.
thread: Thread running the controller.
timeout: Max time to wait for thread to stop.
"""
controller.stop()
thread.join(timeout=timeout)
def main():
"""Standalone keyboard control with optional robot connection."""
parser = argparse.ArgumentParser(description="Keyboard control for Unitree G1")
parser.add_argument("--standalone", action="store_true",
help="Run in standalone mode (just print commands, no robot)")
args = parser.parse_args()
if args.standalone:
# Standalone mode - just demonstrate keyboard input
def print_callback(vx, vy, yaw, height):
print(f" → Would send: vx={vx:+.2f}, vy={vy:+.2f}, yaw={yaw:+.2f}, height={height:.3f}")
controller = KeyboardController(callback=print_callback)
print("Running in STANDALONE mode (no robot connection)")
controller.run()
else:
print("To use with a robot, import and use RobotKeyboardController:")
print("")
print(" from lerobot.robots.unitree_g1.keyboard_control import (")
print(" RobotKeyboardController,")
print(" start_keyboard_control_thread,")
print(" stop_keyboard_control_thread")
print(" )")
print("")
print(" # Start keyboard controls")
print(" controller, thread = start_keyboard_control_thread(robot)")
print("")
print(" # ... robot runs ...")
print("")
print(" # Stop keyboard controls")
print(" stop_keyboard_control_thread(controller, thread)")
print("")
print("Or run with --standalone to test keyboard input without a robot.")
if __name__ == "__main__":
main()
+16 -16
View File
@@ -99,11 +99,12 @@ def state_forward_loop(
lowstate_sub: ChannelSubscriber,
lowstate_sock: zmq.Socket,
state_period: float,
shutdown_event: threading.Event,
) -> None:
"""Read observation from DDS and forward to ZMQ clients."""
last_state_time = 0.0
while True:
while not shutdown_event.is_set():
# read from DDS
msg = lowstate_sub.Read()
if msg is None:
@@ -128,7 +129,10 @@ def cmd_forward_loop(
) -> None:
"""Receive commands from ZMQ and forward to DDS."""
while True:
payload = lowcmd_sock.recv()
try:
payload = lowcmd_sock.recv()
except zmq.ContextTerminated:
break
msg_dict = json.loads(payload.decode("utf-8"))
topic = msg_dict.get("topic", "")
@@ -182,30 +186,26 @@ def main() -> None:
lowstate_sock.bind(f"tcp://0.0.0.0:{LOWSTATE_PORT}")
state_period = 0.002 # ~500 hz
shutdown_event = threading.Event()
# start observation forwarding thread
# start observation forwarding in background thread
t_state = threading.Thread(
target=state_forward_loop,
args=(lowstate_sub, lowstate_sock, state_period),
daemon=True,
args=(lowstate_sub, lowstate_sock, state_period, shutdown_event),
)
t_state.start()
# start action forwarding thread
t_cmd = threading.Thread(
target=cmd_forward_loop,
args=(lowcmd_sock, lowcmd_pub_debug, crc),
daemon=True,
)
t_cmd.start()
print("bridge running (lowstate -> zmq, lowcmd -> dds)")
# keep main thread alive so daemon threads don't exit
# run command forwarding in main thread
try:
while True:
time.sleep(1.0)
cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, crc)
except KeyboardInterrupt:
print("shutting down bridge...")
finally:
shutdown_event.set()
ctx.term() # terminates blocking zmq.recv() calls
t_state.join(timeout=2.0)
if __name__ == "__main__":
+26 -9
View File
@@ -30,12 +30,8 @@ from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
)
from unitree_sdk2py.utils.crc import CRC
from lerobot.envs.factory import make_env
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
from lerobot.robots.unitree_g1.unitree_sdk2_socket import (
ChannelFactoryInitialize,
ChannelPublisher,
ChannelSubscriber,
)
from ..robot import Robot
from .config_unitree_g1 import UnitreeG1Config
@@ -127,7 +123,21 @@ class UnitreeG1(Robot):
self.control_dt = config.control_dt
if config.is_simulation:
from unitree_sdk2py.core.channel import (
ChannelFactoryInitialize,
ChannelPublisher,
ChannelSubscriber,
)
else:
from lerobot.robots.unitree_g1.unitree_sdk2_socket import (
ChannelFactoryInitialize,
ChannelPublisher,
ChannelSubscriber,
)
# connect robot
self.ChannelFactoryInitialize = ChannelFactoryInitialize
self.connect()
# initialize direct motor control interface
@@ -138,8 +148,8 @@ class UnitreeG1(Robot):
self.lowstate_buffer = DataBuffer()
# initialize subscribe thread to read robot state
self._shutdown_event = threading.Event()
self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state)
self.subscribe_thread.daemon = True
self.subscribe_thread.start()
while not self.is_connected:
@@ -174,7 +184,7 @@ class UnitreeG1(Robot):
self.remote_controller = self.RemoteController()
def _subscribe_motor_state(self): # polls robot state @ 250Hz
while True:
while not self._shutdown_event.is_set():
start_time = time.time()
msg = self.lowstate_subscriber.Read()
if msg is not None:
@@ -218,10 +228,17 @@ class UnitreeG1(Robot):
pass
def connect(self, calibrate: bool = True) -> None: # connect to DDS
ChannelFactoryInitialize(0)
if self.config.is_simulation:
self.ChannelFactoryInitialize(0, "lo")
self.mujoco_env = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True)
else:
self.ChannelFactoryInitialize(0)
def disconnect(self):
pass
self._shutdown_event.set()
self.subscribe_thread.join(timeout=2.0)
if self.config.is_simulation:
self.mujoco_env["hub_env"][0].envs[0].kill_sim()
def get_observation(self) -> dict[str, Any]:
return self.lowstate_buffer.get_data()
+4
View File
@@ -28,6 +28,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
from .koch_follower import KochFollower
return KochFollower(config)
elif config.type == "omx_follower":
from .omx_follower import OmxFollower
return OmxFollower(config)
elif config.type == "so100_follower":
from .so100_follower import SO100Follower
+2
View File
@@ -40,6 +40,7 @@ from lerobot.robots import ( # noqa: F401
koch_follower,
lekiwi,
make_robot_from_config,
omx_follower,
so100_follower,
so101_follower,
)
@@ -49,6 +50,7 @@ from lerobot.teleoperators import ( # noqa: F401
homunculus,
koch_leader,
make_teleoperator_from_config,
omx_leader,
so100_leader,
so101_leader,
)
+455 -5
View File
@@ -18,7 +18,8 @@
Edit LeRobot datasets using various transformation tools.
This script allows you to delete episodes, split datasets, merge datasets,
and remove features. When new_repo_id is specified, creates a new dataset.
remove features, and convert image datasets to video format.
When new_repo_id is specified, creates a new dataset.
Usage Examples:
@@ -65,6 +66,25 @@ Remove camera feature:
--operation.type remove_feature \
--operation.feature_names "['observation.images.top']"
Convert image dataset to video format (saves locally):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir /path/to/output/pusht_video
Convert image dataset and save with new repo_id:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_to_video
Convert and push to hub:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_to_video \
--push_to_hub true
Using JSON config file:
python -m lerobot.scripts.lerobot_edit_dataset \
--config_path path/to/edit_config.json
@@ -72,9 +92,13 @@ Using JSON config file:
import logging
import shutil
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
import pandas as pd
from tqdm import tqdm
from lerobot.configs import parser
from lerobot.datasets.dataset_tools import (
delete_episodes,
@@ -82,8 +106,10 @@ from lerobot.datasets.dataset_tools import (
remove_feature,
split_dataset,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import write_stats, write_tasks
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
from lerobot.utils.utils import init_logging
@@ -111,10 +137,23 @@ class RemoveFeatureConfig:
feature_names: list[str] | None = None
@dataclass
class ConvertToVideoConfig:
type: str = "convert_to_video"
output_dir: str | None = None
vcodec: str = "libsvtav1"
pix_fmt: str = "yuv420p"
g: int = 2
crf: int = 30
fast_decode: int = 0
episode_indices: list[int] | None = None
num_workers: int = 4
@dataclass
class EditDatasetConfig:
repo_id: str
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertToVideoConfig
root: str | None = None
new_repo_id: str | None = None
push_to_hub: bool = False
@@ -258,6 +297,415 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
def save_episode_images_for_video(
dataset: LeRobotDataset,
imgs_dir: Path,
img_key: str,
episode_index: int,
num_workers: int = 4,
) -> None:
"""Save images from a specific episode and camera to disk for video encoding.
Args:
dataset: The LeRobot dataset to extract images from
imgs_dir: Directory to save images to
img_key: The image key (camera) to extract
episode_index: Index of the episode to save
num_workers: Number of threads for parallel image saving
"""
# Create directory
imgs_dir.mkdir(parents=True, exist_ok=True)
# Get dataset without torch format for PIL image access
hf_dataset = dataset.hf_dataset.with_format(None)
# Select only this camera's images
imgs_dataset = hf_dataset.select_columns(img_key)
# Get episode start and end indices
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
# Get all items for this episode
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
# Define function to save a single image
def save_single_image(i_item_tuple):
i, item = i_item_tuple
img = item[img_key]
# Use frame-XXXXXX.png format to match encode_video_frames expectations
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
return i
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
items = list(enumerate(episode_dataset))
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(save_single_image, item) for item in items]
for future in as_completed(futures):
future.result() # This will raise any exceptions that occurred
def encode_episode_videos(
dataset: LeRobotDataset,
new_meta: LeRobotDatasetMetadata,
episode_index: int,
vcodec: str,
pix_fmt: str,
g: int,
crf: int,
fast_decode: int,
temp_dir: Path,
num_image_workers: int = 4,
) -> dict[str, dict]:
"""Encode videos for a single episode and return video metadata.
Args:
dataset: Source dataset with images
new_meta: Metadata object for the new video dataset
episode_index: Episode index to process
vcodec: Video codec
pix_fmt: Pixel format
g: Group of pictures size
crf: Constant rate factor
fast_decode: Fast decode tuning
temp_dir: Temporary directory for images
num_image_workers: Number of workers for saving images
Returns:
Dictionary mapping video keys to their metadata (chunk_index, file_index, timestamps)
"""
hf_dataset = dataset.hf_dataset.with_format(None)
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
video_metadata = {}
fps = int(dataset.fps) # Convert to int for PyAV compatibility
episode_length = dataset.meta.episodes["length"][episode_index]
episode_duration = episode_length / dataset.fps # Use original fps for duration calculation
for img_key in img_keys:
# Save images temporarily
imgs_dir = temp_dir / f"episode_{episode_index:06d}" / img_key
save_episode_images_for_video(dataset, imgs_dir, img_key, episode_index, num_image_workers)
# Determine chunk and file indices
# For simplicity, we'll put each episode in its own file
chunk_idx = episode_index // new_meta.chunks_size
file_idx = episode_index % new_meta.chunks_size
# Create video path in the new dataset structure
video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
)
video_path.parent.mkdir(parents=True, exist_ok=True)
# Encode video
encode_video_frames(
imgs_dir=imgs_dir,
video_path=video_path,
fps=fps,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
overwrite=True,
)
# Clean up temporary images
shutil.rmtree(imgs_dir)
# Store video metadata
video_metadata[img_key] = {
f"videos/{img_key}/chunk_index": chunk_idx,
f"videos/{img_key}/file_index": file_idx,
f"videos/{img_key}/from_timestamp": 0.0,
f"videos/{img_key}/to_timestamp": episode_duration,
}
return video_metadata
def convert_dataset_to_videos(
dataset: LeRobotDataset,
output_dir: Path,
repo_id: str | None = None,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int = 2,
crf: int = 30,
fast_decode: int = 0,
episode_indices: list[int] | None = None,
num_workers: int = 4,
) -> LeRobotDataset:
"""Convert image-based dataset to video-based dataset.
Creates a new LeRobotDataset with videos instead of images, following the proper
LeRobot dataset structure with videos stored in chunked MP4 files.
Args:
dataset: The source LeRobot dataset with images
output_dir: Directory to save the new video dataset
repo_id: Repository ID for the new dataset (default: original_id + "_video")
vcodec: Video codec (default: libsvtav1)
pix_fmt: Pixel format (default: yuv420p)
g: Group of pictures size (default: 2)
crf: Constant rate factor (default: 30)
fast_decode: Fast decode tuning (default: 0)
episode_indices: List of episode indices to convert (None = all episodes)
num_workers: Number of threads for parallel processing (default: 4)
Returns:
New LeRobotDataset with videos
"""
# Check that it's an image dataset
if len(dataset.meta.video_keys) > 0:
raise ValueError(
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
)
# Get all image keys
hf_dataset = dataset.hf_dataset.with_format(None)
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
if len(img_keys) == 0:
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
# Determine which episodes to process
if episode_indices is None:
episode_indices = list(range(dataset.meta.total_episodes))
if repo_id is None:
repo_id = f"{dataset.repo_id}_video"
logging.info(
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
)
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
# Create new features dict, converting image features to video features
new_features = {}
for key, value in dataset.meta.features.items():
if key not in img_keys:
new_features[key] = value
else:
# Convert image key to video format
new_features[key] = value.copy()
new_features[key]["dtype"] = "video" # Change dtype from "image" to "video"
# Video info will be updated after episodes are encoded
# Create new metadata for video dataset
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=new_features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=True,
chunks_size=dataset.meta.chunks_size,
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
)
# Create temporary directory for image extraction
temp_dir = output_dir / "temp_images"
temp_dir.mkdir(parents=True, exist_ok=True)
# Process each episode
all_episode_metadata = []
try:
for ep_idx in tqdm(episode_indices, desc="Converting episodes to videos"):
# Get episode metadata from source
src_episode = dataset.meta.episodes[ep_idx]
# Encode videos for this episode
video_metadata = encode_episode_videos(
dataset=dataset,
new_meta=new_meta,
episode_index=ep_idx,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
temp_dir=temp_dir,
num_image_workers=num_workers,
)
# Build episode metadata
episode_meta = {
"episode_index": ep_idx,
"length": src_episode["length"],
"dataset_from_index": ep_idx * src_episode["length"],
"dataset_to_index": (ep_idx + 1) * src_episode["length"],
}
# Add video metadata
for img_key in img_keys:
episode_meta.update(video_metadata[img_key])
# Add data chunk/file info (using same structure as source)
if "data/chunk_index" in src_episode:
episode_meta["data/chunk_index"] = src_episode["data/chunk_index"]
episode_meta["data/file_index"] = src_episode["data/file_index"]
all_episode_metadata.append(episode_meta)
# Copy and transform data files (removing image columns)
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
# Save episode metadata
episodes_df = pd.DataFrame(all_episode_metadata)
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
episodes_path.parent.mkdir(parents=True, exist_ok=True)
episodes_df.to_parquet(episodes_path, index=False)
# Update metadata info
new_meta.info["total_episodes"] = len(episode_indices)
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata)
new_meta.info["total_tasks"] = dataset.meta.total_tasks
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
# Update video info for all image keys (now videos)
# We need to manually set video info since update_video_info() checks video_keys first
for img_key in img_keys:
if not new_meta.features[img_key].get("info", None):
video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=0, file_index=0
)
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
from lerobot.datasets.utils import write_info
write_info(new_meta.info, new_meta.root)
# Copy stats and tasks
if dataset.meta.stats is not None:
# Remove image stats
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
write_stats(new_stats, new_meta.root)
if dataset.meta.tasks is not None:
write_tasks(dataset.meta.tasks, new_meta.root)
finally:
# Clean up temporary directory
if temp_dir.exists():
shutil.rmtree(temp_dir)
logging.info(f"✓ Completed converting {dataset.repo_id} to video format")
logging.info(f"New dataset saved to: {output_dir}")
# Return new dataset
return LeRobotDataset(repo_id=repo_id, root=output_dir)
def _copy_data_without_images(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_indices: list[int],
img_keys: list[str],
) -> None:
"""Copy data files without image columns.
Args:
src_dataset: Source dataset
dst_meta: Destination metadata
episode_indices: Episodes to include
img_keys: Image keys to remove
"""
from lerobot.datasets.utils import DATA_DIR
data_dir = src_dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
episode_set = set(episode_indices)
for src_path in tqdm(parquet_files, desc="Processing data files"):
df = pd.read_parquet(src_path).reset_index(drop=True)
# Filter to only include selected episodes
df = df[df["episode_index"].isin(episode_set)].copy()
if len(df) == 0:
continue
# Remove image columns
columns_to_drop = [col for col in img_keys if col in df.columns]
if columns_to_drop:
df = df.drop(columns=columns_to_drop)
# Get chunk and file indices from path
relative_path = src_path.relative_to(src_dataset.root)
chunk_dir = relative_path.parts[1]
file_name = relative_path.parts[2]
chunk_idx = int(chunk_dir.split("-")[1])
file_idx = int(file_name.split("-")[1].split(".")[0])
# Write to destination without pandas index
dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet"
dst_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(dst_path, index=False)
def handle_convert_to_video(cfg: EditDatasetConfig) -> None:
# Note: Parser may create any config type with the right fields, so we access fields directly
# instead of checking isinstance()
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
# Determine output directory and repo_id
# Priority: 1) new_repo_id, 2) operation.output_dir, 3) auto-generated name
output_dir_config = getattr(cfg.operation, "output_dir", None)
if cfg.new_repo_id:
# Use new_repo_id for both local storage and hub push
output_repo_id = cfg.new_repo_id
output_dir = Path(cfg.root) / cfg.new_repo_id if cfg.root else HF_LEROBOT_HOME / cfg.new_repo_id
logging.info(f"Saving to new dataset: {cfg.new_repo_id}")
elif output_dir_config:
# Use custom output directory for local-only storage
output_dir = Path(output_dir_config)
# Extract repo name from output_dir for the dataset
output_repo_id = output_dir.name
logging.info(f"Saving to local directory: {output_dir}")
else:
# Auto-generate name: append "_video" to original repo_id
output_repo_id = f"{cfg.repo_id}_video"
output_dir = Path(cfg.root) / output_repo_id if cfg.root else HF_LEROBOT_HOME / output_repo_id
logging.info(f"Saving to auto-generated location: {output_dir}")
logging.info(f"Converting dataset {cfg.repo_id} to video format")
new_dataset = convert_dataset_to_videos(
dataset=dataset,
output_dir=output_dir,
repo_id=output_repo_id,
vcodec=getattr(cfg.operation, "vcodec", "libsvtav1"),
pix_fmt=getattr(cfg.operation, "pix_fmt", "yuv420p"),
g=getattr(cfg.operation, "g", 2),
crf=getattr(cfg.operation, "crf", 30),
fast_decode=getattr(cfg.operation, "fast_decode", 0),
episode_indices=getattr(cfg.operation, "episode_indices", None),
num_workers=getattr(cfg.operation, "num_workers", 4),
)
logging.info("Video dataset created successfully!")
logging.info(f"Location: {output_dir}")
logging.info(f"Episodes: {new_dataset.meta.total_episodes}")
logging.info(f"Frames: {new_dataset.meta.total_frames}")
if cfg.push_to_hub:
logging.info(f"Pushing to hub as {output_repo_id}...")
new_dataset.push_to_hub()
logging.info("✓ Successfully pushed to hub!")
else:
logging.info("Dataset saved locally (not pushed to hub)")
@parser.wrap()
def edit_dataset(cfg: EditDatasetConfig) -> None:
operation_type = cfg.operation.type
@@ -270,10 +718,12 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
handle_merge(cfg)
elif operation_type == "remove_feature":
handle_remove_feature(cfg)
elif operation_type == "convert_to_video":
handle_convert_to_video(cfg)
else:
raise ValueError(
f"Unknown operation type: {operation_type}\n"
f"Available operations: delete_episodes, split, merge, remove_feature"
f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video"
)
-1
View File
@@ -173,7 +173,6 @@ def rollout(
observation = env_preprocessor(observation)
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
action = postprocessor(action)
@@ -46,6 +46,7 @@ from lerobot.robots import ( # noqa: F401
RobotConfig,
koch_follower,
make_robot_from_config,
omx_follower,
so100_follower,
so101_follower,
)
@@ -54,6 +55,7 @@ from lerobot.teleoperators import ( # noqa: F401
gamepad,
koch_leader,
make_teleoperator_from_config,
omx_leader,
so100_leader,
so101_leader,
)
+26 -10
View File
@@ -97,9 +97,11 @@ from lerobot.robots import ( # noqa: F401
hope_jr,
koch_follower,
make_robot_from_config,
omx_follower,
so100_follower,
so101_follower,
)
from lerobot.robots.unitree_g1 import config_unitree_g1 # noqa: F401
from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
@@ -107,6 +109,7 @@ from lerobot.teleoperators import ( # noqa: F401
homunculus,
koch_leader,
make_teleoperator_from_config,
omx_leader,
so100_leader,
so101_leader,
)
@@ -195,9 +198,8 @@ class RecordConfig:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.teleop is None and self.policy is None:
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
# Note: teleop and policy can both be None for robots with built-in control (e.g. unitree_g1)
# This is validated in record() after the robot is instantiated
@classmethod
def __get_path_fields__(cls) -> list[str]:
@@ -270,7 +272,12 @@ def record_loop(
for t in teleop
if isinstance(
t,
(so100_leader.SO100Leader | so101_leader.SO101Leader | koch_leader.KochLeader),
(
so100_leader.SO100Leader
| so101_leader.SO101Leader
| koch_leader.KochLeader
| omx_leader.OmxLeader
),
)
),
None,
@@ -333,6 +340,13 @@ def record_loop(
base_action = robot._from_keyboard_to_base_action(keyboard_action)
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
act_processed_teleop = teleop_action_processor((act, obs))
elif policy is None and teleop is None and dataset is not None:
# Observation-only recording (robot controls itself, e.g. unitree_g1)
# Record observations, extract action-relevant values (positions) from obs
# Filter obs_processed to only include keys that match action_features
action_keys = set(robot.action_features.keys())
action_values = {k: v for k, v in obs_processed.items() if k in action_keys}
robot_action_to_send = None
else:
logging.info(
"No policy or teleoperator provided, skipping action generation."
@@ -345,15 +359,17 @@ def record_loop(
if policy is not None and act_processed_policy is not None:
action_values = act_processed_policy
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
else:
elif teleop is not None:
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
# else: observation-only mode, action_values already set above
# Send action to robot
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
_sent_action = robot.send_action(robot_action_to_send)
# Send action to robot (skip if observation-only mode)
if robot_action_to_send is not None:
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
_sent_action = robot.send_action(robot_action_to_send)
# Write to dataset
if dataset is not None:
+1
View File
@@ -58,6 +58,7 @@ from lerobot.robots import ( # noqa: F401
hope_jr,
koch_follower,
make_robot_from_config,
omx_follower,
so100_follower,
so101_follower,
)
@@ -33,6 +33,7 @@ from lerobot.robots import ( # noqa: F401
koch_follower,
lekiwi,
make_robot_from_config,
omx_follower,
so100_follower,
so101_follower,
)
@@ -40,6 +41,7 @@ from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
koch_leader,
make_teleoperator_from_config,
omx_leader,
so100_leader,
so101_leader,
)
@@ -47,6 +49,8 @@ from lerobot.teleoperators import ( # noqa: F401
COMPATIBLE_DEVICES = [
"koch_follower",
"koch_leader",
"omx_follower",
"omx_leader",
"so100_follower",
"so100_leader",
"so101_follower",
@@ -75,6 +75,7 @@ from lerobot.robots import ( # noqa: F401
hope_jr,
koch_follower,
make_robot_from_config,
omx_follower,
so100_follower,
so101_follower,
)
@@ -87,6 +88,7 @@ from lerobot.teleoperators import ( # noqa: F401
keyboard,
koch_leader,
make_teleoperator_from_config,
omx_leader,
so100_leader,
so101_leader,
)
+1 -9
View File
@@ -62,7 +62,6 @@ def update_policy(
accelerator: Accelerator,
lr_scheduler=None,
lock=None,
postprocessor = None,
) -> tuple[MetricsTracker, dict]:
"""
Performs a single training step to update the policy's weights.
@@ -91,10 +90,6 @@ def update_policy(
# Let accelerator handle mixed precision
with accelerator.autocast():
loss, output_dict = policy.forward(batch)
# action = policy.predict_action_chunk(batch)
# if postprocessor is not None:
# action = postprocessor(action)
# breakpoint()
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
# Use accelerator's backward method
@@ -156,7 +151,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
from accelerate.utils import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(step_scheduler_with_optimizer=False, gradient_accumulation_steps=4, kwargs_handlers=[ddp_kwargs])
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
init_logging(accelerator=accelerator)
@@ -211,7 +206,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
ds_meta=dataset.meta,
rename_map=cfg.rename_map,
)
# Wait for all processes to finish policy creation before continuing
accelerator.wait_for_everyone()
@@ -250,7 +244,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
**postprocessor_kwargs,
)
if is_main_process:
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
@@ -350,7 +343,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
cfg.optimizer.grad_clip_norm,
accelerator=accelerator,
lr_scheduler=lr_scheduler,
postprocessor=postprocessor,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -0,0 +1,18 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_omx_leader import OmxLeaderConfig
from .omx_leader import OmxLeader
@@ -0,0 +1,30 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("omx_leader")
@dataclass
class OmxLeaderConfig(TeleoperatorConfig):
# Port to connect to the arm
port: str
# Sets the arm in torque mode with the gripper motor set to this value. This makes it possible to squeeze
# the gripper and have it spring back to an open position on its own.
gripper_open_pos: float = 37.0
@@ -0,0 +1,165 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.dynamixel import (
DriveMode,
DynamixelMotorsBus,
OperatingMode,
)
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from .config_omx_leader import OmxLeaderConfig
logger = logging.getLogger(__name__)
class OmxLeader(Teleoperator):
"""
- [OMX](https://github.com/ROBOTIS-GIT/open_manipulator),
expansion, developed by Woojin Wie and Junha Cha from [ROBOTIS](https://ai.robotis.com/)
"""
config_class = OmxLeaderConfig
name = "omx_leader"
def __init__(self, config: OmxLeaderConfig):
super().__init__(config)
self.config = config
self.bus = DynamixelMotorsBus(
port=self.config.port,
motors={
"shoulder_pan": Motor(1, "xl330-m288", MotorNormMode.RANGE_M100_100),
"shoulder_lift": Motor(2, "xl330-m288", MotorNormMode.RANGE_M100_100),
"elbow_flex": Motor(3, "xl330-m288", MotorNormMode.RANGE_M100_100),
"wrist_flex": Motor(4, "xl330-m288", MotorNormMode.RANGE_M100_100),
"wrist_roll": Motor(5, "xl330-m288", MotorNormMode.RANGE_M100_100),
"gripper": Motor(6, "xl330-m077", MotorNormMode.RANGE_0_100),
},
calibration=self.calibration,
)
@property
def action_features(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.bus.motors}
@property
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.bus.is_connected
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect()
if not self.is_calibrated and calibrate:
logger.info(
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
)
self.calibrate()
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
def calibrate(self) -> None:
self.bus.disable_torque()
logger.info(f"\nUsing factory default calibration values for {self}")
logger.info(f"\nWriting default configuration of {self} to the motors")
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
for motor in self.bus.motors:
if motor == "gripper":
self.bus.write("Drive_Mode", motor, DriveMode.INVERTED.value)
else:
self.bus.write("Drive_Mode", motor, DriveMode.NON_INVERTED.value)
drive_modes = {motor: 1 if motor == "gripper" else 0 for motor in self.bus.motors}
self.calibration = {}
for motor, m in self.bus.motors.items():
self.calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=drive_modes[motor],
homing_offset=0,
range_min=0,
range_max=4095,
)
self.bus.write_calibration(self.calibration)
self._save_calibration()
logger.info(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
self.bus.disable_torque()
self.bus.configure_motors()
for motor in self.bus.motors:
if motor != "gripper":
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while
# assembling the arm, you could end up with a servo with a position 0 or 4095 at a crucial
# point
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
# Use 'position control current based' for gripper to be limited by the limit of the current.
# For the follower gripper, it means it can grasp an object without forcing too much even tho,
# its goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
# For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger
# to make it move, and it will move back to its original target position when we release the force.
self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
# Set gripper's goal pos in current position mode so that we can use it as a trigger.
self.bus.enable_torque("gripper")
if self.is_calibrated:
self.bus.write("Goal_Position", "gripper", self.config.gripper_open_pos)
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
def get_action(self) -> dict[str, float]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start = time.perf_counter()
action = self.bus.sync_read("Present_Position")
action = {f"{motor}.pos": val for motor, val in action.items()}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return action
def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO(rcadene, aliberts): Implement force feedback
raise NotImplementedError
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect()
logger.info(f"{self} disconnected.")
+4
View File
@@ -41,6 +41,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
from .koch_leader import KochLeader
return KochLeader(config)
elif config.type == "omx_leader":
from .omx_leader import OmxLeader
return OmxLeader(config)
elif config.type == "so100_leader":
from .so100_leader import SO100Leader
+1 -8
View File
@@ -26,15 +26,8 @@ OBS_IMAGES = OBS_IMAGE + "s"
OBS_LANGUAGE = OBS_STR + ".language"
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
OBS_LANGUAGE_HIGH_LEVEL_TASK = OBS_STR + ".user_prompt"
OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS = OBS_LANGUAGE_HIGH_LEVEL_TASK + ".tokens"
OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK = OBS_LANGUAGE_HIGH_LEVEL_TASK + ".attention_mask"
OBS_LANGUAGE_SUBTASK_ONLY = OBS_STR + ".subtask"
OBS_LANGUAGE_SUBTASK_ONLY_TOKENS = OBS_LANGUAGE_SUBTASK_ONLY + ".tokens"
OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK_ONLY + ".attention_mask"
ACTION = "action"
ACTION_TOKENS = ACTION + ".tokens"
ACTION_TOKEN_MASK = ACTION + ".token_mask"
REWARD = "next.reward"
TRUNCATED = "next.truncated"
DONE = "next.done"
+105
View File
@@ -29,6 +29,7 @@ from lerobot.datasets.dataset_tools import (
remove_feature,
split_dataset,
)
from lerobot.scripts.lerobot_edit_dataset import convert_dataset_to_videos
@pytest.fixture
@@ -1047,3 +1048,107 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
assert new_chunk_indices == original_chunk_indices, "Chunk indices should be preserved"
assert new_file_indices == original_file_indices, "File indices should be preserved"
assert "reward" in modified_dataset.meta.features
def test_convert_dataset_to_videos(tmp_path):
"""Test converting lerobot/pusht_image dataset to video format."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Load the actual lerobot/pusht_image dataset (only first 2 episodes for speed)
source_dataset = LeRobotDataset("lerobot/pusht_image", episodes=[0, 1])
output_dir = tmp_path / "pusht_video"
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
# Verify source dataset has images, not videos
assert len(source_dataset.meta.video_keys) == 0
assert "observation.image" in source_dataset.meta.features
# Convert to video dataset (only first 2 episodes for speed)
video_dataset = convert_dataset_to_videos(
dataset=source_dataset,
output_dir=output_dir,
repo_id="lerobot/pusht_video",
vcodec="libsvtav1",
pix_fmt="yuv420p",
g=2,
crf=30,
episode_indices=[0, 1],
num_workers=2,
)
# Verify new dataset has videos
assert len(video_dataset.meta.video_keys) > 0
assert "observation.image" in video_dataset.meta.video_keys
# Verify correct number of episodes and frames (2 episodes)
assert video_dataset.meta.total_episodes == 2
# Compare against the actual number of frames in the loaded episodes, not metadata total
assert len(video_dataset) == len(source_dataset)
# Verify video files exist
for ep_idx in range(video_dataset.meta.total_episodes):
for video_key in video_dataset.meta.video_keys:
video_path = video_dataset.root / video_dataset.meta.get_video_file_path(ep_idx, video_key)
assert video_path.exists(), f"Video file should exist: {video_path}"
# Verify we can load the dataset and access it
assert len(video_dataset) == video_dataset.meta.total_frames
# Test that we can actually get an item from the video dataset
item = video_dataset[0]
assert "observation.image" in item
assert "action" in item
# Cleanup
import shutil
if output_dir.exists():
shutil.rmtree(output_dir)
def test_convert_dataset_to_videos_subset_episodes(tmp_path):
"""Test converting only specific episodes from lerobot/pusht_image to video format."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Load the actual lerobot/pusht_image dataset (only first 3 episodes)
source_dataset = LeRobotDataset("lerobot/pusht_image", episodes=[0, 1, 2])
output_dir = tmp_path / "pusht_video_subset"
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
# Convert only episode 0 to video (subset of loaded episodes)
episode_indices = [0]
video_dataset = convert_dataset_to_videos(
dataset=source_dataset,
output_dir=output_dir,
repo_id="lerobot/pusht_video_subset",
episode_indices=episode_indices,
num_workers=2,
)
# Verify correct number of episodes
assert video_dataset.meta.total_episodes == len(episode_indices)
# Verify video files exist for selected episodes
assert len(video_dataset.meta.video_keys) > 0
assert "observation.image" in video_dataset.meta.video_keys
# Cleanup
import shutil
if output_dir.exists():
shutil.rmtree(output_dir)
@@ -266,7 +266,7 @@ def create_original_observation_with_openpi_preprocessing(batch):
elif len(tasks) == 1:
tasks = tasks * batch_size
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateAndLanguageTokenizerProcessorStep)
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep)
state = batch["observation.state"]
state = deepcopy(state)