diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml
index 03f26a792..f9fa02597 100644
--- a/.github/workflows/nightly.yml
+++ b/.github/workflows/nightly.yml
@@ -119,6 +119,7 @@ jobs:
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
container:
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
+ options: --shm-size "16gb"
credentials:
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
@@ -158,3 +159,35 @@ jobs:
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
run: make test-end-to-end
+
+ # This job runs multi-GPU training tests with 4 GPUs
+ nightly-multi-gpu-tests:
+ name: Nightly Multi-GPU Tests
+ needs: [build-docker-gpu-nightly]
+ runs-on:
+ group: aws-g4dn-12xlarge # Instance with 4 GPUs
+ env:
+ HF_HOME: /home/user_lerobot/.cache/huggingface
+ HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
+ TORCH_HOME: /home/user_lerobot/.cache/torch
+ TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
+ CUDA_VISIBLE_DEVICES: "0,1,2,3"
+ container:
+ image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
+ options: --gpus all --shm-size "16gb"
+ credentials:
+ username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
+ defaults:
+ run:
+ shell: bash
+ working-directory: /lerobot
+ steps:
+ - name: Verify GPU availability
+ run: |
+ nvidia-smi
+ python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
+
+ - name: Run multi-GPU training tests
+ run: pytest tests/training/test_multi_gpu.py -vv --maxfail=3
+ timeout-minutes: 10
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index af91c9f58..06fc69fc4 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -27,15 +27,17 @@ env:
This issue was closed because it has been stalled for 14 days with no activity.
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
CLOSE_PR_MESSAGE: >
- This PR was closed because it has been stalled for 14 days with no activity.
+ This PR was closed because it has been stalled for 21 days with no activity.
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
WARN_ISSUE_MESSAGE: >
This issue has been automatically marked as stale because it has not had
recent activity (6 months). It will be closed if no further activity occurs.
+ Any change, comment or update to this issue will reset this count.
Thank you for your contributions.
WARN_PR_MESSAGE: >
This PR has been automatically marked as stale because it has not had
- recent activity (6 months). It will be closed if no further activity occurs.
+ recent activity (1 year). It will be closed if no further activity occurs.
+ Any change, comment or update to this PR will reset this count.
Thank you for your contributions.
jobs:
@@ -56,10 +58,10 @@ jobs:
stale-pr-label: stale
exempt-issue-labels: never-stale
exempt-pr-labels: never-stale
- days-before-issue-stale: 180 # TODO(Steven): Will modify this to 90 after initial cleanup
+ days-before-issue-stale: 180
days-before-issue-close: 14
- days-before-pr-stale: 180
- days-before-pr-close: 14
+ days-before-pr-stale: 365
+ days-before-pr-close: 21
delete-branch: true
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 369af602b..a07596728 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -72,7 +72,6 @@ post it.
Look at our implementations for [datasets](./src/lerobot/datasets/), [policies](./src/lerobot/policies/),
environments ([aloha](https://github.com/huggingface/gym-aloha),
-[xarm](https://github.com/huggingface/gym-xarm),
[pusht](https://github.com/huggingface/gym-pusht))
and follow the same api design.
diff --git a/Makefile b/Makefile
index fbe8a5bae..e02f02403 100644
--- a/Makefile
+++ b/Makefile
@@ -119,10 +119,9 @@ test-tdmpc-ete-train:
--policy.type=tdmpc \
--policy.device=$(DEVICE) \
--policy.push_to_hub=false \
- --env.type=xarm \
- --env.task=XarmLift-v0 \
+ --env.type=pusht \
--env.episode_length=5 \
- --dataset.repo_id=lerobot/xarm_lift_medium \
+ --dataset.repo_id=lerobot/pusht_image \
--dataset.image_transforms.enable=true \
--dataset.episodes="[0]" \
--batch_size=2 \
@@ -140,9 +139,10 @@ test-tdmpc-ete-eval:
lerobot-eval \
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
--policy.device=$(DEVICE) \
- --env.type=xarm \
+ --env.type=pusht \
--env.episode_length=5 \
- --env.task=XarmLift-v0 \
+ --env.observation_height=96 \
+ --env.observation_width=96 \
--eval.n_episodes=1 \
--eval.batch_size=1
diff --git a/README.md b/README.md
index 357e62cc1..6409d931d 100644
--- a/README.md
+++ b/README.md
@@ -310,7 +310,7 @@ To upload these to the hub, run the following:
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
```
-See [eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/eval.py) for an example of how other people may use your policy.
+See [lerobot_eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py) for an example of how other people may use your policy.
### Acknowledgment
diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 36eaea165..5e100013a 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -7,8 +7,6 @@
- sections:
- local: il_robots
title: Imitation Learning for Robots
- - local: il_sim
- title: Imitation Learning in Sim
- local: cameras
title: Cameras
- local: integrate_hardware
@@ -19,23 +17,35 @@
title: Train RL in Simulation
- local: async
title: Use Async Inference
+ - local: multi_gpu_training
+ title: Multi GPU training
title: "Tutorials"
- sections:
- local: lerobot-dataset-v3
title: Using LeRobotDataset
- local: porting_datasets_v3
title: Porting Large Datasets
+ - local: using_dataset_tools
+ title: Using the Dataset Tools
title: "Datasets"
- sections:
+ - local: act
+ title: ACT
- local: smolvla
title: SmolVLA
- local: pi0
title: π₀ (Pi0)
- local: pi05
title: π₀.₅ (Pi05)
+ title: "Policies"
+- sections:
+ - local: il_sim
+ title: Imitation Learning in Sim
- local: libero
title: Using Libero
- title: "Policies"
+ - local: metaworld
+ title: Using MetaWorld
+ title: "Simulation"
- sections:
- local: introduction_processors
title: Introduction to Robot Processors
diff --git a/docs/source/act.mdx b/docs/source/act.mdx
new file mode 100644
index 000000000..e3294ca69
--- /dev/null
+++ b/docs/source/act.mdx
@@ -0,0 +1,92 @@
+# ACT (Action Chunking with Transformers)
+
+ACT is a **lightweight and efficient policy for imitation learning**, especially well-suited for fine-grained manipulation tasks. It's the **first model we recommend when you're starting out** with LeRobot due to its fast training time, low computational requirements, and strong performance.
+
+
+
+
+
+_Watch this tutorial from the LeRobot team to learn how ACT works: [LeRobot ACT Tutorial](https://www.youtube.com/watch?v=ft73x0LfGpM)_
+
+## Model Overview
+
+Action Chunking with Transformers (ACT) was introduced in the paper [Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware](https://arxiv.org/abs/2304.13705) by Zhao et al. The policy was designed to enable precise, contact-rich manipulation tasks using affordable hardware and minimal demonstration data.
+
+### Why ACT is Great for Beginners
+
+ACT stands out as an excellent starting point for several reasons:
+
+- **Fast Training**: Trains in a few hours on a single GPU
+- **Lightweight**: Only ~80M parameters, making it efficient and easy to work with
+- **Data Efficient**: Often achieves high success rates with just 50 demonstrations
+
+### Architecture
+
+ACT uses a transformer-based architecture with three main components:
+
+1. **Vision Backbone**: ResNet-18 processes images from multiple camera viewpoints
+2. **Transformer Encoder**: Synthesizes information from camera features, joint positions, and a learned latent variable
+3. **Transformer Decoder**: Generates coherent action sequences using cross-attention
+
+The policy takes as input:
+
+- Multiple RGB images (e.g., from wrist cameras, front/top cameras)
+- Current robot joint positions
+- A latent style variable `z` (learned during training, set to zero during inference)
+
+And outputs a chunk of `k` future action sequences.
+
+## Installation Requirements
+
+1. Install LeRobot by following our [Installation Guide](./installation).
+2. ACT is included in the base LeRobot installation, so no additional dependencies are needed!
+
+## Training ACT
+
+ACT works seamlessly with the standard LeRobot training pipeline. Here's a complete example for training ACT on your dataset:
+
+```bash
+lerobot-train \
+ --dataset.repo_id=${HF_USER}/your_dataset \
+ --policy.type=act \
+ --output_dir=outputs/train/act_your_dataset \
+ --job_name=act_your_dataset \
+ --policy.device=cuda \
+ --wandb.enable=true \
+ --policy.repo_id=${HF_USER}/act_policy
+```
+
+### Training Tips
+
+1. **Start with defaults**: ACT's default hyperparameters work well for most tasks
+2. **Training duration**: Expect a few hours for 100k training steps on a single GPU
+3. **Batch size**: Start with batch size 8 and adjust based on your GPU memory
+
+### Train using Google Colab
+
+If your local computer doesn't have a powerful GPU, you can utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
+
+## Evaluating ACT
+
+Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
+
+```bash
+lerobot-record \
+ --robot.type=so100_follower \
+ --robot.port=/dev/ttyACM0 \
+ --robot.id=my_robot \
+ --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
+ --display_data=true \
+ --dataset.repo_id=${HF_USER}/eval_act_your_dataset \
+ --dataset.num_episodes=10 \
+ --dataset.single_task="Your task description" \
+ --policy.path=${HF_USER}/act_policy
+```
diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx
index 91df14028..0d8fd56e5 100644
--- a/docs/source/il_robots.mdx
+++ b/docs/source/il_robots.mdx
@@ -513,13 +513,14 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.act.modeling_act import ACTPolicy
+from lerobot.policies.factory import make_pre_post_processors
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
+from lerobot.scripts.lerobot_record import record_loop
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
-from lerobot.record import record_loop
-from lerobot.policies.factory import make_processor
+
NUM_EPISODES = 5
FPS = 30
@@ -562,7 +563,7 @@ init_rerun(session_name="recording")
# Connect the robot
robot.connect()
-preprocessor, postprocessor = make_processor(
+preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx
index 93354c2ee..f5fd09acd 100644
--- a/docs/source/installation.mdx
+++ b/docs/source/installation.mdx
@@ -91,7 +91,7 @@ LeRobot provides optional extras for specific functionalities. Multiple extras c
### Simulations
-Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), `xarm` ([gym-xarm](https://github.com/huggingface/gym-xarm)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht))
+Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht))
Example:
```bash
diff --git a/docs/source/integrate_hardware.mdx b/docs/source/integrate_hardware.mdx
index 7e7fe0bff..ed9dc8dd5 100644
--- a/docs/source/integrate_hardware.mdx
+++ b/docs/source/integrate_hardware.mdx
@@ -8,7 +8,7 @@ To that end, we provide the [`Robot`](https://github.com/huggingface/lerobot/blo
- Your own robot which exposes a communication interface (e.g. serial, CAN, TCP)
- A way to read sensor data and send motor commands programmatically, e.g. manufacturer's SDK or API, or your own protocol implementation.
-- LeRobot installed in your environment. Follow our [Installation Guide](./installation.mdx).
+- LeRobot installed in your environment. Follow our [Installation Guide](./installation).
## Choose your motors
@@ -65,7 +65,7 @@ class MyCoolRobotConfig(RobotConfig):
```
-[Cameras tutorial](./cameras.mdx) to understand how to detect and add your camera.
+[Cameras tutorial](./cameras) to understand how to detect and add your camera.
Next, we'll create our actual robot class which inherits from `Robot`. This abstract class defines a contract you must follow for your robot to be usable with the rest of the LeRobot tools.
diff --git a/docs/source/introduction_processors.mdx b/docs/source/introduction_processors.mdx
index 308edbb3b..6f3768615 100644
--- a/docs/source/introduction_processors.mdx
+++ b/docs/source/introduction_processors.mdx
@@ -297,9 +297,9 @@ LeRobot provides many registered processor steps. Here are the most commonly use
### Next Steps
-- **[Implement Your Own Processor](implement_your_own_processor.mdx)** - Create custom processor steps
-- **[Debug Your Pipeline](debug_processor_pipeline.mdx)** - Troubleshoot and optimize pipelines
-- **[Processors for Robots and Teleoperators](processors_robots_teleop.mdx)** - Real-world integration patterns
+- **[Implement Your Own Processor](./implement_your_own_processor)** - Create custom processor steps
+- **[Debug Your Pipeline](./debug_processor_pipeline)** - Troubleshoot and optimize pipelines
+- **[Processors for Robots and Teleoperators](./processors_robots_teleop)** - Real-world integration patterns
## Summary
diff --git a/docs/source/lerobot-dataset-v3.mdx b/docs/source/lerobot-dataset-v3.mdx
index cf1942fdc..3521914f2 100644
--- a/docs/source/lerobot-dataset-v3.mdx
+++ b/docs/source/lerobot-dataset-v3.mdx
@@ -279,3 +279,36 @@ python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id= 2e-4
+accelerate launch --num_processes=2 $(which lerobot-train) \
+ --optimizer.lr=2e-4 \
+ --dataset.repo_id=lerobot/pusht \
+ --policy=act
+```
+
+**Training Steps Scaling:**
+
+Since the effective batch size `bs` increases with multiple GPUs (batch_size × num_gpus), you may want to reduce the number of training steps proportionally:
+
+```bash
+# Example: 2 GPUs with effective batch size 2x larger
+# Original: batch_size=8, steps=100000
+# With 2 GPUs: batch_size=8 (16 in total), steps=50000
+accelerate launch --num_processes=2 $(which lerobot-train) \
+ --batch_size=8 \
+ --steps=50000 \
+ --dataset.repo_id=lerobot/pusht \
+ --policy=act
+```
+
+## Notes
+
+- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
+- Training logs, checkpoints, and hub uploads are only done by the main process to avoid conflicts. Non-main processes have console logging disabled to prevent duplicate output.
+- The effective batch size is `batch_size × num_gpus`. If you use 4 GPUs with `--batch_size=8`, your effective batch size is 32.
+- Learning rate scheduling is handled correctly across multiple processes—LeRobot sets `step_scheduler_with_optimizer=False` to prevent accelerate from adjusting scheduler steps based on the number of processes.
+- When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility.
+- WandB integration automatically initializes only on the main process, preventing multiple runs from being created.
+
+For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate). If you want to learn more about how to train on a large number of GPUs, checkout this awesome guide: [Ultrascale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook).
diff --git a/docs/source/phone_teleop.mdx b/docs/source/phone_teleop.mdx
index 22159193c..76e3c367c 100644
--- a/docs/source/phone_teleop.mdx
+++ b/docs/source/phone_teleop.mdx
@@ -79,7 +79,7 @@ After running the example:
- Android: after starting the script, open the printed local URL on your phone, tap Start, then press and hold Move.
- iOS: open HEBI Mobile I/O first; B1 enables motion. A3 controls the gripper.
-Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop.mdx) guide.
+Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop) guide.
- Run this example to record a dataset, which saves absolute end effector observations and actions:
diff --git a/docs/source/using_dataset_tools.mdx b/docs/source/using_dataset_tools.mdx
new file mode 100644
index 000000000..affca0ee5
--- /dev/null
+++ b/docs/source/using_dataset_tools.mdx
@@ -0,0 +1,102 @@
+# Using Dataset Tools
+
+This guide covers the dataset tools utilities available in LeRobot for modifying and editing existing datasets.
+
+## Overview
+
+LeRobot provides several utilities for manipulating datasets:
+
+1. **Delete Episodes** - Remove specific episodes from a dataset
+2. **Split Dataset** - Divide a dataset into multiple smaller datasets
+3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
+4. **Add Features** - Add new features to a dataset
+5. **Remove Features** - Remove features from a dataset
+
+The core implementation is in `lerobot.datasets.dataset_tools`.
+An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
+
+## Command-Line Tool: lerobot-edit-dataset
+
+`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features.
+
+Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
+
+### Usage Examples
+
+#### Delete Episodes
+
+Remove specific episodes from a dataset. This is useful for filtering out undesired data.
+
+```bash
+# Delete episodes 0, 2, and 5 (modifies original dataset)
+lerobot-edit-dataset \
+ --repo_id lerobot/pusht \
+ --operation.type delete_episodes \
+ --operation.episode_indices "[0, 2, 5]"
+
+# Delete episodes and save to a new dataset (preserves original dataset)
+lerobot-edit-dataset \
+ --repo_id lerobot/pusht \
+ --new_repo_id lerobot/pusht_after_deletion \
+ --operation.type delete_episodes \
+ --operation.episode_indices "[0, 2, 5]"
+```
+
+#### Split Dataset
+
+Divide a dataset into multiple subsets.
+
+```bash
+# Split by fractions (e.g. 80% train, 20% test, 20% val)
+lerobot-edit-dataset \
+ --repo_id lerobot/pusht \
+ --operation.type split \
+ --operation.splits '{"train": 0.8, "test": 0.2, "val": 0.2}'
+
+# Split by specific episode indices
+lerobot-edit-dataset \
+ --repo_id lerobot/pusht \
+ --operation.type split \
+ --operation.splits '{"task1": [0, 1, 2, 3], "task2": [4, 5]}'
+```
+
+There are no constraints on the split names, they can be determined by the user. Resulting datasets are saved under the repo id with the split name appended, e.g. `lerobot/pusht_train`, `lerobot/pusht_task1`, `lerobot/pusht_task2`.
+
+#### Merge Datasets
+
+Combine multiple datasets into a single dataset.
+
+```bash
+# Merge train and validation splits back into one dataset
+lerobot-edit-dataset \
+ --repo_id lerobot/pusht_merged \
+ --operation.type merge \
+ --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
+```
+
+#### Remove Features
+
+Remove features from a dataset.
+
+```bash
+# Remove a camera feature
+lerobot-edit-dataset \
+ --repo_id lerobot/pusht \
+ --operation.type remove_feature \
+ --operation.feature_names "['observation.images.top']"
+```
+
+### Push to Hub
+
+Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
+
+```bash
+lerobot-edit-dataset \
+ --repo_id lerobot/pusht \
+ --new_repo_id lerobot/pusht_after_deletion \
+ --operation.type delete_episodes \
+ --operation.episode_indices "[0, 2, 5]" \
+ --push_to_hub
+```
+
+There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.
diff --git a/examples/dataset/use_dataset_tools.py b/examples/dataset/use_dataset_tools.py
new file mode 100644
index 000000000..bd7c389bc
--- /dev/null
+++ b/examples/dataset/use_dataset_tools.py
@@ -0,0 +1,124 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Example script demonstrating dataset tools utilities.
+
+This script shows how to:
+1. Delete episodes from a dataset
+2. Split a dataset into train/val sets
+3. Add/remove features
+4. Merge datasets
+
+Usage:
+ python examples/dataset/use_dataset_tools.py
+"""
+
+import numpy as np
+
+from lerobot.datasets.dataset_tools import (
+ add_features,
+ delete_episodes,
+ merge_datasets,
+ modify_features,
+ remove_feature,
+ split_dataset,
+)
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+
+
+def main():
+ dataset = LeRobotDataset("lerobot/pusht")
+
+ print(f"Original dataset: {dataset.meta.total_episodes} episodes, {dataset.meta.total_frames} frames")
+ print(f"Features: {list(dataset.meta.features.keys())}")
+
+ print("\n1. Deleting episodes 0 and 2...")
+ filtered_dataset = delete_episodes(dataset, episode_indices=[0, 2], repo_id="lerobot/pusht_filtered")
+ print(f"Filtered dataset: {filtered_dataset.meta.total_episodes} episodes")
+
+ print("\n2. Splitting dataset into train/val...")
+ splits = split_dataset(
+ dataset,
+ splits={"train": 0.8, "val": 0.2},
+ )
+ print(f"Train split: {splits['train'].meta.total_episodes} episodes")
+ print(f"Val split: {splits['val'].meta.total_episodes} episodes")
+
+ print("\n3. Adding features...")
+
+ reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32)
+
+ def compute_success(row_dict, episode_index, frame_index):
+ episode_length = 10
+ return float(frame_index >= episode_length - 10)
+
+ dataset_with_features = add_features(
+ dataset,
+ features={
+ "reward": (
+ reward_values,
+ {"dtype": "float32", "shape": (1,), "names": None},
+ ),
+ "success": (
+ compute_success,
+ {"dtype": "float32", "shape": (1,), "names": None},
+ ),
+ },
+ repo_id="lerobot/pusht_with_features",
+ )
+
+ print(f"New features: {list(dataset_with_features.meta.features.keys())}")
+
+ print("\n4. Removing the success feature...")
+ dataset_cleaned = remove_feature(
+ dataset_with_features, feature_names="success", repo_id="lerobot/pusht_cleaned"
+ )
+ print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}")
+
+ print("\n5. Using modify_features to add and remove features simultaneously...")
+ dataset_modified = modify_features(
+ dataset_with_features,
+ add_features={
+ "discount": (
+ np.ones(dataset.meta.total_frames, dtype=np.float32) * 0.99,
+ {"dtype": "float32", "shape": (1,), "names": None},
+ ),
+ },
+ remove_features="reward",
+ repo_id="lerobot/pusht_modified",
+ )
+ print(f"Modified features: {list(dataset_modified.meta.features.keys())}")
+
+ print("\n6. Merging train and val splits back together...")
+ merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged")
+ print(f"Merged dataset: {merged.meta.total_episodes} episodes")
+
+ print("\n7. Complex workflow example...")
+
+ if len(dataset.meta.camera_keys) > 1:
+ camera_to_remove = dataset.meta.camera_keys[0]
+ print(f"Removing camera: {camera_to_remove}")
+ dataset_no_cam = remove_feature(
+ dataset, feature_names=camera_to_remove, repo_id="pusht_no_first_camera"
+ )
+ print(f"Remaining cameras: {dataset_no_cam.meta.camera_keys}")
+
+ print("\nDone! Check ~/.cache/huggingface/lerobot/ for the created datasets.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py
index 8a62d92a9..4501008d0 100644
--- a/examples/lekiwi/evaluate.py
+++ b/examples/lekiwi/evaluate.py
@@ -133,4 +133,6 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say("Stop recording")
robot.disconnect()
listener.stop()
+
+dataset.finalize()
dataset.push_to_hub()
diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py
index 9070741bf..491e1c386 100644
--- a/examples/lekiwi/record.py
+++ b/examples/lekiwi/record.py
@@ -130,4 +130,6 @@ robot.disconnect()
leader_arm.disconnect()
keyboard.disconnect()
listener.stop()
+
+dataset.finalize()
dataset.push_to_hub()
diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py
index 0d53f1177..ff8dbddd2 100644
--- a/examples/phone_to_so100/evaluate.py
+++ b/examples/phone_to_so100/evaluate.py
@@ -194,4 +194,6 @@ for episode_idx in range(NUM_EPISODES):
log_say("Stop recording")
robot.disconnect()
listener.stop()
+
+dataset.finalize()
dataset.push_to_hub()
diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py
index d3ef293a7..880f9c9b4 100644
--- a/examples/phone_to_so100/record.py
+++ b/examples/phone_to_so100/record.py
@@ -200,4 +200,6 @@ log_say("Stop recording")
robot.disconnect()
phone.disconnect()
listener.stop()
+
+dataset.finalize()
dataset.push_to_hub()
diff --git a/examples/port_datasets/port_droid.py b/examples/port_datasets/port_droid.py
index 4efb131e4..a1fb50914 100644
--- a/examples/port_datasets/port_droid.py
+++ b/examples/port_datasets/port_droid.py
@@ -362,6 +362,8 @@ def port_droid(
lerobot_dataset.save_episode()
logging.info("Save_episode")
+ lerobot_dataset.finalize()
+
if push_to_hub:
lerobot_dataset.push_to_hub(
# Add openx tag, since it belongs to the openx collection of datasets
diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py
index 53a385442..60489b3cf 100644
--- a/examples/so100_to_so100_EE/evaluate.py
+++ b/examples/so100_to_so100_EE/evaluate.py
@@ -195,4 +195,6 @@ for episode_idx in range(NUM_EPISODES):
log_say("Stop recording")
robot.disconnect()
listener.stop()
+
+dataset.finalize()
dataset.push_to_hub()
diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py
index 9ed6e51a9..5ff1c286f 100644
--- a/examples/so100_to_so100_EE/record.py
+++ b/examples/so100_to_so100_EE/record.py
@@ -199,4 +199,6 @@ log_say("Stop recording")
leader.disconnect()
follower.disconnect()
listener.stop()
+
+dataset.finalize()
dataset.push_to_hub()
diff --git a/pyproject.toml b/pyproject.toml
index c67b481f0..c8879ac36 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -62,24 +62,26 @@ dependencies = [
"datasets>=4.0.0,<4.2.0",
"diffusers>=0.27.2,<0.36.0",
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
+ "accelerate>=1.10.0,<2.0.0",
# Core dependencies
+ "setuptools>=71.0.0,<81.0.0",
"cmake>=3.29.0.1,<4.2.0",
"einops>=0.8.0,<0.9.0",
"opencv-python-headless>=4.9.0,<4.13.0",
- "av>=14.2.0,<16.0.0",
+ "av>=15.0.0,<16.0.0",
"jsonlines>=4.0.0,<5.0.0",
"packaging>=24.2,<26.0",
"pynput>=1.7.7,<1.9.0",
"pyserial>=3.5,<4.0",
- "wandb>=0.20.0,<0.23.0",
+ "wandb>=0.20.0,<0.22.0", # TODO: Bumb dependency (compatible with protobuf)
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
"draccus==0.10.0", # TODO: Remove ==
- "gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency
+ "gymnasium>=1.0.0",
"rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency
# Support dependencies
@@ -95,7 +97,7 @@ dependencies = [
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.10.0"]
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
-grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
+grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
@@ -132,11 +134,10 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
-aloha = ["gym-aloha>=0.1.1,<0.2.0"]
+aloha = ["gym-aloha>=0.1.2,<0.2.0"]
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
-xarm = ["gym-xarm>=0.1.1,<0.2.0"]
libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"]
-
+metaworld = ["metaworld>=3.0.0"]
# All
all = [
@@ -156,9 +157,9 @@ all = [
"lerobot[video_benchmark]",
"lerobot[aloha]",
"lerobot[pusht]",
- "lerobot[xarm]",
"lerobot[phone]",
"lerobot[libero]",
+ "lerobot[metaworld]",
]
[project.scripts]
@@ -175,6 +176,7 @@ lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
lerobot-info="lerobot.scripts.lerobot_info:main"
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
+lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.packages.find]
diff --git a/src/lerobot/__init__.py b/src/lerobot/__init__.py
index 9d3ed1893..eec574296 100644
--- a/src/lerobot/__init__.py
+++ b/src/lerobot/__init__.py
@@ -57,7 +57,6 @@ available_tasks_per_env = {
"AlohaTransferCube-v0",
],
"pusht": ["PushT-v0"],
- "xarm": ["XarmLift-v0"],
}
available_envs = list(available_tasks_per_env.keys())
@@ -75,16 +74,6 @@ available_datasets_per_env = {
# TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly
# coupled with tests.
"pusht": ["lerobot/pusht", "lerobot/pusht_image"],
- "xarm": [
- "lerobot/xarm_lift_medium",
- "lerobot/xarm_lift_medium_replay",
- "lerobot/xarm_push_medium",
- "lerobot/xarm_push_medium_replay",
- "lerobot/xarm_lift_medium_image",
- "lerobot/xarm_lift_medium_replay_image",
- "lerobot/xarm_push_medium_image",
- "lerobot/xarm_push_medium_replay_image",
- ],
}
available_real_world_datasets = [
@@ -195,7 +184,6 @@ available_motors = [
available_policies_per_env = {
"aloha": ["act"],
"pusht": ["diffusion", "vqbet"],
- "xarm": ["tdmpc"],
"koch_real": ["act_koch_real"],
"aloha_real": ["act_aloha_real"],
}
diff --git a/src/lerobot/async_inference/configs.py b/src/lerobot/async_inference/configs.py
index 24f889df1..d1768a323 100644
--- a/src/lerobot/async_inference/configs.py
+++ b/src/lerobot/async_inference/configs.py
@@ -142,11 +142,6 @@ class RobotClientConfig:
default=False, metadata={"help": "Visualize the action queue size"}
)
- # Verification configuration
- verify_robot_cameras: bool = field(
- default=True, metadata={"help": "Verify that the robot cameras match the policy cameras"}
- )
-
@property
def environment_dt(self) -> float:
"""Environment time step, in seconds"""
diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py
index 54fad8c54..f73cbc1da 100644
--- a/src/lerobot/async_inference/helpers.py
+++ b/src/lerobot/async_inference/helpers.py
@@ -62,15 +62,6 @@ def visualize_action_queue_size(action_queue_size: list[int]) -> None:
plt.show()
-def validate_robot_cameras_for_policy(
- lerobot_observation_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature]
-) -> None:
- image_keys = list(filter(is_image_key, lerobot_observation_features))
- assert set(image_keys) == set(policy_image_features.keys()), (
- f"Policy image features must match robot cameras! Received {list(policy_image_features.keys())} != {image_keys}"
- )
-
-
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False)
diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py
index 8c4425c6b..f9d70a64e 100644
--- a/src/lerobot/async_inference/robot_client.py
+++ b/src/lerobot/async_inference/robot_client.py
@@ -48,7 +48,6 @@ import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
-from lerobot.configs.policies import PreTrainedConfig
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -76,7 +75,6 @@ from .helpers import (
TimedObservation,
get_logger,
map_robot_keys_to_lerobot_features,
- validate_robot_cameras_for_policy,
visualize_action_queue_size,
)
@@ -98,14 +96,6 @@ class RobotClient:
lerobot_features = map_robot_keys_to_lerobot_features(self.robot)
- if config.verify_robot_cameras:
- # Load policy config for validation
- policy_config = PreTrainedConfig.from_pretrained(config.pretrained_name_or_path)
- policy_image_features = policy_config.image_features
-
- # The cameras specified for inference must match the one supported by the policy chosen
- validate_robot_cameras_for_policy(lerobot_features, policy_image_features)
-
# Use environment variable if server_address is not provided in config
self.server_address = config.server_address
diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py
index 803645f29..870c9571e 100644
--- a/src/lerobot/datasets/aggregate.py
+++ b/src/lerobot/datasets/aggregate.py
@@ -31,15 +31,15 @@ from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
+ get_file_size_in_mb,
get_parquet_file_size_in_mb,
- get_video_size_in_mb,
to_parquet_with_hf_images,
update_chunk_file_indices,
write_info,
write_stats,
write_tasks,
)
-from lerobot.datasets.video_utils import concatenate_video_files
+from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
@@ -130,10 +130,34 @@ def update_meta_data(
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
for key, video_idx in videos_idx.items():
- df[f"videos/{key}/chunk_index"] = df[f"videos/{key}/chunk_index"] + video_idx["chunk"]
- df[f"videos/{key}/file_index"] = df[f"videos/{key}/file_index"] + video_idx["file"]
- df[f"videos/{key}/from_timestamp"] = df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
- df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
+ # Store original video file indices before updating
+ orig_chunk_col = f"videos/{key}/chunk_index"
+ orig_file_col = f"videos/{key}/file_index"
+ df["_orig_chunk"] = df[orig_chunk_col].copy()
+ df["_orig_file"] = df[orig_file_col].copy()
+
+ # Update chunk and file indices to point to destination
+ df[orig_chunk_col] = video_idx["chunk"]
+ df[orig_file_col] = video_idx["file"]
+
+ # Apply per-source-file timestamp offsets
+ src_to_offset = video_idx.get("src_to_offset", {})
+ if src_to_offset:
+ # Apply offset based on original source file
+ for idx in df.index:
+ src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
+ offset = src_to_offset.get(src_key, 0)
+ df.at[idx, f"videos/{key}/from_timestamp"] += offset
+ df.at[idx, f"videos/{key}/to_timestamp"] += offset
+ else:
+ # Fallback to simple offset (for backward compatibility)
+ df[f"videos/{key}/from_timestamp"] = (
+ df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
+ )
+ df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
+
+ # Clean up temporary columns
+ df = df.drop(columns=["_orig_chunk", "_orig_file"])
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
@@ -193,6 +217,10 @@ def aggregate_datasets(
robot_type=robot_type,
features=features,
root=aggr_root,
+ use_videos=len(video_keys) > 0,
+ chunks_size=chunk_size,
+ data_files_size_in_mb=data_files_size_in_mb,
+ video_files_size_in_mb=video_files_size_in_mb,
)
logging.info("Find all tasks")
@@ -236,6 +264,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
Returns:
dict: Updated videos_idx with current chunk and file indices.
"""
+ for key in videos_idx:
+ videos_idx[key]["episode_duration"] = 0
+ # Track offset for each source (chunk, file) pair
+ videos_idx[key]["src_to_offset"] = {}
+
for key, video_idx in videos_idx.items():
unique_chunk_file_pairs = {
(chunk, file)
@@ -249,6 +282,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
chunk_idx = video_idx["chunk"]
file_idx = video_idx["file"]
+ current_offset = video_idx["latest_duration"]
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
@@ -263,21 +297,25 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
file_index=file_idx,
)
- # If a new file is created, we don't want to increment the latest_duration
- update_latest_duration = False
+ src_duration = get_video_duration_in_s(src_path)
if not dst_path.exists():
- # First write to this destination file
+ # Store offset before incrementing
+ videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(src_path), str(dst_path))
- continue # not accumulating further, already copied the file in place
+ videos_idx[key]["episode_duration"] += src_duration
+ current_offset += src_duration
+ continue
# Check file sizes before appending
- src_size = get_video_size_in_mb(src_path)
- dst_size = get_video_size_in_mb(dst_path)
+ src_size = get_file_size_in_mb(src_path)
+ dst_size = get_file_size_in_mb(dst_path)
if dst_size + src_size >= video_files_size_in_mb:
- # Rotate to a new chunk/file
+ # Rotate to a new file, this source becomes start of new destination
+ # So its offset should be 0
+ videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key,
@@ -286,25 +324,22 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
)
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(src_path), str(dst_path))
+ # Reset offset for next file
+ current_offset = src_duration
else:
- # Get the timestamps shift for this video
- timestamps_shift_s = dst_meta.info["total_frames"] / dst_meta.info["fps"]
-
- # Append to existing video file
+ # Append to existing video file - use current accumulated offset
+ videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
concatenate_video_files(
[dst_path, src_path],
dst_path,
)
- # Update the latest_duration when appending (shifts timestamps!)
- update_latest_duration = not update_latest_duration
+ current_offset += src_duration
+
+ videos_idx[key]["episode_duration"] += src_duration
- # Update the videos_idx with the final chunk and file indices for this key
videos_idx[key]["chunk"] = chunk_idx
videos_idx[key]["file"] = file_idx
- if update_latest_duration:
- videos_idx[key]["latest_duration"] += timestamps_shift_s
-
return videos_idx
@@ -389,9 +424,6 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
videos_idx,
)
- for k in videos_idx:
- videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
-
meta_idx = append_or_create_parquet_file(
df,
src_path,
@@ -403,6 +435,10 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
aggr_root=dst_meta.root,
)
+ # Increment latest_duration by the total duration added from this source dataset
+ for k in videos_idx:
+ videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
+
return meta_idx
diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py
new file mode 100644
index 000000000..2735ba0a0
--- /dev/null
+++ b/src/lerobot/datasets/dataset_tools.py
@@ -0,0 +1,1089 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Dataset tools utilities for LeRobotDataset.
+
+This module provides utilities for:
+- Deleting episodes from datasets
+- Splitting datasets into multiple smaller datasets
+- Adding/removing features from datasets
+- Merging datasets (wrapper around aggregate functionality)
+"""
+
+import logging
+import shutil
+from collections.abc import Callable
+from pathlib import Path
+
+import datasets
+import numpy as np
+import pandas as pd
+import pyarrow.parquet as pq
+import torch
+from tqdm import tqdm
+
+from lerobot.datasets.aggregate import aggregate_datasets
+from lerobot.datasets.compute_stats import aggregate_stats
+from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
+from lerobot.datasets.utils import (
+ DEFAULT_CHUNK_SIZE,
+ DEFAULT_DATA_FILE_SIZE_IN_MB,
+ DEFAULT_DATA_PATH,
+ DEFAULT_EPISODES_PATH,
+ get_parquet_file_size_in_mb,
+ load_episodes,
+ update_chunk_file_indices,
+ write_info,
+ write_stats,
+ write_tasks,
+)
+from lerobot.utils.constants import HF_LEROBOT_HOME
+
+
+def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
+ """Load a single episode's metadata including stats from parquet file.
+
+ Args:
+ src_dataset: Source dataset
+ episode_idx: Episode index to load
+
+ Returns:
+ dict containing episode metadata and stats
+ """
+ ep_meta = src_dataset.meta.episodes[episode_idx]
+ chunk_idx = ep_meta["meta/episodes/chunk_index"]
+ file_idx = ep_meta["meta/episodes/file_index"]
+
+ parquet_path = src_dataset.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
+ df = pd.read_parquet(parquet_path)
+
+ episode_row = df[df["episode_index"] == episode_idx].iloc[0]
+
+ return episode_row.to_dict()
+
+
+def delete_episodes(
+ dataset: LeRobotDataset,
+ episode_indices: list[int],
+ output_dir: str | Path | None = None,
+ repo_id: str | None = None,
+) -> LeRobotDataset:
+ """Delete episodes from a LeRobotDataset and create a new dataset.
+
+ Args:
+ dataset: The source LeRobotDataset.
+ episode_indices: List of episode indices to delete.
+ output_dir: Directory to save the new dataset. If None, uses default location.
+ repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
+ """
+ if not episode_indices:
+ raise ValueError("No episodes to delete")
+
+ valid_indices = set(range(dataset.meta.total_episodes))
+ invalid = set(episode_indices) - valid_indices
+ if invalid:
+ raise ValueError(f"Invalid episode indices: {invalid}")
+
+ logging.info(f"Deleting {len(episode_indices)} episodes from dataset")
+
+ if repo_id is None:
+ repo_id = f"{dataset.repo_id}_modified"
+ output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
+
+ episodes_to_keep = [i for i in range(dataset.meta.total_episodes) if i not in episode_indices]
+ if not episodes_to_keep:
+ raise ValueError("Cannot delete all episodes from dataset")
+
+ new_meta = LeRobotDatasetMetadata.create(
+ repo_id=repo_id,
+ fps=dataset.meta.fps,
+ features=dataset.meta.features,
+ robot_type=dataset.meta.robot_type,
+ root=output_dir,
+ use_videos=len(dataset.meta.video_keys) > 0,
+ )
+
+ episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(episodes_to_keep)}
+
+ video_metadata = None
+ if dataset.meta.video_keys:
+ video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping)
+
+ data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping)
+
+ _copy_and_reindex_episodes_metadata(dataset, new_meta, episode_mapping, data_metadata, video_metadata)
+
+ new_dataset = LeRobotDataset(
+ repo_id=repo_id,
+ root=output_dir,
+ image_transforms=dataset.image_transforms,
+ delta_timestamps=dataset.delta_timestamps,
+ tolerance_s=dataset.tolerance_s,
+ )
+
+ logging.info(f"Created new dataset with {len(episodes_to_keep)} episodes")
+ return new_dataset
+
+
+def split_dataset(
+ dataset: LeRobotDataset,
+ splits: dict[str, float | list[int]],
+ output_dir: str | Path | None = None,
+) -> dict[str, LeRobotDataset]:
+ """Split a LeRobotDataset into multiple smaller datasets.
+
+ Args:
+ dataset: The source LeRobotDataset to split.
+ splits: Either a dict mapping split names to episode indices, or a dict mapping
+ split names to fractions (must sum to <= 1.0).
+ output_dir: Base directory for output datasets. If None, uses default location.
+
+ Examples:
+ Split by specific episodes
+ splits = {"train": [0, 1, 2], "val": [3, 4]}
+ datasets = split_dataset(dataset, splits)
+
+ Split by fractions
+ splits = {"train": 0.8, "val": 0.2}
+ datasets = split_dataset(dataset, splits)
+ """
+ if not splits:
+ raise ValueError("No splits provided")
+
+ if all(isinstance(v, float) for v in splits.values()):
+ splits = _fractions_to_episode_indices(dataset.meta.total_episodes, splits)
+
+ all_episodes = set()
+ for split_name, episodes in splits.items():
+ if not episodes:
+ raise ValueError(f"Split '{split_name}' has no episodes")
+ episode_set = set(episodes)
+ if episode_set & all_episodes:
+ raise ValueError("Episodes cannot appear in multiple splits")
+ all_episodes.update(episode_set)
+
+ valid_indices = set(range(dataset.meta.total_episodes))
+ invalid = all_episodes - valid_indices
+ if invalid:
+ raise ValueError(f"Invalid episode indices: {invalid}")
+
+ if output_dir is not None:
+ output_dir = Path(output_dir)
+
+ result_datasets = {}
+
+ for split_name, episodes in splits.items():
+ logging.info(f"Creating split '{split_name}' with {len(episodes)} episodes")
+
+ split_repo_id = f"{dataset.repo_id}_{split_name}"
+
+ split_output_dir = (
+ output_dir / split_name if output_dir is not None else HF_LEROBOT_HOME / split_repo_id
+ )
+
+ episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(sorted(episodes))}
+
+ new_meta = LeRobotDatasetMetadata.create(
+ repo_id=split_repo_id,
+ fps=dataset.meta.fps,
+ features=dataset.meta.features,
+ robot_type=dataset.meta.robot_type,
+ root=split_output_dir,
+ use_videos=len(dataset.meta.video_keys) > 0,
+ chunks_size=dataset.meta.chunks_size,
+ data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
+ video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
+ )
+
+ video_metadata = None
+ if dataset.meta.video_keys:
+ video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping)
+
+ data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping)
+
+ _copy_and_reindex_episodes_metadata(dataset, new_meta, episode_mapping, data_metadata, video_metadata)
+
+ new_dataset = LeRobotDataset(
+ repo_id=split_repo_id,
+ root=split_output_dir,
+ image_transforms=dataset.image_transforms,
+ delta_timestamps=dataset.delta_timestamps,
+ tolerance_s=dataset.tolerance_s,
+ )
+
+ result_datasets[split_name] = new_dataset
+
+ return result_datasets
+
+
+def merge_datasets(
+ datasets: list[LeRobotDataset],
+ output_repo_id: str,
+ output_dir: str | Path | None = None,
+) -> LeRobotDataset:
+ """Merge multiple LeRobotDatasets into a single dataset.
+
+ This is a wrapper around the aggregate_datasets functionality with a cleaner API.
+
+ Args:
+ datasets: List of LeRobotDatasets to merge.
+ output_repo_id: Repository ID for the merged dataset.
+ output_dir: Directory to save the merged dataset. If None, uses default location.
+ """
+ if not datasets:
+ raise ValueError("No datasets to merge")
+
+ output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / output_repo_id
+
+ repo_ids = [ds.repo_id for ds in datasets]
+ roots = [ds.root for ds in datasets]
+
+ aggregate_datasets(
+ repo_ids=repo_ids,
+ aggr_repo_id=output_repo_id,
+ roots=roots,
+ aggr_root=output_dir,
+ )
+
+ merged_dataset = LeRobotDataset(
+ repo_id=output_repo_id,
+ root=output_dir,
+ image_transforms=datasets[0].image_transforms,
+ delta_timestamps=datasets[0].delta_timestamps,
+ tolerance_s=datasets[0].tolerance_s,
+ )
+
+ return merged_dataset
+
+
+def modify_features(
+ dataset: LeRobotDataset,
+ add_features: dict[str, tuple[np.ndarray | torch.Tensor | Callable, dict]] | None = None,
+ remove_features: str | list[str] | None = None,
+ output_dir: str | Path | None = None,
+ repo_id: str | None = None,
+) -> LeRobotDataset:
+ """Modify a LeRobotDataset by adding and/or removing features in a single pass.
+
+ This is the most efficient way to modify features, as it only copies the dataset once
+ regardless of how many features are being added or removed.
+
+ Args:
+ dataset: The source LeRobotDataset.
+ add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples.
+ remove_features: Optional feature name(s) to remove. Can be a single string or list.
+ output_dir: Directory to save the new dataset. If None, uses default location.
+ repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
+
+ Returns:
+ New dataset with features modified.
+
+ Example:
+ new_dataset = modify_features(
+ dataset,
+ add_features={
+ "reward": (reward_array, {"dtype": "float32", "shape": [1], "names": None}),
+ },
+ remove_features=["old_feature"],
+ output_dir="./output",
+ )
+ """
+ if add_features is None and remove_features is None:
+ raise ValueError("Must specify at least one of add_features or remove_features")
+
+ remove_features_list: list[str] = []
+ if remove_features is not None:
+ remove_features_list = [remove_features] if isinstance(remove_features, str) else remove_features
+
+ if add_features:
+ required_keys = {"dtype", "shape"}
+ for feature_name, (_, feature_info) in add_features.items():
+ if feature_name in dataset.meta.features:
+ raise ValueError(f"Feature '{feature_name}' already exists in dataset")
+
+ if not required_keys.issubset(feature_info.keys()):
+ raise ValueError(f"feature_info for '{feature_name}' must contain keys: {required_keys}")
+
+ if remove_features_list:
+ for name in remove_features_list:
+ if name not in dataset.meta.features:
+ raise ValueError(f"Feature '{name}' not found in dataset")
+
+ required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"}
+ if any(name in required_features for name in remove_features_list):
+ raise ValueError(f"Cannot remove required features: {required_features}")
+
+ if repo_id is None:
+ repo_id = f"{dataset.repo_id}_modified"
+ output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
+
+ new_features = dataset.meta.features.copy()
+
+ if remove_features_list:
+ for name in remove_features_list:
+ new_features.pop(name, None)
+
+ if add_features:
+ for feature_name, (_, feature_info) in add_features.items():
+ new_features[feature_name] = feature_info
+
+ video_keys_to_remove = [name for name in remove_features_list if name in dataset.meta.video_keys]
+ remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove]
+
+ new_meta = LeRobotDatasetMetadata.create(
+ repo_id=repo_id,
+ fps=dataset.meta.fps,
+ features=new_features,
+ robot_type=dataset.meta.robot_type,
+ root=output_dir,
+ use_videos=len(remaining_video_keys) > 0,
+ )
+
+ _copy_data_with_feature_changes(
+ dataset=dataset,
+ new_meta=new_meta,
+ add_features=add_features,
+ remove_features=remove_features_list if remove_features_list else None,
+ )
+
+ if new_meta.video_keys:
+ _copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove if video_keys_to_remove else None)
+
+ new_dataset = LeRobotDataset(
+ repo_id=repo_id,
+ root=output_dir,
+ image_transforms=dataset.image_transforms,
+ delta_timestamps=dataset.delta_timestamps,
+ tolerance_s=dataset.tolerance_s,
+ )
+
+ return new_dataset
+
+
+def add_features(
+ dataset: LeRobotDataset,
+ features: dict[str, tuple[np.ndarray | torch.Tensor | Callable, dict]],
+ output_dir: str | Path | None = None,
+ repo_id: str | None = None,
+) -> LeRobotDataset:
+ """Add multiple features to a LeRobotDataset in a single pass.
+
+ This is more efficient than calling add_feature() multiple times, as it only
+ copies the dataset once regardless of how many features are being added.
+
+ Args:
+ dataset: The source LeRobotDataset.
+ features: Dictionary mapping feature names to (feature_values, feature_info) tuples.
+ output_dir: Directory to save the new dataset. If None, uses default location.
+ repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
+
+ Returns:
+ New dataset with all features added.
+
+ Example:
+ features = {
+ "task_embedding": (task_emb_array, {"dtype": "float32", "shape": [384], "names": None}),
+ "cam1_embedding": (cam1_emb_array, {"dtype": "float32", "shape": [768], "names": None}),
+ "cam2_embedding": (cam2_emb_array, {"dtype": "float32", "shape": [768], "names": None}),
+ }
+ new_dataset = add_features(dataset, features, output_dir="./output", repo_id="my_dataset")
+ """
+ if not features:
+ raise ValueError("No features provided")
+
+ return modify_features(
+ dataset=dataset,
+ add_features=features,
+ remove_features=None,
+ output_dir=output_dir,
+ repo_id=repo_id,
+ )
+
+
+def remove_feature(
+ dataset: LeRobotDataset,
+ feature_names: str | list[str],
+ output_dir: str | Path | None = None,
+ repo_id: str | None = None,
+) -> LeRobotDataset:
+ """Remove features from a LeRobotDataset.
+
+ Args:
+ dataset: The source LeRobotDataset.
+ feature_names: Name(s) of features to remove. Can be a single string or list.
+ output_dir: Directory to save the new dataset. If None, uses default location.
+ repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
+
+ Returns:
+ New dataset with features removed.
+ """
+ return modify_features(
+ dataset=dataset,
+ add_features=None,
+ remove_features=feature_names,
+ output_dir=output_dir,
+ repo_id=repo_id,
+ )
+
+
+def _fractions_to_episode_indices(
+ total_episodes: int,
+ splits: dict[str, float],
+) -> dict[str, list[int]]:
+ """Convert split fractions to episode indices."""
+ if sum(splits.values()) > 1.0:
+ raise ValueError("Split fractions must sum to <= 1.0")
+
+ indices = list(range(total_episodes))
+ result = {}
+ start_idx = 0
+
+ for split_name, fraction in splits.items():
+ num_episodes = int(total_episodes * fraction)
+ if num_episodes == 0:
+ logging.warning(f"Split '{split_name}' has no episodes, skipping...")
+ continue
+ end_idx = start_idx + num_episodes
+ if split_name == list(splits.keys())[-1]:
+ end_idx = total_episodes
+ result[split_name] = indices[start_idx:end_idx]
+ start_idx = end_idx
+
+ return result
+
+
+def _copy_and_reindex_data(
+ src_dataset: LeRobotDataset,
+ dst_meta: LeRobotDatasetMetadata,
+ episode_mapping: dict[int, int],
+) -> dict[int, dict]:
+ """Copy and filter data files, only modifying files with deleted episodes.
+
+ Args:
+ src_dataset: Source dataset to copy from
+ dst_meta: Destination metadata object
+ episode_mapping: Mapping from old episode indices to new indices
+
+ Returns:
+ dict mapping episode index to its data file metadata (chunk_index, file_index, etc.)
+ """
+ if src_dataset.meta.episodes is None:
+ src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
+
+ file_to_episodes: dict[Path, set[int]] = {}
+ for old_idx in episode_mapping:
+ file_path = src_dataset.meta.get_data_file_path(old_idx)
+ if file_path not in file_to_episodes:
+ file_to_episodes[file_path] = set()
+ file_to_episodes[file_path].add(old_idx)
+
+ global_index = 0
+ episode_data_metadata: dict[int, dict] = {}
+
+ if dst_meta.tasks is None:
+ all_task_indices = set()
+ for src_path in file_to_episodes:
+ df = pd.read_parquet(src_dataset.root / src_path)
+ mask = df["episode_index"].isin(list(episode_mapping.keys()))
+ task_series: pd.Series = df[mask]["task_index"]
+ all_task_indices.update(task_series.unique().tolist())
+ tasks = [src_dataset.meta.tasks.iloc[idx].name for idx in all_task_indices]
+ dst_meta.save_episode_tasks(list(set(tasks)))
+
+ task_mapping = {}
+ for old_task_idx in range(len(src_dataset.meta.tasks)):
+ task_name = src_dataset.meta.tasks.iloc[old_task_idx].name
+ new_task_idx = dst_meta.get_task_index(task_name)
+ if new_task_idx is not None:
+ task_mapping[old_task_idx] = new_task_idx
+
+ for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"):
+ df = pd.read_parquet(src_dataset.root / src_path)
+
+ all_episodes_in_file = set(df["episode_index"].unique())
+ episodes_to_keep = file_to_episodes[src_path]
+
+ if all_episodes_in_file == episodes_to_keep:
+ df["episode_index"] = df["episode_index"].replace(episode_mapping)
+ df["index"] = range(global_index, global_index + len(df))
+ df["task_index"] = df["task_index"].replace(task_mapping)
+
+ first_ep_old_idx = min(episodes_to_keep)
+ src_ep = src_dataset.meta.episodes[first_ep_old_idx]
+ chunk_idx = src_ep["data/chunk_index"]
+ file_idx = src_ep["data/file_index"]
+ else:
+ mask = df["episode_index"].isin(list(episode_mapping.keys()))
+ df = df[mask].copy().reset_index(drop=True)
+
+ if len(df) == 0:
+ continue
+
+ df["episode_index"] = df["episode_index"].replace(episode_mapping)
+ df["index"] = range(global_index, global_index + len(df))
+ df["task_index"] = df["task_index"].replace(task_mapping)
+
+ first_ep_old_idx = min(episodes_to_keep)
+ src_ep = src_dataset.meta.episodes[first_ep_old_idx]
+ chunk_idx = src_ep["data/chunk_index"]
+ file_idx = src_ep["data/file_index"]
+
+ dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
+
+ _write_parquet(df, dst_path, dst_meta)
+
+ for ep_old_idx in episodes_to_keep:
+ ep_new_idx = episode_mapping[ep_old_idx]
+ ep_df = df[df["episode_index"] == ep_new_idx]
+ episode_data_metadata[ep_new_idx] = {
+ "data/chunk_index": chunk_idx,
+ "data/file_index": file_idx,
+ "dataset_from_index": int(ep_df["index"].min()),
+ "dataset_to_index": int(ep_df["index"].max() + 1),
+ }
+
+ global_index += len(df)
+
+ return episode_data_metadata
+
+
+def _keep_episodes_from_video_with_av(
+ input_path: Path,
+ output_path: Path,
+ episodes_to_keep: list[tuple[float, float]],
+ fps: float,
+ vcodec: str = "libsvtav1",
+ pix_fmt: str = "yuv420p",
+) -> None:
+ """Keep only specified episodes from a video file using PyAV.
+
+ This function decodes frames from specified time ranges and re-encodes them with
+ properly reset timestamps to ensure monotonic progression.
+
+ Args:
+ input_path: Source video file path.
+ output_path: Destination video file path.
+ episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep.
+ fps: Frame rate of the video.
+ vcodec: Video codec to use for encoding.
+ pix_fmt: Pixel format for output video.
+ """
+ from fractions import Fraction
+
+ import av
+
+ if not episodes_to_keep:
+ raise ValueError("No episodes to keep")
+
+ in_container = av.open(str(input_path))
+
+ # Check if video stream exists.
+ if not in_container.streams.video:
+ raise ValueError(
+ f"No video streams found in {input_path}. "
+ "The video file may be corrupted or empty. "
+ "Try re-downloading the dataset or checking the video file."
+ )
+
+ v_in = in_container.streams.video[0]
+
+ out = av.open(str(output_path), mode="w")
+
+ # Convert fps to Fraction for PyAV compatibility.
+ fps_fraction = Fraction(fps).limit_denominator(1000)
+ v_out = out.add_stream(vcodec, rate=fps_fraction)
+
+ # PyAV type stubs don't distinguish video streams from audio/subtitle streams.
+ v_out.width = v_in.codec_context.width
+ v_out.height = v_in.codec_context.height
+ v_out.pix_fmt = pix_fmt
+
+ # Set time_base to match the frame rate for proper timestamp handling.
+ v_out.time_base = Fraction(1, int(fps))
+
+ out.start_encoding()
+
+ # Create set of (start, end) ranges for fast lookup.
+ # Convert to a sorted list for efficient checking.
+ time_ranges = sorted(episodes_to_keep)
+
+ # Track frame index for setting PTS and current range being processed.
+ frame_count = 0
+ range_idx = 0
+
+ # Read through entire video once and filter frames.
+ for packet in in_container.demux(v_in):
+ for frame in packet.decode():
+ if frame is None:
+ continue
+
+ # Get frame timestamp.
+ frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0
+
+ # Check if frame is in any of our desired time ranges.
+ # Skip ranges that have already passed.
+ while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]:
+ range_idx += 1
+
+ # If we've passed all ranges, stop processing.
+ if range_idx >= len(time_ranges):
+ break
+
+ # Check if frame is in current range.
+ start_ts, end_ts = time_ranges[range_idx]
+ if frame_time < start_ts:
+ continue
+
+ # Frame is in range - create a new frame with reset timestamps.
+ # We need to create a copy to avoid modifying the original.
+ new_frame = frame.reformat(width=v_out.width, height=v_out.height, format=v_out.pix_fmt)
+ new_frame.pts = frame_count
+ new_frame.time_base = Fraction(1, int(fps))
+
+ # Encode and mux the frame.
+ for pkt in v_out.encode(new_frame):
+ out.mux(pkt)
+
+ frame_count += 1
+
+ # Flush encoder.
+ for pkt in v_out.encode():
+ out.mux(pkt)
+
+ out.close()
+ in_container.close()
+
+
+def _copy_and_reindex_videos(
+ src_dataset: LeRobotDataset,
+ dst_meta: LeRobotDatasetMetadata,
+ episode_mapping: dict[int, int],
+ vcodec: str = "libsvtav1",
+ pix_fmt: str = "yuv420p",
+) -> dict[int, dict]:
+ """Copy and filter video files, only re-encoding files with deleted episodes.
+
+ For video files that only contain kept episodes, we copy them directly.
+ For files with mixed kept/deleted episodes, we use PyAV filters to efficiently
+ re-encode only the desired segments.
+
+ Args:
+ src_dataset: Source dataset to copy from
+ dst_meta: Destination metadata object
+ episode_mapping: Mapping from old episode indices to new indices
+
+ Returns:
+ dict mapping episode index to its video metadata (chunk_index, file_index, timestamps)
+ """
+ if src_dataset.meta.episodes is None:
+ src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
+
+ episodes_video_metadata: dict[int, dict] = {new_idx: {} for new_idx in episode_mapping.values()}
+
+ for video_key in src_dataset.meta.video_keys:
+ logging.info(f"Processing videos for {video_key}")
+
+ if dst_meta.video_path is None:
+ raise ValueError("Destination metadata has no video_path defined")
+
+ file_to_episodes: dict[tuple[int, int], list[int]] = {}
+ for old_idx in episode_mapping:
+ src_ep = src_dataset.meta.episodes[old_idx]
+ chunk_idx = src_ep[f"videos/{video_key}/chunk_index"]
+ file_idx = src_ep[f"videos/{video_key}/file_index"]
+ file_key = (chunk_idx, file_idx)
+ if file_key not in file_to_episodes:
+ file_to_episodes[file_key] = []
+ file_to_episodes[file_key].append(old_idx)
+
+ for (src_chunk_idx, src_file_idx), episodes_in_file in tqdm(
+ sorted(file_to_episodes.items()), desc=f"Processing {video_key} video files"
+ ):
+ all_episodes_in_file = [
+ ep_idx
+ for ep_idx in range(src_dataset.meta.total_episodes)
+ if src_dataset.meta.episodes[ep_idx].get(f"videos/{video_key}/chunk_index") == src_chunk_idx
+ and src_dataset.meta.episodes[ep_idx].get(f"videos/{video_key}/file_index") == src_file_idx
+ ]
+
+ episodes_to_keep_set = set(episodes_in_file)
+ all_in_file_set = set(all_episodes_in_file)
+
+ if all_in_file_set == episodes_to_keep_set:
+ assert src_dataset.meta.video_path is not None
+ src_video_path = src_dataset.root / src_dataset.meta.video_path.format(
+ video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx
+ )
+ dst_video_path = dst_meta.root / dst_meta.video_path.format(
+ video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx
+ )
+ dst_video_path.parent.mkdir(parents=True, exist_ok=True)
+ shutil.copy(src_video_path, dst_video_path)
+
+ for old_idx in episodes_in_file:
+ new_idx = episode_mapping[old_idx]
+ src_ep = src_dataset.meta.episodes[old_idx]
+ episodes_video_metadata[new_idx][f"videos/{video_key}/chunk_index"] = src_chunk_idx
+ episodes_video_metadata[new_idx][f"videos/{video_key}/file_index"] = src_file_idx
+ episodes_video_metadata[new_idx][f"videos/{video_key}/from_timestamp"] = src_ep[
+ f"videos/{video_key}/from_timestamp"
+ ]
+ episodes_video_metadata[new_idx][f"videos/{video_key}/to_timestamp"] = src_ep[
+ f"videos/{video_key}/to_timestamp"
+ ]
+ else:
+ # Build list of time ranges to keep, in sorted order.
+ sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
+ episodes_to_keep_ranges: list[tuple[float, float]] = []
+
+ for old_idx in sorted_keep_episodes:
+ src_ep = src_dataset.meta.episodes[old_idx]
+ from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
+ to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
+ episodes_to_keep_ranges.append((from_ts, to_ts))
+
+ # Use PyAV filters to efficiently re-encode only the desired segments.
+ assert src_dataset.meta.video_path is not None
+ src_video_path = src_dataset.root / src_dataset.meta.video_path.format(
+ video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx
+ )
+ dst_video_path = dst_meta.root / dst_meta.video_path.format(
+ video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx
+ )
+ dst_video_path.parent.mkdir(parents=True, exist_ok=True)
+
+ logging.info(
+ f"Re-encoding {video_key} (chunk {src_chunk_idx}, file {src_file_idx}) "
+ f"with {len(episodes_to_keep_ranges)} episodes"
+ )
+ _keep_episodes_from_video_with_av(
+ src_video_path,
+ dst_video_path,
+ episodes_to_keep_ranges,
+ src_dataset.meta.fps,
+ vcodec,
+ pix_fmt,
+ )
+
+ cumulative_ts = 0.0
+ for old_idx in sorted_keep_episodes:
+ new_idx = episode_mapping[old_idx]
+ src_ep = src_dataset.meta.episodes[old_idx]
+ ep_length = src_ep["length"]
+ ep_duration = ep_length / src_dataset.meta.fps
+
+ episodes_video_metadata[new_idx][f"videos/{video_key}/chunk_index"] = src_chunk_idx
+ episodes_video_metadata[new_idx][f"videos/{video_key}/file_index"] = src_file_idx
+ episodes_video_metadata[new_idx][f"videos/{video_key}/from_timestamp"] = cumulative_ts
+ episodes_video_metadata[new_idx][f"videos/{video_key}/to_timestamp"] = (
+ cumulative_ts + ep_duration
+ )
+
+ cumulative_ts += ep_duration
+
+ return episodes_video_metadata
+
+
+def _copy_and_reindex_episodes_metadata(
+ src_dataset: LeRobotDataset,
+ dst_meta: LeRobotDatasetMetadata,
+ episode_mapping: dict[int, int],
+ data_metadata: dict[int, dict],
+ video_metadata: dict[int, dict] | None = None,
+) -> None:
+ """Copy and reindex episodes metadata using provided data and video metadata.
+
+ Args:
+ src_dataset: Source dataset to copy from
+ dst_meta: Destination metadata object
+ episode_mapping: Mapping from old episode indices to new indices
+ data_metadata: Dict mapping new episode index to its data file metadata
+ video_metadata: Optional dict mapping new episode index to its video metadata
+ """
+ from lerobot.datasets.utils import flatten_dict
+
+ if src_dataset.meta.episodes is None:
+ src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
+
+ all_stats = []
+ total_frames = 0
+
+ for old_idx, new_idx in tqdm(
+ sorted(episode_mapping.items(), key=lambda x: x[1]), desc="Processing episodes metadata"
+ ):
+ src_episode_full = _load_episode_with_stats(src_dataset, old_idx)
+
+ src_episode = src_dataset.meta.episodes[old_idx]
+
+ episode_meta = data_metadata[new_idx].copy()
+
+ if video_metadata and new_idx in video_metadata:
+ episode_meta.update(video_metadata[new_idx])
+
+ # Extract episode statistics from parquet metadata.
+ # Note (maractingi): When pandas/pyarrow serializes numpy arrays with shape (3, 1, 1) to parquet,
+ # they are being deserialized as nested object arrays like:
+ # array([array([array([0.])]), array([array([0.])]), array([array([0.])])])
+ # This happens particularly with image/video statistics. We need to detect and flatten
+ # these nested structures back to proper (3, 1, 1) arrays so aggregate_stats can process them.
+ episode_stats = {}
+ for key in src_episode_full:
+ if key.startswith("stats/"):
+ stat_key = key.replace("stats/", "")
+ parts = stat_key.split("/")
+ if len(parts) == 2:
+ feature_name, stat_name = parts
+ if feature_name not in episode_stats:
+ episode_stats[feature_name] = {}
+
+ value = src_episode_full[key]
+
+ if feature_name in src_dataset.meta.features:
+ feature_dtype = src_dataset.meta.features[feature_name]["dtype"]
+ if feature_dtype in ["image", "video"] and stat_name != "count":
+ if isinstance(value, np.ndarray) and value.dtype == object:
+ flat_values = []
+ for item in value:
+ while isinstance(item, np.ndarray):
+ item = item.flatten()[0]
+ flat_values.append(item)
+ value = np.array(flat_values, dtype=np.float64).reshape(3, 1, 1)
+ elif isinstance(value, np.ndarray) and value.shape == (3,):
+ value = value.reshape(3, 1, 1)
+
+ episode_stats[feature_name][stat_name] = value
+
+ all_stats.append(episode_stats)
+
+ episode_dict = {
+ "episode_index": new_idx,
+ "tasks": src_episode["tasks"],
+ "length": src_episode["length"],
+ }
+ episode_dict.update(episode_meta)
+ episode_dict.update(flatten_dict({"stats": episode_stats}))
+ dst_meta._save_episode_metadata(episode_dict)
+
+ total_frames += src_episode["length"]
+
+ dst_meta._close_writer()
+
+ dst_meta.info.update(
+ {
+ "total_episodes": len(episode_mapping),
+ "total_frames": total_frames,
+ "total_tasks": len(dst_meta.tasks) if dst_meta.tasks is not None else 0,
+ "splits": {"train": f"0:{len(episode_mapping)}"},
+ }
+ )
+ write_info(dst_meta.info, dst_meta.root)
+
+ if not all_stats:
+ logging.warning("No statistics found to aggregate")
+ return
+
+ logging.info(f"Aggregating statistics for {len(all_stats)} episodes")
+ aggregated_stats = aggregate_stats(all_stats)
+ filtered_stats = {k: v for k, v in aggregated_stats.items() if k in dst_meta.features}
+ write_stats(filtered_stats, dst_meta.root)
+
+
+def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -> None:
+ """Write DataFrame to parquet
+
+ This ensures images are properly embedded and the file can be loaded correctly by HF datasets.
+ """
+ from lerobot.datasets.utils import embed_images, get_hf_features_from_features
+
+ hf_features = get_hf_features_from_features(meta.features)
+ ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train")
+
+ if len(meta.image_keys) > 0:
+ ep_dataset = embed_images(ep_dataset)
+
+ table = ep_dataset.with_format("arrow")[:]
+ writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
+ writer.write_table(table)
+ writer.close()
+
+
+def _save_data_chunk(
+ df: pd.DataFrame,
+ meta: LeRobotDatasetMetadata,
+ chunk_idx: int = 0,
+ file_idx: int = 0,
+) -> tuple[int, int, dict[int, dict]]:
+ """Save a data chunk and return updated indices and episode metadata.
+
+ Returns:
+ tuple: (next_chunk_idx, next_file_idx, episode_metadata_dict)
+ where episode_metadata_dict maps episode_index to its data file metadata
+ """
+ path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
+ path.parent.mkdir(parents=True, exist_ok=True)
+
+ _write_parquet(df, path, meta)
+
+ episode_metadata = {}
+ for ep_idx in df["episode_index"].unique():
+ ep_df = df[df["episode_index"] == ep_idx]
+ episode_metadata[ep_idx] = {
+ "data/chunk_index": chunk_idx,
+ "data/file_index": file_idx,
+ "dataset_from_index": int(ep_df["index"].min()),
+ "dataset_to_index": int(ep_df["index"].max() + 1),
+ }
+
+ file_size = get_parquet_file_size_in_mb(path)
+ if file_size >= DEFAULT_DATA_FILE_SIZE_IN_MB * 0.9:
+ chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
+
+ return chunk_idx, file_idx, episode_metadata
+
+
+def _copy_data_with_feature_changes(
+ dataset: LeRobotDataset,
+ new_meta: LeRobotDatasetMetadata,
+ add_features: dict[str, tuple] | None = None,
+ remove_features: list[str] | None = None,
+) -> None:
+ """Copy data while adding or removing features."""
+ if dataset.meta.episodes is None:
+ dataset.meta.episodes = load_episodes(dataset.meta.root)
+
+ # Map file paths to episode indices to extract chunk/file indices
+ file_to_episodes: dict[Path, set[int]] = {}
+ for ep_idx in range(dataset.meta.total_episodes):
+ file_path = dataset.meta.get_data_file_path(ep_idx)
+ if file_path not in file_to_episodes:
+ file_to_episodes[file_path] = set()
+ file_to_episodes[file_path].add(ep_idx)
+
+ frame_idx = 0
+
+ for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"):
+ df = pd.read_parquet(dataset.root / src_path).reset_index(drop=True)
+
+ # Get chunk_idx and file_idx from the source file's first episode
+ episodes_in_file = file_to_episodes[src_path]
+ first_ep_idx = min(episodes_in_file)
+ src_ep = dataset.meta.episodes[first_ep_idx]
+ chunk_idx = src_ep["data/chunk_index"]
+ file_idx = src_ep["data/file_index"]
+
+ if remove_features:
+ df = df.drop(columns=remove_features, errors="ignore")
+
+ if add_features:
+ end_idx = frame_idx + len(df)
+ for feature_name, (values, _) in add_features.items():
+ if callable(values):
+ feature_values = []
+ for _, row in df.iterrows():
+ ep_idx = row["episode_index"]
+ frame_in_ep = row["frame_index"]
+ value = values(row.to_dict(), ep_idx, frame_in_ep)
+ if isinstance(value, np.ndarray) and value.size == 1:
+ value = value.item()
+ feature_values.append(value)
+ df[feature_name] = feature_values
+ else:
+ feature_slice = values[frame_idx:end_idx]
+ if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1:
+ df[feature_name] = feature_slice.flatten()
+ else:
+ df[feature_name] = feature_slice
+ frame_idx = end_idx
+
+ # Write using the preserved chunk_idx and file_idx from source
+ dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
+
+ _write_parquet(df, dst_path, new_meta)
+
+ _copy_episodes_metadata_and_stats(dataset, new_meta)
+
+
+def _copy_videos(
+ src_dataset: LeRobotDataset,
+ dst_meta: LeRobotDatasetMetadata,
+ exclude_keys: list[str] | None = None,
+) -> None:
+ """Copy video files, optionally excluding certain keys."""
+ if exclude_keys is None:
+ exclude_keys = []
+
+ for video_key in src_dataset.meta.video_keys:
+ if video_key in exclude_keys:
+ continue
+
+ video_files = set()
+ for ep_idx in range(len(src_dataset.meta.episodes)):
+ try:
+ video_files.add(src_dataset.meta.get_video_file_path(ep_idx, video_key))
+ except KeyError:
+ continue
+
+ for src_path in tqdm(sorted(video_files), desc=f"Copying {video_key} videos"):
+ dst_path = dst_meta.root / src_path
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
+ shutil.copy(src_dataset.root / src_path, dst_path)
+
+
+def _copy_episodes_metadata_and_stats(
+ src_dataset: LeRobotDataset,
+ dst_meta: LeRobotDatasetMetadata,
+) -> None:
+ """Copy episodes metadata and recalculate stats."""
+ if src_dataset.meta.tasks is not None:
+ write_tasks(src_dataset.meta.tasks, dst_meta.root)
+ dst_meta.tasks = src_dataset.meta.tasks.copy()
+
+ episodes_dir = src_dataset.root / "meta/episodes"
+ dst_episodes_dir = dst_meta.root / "meta/episodes"
+ if episodes_dir.exists():
+ shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True)
+
+ dst_meta.info.update(
+ {
+ "total_episodes": src_dataset.meta.total_episodes,
+ "total_frames": src_dataset.meta.total_frames,
+ "total_tasks": src_dataset.meta.total_tasks,
+ "splits": src_dataset.meta.info.get("splits", {"train": f"0:{src_dataset.meta.total_episodes}"}),
+ }
+ )
+
+ if dst_meta.video_keys and src_dataset.meta.video_keys:
+ for key in dst_meta.video_keys:
+ if key in src_dataset.meta.features:
+ dst_meta.info["features"][key]["info"] = src_dataset.meta.info["features"][key].get(
+ "info", {}
+ )
+
+ write_info(dst_meta.info, dst_meta.root)
+
+ if set(dst_meta.features.keys()) != set(src_dataset.meta.features.keys()):
+ logging.info("Recalculating dataset statistics...")
+ if src_dataset.meta.stats:
+ new_stats = {}
+ for key in dst_meta.features:
+ if key in src_dataset.meta.stats:
+ new_stats[key] = src_dataset.meta.stats[key]
+ write_stats(new_stats, dst_meta.root)
+ else:
+ if src_dataset.meta.stats:
+ write_stats(src_dataset.meta.stats, dst_meta.root)
diff --git a/src/lerobot/datasets/image_writer.py b/src/lerobot/datasets/image_writer.py
index 4a4e1ab05..ee10df6e1 100644
--- a/src/lerobot/datasets/image_writer.py
+++ b/src/lerobot/datasets/image_writer.py
@@ -68,7 +68,30 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
return PIL.Image.fromarray(image_array)
-def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
+def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
+ """
+ Saves a NumPy array or PIL Image to a file.
+
+ This function handles both NumPy arrays and PIL Image objects, converting
+ the former to a PIL Image before saving. It includes error handling for
+ the save operation.
+
+ Args:
+ image (np.ndarray | PIL.Image.Image): The image data to save.
+ fpath (Path): The destination file path for the image.
+ compress_level (int, optional): The compression level for the saved
+ image, as used by PIL.Image.save(). Defaults to 1.
+ Refer to: https://github.com/huggingface/lerobot/pull/2135
+ for more details on the default value rationale.
+
+ Raises:
+ TypeError: If the input 'image' is not a NumPy array or a
+ PIL.Image.Image object.
+
+ Side Effects:
+ Prints an error message to the console if the image writing process
+ fails for any reason.
+ """
try:
if isinstance(image, np.ndarray):
img = image_array_to_pil_image(image)
@@ -76,7 +99,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
img = image
else:
raise TypeError(f"Unsupported image type: {type(image)}")
- img.save(fpath)
+ img.save(fpath, compress_level=compress_level)
except Exception as e:
print(f"Error writing image {fpath}: {e}")
diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py
index b661b21b0..ae142c1e8 100644
--- a/src/lerobot/datasets/lerobot_dataset.py
+++ b/src/lerobot/datasets/lerobot_dataset.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
-import gc
import logging
import shutil
import tempfile
@@ -26,6 +25,8 @@ import numpy as np
import packaging.version
import pandas as pd
import PIL.Image
+import pyarrow as pa
+import pyarrow.parquet as pq
import torch
import torch.utils
from huggingface_hub import HfApi, snapshot_download
@@ -46,13 +47,9 @@ from lerobot.datasets.utils import (
embed_images,
flatten_dict,
get_delta_indices,
- get_hf_dataset_cache_dir,
- get_hf_dataset_size_in_mb,
+ get_file_size_in_mb,
get_hf_features_from_features,
- get_parquet_file_size_in_mb,
- get_parquet_num_frames,
get_safe_version,
- get_video_size_in_mb,
hf_transform_to_torch,
is_valid_version,
load_episodes,
@@ -60,7 +57,6 @@ from lerobot.datasets.utils import (
load_nested_dataset,
load_stats,
load_tasks,
- to_parquet_with_hf_images,
update_chunk_file_indices,
validate_episode_buffer,
validate_frame,
@@ -90,10 +86,15 @@ class LeRobotDatasetMetadata:
root: str | Path | None = None,
revision: str | None = None,
force_cache_sync: bool = False,
+ metadata_buffer_size: int = 10,
):
self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
+ self.writer = None
+ self.latest_episode = None
+ self.metadata_buffer: list[dict] = []
+ self.metadata_buffer_size = metadata_buffer_size
try:
if force_cache_sync:
@@ -107,6 +108,54 @@ class LeRobotDatasetMetadata:
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()
+ def _flush_metadata_buffer(self) -> None:
+ """Write all buffered episode metadata to parquet file."""
+ if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
+ return
+
+ combined_dict = {}
+ for episode_dict in self.metadata_buffer:
+ for key, value in episode_dict.items():
+ if key not in combined_dict:
+ combined_dict[key] = []
+ # Extract value and serialize numpy arrays
+ # because PyArrow's from_pydict function doesn't support numpy arrays
+ val = value[0] if isinstance(value, list) else value
+ combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
+
+ first_ep = self.metadata_buffer[0]
+ chunk_idx = first_ep["meta/episodes/chunk_index"][0]
+ file_idx = first_ep["meta/episodes/file_index"][0]
+
+ table = pa.Table.from_pydict(combined_dict)
+
+ if not self.writer:
+ path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
+ path.parent.mkdir(parents=True, exist_ok=True)
+ self.writer = pq.ParquetWriter(
+ path, schema=table.schema, compression="snappy", use_dictionary=True
+ )
+
+ self.writer.write_table(table)
+
+ self.latest_episode = self.metadata_buffer[-1]
+ self.metadata_buffer.clear()
+
+ def _close_writer(self) -> None:
+ """Close and cleanup the parquet writer if it exists."""
+ self._flush_metadata_buffer()
+
+ writer = getattr(self, "writer", None)
+ if writer is not None:
+ writer.close()
+ self.writer = None
+
+ def __del__(self):
+ """
+ Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
+ """
+ self._close_writer()
+
def load_metadata(self):
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
@@ -138,6 +187,12 @@ class LeRobotDatasetMetadata:
return packaging.version.parse(self.info["codebase_version"])
def get_data_file_path(self, ep_index: int) -> Path:
+ if self.episodes is None:
+ self.episodes = load_episodes(self.root)
+ if ep_index >= len(self.episodes):
+ raise IndexError(
+ f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
+ )
ep = self.episodes[ep_index]
chunk_idx = ep["data/chunk_index"]
file_idx = ep["data/file_index"]
@@ -145,6 +200,12 @@ class LeRobotDatasetMetadata:
return Path(fpath)
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
+ if self.episodes is None:
+ self.episodes = load_episodes(self.root)
+ if ep_index >= len(self.episodes):
+ raise IndexError(
+ f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
+ )
ep = self.episodes[ep_index]
chunk_idx = ep[f"videos/{vid_key}/chunk_index"]
file_idx = ep[f"videos/{vid_key}/file_index"]
@@ -260,72 +321,75 @@ class LeRobotDatasetMetadata:
write_tasks(self.tasks, self.root)
def _save_episode_metadata(self, episode_dict: dict) -> None:
- """Save episode metadata to a parquet file and update the Hugging Face dataset of episodes metadata.
+ """Buffer episode metadata and write to parquet in batches for efficiency.
- This function processes episodes metadata from a dictionary, converts it into a Hugging Face dataset,
- and saves it as a parquet file. It handles both the creation of new parquet files and the
- updating of existing ones based on size constraints. After saving the metadata, it reloads
- the Hugging Face dataset to ensure it is up-to-date.
+ This function accumulates episode metadata in a buffer and flushes it when the buffer
+ reaches the configured size. This reduces I/O overhead by writing multiple episodes
+ at once instead of one row at a time.
Notes: We both need to update parquet files and HF dataset:
- `pandas` loads parquet file in RAM
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
or loads directly from pyarrow cache.
"""
- # Convert buffer into HF Dataset
+ # Convert to list format for each value
episode_dict = {key: [value] for key, value in episode_dict.items()}
- ep_dataset = datasets.Dataset.from_dict(episode_dict)
- ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
- df = pd.DataFrame(ep_dataset)
num_frames = episode_dict["length"][0]
- if self.episodes is None:
+ if self.latest_episode is None:
# Initialize indices and frame count for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
- df["meta/episodes/chunk_index"] = [chunk_idx]
- df["meta/episodes/file_index"] = [file_idx]
- df["dataset_from_index"] = [0]
- df["dataset_to_index"] = [num_frames]
- else:
- # Retrieve information from the latest parquet file
- latest_ep = self.episodes[-1]
- chunk_idx = latest_ep["meta/episodes/chunk_index"]
- file_idx = latest_ep["meta/episodes/file_index"]
+ if self.episodes is not None and len(self.episodes) > 0:
+ # It means we are resuming recording, so we need to load the latest episode
+ # Update the indices to avoid overwriting the latest episode
+ chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"]
+ file_idx = self.episodes[-1]["meta/episodes/file_index"]
+ latest_num_frames = self.episodes[-1]["dataset_to_index"]
+ episode_dict["dataset_from_index"] = [latest_num_frames]
+ episode_dict["dataset_to_index"] = [latest_num_frames + num_frames]
- latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
- latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
-
- if latest_size_in_mb + ep_size_in_mb >= self.data_files_size_in_mb:
- # Size limit is reached, prepare new parquet file
+ # When resuming, move to the next file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
+ else:
+ episode_dict["dataset_from_index"] = [0]
+ episode_dict["dataset_to_index"] = [num_frames]
+
+ episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
+ episode_dict["meta/episodes/file_index"] = [file_idx]
+ else:
+ chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0]
+ file_idx = self.latest_episode["meta/episodes/file_index"][0]
+
+ latest_path = (
+ self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
+ if self.writer is None
+ else self.writer.where
+ )
+
+ if Path(latest_path).exists():
+ latest_size_in_mb = get_file_size_in_mb(Path(latest_path))
+ latest_num_frames = self.latest_episode["episode_index"][0]
+
+ av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0
+
+ if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb:
+ # Size limit is reached, flush buffer and prepare new parquet file
+ self._flush_metadata_buffer()
+ chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
+ self._close_writer()
# Update the existing pandas dataframe with new row
- df["meta/episodes/chunk_index"] = [chunk_idx]
- df["meta/episodes/file_index"] = [file_idx]
- df["dataset_from_index"] = [latest_ep["dataset_to_index"]]
- df["dataset_to_index"] = [latest_ep["dataset_to_index"] + num_frames]
+ episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
+ episode_dict["meta/episodes/file_index"] = [file_idx]
+ episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]]
+ episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames]
- if latest_size_in_mb + ep_size_in_mb < self.data_files_size_in_mb:
- # Size limit wasnt reached, concatenate latest dataframe with new one
- latest_df = pd.read_parquet(latest_path)
- df = pd.concat([latest_df, df], ignore_index=True)
+ # Add to buffer
+ self.metadata_buffer.append(episode_dict)
+ self.latest_episode = episode_dict
- # Memort optimization
- del latest_df
- gc.collect()
-
- # Write the resulting dataframe from RAM to disk
- path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
- path.parent.mkdir(parents=True, exist_ok=True)
- df.to_parquet(path, index=False)
-
- if self.episodes is not None:
- # Remove the episodes cache directory, necessary to avoid cache bloat
- cached_dir = get_hf_dataset_cache_dir(self.episodes)
- if cached_dir is not None:
- shutil.rmtree(cached_dir)
-
- self.episodes = load_episodes(self.root)
+ if len(self.metadata_buffer) >= self.metadata_buffer_size:
+ self._flush_metadata_buffer()
def save_episode(
self,
@@ -438,6 +502,10 @@ class LeRobotDatasetMetadata:
robot_type: str | None = None,
root: str | Path | None = None,
use_videos: bool = True,
+ metadata_buffer_size: int = 10,
+ chunks_size: int | None = None,
+ data_files_size_in_mb: int | None = None,
+ video_files_size_in_mb: int | None = None,
) -> "LeRobotDatasetMetadata":
"""Creates metadata for a LeRobotDataset."""
obj = cls.__new__(cls)
@@ -452,11 +520,24 @@ class LeRobotDatasetMetadata:
obj.tasks = None
obj.episodes = None
obj.stats = None
- obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type)
+ obj.info = create_empty_dataset_info(
+ CODEBASE_VERSION,
+ fps,
+ features,
+ use_videos,
+ robot_type,
+ chunks_size,
+ data_files_size_in_mb,
+ video_files_size_in_mb,
+ )
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
obj.revision = None
+ obj.writer = None
+ obj.latest_episode = None
+ obj.metadata_buffer = []
+ obj.metadata_buffer_size = metadata_buffer_size
return obj
@@ -603,6 +684,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Unused attributes
self.image_writer = None
self.episode_buffer = None
+ self.writer = None
+ self.latest_episode = None
self.root.mkdir(exist_ok=True, parents=True)
@@ -611,6 +694,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
)
+ # Track dataset state for efficient incremental writing
+ self._lazy_loading = False
+ self._recorded_frames = self.meta.total_frames
+ self._writer_closed_for_reading = False
+
# Load actual data
try:
if force_cache_sync:
@@ -629,6 +717,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
+ def _close_writer(self) -> None:
+ """Close and cleanup the parquet writer if it exists."""
+ writer = getattr(self, "writer", None)
+ if writer is not None:
+ writer.close()
+ self.writer = None
+
+ def __del__(self):
+ """
+ Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
+ """
+ self._close_writer()
+
def push_to_hub(
self,
branch: str | None = None,
@@ -769,8 +870,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def num_frames(self) -> int:
- """Number of frames in selected episodes."""
- return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
+ """Number of frames in selected episodes.
+
+ Note: When episodes a subset of the full dataset is requested, we must return the
+ actual loaded data length (len(self.hf_dataset)) rather than metadata total_frames.
+ self.meta.total_frames is the total number of frames in the full dataset.
+ """
+ if self.episodes is not None and self.hf_dataset is not None:
+ return len(self.hf_dataset)
+ return self.meta.total_frames
@property
def num_episodes(self) -> int:
@@ -848,10 +956,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
return item
+ def _ensure_hf_dataset_loaded(self):
+ """Lazy load the HF dataset only when needed for reading."""
+ if self._lazy_loading or self.hf_dataset is None:
+ # Close the writer before loading to ensure parquet file is properly finalized
+ if self.writer is not None:
+ self._close_writer()
+ self._writer_closed_for_reading = True
+ self.hf_dataset = self.load_hf_dataset()
+ self._lazy_loading = False
+
def __len__(self):
return self.num_frames
def __getitem__(self, idx) -> dict:
+ # Ensure dataset is loaded when we actually need to read from it
+ self._ensure_hf_dataset_loaded()
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
@@ -890,6 +1010,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
"})',\n"
)
+ def finalize(self):
+ """
+ Close the parquet writers. This function needs to be called after data collection/conversion, else footer metadata won't be written to the parquet files.
+ The dataset won't be valid and can't be loaded as ds = LeRobotDataset(repo_id=repo, root=HF_LEROBOT_HOME.joinpath(repo))
+ """
+ self._close_writer()
+ self.meta._close_writer()
+
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
ep_buffer = {}
@@ -1097,74 +1225,101 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train")
ep_dataset = embed_images(ep_dataset)
- ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
ep_num_frames = len(ep_dataset)
- df = pd.DataFrame(ep_dataset)
- if self.meta.episodes is None:
+ if self.latest_episode is None:
# Initialize indices and frame count for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
- latest_num_frames = 0
+ global_frame_index = 0
+ # However, if the episodes already exists
+ # It means we are resuming recording, so we need to load the latest episode
+ # Update the indices to avoid overwriting the latest episode
+ if self.meta.episodes is not None and len(self.meta.episodes) > 0:
+ latest_ep = self.meta.episodes[-1]
+ global_frame_index = latest_ep["dataset_to_index"]
+ chunk_idx = latest_ep["data/chunk_index"]
+ file_idx = latest_ep["data/file_index"]
+
+ # When resuming, move to the next file
+ chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
else:
# Retrieve information from the latest parquet file
- latest_ep = self.meta.episodes[-1]
+ latest_ep = self.latest_episode
chunk_idx = latest_ep["data/chunk_index"]
file_idx = latest_ep["data/file_index"]
+ global_frame_index = latest_ep["index"][-1] + 1
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
- latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
- latest_num_frames = get_parquet_num_frames(latest_path)
+ latest_size_in_mb = get_file_size_in_mb(latest_path)
+
+ frames_in_current_file = global_frame_index - latest_ep["dataset_from_index"]
+ av_size_per_frame = (
+ latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0
+ )
# Determine if a new parquet file is needed
- if latest_size_in_mb + ep_size_in_mb >= self.meta.data_files_size_in_mb:
- # Size limit is reached, prepare new parquet file
+ if (
+ latest_size_in_mb + av_size_per_frame * ep_num_frames >= self.meta.data_files_size_in_mb
+ or self._writer_closed_for_reading
+ ):
+ # Size limit is reached or writer was closed for reading, prepare new parquet file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
- latest_num_frames = 0
- else:
- # Update the existing parquet file with new rows
- latest_df = pd.read_parquet(latest_path)
- df = pd.concat([latest_df, df], ignore_index=True)
+ self._close_writer()
+ self._writer_closed_for_reading = False
- # Memort optimization
- del latest_df
- gc.collect()
+ ep_dict["data/chunk_index"] = chunk_idx
+ ep_dict["data/file_index"] = file_idx
# Write the resulting dataframe from RAM to disk
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
- if len(self.meta.image_keys) > 0:
- to_parquet_with_hf_images(df, path)
- else:
- df.to_parquet(path)
- if self.hf_dataset is not None:
- # Remove hf dataset cache directory, necessary to avoid cache bloat
- cached_dir = get_hf_dataset_cache_dir(self.hf_dataset)
- if cached_dir is not None:
- shutil.rmtree(cached_dir)
-
- self.hf_dataset = self.load_hf_dataset()
+ table = ep_dataset.with_format("arrow")[:]
+ if not self.writer:
+ self.writer = pq.ParquetWriter(
+ path, schema=table.schema, compression="snappy", use_dictionary=True
+ )
+ self.writer.write_table(table)
metadata = {
"data/chunk_index": chunk_idx,
"data/file_index": file_idx,
- "dataset_from_index": latest_num_frames,
- "dataset_to_index": latest_num_frames + ep_num_frames,
+ "dataset_from_index": global_frame_index,
+ "dataset_to_index": global_frame_index + ep_num_frames,
}
+
+ # Store metadata with episode data for next episode
+ self.latest_episode = {**ep_dict, **metadata}
+
+ # Mark that the HF dataset needs reloading (lazy loading approach)
+ # This avoids expensive reloading during sequential recording
+ self._lazy_loading = True
+ # Update recorded frames count for efficient length tracking
+ self._recorded_frames += ep_num_frames
+
return metadata
def _save_episode_video(self, video_key: str, episode_index: int) -> dict:
# Encode episode frames into a temporary video
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
- ep_size_in_mb = get_video_size_in_mb(ep_path)
+ ep_size_in_mb = get_file_size_in_mb(ep_path)
ep_duration_in_s = get_video_duration_in_s(ep_path)
- if self.meta.episodes is None or (
- f"videos/{video_key}/chunk_index" not in self.meta.episodes.column_names
- or f"videos/{video_key}/file_index" not in self.meta.episodes.column_names
+ if (
+ episode_index == 0
+ or self.meta.latest_episode is None
+ or f"videos/{video_key}/chunk_index" not in self.meta.latest_episode
):
# Initialize indices for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
+ if self.meta.episodes is not None and len(self.meta.episodes) > 0:
+ # It means we are resuming recording, so we need to load the latest episode
+ # Update the indices to avoid overwriting the latest episode
+ old_chunk_idx = self.meta.episodes[-1][f"videos/{video_key}/chunk_index"]
+ old_file_idx = self.meta.episodes[-1][f"videos/{video_key}/file_index"]
+ chunk_idx, file_idx = update_chunk_file_indices(
+ old_chunk_idx, old_file_idx, self.meta.chunks_size
+ )
latest_duration_in_s = 0.0
new_path = self.root / self.meta.video_path.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
@@ -1172,16 +1327,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
new_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(ep_path), str(new_path))
else:
- # Retrieve information from the latest updated video file (possibly several episodes ago)
- latest_ep = self.meta.episodes[episode_index - 1]
- chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"]
- file_idx = latest_ep[f"videos/{video_key}/file_index"]
+ # Retrieve information from the latest updated video file using latest_episode
+ latest_ep = self.meta.latest_episode
+ chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"][0]
+ file_idx = latest_ep[f"videos/{video_key}/file_index"][0]
latest_path = self.root / self.meta.video_path.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
- latest_size_in_mb = get_video_size_in_mb(latest_path)
- latest_duration_in_s = get_video_duration_in_s(latest_path)
+ latest_size_in_mb = get_file_size_in_mb(latest_path)
+ latest_duration_in_s = latest_ep[f"videos/{video_key}/to_timestamp"][0]
if latest_size_in_mb + ep_size_in_mb >= self.meta.video_files_size_in_mb:
# Move temporary episode video to a new video file in the dataset
@@ -1315,6 +1470,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None
obj.delta_indices = None
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
+ obj.writer = None
+ obj.latest_episode = None
+ # Initialize tracking for incremental recording
+ obj._lazy_loading = False
+ obj._recorded_frames = 0
+ obj._writer_closed_for_reading = False
return obj
diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py
index a2f285014..37d8432b2 100644
--- a/src/lerobot/datasets/utils.py
+++ b/src/lerobot/datasets/utils.py
@@ -30,7 +30,7 @@ import pandas
import pandas as pd
import pyarrow.parquet as pq
import torch
-from datasets import Dataset, concatenate_datasets
+from datasets import Dataset
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError
@@ -44,7 +44,7 @@ from lerobot.datasets.backward_compatibility import (
ForwardCompatibilityError,
)
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
-from lerobot.utils.utils import is_valid_numpy_dtype_string
+from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_string
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
@@ -94,12 +94,6 @@ def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
return hf_ds.data.nbytes // (1024**2)
-def get_hf_dataset_cache_dir(hf_ds: Dataset) -> Path | None:
- if hf_ds.cache_files is None or len(hf_ds.cache_files) == 0:
- return None
- return Path(hf_ds.cache_files[0]["filename"]).parents[2]
-
-
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]:
if file_idx == chunks_size - 1:
file_idx = 0
@@ -123,8 +117,9 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None)
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
- datasets = [Dataset.from_parquet(str(path), features=features) for path in paths]
- return concatenate_datasets(datasets)
+ with SuppressProgressBars():
+ datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
+ return datasets
def get_parquet_num_frames(parquet_path: str | Path) -> int:
@@ -132,10 +127,14 @@ def get_parquet_num_frames(parquet_path: str | Path) -> int:
return metadata.num_rows
-def get_video_size_in_mb(mp4_path: Path) -> float:
- file_size_bytes = mp4_path.stat().st_size
- file_size_mb = file_size_bytes / (1024**2)
- return file_size_mb
+def get_file_size_in_mb(file_path: Path) -> float:
+ """Get file size on disk in megabytes.
+
+ Args:
+ file_path (Path): Path to the file.
+ """
+ file_size_bytes = file_path.stat().st_size
+ return file_size_bytes / (1024**2)
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py
index 42ab2f642..b8ae29ad6 100644
--- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py
+++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py
@@ -69,9 +69,9 @@ from lerobot.datasets.utils import (
LEGACY_TASKS_PATH,
cast_stats_to_numpy,
flatten_dict,
+ get_file_size_in_mb,
get_parquet_file_size_in_mb,
get_parquet_num_frames,
- get_video_size_in_mb,
load_info,
update_chunk_file_indices,
write_episodes,
@@ -310,7 +310,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
episodes_metadata = []
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
- ep_size_in_mb = get_video_size_in_mb(ep_path)
+ ep_size_in_mb = get_file_size_in_mb(ep_path)
ep_duration_in_s = get_video_duration_in_s(ep_path)
# Check if adding this episode would exceed the limit
diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py
index 1d4f07c76..740cdb602 100644
--- a/src/lerobot/datasets/video_utils.py
+++ b/src/lerobot/datasets/video_utils.py
@@ -452,6 +452,9 @@ def concatenate_video_files(
template=input_stream, opaque=True
)
+ # set the time base to the input stream time base (missing in the codec context)
+ stream_map[input_stream.index].time_base = input_stream.time_base
+
# Demux + remux packets (no re-encode)
for packet in input_container.demux():
# Skip packets from un-mapped streams
@@ -639,6 +642,9 @@ class VideoEncodingManager:
)
self.dataset._batch_save_episode_video(start_ep, end_ep)
+ # Finalize the dataset to properly close all writers
+ self.dataset.finalize()
+
# Clean up episode images if recording was interrupted
if exc_type is not None:
interrupted_episode_index = self.dataset.num_episodes
diff --git a/src/lerobot/envs/__init__.py b/src/lerobot/envs/__init__.py
index 4977d11d9..d767b6e8c 100644
--- a/src/lerobot/envs/__init__.py
+++ b/src/lerobot/envs/__init__.py
@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
+from .configs import AlohaEnv, EnvConfig, PushtEnv # noqa: F401
diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py
index 0daaaf9fd..3aa155093 100644
--- a/src/lerobot/envs/configs.py
+++ b/src/lerobot/envs/configs.py
@@ -50,6 +50,8 @@ class AlohaEnv(EnvConfig):
fps: int = 50
episode_length: int = 400
obs_type: str = "pixels_agent_pos"
+ observation_height: int = 480
+ observation_width: int = 640
render_mode: str = "rgb_array"
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
@@ -67,10 +69,14 @@ class AlohaEnv(EnvConfig):
def __post_init__(self):
if self.obs_type == "pixels":
- self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
+ self.features["top"] = PolicyFeature(
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
+ )
elif self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
- self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
+ self.features["pixels/top"] = PolicyFeature(
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
+ )
@property
def gym_kwargs(self) -> dict:
@@ -91,6 +97,8 @@ class PushtEnv(EnvConfig):
render_mode: str = "rgb_array"
visualization_width: int = 384
visualization_height: int = 384
+ observation_height: int = 384
+ observation_width: int = 384
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
@@ -108,7 +116,9 @@ class PushtEnv(EnvConfig):
def __post_init__(self):
if self.obs_type == "pixels_agent_pos":
- self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
+ self.features["pixels"] = PolicyFeature(
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
+ )
elif self.obs_type == "environment_state_agent_pos":
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
@@ -123,45 +133,6 @@ class PushtEnv(EnvConfig):
}
-@EnvConfig.register_subclass("xarm")
-@dataclass
-class XarmEnv(EnvConfig):
- task: str | None = "XarmLift-v0"
- fps: int = 15
- episode_length: int = 200
- obs_type: str = "pixels_agent_pos"
- render_mode: str = "rgb_array"
- visualization_width: int = 384
- visualization_height: int = 384
- features: dict[str, PolicyFeature] = field(
- default_factory=lambda: {
- ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
- "pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
- }
- )
- features_map: dict[str, str] = field(
- default_factory=lambda: {
- ACTION: ACTION,
- "agent_pos": OBS_STATE,
- "pixels": OBS_IMAGE,
- }
- )
-
- def __post_init__(self):
- if self.obs_type == "pixels_agent_pos":
- self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
-
- @property
- def gym_kwargs(self) -> dict:
- return {
- "obs_type": self.obs_type,
- "render_mode": self.render_mode,
- "visualization_width": self.visualization_width,
- "visualization_height": self.visualization_height,
- "max_episode_steps": self.episode_length,
- }
-
-
@dataclass
class ImagePreprocessingConfig:
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
@@ -255,6 +226,8 @@ class LiberoEnv(EnvConfig):
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
init_states: bool = True
camera_name_mapping: dict[str, str] | None = None
+ observation_height: int = 360
+ observation_width: int = 360
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
@@ -272,18 +245,18 @@ class LiberoEnv(EnvConfig):
def __post_init__(self):
if self.obs_type == "pixels":
self.features["pixels/agentview_image"] = PolicyFeature(
- type=FeatureType.VISUAL, shape=(360, 360, 3)
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
- type=FeatureType.VISUAL, shape=(360, 360, 3)
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
elif self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
self.features["pixels/agentview_image"] = PolicyFeature(
- type=FeatureType.VISUAL, shape=(360, 360, 3)
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
- type=FeatureType.VISUAL, shape=(360, 360, 3)
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
else:
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
@@ -294,3 +267,45 @@ class LiberoEnv(EnvConfig):
"obs_type": self.obs_type,
"render_mode": self.render_mode,
}
+
+
+@EnvConfig.register_subclass("metaworld")
+@dataclass
+class MetaworldEnv(EnvConfig):
+ task: str = "metaworld-push-v2" # add all tasks
+ fps: int = 80
+ episode_length: int = 400
+ obs_type: str = "pixels_agent_pos"
+ render_mode: str = "rgb_array"
+ multitask_eval: bool = True
+ features: dict[str, PolicyFeature] = field(
+ default_factory=lambda: {
+ "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
+ }
+ )
+ features_map: dict[str, str] = field(
+ default_factory=lambda: {
+ "action": ACTION,
+ "agent_pos": OBS_STATE,
+ "top": f"{OBS_IMAGE}",
+ "pixels/top": f"{OBS_IMAGE}",
+ }
+ )
+
+ def __post_init__(self):
+ if self.obs_type == "pixels":
+ self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3))
+
+ elif self.obs_type == "pixels_agent_pos":
+ self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
+ self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3))
+
+ else:
+ raise ValueError(f"Unsupported obs_type: {self.obs_type}")
+
+ @property
+ def gym_kwargs(self) -> dict:
+ return {
+ "obs_type": self.obs_type,
+ "render_mode": self.render_mode,
+ }
diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py
index c27f01b65..059e0e11a 100644
--- a/src/lerobot/envs/factory.py
+++ b/src/lerobot/envs/factory.py
@@ -17,7 +17,7 @@ import importlib
import gymnasium as gym
-from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv, XarmEnv
+from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -25,8 +25,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
return AlohaEnv(**kwargs)
elif env_type == "pusht":
return PushtEnv(**kwargs)
- elif env_type == "xarm":
- return XarmEnv(**kwargs)
elif env_type == "libero":
return LiberoEnv(**kwargs)
else:
@@ -74,7 +72,18 @@ def make_env(
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
)
+ elif "metaworld" in cfg.type:
+ from lerobot.envs.metaworld import create_metaworld_envs
+ if cfg.task is None:
+ raise ValueError("MetaWorld requires a task to be specified")
+
+ return create_metaworld_envs(
+ task=cfg.task,
+ n_envs=n_envs,
+ gym_kwargs=cfg.gym_kwargs,
+ env_cls=env_cls,
+ )
package_name = f"gym_{cfg.type}"
try:
importlib.import_module(package_name)
@@ -87,7 +96,7 @@ def make_env(
def _make_one():
return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
- vec = env_cls([_make_one for _ in range(n_envs)])
+ vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
# normalize to {suite: {task_id: vec_env}} for consistency
suite_name = cfg.type # e.g., "pusht", "aloha"
diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py
index 99ec6712f..94b08e991 100644
--- a/src/lerobot/envs/libero.py
+++ b/src/lerobot/envs/libero.py
@@ -260,19 +260,23 @@ class LiberoEnv(gym.Env):
is_success = self._env.check_success()
terminated = done or is_success
- info["is_success"] = is_success
-
+ info.update(
+ {
+ "task": self.task,
+ "task_id": self.task_id,
+ "done": done,
+ "is_success": is_success,
+ }
+ )
observation = self._format_raw_obs(raw_obs)
- if done:
+ if terminated:
+ info["final_info"] = {
+ "task": self.task,
+ "task_id": self.task_id,
+ "done": bool(done),
+ "is_success": bool(is_success),
+ }
self.reset()
- info.update(
- {
- "task": self.task,
- "task_id": self.task_id,
- "done": done,
- "is_success": is_success,
- }
- )
truncated = False
return observation, reward, terminated, truncated, info
diff --git a/src/lerobot/envs/metaworld.py b/src/lerobot/envs/metaworld.py
new file mode 100644
index 000000000..9190f33ad
--- /dev/null
+++ b/src/lerobot/envs/metaworld.py
@@ -0,0 +1,313 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+from collections import defaultdict
+from collections.abc import Callable, Sequence
+from pathlib import Path
+from typing import Any
+
+import gymnasium as gym
+import metaworld
+import metaworld.policies as policies
+import numpy as np
+from gymnasium import spaces
+
+# ---- Load configuration data from the external JSON file ----
+CONFIG_PATH = Path(__file__).parent / "metaworld_config.json"
+try:
+ with open(CONFIG_PATH) as f:
+ data = json.load(f)
+except FileNotFoundError as err:
+ raise FileNotFoundError(
+ "Could not find 'metaworld_config.json'. "
+ "Please ensure the configuration file is in the same directory as the script."
+ ) from err
+except json.JSONDecodeError as err:
+ raise ValueError(
+ "Failed to decode 'metaworld_config.json'. Please ensure it is a valid JSON file."
+ ) from err
+
+# ---- Process the loaded data ----
+
+# extract and type-check top-level dicts
+task_descriptions_obj = data.get("TASK_DESCRIPTIONS")
+if not isinstance(task_descriptions_obj, dict):
+ raise TypeError("Expected TASK_DESCRIPTIONS to be a dict[str, str]")
+TASK_DESCRIPTIONS: dict[str, str] = task_descriptions_obj
+
+task_name_to_id_obj = data.get("TASK_NAME_TO_ID")
+if not isinstance(task_name_to_id_obj, dict):
+ raise TypeError("Expected TASK_NAME_TO_ID to be a dict[str, int]")
+TASK_NAME_TO_ID: dict[str, int] = task_name_to_id_obj
+
+# difficulty -> tasks mapping
+difficulty_to_tasks = data.get("DIFFICULTY_TO_TASKS")
+if not isinstance(difficulty_to_tasks, dict):
+ raise TypeError("Expected 'DIFFICULTY_TO_TASKS' to be a dict[str, list[str]]")
+DIFFICULTY_TO_TASKS: dict[str, list[str]] = difficulty_to_tasks
+
+# convert policy strings -> actual policy classes
+task_policy_mapping = data.get("TASK_POLICY_MAPPING")
+if not isinstance(task_policy_mapping, dict):
+ raise TypeError("Expected 'TASK_POLICY_MAPPING' to be a dict[str, str]")
+TASK_POLICY_MAPPING: dict[str, Any] = {
+ task_name: getattr(policies, policy_class_name)
+ for task_name, policy_class_name in task_policy_mapping.items()
+}
+ACTION_DIM = 4
+OBS_DIM = 4
+
+
+class MetaworldEnv(gym.Env):
+ metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
+
+ def __init__(
+ self,
+ task,
+ camera_name="corner2",
+ obs_type="pixels",
+ render_mode="rgb_array",
+ observation_width=480,
+ observation_height=480,
+ visualization_width=640,
+ visualization_height=480,
+ ):
+ super().__init__()
+ self.task = task.replace("metaworld-", "")
+ self.obs_type = obs_type
+ self.render_mode = render_mode
+ self.observation_width = observation_width
+ self.observation_height = observation_height
+ self.visualization_width = visualization_width
+ self.visualization_height = visualization_height
+ self.camera_name = camera_name
+
+ self._env = self._make_envs_task(self.task)
+ self._max_episode_steps = self._env.max_path_length
+ self.task_description = TASK_DESCRIPTIONS[self.task]
+
+ self.expert_policy = TASK_POLICY_MAPPING[self.task]()
+
+ if self.obs_type == "state":
+ raise NotImplementedError()
+ elif self.obs_type == "pixels":
+ self.observation_space = spaces.Dict(
+ {
+ "pixels": spaces.Box(
+ low=0,
+ high=255,
+ shape=(self.observation_height, self.observation_width, 3),
+ dtype=np.uint8,
+ )
+ }
+ )
+ elif self.obs_type == "pixels_agent_pos":
+ self.observation_space = spaces.Dict(
+ {
+ "pixels": spaces.Box(
+ low=0,
+ high=255,
+ shape=(self.observation_height, self.observation_width, 3),
+ dtype=np.uint8,
+ ),
+ "agent_pos": spaces.Box(
+ low=-1000.0,
+ high=1000.0,
+ shape=(OBS_DIM,),
+ dtype=np.float64,
+ ),
+ }
+ )
+
+ self.action_space = spaces.Box(low=-1, high=1, shape=(ACTION_DIM,), dtype=np.float32)
+
+ def render(self) -> np.ndarray:
+ """
+ Render the current environment frame.
+
+ Returns:
+ np.ndarray: The rendered RGB image from the environment.
+ """
+ image = self._env.render()
+ if self.camera_name == "corner2":
+ # Images from this camera are flipped — correct them
+ image = np.flip(image, (0, 1))
+ return image
+
+ def _make_envs_task(self, env_name: str):
+ mt1 = metaworld.MT1(env_name, seed=42)
+ env = mt1.train_classes[env_name](render_mode="rgb_array", camera_name=self.camera_name)
+ env.set_task(mt1.train_tasks[0])
+ if self.camera_name == "corner2":
+ env.model.cam_pos[2] = [
+ 0.75,
+ 0.075,
+ 0.7,
+ ] # corner2 position, similar to https://arxiv.org/pdf/2206.14244
+ env.reset()
+ env._freeze_rand_vec = False # otherwise no randomization
+ return env
+
+ def _format_raw_obs(self, raw_obs: np.ndarray) -> dict[str, Any]:
+ image = None
+ if self._env is not None:
+ image = self._env.render()
+ if self.camera_name == "corner2":
+ # NOTE: The "corner2" camera in MetaWorld environments outputs images with both axes inverted.
+ image = np.flip(image, (0, 1))
+ agent_pos = raw_obs[:4]
+ if self.obs_type == "state":
+ raise NotImplementedError(
+ "'state' obs_type not implemented for MetaWorld. Use pixel modes instead."
+ )
+
+ elif self.obs_type in ("pixels", "pixels_agent_pos"):
+ assert image is not None, (
+ "Expected `image` to be rendered before constructing pixel-based observations. "
+ "This likely means `env.render()` returned None or the environment was not provided."
+ )
+
+ if self.obs_type == "pixels":
+ obs = {"pixels": image.copy()}
+
+ else: # pixels_agent_pos
+ obs = {
+ "pixels": image.copy(),
+ "agent_pos": agent_pos,
+ }
+ else:
+ raise ValueError(f"Unknown obs_type: {self.obs_type}")
+ return obs
+
+ def reset(
+ self,
+ seed: int | None = None,
+ **kwargs,
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
+ """
+ Reset the environment to its initial state.
+
+ Args:
+ seed (Optional[int]): Random seed for environment initialization.
+
+ Returns:
+ observation (Dict[str, Any]): The initial formatted observation.
+ info (Dict[str, Any]): Additional info about the reset state.
+ """
+ super().reset(seed=seed)
+
+ raw_obs, info = self._env.reset(seed=seed)
+
+ observation = self._format_raw_obs(raw_obs)
+
+ info = {"is_success": False}
+ return observation, info
+
+ def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
+ """
+ Perform one environment step.
+
+ Args:
+ action (np.ndarray): The action to execute, must be 1-D with shape (action_dim,).
+
+ Returns:
+ observation (Dict[str, Any]): The formatted observation after the step.
+ reward (float): The scalar reward for this step.
+ terminated (bool): Whether the episode terminated successfully.
+ truncated (bool): Whether the episode was truncated due to a time limit.
+ info (Dict[str, Any]): Additional environment info.
+ """
+ if action.ndim != 1:
+ raise ValueError(
+ f"Expected action to be 1-D (shape (action_dim,)), "
+ f"but got shape {action.shape} with ndim={action.ndim}"
+ )
+ raw_obs, reward, done, truncated, info = self._env.step(action)
+
+ # Determine whether the task was successful
+ is_success = bool(info.get("success", 0))
+ terminated = done or is_success
+ info.update(
+ {
+ "task": self.task,
+ "done": done,
+ "is_success": is_success,
+ }
+ )
+
+ # Format the raw observation into the expected structure
+ observation = self._format_raw_obs(raw_obs)
+ if terminated:
+ info["final_info"] = {
+ "task": self.task,
+ "done": bool(done),
+ "is_success": bool(is_success),
+ }
+ self.reset()
+
+ return observation, reward, terminated, truncated, info
+
+ def close(self):
+ self._env.close()
+
+
+# ---- Main API ----------------------------------------------------------------
+
+
+def create_metaworld_envs(
+ task: str,
+ n_envs: int,
+ gym_kwargs: dict[str, Any] | None = None,
+ env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
+) -> dict[str, dict[int, Any]]:
+ """
+ Create vectorized Meta-World environments with a consistent return shape.
+
+ Returns:
+ dict[task_group][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories)
+ Notes:
+ - n_envs is the number of rollouts *per task* (episode_index = 0..n_envs-1).
+ - `task` can be a single difficulty group (e.g., "easy", "medium", "hard") or a comma-separated list.
+ - If a task name is not in DIFFICULTY_TO_TASKS, we treat it as a single custom task.
+ """
+ if env_cls is None or not callable(env_cls):
+ raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.")
+ if not isinstance(n_envs, int) or n_envs <= 0:
+ raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
+
+ gym_kwargs = dict(gym_kwargs or {})
+ task_groups = [t.strip() for t in task.split(",") if t.strip()]
+ if not task_groups:
+ raise ValueError("`task` must contain at least one Meta-World task or difficulty group.")
+
+ print(f"Creating Meta-World envs | task_groups={task_groups} | n_envs(per task)={n_envs}")
+
+ out: dict[str, dict[int, Any]] = defaultdict(dict)
+
+ for group in task_groups:
+ # if not in difficulty presets, treat it as a single custom task
+ tasks = DIFFICULTY_TO_TASKS.get(group, [group])
+
+ for tid, task_name in enumerate(tasks):
+ print(f"Building vec env | group={group} | task_id={tid} | task={task_name}")
+
+ # build n_envs factories
+ fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)]
+
+ out[group][tid] = env_cls(fns)
+
+ # return a plain dict for consistency
+ return {group: dict(task_map) for group, task_map in out.items()}
diff --git a/src/lerobot/envs/metaworld_config.json b/src/lerobot/envs/metaworld_config.json
new file mode 100644
index 000000000..41a417fef
--- /dev/null
+++ b/src/lerobot/envs/metaworld_config.json
@@ -0,0 +1,121 @@
+{
+ "TASK_DESCRIPTIONS": {
+ "assembly-v3": "Pick up a nut and place it onto a peg",
+ "basketball-v3": "Dunk the basketball into the basket",
+ "bin-picking-v3": "Grasp the puck from one bin and place it into another bin",
+ "box-close-v3": "Grasp the cover and close the box with it",
+ "button-press-topdown-v3": "Press a button from the top",
+ "button-press-topdown-wall-v3": "Bypass a wall and press a button from the top",
+ "button-press-v3": "Press a button",
+ "button-press-wall-v3": "Bypass a wall and press a button",
+ "coffee-button-v3": "Push a button on the coffee machine",
+ "coffee-pull-v3": "Pull a mug from a coffee machine",
+ "coffee-push-v3": "Push a mug under a coffee machine",
+ "dial-turn-v3": "Rotate a dial 180 degrees",
+ "disassemble-v3": "Pick a nut out of a peg",
+ "door-close-v3": "Close a door with a revolving joint",
+ "door-lock-v3": "Lock the door by rotating the lock clockwise",
+ "door-open-v3": "Open a door with a revolving joint",
+ "door-unlock-v3": "Unlock the door by rotating the lock counter-clockwise",
+ "hand-insert-v3": "Insert the gripper into a hole",
+ "drawer-close-v3": "Push and close a drawer",
+ "drawer-open-v3": "Open a drawer",
+ "faucet-open-v3": "Rotate the faucet counter-clockwise",
+ "faucet-close-v3": "Rotate the faucet clockwise",
+ "hammer-v3": "Hammer a screw on the wall",
+ "handle-press-side-v3": "Press a handle down sideways",
+ "handle-press-v3": "Press a handle down",
+ "handle-pull-side-v3": "Pull a handle up sideways",
+ "handle-pull-v3": "Pull a handle up",
+ "lever-pull-v3": "Pull a lever down 90 degrees",
+ "peg-insert-side-v3": "Insert a peg sideways",
+ "pick-place-wall-v3": "Pick a puck, bypass a wall and place the puck",
+ "pick-out-of-hole-v3": "Pick up a puck from a hole",
+ "reach-v3": "Reach a goal position",
+ "push-back-v3": "Push the puck to a goal",
+ "push-v3": "Push the puck to a goal",
+ "pick-place-v3": "Pick and place a puck to a goal",
+ "plate-slide-v3": "Slide a plate into a cabinet",
+ "plate-slide-side-v3": "Slide a plate into a cabinet sideways",
+ "plate-slide-back-v3": "Get a plate from the cabinet",
+ "plate-slide-back-side-v3": "Get a plate from the cabinet sideways",
+ "peg-unplug-side-v3": "Unplug a peg sideways",
+ "soccer-v3": "Kick a soccer into the goal",
+ "stick-push-v3": "Grasp a stick and push a box using the stick",
+ "stick-pull-v3": "Grasp a stick and pull a box with the stick",
+ "push-wall-v3": "Bypass a wall and push a puck to a goal",
+ "reach-wall-v3": "Bypass a wall and reach a goal",
+ "shelf-place-v3": "Pick and place a puck onto a shelf",
+ "sweep-into-v3": "Sweep a puck into a hole",
+ "sweep-v3": "Sweep a puck off the table",
+ "window-open-v3": "Push and open a window",
+ "window-close-v3": "Push and close a window"
+ },
+ "TASK_NAME_TO_ID": {
+ "assembly-v3": 0, "basketball-v3": 1, "bin-picking-v3": 2, "box-close-v3": 3,
+ "button-press-topdown-v3": 4, "button-press-topdown-wall-v3": 5, "button-press-v3": 6,
+ "button-press-wall-v3": 7, "coffee-button-v3": 8, "coffee-pull-v3": 9, "coffee-push-v3": 10,
+ "dial-turn-v3": 11, "disassemble-v3": 12, "door-close-v3": 13, "door-lock-v3": 14,
+ "door-open-v3": 15, "door-unlock-v3": 16, "drawer-close-v3": 17, "drawer-open-v3": 18,
+ "faucet-close-v3": 19, "faucet-open-v3": 20, "hammer-v3": 21, "hand-insert-v3": 22,
+ "handle-press-side-v3": 23, "handle-press-v3": 24, "handle-pull-side-v3": 25,
+ "handle-pull-v3": 26, "lever-pull-v3": 27, "peg-insert-side-v3": 28, "peg-unplug-side-v3": 29,
+ "pick-out-of-hole-v3": 30, "pick-place-v3": 31, "pick-place-wall-v3": 32,
+ "plate-slide-back-side-v3": 33, "plate-slide-back-v3": 34, "plate-slide-side-v3": 35,
+ "plate-slide-v3": 36, "push-back-v3": 37, "push-v3": 38, "push-wall-v3": 39, "reach-v3": 40,
+ "reach-wall-v3": 41, "shelf-place-v3": 42, "soccer-v3": 43, "stick-pull-v3": 44,
+ "stick-push-v3": 45, "sweep-into-v3": 46, "sweep-v3": 47, "window-open-v3": 48,
+ "window-close-v3": 49
+ },
+ "DIFFICULTY_TO_TASKS": {
+ "easy": [
+ "button-press-v3", "button-press-topdown-v3", "button-press-topdown-wall-v3",
+ "button-press-wall-v3", "coffee-button-v3", "dial-turn-v3", "door-close-v3",
+ "door-lock-v3", "door-open-v3", "door-unlock-v3", "drawer-close-v3", "drawer-open-v3",
+ "faucet-close-v3", "faucet-open-v3", "handle-press-v3", "handle-press-side-v3",
+ "handle-pull-v3", "handle-pull-side-v3", "lever-pull-v3", "plate-slide-v3",
+ "plate-slide-back-v3", "plate-slide-back-side-v3", "plate-slide-side-v3", "reach-v3",
+ "reach-wall-v3", "window-close-v3", "window-open-v3", "peg-unplug-side-v3"
+ ],
+ "medium": [
+ "basketball-v3", "bin-picking-v3", "box-close-v3", "coffee-pull-v3", "coffee-push-v3",
+ "hammer-v3", "peg-insert-side-v3", "push-wall-v3", "soccer-v3", "sweep-v3", "sweep-into-v3"
+ ],
+ "hard": [
+ "assembly-v3", "hand-insert-v3", "pick-out-of-hole-v3", "pick-place-v3", "push-v3", "push-back-v3"
+ ],
+ "very_hard": [
+ "shelf-place-v3", "disassemble-v3", "stick-pull-v3", "stick-push-v3", "pick-place-wall-v3"
+ ]
+ },
+ "TASK_POLICY_MAPPING": {
+ "assembly-v3": "SawyerAssemblyV3Policy", "basketball-v3": "SawyerBasketballV3Policy",
+ "bin-picking-v3": "SawyerBinPickingV3Policy", "box-close-v3": "SawyerBoxCloseV3Policy",
+ "button-press-topdown-v3": "SawyerButtonPressTopdownV3Policy",
+ "button-press-topdown-wall-v3": "SawyerButtonPressTopdownWallV3Policy",
+ "button-press-v3": "SawyerButtonPressV3Policy", "button-press-wall-v3": "SawyerButtonPressWallV3Policy",
+ "coffee-button-v3": "SawyerCoffeeButtonV3Policy", "coffee-pull-v3": "SawyerCoffeePullV3Policy",
+ "coffee-push-v3": "SawyerCoffeePushV3Policy", "dial-turn-v3": "SawyerDialTurnV3Policy",
+ "disassemble-v3": "SawyerDisassembleV3Policy", "door-close-v3": "SawyerDoorCloseV3Policy",
+ "door-lock-v3": "SawyerDoorLockV3Policy", "door-open-v3": "SawyerDoorOpenV3Policy",
+ "door-unlock-v3": "SawyerDoorUnlockV3Policy", "drawer-close-v3": "SawyerDrawerCloseV3Policy",
+ "drawer-open-v3": "SawyerDrawerOpenV3Policy", "faucet-close-v3": "SawyerFaucetCloseV3Policy",
+ "faucet-open-v3": "SawyerFaucetOpenV3Policy", "hammer-v3": "SawyerHammerV3Policy",
+ "hand-insert-v3": "SawyerHandInsertV3Policy", "handle-press-side-v3": "SawyerHandlePressSideV3Policy",
+ "handle-press-v3": "SawyerHandlePressV3Policy", "handle-pull-side-v3": "SawyerHandlePullSideV3Policy",
+ "handle-pull-v3": "SawyerHandlePullV3Policy", "lever-pull-v3": "SawyerLeverPullV3Policy",
+ "peg-insert-side-v3": "SawyerPegInsertionSideV3Policy", "peg-unplug-side-v3": "SawyerPegUnplugSideV3Policy",
+ "pick-out-of-hole-v3": "SawyerPickOutOfHoleV3Policy", "pick-place-v3": "SawyerPickPlaceV3Policy",
+ "pick-place-wall-v3": "SawyerPickPlaceWallV3Policy",
+ "plate-slide-back-side-v3": "SawyerPlateSlideBackSideV3Policy",
+ "plate-slide-back-v3": "SawyerPlateSlideBackV3Policy",
+ "plate-slide-side-v3": "SawyerPlateSlideSideV3Policy", "plate-slide-v3": "SawyerPlateSlideV3Policy",
+ "push-back-v3": "SawyerPushBackV3Policy", "push-v3": "SawyerPushV3Policy",
+ "push-wall-v3": "SawyerPushWallV3Policy", "reach-v3": "SawyerReachV3Policy",
+ "reach-wall-v3": "SawyerReachWallV3Policy", "shelf-place-v3": "SawyerShelfPlaceV3Policy",
+ "soccer-v3": "SawyerSoccerV3Policy", "stick-pull-v3": "SawyerStickPullV3Policy",
+ "stick-push-v3": "SawyerStickPushV3Policy", "sweep-into-v3": "SawyerSweepIntoV3Policy",
+ "sweep-v3": "SawyerSweepV3Policy", "window-open-v3": "SawyerWindowOpenV3Policy",
+ "window-close-v3": "SawyerWindowCloseV3Policy"
+ }
+}
diff --git a/src/lerobot/optim/schedulers.py b/src/lerobot/optim/schedulers.py
index 55ee62e40..b5d54b396 100644
--- a/src/lerobot/optim/schedulers.py
+++ b/src/lerobot/optim/schedulers.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
+import logging
import math
from dataclasses import asdict, dataclass
from pathlib import Path
@@ -79,7 +80,11 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
@dataclass
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
- """Used by Physical Intelligence to train Pi0"""
+ """Used by Physical Intelligence to train Pi0.
+
+ Automatically scales warmup and decay steps if num_training_steps < num_decay_steps.
+ This ensures the learning rate schedule completes properly even with shorter training runs.
+ """
num_warmup_steps: int
num_decay_steps: int
@@ -87,23 +92,39 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
decay_lr: float
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
- del num_training_steps
+ # Auto-scale scheduler parameters if training steps are shorter than configured decay steps
+ actual_warmup_steps = self.num_warmup_steps
+ actual_decay_steps = self.num_decay_steps
+
+ if num_training_steps < self.num_decay_steps:
+ # Calculate scaling factor to fit the schedule into the available training steps
+ scale_factor = num_training_steps / self.num_decay_steps
+ actual_warmup_steps = int(self.num_warmup_steps * scale_factor)
+ actual_decay_steps = num_training_steps
+
+ logging.info(
+ f"Auto-scaling LR scheduler: "
+ f"num_training_steps ({num_training_steps}) < num_decay_steps ({self.num_decay_steps}). "
+ f"Scaling warmup: {self.num_warmup_steps} → {actual_warmup_steps}, "
+ f"decay: {self.num_decay_steps} → {actual_decay_steps} "
+ f"(scale factor: {scale_factor:.3f})"
+ )
def lr_lambda(current_step):
def linear_warmup_schedule(current_step):
if current_step <= 0:
- return 1 / (self.num_warmup_steps + 1)
- frac = 1 - current_step / self.num_warmup_steps
- return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
+ return 1 / (actual_warmup_steps + 1)
+ frac = 1 - current_step / actual_warmup_steps
+ return (1 / (actual_warmup_steps + 1) - 1) * frac + 1
def cosine_decay_schedule(current_step):
- step = min(current_step, self.num_decay_steps)
- cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
+ step = min(current_step, actual_decay_steps)
+ cosine_decay = 0.5 * (1 + math.cos(math.pi * step / actual_decay_steps))
alpha = self.decay_lr / self.peak_lr
decayed = (1 - alpha) * cosine_decay + alpha
return decayed
- if current_step < self.num_warmup_steps:
+ if current_step < actual_warmup_steps:
return linear_warmup_schedule(current_step)
return cosine_decay_schedule(current_step)
diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py
index ac76baf9f..cfb550ab2 100644
--- a/src/lerobot/policies/factory.py
+++ b/src/lerobot/policies/factory.py
@@ -31,7 +31,6 @@ from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
-from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
@@ -58,7 +57,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
- "vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla".
+ "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
Returns:
The policy class corresponding to the given name.
@@ -82,10 +81,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy
- elif name == "pi0fast":
- from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
-
- return PI0FASTPolicy
elif name == "pi0":
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
@@ -119,7 +114,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
- "diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla",
+ "diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
"reward_classifier".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
@@ -137,8 +132,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return ACTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
- elif policy_type == "pi0fast":
- return PI0FASTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
elif policy_type == "pi05":
@@ -260,14 +253,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
- elif isinstance(policy_cfg, PI0FASTConfig):
- from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
-
- processors = make_pi0fast_pre_post_processors(
- config=policy_cfg,
- dataset_stats=kwargs.get("dataset_stats"),
- )
-
elif isinstance(policy_cfg, PI0Config):
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py
index cc1cda9d8..d745f4317 100644
--- a/src/lerobot/policies/pi0/configuration_pi0.py
+++ b/src/lerobot/policies/pi0/configuration_pi0.py
@@ -75,6 +75,8 @@ class PI0Config(PreTrainedConfig):
optimizer_grad_clip_norm: float = 1.0
# Scheduler settings: see openpi `CosineDecaySchedule`
+ # Note: These will auto-scale if --steps < scheduler_decay_steps
+ # For example, --steps=3000 will scale warmup to 100 and decay to 3000
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py
index a2dcdaea3..596b273d5 100644
--- a/src/lerobot/policies/pi0/modeling_pi0.py
+++ b/src/lerobot/policies/pi0/modeling_pi0.py
@@ -897,7 +897,7 @@ class PI0Policy(PreTrainedPolicy):
) -> T:
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
print(
- "The PI05 model is a direct port of the OpenPI implementation. \n"
+ "The PI0 model is a direct port of the OpenPI implementation. \n"
"This implementation follows the original OpenPI structure for compatibility. \n"
"Original implementation: https://github.com/Physical-Intelligence/openpi"
)
diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py
index 7c1e950b0..61346c330 100644
--- a/src/lerobot/policies/pi05/configuration_pi05.py
+++ b/src/lerobot/policies/pi05/configuration_pi05.py
@@ -75,6 +75,8 @@ class PI05Config(PreTrainedConfig):
optimizer_grad_clip_norm: float = 1.0
# Scheduler settings: see openpi `CosineDecaySchedule`
+ # Note: These will auto-scale if --steps < scheduler_decay_steps
+ # For example, --steps=3000 will scale warmup to 100 and decay to 3000
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
diff --git a/src/lerobot/policies/pi0fast/configuration_pi0fast.py b/src/lerobot/policies/pi0fast/configuration_pi0fast.py
deleted file mode 100644
index cefd4e688..000000000
--- a/src/lerobot/policies/pi0fast/configuration_pi0fast.py
+++ /dev/null
@@ -1,153 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from dataclasses import dataclass, field
-
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
-from lerobot.optim.optimizers import AdamWConfig
-from lerobot.optim.schedulers import (
- CosineDecayWithWarmupSchedulerConfig,
-)
-from lerobot.utils.constants import OBS_IMAGES
-
-
-@PreTrainedConfig.register_subclass("pi0fast")
-@dataclass
-class PI0FASTConfig(PreTrainedConfig):
- # Input / output structure.
- n_obs_steps: int = 1
- chunk_size: int = 10
- n_action_steps: int = 5
-
- normalization_mapping: dict[str, NormalizationMode] = field(
- default_factory=lambda: {
- "VISUAL": NormalizationMode.IDENTITY,
- "STATE": NormalizationMode.MEAN_STD,
- "ACTION": NormalizationMode.MEAN_STD,
- }
- )
-
- # Shorter state and action vectors will be padded
- max_state_dim: int = 32 # 32
- max_action_dim: int = 32 # 32
-
- # Image preprocessing
- resize_imgs_with_padding: tuple[int, int] = (224, 224)
- interpolate_like_pi: bool = False
-
- # Add empty images. Used by pi0_aloha_sim which adds the empty
- # left and right wrist cameras in addition to the top camera.
- empty_cameras: int = 0
-
- # Converts the joint and gripper values from the standard Aloha space to
- # the space used by the pi internal runtime which was used to train the base model.
- adapt_to_pi_aloha: bool = False
-
- # Converts joint dimensions to deltas with respect to the current state before passing to the model.
- # Gripper dimensions will remain in absolute values.
- use_delta_joint_actions_aloha: bool = False
-
- # Tokenizer
- tokenizer_max_length: int = 48
-
- # Projector
- proj_width: int = 1024
-
- # Decoding
- max_decoding_steps: int = 256
- fast_skip_tokens: int = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
- max_input_seq_len: int = 256 # 512
-
- # Utils
- use_cache: bool = True
-
- # Frozen parameters
- freeze_vision_encoder: bool = True
- freeze_lm_head: bool = True
-
- # Training presets
- optimizer_lr: float = 1e-4
- optimizer_betas: tuple[float, float] = (0.9, 0.95)
- optimizer_eps: float = 1e-8
- optimizer_weight_decay: float = 1e-5
-
- scheduler_warmup_steps: int = 1_000
- scheduler_decay_steps: int = 30_000
- scheduler_decay_lr: float = 2.5e-6
-
- checkpoint_path: str = None
-
- padding_side: str = "right"
-
- precision: str = "bfloat16"
- grad_clip_norm: float = 1
-
- # Allows padding/truncation of generated action tokens during detokenization to ensure decoding.
- # In the original version, tensors of 0s were generated if shapes didn't match for stable decoding.
- relaxed_action_decoding: bool = True
-
- def __post_init__(self):
- super().__post_init__()
-
- """Input validation (not exhaustive)."""
- if self.n_action_steps > self.chunk_size:
- raise ValueError(
- f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
- f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
- )
- if self.n_obs_steps != 1:
- raise ValueError(
- f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
- )
-
- def validate_features(self) -> None:
- for i in range(self.empty_cameras):
- key = f"{OBS_IMAGES}.empty_camera_{i}"
- empty_camera = PolicyFeature(
- type=FeatureType.VISUAL,
- shape=(3, 480, 640),
- )
- self.input_features[key] = empty_camera
-
- def get_optimizer_preset(self) -> AdamWConfig:
- return AdamWConfig(
- lr=self.optimizer_lr,
- betas=self.optimizer_betas,
- eps=self.optimizer_eps,
- weight_decay=self.optimizer_weight_decay,
- grad_clip_norm=self.grad_clip_norm,
- )
-
- def get_scheduler_preset(self):
- return CosineDecayWithWarmupSchedulerConfig(
- peak_lr=self.optimizer_lr,
- decay_lr=self.scheduler_decay_lr,
- num_warmup_steps=self.scheduler_warmup_steps,
- num_decay_steps=self.scheduler_decay_steps,
- )
-
- @property
- def observation_delta_indices(self) -> None:
- return None
-
- @property
- def action_delta_indices(self) -> list:
- return list(range(self.chunk_size))
-
- @property
- def reward_delta_indices(self) -> None:
- return None
diff --git a/src/lerobot/policies/pi0fast/modeling_pi0fast.py b/src/lerobot/policies/pi0fast/modeling_pi0fast.py
deleted file mode 100644
index 102cfb8fa..000000000
--- a/src/lerobot/policies/pi0fast/modeling_pi0fast.py
+++ /dev/null
@@ -1,980 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models
-
-[Paper](https://huggingface.co/papers/2501.09747)
-[Jax code](https://github.com/Physical-Intelligence/openpi)
-
-Designed by Physical Intelligence. Ported from Jax by Hugging Face.
-Disclaimer: It is not expected to perform as well as the original implementation.
-
-Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
-```bash
-lerobot-train \
---policy.path=lerobot/pi0fast_base \
---dataset.repo_id=danaaubakirova/koch_test
-```
-
-Example of training the pi0+FAST neural network with from scratch:
-```bash
-lerobot-train \
---policy.type=pi0fast \
---dataset.repo_id=danaaubakirova/koch_test
-```
-
-Example of using the pi0 pretrained model outside LeRobot training framework:
-```python
-policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base")
-```
-
-"""
-
-from collections import deque
-from functools import partial
-
-import numpy as np
-import torch
-import torch.nn.functional as F # noqa: N812
-from PIL import Image
-from scipy.fft import idct
-from torch import Tensor, nn
-from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration
-from transformers.cache_utils import HybridCache, StaticCache
-from transformers.models.auto import CONFIG_MAPPING
-
-from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
-from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.utils.constants import ACTION, OBS_STATE
-
-PRECISION = {
- "float16": torch.float16,
- "float32": torch.float32,
- "bfloat16": torch.bfloat16,
-}
-
-
-def normalize(x, min_val, max_val):
- return (x - min_val) / (max_val - min_val)
-
-
-def unnormalize(x, min_val, max_val):
- return x * (max_val - min_val) + min_val
-
-
-def safe_arcsin(value):
- # This ensures that the input stays within
- # [−1,1] to avoid invalid values for arcsin
- return torch.arcsin(torch.clamp(value, -1.0, 1.0))
-
-
-def aloha_gripper_to_angular(value):
- # Aloha transforms the gripper positions into a linear space. The following code
- # reverses this transformation to be consistent with pi0 which is pretrained in
- # angular space.
- #
- # These values are coming from the Aloha code:
- # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
- value = unnormalize(value, min_val=0.01844, max_val=0.05800)
-
- # This is the inverse of the angular to linear transformation inside the Interbotix code.
- def linear_to_radian(linear_position, arm_length, horn_radius):
- value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
- return safe_arcsin(value)
-
- # The constants are taken from the Interbotix code.
- value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
-
- # Normalize to [0, 1].
- # The values 0.4 and 1.5 were measured on an actual Trossen robot.
- return normalize(value, min_val=0.4, max_val=1.5)
-
-
-def aloha_gripper_from_angular(value):
- # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
- # Note that the units are still angular but the range is different.
-
- # The values 0.4 and 1.5 were measured on an actual Trossen robot.
- value = unnormalize(value, min_val=0.4, max_val=1.5)
-
- # These values are coming from the Aloha code:
- # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
- return normalize(value, min_val=-0.6213, max_val=1.4910)
-
-
-def aloha_gripper_from_angular_inv(value):
- # Directly inverts the gripper_from_angular function.
- value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
- return normalize(value, min_val=0.4, max_val=1.5)
-
-
-class PI0FASTPolicy(PreTrainedPolicy):
- """Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot."""
-
- config_class = PI0FASTConfig
- name = "pi0fast"
-
- def __init__(
- self,
- config: PI0FASTConfig,
- dataset_stats: dict[str, dict[str, Tensor]] | None = None,
- ):
- """
- Args:
- config: Policy configuration class instance or None, in which case the default instantiation of
- the configuration class is used.
- dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
- that they will be passed with a call to `load_state_dict` before the policy is used.
- """
-
- super().__init__(config)
- config.validate_features()
- self.config = config
-
- self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
- self.model = PI0FAST(config)
-
- self.reset()
-
- def reset(self):
- """This should be called whenever the environment is reset."""
- self._action_queue = deque([], maxlen=self.config.n_action_steps)
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- """Override the from_pretrained method to display important disclaimer."""
- print(
- "⚠️ DISCLAIMER: The PI0FAST model is ported from JAX by the Hugging Face team. \n"
- " It is not expected to perform as well as the original implementation. \n"
- " Original implementation: https://github.com/Physical-Intelligence/openpi"
- )
- return super().from_pretrained(*args, **kwargs)
-
- def get_optim_params(self) -> dict:
- return self.parameters()
-
- def _pi_aloha_decode_state(self, state):
- # Flip the joints.
- for motor_idx in [1, 2, 8, 9]:
- state[:, motor_idx] *= -1
- # Reverse the gripper transformation that is being applied by the Aloha runtime.
- for motor_idx in [6, 13]:
- state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
- return state
-
- def _pi_aloha_encode_actions(self, actions):
- # Flip the joints.
- for motor_idx in [1, 2, 8, 9]:
- actions[:, :, motor_idx] *= -1
- # Reverse the gripper transformation that is being applied by the Aloha runtime.
- for motor_idx in [6, 13]:
- actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
- return actions
-
- def _pi_aloha_encode_actions_inv(self, actions):
- # Flip the joints again.
- for motor_idx in [1, 2, 8, 9]:
- actions[:, :, motor_idx] *= -1
- # Reverse the gripper transformation that is being applied by the Aloha runtime.
- for motor_idx in [6, 13]:
- actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
- return actions
-
- @torch.no_grad()
- def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
- """Predict a chunk of actions given environment observations."""
- raise NotImplementedError("Currently not implemented for PI0FAST")
-
- @torch.no_grad()
- def select_action(self, batch: dict[str, Tensor]) -> Tensor:
- """Select a single action given environment observations.
-
- This method wraps `select_actions` in order to return one action at a time for execution in the
- environment. It works by managing the actions in a queue and only calling `select_actions` when the
- queue is empty.
- """
- self.eval()
-
- if self.config.adapt_to_pi_aloha:
- batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
-
- # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
- # querying the policy.
- if len(self._action_queue) == 0:
- actions = self.model.generate_actions(batch)
-
- actions = actions[:, : self.config.n_action_steps]
-
- original_action_dim = self.config.action_feature.shape[
- 0
- ] # self.config.max_action_dim # self.config.action_feature.shape[0]
- actions = actions[:, :, :original_action_dim]
-
- if self.config.adapt_to_pi_aloha:
- actions = self._pi_aloha_encode_actions(actions)
-
- # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
- # effectively has shape (n_action_steps, batch_size, *), hence the transpose.
- self._action_queue.extend(actions.transpose(0, 1))
- return self._action_queue.popleft()
-
- def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
- if self.config.adapt_to_pi_aloha:
- batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
- batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
- loss_dict = self.model.forward(batch)
- return loss_dict["loss"], loss_dict
-
-
-def block_causal_update_causal_mask(
- attention_mask,
- token_type_ids=None,
- past_key_values=None,
- cache_position=None,
- input_tensor=None,
- attn_implementation: str = "eager",
- dtype: torch.dtype = "float32",
-):
- """
- Update the causal mask during training and generation. It can be customized to different attention masks.
- """
- if attn_implementation == "flash_attention_2":
- if attention_mask is not None and 0.0 in attention_mask:
- return attention_mask
- return None
- using_static_cache = isinstance(past_key_values, StaticCache)
- min_dtype = torch.finfo(dtype).min
-
- if input_tensor is None:
- input_tensor = attention_mask
-
- inputs_lead_dim, sequence_length = input_tensor.shape[:2]
-
- if using_static_cache or isinstance(past_key_values, HybridCache):
- target_length = past_key_values.get_max_cache_shape()
- else:
- target_length = (
- attention_mask.shape[-1]
- if isinstance(attention_mask, torch.Tensor)
- else cache_position[0] + sequence_length + 1
- )
-
- # Handle precomputed attention masks
- if attention_mask is not None and attention_mask.dim() == 4:
- return attention_mask
-
- # Causal mask initialization
- causal_mask = torch.full(
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
- )
-
- # Standard causal masking (triu ensures tokens can only attend to past)
- if sequence_length != 1:
- causal_mask = torch.triu(causal_mask, diagonal=1)
-
- # Apply block causal mask
- if token_type_ids is not None:
- token_type_ids = token_type_ids.to(causal_mask.device).bool()
- cumsum = torch.cumsum(token_type_ids, dim=1)
- block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None]
-
- # Combine causal_mask with block-wise attention mask
- causal_mask = torch.where(block_causal_mask, 0.0, causal_mask)
- causal_mask = causal_mask[:, None, :, :]
- else:
- # Apply past cache position constraint
- causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
- -1, 1
- )
- causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
- else:
- # Apply past cache position constraint
- causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
- -1, 1
- )
- causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
-
- if attention_mask is not None:
- causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits
- mask_length = attention_mask.shape[-1]
-
- # Apply padding mask
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
- causal_mask.device
- )
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
- padding_mask, min_dtype
- )
-
- return causal_mask
-
-
-def prepare_inputs_for_generation(
- # self,
- input_ids,
- past_key_values=None,
- inputs_embeds=None,
- cache_position=None,
- position_ids=None,
- pixel_values=None,
- attention_mask=None,
- token_type_ids=None,
- use_cache=True,
- num_logits_to_keep=None,
- labels=None,
- self=None,
- **kwargs,
-):
- # create block causal attention
- if cache_position[0] > 0 and input_ids.shape[1] > 0:
- input_tensor = input_ids[:, -1:]
- new_positions = (
- torch.ones(
- (position_ids.shape[0], input_ids.shape[1]),
- dtype=position_ids.dtype,
- device=position_ids.device,
- ).cumsum(-1)
- + position_ids[:, -1:]
- )
- position_ids = torch.cat([position_ids, new_positions], dim=-1)
- else:
- input_tensor = inputs_embeds
- attention_mask = block_causal_update_causal_mask(
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- cache_position=cache_position,
- input_tensor=input_tensor,
- token_type_ids=token_type_ids,
- dtype=self.dtype,
- attn_implementation=self.config.text_config._attn_implementation,
- )
- # Overwritten -- custom `position_ids` and `pixel_values` handling
- model_inputs = self.language_model.prepare_inputs_for_generation(
- input_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- position_ids=position_ids,
- cache_position=cache_position,
- use_cache=use_cache,
- num_logits_to_keep=num_logits_to_keep,
- token_type_ids=token_type_ids,
- **kwargs,
- )
-
- # Position_ids in Paligemma are 1-indexed
- if model_inputs.get("position_ids") is not None:
- model_inputs["position_ids"] += 1
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
- # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
- if cache_position[0] == 0:
- model_inputs["pixel_values"] = pixel_values
- is_training = token_type_ids is not None and labels is not None
- if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
- input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
- causal_mask = self._update_causal_mask(
- attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
- )
- model_inputs["attention_mask"] = causal_mask
-
- return model_inputs
-
-
-class PI0FAST(nn.Module):
- def __init__(self, config: PI0FASTConfig):
- super().__init__()
- self.config = config
-
- # TODO: move tokenizers in Policy
- fast_tokenizer_path = "physical-intelligence/fast"
- pi0_paligemma_path = "google/paligemma-3b-pt-224"
- self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path)
- self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path)
- self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
- self.fast_skip_tokens = self.config.fast_skip_tokens
- self.max_input_seq_len = self.config.max_input_seq_len
- self.action_horizon = self.config.chunk_size
- self.action_dim = self.config.action_feature.shape[
- 0
- ] # self.config.max_action_dim # self.config.action_feature.shape[0]
- precision = config.precision
- torch_precision = PRECISION.get(precision, torch.float32)
- self.pad_token_id = (
- self.paligemma_tokenizer.pad_token_id
- if hasattr(self.paligemma_tokenizer, "pad_token_id")
- else self.paligemma_tokenizer.eos_token_id
- )
-
- paligemma_config = CONFIG_MAPPING["paligemma"](
- transformers_version="4.48.1",
- _vocab_size=257152,
- bos_token_id=2,
- eos_token_id=1,
- hidden_size=2048,
- image_token_index=257152,
- model_type="paligemma",
- pad_token_id=0,
- projection_dim=2048,
- text_config={
- "hidden_activation": "gelu_pytorch_tanh",
- "hidden_size": 2048,
- "intermediate_size": 16384,
- "model_type": "gemma",
- "num_attention_heads": 8,
- "num_hidden_layers": 18,
- "num_image_tokens": 256,
- "num_key_value_heads": 1,
- "torch_dtype": precision,
- "vocab_size": 257152,
- "_attn_implementation": "eager",
- },
- vision_config={
- "hidden_size": 1152,
- "intermediate_size": 4304,
- "model_type": "siglip_vision_model",
- "num_attention_heads": 16,
- "num_hidden_layers": 27,
- "num_image_tokens": 256,
- "patch_size": 14,
- "projection_dim": 2048,
- "projector_hidden_act": "gelu_pytorch_tanh",
- "torch_dtype": precision,
- "vision_use_head": False,
- },
- )
- self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
-
- self.pi0_paligemma.prepare_inputs_for_generation = partial(
- prepare_inputs_for_generation, self=self.pi0_paligemma
- )
- # change important stuff in bf16
- params_to_change_dtype = [
- "language_model",
- "vision_tower",
- "multi_modal",
- ]
- for name, param in self.pi0_paligemma.named_parameters():
- if any(selector in name for selector in params_to_change_dtype):
- param.data = param.data.to(dtype=torch_precision)
- self.set_requires_grad()
- self.image_keys = self.config.image_features.keys()
- # TODO: Remove this once we bump transformers to >4.52.0 because the attribute will be removed
- # AttributeError: 'PaliGemmaConfig' object has no attribute 'ignore_index'
- self.ignore_index = self.pi0_paligemma.config.ignore_index
- self.padding_side = self.config.padding_side
-
- def set_requires_grad(self):
- if self.config.freeze_vision_encoder:
- self.pi0_paligemma.vision_tower.eval()
- for params in self.pi0_paligemma.vision_tower.parameters():
- params.requires_grad = False
- # To avoid unused params issue with distributed training
- if self.config.freeze_lm_head:
- for name, params in self.pi0_paligemma.named_parameters():
- if "embed_tokens" in name: # lm heads and embedding layer are tied
- params.requires_grad = False
-
- def embed_tokens(self, tokens: torch.Tensor):
- return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
-
- def prepare_inputs_for_generation(self, *args, **kwargs):
- return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs)
-
- def prepare_images(self, batch):
- """Preprocess LeRobot batch into Pi0 inputs"""
- images = []
- img_masks = []
- present_img_keys = [key for key in self.image_keys if key in batch]
- if len(present_img_keys) == 0:
- raise ValueError(
- f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
- )
-
- # Preprocess image features present in the batch
- num_empty_cameras = 0
- for key in self.image_keys:
- if key in present_img_keys:
- img = batch[key]
-
- if self.config.resize_imgs_with_padding is not None:
- img = resize_with_pad(
- img,
- *self.config.resize_imgs_with_padding,
- pad_value=0,
- interpolate_like_pi=self.config.interpolate_like_pi,
- )
-
- # Normalize from range [0,1] to [-1,1] as expected by siglip
- img = img * 2.0 - 1.0
-
- bsize = img.shape[0]
- device = img.device
- mask = torch.ones(bsize, dtype=torch.bool, device=device)
- else:
- if num_empty_cameras >= self.config.empty_cameras:
- continue
- img = torch.ones_like(img) * -1
- bsize = img.shape[0]
- device = img.device
- mask = torch.ones(bsize, dtype=torch.bool, device=device)
- num_empty_cameras += 1
-
- images.append(img)
- img_masks.append(mask)
- return images, img_masks
-
- def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor:
- mins = actions.amin(dim=(1, 2), keepdim=True) # [0]
- maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
- return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
-
- def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
- out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
- return out
-
- def fast_tokenizer_wrapper(self, actions_norm):
- """
- A wrapper for self.fast_tokenizer that ensures batch processing,
- conversion to PyTorch tensors, and returns a dictionary without padding.
- """
- batch_tokens = self.fast_tokenizer(actions_norm)
- fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt")
-
- return fast_out
-
- def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor:
- token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
- # Compute cumulative sum mask
- cumsum_mask = (padded_mask != 0).cumsum(dim=1)
- # Suffix block (everything after prefix_len)
- suffix_mask = cumsum_mask > prefix_len
- token_type_ids = suffix_mask
- return token_type_ids
-
- def create_input_tokens(self, state, lang_text, actions=None):
- bsize = state.shape[0]
- device = state.device
- bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
- discretized = torch.bucketize(state, bins) - 1
- discretized = discretized[:, :32]
-
- prefix_texts = []
- state_text = []
- for txt, disc in zip(lang_text, discretized, strict=False):
- cleaned = txt.lower().strip().replace("_", " ")
- state_str = " ".join(str(val.item()) for val in disc)
- prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
- state_text.append(f"State: {state_str};\n")
-
- prefix_out = self.paligemma_tokenizer(
- prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False
- )
- prefix_ids = prefix_out["input_ids"].to(device)
- prefix_mask = prefix_out["attention_mask"].to(device)
- prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
-
- if actions is not None:
- actions_norm = self.normalize_actions(actions)
- actions_pad = F.pad(
- actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0
- )[:, :, : self.config.max_action_dim]
- fast_out = self.fast_tokenizer_wrapper(
- actions_pad.cpu(),
- )
- act_ids = fast_out["input_ids"]
- act_mask = fast_out["attention_mask"].to(device)
-
- act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
- # Replace action with 0 to pad tokens
- act_ids = torch.where(
- act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
- self.pad_token_id,
- act_ids,
- )
-
- eos_token = torch.tensor(
- [self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
- ).expand(bsize, -1)
- eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
- bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
- bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
- bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
- act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
- act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
- act_mask = act_mask.to(device)
- else:
- act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
- act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
- final_ids = torch.cat([prefix_ids, act_ids], dim=1)
-
- final_mask = torch.cat([prefix_mask, act_mask], dim=1)
- batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
-
- # Use tokenizer pad function
- padded_output = self.paligemma_tokenizer.pad(
- batch_inputs, padding="longest", max_length=180, return_tensors="pt"
- )
- padded_mask = padded_output["attention_mask"]
-
- # define tensor of padding lengths
- att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens
-
- token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)
-
- padded_output["padded_mask"] = padded_output.pop("attention_mask")
- padded_output["attention_mask"] = att_mask
- # loss is computed not on prefix, and not on padding
- padded_output["loss_mask"] = att_mask & padded_output["padded_mask"]
- padded_output["token_type_ids"] = token_type_ids
- return padded_output
-
- def shift_padding_side(
- self,
- tokens: torch.Tensor,
- ar_mask: torch.Tensor,
- padding_mask: torch.Tensor,
- loss_mask: torch.Tensor,
- targets: torch.Tensor,
- token_type_ids: torch.Tensor,
- padding_side: str = "right",
- ) -> tuple[torch.Tensor]:
- if padding_side not in ["right", "left"]:
- return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids
-
- new_tokens = torch.empty_like(tokens)
- new_ar_masks = torch.empty_like(ar_mask)
- new_padding_mask = torch.empty_like(padding_mask)
- new_loss_mask = torch.empty_like(loss_mask)
- new_targets = torch.empty_like(targets)
- new_token_type_ids = torch.empty_like(token_type_ids)
- batch_size = tokens.shape[0]
- for i in range(batch_size):
- padding_indices = torch.where(padding_mask[i] == 0)[0]
- non_padding_indices = torch.where(padding_mask[i] == 1)[0]
- if padding_side == "left":
- new_indices = torch.cat((padding_indices, non_padding_indices), dim=0)
- else:
- new_indices = torch.cat((non_padding_indices, padding_indices), dim=0)
- new_tokens[i] = tokens[i].index_select(0, new_indices)
- new_ar_masks[i] = ar_mask[i].index_select(0, new_indices)
- new_padding_mask[i] = padding_mask[i].index_select(0, new_indices)
- new_loss_mask[i] = loss_mask[i].index_select(0, new_indices)
- new_targets[i] = targets[i].index_select(0, new_indices)
- new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices)
-
- return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids
-
- def forward(self, batch: dict[str, Tensor]):
- device = batch[OBS_STATE].device
- # TODO: keep like this or move to the policy .forward
- images, img_masks = self.prepare_images(batch)
-
- padded_outs = self.create_input_tokens(
- state=batch[OBS_STATE],
- lang_text=batch["task"],
- actions=batch[ACTION],
- )
-
- embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
- images,
- img_masks,
- padded_outs["input_ids"],
- padded_outs["padded_mask"],
- padded_outs["attention_mask"],
- padded_outs["loss_mask"],
- padded_outs["token_type_ids"],
- padding_side=self.padding_side,
- )
- position_ids = torch.cumsum(pad_masks, dim=1) - 1
- token_type_ids = token_type_ids.to(dtype=torch.int64)
- past_seen_tokens = 0
- cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device)
- pad_masks = block_causal_update_causal_mask(
- attention_mask=pad_masks,
- past_key_values=None,
- cache_position=cache_position,
- input_tensor=embs,
- token_type_ids=token_type_ids,
- dtype=self.pi0_paligemma.dtype,
- attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation,
- )
- outputs = self.pi0_paligemma.forward(
- input_ids=None,
- token_type_ids=None,
- attention_mask=pad_masks,
- position_ids=position_ids,
- past_key_values=None,
- inputs_embeds=embs,
- use_cache=False,
- labels=None,
- )
-
- logits = outputs.logits
-
- loss_fct = nn.CrossEntropyLoss(reduction="none")
-
- # Shift left for next-step prediction
- logits = logits[:, :-1, :]
- targets = targets[:, 1:].to(device) # Shift targets
- loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape
-
- # Compute per-token loss
- token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1))
-
- # Apply loss mask
- token_loss = token_loss * loss_mask.reshape(-1)
-
- # Compute final loss
- loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1)
-
- # Return loss dictionary
- loss_dict = {"ce_loss": loss.item(), "loss": loss}
- return loss_dict
-
- def decode_actions_with_fast(
- self,
- tokens: list[list[int]],
- *,
- time_horizon: int | None = None,
- action_dim: int | None = None,
- relaxed_decoding: bool = True,
- ) -> np.array:
- """
- Adapt original decoding in FAST to always return actions instead of zeros.
- """
- self.time_horizon = (
- time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
- )
- self.action_dim = (
- action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim
- )
-
- # Cache the time horizon and action dimension for the next call
- self.called_time_horizon = self.time_horizon
- self.called_action_dim = self.action_dim
-
- assert self.time_horizon is not None and self.action_dim is not None, (
- "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
- )
-
- decoded_actions = []
- for token in tokens:
- try:
- decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
- decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
- if relaxed_decoding:
- # Expected sequence length
- expected_seq_len = self.time_horizon * self.action_dim
- diff = expected_seq_len - decoded_dct_coeff.shape[0]
- # Apply truncation if too long
- if diff < 0:
- decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right
- # Apply padding if too short
- elif diff > 0:
- decoded_dct_coeff = np.pad(
- decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
- )
-
- decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
- assert decoded_dct_coeff.shape == (
- self.time_horizon,
- self.action_dim,
- ), (
- f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
- )
- except Exception as e:
- print(f"Error decoding tokens: {e}")
- print(f"Tokens: {token}")
- decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
- decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho"))
- return np.stack(decoded_actions)
-
- def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor:
- """
- Extracts actions from predicted output tokens using the FAST model.
-
- Args:
- tokens (torch.Tensor): The input tensor of tokenized outputs.
- action_horizon (int): The number of timesteps for actions.
- action_dim (int): The dimensionality of each action.
-
- Returns:
- torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim).
- """
- # Decode predicted output tokens
- decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
- cleaned_tokens = [
- tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
- for tokens_sequence in decoded_tokens
- ]
- raw_action_tokens = [
- self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
- for sample_tokens in cleaned_tokens
- ] # something like this should be robust #looks good
- action_tokens = [
- self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens
- ]
- # returns the tensor of decoded actions per sample in a list
- decoded_actions = [
- torch.tensor(
- self.decode_actions_with_fast(
- tok.tolist(),
- time_horizon=action_horizon,
- action_dim=action_dim,
- relaxed_decoding=self.config.relaxed_action_decoding,
- ),
- device=tokens.device,
- ).squeeze(0)
- for tok in action_tokens
- ]
-
- return torch.stack(
- decoded_actions,
- dim=0,
- )
-
- def generate_actions(self, batch: dict[str, Tensor]):
- # TODO: keep like this or move to the policy .forward
- images, img_masks = self.prepare_images(batch)
-
- padded_outs = self.create_input_tokens(state=batch[OBS_STATE], lang_text=batch["task"], actions=None)
- embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
- images,
- img_masks,
- padded_outs["input_ids"],
- padded_outs["padded_mask"],
- padded_outs["attention_mask"],
- padded_outs["loss_mask"],
- padded_outs["token_type_ids"],
- padding_side="left",
- )
- token_type_ids = token_type_ids.to(dtype=torch.int64)
- prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1
- output_tokens = self.pi0_paligemma.generate(
- input_ids=None,
- attention_mask=pad_masks,
- position_ids=prefix_position_ids,
- past_key_values=None,
- inputs_embeds=embs,
- use_cache=self.config.use_cache,
- max_new_tokens=self.config.max_decoding_steps,
- do_sample=False,
- num_beams=1,
- token_type_ids=token_type_ids,
- )
- actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
- return actions
-
- def embed_image(self, image: torch.Tensor):
- # Handle different transformers versions
- if hasattr(self.pi0_paligemma, "get_image_features"):
- return self.pi0_paligemma.get_image_features(image)
- else:
- return self.pi0_paligemma.model.get_image_features(image)
-
- def embed_inputs(
- self,
- images,
- img_masks,
- tokens,
- pad_mask,
- ar_mask,
- loss_mask,
- token_type_ids,
- padding_side: str = "right",
- ):
- # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
- # images are a list of same size
- # vectorizing everything!
- device = images[0].device
- image_embedding_dim = images[0].shape[-1] # TODO should be from self.config
- all_images = torch.stack(images, dim=1).to(device)
- b, n, c, h, w = all_images.shape
- all_images = all_images.view(b * n, c, h, w)
- embedded = self.embed_image(all_images).to(device)
- b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions
- m = b_n // b # Compute the number of images per sample dynamically
-
- # Reshape dynamically
- embedded = embedded.view(b, m, p, image_embedding_dim)
- tokens_embs = self.embed_tokens(tokens.to(device))
-
- img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device)
- num_img_emb = embedded.shape[2]
- img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1)
- img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
-
- image_target_tokens = (
- torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id
- ).reshape(b, -1)
- image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
-
- embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D)
-
- embs = torch.cat([embedded, tokens_embs], dim=1).to(device)
- pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1)
- att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1)
- loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1)
- targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1)
- token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1)
-
- # Shift pad tokens to the left (.generate()) or right (.train())
- embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side(
- embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side
- )
-
- targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets)
- return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids
-
-
-def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
- # assume no-op when width height fits already
- if img.ndim != 4:
- raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
-
- cur_height, cur_width = img.shape[2:]
-
- ratio = max(cur_width / width, cur_height / height)
- resized_height = int(cur_height / ratio)
- resized_width = int(cur_width / ratio)
-
- if interpolate_like_pi:
- img = (img * 255.0).to(dtype=torch.uint8)
- img = img.permute(0, 2, 3, 1)
- original_device = img.device
- img = img.to(device="cpu").numpy()
- imgs = []
- for sub_img in img:
- sub_img = Image.fromarray(sub_img)
- resized_img = sub_img.resize((resized_width, resized_height), resample=2)
- resized_img = torch.from_numpy(np.array(resized_img))
- imgs.append(resized_img)
- img = torch.stack(imgs, dim=0)
- img = img.permute(0, 3, 1, 2)
- resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0
- else:
- resized_img = F.interpolate(
- img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
- )
-
- pad_height = max(0, int(height - resized_height))
- pad_width = max(0, int(width - resized_width))
-
- # pad on left and top of image
- padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
- return padded_img
diff --git a/src/lerobot/policies/pi0fast/processor_pi0fast.py b/src/lerobot/policies/pi0fast/processor_pi0fast.py
deleted file mode 100644
index 95b5e541b..000000000
--- a/src/lerobot/policies/pi0fast/processor_pi0fast.py
+++ /dev/null
@@ -1,92 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Any
-
-import torch
-
-from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
-from lerobot.processor import (
- AddBatchDimensionProcessorStep,
- DeviceProcessorStep,
- NormalizerProcessorStep,
- PolicyAction,
- PolicyProcessorPipeline,
- RenameObservationsProcessorStep,
- UnnormalizerProcessorStep,
-)
-from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
-from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
-
-
-def make_pi0fast_pre_post_processors(
- config: PI0FASTConfig,
- dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
-) -> tuple[
- PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
- PolicyProcessorPipeline[PolicyAction, PolicyAction],
-]:
- """
- Constructs pre-processor and post-processor pipelines for the PI0Fast policy.
-
- The pre-processing pipeline prepares input data for the model by:
- 1. Renaming features to match pretrained configurations.
- 2. Normalizing input and output features based on dataset statistics.
- 3. Adding a batch dimension.
- 4. Moving all data to the specified device.
-
- The post-processing pipeline handles the model's output by:
- 1. Moving data to the CPU.
- 2. Unnormalizing the output features to their original scale.
-
- Args:
- config: The configuration object for the PI0Fast policy.
- dataset_stats: A dictionary of statistics for normalization.
- preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
- postprocessor_kwargs: Additional arguments for the post-processor pipeline.
-
- Returns:
- A tuple containing the configured pre-processor and post-processor pipelines.
- """
-
- input_steps = [
- RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
- AddBatchDimensionProcessorStep(),
- DeviceProcessorStep(device=config.device),
- NormalizerProcessorStep(
- features={**config.input_features, **config.output_features},
- norm_map=config.normalization_mapping,
- stats=dataset_stats,
- ),
- ]
- output_steps = [
- UnnormalizerProcessorStep(
- features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
- ),
- DeviceProcessorStep(device="cpu"),
- ]
- return (
- PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
- steps=input_steps,
- name=POLICY_PREPROCESSOR_DEFAULT_NAME,
- ),
- PolicyProcessorPipeline[PolicyAction, PolicyAction](
- steps=output_steps,
- name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
- to_transition=policy_action_to_transition,
- to_output=transition_to_policy_action,
- ),
- )
diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py
index 5a3994cdf..21b39a80e 100644
--- a/src/lerobot/policies/utils.py
+++ b/src/lerobot/policies/utils.py
@@ -16,10 +16,16 @@
import logging
from collections import deque
+from typing import Any
+import numpy as np
import torch
from torch import nn
+from lerobot.datasets.utils import build_dataset_frame
+from lerobot.processor import PolicyAction, RobotAction, RobotObservation
+from lerobot.utils.constants import ACTION, OBS_STR
+
def populate_queues(
queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None
@@ -85,3 +91,110 @@ def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str])
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
if unexpected_keys:
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
+
+
+# TODO(Steven): Move this function to a proper preprocessor step
+def prepare_observation_for_inference(
+ observation: dict[str, np.ndarray],
+ device: torch.device,
+ task: str | None = None,
+ robot_type: str | None = None,
+) -> RobotObservation:
+ """Converts observation data to model-ready PyTorch tensors.
+
+ This function takes a dictionary of NumPy arrays, performs necessary
+ preprocessing, and prepares it for model inference. The steps include:
+ 1. Converting NumPy arrays to PyTorch tensors.
+ 2. Normalizing and permuting image data (if any).
+ 3. Adding a batch dimension to each tensor.
+ 4. Moving all tensors to the specified compute device.
+ 5. Adding task and robot type information to the dictionary.
+
+ Args:
+ observation: A dictionary mapping observation names (str) to NumPy
+ array data. For images, the format is expected to be (H, W, C).
+ device: The PyTorch device (e.g., 'cpu' or 'cuda') to which the
+ tensors will be moved.
+ task: An optional string identifier for the current task.
+ robot_type: An optional string identifier for the robot being used.
+
+ Returns:
+ A dictionary where values are PyTorch tensors preprocessed for
+ inference, residing on the target device. Image tensors are reshaped
+ to (C, H, W) and normalized to a [0, 1] range.
+ """
+ for name in observation:
+ observation[name] = torch.from_numpy(observation[name])
+ if "image" in name:
+ observation[name] = observation[name].type(torch.float32) / 255
+ observation[name] = observation[name].permute(2, 0, 1).contiguous()
+ observation[name] = observation[name].unsqueeze(0)
+ observation[name] = observation[name].to(device)
+
+ observation["task"] = task if task else ""
+ observation["robot_type"] = robot_type if robot_type else ""
+
+ return observation
+
+
+def build_inference_frame(
+ observation: dict[str, Any],
+ device: torch.device,
+ ds_features: dict[str, dict],
+ task: str | None = None,
+ robot_type: str | None = None,
+) -> RobotObservation:
+ """Constructs a model-ready observation tensor dict from a raw observation.
+
+ This utility function orchestrates the process of converting a raw,
+ unstructured observation from an environment into a structured,
+ tensor-based format suitable for passing to a policy model.
+
+ Args:
+ observation: The raw observation dictionary, which may contain
+ superfluous keys.
+ device: The target PyTorch device for the final tensors.
+ ds_features: A configuration dictionary that specifies which features
+ to extract from the raw observation.
+ task: An optional string identifier for the current task.
+ robot_type: An optional string identifier for the robot being used.
+
+ Returns:
+ A dictionary of preprocessed tensors ready for model inference.
+ """
+ # Extracts the correct keys from the incoming raw observation
+ observation = build_dataset_frame(ds_features, observation, prefix=OBS_STR)
+
+ # Performs the necessary conversions to the observation
+ observation = prepare_observation_for_inference(observation, device, task, robot_type)
+
+ return observation
+
+
+def make_robot_action(action_tensor: PolicyAction, ds_features: dict[str, dict]) -> RobotAction:
+ """Converts a policy's output tensor into a dictionary of named actions.
+
+ This function translates the numerical output from a policy model into a
+ human-readable and robot-consumable format, where each dimension of the
+ action tensor is mapped to a named motor or actuator command.
+
+ Args:
+ action_tensor: A PyTorch tensor representing the policy's action,
+ typically with a batch dimension (e.g., shape [1, action_dim]).
+ ds_features: A configuration dictionary containing metadata, including
+ the names corresponding to each index of the action tensor.
+
+ Returns:
+ A dictionary mapping action names (e.g., "joint_1_motor") to their
+ corresponding floating-point values, ready to be sent to a robot
+ controller.
+ """
+ # TODO(Steven): Check if these steps are already in all postprocessor policies
+ action_tensor = action_tensor.squeeze(0)
+ action_tensor = action_tensor.to("cpu")
+
+ action_names = ds_features[ACTION]["names"]
+ act_processed_policy: RobotAction = {
+ f"{name}": float(action_tensor[i]) for i, name in enumerate(action_names)
+ }
+ return act_processed_policy
diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py
index 917e4e2cc..81aa29c48 100644
--- a/src/lerobot/rl/buffer.py
+++ b/src/lerobot/rl/buffer.py
@@ -607,6 +607,7 @@ class ReplayBuffer:
lerobot_dataset.save_episode()
lerobot_dataset.stop_image_writer()
+ lerobot_dataset.finalize()
return lerobot_dataset
diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py
index ad36f1b36..f9c9d0d7a 100644
--- a/src/lerobot/rl/gym_manipulator.py
+++ b/src/lerobot/rl/gym_manipulator.py
@@ -696,7 +696,7 @@ def control_loop(
episode_idx += 1
if dataset is not None:
- if transition[TransitionKey.INFO].get("rerecord_episode", False):
+ if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False):
logging.info(f"Re-recording episode {episode_idx}")
dataset.clear_episode_buffer()
episode_idx -= 1
diff --git a/src/lerobot/rl/wandb_utils.py b/src/lerobot/rl/wandb_utils.py
index 8c92e7145..7b7f8a57b 100644
--- a/src/lerobot/rl/wandb_utils.py
+++ b/src/lerobot/rl/wandb_utils.py
@@ -99,7 +99,7 @@ class WandBLogger:
cfg.wandb.run_id = run_id
# Handle custom step key for rl asynchronous training.
self._wandb_custom_step_key: set[str] | None = None
- print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
+ logging.info(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py
new file mode 100644
index 000000000..83ba027bc
--- /dev/null
+++ b/src/lerobot/scripts/lerobot_edit_dataset.py
@@ -0,0 +1,286 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Edit LeRobot datasets using various transformation tools.
+
+This script allows you to delete episodes, split datasets, merge datasets,
+and remove features. When new_repo_id is specified, creates a new dataset.
+
+Usage Examples:
+
+Delete episodes 0, 2, and 5 from a dataset:
+ python -m lerobot.scripts.lerobot_edit_dataset \
+ --repo_id lerobot/pusht \
+ --operation.type delete_episodes \
+ --operation.episode_indices "[0, 2, 5]"
+
+Delete episodes and save to a new dataset:
+ python -m lerobot.scripts.lerobot_edit_dataset \
+ --repo_id lerobot/pusht \
+ --new_repo_id lerobot/pusht_filtered \
+ --operation.type delete_episodes \
+ --operation.episode_indices "[0, 2, 5]"
+
+Split dataset by fractions:
+ python -m lerobot.scripts.lerobot_edit_dataset \
+ --repo_id lerobot/pusht \
+ --operation.type split \
+ --operation.splits '{"train": 0.8, "val": 0.2}'
+
+Split dataset by episode indices:
+ python -m lerobot.scripts.lerobot_edit_dataset \
+ --repo_id lerobot/pusht \
+ --operation.type split \
+ --operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}'
+
+Split into more than two splits:
+ python -m lerobot.scripts.lerobot_edit_dataset \
+ --repo_id lerobot/pusht \
+ --operation.type split \
+ --operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}'
+
+Merge multiple datasets:
+ python -m lerobot.scripts.lerobot_edit_dataset \
+ --repo_id lerobot/pusht_merged \
+ --operation.type merge \
+ --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
+
+Remove camera feature:
+ python -m lerobot.scripts.lerobot_edit_dataset \
+ --repo_id lerobot/pusht \
+ --operation.type remove_feature \
+ --operation.feature_names "['observation.images.top']"
+
+Using JSON config file:
+ python -m lerobot.scripts.lerobot_edit_dataset \
+ --config_path path/to/edit_config.json
+"""
+
+import logging
+import shutil
+from dataclasses import dataclass
+from pathlib import Path
+
+from lerobot.configs import parser
+from lerobot.datasets.dataset_tools import (
+ delete_episodes,
+ merge_datasets,
+ remove_feature,
+ split_dataset,
+)
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.utils.constants import HF_LEROBOT_HOME
+from lerobot.utils.utils import init_logging
+
+
+@dataclass
+class DeleteEpisodesConfig:
+ type: str = "delete_episodes"
+ episode_indices: list[int] | None = None
+
+
+@dataclass
+class SplitConfig:
+ type: str = "split"
+ splits: dict[str, float | list[int]] | None = None
+
+
+@dataclass
+class MergeConfig:
+ type: str = "merge"
+ repo_ids: list[str] | None = None
+
+
+@dataclass
+class RemoveFeatureConfig:
+ type: str = "remove_feature"
+ feature_names: list[str] | None = None
+
+
+@dataclass
+class EditDatasetConfig:
+ repo_id: str
+ operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig
+ root: str | None = None
+ new_repo_id: str | None = None
+ push_to_hub: bool = False
+
+
+def get_output_path(repo_id: str, new_repo_id: str | None, root: Path | None) -> tuple[str, Path]:
+ if new_repo_id:
+ output_repo_id = new_repo_id
+ output_dir = root / new_repo_id if root else HF_LEROBOT_HOME / new_repo_id
+ else:
+ output_repo_id = repo_id
+ dataset_path = root / repo_id if root else HF_LEROBOT_HOME / repo_id
+ old_path = Path(str(dataset_path) + "_old")
+
+ if dataset_path.exists():
+ if old_path.exists():
+ shutil.rmtree(old_path)
+ shutil.move(str(dataset_path), str(old_path))
+
+ output_dir = dataset_path
+
+ return output_repo_id, output_dir
+
+
+def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
+ if not isinstance(cfg.operation, DeleteEpisodesConfig):
+ raise ValueError("Operation config must be DeleteEpisodesConfig")
+
+ if not cfg.operation.episode_indices:
+ raise ValueError("episode_indices must be specified for delete_episodes operation")
+
+ dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
+ output_repo_id, output_dir = get_output_path(
+ cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
+ )
+
+ if cfg.new_repo_id is None:
+ dataset.root = Path(str(dataset.root) + "_old")
+
+ logging.info(f"Deleting episodes {cfg.operation.episode_indices} from {cfg.repo_id}")
+ new_dataset = delete_episodes(
+ dataset,
+ episode_indices=cfg.operation.episode_indices,
+ output_dir=output_dir,
+ repo_id=output_repo_id,
+ )
+
+ logging.info(f"Dataset saved to {output_dir}")
+ logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}")
+
+ if cfg.push_to_hub:
+ logging.info(f"Pushing to hub as {output_repo_id}")
+ LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
+
+
+def handle_split(cfg: EditDatasetConfig) -> None:
+ if not isinstance(cfg.operation, SplitConfig):
+ raise ValueError("Operation config must be SplitConfig")
+
+ if not cfg.operation.splits:
+ raise ValueError(
+ "splits dict must be specified with split names as keys and fractions/episode lists as values"
+ )
+
+ dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
+
+ logging.info(f"Splitting dataset {cfg.repo_id} with splits: {cfg.operation.splits}")
+ split_datasets = split_dataset(dataset, splits=cfg.operation.splits)
+
+ for split_name, split_ds in split_datasets.items():
+ split_repo_id = f"{cfg.repo_id}_{split_name}"
+ logging.info(
+ f"{split_name}: {split_ds.meta.total_episodes} episodes, {split_ds.meta.total_frames} frames"
+ )
+
+ if cfg.push_to_hub:
+ logging.info(f"Pushing {split_name} split to hub as {split_repo_id}")
+ LeRobotDataset(split_ds.repo_id, root=split_ds.root).push_to_hub()
+
+
+def handle_merge(cfg: EditDatasetConfig) -> None:
+ if not isinstance(cfg.operation, MergeConfig):
+ raise ValueError("Operation config must be MergeConfig")
+
+ if not cfg.operation.repo_ids:
+ raise ValueError("repo_ids must be specified for merge operation")
+
+ if not cfg.repo_id:
+ raise ValueError("repo_id must be specified as the output repository for merged dataset")
+
+ logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge")
+ datasets = [LeRobotDataset(repo_id, root=cfg.root) for repo_id in cfg.operation.repo_ids]
+
+ output_dir = Path(cfg.root) / cfg.repo_id if cfg.root else HF_LEROBOT_HOME / cfg.repo_id
+
+ logging.info(f"Merging datasets into {cfg.repo_id}")
+ merged_dataset = merge_datasets(
+ datasets,
+ output_repo_id=cfg.repo_id,
+ output_dir=output_dir,
+ )
+
+ logging.info(f"Merged dataset saved to {output_dir}")
+ logging.info(
+ f"Episodes: {merged_dataset.meta.total_episodes}, Frames: {merged_dataset.meta.total_frames}"
+ )
+
+ if cfg.push_to_hub:
+ logging.info(f"Pushing to hub as {cfg.repo_id}")
+ LeRobotDataset(merged_dataset.repo_id, root=output_dir).push_to_hub()
+
+
+def handle_remove_feature(cfg: EditDatasetConfig) -> None:
+ if not isinstance(cfg.operation, RemoveFeatureConfig):
+ raise ValueError("Operation config must be RemoveFeatureConfig")
+
+ if not cfg.operation.feature_names:
+ raise ValueError("feature_names must be specified for remove_feature operation")
+
+ dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
+ output_repo_id, output_dir = get_output_path(
+ cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
+ )
+
+ if cfg.new_repo_id is None:
+ dataset.root = Path(str(dataset.root) + "_old")
+
+ logging.info(f"Removing features {cfg.operation.feature_names} from {cfg.repo_id}")
+ new_dataset = remove_feature(
+ dataset,
+ feature_names=cfg.operation.feature_names,
+ output_dir=output_dir,
+ repo_id=output_repo_id,
+ )
+
+ logging.info(f"Dataset saved to {output_dir}")
+ logging.info(f"Remaining features: {list(new_dataset.meta.features.keys())}")
+
+ if cfg.push_to_hub:
+ logging.info(f"Pushing to hub as {output_repo_id}")
+ LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
+
+
+@parser.wrap()
+def edit_dataset(cfg: EditDatasetConfig) -> None:
+ operation_type = cfg.operation.type
+
+ if operation_type == "delete_episodes":
+ handle_delete_episodes(cfg)
+ elif operation_type == "split":
+ handle_split(cfg)
+ elif operation_type == "merge":
+ handle_merge(cfg)
+ elif operation_type == "remove_feature":
+ handle_remove_feature(cfg)
+ else:
+ raise ValueError(
+ f"Unknown operation type: {operation_type}\n"
+ f"Available operations: delete_episodes, split, merge, remove_feature"
+ )
+
+
+def main() -> None:
+ init_logging()
+ edit_dataset()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py
index d45be5c42..0fdec9286 100644
--- a/src/lerobot/scripts/lerobot_eval.py
+++ b/src/lerobot/scripts/lerobot_eval.py
@@ -27,7 +27,7 @@ lerobot-eval \
--eval.batch_size=10 \
--eval.n_episodes=10 \
--use_amp=false \
- --device=cuda
+ --policy.device=cuda
```
OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes.
@@ -38,7 +38,7 @@ lerobot-eval \
--eval.batch_size=10 \
--eval.n_episodes=10 \
--use_amp=false \
- --device=cuda
+ --policy.device=cuda
```
Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files.
@@ -180,9 +180,15 @@ def rollout(
render_callback(env)
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
- # available of none of the envs finished.
+ # available if none of the envs finished.
if "final_info" in info:
- successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
+ final_info = info["final_info"]
+ if not isinstance(final_info, dict):
+ raise RuntimeError(
+ "Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). "
+ "You're likely using an older version of gymnasium (< 1.0). Please upgrade."
+ )
+ successes = final_info["is_success"].tolist()
else:
successes = [False] * env.num_envs
diff --git a/src/lerobot/scripts/lerobot_find_cameras.py b/src/lerobot/scripts/lerobot_find_cameras.py
index e17dca805..0248a2768 100644
--- a/src/lerobot/scripts/lerobot_find_cameras.py
+++ b/src/lerobot/scripts/lerobot_find_cameras.py
@@ -180,7 +180,7 @@ def create_camera_instance(cam_meta: dict[str, Any]) -> dict[str, Any] | None:
if instance:
logger.info(f"Connecting to {cam_type} camera: {cam_id}...")
- instance.connect(warmup=False)
+ instance.connect(warmup=True)
return {"instance": instance, "meta": cam_meta}
except Exception as e:
logger.error(f"Failed to connect or configure {cam_type} camera {cam_id}: {e}")
diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py
index 91cbbfb86..28da73be2 100644
--- a/src/lerobot/scripts/lerobot_record.py
+++ b/src/lerobot/scripts/lerobot_record.py
@@ -79,6 +79,7 @@ from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
+from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
PolicyAction,
PolicyProcessorPipeline,
@@ -324,10 +325,7 @@ def record_loop(
robot_type=robot.robot_type,
)
- action_names = dataset.features[ACTION]["names"]
- act_processed_policy: RobotAction = {
- f"{name}": float(action_values[i]) for i, name in enumerate(action_names)
- }
+ act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
elif policy is None and isinstance(teleop, Teleoperator):
act = teleop.get_action()
diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py
index 7023cb1d0..f4b1118c3 100644
--- a/src/lerobot/scripts/lerobot_train.py
+++ b/src/lerobot/scripts/lerobot_train.py
@@ -21,8 +21,8 @@ from pprint import pformat
from typing import Any
import torch
+from accelerate import Accelerator
from termcolor import colored
-from torch.amp import GradScaler
from torch.optim import Optimizer
from lerobot.configs import parser
@@ -35,7 +35,6 @@ from lerobot.envs.utils import close_envs
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.policies.utils import get_device_from_parameters
from lerobot.rl.wandb_utils import WandBLogger
from lerobot.scripts.lerobot_eval import eval_policy_all
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
@@ -49,7 +48,6 @@ from lerobot.utils.train_utils import (
)
from lerobot.utils.utils import (
format_big_number,
- get_safe_torch_device,
has_method,
init_logging,
)
@@ -61,16 +59,15 @@ def update_policy(
batch: Any,
optimizer: Optimizer,
grad_clip_norm: float,
- grad_scaler: GradScaler,
+ accelerator: Accelerator,
lr_scheduler=None,
- use_amp: bool = False,
lock=None,
) -> tuple[MetricsTracker, dict]:
"""
Performs a single training step to update the policy's weights.
This function executes the forward and backward passes, clips gradients, and steps the optimizer and
- learning rate scheduler. It also handles mixed-precision training via a GradScaler.
+ learning rate scheduler. Accelerator handles mixed-precision training automatically.
Args:
train_metrics: A MetricsTracker instance to record training statistics.
@@ -78,9 +75,8 @@ def update_policy(
batch: A batch of training data.
optimizer: The optimizer used to update the policy's parameters.
grad_clip_norm: The maximum norm for gradient clipping.
- grad_scaler: The GradScaler for automatic mixed-precision training.
+ accelerator: The Accelerator instance for distributed training and mixed precision.
lr_scheduler: An optional learning rate scheduler.
- use_amp: A boolean indicating whether to use automatic mixed precision.
lock: An optional lock for thread-safe optimizer updates.
Returns:
@@ -89,28 +85,27 @@ def update_policy(
- A dictionary of outputs from the policy's forward pass, for logging purposes.
"""
start_time = time.perf_counter()
- device = get_device_from_parameters(policy)
policy.train()
- with torch.autocast(device_type=device.type) if use_amp else nullcontext():
+
+ # Let accelerator handle mixed precision
+ with accelerator.autocast():
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
- grad_scaler.scale(loss).backward()
- # Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
- grad_scaler.unscale_(optimizer)
+ # Use accelerator's backward method
+ accelerator.backward(loss)
- grad_norm = torch.nn.utils.clip_grad_norm_(
- policy.parameters(),
- grad_clip_norm,
- error_if_nonfinite=False,
- )
+ # Clip gradients if specified
+ if grad_clip_norm > 0:
+ grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
+ else:
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ policy.parameters(), float("inf"), error_if_nonfinite=False
+ )
- # Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
- # although it still skips optimizer.step() if the gradients contain infs or NaNs.
+ # Optimizer step
with lock if lock is not None else nullcontext():
- grad_scaler.step(optimizer)
- # Updates the scale for next iteration.
- grad_scaler.update()
+ optimizer.step()
optimizer.zero_grad()
@@ -118,9 +113,9 @@ def update_policy(
if lr_scheduler is not None:
lr_scheduler.step()
- if has_method(policy, "update"):
- # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
- policy.update()
+ # Update internal buffers if policy has update method
+ if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
+ accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
@@ -209,7 +204,7 @@ def wrap_policy_in_peft_model(cfg, policy):
@parser.wrap()
-def train(cfg: TrainPipelineConfig):
+def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
"""
Main function to train a policy.
@@ -223,36 +218,68 @@ def train(cfg: TrainPipelineConfig):
Args:
cfg: A `TrainPipelineConfig` object containing all training configurations.
+ accelerator: Optional Accelerator instance. If None, one will be created automatically.
"""
cfg.validate()
- logging.info(pformat(cfg.to_dict()))
- if cfg.wandb.enable and cfg.wandb.project:
+ # Create Accelerator if not provided
+ # It will automatically detect if running in distributed mode or single-process mode
+ # We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
+ # We set find_unused_parameters=True to handle models with conditional computation
+ if accelerator is None:
+ from accelerate.utils import DistributedDataParallelKwargs
+
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
+
+ init_logging(accelerator=accelerator)
+
+ # Determine if this is the main process (for logging and checkpointing)
+ # When using accelerate, only the main process should log to avoid duplicate outputs
+ is_main_process = accelerator.is_main_process
+
+ # Only log on main process
+ if is_main_process:
+ logging.info(pformat(cfg.to_dict()))
+
+ # Initialize wandb only on main process
+ if cfg.wandb.enable and cfg.wandb.project and is_main_process:
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
- logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
+ if is_main_process:
+ logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if cfg.seed is not None:
- set_seed(cfg.seed)
+ set_seed(cfg.seed, accelerator=accelerator)
- # Check device is available
- device = get_safe_torch_device(cfg.policy.device, log=True)
+ # Use accelerator's device
+ device = accelerator.device
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
- logging.info("Creating dataset")
- dataset = make_dataset(cfg)
+ # Dataset loading synchronization: main process downloads first to avoid race conditions
+ if is_main_process:
+ logging.info("Creating dataset")
+ dataset = make_dataset(cfg)
+
+ accelerator.wait_for_everyone()
+
+ # Now all other processes can safely load the dataset
+ if not is_main_process:
+ dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None
if cfg.eval_freq > 0 and cfg.env is not None:
- logging.info("Creating env")
+ if is_main_process:
+ logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
- logging.info("Creating policy")
+ if is_main_process:
+ logging.info("Creating policy")
policy = make_policy(
cfg=cfg.policy,
ds_meta=dataset.meta,
@@ -262,6 +289,9 @@ def train(cfg: TrainPipelineConfig):
logging.info("Using PEFT! Wrapping model.")
policy = wrap_policy_in_peft_model(cfg, policy)
+ # Wait for all processes to finish policy creation before continuing
+ accelerator.wait_for_everyone()
+
# Create processors - only provide dataset_stats if not resuming from saved processors
processor_kwargs = {}
postprocessor_kwargs = {}
@@ -293,9 +323,9 @@ def train(cfg: TrainPipelineConfig):
**postprocessor_kwargs,
)
- logging.info("Creating optimizer and scheduler")
+ if is_main_process:
+ logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
- grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
step = 0 # number of policy updates (forward + backward + optim)
@@ -305,14 +335,18 @@ def train(cfg: TrainPipelineConfig):
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
- logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
- if cfg.env is not None:
- logging.info(f"{cfg.env.task=}")
- logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
- logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
- logging.info(f"{dataset.num_episodes=}")
- logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
- logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
+ if is_main_process:
+ logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
+ if cfg.env is not None:
+ logging.info(f"{cfg.env.task=}")
+ logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
+ logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
+ logging.info(f"{dataset.num_episodes=}")
+ num_processes = accelerator.num_processes
+ effective_bs = cfg.batch_size * num_processes
+ logging.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}")
+ logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
+ logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"):
@@ -335,7 +369,13 @@ def train(cfg: TrainPipelineConfig):
sampler=sampler,
pin_memory=device.type == "cuda",
drop_last=False,
- prefetch_factor=2,
+ prefetch_factor=2 if cfg.num_workers > 0 else None,
+ )
+
+ # Prepare everything with accelerator
+ accelerator.wait_for_everyone()
+ policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
+ policy, optimizer, dataloader, lr_scheduler
)
dl_iter = cycle(dataloader)
@@ -349,11 +389,20 @@ def train(cfg: TrainPipelineConfig):
"dataloading_s": AverageMeter("data_s", ":.3f"),
}
+ # Use effective batch size for proper epoch calculation in distributed training
+ effective_batch_size = cfg.batch_size * accelerator.num_processes
train_tracker = MetricsTracker(
- cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
+ effective_batch_size,
+ dataset.num_frames,
+ dataset.num_episodes,
+ train_metrics,
+ initial_step=step,
+ accelerator=accelerator,
)
- logging.info("Start offline training on a fixed dataset")
+ if is_main_process:
+ logging.info("Start offline training on a fixed dataset")
+
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
@@ -366,16 +415,15 @@ def train(cfg: TrainPipelineConfig):
batch,
optimizer,
cfg.optimizer.grad_clip_norm,
- grad_scaler=grad_scaler,
+ accelerator=accelerator,
lr_scheduler=lr_scheduler,
- use_amp=cfg.policy.use_amp,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
step += 1
train_tracker.step()
- is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
+ is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
@@ -389,69 +437,90 @@ def train(cfg: TrainPipelineConfig):
train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step:
- logging.info(f"Checkpoint policy after step {step}")
- checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
- save_checkpoint(
- checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor
- )
- update_last_checkpoint(checkpoint_dir)
- if wandb_logger:
- wandb_logger.log_policy(checkpoint_dir)
-
- if cfg.env and is_eval_step:
- step_id = get_step_identifier(step, cfg.steps)
- logging.info(f"Eval policy at step {step}")
- with (
- torch.no_grad(),
- torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
- ):
- eval_info = eval_policy_all(
- envs=eval_env, # dict[suite][task_id] -> vec_env
- policy=policy,
+ if is_main_process:
+ logging.info(f"Checkpoint policy after step {step}")
+ checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
+ save_checkpoint(
+ checkpoint_dir=checkpoint_dir,
+ step=step,
+ cfg=cfg,
+ policy=accelerator.unwrap_model(policy),
+ optimizer=optimizer,
+ scheduler=lr_scheduler,
preprocessor=preprocessor,
postprocessor=postprocessor,
- n_episodes=cfg.eval.n_episodes,
- videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
- max_episodes_rendered=4,
- start_seed=cfg.seed,
- max_parallel_tasks=cfg.env.max_parallel_tasks,
)
- # overall metrics (suite-agnostic)
- aggregated = eval_info["overall"]
+ update_last_checkpoint(checkpoint_dir)
+ if wandb_logger:
+ wandb_logger.log_policy(checkpoint_dir)
- # optional: per-suite logging
- for suite, suite_info in eval_info.items():
- logging.info("Suite %s aggregated: %s", suite, suite_info)
+ accelerator.wait_for_everyone()
- # meters/tracker
- eval_metrics = {
- "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
- "pc_success": AverageMeter("success", ":.1f"),
- "eval_s": AverageMeter("eval_s", ":.3f"),
- }
- eval_tracker = MetricsTracker(
- cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
- )
- eval_tracker.eval_s = aggregated.pop("eval_s")
- eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
- eval_tracker.pc_success = aggregated.pop("pc_success")
- if wandb_logger:
- wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
- wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
- wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
+ if cfg.env and is_eval_step:
+ if is_main_process:
+ step_id = get_step_identifier(step, cfg.steps)
+ logging.info(f"Eval policy at step {step}")
+ with torch.no_grad(), accelerator.autocast():
+ eval_info = eval_policy_all(
+ envs=eval_env, # dict[suite][task_id] -> vec_env
+ policy=accelerator.unwrap_model(policy),
+ preprocessor=preprocessor,
+ postprocessor=postprocessor,
+ n_episodes=cfg.eval.n_episodes,
+ videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
+ max_episodes_rendered=4,
+ start_seed=cfg.seed,
+ max_parallel_tasks=cfg.env.max_parallel_tasks,
+ )
+ # overall metrics (suite-agnostic)
+ aggregated = eval_info["overall"]
+
+ # optional: per-suite logging
+ for suite, suite_info in eval_info.items():
+ logging.info("Suite %s aggregated: %s", suite, suite_info)
+
+ # meters/tracker
+ eval_metrics = {
+ "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
+ "pc_success": AverageMeter("success", ":.1f"),
+ "eval_s": AverageMeter("eval_s", ":.3f"),
+ }
+ eval_tracker = MetricsTracker(
+ cfg.batch_size,
+ dataset.num_frames,
+ dataset.num_episodes,
+ eval_metrics,
+ initial_step=step,
+ accelerator=accelerator,
+ )
+ eval_tracker.eval_s = aggregated.pop("eval_s")
+ eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
+ eval_tracker.pc_success = aggregated.pop("pc_success")
+ if wandb_logger:
+ wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
+ wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
+ wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
+
+ accelerator.wait_for_everyone()
if eval_env:
close_envs(eval_env)
- logging.info("End of training")
- if cfg.policy.push_to_hub:
- policy.push_model_to_hub(cfg)
- preprocessor.push_to_hub(cfg.policy.repo_id)
- postprocessor.push_to_hub(cfg.policy.repo_id)
+ if is_main_process:
+ logging.info("End of training")
+
+ if cfg.policy.push_to_hub:
+ unwrapped_policy = accelerator.unwrap_model(policy)
+ unwrapped_policy.push_model_to_hub(cfg)
+ preprocessor.push_to_hub(cfg.policy.repo_id)
+ postprocessor.push_to_hub(cfg.policy.repo_id)
+
+ # Properly clean up the distributed process group
+ accelerator.wait_for_everyone()
+ accelerator.end_training()
def main():
- init_logging()
train()
diff --git a/src/lerobot/teleoperators/homunculus/homunculus_arm.py b/src/lerobot/teleoperators/homunculus/homunculus_arm.py
index 21d73de2e..43116f5c0 100644
--- a/src/lerobot/teleoperators/homunculus/homunculus_arm.py
+++ b/src/lerobot/teleoperators/homunculus/homunculus_arm.py
@@ -270,8 +270,15 @@ class HomunculusArm(Teleoperator):
raw_values = None
with self.serial_lock:
if self.serial.in_waiting > 0:
- self.serial.flush()
- raw_values = self.serial.readline().decode("utf-8").strip().split(" ")
+ lines = []
+ while self.serial.in_waiting > 0:
+ line = self.serial.read_until().decode("utf-8").strip()
+ if line:
+ lines.append(line.split(" "))
+
+ if lines:
+ raw_values = lines[-1]
+
if raw_values is None or len(raw_values) != 21: # 16 raw + 5 angle values
continue
diff --git a/src/lerobot/teleoperators/homunculus/homunculus_glove.py b/src/lerobot/teleoperators/homunculus/homunculus_glove.py
index 251ecf56d..fefeec1e8 100644
--- a/src/lerobot/teleoperators/homunculus/homunculus_glove.py
+++ b/src/lerobot/teleoperators/homunculus/homunculus_glove.py
@@ -304,8 +304,15 @@ class HomunculusGlove(Teleoperator):
positions = None
with self.serial_lock:
if self.serial.in_waiting > 0:
- self.serial.flush()
- positions = self.serial.readline().decode("utf-8").strip().split(" ")
+ lines = []
+ while self.serial.in_waiting > 0:
+ line = self.serial.read_until().decode("utf-8").strip()
+ if line:
+ lines.append(line.split(" "))
+
+ if lines:
+ positions = lines[-1]
+
if positions is None or len(positions) != len(self.joints):
continue
diff --git a/src/lerobot/templates/lerobot_modelcard_template.md b/src/lerobot/templates/lerobot_modelcard_template.md
index 34af282b0..c59cf4183 100644
--- a/src/lerobot/templates/lerobot_modelcard_template.md
+++ b/src/lerobot/templates/lerobot_modelcard_template.md
@@ -19,8 +19,6 @@
[Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation.
{% elif model_name == "vqbet" %}
[VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills.
-{% elif model_name == "pi0fast" %}
-[Pi0-Fast](https://huggingface.co/papers/2501.09747) is a variant of Pi0 that uses a new tokenization method called FAST, which enables training of an autoregressive vision-language-action policy for high-frequency robotic tasks with improved performance and reduced training time.
{% elif model_name == "pi0" %}
**π₀ (Pi0)**
diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py
index 17371921c..7cfe177ef 100644
--- a/src/lerobot/utils/control_utils.py
+++ b/src/lerobot/utils/control_utils.py
@@ -31,6 +31,7 @@ from deepdiff import DeepDiff
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import DEFAULT_FEATURES
from lerobot.policies.pretrained import PreTrainedPolicy
+from lerobot.policies.utils import prepare_observation_for_inference
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.robots import Robot
@@ -102,17 +103,7 @@ def predict_action(
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
- for name in observation:
- observation[name] = torch.from_numpy(observation[name])
- if "image" in name:
- observation[name] = observation[name].type(torch.float32) / 255
- observation[name] = observation[name].permute(2, 0, 1).contiguous()
- observation[name] = observation[name].unsqueeze(0)
- observation[name] = observation[name].to(device)
-
- observation["task"] = task if task else ""
- observation["robot_type"] = robot_type if robot_type else ""
-
+ observation = prepare_observation_for_inference(observation, device, task, robot_type)
observation = preprocessor(observation)
# Compute the next action with the policy
@@ -121,12 +112,6 @@ def predict_action(
action = postprocessor(action)
- # Remove batch dimension
- action = action.squeeze(0)
-
- # Move to cpu, if not already the case
- action = action.to("cpu")
-
return action
diff --git a/src/lerobot/utils/logging_utils.py b/src/lerobot/utils/logging_utils.py
index b6404e66d..c4c1f42e0 100644
--- a/src/lerobot/utils/logging_utils.py
+++ b/src/lerobot/utils/logging_utils.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from collections.abc import Callable
from typing import Any
from lerobot.utils.utils import format_big_number
@@ -84,6 +85,7 @@ class MetricsTracker:
"samples",
"episodes",
"epochs",
+ "accelerator",
]
def __init__(
@@ -93,6 +95,7 @@ class MetricsTracker:
num_episodes: int,
metrics: dict[str, AverageMeter],
initial_step: int = 0,
+ accelerator: Callable | None = None,
):
self.__dict__.update(dict.fromkeys(self.__keys__))
self._batch_size = batch_size
@@ -106,6 +109,7 @@ class MetricsTracker:
self.samples = self.steps * self._batch_size
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
+ self.accelerator = accelerator
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
if name in self.__dict__:
@@ -128,7 +132,7 @@ class MetricsTracker:
Updates metrics that depend on 'step' for one step.
"""
self.steps += 1
- self.samples += self._batch_size
+ self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
diff --git a/src/lerobot/utils/random_utils.py b/src/lerobot/utils/random_utils.py
index 1bb1f0631..b34d357aa 100644
--- a/src/lerobot/utils/random_utils.py
+++ b/src/lerobot/utils/random_utils.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
-from collections.abc import Generator
+from collections.abc import Callable, Generator
from contextlib import contextmanager
from pathlib import Path
from typing import Any
@@ -164,14 +164,20 @@ def set_rng_state(random_state_dict: dict[str, Any]):
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
-def set_seed(seed) -> None:
+def set_seed(seed, accelerator: Callable | None = None) -> None:
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
+
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
+ if accelerator:
+ from accelerate.utils import set_seed as _accelerate_set_seed
+
+ _accelerate_set_seed(seed)
+
@contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]:
diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py
index 8777d5a9d..4447a1fcf 100644
--- a/src/lerobot/utils/utils.py
+++ b/src/lerobot/utils/utils.py
@@ -27,6 +27,8 @@ from statistics import mean
import numpy as np
import torch
+from accelerate import Accelerator
+from datasets.utils.logging import disable_progress_bar, enable_progress_bar
def inside_slurm():
@@ -109,36 +111,50 @@ def init_logging(
display_pid: bool = False,
console_level: str = "INFO",
file_level: str = "DEBUG",
+ accelerator: Accelerator | None = None,
):
+ """Initialize logging configuration for LeRobot.
+
+ In multi-GPU training, only the main process logs to console to avoid duplicate output.
+ Non-main processes have console logging suppressed but can still log to file.
+
+ Args:
+ log_file: Optional file path to write logs to
+ display_pid: Include process ID in log messages (useful for debugging multi-process)
+ console_level: Logging level for console output
+ file_level: Logging level for file output
+ accelerator: Optional Accelerator instance (for multi-GPU detection)
+ """
+
def custom_format(record: logging.LogRecord) -> str:
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}"
-
- # NOTE: Display PID is useful for multi-process logging.
- if display_pid:
- pid_str = f"[PID: {os.getpid()}]"
- message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
- else:
- message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
- return message
+ pid_str = f"[PID: {os.getpid()}] " if display_pid else ""
+ return f"{record.levelname} {pid_str}{dt} {fnameline[-15:]:>15} {record.getMessage()}"
formatter = logging.Formatter()
formatter.format = custom_format
logger = logging.getLogger()
- logger.setLevel(logging.NOTSET) # Set the logger to the lowest level to capture all messages
+ logger.setLevel(logging.NOTSET)
- # Remove unused default handlers
- for handler in logger.handlers[:]:
- logger.removeHandler(handler)
+ # Clear any existing handlers
+ logger.handlers.clear()
- # Write logs to console
- console_handler = logging.StreamHandler()
- console_handler.setFormatter(formatter)
- console_handler.setLevel(console_level.upper())
- logger.addHandler(console_handler)
+ # Determine if this is a non-main process in distributed training
+ is_main_process = accelerator.is_main_process if accelerator is not None else True
+
+ # Console logging (main process only)
+ if is_main_process:
+ console_handler = logging.StreamHandler()
+ console_handler.setFormatter(formatter)
+ console_handler.setLevel(console_level.upper())
+ logger.addHandler(console_handler)
+ else:
+ # Suppress console output for non-main processes
+ logger.addHandler(logging.NullHandler())
+ logger.setLevel(logging.ERROR)
- # Additionally write logs to file
if log_file is not None:
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
@@ -247,6 +263,25 @@ def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float):
return days, hours, minutes, seconds
+class SuppressProgressBars:
+ """
+ Context manager to suppress progress bars.
+
+ Example
+ --------
+ ```python
+ with SuppressProgressBars():
+ # Code that would normally show progress bars
+ ```
+ """
+
+ def __enter__(self):
+ disable_progress_bar()
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ enable_progress_bar()
+
+
class TimerManager:
"""
Lightweight utility to measure elapsed time.
diff --git a/tests/async_inference/test_e2e.py b/tests/async_inference/test_e2e.py
index ebaef2ef1..11941ce32 100644
--- a/tests/async_inference/test_e2e.py
+++ b/tests/async_inference/test_e2e.py
@@ -139,7 +139,6 @@ def test_async_inference_e2e(monkeypatch):
policy_type="test",
pretrained_name_or_path="test",
actions_per_chunk=20,
- verify_robot_cameras=False,
)
client = RobotClient(client_config)
diff --git a/tests/async_inference/test_robot_client.py b/tests/async_inference/test_robot_client.py
index dfdb8ce42..5b138d91b 100644
--- a/tests/async_inference/test_robot_client.py
+++ b/tests/async_inference/test_robot_client.py
@@ -51,7 +51,6 @@ def robot_client():
policy_type="test",
pretrained_name_or_path="test",
actions_per_chunk=20,
- verify_robot_cameras=False,
)
client = RobotClient(test_config)
diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py
index 4f316f80e..b710a3a4b 100644
--- a/tests/datasets/test_aggregate.py
+++ b/tests/datasets/test_aggregate.py
@@ -181,6 +181,54 @@ def assert_dataset_iteration_works(aggr_ds):
pass
+def assert_video_timestamps_within_bounds(aggr_ds):
+ """Test that all video timestamps are within valid bounds for their respective video files.
+
+ This catches bugs where timestamps point to frames beyond the actual video length,
+ which would cause "Invalid frame index" errors during data loading.
+ """
+ try:
+ from torchcodec.decoders import VideoDecoder
+ except ImportError:
+ return
+
+ for ep_idx in range(aggr_ds.num_episodes):
+ ep = aggr_ds.meta.episodes[ep_idx]
+
+ for vid_key in aggr_ds.meta.video_keys:
+ from_ts = ep[f"videos/{vid_key}/from_timestamp"]
+ to_ts = ep[f"videos/{vid_key}/to_timestamp"]
+ video_path = aggr_ds.root / aggr_ds.meta.get_video_file_path(ep_idx, vid_key)
+
+ if not video_path.exists():
+ continue
+
+ from_frame_idx = round(from_ts * aggr_ds.fps)
+ to_frame_idx = round(to_ts * aggr_ds.fps)
+
+ try:
+ decoder = VideoDecoder(str(video_path))
+ num_frames = len(decoder)
+
+ # Verify timestamps don't exceed video bounds
+ assert from_frame_idx >= 0, (
+ f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) < 0"
+ )
+ assert from_frame_idx < num_frames, (
+ f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= video frames ({num_frames})"
+ )
+ assert to_frame_idx <= num_frames, (
+ f"Episode {ep_idx}, {vid_key}: to_frame_idx ({to_frame_idx}) > video frames ({num_frames})"
+ )
+ assert from_frame_idx < to_frame_idx, (
+ f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= to_frame_idx ({to_frame_idx})"
+ )
+ except Exception as e:
+ raise AssertionError(
+ f"Failed to verify timestamps for episode {ep_idx}, {vid_key}: {e}"
+ ) from e
+
+
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
"""Test basic aggregation functionality with standard parameters."""
ds_0_num_frames = 400
@@ -227,6 +275,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
+ assert_video_timestamps_within_bounds(aggr_ds)
assert_dataset_iteration_works(aggr_ds)
@@ -277,6 +326,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
+ assert_video_timestamps_within_bounds(aggr_ds)
assert_dataset_iteration_works(aggr_ds)
# Check that multiple files were actually created due to small size limits
@@ -290,3 +340,43 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
if video_dir.exists():
video_files = list(video_dir.rglob("*.mp4"))
assert len(video_files) > 1, "Small file size limits should create multiple video files"
+
+
+def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
+ """Regression test for video timestamp bug when merging datasets.
+
+ This test specifically checks that video timestamps are correctly calculated
+ and accumulated when merging multiple datasets.
+ """
+ datasets = []
+ for i in range(3):
+ ds = lerobot_dataset_factory(
+ root=tmp_path / f"regression_{i}",
+ repo_id=f"{DUMMY_REPO_ID}_regression_{i}",
+ total_episodes=2,
+ total_frames=100,
+ )
+ datasets.append(ds)
+
+ aggregate_datasets(
+ repo_ids=[ds.repo_id for ds in datasets],
+ roots=[ds.root for ds in datasets],
+ aggr_repo_id=f"{DUMMY_REPO_ID}_regression_aggr",
+ aggr_root=tmp_path / "regression_aggr",
+ )
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.return_value = str(tmp_path / "regression_aggr")
+ aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr")
+
+ assert_video_timestamps_within_bounds(aggr_ds)
+
+ for i in range(len(aggr_ds)):
+ item = aggr_ds[i]
+ for key in aggr_ds.meta.video_keys:
+ assert key in item, f"Video key {key} missing from item {i}"
+ assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py
new file mode 100644
index 000000000..8bc1dbf6b
--- /dev/null
+++ b/tests/datasets/test_dataset_tools.py
@@ -0,0 +1,1049 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for dataset tools utilities."""
+
+from unittest.mock import patch
+
+import numpy as np
+import pytest
+import torch
+
+from lerobot.datasets.dataset_tools import (
+ add_features,
+ delete_episodes,
+ merge_datasets,
+ modify_features,
+ remove_feature,
+ split_dataset,
+)
+
+
+@pytest.fixture
+def sample_dataset(tmp_path, empty_lerobot_dataset_factory):
+ """Create a sample dataset for testing."""
+ features = {
+ "action": {"dtype": "float32", "shape": (6,), "names": None},
+ "observation.state": {"dtype": "float32", "shape": (4,), "names": None},
+ "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
+ }
+
+ dataset = empty_lerobot_dataset_factory(
+ root=tmp_path / "test_dataset",
+ features=features,
+ )
+
+ for ep_idx in range(5):
+ for _ in range(10):
+ frame = {
+ "action": np.random.randn(6).astype(np.float32),
+ "observation.state": np.random.randn(4).astype(np.float32),
+ "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
+ "task": f"task_{ep_idx % 2}",
+ }
+ dataset.add_frame(frame)
+ dataset.save_episode()
+
+ dataset.finalize()
+ return dataset
+
+
+def test_delete_single_episode(sample_dataset, tmp_path):
+ """Test deleting a single episode."""
+ output_dir = tmp_path / "filtered"
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.return_value = str(output_dir)
+
+ new_dataset = delete_episodes(
+ sample_dataset,
+ episode_indices=[2],
+ output_dir=output_dir,
+ )
+
+ assert new_dataset.meta.total_episodes == 4
+ assert new_dataset.meta.total_frames == 40
+
+ episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}
+ assert episode_indices == {0, 1, 2, 3}
+
+ assert len(new_dataset) == 40
+
+
+def test_delete_multiple_episodes(sample_dataset, tmp_path):
+ """Test deleting multiple episodes."""
+ output_dir = tmp_path / "filtered"
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.return_value = str(output_dir)
+
+ new_dataset = delete_episodes(
+ sample_dataset,
+ episode_indices=[1, 3],
+ output_dir=output_dir,
+ )
+
+ assert new_dataset.meta.total_episodes == 3
+ assert new_dataset.meta.total_frames == 30
+
+ episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}
+ assert episode_indices == {0, 1, 2}
+
+
+def test_delete_invalid_episodes(sample_dataset, tmp_path):
+ """Test error handling for invalid episode indices."""
+ with pytest.raises(ValueError, match="Invalid episode indices"):
+ delete_episodes(
+ sample_dataset,
+ episode_indices=[10, 20],
+ output_dir=tmp_path / "filtered",
+ )
+
+
+def test_delete_all_episodes(sample_dataset, tmp_path):
+ """Test error when trying to delete all episodes."""
+ with pytest.raises(ValueError, match="Cannot delete all episodes"):
+ delete_episodes(
+ sample_dataset,
+ episode_indices=list(range(5)),
+ output_dir=tmp_path / "filtered",
+ )
+
+
+def test_delete_empty_list(sample_dataset, tmp_path):
+ """Test error when no episodes specified."""
+ with pytest.raises(ValueError, match="No episodes to delete"):
+ delete_episodes(
+ sample_dataset,
+ episode_indices=[],
+ output_dir=tmp_path / "filtered",
+ )
+
+
+def test_split_by_episodes(sample_dataset, tmp_path):
+ """Test splitting dataset by specific episode indices."""
+ splits = {
+ "train": [0, 1, 2],
+ "val": [3, 4],
+ }
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+
+ def mock_snapshot(repo_id, **kwargs):
+ if "train" in repo_id:
+ return str(tmp_path / f"{sample_dataset.repo_id}_train")
+ elif "val" in repo_id:
+ return str(tmp_path / f"{sample_dataset.repo_id}_val")
+ return str(kwargs.get("local_dir", tmp_path))
+
+ mock_snapshot_download.side_effect = mock_snapshot
+
+ result = split_dataset(
+ sample_dataset,
+ splits=splits,
+ output_dir=tmp_path,
+ )
+
+ assert set(result.keys()) == {"train", "val"}
+
+ assert result["train"].meta.total_episodes == 3
+ assert result["train"].meta.total_frames == 30
+
+ assert result["val"].meta.total_episodes == 2
+ assert result["val"].meta.total_frames == 20
+
+ train_episodes = {int(idx.item()) for idx in result["train"].hf_dataset["episode_index"]}
+ assert train_episodes == {0, 1, 2}
+
+ val_episodes = {int(idx.item()) for idx in result["val"].hf_dataset["episode_index"]}
+ assert val_episodes == {0, 1}
+
+
+def test_split_by_fractions(sample_dataset, tmp_path):
+ """Test splitting dataset by fractions."""
+ splits = {
+ "train": 0.6,
+ "val": 0.4,
+ }
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+
+ def mock_snapshot(repo_id, **kwargs):
+ for split_name in splits:
+ if split_name in repo_id:
+ return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
+ return str(kwargs.get("local_dir", tmp_path))
+
+ mock_snapshot_download.side_effect = mock_snapshot
+
+ result = split_dataset(
+ sample_dataset,
+ splits=splits,
+ output_dir=tmp_path,
+ )
+
+ assert result["train"].meta.total_episodes == 3
+ assert result["val"].meta.total_episodes == 2
+
+
+def test_split_overlapping_episodes(sample_dataset, tmp_path):
+ """Test error when episodes appear in multiple splits."""
+ splits = {
+ "train": [0, 1, 2],
+ "val": [2, 3, 4],
+ }
+
+ with pytest.raises(ValueError, match="Episodes cannot appear in multiple splits"):
+ split_dataset(sample_dataset, splits=splits, output_dir=tmp_path)
+
+
+def test_split_invalid_fractions(sample_dataset, tmp_path):
+ """Test error when fractions sum to more than 1."""
+ splits = {
+ "train": 0.7,
+ "val": 0.5,
+ }
+
+ with pytest.raises(ValueError, match="Split fractions must sum to <= 1.0"):
+ split_dataset(sample_dataset, splits=splits, output_dir=tmp_path)
+
+
+def test_split_empty(sample_dataset, tmp_path):
+ """Test error with empty splits."""
+ with pytest.raises(ValueError, match="No splits provided"):
+ split_dataset(sample_dataset, splits={}, output_dir=tmp_path)
+
+
+def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
+ """Test merging two datasets."""
+ features = {
+ "action": {"dtype": "float32", "shape": (6,), "names": None},
+ "observation.state": {"dtype": "float32", "shape": (4,), "names": None},
+ "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
+ }
+
+ dataset2 = empty_lerobot_dataset_factory(
+ root=tmp_path / "test_dataset2",
+ features=features,
+ )
+
+ for ep_idx in range(3):
+ for _ in range(10):
+ frame = {
+ "action": np.random.randn(6).astype(np.float32),
+ "observation.state": np.random.randn(4).astype(np.float32),
+ "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
+ "task": f"task_{ep_idx % 2}",
+ }
+ dataset2.add_frame(frame)
+ dataset2.save_episode()
+ dataset2.finalize()
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
+
+ merged = merge_datasets(
+ [sample_dataset, dataset2],
+ output_repo_id="merged_dataset",
+ output_dir=tmp_path / "merged_dataset",
+ )
+
+ assert merged.meta.total_episodes == 8 # 5 + 3
+ assert merged.meta.total_frames == 80 # 50 + 30
+
+ episode_indices = sorted({int(idx.item()) for idx in merged.hf_dataset["episode_index"]})
+ assert episode_indices == list(range(8))
+
+
+def test_merge_empty_list(tmp_path):
+ """Test error when merging empty list."""
+ with pytest.raises(ValueError, match="No datasets to merge"):
+ merge_datasets([], output_repo_id="merged", output_dir=tmp_path)
+
+
+def test_add_features_with_values(sample_dataset, tmp_path):
+ """Test adding a feature with pre-computed values."""
+ num_frames = sample_dataset.meta.total_frames
+ reward_values = np.random.randn(num_frames, 1).astype(np.float32)
+
+ feature_info = {
+ "dtype": "float32",
+ "shape": (1,),
+ "names": None,
+ }
+ features = {
+ "reward": (reward_values, feature_info),
+ }
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.return_value = str(tmp_path / "with_reward")
+
+ new_dataset = add_features(
+ dataset=sample_dataset,
+ features=features,
+ output_dir=tmp_path / "with_reward",
+ )
+
+ assert "reward" in new_dataset.meta.features
+ assert new_dataset.meta.features["reward"] == feature_info
+
+ assert len(new_dataset) == num_frames
+ sample_item = new_dataset[0]
+ assert "reward" in sample_item
+ assert isinstance(sample_item["reward"], torch.Tensor)
+
+
+def test_add_features_with_callable(sample_dataset, tmp_path):
+ """Test adding a feature with a callable."""
+
+ def compute_reward(frame_dict, episode_idx, frame_idx):
+ return float(episode_idx * 10 + frame_idx)
+
+ feature_info = {
+ "dtype": "float32",
+ "shape": (1,),
+ "names": None,
+ }
+ features = {
+ "reward": (compute_reward, feature_info),
+ }
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.return_value = str(tmp_path / "with_reward")
+
+ new_dataset = add_features(
+ dataset=sample_dataset,
+ features=features,
+ output_dir=tmp_path / "with_reward",
+ )
+
+ assert "reward" in new_dataset.meta.features
+
+ items = [new_dataset[i] for i in range(10)]
+ first_episode_items = [item for item in items if item["episode_index"] == 0]
+ assert len(first_episode_items) == 10
+
+ first_frame = first_episode_items[0]
+ assert first_frame["frame_index"] == 0
+ assert float(first_frame["reward"]) == 0.0
+
+
+def test_add_existing_feature(sample_dataset, tmp_path):
+ """Test error when adding an existing feature."""
+ feature_info = {"dtype": "float32", "shape": (1,)}
+ features = {
+ "action": (np.zeros(50), feature_info),
+ }
+
+ with pytest.raises(ValueError, match="Feature 'action' already exists"):
+ add_features(
+ dataset=sample_dataset,
+ features=features,
+ output_dir=tmp_path / "modified",
+ )
+
+
+def test_add_feature_invalid_info(sample_dataset, tmp_path):
+ """Test error with invalid feature info."""
+ with pytest.raises(ValueError, match="feature_info for 'reward' must contain keys"):
+ add_features(
+ dataset=sample_dataset,
+ features={
+ "reward": (np.zeros(50), {"dtype": "float32"}),
+ },
+ output_dir=tmp_path / "modified",
+ )
+
+
+def test_modify_features_add_and_remove(sample_dataset, tmp_path):
+ """Test modifying features by adding and removing simultaneously."""
+ feature_info = {"dtype": "float32", "shape": (1,), "names": None}
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.return_value = str(tmp_path / "modified")
+
+ # First add a feature we'll later remove
+ dataset_with_reward = add_features(
+ sample_dataset,
+ features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)},
+ output_dir=tmp_path / "with_reward",
+ )
+
+ # Now use modify_features to add "success" and remove "reward" in one pass
+ modified_dataset = modify_features(
+ dataset_with_reward,
+ add_features={
+ "success": (np.random.randn(50, 1).astype(np.float32), feature_info),
+ },
+ remove_features="reward",
+ output_dir=tmp_path / "modified",
+ )
+
+ assert "success" in modified_dataset.meta.features
+ assert "reward" not in modified_dataset.meta.features
+ assert len(modified_dataset) == 50
+
+
+def test_modify_features_only_add(sample_dataset, tmp_path):
+ """Test that modify_features works with only add_features."""
+ feature_info = {"dtype": "float32", "shape": (1,), "names": None}
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.return_value = str(tmp_path / "modified")
+
+ modified_dataset = modify_features(
+ sample_dataset,
+ add_features={
+ "reward": (np.random.randn(50, 1).astype(np.float32), feature_info),
+ },
+ output_dir=tmp_path / "modified",
+ )
+
+ assert "reward" in modified_dataset.meta.features
+ assert len(modified_dataset) == 50
+
+
+def test_modify_features_only_remove(sample_dataset, tmp_path):
+ """Test that modify_features works with only remove_features."""
+ feature_info = {"dtype": "float32", "shape": (1,), "names": None}
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
+
+ dataset_with_reward = add_features(
+ sample_dataset,
+ features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)},
+ output_dir=tmp_path / "with_reward",
+ )
+
+ modified_dataset = modify_features(
+ dataset_with_reward,
+ remove_features="reward",
+ output_dir=tmp_path / "modified",
+ )
+
+ assert "reward" not in modified_dataset.meta.features
+
+
+def test_modify_features_no_changes(sample_dataset, tmp_path):
+ """Test error when modify_features is called with no changes."""
+ with pytest.raises(ValueError, match="Must specify at least one of add_features or remove_features"):
+ modify_features(
+ sample_dataset,
+ output_dir=tmp_path / "modified",
+ )
+
+
+def test_remove_single_feature(sample_dataset, tmp_path):
+ """Test removing a single feature."""
+ feature_info = {"dtype": "float32", "shape": (1,), "names": None}
+ features = {
+ "reward": (np.random.randn(50, 1).astype(np.float32), feature_info),
+ }
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
+
+ dataset_with_reward = add_features(
+ dataset=sample_dataset,
+ features=features,
+ output_dir=tmp_path / "with_reward",
+ )
+
+ dataset_without_reward = remove_feature(
+ dataset_with_reward,
+ feature_names="reward",
+ output_dir=tmp_path / "without_reward",
+ )
+
+ assert "reward" not in dataset_without_reward.meta.features
+
+ sample_item = dataset_without_reward[0]
+ assert "reward" not in sample_item
+
+
+def test_remove_multiple_features(sample_dataset, tmp_path):
+ """Test removing multiple features at once."""
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
+
+ dataset = sample_dataset
+ features = {}
+ for feature_name in ["reward", "success"]:
+ feature_info = {"dtype": "float32", "shape": (1,), "names": None}
+ features[feature_name] = (
+ np.random.randn(dataset.meta.total_frames, 1).astype(np.float32),
+ feature_info,
+ )
+
+ dataset_with_features = add_features(
+ dataset, features=features, output_dir=tmp_path / "with_features"
+ )
+ dataset_clean = remove_feature(
+ dataset_with_features, feature_names=["reward", "success"], output_dir=tmp_path / "clean"
+ )
+
+ assert "reward" not in dataset_clean.meta.features
+ assert "success" not in dataset_clean.meta.features
+
+
+def test_remove_nonexistent_feature(sample_dataset, tmp_path):
+ """Test error when removing non-existent feature."""
+ with pytest.raises(ValueError, match="Feature 'nonexistent' not found"):
+ remove_feature(
+ sample_dataset,
+ feature_names="nonexistent",
+ output_dir=tmp_path / "modified",
+ )
+
+
+def test_remove_required_feature(sample_dataset, tmp_path):
+ """Test error when trying to remove required features."""
+ with pytest.raises(ValueError, match="Cannot remove required features"):
+ remove_feature(
+ sample_dataset,
+ feature_names="timestamp",
+ output_dir=tmp_path / "modified",
+ )
+
+
+def test_remove_camera_feature(sample_dataset, tmp_path):
+ """Test removing a camera feature."""
+ camera_keys = sample_dataset.meta.camera_keys
+ if not camera_keys:
+ pytest.skip("No camera keys in dataset")
+
+ camera_to_remove = camera_keys[0]
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.return_value = str(tmp_path / "without_camera")
+
+ dataset_without_camera = remove_feature(
+ sample_dataset,
+ feature_names=camera_to_remove,
+ output_dir=tmp_path / "without_camera",
+ )
+
+ assert camera_to_remove not in dataset_without_camera.meta.features
+ assert camera_to_remove not in dataset_without_camera.meta.camera_keys
+
+ sample_item = dataset_without_camera[0]
+ assert camera_to_remove not in sample_item
+
+
+def test_complex_workflow_integration(sample_dataset, tmp_path):
+ """Test a complex workflow combining multiple operations."""
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
+
+ dataset = add_features(
+ sample_dataset,
+ features={
+ "reward": (
+ np.random.randn(50, 1).astype(np.float32),
+ {"dtype": "float32", "shape": (1,), "names": None},
+ )
+ },
+ output_dir=tmp_path / "step1",
+ )
+
+ dataset = delete_episodes(
+ dataset,
+ episode_indices=[2],
+ output_dir=tmp_path / "step2",
+ )
+
+ splits = split_dataset(
+ dataset,
+ splits={"train": 0.75, "val": 0.25},
+ output_dir=tmp_path / "step3",
+ )
+
+ merged = merge_datasets(
+ list(splits.values()),
+ output_repo_id="final_dataset",
+ output_dir=tmp_path / "step4",
+ )
+
+ assert merged.meta.total_episodes == 4
+ assert merged.meta.total_frames == 40
+ assert "reward" in merged.meta.features
+
+ assert len(merged) == 40
+ sample_item = merged[0]
+ assert "reward" in sample_item
+
+
+def test_delete_episodes_preserves_stats(sample_dataset, tmp_path):
+ """Test that deleting episodes preserves statistics correctly."""
+ output_dir = tmp_path / "filtered"
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.return_value = str(output_dir)
+
+ new_dataset = delete_episodes(
+ sample_dataset,
+ episode_indices=[2],
+ output_dir=output_dir,
+ )
+
+ assert new_dataset.meta.stats is not None
+ for feature in ["action", "observation.state"]:
+ assert feature in new_dataset.meta.stats
+ assert "mean" in new_dataset.meta.stats[feature]
+ assert "std" in new_dataset.meta.stats[feature]
+
+
+def test_delete_episodes_preserves_tasks(sample_dataset, tmp_path):
+ """Test that tasks are preserved correctly after deletion."""
+ output_dir = tmp_path / "filtered"
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.return_value = str(output_dir)
+
+ new_dataset = delete_episodes(
+ sample_dataset,
+ episode_indices=[0],
+ output_dir=output_dir,
+ )
+
+ assert new_dataset.meta.tasks is not None
+ assert len(new_dataset.meta.tasks) == 2
+
+ tasks_in_dataset = {str(item["task"]) for item in new_dataset}
+ assert len(tasks_in_dataset) > 0
+
+
+def test_split_three_ways(sample_dataset, tmp_path):
+ """Test splitting dataset into three splits."""
+ splits = {
+ "train": 0.6,
+ "val": 0.2,
+ "test": 0.2,
+ }
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+
+ def mock_snapshot(repo_id, **kwargs):
+ for split_name in splits:
+ if split_name in repo_id:
+ return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
+ return str(kwargs.get("local_dir", tmp_path))
+
+ mock_snapshot_download.side_effect = mock_snapshot
+
+ result = split_dataset(
+ sample_dataset,
+ splits=splits,
+ output_dir=tmp_path,
+ )
+
+ assert set(result.keys()) == {"train", "val", "test"}
+ assert result["train"].meta.total_episodes == 3
+ assert result["val"].meta.total_episodes == 1
+ assert result["test"].meta.total_episodes == 1
+
+ total_frames = sum(ds.meta.total_frames for ds in result.values())
+ assert total_frames == sample_dataset.meta.total_frames
+
+
+def test_split_preserves_stats(sample_dataset, tmp_path):
+ """Test that statistics are preserved when splitting."""
+ splits = {"train": [0, 1, 2], "val": [3, 4]}
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+
+ def mock_snapshot(repo_id, **kwargs):
+ for split_name in splits:
+ if split_name in repo_id:
+ return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
+ return str(kwargs.get("local_dir", tmp_path))
+
+ mock_snapshot_download.side_effect = mock_snapshot
+
+ result = split_dataset(
+ sample_dataset,
+ splits=splits,
+ output_dir=tmp_path,
+ )
+
+ for split_ds in result.values():
+ assert split_ds.meta.stats is not None
+ for feature in ["action", "observation.state"]:
+ assert feature in split_ds.meta.stats
+ assert "mean" in split_ds.meta.stats[feature]
+ assert "std" in split_ds.meta.stats[feature]
+
+
+def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
+ """Test merging three datasets."""
+ features = {
+ "action": {"dtype": "float32", "shape": (6,), "names": None},
+ "observation.state": {"dtype": "float32", "shape": (4,), "names": None},
+ "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
+ }
+
+ datasets = [sample_dataset]
+
+ for i in range(2):
+ dataset = empty_lerobot_dataset_factory(
+ root=tmp_path / f"test_dataset{i + 2}",
+ features=features,
+ )
+
+ for ep_idx in range(2):
+ for _ in range(10):
+ frame = {
+ "action": np.random.randn(6).astype(np.float32),
+ "observation.state": np.random.randn(4).astype(np.float32),
+ "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
+ "task": f"task_{ep_idx}",
+ }
+ dataset.add_frame(frame)
+ dataset.save_episode()
+ dataset.finalize()
+
+ datasets.append(dataset)
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
+
+ merged = merge_datasets(
+ datasets,
+ output_repo_id="merged_dataset",
+ output_dir=tmp_path / "merged_dataset",
+ )
+
+ assert merged.meta.total_episodes == 9
+ assert merged.meta.total_frames == 90
+
+
+def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
+ """Test that statistics are computed for merged datasets."""
+ features = {
+ "action": {"dtype": "float32", "shape": (6,), "names": None},
+ "observation.state": {"dtype": "float32", "shape": (4,), "names": None},
+ "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
+ }
+
+ dataset2 = empty_lerobot_dataset_factory(
+ root=tmp_path / "test_dataset2",
+ features=features,
+ )
+
+ for ep_idx in range(3):
+ for _ in range(10):
+ frame = {
+ "action": np.random.randn(6).astype(np.float32),
+ "observation.state": np.random.randn(4).astype(np.float32),
+ "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
+ "task": f"task_{ep_idx % 2}",
+ }
+ dataset2.add_frame(frame)
+ dataset2.save_episode()
+ dataset2.finalize()
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
+
+ merged = merge_datasets(
+ [sample_dataset, dataset2],
+ output_repo_id="merged_dataset",
+ output_dir=tmp_path / "merged_dataset",
+ )
+
+ assert merged.meta.stats is not None
+ for feature in ["action", "observation.state"]:
+ assert feature in merged.meta.stats
+ assert "mean" in merged.meta.stats[feature]
+ assert "std" in merged.meta.stats[feature]
+
+
+def test_add_features_preserves_existing_stats(sample_dataset, tmp_path):
+ """Test that adding a feature preserves existing stats."""
+ num_frames = sample_dataset.meta.total_frames
+ reward_values = np.random.randn(num_frames, 1).astype(np.float32)
+
+ feature_info = {
+ "dtype": "float32",
+ "shape": (1,),
+ "names": None,
+ }
+ features = {
+ "reward": (reward_values, feature_info),
+ }
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.return_value = str(tmp_path / "with_reward")
+
+ new_dataset = add_features(
+ dataset=sample_dataset,
+ features=features,
+ output_dir=tmp_path / "with_reward",
+ )
+
+ assert new_dataset.meta.stats is not None
+ for feature in ["action", "observation.state"]:
+ assert feature in new_dataset.meta.stats
+ assert "mean" in new_dataset.meta.stats[feature]
+ assert "std" in new_dataset.meta.stats[feature]
+
+
+def test_remove_feature_updates_stats(sample_dataset, tmp_path):
+ """Test that removing a feature removes it from stats."""
+ feature_info = {"dtype": "float32", "shape": (1,), "names": None}
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
+
+ dataset_with_reward = add_features(
+ sample_dataset,
+ features={
+ "reward": (np.random.randn(50, 1).astype(np.float32), feature_info),
+ },
+ output_dir=tmp_path / "with_reward",
+ )
+
+ dataset_without_reward = remove_feature(
+ dataset_with_reward,
+ feature_names="reward",
+ output_dir=tmp_path / "without_reward",
+ )
+
+ if dataset_without_reward.meta.stats:
+ assert "reward" not in dataset_without_reward.meta.stats
+
+
+def test_delete_consecutive_episodes(sample_dataset, tmp_path):
+ """Test deleting consecutive episodes."""
+ output_dir = tmp_path / "filtered"
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.return_value = str(output_dir)
+
+ new_dataset = delete_episodes(
+ sample_dataset,
+ episode_indices=[1, 2, 3],
+ output_dir=output_dir,
+ )
+
+ assert new_dataset.meta.total_episodes == 2
+ assert new_dataset.meta.total_frames == 20
+
+ episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]})
+ assert episode_indices == [0, 1]
+
+
+def test_delete_first_and_last_episodes(sample_dataset, tmp_path):
+ """Test deleting first and last episodes."""
+ output_dir = tmp_path / "filtered"
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+ mock_snapshot_download.return_value = str(output_dir)
+
+ new_dataset = delete_episodes(
+ sample_dataset,
+ episode_indices=[0, 4],
+ output_dir=output_dir,
+ )
+
+ assert new_dataset.meta.total_episodes == 3
+ assert new_dataset.meta.total_frames == 30
+
+ episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]})
+ assert episode_indices == [0, 1, 2]
+
+
+def test_split_all_episodes_assigned(sample_dataset, tmp_path):
+ """Test that all episodes can be explicitly assigned to splits."""
+ splits = {
+ "split1": [0, 1],
+ "split2": [2, 3],
+ "split3": [4],
+ }
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+
+ def mock_snapshot(repo_id, **kwargs):
+ for split_name in splits:
+ if split_name in repo_id:
+ return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
+ return str(kwargs.get("local_dir", tmp_path))
+
+ mock_snapshot_download.side_effect = mock_snapshot
+
+ result = split_dataset(
+ sample_dataset,
+ splits=splits,
+ output_dir=tmp_path,
+ )
+
+ total_episodes = sum(ds.meta.total_episodes for ds in result.values())
+ assert total_episodes == sample_dataset.meta.total_episodes
+
+
+def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
+ """Test that modifying features preserves chunk_idx and file_idx from source dataset."""
+ feature_info = {"dtype": "float32", "shape": (1,), "names": None}
+
+ with (
+ patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
+ patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
+ ):
+ mock_get_safe_version.return_value = "v3.0"
+
+ def mock_snapshot(repo_id, **kwargs):
+ return str(kwargs.get("local_dir", tmp_path / repo_id.split("/")[-1]))
+
+ mock_snapshot_download.side_effect = mock_snapshot
+
+ # First split the dataset to create a non-zero starting chunk/file structure
+ splits = split_dataset(
+ sample_dataset,
+ splits={"train": [0, 1, 2], "val": [3, 4]},
+ output_dir=tmp_path / "splits",
+ )
+
+ train_dataset = splits["train"]
+
+ # Get original chunk/file indices from first episode
+ if train_dataset.meta.episodes is None:
+ from lerobot.datasets.utils import load_episodes
+
+ train_dataset.meta.episodes = load_episodes(train_dataset.meta.root)
+ original_chunk_indices = [ep["data/chunk_index"] for ep in train_dataset.meta.episodes]
+ original_file_indices = [ep["data/file_index"] for ep in train_dataset.meta.episodes]
+
+ # Now add a feature to the split dataset
+ modified_dataset = add_features(
+ train_dataset,
+ features={
+ "reward": (
+ np.random.randn(train_dataset.meta.total_frames, 1).astype(np.float32),
+ feature_info,
+ ),
+ },
+ output_dir=tmp_path / "modified",
+ )
+
+ # Check that chunk/file indices are preserved
+ if modified_dataset.meta.episodes is None:
+ from lerobot.datasets.utils import load_episodes
+
+ modified_dataset.meta.episodes = load_episodes(modified_dataset.meta.root)
+ new_chunk_indices = [ep["data/chunk_index"] for ep in modified_dataset.meta.episodes]
+ new_file_indices = [ep["data/file_index"] for ep in modified_dataset.meta.episodes]
+
+ assert new_chunk_indices == original_chunk_indices, "Chunk indices should be preserved"
+ assert new_file_indices == original_file_indices, "File indices should be preserved"
+ assert "reward" in modified_dataset.meta.features
diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py
index 2bc3bea43..e174c5789 100644
--- a/tests/datasets/test_datasets.py
+++ b/tests/datasets/test_datasets.py
@@ -806,6 +806,8 @@ def test_episode_index_distribution(tmp_path, empty_lerobot_dataset_factory):
dataset.add_frame({"state": torch.randn(2), "task": f"task_{episode_idx}"})
dataset.save_episode()
+ dataset.finalize()
+
# Load the dataset and check episode indices
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
@@ -855,6 +857,8 @@ def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_fact
dataset.add_frame({"state": torch.randn(3), ACTION: torch.randn(2), "task": tasks[episode_idx]})
dataset.save_episode()
+ dataset.finalize()
+
# Load and validate episode metadata
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
@@ -893,6 +897,8 @@ def test_data_consistency_across_episodes(tmp_path, empty_lerobot_dataset_factor
dataset.add_frame({"state": torch.randn(1), "task": "consistency_test"})
dataset.save_episode()
+ dataset.finalize()
+
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
# Check data consistency - no gaps or overlaps
@@ -944,6 +950,8 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory)
dataset.add_frame({"state": state_data, ACTION: action_data, "task": "stats_test"})
dataset.save_episode()
+ dataset.finalize()
+
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
# Check that statistics exist for all features
@@ -989,6 +997,8 @@ def test_episode_boundary_integrity(tmp_path, empty_lerobot_dataset_factory):
dataset.add_frame({"state": torch.tensor([float(frame_idx)]), "task": f"episode_{episode_idx}"})
dataset.save_episode()
+ dataset.finalize()
+
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
# Test episode boundaries
@@ -1031,6 +1041,8 @@ def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory):
dataset.add_frame({"state": torch.randn(1), "task": task})
dataset.save_episode()
+ dataset.finalize()
+
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
# Check that all unique tasks are in the tasks metadata
@@ -1056,3 +1068,134 @@ def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory):
# Check total number of tasks
assert loaded_dataset.meta.total_tasks == len(unique_tasks)
+
+
+def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory):
+ """Test that resuming dataset recording preserves previously recorded episodes.
+
+ This test validates the critical resume functionality by:
+ 1. Recording initial episodes and finalizing
+ 2. Reopening the dataset
+ 3. Recording additional episodes
+ 4. Verifying all data (old + new) is intact
+
+ This specifically tests the bug fix where parquet files were being overwritten
+ instead of appended to during resume.
+ """
+ features = {
+ "observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
+ "action": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
+ }
+
+ dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
+
+ initial_episodes = 2
+ frames_per_episode = 3
+
+ for ep_idx in range(initial_episodes):
+ for frame_idx in range(frames_per_episode):
+ dataset.add_frame(
+ {
+ "observation.state": torch.tensor([float(ep_idx), float(frame_idx)]),
+ "action": torch.tensor([0.5, 0.5]),
+ "task": f"task_{ep_idx}",
+ }
+ )
+ dataset.save_episode()
+
+ assert dataset.meta.total_episodes == initial_episodes
+ assert dataset.meta.total_frames == initial_episodes * frames_per_episode
+
+ dataset.finalize()
+ initial_root = dataset.root
+ initial_repo_id = dataset.repo_id
+ del dataset
+
+ dataset_verify = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0")
+ assert dataset_verify.meta.total_episodes == initial_episodes
+ assert dataset_verify.meta.total_frames == initial_episodes * frames_per_episode
+ assert len(dataset_verify.hf_dataset) == initial_episodes * frames_per_episode
+
+ for idx in range(len(dataset_verify.hf_dataset)):
+ item = dataset_verify[idx]
+ expected_ep = idx // frames_per_episode
+ expected_frame = idx % frames_per_episode
+ assert item["episode_index"].item() == expected_ep
+ assert item["frame_index"].item() == expected_frame
+ assert item["index"].item() == idx
+ assert item["observation.state"][0].item() == float(expected_ep)
+ assert item["observation.state"][1].item() == float(expected_frame)
+
+ del dataset_verify
+
+ # Phase 3: Resume recording - add more episodes
+ dataset_resumed = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0")
+
+ assert dataset_resumed.meta.total_episodes == initial_episodes
+ assert dataset_resumed.meta.total_frames == initial_episodes * frames_per_episode
+ assert dataset_resumed.latest_episode is None # Not recording yet
+ assert dataset_resumed.writer is None
+ assert dataset_resumed.meta.writer is None
+
+ additional_episodes = 2
+ for ep_idx in range(initial_episodes, initial_episodes + additional_episodes):
+ for frame_idx in range(frames_per_episode):
+ dataset_resumed.add_frame(
+ {
+ "observation.state": torch.tensor([float(ep_idx), float(frame_idx)]),
+ "action": torch.tensor([0.5, 0.5]),
+ "task": f"task_{ep_idx}",
+ }
+ )
+ dataset_resumed.save_episode()
+
+ total_episodes = initial_episodes + additional_episodes
+ total_frames = total_episodes * frames_per_episode
+ assert dataset_resumed.meta.total_episodes == total_episodes
+ assert dataset_resumed.meta.total_frames == total_frames
+
+ dataset_resumed.finalize()
+ del dataset_resumed
+
+ dataset_final = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0")
+
+ assert dataset_final.meta.total_episodes == total_episodes
+ assert dataset_final.meta.total_frames == total_frames
+ assert len(dataset_final.hf_dataset) == total_frames
+
+ for idx in range(total_frames):
+ item = dataset_final[idx]
+ expected_ep = idx // frames_per_episode
+ expected_frame = idx % frames_per_episode
+
+ assert item["episode_index"].item() == expected_ep, (
+ f"Frame {idx}: wrong episode_index. Expected {expected_ep}, got {item['episode_index'].item()}"
+ )
+ assert item["frame_index"].item() == expected_frame, (
+ f"Frame {idx}: wrong frame_index. Expected {expected_frame}, got {item['frame_index'].item()}"
+ )
+ assert item["index"].item() == idx, (
+ f"Frame {idx}: wrong index. Expected {idx}, got {item['index'].item()}"
+ )
+
+ # Verify data integrity
+ assert item["observation.state"][0].item() == float(expected_ep), (
+ f"Frame {idx}: wrong observation.state[0]. Expected {float(expected_ep)}, "
+ f"got {item['observation.state'][0].item()}"
+ )
+ assert item["observation.state"][1].item() == float(expected_frame), (
+ f"Frame {idx}: wrong observation.state[1]. Expected {float(expected_frame)}, "
+ f"got {item['observation.state'][1].item()}"
+ )
+
+ assert len(dataset_final.meta.episodes) == total_episodes
+ for ep_idx in range(total_episodes):
+ ep_metadata = dataset_final.meta.episodes[ep_idx]
+ assert ep_metadata["episode_index"] == ep_idx
+ assert ep_metadata["length"] == frames_per_episode
+ assert ep_metadata["tasks"] == [f"task_{ep_idx}"]
+
+ expected_from = ep_idx * frames_per_episode
+ expected_to = (ep_idx + 1) * frames_per_episode
+ assert ep_metadata["dataset_from_index"] == expected_from
+ assert ep_metadata["dataset_to_index"] == expected_to
diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py
index 34fa89390..345526d90 100644
--- a/tests/policies/test_policies.py
+++ b/tests/policies/test_policies.py
@@ -95,7 +95,6 @@ def test_get_policy_and_config_classes(policy_name: str):
@pytest.mark.parametrize(
"ds_repo_id,env_name,env_kwargs,policy_name,policy_kwargs",
[
- ("lerobot/xarm_lift_medium", "xarm", {}, "tdmpc", {"use_mpc": True}),
("lerobot/pusht", "pusht", {}, "diffusion", {}),
("lerobot/pusht", "pusht", {}, "vqbet", {}),
("lerobot/pusht", "pusht", {}, "act", {}),
@@ -328,8 +327,6 @@ def test_multikey_construction(multikey: bool):
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
# to test with `policy.use_mpc=false`.
- ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
- # ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
diff --git a/tests/training/test_multi_gpu.py b/tests/training/test_multi_gpu.py
new file mode 100644
index 000000000..bb234e2e7
--- /dev/null
+++ b/tests/training/test_multi_gpu.py
@@ -0,0 +1,211 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Multi-GPU Training Tests
+
+This module tests multi-GPU training functionality with accelerate.
+These tests are designed to run on machines with 2+ GPUs and are executed
+in the nightly CI workflow.
+
+The tests automatically generate accelerate configs and launch training
+with subprocess to properly test the distributed training environment.
+"""
+
+import os
+import subprocess
+import tempfile
+from pathlib import Path
+
+import pytest
+import torch
+
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+
+
+def get_num_available_gpus():
+ """Returns the number of available GPUs."""
+ if not torch.cuda.is_available():
+ return 0
+ return torch.cuda.device_count()
+
+
+def download_dataset(repo_id, episodes):
+ """
+ Pre-download dataset to avoid race conditions in multi-GPU training.
+
+ Args:
+ repo_id: HuggingFace dataset repository ID
+ episodes: List of episode indices to download
+ """
+ # Simply instantiating the dataset will download it
+ _ = LeRobotDataset(repo_id, episodes=episodes)
+ print(f"Dataset {repo_id} downloaded successfully")
+
+
+def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
+ """
+ Helper function to run training with accelerate launch.
+
+ Args:
+ config_args: List of config arguments to pass to lerobot_train.py
+ num_processes: Number of processes (GPUs) to use
+ temp_dir: Temporary directory for outputs
+
+ Returns:
+ subprocess.CompletedProcess result
+ """
+
+ config_path = Path(temp_dir) / "accelerate_config.yaml"
+
+ # Write YAML config
+ with open(config_path, "w") as f:
+ f.write("compute_environment: LOCAL_MACHINE\n")
+ f.write("distributed_type: MULTI_GPU\n")
+ f.write("mixed_precision: 'no'\n")
+ f.write(f"num_processes: {num_processes}\n")
+ f.write("use_cpu: false\n")
+ f.write("gpu_ids: all\n")
+ f.write("downcast_bf16: 'no'\n")
+ f.write("machine_rank: 0\n")
+ f.write("main_training_function: main\n")
+ f.write("num_machines: 1\n")
+ f.write("rdzv_backend: static\n")
+ f.write("same_network: true\n")
+
+ cmd = [
+ "accelerate",
+ "launch",
+ "--config_file",
+ str(config_path),
+ "-m",
+ "lerobot.scripts.lerobot_train",
+ ] + config_args
+
+ result = subprocess.run(
+ cmd,
+ capture_output=True,
+ text=True,
+ env={**os.environ, "CUDA_VISIBLE_DEVICES": ",".join(map(str, range(num_processes)))},
+ )
+
+ return result
+
+
+@pytest.mark.skipif(
+ get_num_available_gpus() < 2,
+ reason="Multi-GPU tests require at least 2 GPUs",
+)
+class TestMultiGPUTraining:
+ """Test suite for multi-GPU training functionality."""
+
+ def test_basic_multi_gpu_training(self):
+ """
+ Test that basic multi-GPU training runs successfully.
+ Verifies that the training completes without errors.
+ """
+ # Pre-download dataset to avoid race conditions
+ download_dataset("lerobot/pusht", episodes=[0])
+
+ with tempfile.TemporaryDirectory() as temp_dir:
+ output_dir = Path(temp_dir) / "outputs"
+
+ config_args = [
+ "--dataset.repo_id=lerobot/pusht",
+ "--dataset.episodes=[0]",
+ "--policy.type=act",
+ "--policy.device=cuda",
+ "--policy.push_to_hub=false",
+ f"--output_dir={output_dir}",
+ "--batch_size=4",
+ "--steps=10",
+ "--eval_freq=-1",
+ "--log_freq=5",
+ "--save_freq=10",
+ "--seed=42",
+ "--num_workers=0",
+ ]
+
+ result = run_accelerate_training(config_args, num_processes=4, temp_dir=temp_dir)
+
+ # Check that training completed successfully
+ assert result.returncode == 0, (
+ f"Multi-GPU training failed with return code {result.returncode}\n"
+ f"STDOUT:\n{result.stdout}\n"
+ f"STDERR:\n{result.stderr}"
+ )
+
+ # Verify checkpoint was saved
+ checkpoints_dir = output_dir / "checkpoints"
+ assert checkpoints_dir.exists(), "Checkpoints directory was not created"
+
+ # Verify that training completed
+ assert "End of training" in result.stdout or "End of training" in result.stderr
+
+ def test_checkpoint_saving_multi_gpu(self):
+ """
+ Test that checkpoints are correctly saved during multi-GPU training.
+ Only the main process (rank 0) should save checkpoints.
+ """
+ # Pre-download dataset to avoid race conditions
+ download_dataset("lerobot/pusht", episodes=[0])
+
+ with tempfile.TemporaryDirectory() as temp_dir:
+ output_dir = Path(temp_dir) / "outputs"
+
+ config_args = [
+ "--dataset.repo_id=lerobot/pusht",
+ "--dataset.episodes=[0]",
+ "--policy.type=act",
+ "--policy.device=cuda",
+ "--policy.push_to_hub=false",
+ f"--output_dir={output_dir}",
+ "--batch_size=4",
+ "--steps=20",
+ "--eval_freq=-1",
+ "--log_freq=5",
+ "--save_freq=10",
+ "--seed=42",
+ "--num_workers=0",
+ ]
+
+ result = run_accelerate_training(config_args, num_processes=2, temp_dir=temp_dir)
+
+ assert result.returncode == 0, (
+ f"Training failed:\nSTDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}"
+ )
+
+ # Verify checkpoint directory exists
+ checkpoints_dir = output_dir / "checkpoints"
+ assert checkpoints_dir.exists(), "Checkpoints directory not created"
+
+ # Count checkpoint directories (should have checkpoint at step 10 and 20)
+ checkpoint_dirs = [d for d in checkpoints_dir.iterdir() if d.is_dir()]
+ assert len(checkpoint_dirs) >= 1, f"Expected at least 1 checkpoint, found {len(checkpoint_dirs)}"
+
+ # Verify checkpoint contents
+ for checkpoint_dir in checkpoint_dirs:
+ # Check for model files
+ model_files = list(checkpoint_dir.rglob("*.safetensors"))
+ assert len(model_files) > 0, f"No model files in checkpoint {checkpoint_dir}"
+
+ # Check for training state
+ training_state_dir = checkpoint_dir / "training_state"
+ assert training_state_dir.exists(), f"No training state in checkpoint {checkpoint_dir}"
+
+ # Verify optimizer state exists
+ optimizer_state = training_state_dir / "optimizer_state.safetensors"
+ assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}"