mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b2d3186011 | |||
| 5865170d36 | |||
| 2dd366436e | |||
| 5f15232271 | |||
| bc38261321 | |||
| aaf3707058 | |||
| 89bd58a9a2 | |||
| b22e0315b0 | |||
| fcbf550952 | |||
| af036ce57e | |||
| 1c388c0002 | |||
| 51d3822d75 |
+42
-42
@@ -28,9 +28,9 @@ We don't expect the same optimal settings for a dataset of images from a simulat
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `aliberts/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
|
||||
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
|
||||
|
||||
@@ -179,7 +179,7 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 2 20 None \
|
||||
@@ -203,9 +203,9 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
@@ -221,9 +221,9 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
--vcodec libsvtav1 \
|
||||
--pix-fmt yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
@@ -252,37 +252,37 @@ Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_read
|
||||
|
||||
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
|
||||
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| aliberts/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| aliberts/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| aliberts/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| aliberts/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| aliberts/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| aliberts/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| aliberts/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
|
||||
@@ -185,7 +185,7 @@ echo $HF_USER
|
||||
Use the standard recording command:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_record.py \
|
||||
lerobot-record \
|
||||
--robot.type=earthrover_mini_plus \
|
||||
--teleop.type=keyboard_rover \
|
||||
--dataset.repo_id=your_username/dataset_name \
|
||||
|
||||
@@ -224,7 +224,7 @@ lerobot-record \
|
||||
--teleop.port=/dev/tty.usbmodem1201 \
|
||||
--teleop.id=right \
|
||||
--teleop.side=right \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
|
||||
--dataset.single_task="Hand recording test with video data" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
@@ -241,7 +241,7 @@ lerobot-replay \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
--robot.side=right \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_camera \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_camera \
|
||||
--dataset.episode=0
|
||||
```
|
||||
|
||||
@@ -249,13 +249,13 @@ lerobot-replay \
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/hopejr_hand \
|
||||
--job_name=hopejr \
|
||||
--policy.device=mps \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=nepyope/hand_test_policy
|
||||
--policy.repo_id=<USER>/hand_test_policy
|
||||
```
|
||||
|
||||
### Evaluate
|
||||
@@ -270,7 +270,7 @@ lerobot-record \
|
||||
--robot.side=right \
|
||||
--robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=nepyope/eval_hopejr \
|
||||
--dataset.repo_id=<USER>/eval_hopejr \
|
||||
--dataset.single_task="Evaluate hopejr hand policy" \
|
||||
--dataset.num_episodes=10 \
|
||||
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
|
||||
|
||||
+1
-1
@@ -60,7 +60,7 @@ policy.type=pi0
|
||||
For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi0 \
|
||||
--output_dir=./outputs/pi0_training \
|
||||
|
||||
@@ -56,7 +56,7 @@ policy.type=pi05
|
||||
Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py\
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi05 \
|
||||
--output_dir=./outputs/pi05_training \
|
||||
|
||||
@@ -269,7 +269,7 @@ This generates visualizations showing video frames with subtask boundaries overl
|
||||
Train with **no annotations** - uses linear progress from 0 to 1:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=single_stage \
|
||||
@@ -288,7 +288,7 @@ python src/lerobot/scripts/lerobot_train.py \
|
||||
Train with **dense annotations only** (sparse auto-generated):
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dense_only \
|
||||
@@ -307,7 +307,7 @@ python src/lerobot/scripts/lerobot_train.py \
|
||||
Train with **both sparse and dense annotations**:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dual \
|
||||
@@ -468,7 +468,7 @@ This script:
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
|
||||
@@ -216,7 +216,7 @@ lerobot-teleoperate \
|
||||
### Record Dataset in Simulation
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
lerobot-record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
@@ -266,7 +266,7 @@ lerobot-teleoperate \
|
||||
### Record Dataset on Real Robot
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
lerobot-record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
|
||||
@@ -12,6 +12,7 @@ LeRobot provides several utilities for manipulating datasets:
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
|
||||
7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
@@ -156,6 +157,30 @@ lerobot-edit-dataset \
|
||||
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
|
||||
|
||||
### Show the information of datasets
|
||||
|
||||
Show the information of datasets such as number of episode, number of frame, File size and so on.
|
||||
No change will be made to the dataset
|
||||
|
||||
```bash
|
||||
|
||||
# Show dataset information without feature details
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
|
||||
# Show dataset information with feature details
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features true
|
||||
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- `parameters`: The flag to control show or no show dataset information with feature details.(default=false)
|
||||
|
||||
### Push to Hub
|
||||
|
||||
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
|
||||
@@ -45,7 +45,7 @@ policy.type=wall_x
|
||||
For training WallX, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=wall_x \
|
||||
--output_dir=./outputs/wallx_training \
|
||||
|
||||
@@ -154,7 +154,7 @@ lerobot-train \
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
|
||||
--dataset.repo_id=<USER>/bimanual-so100-handover-cube \
|
||||
--output_dir=./outputs/xvla_bimanual \
|
||||
--job_name=xvla_so101_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
|
||||
@@ -22,7 +22,7 @@ lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.repo_id=<USER>/record-test \
|
||||
--dataset.episode=2
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,726 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Mirror a bimanual dataset in parallel with DataTrove + SLURM, then double it.
|
||||
|
||||
Workflow:
|
||||
1) Split source episodes across `num_shards` ranks and mirror each shard in parallel.
|
||||
2) Aggregate mirrored shards into one mirrored dataset.
|
||||
3) Aggregate [original, mirrored] into a final doubled dataset.
|
||||
|
||||
Example:
|
||||
python examples/port_datasets/slurm_mirror_dataset.py \
|
||||
--repo-id=pepijn/openarm_bimanual \
|
||||
--output-repo-id=pepijn/openarm_bimanual_doubled \
|
||||
--partition=hopper-cpu \
|
||||
--num-shards=256 \
|
||||
--workers=64 \
|
||||
--cpus-per-task=8 \
|
||||
--mem-per-cpu=4G
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
OPENARM_MIRRORING_MASK = {
|
||||
"joint_1": -1,
|
||||
"joint_2": -1,
|
||||
"joint_3": -1,
|
||||
"joint_4": 1,
|
||||
"joint_5": -1,
|
||||
"joint_6": -1,
|
||||
"joint_7": -1,
|
||||
"gripper": 1,
|
||||
}
|
||||
|
||||
|
||||
def get_mirroring_mask(robot_type: str | None) -> dict[str, int]:
|
||||
if robot_type in ["bi_openarm_follower", "openarm_follower", "bi_openarms_follower", "openarms_follower"]:
|
||||
return OPENARM_MIRRORING_MASK
|
||||
raise ValueError(f"Unknown robot type: {robot_type}. Add a mirroring mask for this robot.")
|
||||
|
||||
|
||||
def swap_left_right_name(name: str) -> str:
|
||||
value = name.replace("left_", "LEFT_PLACEHOLDER_")
|
||||
value = value.replace("right_", "left_")
|
||||
value = value.replace("LEFT_PLACEHOLDER_", "right_")
|
||||
return value
|
||||
|
||||
|
||||
def mirror_feature_names(names: list[str]) -> tuple[list[str], dict[int, int]]:
|
||||
mirrored_names = [swap_left_right_name(n) for n in names]
|
||||
old_to_new_idx = {}
|
||||
for old_idx, old_name in enumerate(names):
|
||||
new_name = swap_left_right_name(old_name)
|
||||
new_idx = mirrored_names.index(new_name)
|
||||
old_to_new_idx[old_idx] = new_idx
|
||||
return mirrored_names, old_to_new_idx
|
||||
|
||||
|
||||
def _get_axis_names(feature: dict[str, Any]) -> list[str] | None:
|
||||
names = feature.get("names")
|
||||
if isinstance(names, list):
|
||||
return names
|
||||
if isinstance(names, dict):
|
||||
axes = names.get("axes")
|
||||
if isinstance(axes, list):
|
||||
return axes
|
||||
return None
|
||||
|
||||
|
||||
def _to_numpy(value: Any) -> Any:
|
||||
if isinstance(value, np.ndarray):
|
||||
return value
|
||||
if hasattr(value, "detach"):
|
||||
return value.detach().cpu().numpy()
|
||||
if hasattr(value, "cpu") and hasattr(value, "numpy"):
|
||||
return value.cpu().numpy()
|
||||
if hasattr(value, "numpy"):
|
||||
return value.numpy()
|
||||
return value
|
||||
|
||||
|
||||
def apply_mirroring_mask(value: float, axis_name: str, mirroring_mask: dict[str, int]) -> float:
|
||||
if axis_name.startswith("left_") or axis_name.startswith("right_"):
|
||||
axis_name = axis_name.split("_", 1)[1]
|
||||
joint_name = axis_name.split(".")[0]
|
||||
return value * mirroring_mask.get(joint_name, 1)
|
||||
|
||||
|
||||
def mirror_vector_feature(
|
||||
value: Any,
|
||||
feature: dict[str, Any],
|
||||
mirroring_mask: dict[str, int],
|
||||
) -> Any:
|
||||
array = _to_numpy(value)
|
||||
if not isinstance(array, np.ndarray) or array.ndim != 1:
|
||||
return array
|
||||
|
||||
names = _get_axis_names(feature)
|
||||
if names is None or len(names) != len(array):
|
||||
return array
|
||||
|
||||
mirrored_names, index_mapping = mirror_feature_names(names)
|
||||
mirrored = np.zeros_like(array)
|
||||
for old_idx, new_idx in index_mapping.items():
|
||||
mirrored[new_idx] = apply_mirroring_mask(array[old_idx], mirrored_names[new_idx], mirroring_mask)
|
||||
return mirrored
|
||||
|
||||
|
||||
def flip_horizontal(value: Any, expected_shape: list[int] | tuple[int, ...]) -> Any:
|
||||
array = _to_numpy(value)
|
||||
if not isinstance(array, np.ndarray) or array.ndim != 3:
|
||||
return array
|
||||
|
||||
expected_shape = tuple(expected_shape)
|
||||
if array.shape == expected_shape:
|
||||
return np.flip(array, axis=1).copy() # HWC
|
||||
|
||||
if len(expected_shape) == 3:
|
||||
c, h, w = expected_shape
|
||||
if array.shape == (c, h, w):
|
||||
return np.flip(array, axis=2).copy() # CHW
|
||||
|
||||
# Conservative fallback for unexpected layouts.
|
||||
return np.flip(array, axis=-1).copy()
|
||||
|
||||
|
||||
def build_mirrored_features(features: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||
mirrored = {}
|
||||
for key, feature in features.items():
|
||||
new_key = swap_left_right_name(key)
|
||||
new_feature = copy.deepcopy(feature)
|
||||
names = new_feature.get("names")
|
||||
if isinstance(names, list):
|
||||
new_feature["names"] = [swap_left_right_name(name) for name in names]
|
||||
elif isinstance(names, dict) and isinstance(names.get("axes"), list):
|
||||
new_feature["names"]["axes"] = [swap_left_right_name(name) for name in names["axes"]]
|
||||
mirrored[new_key] = new_feature
|
||||
return mirrored
|
||||
|
||||
|
||||
def build_mirrored_frame(
|
||||
item: dict[str, Any],
|
||||
source_features: dict[str, dict[str, Any]],
|
||||
mirroring_mask: dict[str, int],
|
||||
) -> dict[str, Any]:
|
||||
frame = {}
|
||||
for key, feature in source_features.items():
|
||||
if key in DEFAULT_FEATURES:
|
||||
continue
|
||||
|
||||
value = item[key]
|
||||
if key in {"action", "observation.state"}:
|
||||
value = mirror_vector_feature(value, feature, mirroring_mask)
|
||||
elif feature["dtype"] in {"video", "image"}:
|
||||
value = flip_horizontal(value, feature["shape"])
|
||||
else:
|
||||
value = _to_numpy(value)
|
||||
|
||||
frame[swap_left_right_name(key)] = value
|
||||
|
||||
frame["task"] = item["task"]
|
||||
if "timestamp" in item:
|
||||
ts = _to_numpy(item["timestamp"])
|
||||
frame["timestamp"] = float(ts.item() if hasattr(ts, "item") else ts)
|
||||
return frame
|
||||
|
||||
|
||||
def _resolve_source_root(repo_id: str, root: Path | None) -> Path:
|
||||
source_meta = LeRobotDatasetMetadata(repo_id=repo_id, root=root)
|
||||
return source_meta.root
|
||||
|
||||
|
||||
def _get_work_dir(output_repo_id: str, work_dir: Path | None) -> Path:
|
||||
if work_dir is not None:
|
||||
return work_dir
|
||||
safe_name = output_repo_id.replace("/", "__")
|
||||
return HF_LEROBOT_HOME / "_mirror_work" / safe_name
|
||||
|
||||
|
||||
def _get_shard_root(work_dir: Path, world_size: int, rank: int) -> Path:
|
||||
return work_dir / "mirrored_shards" / f"world_{world_size}_rank_{rank}"
|
||||
|
||||
|
||||
def _is_valid_dataset_root(root: Path) -> bool:
|
||||
return (root / "meta" / "info.json").exists()
|
||||
|
||||
|
||||
def mirror_shard(
|
||||
repo_id: str,
|
||||
source_root: Path,
|
||||
mirrored_repo_id: str,
|
||||
shard_root: Path,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
vcodec: str,
|
||||
overwrite: bool,
|
||||
) -> None:
|
||||
source_dataset = LeRobotDataset(repo_id=repo_id, root=source_root)
|
||||
selected_episodes = list(range(rank, source_dataset.meta.total_episodes, world_size))
|
||||
|
||||
if len(selected_episodes) == 0:
|
||||
logger.info("Rank %s has no episodes assigned. Skipping.", rank)
|
||||
return
|
||||
|
||||
if shard_root.exists():
|
||||
if overwrite:
|
||||
shutil.rmtree(shard_root)
|
||||
elif _is_valid_dataset_root(shard_root):
|
||||
logger.info("Rank %s shard already exists at %s. Skipping.", rank, shard_root)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Shard root {shard_root} exists but is not a valid dataset. Use --overwrite to recreate."
|
||||
)
|
||||
|
||||
mirroring_mask = get_mirroring_mask(source_dataset.meta.robot_type)
|
||||
mirrored_features = build_mirrored_features(source_dataset.meta.features)
|
||||
|
||||
shard_repo_name = f"{mirrored_repo_id}_world_{world_size}_rank_{rank}"
|
||||
mirrored_dataset = LeRobotDataset.create(
|
||||
repo_id=shard_repo_name,
|
||||
root=shard_root,
|
||||
fps=source_dataset.meta.fps,
|
||||
features=mirrored_features,
|
||||
robot_type=source_dataset.meta.robot_type,
|
||||
use_videos=len(source_dataset.meta.video_keys) > 0,
|
||||
vcodec=vcodec,
|
||||
)
|
||||
mirrored_dataset.meta.update_chunk_settings(
|
||||
chunks_size=source_dataset.meta.chunks_size,
|
||||
data_files_size_in_mb=source_dataset.meta.data_files_size_in_mb,
|
||||
video_files_size_in_mb=source_dataset.meta.video_files_size_in_mb,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Rank %s processing %s episodes into shard %s",
|
||||
rank,
|
||||
len(selected_episodes),
|
||||
shard_root,
|
||||
)
|
||||
for source_ep_idx in selected_episodes:
|
||||
episode = source_dataset.meta.episodes[source_ep_idx]
|
||||
start_idx = int(episode["dataset_from_index"])
|
||||
end_idx = int(episode["dataset_to_index"])
|
||||
|
||||
for frame_idx in range(start_idx, end_idx):
|
||||
item = source_dataset[frame_idx]
|
||||
mirrored_frame = build_mirrored_frame(
|
||||
item=item,
|
||||
source_features=source_dataset.meta.features,
|
||||
mirroring_mask=mirroring_mask,
|
||||
)
|
||||
mirrored_dataset.add_frame(mirrored_frame)
|
||||
|
||||
mirrored_dataset.save_episode()
|
||||
|
||||
mirrored_dataset.finalize()
|
||||
|
||||
|
||||
class MirrorDatasetShards(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
source_root: Path,
|
||||
mirrored_repo_id: str,
|
||||
work_dir: Path,
|
||||
vcodec: str,
|
||||
overwrite: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.source_root = source_root
|
||||
self.mirrored_repo_id = mirrored_repo_id
|
||||
self.work_dir = work_dir
|
||||
self.vcodec = vcodec
|
||||
self.overwrite = overwrite
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
init_logging()
|
||||
shard_root = _get_shard_root(self.work_dir, world_size, rank)
|
||||
mirror_shard(
|
||||
repo_id=self.repo_id,
|
||||
source_root=self.source_root,
|
||||
mirrored_repo_id=self.mirrored_repo_id,
|
||||
shard_root=shard_root,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
vcodec=self.vcodec,
|
||||
overwrite=self.overwrite,
|
||||
)
|
||||
|
||||
|
||||
def make_mirror_executor(
|
||||
repo_id: str,
|
||||
source_root: Path,
|
||||
mirrored_repo_id: str,
|
||||
work_dir: Path,
|
||||
logs_dir: Path,
|
||||
job_name: str,
|
||||
num_shards: int,
|
||||
workers: int,
|
||||
partition: str,
|
||||
cpus_per_task: int,
|
||||
mem_per_cpu: str,
|
||||
time_limit: str,
|
||||
vcodec: str,
|
||||
overwrite: bool,
|
||||
slurm: bool,
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
MirrorDatasetShards(
|
||||
repo_id=repo_id,
|
||||
source_root=source_root,
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
work_dir=work_dir,
|
||||
vcodec=vcodec,
|
||||
overwrite=overwrite,
|
||||
),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
if partition is None:
|
||||
raise ValueError("`--partition` is required when `--slurm 1`.")
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": num_shards,
|
||||
"workers": workers,
|
||||
"time": time_limit,
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
return SlurmPipelineExecutor(**kwargs)
|
||||
|
||||
kwargs.update({"tasks": num_shards, "workers": 1})
|
||||
return LocalPipelineExecutor(**kwargs)
|
||||
|
||||
|
||||
class AggregateMirroredShardsStep(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
mirrored_repo_id: str,
|
||||
mirrored_root: Path,
|
||||
work_dir: Path,
|
||||
num_shards: int,
|
||||
overwrite: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.mirrored_repo_id = mirrored_repo_id
|
||||
self.mirrored_root = mirrored_root
|
||||
self.work_dir = work_dir
|
||||
self.num_shards = num_shards
|
||||
self.overwrite = overwrite
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
init_logging()
|
||||
if rank != 0:
|
||||
logger.info("Skipping rank %s for aggregate mirrored step", rank)
|
||||
return
|
||||
aggregate_mirrored_shards(
|
||||
mirrored_repo_id=self.mirrored_repo_id,
|
||||
mirrored_root=self.mirrored_root,
|
||||
work_dir=self.work_dir,
|
||||
num_shards=self.num_shards,
|
||||
overwrite=self.overwrite,
|
||||
)
|
||||
|
||||
|
||||
class BuildDoubledDatasetStep(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
source_repo_id: str,
|
||||
source_root: Path,
|
||||
mirrored_repo_id: str,
|
||||
mirrored_root: Path,
|
||||
output_repo_id: str,
|
||||
output_root: Path,
|
||||
overwrite: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.source_repo_id = source_repo_id
|
||||
self.source_root = source_root
|
||||
self.mirrored_repo_id = mirrored_repo_id
|
||||
self.mirrored_root = mirrored_root
|
||||
self.output_repo_id = output_repo_id
|
||||
self.output_root = output_root
|
||||
self.overwrite = overwrite
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
init_logging()
|
||||
if rank != 0:
|
||||
logger.info("Skipping rank %s for build doubled step", rank)
|
||||
return
|
||||
build_doubled_dataset(
|
||||
source_repo_id=self.source_repo_id,
|
||||
source_root=self.source_root,
|
||||
mirrored_repo_id=self.mirrored_repo_id,
|
||||
mirrored_root=self.mirrored_root,
|
||||
output_repo_id=self.output_repo_id,
|
||||
output_root=self.output_root,
|
||||
overwrite=self.overwrite,
|
||||
)
|
||||
|
||||
|
||||
class PushDoubledDatasetStep(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
output_repo_id: str,
|
||||
output_root: Path,
|
||||
):
|
||||
super().__init__()
|
||||
self.output_repo_id = output_repo_id
|
||||
self.output_root = output_root
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
init_logging()
|
||||
if rank != 0:
|
||||
logger.info("Skipping rank %s for push step", rank)
|
||||
return
|
||||
logger.info("Pushing doubled dataset to hub: %s", self.output_repo_id)
|
||||
LeRobotDataset(self.output_repo_id, root=self.output_root).push_to_hub()
|
||||
|
||||
|
||||
def make_single_task_executor(
|
||||
step: PipelineStep,
|
||||
logs_dir: Path,
|
||||
job_name: str,
|
||||
partition: str | None,
|
||||
cpus_per_task: int,
|
||||
mem_per_cpu: str,
|
||||
time_limit: str,
|
||||
slurm: bool,
|
||||
depends: SlurmPipelineExecutor | None = None,
|
||||
):
|
||||
kwargs = {"pipeline": [step], "logging_dir": str(logs_dir / job_name)}
|
||||
if slurm:
|
||||
if partition is None:
|
||||
raise ValueError("`--partition` is required when `--slurm 1`.")
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
"time": time_limit,
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
"depends": depends,
|
||||
}
|
||||
)
|
||||
return SlurmPipelineExecutor(**kwargs)
|
||||
|
||||
kwargs.update({"tasks": 1, "workers": 1})
|
||||
return LocalPipelineExecutor(**kwargs)
|
||||
|
||||
|
||||
def aggregate_mirrored_shards(
|
||||
mirrored_repo_id: str,
|
||||
mirrored_root: Path,
|
||||
work_dir: Path,
|
||||
num_shards: int,
|
||||
overwrite: bool,
|
||||
):
|
||||
if mirrored_root.exists():
|
||||
if overwrite:
|
||||
shutil.rmtree(mirrored_root)
|
||||
elif _is_valid_dataset_root(mirrored_root):
|
||||
logger.info("Mirrored dataset already exists at %s. Skipping aggregation.", mirrored_root)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Mirrored root {mirrored_root} exists but is not a valid dataset. Use --overwrite to recreate."
|
||||
)
|
||||
|
||||
shard_repo_ids = []
|
||||
shard_roots = []
|
||||
for rank in range(num_shards):
|
||||
shard_root = _get_shard_root(work_dir, num_shards, rank)
|
||||
if _is_valid_dataset_root(shard_root):
|
||||
shard_repo_ids.append(f"{mirrored_repo_id}_world_{num_shards}_rank_{rank}")
|
||||
shard_roots.append(shard_root)
|
||||
|
||||
if len(shard_repo_ids) == 0:
|
||||
raise RuntimeError("No mirrored shards were produced. Nothing to aggregate.")
|
||||
|
||||
logger.info("Aggregating %s mirrored shards into %s", len(shard_repo_ids), mirrored_root)
|
||||
aggregate_datasets(
|
||||
repo_ids=shard_repo_ids,
|
||||
roots=shard_roots,
|
||||
aggr_repo_id=mirrored_repo_id,
|
||||
aggr_root=mirrored_root,
|
||||
)
|
||||
|
||||
|
||||
def build_doubled_dataset(
|
||||
source_repo_id: str,
|
||||
source_root: Path,
|
||||
mirrored_repo_id: str,
|
||||
mirrored_root: Path,
|
||||
output_repo_id: str,
|
||||
output_root: Path,
|
||||
overwrite: bool,
|
||||
):
|
||||
if output_root.exists():
|
||||
if overwrite:
|
||||
shutil.rmtree(output_root)
|
||||
elif _is_valid_dataset_root(output_root):
|
||||
logger.info("Doubled dataset already exists at %s. Skipping final aggregation.", output_root)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Output root {output_root} exists but is not a valid dataset. Use --overwrite to recreate."
|
||||
)
|
||||
|
||||
logger.info("Aggregating source + mirrored into doubled dataset at %s", output_root)
|
||||
aggregate_datasets(
|
||||
repo_ids=[source_repo_id, mirrored_repo_id],
|
||||
roots=[source_root, mirrored_root],
|
||||
aggr_repo_id=output_repo_id,
|
||||
aggr_root=output_root,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--repo-id", type=str, required=True, help="Source dataset repo id.")
|
||||
parser.add_argument("--output-repo-id", type=str, required=True, help="Final doubled dataset repo id.")
|
||||
parser.add_argument("--root", type=Path, default=None, help="Root path of source dataset.")
|
||||
parser.add_argument(
|
||||
"--output-root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root path where final doubled dataset is written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--work-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Intermediate directory for mirrored shards and mirrored aggregate dataset.",
|
||||
)
|
||||
parser.add_argument("--logs-dir", type=Path, required=True, help="DataTrove logs path.")
|
||||
parser.add_argument("--job-name", type=str, default="mirror_dataset", help="SLURM job name.")
|
||||
parser.add_argument("--num-shards", type=int, default=256, help="Number of DataTrove tasks/ranks.")
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Max concurrent DataTrove workers on SLURM.",
|
||||
)
|
||||
parser.add_argument("--partition", type=str, default=None, help="SLURM partition (e.g. hopper-cpu).")
|
||||
parser.add_argument("--cpus-per-task", type=int, default=8, help="CPU count per SLURM task.")
|
||||
parser.add_argument("--mem-per-cpu", type=str, default="4G", help="Memory per CPU for SLURM task.")
|
||||
parser.add_argument("--time", type=str, default="24:00:00", help="SLURM time limit.")
|
||||
parser.add_argument("--vcodec", type=str, default="libsvtav1", help="Video codec for output videos.")
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Use SLURM executor. Set 0 for local sequential debugging.",
|
||||
)
|
||||
parser.add_argument("--overwrite", action="store_true", help="Delete existing intermediate/final outputs.")
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Push final doubled dataset to Hugging Face Hub after completion.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
init_logging()
|
||||
slurm = args.slurm == 1
|
||||
|
||||
source_root = _resolve_source_root(args.repo_id, args.root)
|
||||
output_root = args.output_root if args.output_root is not None else HF_LEROBOT_HOME / args.output_repo_id
|
||||
work_dir = _get_work_dir(args.output_repo_id, args.work_dir)
|
||||
mirrored_repo_id = f"{args.output_repo_id}_mirrored"
|
||||
mirrored_root = work_dir / "mirrored_aggregate"
|
||||
|
||||
work_dir.mkdir(parents=True, exist_ok=True)
|
||||
args.logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
mirror_executor = make_mirror_executor(
|
||||
repo_id=args.repo_id,
|
||||
source_root=source_root,
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
work_dir=work_dir,
|
||||
logs_dir=args.logs_dir,
|
||||
job_name=args.job_name,
|
||||
num_shards=args.num_shards,
|
||||
workers=args.workers,
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time,
|
||||
vcodec=args.vcodec,
|
||||
overwrite=args.overwrite,
|
||||
slurm=slurm,
|
||||
)
|
||||
if slurm:
|
||||
aggregate_executor = make_single_task_executor(
|
||||
step=AggregateMirroredShardsStep(
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
mirrored_root=mirrored_root,
|
||||
work_dir=work_dir,
|
||||
num_shards=args.num_shards,
|
||||
overwrite=args.overwrite,
|
||||
),
|
||||
logs_dir=args.logs_dir,
|
||||
job_name=f"{args.job_name}_aggregate_mirrored",
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time,
|
||||
slurm=True,
|
||||
depends=mirror_executor,
|
||||
)
|
||||
build_executor = make_single_task_executor(
|
||||
step=BuildDoubledDatasetStep(
|
||||
source_repo_id=args.repo_id,
|
||||
source_root=source_root,
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
mirrored_root=mirrored_root,
|
||||
output_repo_id=args.output_repo_id,
|
||||
output_root=output_root,
|
||||
overwrite=args.overwrite,
|
||||
),
|
||||
logs_dir=args.logs_dir,
|
||||
job_name=f"{args.job_name}_build_doubled",
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time,
|
||||
slurm=True,
|
||||
depends=aggregate_executor,
|
||||
)
|
||||
|
||||
final_executor: SlurmPipelineExecutor | LocalPipelineExecutor = build_executor
|
||||
push_executor = None
|
||||
if args.push_to_hub:
|
||||
push_executor = make_single_task_executor(
|
||||
step=PushDoubledDatasetStep(
|
||||
output_repo_id=args.output_repo_id,
|
||||
output_root=output_root,
|
||||
),
|
||||
logs_dir=args.logs_dir,
|
||||
job_name=f"{args.job_name}_push",
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time,
|
||||
slurm=True,
|
||||
depends=build_executor,
|
||||
)
|
||||
final_executor = push_executor
|
||||
|
||||
final_executor.run()
|
||||
logger.info(
|
||||
"Submitted SLURM chain. job_ids: mirror=%s aggregate=%s doubled=%s push=%s",
|
||||
mirror_executor.job_id,
|
||||
aggregate_executor.job_id,
|
||||
build_executor.job_id,
|
||||
push_executor.job_id if push_executor is not None else None,
|
||||
)
|
||||
return
|
||||
|
||||
mirror_executor.run()
|
||||
aggregate_mirrored_shards(
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
mirrored_root=mirrored_root,
|
||||
work_dir=work_dir,
|
||||
num_shards=args.num_shards,
|
||||
overwrite=args.overwrite,
|
||||
)
|
||||
build_doubled_dataset(
|
||||
source_repo_id=args.repo_id,
|
||||
source_root=source_root,
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
mirrored_root=mirrored_root,
|
||||
output_repo_id=args.output_repo_id,
|
||||
output_root=output_root,
|
||||
overwrite=args.overwrite,
|
||||
)
|
||||
if args.push_to_hub:
|
||||
logger.info("Pushing doubled dataset to hub: %s", args.output_repo_id)
|
||||
LeRobotDataset(args.output_repo_id, root=output_root).push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -27,8 +27,8 @@ measuring consistency and ground truth alignment.
|
||||
Usage:
|
||||
# Basic usage with smolvla policy
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--rtc.max_guidance_weight=10.0 \
|
||||
@@ -58,16 +58,16 @@ Usage:
|
||||
--device=cuda
|
||||
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=lipsop/reuben_pi0 \
|
||||
--dataset.repo_id=ReubenLim/so101_cube_in_cup \
|
||||
--policy.path=<USER>/reuben_pi0 \
|
||||
--dataset.repo_id=<USER>/so101_cube_in_cup \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda
|
||||
|
||||
# With torch.compile for faster inference (PyTorch 2.0+)
|
||||
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--use_torch_compile=true \
|
||||
@@ -75,8 +75,8 @@ Usage:
|
||||
|
||||
# With torch.compile on CUDA (CUDA graphs disabled by default)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda \
|
||||
--use_torch_compile=true \
|
||||
@@ -84,8 +84,8 @@ Usage:
|
||||
|
||||
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_backend=inductor \
|
||||
--torch_compile_mode=max-autotune \
|
||||
|
||||
@@ -28,7 +28,7 @@ For simulation environments, see eval_with_simulation.py
|
||||
Usage:
|
||||
# Run RTC with Real robot with RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
@@ -41,7 +41,7 @@ Usage:
|
||||
|
||||
# Run RTC with Real robot without RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
@@ -53,7 +53,7 @@ Usage:
|
||||
|
||||
# Run RTC with Real robot with pi0.5 policy
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.path=<USER>/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
|
||||
+4
-4
@@ -59,7 +59,7 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
||||
dependencies = [
|
||||
|
||||
# Hugging Face dependencies
|
||||
"datasets>=4.0.0,<4.2.0",
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"diffusers>=0.27.2,<0.36.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
@@ -76,9 +76,9 @@ dependencies = [
|
||||
"pyserial>=3.5,<4.0",
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
|
||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
|
||||
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
|
||||
|
||||
"draccus==0.10.0", # TODO: Remove ==
|
||||
"gymnasium>=1.1.1,<2.0.0",
|
||||
|
||||
@@ -150,7 +150,7 @@ class Camera(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -530,7 +530,7 @@ class OpenCVCamera(Camera):
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -201,7 +201,7 @@ class Reachy2Camera(Camera):
|
||||
return self.read()
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -573,7 +573,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -656,7 +656,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
||||
will be stored under root/repo_id.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
||||
set the LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
'~/.cache/huggingface/lerobot'.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
|
||||
@@ -122,19 +122,9 @@ def load_nested_dataset(
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
with SuppressProgressBars():
|
||||
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
|
||||
# PyArrow loads the entire dataset into memory
|
||||
if episodes is None:
|
||||
return Dataset.from_parquet([str(path) for path in paths], features=features)
|
||||
|
||||
arrow_dataset = pa_ds.dataset(paths, format="parquet")
|
||||
filter_expr = pa_ds.field("episode_index").isin(episodes)
|
||||
table = arrow_dataset.to_table(filter=filter_expr)
|
||||
|
||||
if features is not None:
|
||||
table = table.cast(features.arrow_schema)
|
||||
|
||||
return Dataset(table)
|
||||
# We use .from_parquet() memory-mapped loading for efficiency
|
||||
filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None
|
||||
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
|
||||
@@ -529,7 +529,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
"(e.g. `lerobot/pusht`, `<USER>/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
|
||||
@@ -27,18 +27,18 @@ Usage:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
|
||||
# Faster computation with stride (compute every 5 frames, interpolate the rest)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--stride 5
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 5
|
||||
|
||||
@@ -714,12 +714,12 @@ Examples:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 10
|
||||
""",
|
||||
|
||||
@@ -30,7 +30,7 @@ Example of finetuning the smolvla pretrained model (`smolvla_base`):
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
@@ -40,7 +40,7 @@ and an action expert.
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
|
||||
@@ -44,6 +44,7 @@ from .hil_processor import (
|
||||
AddTeleopActionAsComplimentaryDataStep,
|
||||
AddTeleopEventsAsInfoStep,
|
||||
GripperPenaltyProcessorStep,
|
||||
GymHILAdapterProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
RewardClassifierProcessorStep,
|
||||
@@ -87,6 +88,7 @@ __all__ = [
|
||||
"DoneProcessorStep",
|
||||
"EnvAction",
|
||||
"EnvTransition",
|
||||
"GymHILAdapterProcessorStep",
|
||||
"GripperPenaltyProcessorStep",
|
||||
"hotswap_stats",
|
||||
"IdentityProcessorStep",
|
||||
|
||||
@@ -20,6 +20,7 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
|
||||
from .converters import to_tensor
|
||||
from .core import EnvAction, EnvTransition, PolicyAction
|
||||
from .hil_processor import TELEOP_ACTION_KEY
|
||||
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@@ -89,6 +90,13 @@ class Numpy2TorchActionProcessorStep(ProcessorStep):
|
||||
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
||||
new_transition[TransitionKey.ACTION] = torch_action
|
||||
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if TELEOP_ACTION_KEY in complementary_data:
|
||||
teleop_action = complementary_data[TELEOP_ACTION_KEY]
|
||||
if isinstance(teleop_action, EnvAction):
|
||||
complementary_data[TELEOP_ACTION_KEY] = to_tensor(teleop_action)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -312,6 +312,37 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("gym_hil_adapter_processor")
|
||||
class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Adapts the output of the `gym-hil` environment to the format expected by `lerobot` processors.
|
||||
|
||||
This step normalizes the `transition` object by:
|
||||
1. Copying `teleop_action` from `info` to `complementary_data`.
|
||||
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
if TELEOP_ACTION_KEY in info:
|
||||
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
||||
|
||||
if "is_intervention" in info:
|
||||
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
||||
|
||||
transition[TransitionKey.INFO] = info
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
|
||||
@@ -36,6 +36,7 @@ from lerobot.processor import (
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
GripperPenaltyProcessorStep,
|
||||
GymHILAdapterProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
MapDeltaActionToRobotActionStep,
|
||||
@@ -379,6 +380,7 @@ def make_processors(
|
||||
]
|
||||
|
||||
env_pipeline_steps = [
|
||||
GymHILAdapterProcessorStep(),
|
||||
Numpy2TorchActionProcessorStep(),
|
||||
VanillaObservationProcessorStep(),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
@@ -608,7 +610,14 @@ def control_loop(
|
||||
|
||||
dataset = None
|
||||
if cfg.mode == "record":
|
||||
action_features = teleop_device.action_features
|
||||
if teleop_device:
|
||||
action_features = teleop_device.action_features
|
||||
else:
|
||||
action_features = {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": ["delta_x", "delta_y", "delta_z", "gripper"],
|
||||
}
|
||||
features = {
|
||||
ACTION: action_features,
|
||||
REWARD: {"dtype": "float32", "shape": (1,), "names": None},
|
||||
@@ -656,7 +665,7 @@ def control_loop(
|
||||
# Create a neutral action (no movement)
|
||||
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
||||
if use_gripper:
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay
|
||||
|
||||
# Use the new step function
|
||||
transition = step_env_and_process_transition(
|
||||
@@ -725,6 +734,8 @@ def control_loop(
|
||||
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
|
||||
|
||||
if dataset is not None and cfg.dataset.push_to_hub:
|
||||
logging.info("Finalizing dataset before pushing to hub")
|
||||
dataset.finalize()
|
||||
logging.info("Pushing dataset to hub")
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ class HopeJrArm(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -171,7 +171,7 @@ class HopeJrHand(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ class KochFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -360,7 +360,7 @@ class LeKiwi(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ class OmxFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -241,7 +241,7 @@ class OpenArmFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -180,7 +180,7 @@ class Reachy2Robot(Robot):
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
|
||||
return obs_dict
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class SOFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -324,7 +324,7 @@ class UnitreeG1(Robot):
|
||||
|
||||
# Cameras - read images from ZMQ cameras
|
||||
for cam_name, cam in self._cameras.items():
|
||||
obs[cam_name] = cam.async_read()
|
||||
obs[cam_name] = cam.read_latest()
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
@@ -47,16 +47,14 @@ local$ rerun lerobot_pusht_episode_0.rrd
|
||||
```
|
||||
|
||||
- Visualize data stored on a distant machine through streaming:
|
||||
(You need to forward the websocket port to the distant machine, with
|
||||
`ssh -L 9087:localhost:9087 username@remote-host`)
|
||||
```
|
||||
distant$ lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0 \
|
||||
--mode distant \
|
||||
--ws-port 9087
|
||||
--grpc-port 9876
|
||||
|
||||
local$ rerun ws://localhost:9087
|
||||
local$ rerun rerun+http://IP:GRPC_PORT/proxy
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -75,6 +73,7 @@ import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
@@ -93,10 +92,11 @@ def visualize_dataset(
|
||||
num_workers: int = 0,
|
||||
mode: str = "local",
|
||||
web_port: int = 9090,
|
||||
ws_port: int = 9087,
|
||||
grpc_port: int = 9876,
|
||||
save: bool = False,
|
||||
output_dir: Path | None = None,
|
||||
display_compressed_images: bool = False,
|
||||
**kwargs,
|
||||
) -> Path | None:
|
||||
if save:
|
||||
assert output_dir is not None, (
|
||||
@@ -126,7 +126,9 @@ def visualize_dataset(
|
||||
gc.collect()
|
||||
|
||||
if mode == "distant":
|
||||
rr.serve_web_viewer(open_browser=False, web_port=web_port)
|
||||
server_uri = rr.serve_grpc(grpc_port=grpc_port)
|
||||
logging.info(f"Connect to a Rerun Server: rerun rerun+http://IP:{grpc_port}/proxy")
|
||||
rr.serve_web_viewer(open_browser=False, web_port=web_port, connect_to=server_uri)
|
||||
|
||||
logging.info("Logging to Rerun")
|
||||
|
||||
@@ -226,7 +228,7 @@ def main():
|
||||
"Mode of viewing between 'local' or 'distant'. "
|
||||
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
||||
"'distant' creates a server on the distant machine where the data is stored. "
|
||||
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
||||
"Visualize the data by connecting to the server with `rerun rerun+http://IP:GRPC_PORT/proxy` on the local machine."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -238,8 +240,13 @@ def main():
|
||||
parser.add_argument(
|
||||
"--ws-port",
|
||||
type=int,
|
||||
default=9087,
|
||||
help="Web socket port for rerun.io when `--mode distant` is set.",
|
||||
help="deprecated, please use --grpc-port instead.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grpc-port",
|
||||
type=int,
|
||||
default=9876,
|
||||
help="gRPC port for rerun.io when `--mode distant` is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save",
|
||||
@@ -265,9 +272,7 @@ def main():
|
||||
|
||||
parser.add_argument(
|
||||
"--display-compressed-images",
|
||||
type=bool,
|
||||
required=True,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="If set, display compressed images in Rerun instead of uncompressed ones.",
|
||||
)
|
||||
|
||||
@@ -277,6 +282,14 @@ def main():
|
||||
root = kwargs.pop("root")
|
||||
tolerance_s = kwargs.pop("tolerance_s")
|
||||
|
||||
if kwargs["ws_port"] is not None:
|
||||
logging.warning(
|
||||
"--ws-port is deprecated and will be removed in future versions. Please use --grpc-port instead."
|
||||
)
|
||||
logging.warning("Setting grpc_port to ws_port value.")
|
||||
kwargs["grpc_port"] = kwargs.pop("ws_port")
|
||||
|
||||
init_logging()
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
|
||||
|
||||
|
||||
@@ -24,94 +24,107 @@ When new_repo_id is specified, creates a new dataset.
|
||||
Usage Examples:
|
||||
|
||||
Delete episodes 0, 2, and 5 from a dataset:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Delete episodes and save to a new dataset:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_filtered \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Split dataset by fractions:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.8, "val": 0.2}'
|
||||
|
||||
Split dataset by episode indices:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}'
|
||||
|
||||
Split into more than two splits:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}'
|
||||
|
||||
Merge multiple datasets:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_merged \
|
||||
--operation.type merge \
|
||||
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
||||
|
||||
Remove camera feature:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
|
||||
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Pick up the cube and place it"
|
||||
|
||||
Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}'
|
||||
|
||||
Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Default task" \
|
||||
--operation.episode_tasks '{"5": "Special task for episode 5"}'
|
||||
|
||||
Convert image dataset to video format and save locally:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
Convert image dataset to video format and save with new repo_id:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_image_to_video
|
||||
|
||||
Convert image dataset to video format and push to hub:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
Show dataset information:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features true
|
||||
|
||||
Show dataset information without feature details:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features false
|
||||
|
||||
Using JSON config file:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--config_path path/to/edit_config.json
|
||||
"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
@@ -184,6 +197,13 @@ class ConvertImageToVideoConfig(OperationConfig):
|
||||
max_frames_per_batch: int | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("info")
|
||||
@dataclass
|
||||
class InfoConfig(OperationConfig):
|
||||
type: str = "info"
|
||||
show_features: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
@@ -436,6 +456,49 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
logging.info("Dataset saved locally (not pushed to hub)")
|
||||
|
||||
|
||||
def _get_dataset_size(repo_path):
|
||||
import os
|
||||
|
||||
total = 0
|
||||
with os.scandir(repo_path) as it:
|
||||
for entry in it:
|
||||
if entry.is_file():
|
||||
total += entry.stat().st_size
|
||||
elif entry.is_dir():
|
||||
total += _get_dataset_size(entry.path)
|
||||
return total
|
||||
|
||||
|
||||
def handle_info(cfg: EditDatasetConfig):
|
||||
if not isinstance(cfg.operation, InfoConfig):
|
||||
raise ValueError("Operation config must be InfoConfig")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
sys.stdout.write(f"======Info {dataset.meta.repo_id}\n")
|
||||
sys.stdout.write(f"Repository ID: {dataset.meta.repo_id} \n")
|
||||
sys.stdout.write(f"Total episode: {dataset.meta.total_episodes} \n")
|
||||
sys.stdout.write(f"Total task: {dataset.meta.total_tasks} \n")
|
||||
sys.stdout.write(f"Total frame(Actual Count): {dataset.meta.total_frames}({len(dataset)}) \n")
|
||||
sys.stdout.write(
|
||||
f"Average frame per episode: {dataset.meta.total_frames / dataset.meta.total_episodes:.1f}\n"
|
||||
)
|
||||
sys.stdout.write(
|
||||
f"Average episode time(sec): {(dataset.meta.total_frames / dataset.meta.total_episodes) / dataset.meta.fps:.1f}\n"
|
||||
)
|
||||
sys.stdout.write(f"FPS: {dataset.meta.fps}\n")
|
||||
|
||||
total_file_size = _get_dataset_size(dataset.root)
|
||||
sys.stdout.write(f"Size: {total_file_size / (1024 * 1024):.1f} MB\n")
|
||||
if cfg.operation.show_features:
|
||||
import json
|
||||
|
||||
feature_dump_str = json.dumps(
|
||||
dataset.meta.features, ensure_ascii=False, indent=4, sort_keys=True, separators=(",", ": ")
|
||||
)
|
||||
sys.stdout.write("Features:\n")
|
||||
sys.stdout.write(f"{feature_dump_str}\n")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
operation_type = cfg.operation.type
|
||||
@@ -452,6 +515,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
elif operation_type == "info":
|
||||
handle_info(cfg)
|
||||
else:
|
||||
available = ", ".join(OperationConfig.get_known_choices())
|
||||
raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}")
|
||||
|
||||
@@ -398,7 +398,14 @@ def record_loop(
|
||||
)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
precise_sleep(max(1 / fps - dt_s, 0.0))
|
||||
|
||||
sleep_time_s: float = 1 / fps - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
)
|
||||
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.repo_id=<USER>/record-test \
|
||||
--dataset.episode=0
|
||||
```
|
||||
|
||||
|
||||
@@ -16,14 +16,14 @@ import platform
|
||||
import time
|
||||
|
||||
|
||||
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.003):
|
||||
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.005):
|
||||
"""
|
||||
Wait for `seconds` with better precision than time.sleep alone at the expense of more CPU usage.
|
||||
|
||||
Parameters:
|
||||
- seconds: duration to wait
|
||||
- spin_threshold: if remaining <= spin_threshold -> spin; otherwise sleep (seconds). Default 10ms
|
||||
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 3ms
|
||||
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 5ms
|
||||
|
||||
Note:
|
||||
The default parameters are chosen to prioritize timing accuracy over CPU usage for the common 30 FPS use case.
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lerobot.optim.schedulers import (
|
||||
@@ -38,6 +40,10 @@ def test_diffuser_scheduler(optimizer):
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
}
|
||||
|
||||
if Version(torch.__version__) >= Version("2.8"):
|
||||
expected_state_dict["_is_initial"] = False
|
||||
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
@@ -56,6 +62,10 @@ def test_vqbet_scheduler(optimizer):
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
}
|
||||
|
||||
if Version(torch.__version__) >= Version("2.8"):
|
||||
expected_state_dict["_is_initial"] = False
|
||||
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
@@ -76,6 +86,10 @@ def test_cosine_decay_with_warmup_scheduler(optimizer):
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
}
|
||||
|
||||
if Version(torch.__version__) >= Version("2.8"):
|
||||
expected_state_dict["_is_initial"] = False
|
||||
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
|
||||
@@ -142,6 +142,7 @@ def _make_reachy2_camera_mock(*args, **kwargs):
|
||||
cam.connect = MagicMock()
|
||||
cam.disconnect = MagicMock()
|
||||
cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
|
||||
cam.read_latest = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
|
||||
return cam
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from lerobot.scripts.lerobot_edit_dataset import (
|
||||
ConvertImageToVideoConfig,
|
||||
DeleteEpisodesConfig,
|
||||
EditDatasetConfig,
|
||||
InfoConfig,
|
||||
MergeConfig,
|
||||
ModifyTasksConfig,
|
||||
OperationConfig,
|
||||
@@ -46,6 +47,7 @@ class TestOperationTypeParsing:
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
("info", InfoConfig),
|
||||
],
|
||||
)
|
||||
def test_operation_type_resolves_correct_class(self, type_name, expected_cls):
|
||||
@@ -63,6 +65,7 @@ class TestOperationTypeParsing:
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
("info", InfoConfig),
|
||||
],
|
||||
)
|
||||
def test_get_choice_name_roundtrips(self, type_name, expected_cls):
|
||||
|
||||
Reference in New Issue
Block a user