From 9a49e57c728f09d29876759f23b71e4d553b95c9 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 8 Oct 2025 20:06:56 +0200 Subject: [PATCH 01/24] refactor(datasets): add compress_level parameter to write_image() and set it to 1 (#2135) * refactor(datasets): add compress_level parameter to write_image() and set it to 1 * docs(dataset): add docs to write_image() --- src/lerobot/datasets/image_writer.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/src/lerobot/datasets/image_writer.py b/src/lerobot/datasets/image_writer.py index 4a4e1ab05..ee10df6e1 100644 --- a/src/lerobot/datasets/image_writer.py +++ b/src/lerobot/datasets/image_writer.py @@ -68,7 +68,30 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) return PIL.Image.fromarray(image_array) -def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path): +def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1): + """ + Saves a NumPy array or PIL Image to a file. + + This function handles both NumPy arrays and PIL Image objects, converting + the former to a PIL Image before saving. It includes error handling for + the save operation. + + Args: + image (np.ndarray | PIL.Image.Image): The image data to save. + fpath (Path): The destination file path for the image. + compress_level (int, optional): The compression level for the saved + image, as used by PIL.Image.save(). Defaults to 1. + Refer to: https://github.com/huggingface/lerobot/pull/2135 + for more details on the default value rationale. + + Raises: + TypeError: If the input 'image' is not a NumPy array or a + PIL.Image.Image object. + + Side Effects: + Prints an error message to the console if the image writing process + fails for any reason. + """ try: if isinstance(image, np.ndarray): img = image_array_to_pil_image(image) @@ -76,7 +99,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path): img = image else: raise TypeError(f"Unsupported image type: {type(image)}") - img.save(fpath) + img.save(fpath, compress_level=compress_level) except Exception as e: print(f"Error writing image {fpath}: {e}") From 4ccf28437a785e453888de8e4b415dc9d35ac4e0 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Wed, 8 Oct 2025 20:07:14 +0200 Subject: [PATCH 02/24] Add act documentation (#2139) * Add act documentation * remove citation as we link the paper * simplify docs * fix pre commit --- docs/source/_toctree.yml | 2 + docs/source/act.mdx | 92 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 docs/source/act.mdx diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 36eaea165..3b6cccc95 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -27,6 +27,8 @@ title: Porting Large Datasets title: "Datasets" - sections: + - local: act + title: ACT - local: smolvla title: SmolVLA - local: pi0 diff --git a/docs/source/act.mdx b/docs/source/act.mdx new file mode 100644 index 000000000..e3294ca69 --- /dev/null +++ b/docs/source/act.mdx @@ -0,0 +1,92 @@ +# ACT (Action Chunking with Transformers) + +ACT is a **lightweight and efficient policy for imitation learning**, especially well-suited for fine-grained manipulation tasks. It's the **first model we recommend when you're starting out** with LeRobot due to its fast training time, low computational requirements, and strong performance. + +
+ +
+ +_Watch this tutorial from the LeRobot team to learn how ACT works: [LeRobot ACT Tutorial](https://www.youtube.com/watch?v=ft73x0LfGpM)_ + +## Model Overview + +Action Chunking with Transformers (ACT) was introduced in the paper [Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware](https://arxiv.org/abs/2304.13705) by Zhao et al. The policy was designed to enable precise, contact-rich manipulation tasks using affordable hardware and minimal demonstration data. + +### Why ACT is Great for Beginners + +ACT stands out as an excellent starting point for several reasons: + +- **Fast Training**: Trains in a few hours on a single GPU +- **Lightweight**: Only ~80M parameters, making it efficient and easy to work with +- **Data Efficient**: Often achieves high success rates with just 50 demonstrations + +### Architecture + +ACT uses a transformer-based architecture with three main components: + +1. **Vision Backbone**: ResNet-18 processes images from multiple camera viewpoints +2. **Transformer Encoder**: Synthesizes information from camera features, joint positions, and a learned latent variable +3. **Transformer Decoder**: Generates coherent action sequences using cross-attention + +The policy takes as input: + +- Multiple RGB images (e.g., from wrist cameras, front/top cameras) +- Current robot joint positions +- A latent style variable `z` (learned during training, set to zero during inference) + +And outputs a chunk of `k` future action sequences. + +## Installation Requirements + +1. Install LeRobot by following our [Installation Guide](./installation). +2. ACT is included in the base LeRobot installation, so no additional dependencies are needed! + +## Training ACT + +ACT works seamlessly with the standard LeRobot training pipeline. Here's a complete example for training ACT on your dataset: + +```bash +lerobot-train \ + --dataset.repo_id=${HF_USER}/your_dataset \ + --policy.type=act \ + --output_dir=outputs/train/act_your_dataset \ + --job_name=act_your_dataset \ + --policy.device=cuda \ + --wandb.enable=true \ + --policy.repo_id=${HF_USER}/act_policy +``` + +### Training Tips + +1. **Start with defaults**: ACT's default hyperparameters work well for most tasks +2. **Training duration**: Expect a few hours for 100k training steps on a single GPU +3. **Batch size**: Start with batch size 8 and adjust based on your GPU memory + +### Train using Google Colab + +If your local computer doesn't have a powerful GPU, you can utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act). + +## Evaluating ACT + +Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes: + +```bash +lerobot-record \ + --robot.type=so100_follower \ + --robot.port=/dev/ttyACM0 \ + --robot.id=my_robot \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ + --display_data=true \ + --dataset.repo_id=${HF_USER}/eval_act_your_dataset \ + --dataset.num_episodes=10 \ + --dataset.single_task="Your task description" \ + --policy.path=${HF_USER}/act_policy +``` From 829d2d1ad9bc0acc20fbf64f22027c615055385e Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 9 Oct 2025 15:20:07 +0200 Subject: [PATCH 03/24] fic(docs): local docs links (#2149) --- docs/source/integrate_hardware.mdx | 4 ++-- docs/source/introduction_processors.mdx | 6 +++--- docs/source/phone_teleop.mdx | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/integrate_hardware.mdx b/docs/source/integrate_hardware.mdx index 7e7fe0bff..ed9dc8dd5 100644 --- a/docs/source/integrate_hardware.mdx +++ b/docs/source/integrate_hardware.mdx @@ -8,7 +8,7 @@ To that end, we provide the [`Robot`](https://github.com/huggingface/lerobot/blo - Your own robot which exposes a communication interface (e.g. serial, CAN, TCP) - A way to read sensor data and send motor commands programmatically, e.g. manufacturer's SDK or API, or your own protocol implementation. -- LeRobot installed in your environment. Follow our [Installation Guide](./installation.mdx). +- LeRobot installed in your environment. Follow our [Installation Guide](./installation). ## Choose your motors @@ -65,7 +65,7 @@ class MyCoolRobotConfig(RobotConfig): ``` -[Cameras tutorial](./cameras.mdx) to understand how to detect and add your camera. +[Cameras tutorial](./cameras) to understand how to detect and add your camera. Next, we'll create our actual robot class which inherits from `Robot`. This abstract class defines a contract you must follow for your robot to be usable with the rest of the LeRobot tools. diff --git a/docs/source/introduction_processors.mdx b/docs/source/introduction_processors.mdx index 308edbb3b..6f3768615 100644 --- a/docs/source/introduction_processors.mdx +++ b/docs/source/introduction_processors.mdx @@ -297,9 +297,9 @@ LeRobot provides many registered processor steps. Here are the most commonly use ### Next Steps -- **[Implement Your Own Processor](implement_your_own_processor.mdx)** - Create custom processor steps -- **[Debug Your Pipeline](debug_processor_pipeline.mdx)** - Troubleshoot and optimize pipelines -- **[Processors for Robots and Teleoperators](processors_robots_teleop.mdx)** - Real-world integration patterns +- **[Implement Your Own Processor](./implement_your_own_processor)** - Create custom processor steps +- **[Debug Your Pipeline](./debug_processor_pipeline)** - Troubleshoot and optimize pipelines +- **[Processors for Robots and Teleoperators](./processors_robots_teleop)** - Real-world integration patterns ## Summary diff --git a/docs/source/phone_teleop.mdx b/docs/source/phone_teleop.mdx index 22159193c..76e3c367c 100644 --- a/docs/source/phone_teleop.mdx +++ b/docs/source/phone_teleop.mdx @@ -79,7 +79,7 @@ After running the example: - Android: after starting the script, open the printed local URL on your phone, tap Start, then press and hold Move. - iOS: open HEBI Mobile I/O first; B1 enables motion. A3 controls the gripper. -Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop.mdx) guide. +Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop) guide. - Run this example to record a dataset, which saves absolute end effector observations and actions: From 656fc0f05956d5192f12b70fe4f0bbc25b17fc2e Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Fri, 10 Oct 2025 11:34:21 +0200 Subject: [PATCH 04/24] Remove validate_robot_cameras_for_policy (#2150) * Remove validate_robot_cameras_for_policy as with rename processor the image keys can be renamed an mapped * fix precommit --- src/lerobot/async_inference/configs.py | 5 ----- src/lerobot/async_inference/helpers.py | 9 --------- src/lerobot/async_inference/robot_client.py | 10 ---------- tests/async_inference/test_e2e.py | 1 - tests/async_inference/test_robot_client.py | 1 - 5 files changed, 26 deletions(-) diff --git a/src/lerobot/async_inference/configs.py b/src/lerobot/async_inference/configs.py index 24f889df1..d1768a323 100644 --- a/src/lerobot/async_inference/configs.py +++ b/src/lerobot/async_inference/configs.py @@ -142,11 +142,6 @@ class RobotClientConfig: default=False, metadata={"help": "Visualize the action queue size"} ) - # Verification configuration - verify_robot_cameras: bool = field( - default=True, metadata={"help": "Verify that the robot cameras match the policy cameras"} - ) - @property def environment_dt(self) -> float: """Environment time step, in seconds""" diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 54fad8c54..f73cbc1da 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -62,15 +62,6 @@ def visualize_action_queue_size(action_queue_size: list[int]) -> None: plt.show() -def validate_robot_cameras_for_policy( - lerobot_observation_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature] -) -> None: - image_keys = list(filter(is_image_key, lerobot_observation_features)) - assert set(image_keys) == set(policy_image_features.keys()), ( - f"Policy image features must match robot cameras! Received {list(policy_image_features.keys())} != {image_keys}" - ) - - def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]: return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False) diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index 8c4425c6b..f9d70a64e 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -48,7 +48,6 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 -from lerobot.configs.policies import PreTrainedConfig from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -76,7 +75,6 @@ from .helpers import ( TimedObservation, get_logger, map_robot_keys_to_lerobot_features, - validate_robot_cameras_for_policy, visualize_action_queue_size, ) @@ -98,14 +96,6 @@ class RobotClient: lerobot_features = map_robot_keys_to_lerobot_features(self.robot) - if config.verify_robot_cameras: - # Load policy config for validation - policy_config = PreTrainedConfig.from_pretrained(config.pretrained_name_or_path) - policy_image_features = policy_config.image_features - - # The cameras specified for inference must match the one supported by the policy chosen - validate_robot_cameras_for_policy(lerobot_features, policy_image_features) - # Use environment variable if server_address is not provided in config self.server_address = config.server_address diff --git a/tests/async_inference/test_e2e.py b/tests/async_inference/test_e2e.py index ebaef2ef1..11941ce32 100644 --- a/tests/async_inference/test_e2e.py +++ b/tests/async_inference/test_e2e.py @@ -139,7 +139,6 @@ def test_async_inference_e2e(monkeypatch): policy_type="test", pretrained_name_or_path="test", actions_per_chunk=20, - verify_robot_cameras=False, ) client = RobotClient(client_config) diff --git a/tests/async_inference/test_robot_client.py b/tests/async_inference/test_robot_client.py index dfdb8ce42..5b138d91b 100644 --- a/tests/async_inference/test_robot_client.py +++ b/tests/async_inference/test_robot_client.py @@ -51,7 +51,6 @@ def robot_client(): policy_type="test", pretrained_name_or_path="test", actions_per_chunk=20, - verify_robot_cameras=False, ) client = RobotClient(test_config) From b8f7e401d42a17d1ac90355f39a1ee7171afb58f Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Fri, 10 Oct 2025 12:32:07 +0200 Subject: [PATCH 05/24] Dataset tools (#2100) * feat(dataset-tools): add dataset utilities and example script - Introduced dataset tools for LeRobotDataset, including functions for deleting episodes, splitting datasets, adding/removing features, and merging datasets. - Added an example script demonstrating the usage of these utilities. - Implemented comprehensive tests for all new functionalities to ensure reliability and correctness. * style fixes * move example to dataset dir * missing lisence * fixes mostly path * clean comments * move tests to functions instead of class based * - fix video editting, decode, delete frames and rencode video - copy unchanged video and parquet files to avoid recreating the entire dataset * Fortify tooling tests * Fix type issue resulting from saving numpy arrays with shape 3,1,1 * added lerobot_edit_dataset * - revert changes in examples - remove hardcoded split names * update comment * fix comment add lerobot-edit-dataset shortcut * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Michel Aractingi * style nit after copilot review * fix: bug in dataset root when editing the dataset in place (without setting new_repo_id * Fix bug in aggregate.py when accumelating video timestamps; add tests to fortify aggregate videos * Added missing output repo id * migrate delete episode to using pyav instead of decoding, writing frames to disk and encoding again. Co-authored-by: Caroline Pascal * added modified suffix in case repo_id is not set in delete_episode * adding docs for dataset tools * bump av version and add back time_base assignment * linter * modified push_to_hub logic in lerobot_edit_dataset * fix(progress bar): fixing the progress bar issue in dataset tools * chore(concatenate): removing no longer needed concatenate_datasets usage * fix(file sizes forwarding): forwarding files and chunk sizes in metadata info when splitting and aggregating datasets * style fix * refactor(aggregate): Fix video indexing and timestamp bugs in dataset merging There were three critical bugs in aggregate.py that prevented correct dataset merging: 1. Video file indices: Changed from += to = assignment to correctly reference merged video files 2. Video timestamps: Implemented per-source-file offset tracking to maintain continuous timestamps when merging split datasets (was causing non-monotonic timestamp warnings) 3. File rotation offsets: Store timestamp offsets after rotation decision to prevent out-of-bounds frame access (was causing "Invalid frame index" errors with small file size limits) Changes: - Updated update_meta_data() to apply per-source-file timestamp offsets - Updated aggregate_videos() to track offsets correctly during file rotation - Added get_video_duration_in_s import for duration calculation * Improved docs for split dataset and added a check for the possible case that the split size results in zero episodes * chore(docs): update merge documentation details Signed-off-by: Steven Palma --------- Co-authored-by: CarolinePascal Co-authored-by: Jack Vial Co-authored-by: Steven Palma --- docs/source/_toctree.yml | 2 + docs/source/using_dataset_tools.mdx | 102 ++ examples/dataset/use_dataset_tools.py | 117 +++ pyproject.toml | 3 +- src/lerobot/datasets/aggregate.py | 82 +- src/lerobot/datasets/dataset_tools.py | 1004 +++++++++++++++++++ src/lerobot/datasets/lerobot_dataset.py | 14 +- src/lerobot/datasets/utils.py | 9 +- src/lerobot/datasets/video_utils.py | 3 + src/lerobot/scripts/lerobot_edit_dataset.py | 286 ++++++ src/lerobot/utils/utils.py | 20 + tests/datasets/test_aggregate.py | 90 ++ tests/datasets/test_dataset_tools.py | 891 ++++++++++++++++ 13 files changed, 2593 insertions(+), 30 deletions(-) create mode 100644 docs/source/using_dataset_tools.mdx create mode 100644 examples/dataset/use_dataset_tools.py create mode 100644 src/lerobot/datasets/dataset_tools.py create mode 100644 src/lerobot/scripts/lerobot_edit_dataset.py create mode 100644 tests/datasets/test_dataset_tools.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 3b6cccc95..568bd6380 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -25,6 +25,8 @@ title: Using LeRobotDataset - local: porting_datasets_v3 title: Porting Large Datasets + - local: using_dataset_tools + title: Using the Dataset Tools title: "Datasets" - sections: - local: act diff --git a/docs/source/using_dataset_tools.mdx b/docs/source/using_dataset_tools.mdx new file mode 100644 index 000000000..affca0ee5 --- /dev/null +++ b/docs/source/using_dataset_tools.mdx @@ -0,0 +1,102 @@ +# Using Dataset Tools + +This guide covers the dataset tools utilities available in LeRobot for modifying and editing existing datasets. + +## Overview + +LeRobot provides several utilities for manipulating datasets: + +1. **Delete Episodes** - Remove specific episodes from a dataset +2. **Split Dataset** - Divide a dataset into multiple smaller datasets +3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids` +4. **Add Features** - Add new features to a dataset +5. **Remove Features** - Remove features from a dataset + +The core implementation is in `lerobot.datasets.dataset_tools`. +An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`. + +## Command-Line Tool: lerobot-edit-dataset + +`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features. + +Run `lerobot-edit-dataset --help` for more information on the configuration of each operation. + +### Usage Examples + +#### Delete Episodes + +Remove specific episodes from a dataset. This is useful for filtering out undesired data. + +```bash +# Delete episodes 0, 2, and 5 (modifies original dataset) +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" + +# Delete episodes and save to a new dataset (preserves original dataset) +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht_after_deletion \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" +``` + +#### Split Dataset + +Divide a dataset into multiple subsets. + +```bash +# Split by fractions (e.g. 80% train, 20% test, 20% val) +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": 0.8, "test": 0.2, "val": 0.2}' + +# Split by specific episode indices +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"task1": [0, 1, 2, 3], "task2": [4, 5]}' +``` + +There are no constraints on the split names, they can be determined by the user. Resulting datasets are saved under the repo id with the split name appended, e.g. `lerobot/pusht_train`, `lerobot/pusht_task1`, `lerobot/pusht_task2`. + +#### Merge Datasets + +Combine multiple datasets into a single dataset. + +```bash +# Merge train and validation splits back into one dataset +lerobot-edit-dataset \ + --repo_id lerobot/pusht_merged \ + --operation.type merge \ + --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" +``` + +#### Remove Features + +Remove features from a dataset. + +```bash +# Remove a camera feature +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type remove_feature \ + --operation.feature_names "['observation.images.top']" +``` + +### Push to Hub + +Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub: + +```bash +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht_after_deletion \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" \ + --push_to_hub +``` + +There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`. diff --git a/examples/dataset/use_dataset_tools.py b/examples/dataset/use_dataset_tools.py new file mode 100644 index 000000000..244259872 --- /dev/null +++ b/examples/dataset/use_dataset_tools.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example script demonstrating dataset tools utilities. + +This script shows how to: +1. Delete episodes from a dataset +2. Split a dataset into train/val sets +3. Add/remove features +4. Merge datasets + +Usage: + python examples/dataset/use_dataset_tools.py +""" + +import numpy as np + +from lerobot.datasets.dataset_tools import ( + add_feature, + delete_episodes, + merge_datasets, + remove_feature, + split_dataset, +) +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +def main(): + dataset = LeRobotDataset("lerobot/pusht") + + print(f"Original dataset: {dataset.meta.total_episodes} episodes, {dataset.meta.total_frames} frames") + print(f"Features: {list(dataset.meta.features.keys())}") + + print("\n1. Deleting episodes 0 and 2...") + filtered_dataset = delete_episodes(dataset, episode_indices=[0, 2], repo_id="lerobot/pusht_filtered") + print(f"Filtered dataset: {filtered_dataset.meta.total_episodes} episodes") + + print("\n2. Splitting dataset into train/val...") + splits = split_dataset( + dataset, + splits={"train": 0.8, "val": 0.2}, + ) + print(f"Train split: {splits['train'].meta.total_episodes} episodes") + print(f"Val split: {splits['val'].meta.total_episodes} episodes") + + print("\n3. Adding a reward feature...") + + reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32) + dataset_with_reward = add_feature( + dataset, + feature_name="reward", + feature_values=reward_values, + feature_info={ + "dtype": "float32", + "shape": (1,), + "names": None, + }, + repo_id="lerobot/pusht_with_reward", + ) + + def compute_success(row_dict, episode_index, frame_index): + episode_length = 10 + return float(frame_index >= episode_length - 10) + + dataset_with_success = add_feature( + dataset_with_reward, + feature_name="success", + feature_values=compute_success, + feature_info={ + "dtype": "float32", + "shape": (1,), + "names": None, + }, + repo_id="lerobot/pusht_with_reward_and_success", + ) + + print(f"New features: {list(dataset_with_success.meta.features.keys())}") + + print("\n4. Removing the success feature...") + dataset_cleaned = remove_feature( + dataset_with_success, feature_names="success", repo_id="lerobot/pusht_cleaned" + ) + print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}") + + print("\n5. Merging train and val splits back together...") + merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged") + print(f"Merged dataset: {merged.meta.total_episodes} episodes") + + print("\n6. Complex workflow example...") + + if len(dataset.meta.camera_keys) > 1: + camera_to_remove = dataset.meta.camera_keys[0] + print(f"Removing camera: {camera_to_remove}") + dataset_no_cam = remove_feature( + dataset, feature_names=camera_to_remove, repo_id="pusht_no_first_camera" + ) + print(f"Remaining cameras: {dataset_no_cam.meta.camera_keys}") + + print("\nDone! Check ~/.cache/huggingface/lerobot/ for the created datasets.") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index c67b481f0..a70208cb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ dependencies = [ "cmake>=3.29.0.1,<4.2.0", "einops>=0.8.0,<0.9.0", "opencv-python-headless>=4.9.0,<4.13.0", - "av>=14.2.0,<16.0.0", + "av>=15.0.0,<16.0.0", "jsonlines>=4.0.0,<5.0.0", "packaging>=24.2,<26.0", "pynput>=1.7.7,<1.9.0", @@ -175,6 +175,7 @@ lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main" lerobot-info="lerobot.scripts.lerobot_info:main" lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main" lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" +lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" # ---------------- Tool Configurations ---------------- [tool.setuptools.packages.find] diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 803645f29..e7ea59ed0 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -39,7 +39,7 @@ from lerobot.datasets.utils import ( write_stats, write_tasks, ) -from lerobot.datasets.video_utils import concatenate_video_files +from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): @@ -130,10 +130,34 @@ def update_meta_data( df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] df["data/file_index"] = df["data/file_index"] + data_idx["file"] for key, video_idx in videos_idx.items(): - df[f"videos/{key}/chunk_index"] = df[f"videos/{key}/chunk_index"] + video_idx["chunk"] - df[f"videos/{key}/file_index"] = df[f"videos/{key}/file_index"] + video_idx["file"] - df[f"videos/{key}/from_timestamp"] = df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"] - df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"] + # Store original video file indices before updating + orig_chunk_col = f"videos/{key}/chunk_index" + orig_file_col = f"videos/{key}/file_index" + df["_orig_chunk"] = df[orig_chunk_col].copy() + df["_orig_file"] = df[orig_file_col].copy() + + # Update chunk and file indices to point to destination + df[orig_chunk_col] = video_idx["chunk"] + df[orig_file_col] = video_idx["file"] + + # Apply per-source-file timestamp offsets + src_to_offset = video_idx.get("src_to_offset", {}) + if src_to_offset: + # Apply offset based on original source file + for idx in df.index: + src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"]) + offset = src_to_offset.get(src_key, 0) + df.at[idx, f"videos/{key}/from_timestamp"] += offset + df.at[idx, f"videos/{key}/to_timestamp"] += offset + else: + # Fallback to simple offset (for backward compatibility) + df[f"videos/{key}/from_timestamp"] = ( + df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"] + ) + df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"] + + # Clean up temporary columns + df = df.drop(columns=["_orig_chunk", "_orig_file"]) df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"] df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"] @@ -193,6 +217,9 @@ def aggregate_datasets( robot_type=robot_type, features=features, root=aggr_root, + chunks_size=chunk_size, + data_files_size_in_mb=data_files_size_in_mb, + video_files_size_in_mb=video_files_size_in_mb, ) logging.info("Find all tasks") @@ -236,6 +263,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu Returns: dict: Updated videos_idx with current chunk and file indices. """ + for key in videos_idx: + videos_idx[key]["episode_duration"] = 0 + # Track offset for each source (chunk, file) pair + videos_idx[key]["src_to_offset"] = {} + for key, video_idx in videos_idx.items(): unique_chunk_file_pairs = { (chunk, file) @@ -249,6 +281,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu chunk_idx = video_idx["chunk"] file_idx = video_idx["file"] + current_offset = video_idx["latest_duration"] for src_chunk_idx, src_file_idx in unique_chunk_file_pairs: src_path = src_meta.root / DEFAULT_VIDEO_PATH.format( @@ -263,21 +296,24 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu file_index=file_idx, ) - # If a new file is created, we don't want to increment the latest_duration - update_latest_duration = False + src_duration = get_video_duration_in_s(src_path) if not dst_path.exists(): - # First write to this destination file + # Store offset before incrementing + videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset dst_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy(str(src_path), str(dst_path)) - continue # not accumulating further, already copied the file in place + videos_idx[key]["episode_duration"] += src_duration + current_offset += src_duration + continue - # Check file sizes before appending src_size = get_video_size_in_mb(src_path) dst_size = get_video_size_in_mb(dst_path) if dst_size + src_size >= video_files_size_in_mb: - # Rotate to a new chunk/file + # Rotate to a new file, this source becomes start of new destination + # So its offset should be 0 + videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0 chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size) dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format( video_key=key, @@ -286,25 +322,22 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu ) dst_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy(str(src_path), str(dst_path)) + # Reset offset for next file + current_offset = src_duration else: - # Get the timestamps shift for this video - timestamps_shift_s = dst_meta.info["total_frames"] / dst_meta.info["fps"] - - # Append to existing video file + # Append to existing video file - use current accumulated offset + videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset concatenate_video_files( [dst_path, src_path], dst_path, ) - # Update the latest_duration when appending (shifts timestamps!) - update_latest_duration = not update_latest_duration + current_offset += src_duration + + videos_idx[key]["episode_duration"] += src_duration - # Update the videos_idx with the final chunk and file indices for this key videos_idx[key]["chunk"] = chunk_idx videos_idx[key]["file"] = file_idx - if update_latest_duration: - videos_idx[key]["latest_duration"] += timestamps_shift_s - return videos_idx @@ -389,9 +422,6 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): videos_idx, ) - for k in videos_idx: - videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"] - meta_idx = append_or_create_parquet_file( df, src_path, @@ -403,6 +433,10 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): aggr_root=dst_meta.root, ) + # Increment latest_duration by the total duration added from this source dataset + for k in videos_idx: + videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"] + return meta_idx diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py new file mode 100644 index 000000000..fdeb24a72 --- /dev/null +++ b/src/lerobot/datasets/dataset_tools.py @@ -0,0 +1,1004 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset tools utilities for LeRobotDataset. + +This module provides utilities for: +- Deleting episodes from datasets +- Splitting datasets into multiple smaller datasets +- Adding/removing features from datasets +- Merging datasets (wrapper around aggregate functionality) +""" + +import logging +import shutil +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from tqdm import tqdm + +from lerobot.datasets.aggregate import aggregate_datasets +from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.datasets.utils import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_DATA_PATH, + DEFAULT_EPISODES_PATH, + get_parquet_file_size_in_mb, + to_parquet_with_hf_images, + update_chunk_file_indices, + write_info, + write_stats, + write_tasks, +) +from lerobot.utils.constants import HF_LEROBOT_HOME + + +def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict: + """Load a single episode's metadata including stats from parquet file. + + Args: + src_dataset: Source dataset + episode_idx: Episode index to load + + Returns: + dict containing episode metadata and stats + """ + ep_meta = src_dataset.meta.episodes[episode_idx] + chunk_idx = ep_meta["meta/episodes/chunk_index"] + file_idx = ep_meta["meta/episodes/file_index"] + + parquet_path = src_dataset.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + df = pd.read_parquet(parquet_path) + + episode_row = df[df["episode_index"] == episode_idx].iloc[0] + + return episode_row.to_dict() + + +def delete_episodes( + dataset: LeRobotDataset, + episode_indices: list[int], + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Delete episodes from a LeRobotDataset and create a new dataset. + + Args: + dataset: The source LeRobotDataset. + episode_indices: List of episode indices to delete. + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + """ + if not episode_indices: + raise ValueError("No episodes to delete") + + valid_indices = set(range(dataset.meta.total_episodes)) + invalid = set(episode_indices) - valid_indices + if invalid: + raise ValueError(f"Invalid episode indices: {invalid}") + + logging.info(f"Deleting {len(episode_indices)} episodes from dataset") + + if repo_id is None: + repo_id = f"{dataset.repo_id}_modified" + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id + + episodes_to_keep = [i for i in range(dataset.meta.total_episodes) if i not in episode_indices] + if not episodes_to_keep: + raise ValueError("Cannot delete all episodes from dataset") + + new_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=dataset.meta.fps, + features=dataset.meta.features, + robot_type=dataset.meta.robot_type, + root=output_dir, + use_videos=len(dataset.meta.video_keys) > 0, + ) + + episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(episodes_to_keep)} + + video_metadata = None + if dataset.meta.video_keys: + video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping) + + data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping) + + _copy_and_reindex_episodes_metadata(dataset, new_meta, episode_mapping, data_metadata, video_metadata) + + new_dataset = LeRobotDataset( + repo_id=repo_id, + root=output_dir, + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + ) + + logging.info(f"Created new dataset with {len(episodes_to_keep)} episodes") + return new_dataset + + +def split_dataset( + dataset: LeRobotDataset, + splits: dict[str, float | list[int]], + output_dir: str | Path | None = None, +) -> dict[str, LeRobotDataset]: + """Split a LeRobotDataset into multiple smaller datasets. + + Args: + dataset: The source LeRobotDataset to split. + splits: Either a dict mapping split names to episode indices, or a dict mapping + split names to fractions (must sum to <= 1.0). + output_dir: Base directory for output datasets. If None, uses default location. + + Examples: + Split by specific episodes + splits = {"train": [0, 1, 2], "val": [3, 4]} + datasets = split_dataset(dataset, splits) + + Split by fractions + splits = {"train": 0.8, "val": 0.2} + datasets = split_dataset(dataset, splits) + """ + if not splits: + raise ValueError("No splits provided") + + if all(isinstance(v, float) for v in splits.values()): + splits = _fractions_to_episode_indices(dataset.meta.total_episodes, splits) + + all_episodes = set() + for split_name, episodes in splits.items(): + if not episodes: + raise ValueError(f"Split '{split_name}' has no episodes") + episode_set = set(episodes) + if episode_set & all_episodes: + raise ValueError("Episodes cannot appear in multiple splits") + all_episodes.update(episode_set) + + valid_indices = set(range(dataset.meta.total_episodes)) + invalid = all_episodes - valid_indices + if invalid: + raise ValueError(f"Invalid episode indices: {invalid}") + + if output_dir is not None: + output_dir = Path(output_dir) + + result_datasets = {} + + for split_name, episodes in splits.items(): + logging.info(f"Creating split '{split_name}' with {len(episodes)} episodes") + + split_repo_id = f"{dataset.repo_id}_{split_name}" + + split_output_dir = ( + output_dir / split_name if output_dir is not None else HF_LEROBOT_HOME / split_repo_id + ) + + episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(sorted(episodes))} + + new_meta = LeRobotDatasetMetadata.create( + repo_id=split_repo_id, + fps=dataset.meta.fps, + features=dataset.meta.features, + robot_type=dataset.meta.robot_type, + root=split_output_dir, + use_videos=len(dataset.meta.video_keys) > 0, + chunks_size=dataset.meta.chunks_size, + data_files_size_in_mb=dataset.meta.data_files_size_in_mb, + video_files_size_in_mb=dataset.meta.video_files_size_in_mb, + ) + + video_metadata = None + if dataset.meta.video_keys: + video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping) + + data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping) + + _copy_and_reindex_episodes_metadata(dataset, new_meta, episode_mapping, data_metadata, video_metadata) + + new_dataset = LeRobotDataset( + repo_id=split_repo_id, + root=split_output_dir, + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + ) + + result_datasets[split_name] = new_dataset + + return result_datasets + + +def merge_datasets( + datasets: list[LeRobotDataset], + output_repo_id: str, + output_dir: str | Path | None = None, +) -> LeRobotDataset: + """Merge multiple LeRobotDatasets into a single dataset. + + This is a wrapper around the aggregate_datasets functionality with a cleaner API. + + Args: + datasets: List of LeRobotDatasets to merge. + output_repo_id: Repository ID for the merged dataset. + output_dir: Directory to save the merged dataset. If None, uses default location. + """ + if not datasets: + raise ValueError("No datasets to merge") + + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / output_repo_id + + repo_ids = [ds.repo_id for ds in datasets] + roots = [ds.root for ds in datasets] + + aggregate_datasets( + repo_ids=repo_ids, + aggr_repo_id=output_repo_id, + roots=roots, + aggr_root=output_dir, + ) + + merged_dataset = LeRobotDataset( + repo_id=output_repo_id, + root=output_dir, + image_transforms=datasets[0].image_transforms, + delta_timestamps=datasets[0].delta_timestamps, + tolerance_s=datasets[0].tolerance_s, + ) + + return merged_dataset + + +def add_feature( + dataset: LeRobotDataset, + feature_name: str, + feature_values: np.ndarray | torch.Tensor | Callable, + feature_info: dict, + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Add a new feature to a LeRobotDataset. + + Args: + dataset: The source LeRobotDataset. + feature_name: Name of the new feature. + feature_values: Either: + - Array/tensor of shape (num_frames, ...) with values for each frame + - Callable that takes (frame_dict, episode_index, frame_index) and returns feature value + feature_info: Dictionary with feature metadata (dtype, shape, names). + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + """ + if feature_name in dataset.meta.features: + raise ValueError(f"Feature '{feature_name}' already exists in dataset") + + if repo_id is None: + repo_id = f"{dataset.repo_id}_modified" + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id + + required_keys = {"dtype", "shape"} + if not required_keys.issubset(feature_info.keys()): + raise ValueError(f"feature_info must contain keys: {required_keys}") + + new_features = dataset.meta.features.copy() + new_features[feature_name] = feature_info + + new_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=dataset.meta.fps, + features=new_features, + robot_type=dataset.meta.robot_type, + root=output_dir, + use_videos=len(dataset.meta.video_keys) > 0, + ) + + _copy_data_with_feature_changes( + dataset=dataset, + new_meta=new_meta, + add_features={feature_name: (feature_values, feature_info)}, + ) + + if dataset.meta.video_keys: + _copy_videos(dataset, new_meta) + + new_dataset = LeRobotDataset( + repo_id=repo_id, + root=output_dir, + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + ) + + return new_dataset + + +def remove_feature( + dataset: LeRobotDataset, + feature_names: str | list[str], + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Remove features from a LeRobotDataset. + + Args: + dataset: The source LeRobotDataset. + feature_names: Name(s) of features to remove. Can be a single string or list. + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + + """ + if isinstance(feature_names, str): + feature_names = [feature_names] + + for name in feature_names: + if name not in dataset.meta.features: + raise ValueError(f"Feature '{name}' not found in dataset") + + required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"} + if any(name in required_features for name in feature_names): + raise ValueError(f"Cannot remove required features: {required_features}") + + if repo_id is None: + repo_id = f"{dataset.repo_id}_modified" + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id + + new_features = {k: v for k, v in dataset.meta.features.items() if k not in feature_names} + + video_keys_to_remove = [name for name in feature_names if name in dataset.meta.video_keys] + + remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove] + + new_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=dataset.meta.fps, + features=new_features, + robot_type=dataset.meta.robot_type, + root=output_dir, + use_videos=len(remaining_video_keys) > 0, + ) + + _copy_data_with_feature_changes( + dataset=dataset, + new_meta=new_meta, + remove_features=feature_names, + ) + + if new_meta.video_keys: + _copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove) + + new_dataset = LeRobotDataset( + repo_id=repo_id, + root=output_dir, + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + ) + + return new_dataset + + +def _fractions_to_episode_indices( + total_episodes: int, + splits: dict[str, float], +) -> dict[str, list[int]]: + """Convert split fractions to episode indices.""" + if sum(splits.values()) > 1.0: + raise ValueError("Split fractions must sum to <= 1.0") + + indices = list(range(total_episodes)) + result = {} + start_idx = 0 + + for split_name, fraction in splits.items(): + num_episodes = int(total_episodes * fraction) + if num_episodes == 0: + logging.warning(f"Split '{split_name}' has no episodes, skipping...") + continue + end_idx = start_idx + num_episodes + if split_name == list(splits.keys())[-1]: + end_idx = total_episodes + result[split_name] = indices[start_idx:end_idx] + start_idx = end_idx + + return result + + +def _copy_and_reindex_data( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_mapping: dict[int, int], +) -> dict[int, dict]: + """Copy and filter data files, only modifying files with deleted episodes. + + Args: + src_dataset: Source dataset to copy from + dst_meta: Destination metadata object + episode_mapping: Mapping from old episode indices to new indices + + Returns: + dict mapping episode index to its data file metadata (chunk_index, file_index, etc.) + """ + file_to_episodes: dict[Path, set[int]] = {} + for old_idx in episode_mapping: + file_path = src_dataset.meta.get_data_file_path(old_idx) + if file_path not in file_to_episodes: + file_to_episodes[file_path] = set() + file_to_episodes[file_path].add(old_idx) + + global_index = 0 + episode_data_metadata: dict[int, dict] = {} + + if dst_meta.tasks is None: + all_task_indices = set() + for src_path in file_to_episodes: + df = pd.read_parquet(src_dataset.root / src_path) + mask = df["episode_index"].isin(list(episode_mapping.keys())) + task_series: pd.Series = df[mask]["task_index"] + all_task_indices.update(task_series.unique().tolist()) + tasks = [src_dataset.meta.tasks.iloc[idx].name for idx in all_task_indices] + dst_meta.save_episode_tasks(list(set(tasks))) + + task_mapping = {} + for old_task_idx in range(len(src_dataset.meta.tasks)): + task_name = src_dataset.meta.tasks.iloc[old_task_idx].name + new_task_idx = dst_meta.get_task_index(task_name) + if new_task_idx is not None: + task_mapping[old_task_idx] = new_task_idx + + for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"): + df = pd.read_parquet(src_dataset.root / src_path) + + all_episodes_in_file = set(df["episode_index"].unique()) + episodes_to_keep = file_to_episodes[src_path] + + if all_episodes_in_file == episodes_to_keep: + df["episode_index"] = df["episode_index"].replace(episode_mapping) + df["index"] = range(global_index, global_index + len(df)) + df["task_index"] = df["task_index"].replace(task_mapping) + + first_ep_old_idx = min(episodes_to_keep) + src_ep = src_dataset.meta.episodes[first_ep_old_idx] + chunk_idx = src_ep["data/chunk_index"] + file_idx = src_ep["data/file_index"] + else: + mask = df["episode_index"].isin(list(episode_mapping.keys())) + df = df[mask].copy().reset_index(drop=True) + + if len(df) == 0: + continue + + df["episode_index"] = df["episode_index"].replace(episode_mapping) + df["index"] = range(global_index, global_index + len(df)) + df["task_index"] = df["task_index"].replace(task_mapping) + + first_ep_old_idx = min(episodes_to_keep) + src_ep = src_dataset.meta.episodes[first_ep_old_idx] + chunk_idx = src_ep["data/chunk_index"] + file_idx = src_ep["data/file_index"] + + dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + dst_path.parent.mkdir(parents=True, exist_ok=True) + + if len(dst_meta.image_keys) > 0: + to_parquet_with_hf_images(df, dst_path) + else: + df.to_parquet(dst_path, index=False) + + for ep_old_idx in episodes_to_keep: + ep_new_idx = episode_mapping[ep_old_idx] + ep_df = df[df["episode_index"] == ep_new_idx] + episode_data_metadata[ep_new_idx] = { + "data/chunk_index": chunk_idx, + "data/file_index": file_idx, + "dataset_from_index": int(ep_df["index"].min()), + "dataset_to_index": int(ep_df["index"].max() + 1), + } + + global_index += len(df) + + return episode_data_metadata + + +def _keep_episodes_from_video_with_av( + input_path: Path, + output_path: Path, + episodes_to_keep: list[tuple[float, float]], + fps: float, + vcodec: str = "libsvtav1", + pix_fmt: str = "yuv420p", +) -> None: + """Keep only specified episodes from a video file using PyAV. + + This function decodes frames from specified time ranges and re-encodes them with + properly reset timestamps to ensure monotonic progression. + + Args: + input_path: Source video file path. + output_path: Destination video file path. + episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep. + fps: Frame rate of the video. + vcodec: Video codec to use for encoding. + pix_fmt: Pixel format for output video. + """ + from fractions import Fraction + + import av + + if not episodes_to_keep: + raise ValueError("No episodes to keep") + + in_container = av.open(str(input_path)) + + # Check if video stream exists. + if not in_container.streams.video: + raise ValueError( + f"No video streams found in {input_path}. " + "The video file may be corrupted or empty. " + "Try re-downloading the dataset or checking the video file." + ) + + v_in = in_container.streams.video[0] + + out = av.open(str(output_path), mode="w") + + # Convert fps to Fraction for PyAV compatibility. + fps_fraction = Fraction(fps).limit_denominator(1000) + v_out = out.add_stream(vcodec, rate=fps_fraction) + + # PyAV type stubs don't distinguish video streams from audio/subtitle streams. + v_out.width = v_in.codec_context.width + v_out.height = v_in.codec_context.height + v_out.pix_fmt = pix_fmt + + # Set time_base to match the frame rate for proper timestamp handling. + v_out.time_base = Fraction(1, int(fps)) + + out.start_encoding() + + # Create set of (start, end) ranges for fast lookup. + # Convert to a sorted list for efficient checking. + time_ranges = sorted(episodes_to_keep) + + # Track frame index for setting PTS and current range being processed. + frame_count = 0 + range_idx = 0 + + # Read through entire video once and filter frames. + for packet in in_container.demux(v_in): + for frame in packet.decode(): + if frame is None: + continue + + # Get frame timestamp. + frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0 + + # Check if frame is in any of our desired time ranges. + # Skip ranges that have already passed. + while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]: + range_idx += 1 + + # If we've passed all ranges, stop processing. + if range_idx >= len(time_ranges): + break + + # Check if frame is in current range. + start_ts, end_ts = time_ranges[range_idx] + if frame_time < start_ts: + continue + + # Frame is in range - create a new frame with reset timestamps. + # We need to create a copy to avoid modifying the original. + new_frame = frame.reformat(width=v_out.width, height=v_out.height, format=v_out.pix_fmt) + new_frame.pts = frame_count + new_frame.time_base = Fraction(1, int(fps)) + + # Encode and mux the frame. + for pkt in v_out.encode(new_frame): + out.mux(pkt) + + frame_count += 1 + + # Flush encoder. + for pkt in v_out.encode(): + out.mux(pkt) + + out.close() + in_container.close() + + +def _copy_and_reindex_videos( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_mapping: dict[int, int], + vcodec: str = "libsvtav1", + pix_fmt: str = "yuv420p", +) -> dict[int, dict]: + """Copy and filter video files, only re-encoding files with deleted episodes. + + For video files that only contain kept episodes, we copy them directly. + For files with mixed kept/deleted episodes, we use PyAV filters to efficiently + re-encode only the desired segments. + + Args: + src_dataset: Source dataset to copy from + dst_meta: Destination metadata object + episode_mapping: Mapping from old episode indices to new indices + + Returns: + dict mapping episode index to its video metadata (chunk_index, file_index, timestamps) + """ + + episodes_video_metadata: dict[int, dict] = {new_idx: {} for new_idx in episode_mapping.values()} + + for video_key in src_dataset.meta.video_keys: + logging.info(f"Processing videos for {video_key}") + + if dst_meta.video_path is None: + raise ValueError("Destination metadata has no video_path defined") + + file_to_episodes: dict[tuple[int, int], list[int]] = {} + for old_idx in episode_mapping: + src_ep = src_dataset.meta.episodes[old_idx] + chunk_idx = src_ep[f"videos/{video_key}/chunk_index"] + file_idx = src_ep[f"videos/{video_key}/file_index"] + file_key = (chunk_idx, file_idx) + if file_key not in file_to_episodes: + file_to_episodes[file_key] = [] + file_to_episodes[file_key].append(old_idx) + + for (src_chunk_idx, src_file_idx), episodes_in_file in tqdm( + sorted(file_to_episodes.items()), desc=f"Processing {video_key} video files" + ): + all_episodes_in_file = [ + ep_idx + for ep_idx in range(src_dataset.meta.total_episodes) + if src_dataset.meta.episodes[ep_idx].get(f"videos/{video_key}/chunk_index") == src_chunk_idx + and src_dataset.meta.episodes[ep_idx].get(f"videos/{video_key}/file_index") == src_file_idx + ] + + episodes_to_keep_set = set(episodes_in_file) + all_in_file_set = set(all_episodes_in_file) + + if all_in_file_set == episodes_to_keep_set: + assert src_dataset.meta.video_path is not None + src_video_path = src_dataset.root / src_dataset.meta.video_path.format( + video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx + ) + dst_video_path = dst_meta.root / dst_meta.video_path.format( + video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx + ) + dst_video_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(src_video_path, dst_video_path) + + for old_idx in episodes_in_file: + new_idx = episode_mapping[old_idx] + src_ep = src_dataset.meta.episodes[old_idx] + episodes_video_metadata[new_idx][f"videos/{video_key}/chunk_index"] = src_chunk_idx + episodes_video_metadata[new_idx][f"videos/{video_key}/file_index"] = src_file_idx + episodes_video_metadata[new_idx][f"videos/{video_key}/from_timestamp"] = src_ep[ + f"videos/{video_key}/from_timestamp" + ] + episodes_video_metadata[new_idx][f"videos/{video_key}/to_timestamp"] = src_ep[ + f"videos/{video_key}/to_timestamp" + ] + else: + # Build list of time ranges to keep, in sorted order. + sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x]) + episodes_to_keep_ranges: list[tuple[float, float]] = [] + + for old_idx in sorted_keep_episodes: + src_ep = src_dataset.meta.episodes[old_idx] + from_ts = src_ep[f"videos/{video_key}/from_timestamp"] + to_ts = src_ep[f"videos/{video_key}/to_timestamp"] + episodes_to_keep_ranges.append((from_ts, to_ts)) + + # Use PyAV filters to efficiently re-encode only the desired segments. + assert src_dataset.meta.video_path is not None + src_video_path = src_dataset.root / src_dataset.meta.video_path.format( + video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx + ) + dst_video_path = dst_meta.root / dst_meta.video_path.format( + video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx + ) + dst_video_path.parent.mkdir(parents=True, exist_ok=True) + + logging.info( + f"Re-encoding {video_key} (chunk {src_chunk_idx}, file {src_file_idx}) " + f"with {len(episodes_to_keep_ranges)} episodes" + ) + _keep_episodes_from_video_with_av( + src_video_path, + dst_video_path, + episodes_to_keep_ranges, + src_dataset.meta.fps, + vcodec, + pix_fmt, + ) + + cumulative_ts = 0.0 + for old_idx in sorted_keep_episodes: + new_idx = episode_mapping[old_idx] + src_ep = src_dataset.meta.episodes[old_idx] + ep_length = src_ep["length"] + ep_duration = ep_length / src_dataset.meta.fps + + episodes_video_metadata[new_idx][f"videos/{video_key}/chunk_index"] = src_chunk_idx + episodes_video_metadata[new_idx][f"videos/{video_key}/file_index"] = src_file_idx + episodes_video_metadata[new_idx][f"videos/{video_key}/from_timestamp"] = cumulative_ts + episodes_video_metadata[new_idx][f"videos/{video_key}/to_timestamp"] = ( + cumulative_ts + ep_duration + ) + + cumulative_ts += ep_duration + + return episodes_video_metadata + + +def _copy_and_reindex_episodes_metadata( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_mapping: dict[int, int], + data_metadata: dict[int, dict], + video_metadata: dict[int, dict] | None = None, +) -> None: + """Copy and reindex episodes metadata using provided data and video metadata. + + Args: + src_dataset: Source dataset to copy from + dst_meta: Destination metadata object + episode_mapping: Mapping from old episode indices to new indices + data_metadata: Dict mapping new episode index to its data file metadata + video_metadata: Optional dict mapping new episode index to its video metadata + """ + from lerobot.datasets.utils import flatten_dict + + all_stats = [] + total_frames = 0 + + for old_idx, new_idx in tqdm( + sorted(episode_mapping.items(), key=lambda x: x[1]), desc="Processing episodes metadata" + ): + src_episode_full = _load_episode_with_stats(src_dataset, old_idx) + + src_episode = src_dataset.meta.episodes[old_idx] + + episode_meta = data_metadata[new_idx].copy() + + if video_metadata and new_idx in video_metadata: + episode_meta.update(video_metadata[new_idx]) + + # Extract episode statistics from parquet metadata. + # Note (maractingi): When pandas/pyarrow serializes numpy arrays with shape (3, 1, 1) to parquet, + # they are being deserialized as nested object arrays like: + # array([array([array([0.])]), array([array([0.])]), array([array([0.])])]) + # This happens particularly with image/video statistics. We need to detect and flatten + # these nested structures back to proper (3, 1, 1) arrays so aggregate_stats can process them. + episode_stats = {} + for key in src_episode_full: + if key.startswith("stats/"): + stat_key = key.replace("stats/", "") + parts = stat_key.split("/") + if len(parts) == 2: + feature_name, stat_name = parts + if feature_name not in episode_stats: + episode_stats[feature_name] = {} + + value = src_episode_full[key] + + if feature_name in src_dataset.meta.features: + feature_dtype = src_dataset.meta.features[feature_name]["dtype"] + if feature_dtype in ["image", "video"] and stat_name != "count": + if isinstance(value, np.ndarray) and value.dtype == object: + flat_values = [] + for item in value: + while isinstance(item, np.ndarray): + item = item.flatten()[0] + flat_values.append(item) + value = np.array(flat_values, dtype=np.float64).reshape(3, 1, 1) + elif isinstance(value, np.ndarray) and value.shape == (3,): + value = value.reshape(3, 1, 1) + + episode_stats[feature_name][stat_name] = value + + all_stats.append(episode_stats) + + episode_dict = { + "episode_index": new_idx, + "tasks": src_episode["tasks"], + "length": src_episode["length"], + } + episode_dict.update(episode_meta) + episode_dict.update(flatten_dict({"stats": episode_stats})) + dst_meta._save_episode_metadata(episode_dict) + + total_frames += src_episode["length"] + + dst_meta.info.update( + { + "total_episodes": len(episode_mapping), + "total_frames": total_frames, + "total_tasks": len(dst_meta.tasks) if dst_meta.tasks is not None else 0, + "splits": {"train": f"0:{len(episode_mapping)}"}, + } + ) + write_info(dst_meta.info, dst_meta.root) + + if not all_stats: + logging.warning("No statistics found to aggregate") + return + + logging.info(f"Aggregating statistics for {len(all_stats)} episodes") + aggregated_stats = aggregate_stats(all_stats) + filtered_stats = {k: v for k, v in aggregated_stats.items() if k in dst_meta.features} + write_stats(filtered_stats, dst_meta.root) + + +def _save_data_chunk( + df: pd.DataFrame, + meta: LeRobotDatasetMetadata, + chunk_idx: int = 0, + file_idx: int = 0, +) -> tuple[int, int, dict[int, dict]]: + """Save a data chunk and return updated indices and episode metadata. + + Returns: + tuple: (next_chunk_idx, next_file_idx, episode_metadata_dict) + where episode_metadata_dict maps episode_index to its data file metadata + """ + path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + path.parent.mkdir(parents=True, exist_ok=True) + + if len(meta.image_keys) > 0: + to_parquet_with_hf_images(df, path) + else: + df.to_parquet(path, index=False) + + episode_metadata = {} + for ep_idx in df["episode_index"].unique(): + ep_df = df[df["episode_index"] == ep_idx] + episode_metadata[ep_idx] = { + "data/chunk_index": chunk_idx, + "data/file_index": file_idx, + "dataset_from_index": int(ep_df["index"].min()), + "dataset_to_index": int(ep_df["index"].max() + 1), + } + + file_size = get_parquet_file_size_in_mb(path) + if file_size >= DEFAULT_DATA_FILE_SIZE_IN_MB * 0.9: + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE) + + return chunk_idx, file_idx, episode_metadata + + +def _copy_data_with_feature_changes( + dataset: LeRobotDataset, + new_meta: LeRobotDatasetMetadata, + add_features: dict[str, tuple] | None = None, + remove_features: list[str] | None = None, +) -> None: + """Copy data while adding or removing features.""" + file_paths = set() + for ep_idx in range(dataset.meta.total_episodes): + file_paths.add(dataset.meta.get_data_file_path(ep_idx)) + + frame_idx = 0 + + for src_path in tqdm(sorted(file_paths), desc="Processing data files"): + df = pd.read_parquet(dataset.root / src_path).reset_index(drop=True) + + if remove_features: + df = df.drop(columns=remove_features, errors="ignore") + + if add_features: + for feature_name, (values, _) in add_features.items(): + if callable(values): + feature_values = [] + for _, row in df.iterrows(): + ep_idx = row["episode_index"] + frame_in_ep = row["frame_index"] + value = values(row.to_dict(), ep_idx, frame_in_ep) + if isinstance(value, np.ndarray) and value.size == 1: + value = value.item() + feature_values.append(value) + df[feature_name] = feature_values + else: + end_idx = frame_idx + len(df) + feature_slice = values[frame_idx:end_idx] + if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1: + df[feature_name] = feature_slice.flatten() + else: + df[feature_name] = feature_slice + frame_idx = end_idx + + _save_data_chunk(df, new_meta) + + _copy_episodes_metadata_and_stats(dataset, new_meta) + + +def _copy_videos( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + exclude_keys: list[str] | None = None, +) -> None: + """Copy video files, optionally excluding certain keys.""" + if exclude_keys is None: + exclude_keys = [] + + for video_key in src_dataset.meta.video_keys: + if video_key in exclude_keys: + continue + + video_files = set() + for ep_idx in range(len(src_dataset.meta.episodes)): + try: + video_files.add(src_dataset.meta.get_video_file_path(ep_idx, video_key)) + except KeyError: + continue + + for src_path in tqdm(sorted(video_files), desc=f"Copying {video_key} videos"): + dst_path = dst_meta.root / src_path + dst_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(src_dataset.root / src_path, dst_path) + + +def _copy_episodes_metadata_and_stats( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, +) -> None: + """Copy episodes metadata and recalculate stats.""" + if src_dataset.meta.tasks is not None: + write_tasks(src_dataset.meta.tasks, dst_meta.root) + dst_meta.tasks = src_dataset.meta.tasks.copy() + + episodes_dir = src_dataset.root / "meta/episodes" + dst_episodes_dir = dst_meta.root / "meta/episodes" + if episodes_dir.exists(): + shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True) + + dst_meta.info.update( + { + "total_episodes": src_dataset.meta.total_episodes, + "total_frames": src_dataset.meta.total_frames, + "total_tasks": src_dataset.meta.total_tasks, + "splits": src_dataset.meta.info.get("splits", {"train": f"0:{src_dataset.meta.total_episodes}"}), + } + ) + + if dst_meta.video_keys and src_dataset.meta.video_keys: + for key in dst_meta.video_keys: + if key in src_dataset.meta.features: + dst_meta.info["features"][key]["info"] = src_dataset.meta.info["features"][key].get( + "info", {} + ) + + write_info(dst_meta.info, dst_meta.root) + + if set(dst_meta.features.keys()) != set(src_dataset.meta.features.keys()): + logging.info("Recalculating dataset statistics...") + if src_dataset.meta.stats: + new_stats = {} + for key in dst_meta.features: + if key in src_dataset.meta.stats: + new_stats[key] = src_dataset.meta.stats[key] + write_stats(new_stats, dst_meta.root) + else: + if src_dataset.meta.stats: + write_stats(src_dataset.meta.stats, dst_meta.root) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index b661b21b0..229d37641 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -438,6 +438,9 @@ class LeRobotDatasetMetadata: robot_type: str | None = None, root: str | Path | None = None, use_videos: bool = True, + chunks_size: int | None = None, + data_files_size_in_mb: int | None = None, + video_files_size_in_mb: int | None = None, ) -> "LeRobotDatasetMetadata": """Creates metadata for a LeRobotDataset.""" obj = cls.__new__(cls) @@ -452,7 +455,16 @@ class LeRobotDatasetMetadata: obj.tasks = None obj.episodes = None obj.stats = None - obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type) + obj.info = create_empty_dataset_info( + CODEBASE_VERSION, + fps, + features, + use_videos, + robot_type, + chunks_size, + data_files_size_in_mb, + video_files_size_in_mb, + ) if len(obj.video_keys) > 0 and not use_videos: raise ValueError() write_json(obj.info, obj.root / INFO_PATH) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index a2f285014..422a7010a 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -30,7 +30,7 @@ import pandas import pandas as pd import pyarrow.parquet as pq import torch -from datasets import Dataset, concatenate_datasets +from datasets import Dataset from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub.errors import RevisionNotFoundError @@ -44,7 +44,7 @@ from lerobot.datasets.backward_compatibility import ( ForwardCompatibilityError, ) from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR -from lerobot.utils.utils import is_valid_numpy_dtype_string +from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_string DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file @@ -123,8 +123,9 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") # TODO(rcadene): set num_proc to accelerate conversion to pyarrow - datasets = [Dataset.from_parquet(str(path), features=features) for path in paths] - return concatenate_datasets(datasets) + with SuppressProgressBars(): + datasets = Dataset.from_parquet([str(path) for path in paths], features=features) + return datasets def get_parquet_num_frames(parquet_path: str | Path) -> int: diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 1d4f07c76..620ba863a 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -452,6 +452,9 @@ def concatenate_video_files( template=input_stream, opaque=True ) + # set the time base to the input stream time base (missing in the codec context) + stream_map[input_stream.index].time_base = input_stream.time_base + # Demux + remux packets (no re-encode) for packet in input_container.demux(): # Skip packets from un-mapped streams diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py new file mode 100644 index 000000000..83ba027bc --- /dev/null +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Edit LeRobot datasets using various transformation tools. + +This script allows you to delete episodes, split datasets, merge datasets, +and remove features. When new_repo_id is specified, creates a new dataset. + +Usage Examples: + +Delete episodes 0, 2, and 5 from a dataset: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" + +Delete episodes and save to a new dataset: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht_filtered \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" + +Split dataset by fractions: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": 0.8, "val": 0.2}' + +Split dataset by episode indices: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}' + +Split into more than two splits: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}' + +Merge multiple datasets: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht_merged \ + --operation.type merge \ + --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" + +Remove camera feature: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type remove_feature \ + --operation.feature_names "['observation.images.top']" + +Using JSON config file: + python -m lerobot.scripts.lerobot_edit_dataset \ + --config_path path/to/edit_config.json +""" + +import logging +import shutil +from dataclasses import dataclass +from pathlib import Path + +from lerobot.configs import parser +from lerobot.datasets.dataset_tools import ( + delete_episodes, + merge_datasets, + remove_feature, + split_dataset, +) +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.utils import init_logging + + +@dataclass +class DeleteEpisodesConfig: + type: str = "delete_episodes" + episode_indices: list[int] | None = None + + +@dataclass +class SplitConfig: + type: str = "split" + splits: dict[str, float | list[int]] | None = None + + +@dataclass +class MergeConfig: + type: str = "merge" + repo_ids: list[str] | None = None + + +@dataclass +class RemoveFeatureConfig: + type: str = "remove_feature" + feature_names: list[str] | None = None + + +@dataclass +class EditDatasetConfig: + repo_id: str + operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig + root: str | None = None + new_repo_id: str | None = None + push_to_hub: bool = False + + +def get_output_path(repo_id: str, new_repo_id: str | None, root: Path | None) -> tuple[str, Path]: + if new_repo_id: + output_repo_id = new_repo_id + output_dir = root / new_repo_id if root else HF_LEROBOT_HOME / new_repo_id + else: + output_repo_id = repo_id + dataset_path = root / repo_id if root else HF_LEROBOT_HOME / repo_id + old_path = Path(str(dataset_path) + "_old") + + if dataset_path.exists(): + if old_path.exists(): + shutil.rmtree(old_path) + shutil.move(str(dataset_path), str(old_path)) + + output_dir = dataset_path + + return output_repo_id, output_dir + + +def handle_delete_episodes(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, DeleteEpisodesConfig): + raise ValueError("Operation config must be DeleteEpisodesConfig") + + if not cfg.operation.episode_indices: + raise ValueError("episode_indices must be specified for delete_episodes operation") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + output_repo_id, output_dir = get_output_path( + cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None + ) + + if cfg.new_repo_id is None: + dataset.root = Path(str(dataset.root) + "_old") + + logging.info(f"Deleting episodes {cfg.operation.episode_indices} from {cfg.repo_id}") + new_dataset = delete_episodes( + dataset, + episode_indices=cfg.operation.episode_indices, + output_dir=output_dir, + repo_id=output_repo_id, + ) + + logging.info(f"Dataset saved to {output_dir}") + logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}") + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {output_repo_id}") + LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() + + +def handle_split(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, SplitConfig): + raise ValueError("Operation config must be SplitConfig") + + if not cfg.operation.splits: + raise ValueError( + "splits dict must be specified with split names as keys and fractions/episode lists as values" + ) + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + + logging.info(f"Splitting dataset {cfg.repo_id} with splits: {cfg.operation.splits}") + split_datasets = split_dataset(dataset, splits=cfg.operation.splits) + + for split_name, split_ds in split_datasets.items(): + split_repo_id = f"{cfg.repo_id}_{split_name}" + logging.info( + f"{split_name}: {split_ds.meta.total_episodes} episodes, {split_ds.meta.total_frames} frames" + ) + + if cfg.push_to_hub: + logging.info(f"Pushing {split_name} split to hub as {split_repo_id}") + LeRobotDataset(split_ds.repo_id, root=split_ds.root).push_to_hub() + + +def handle_merge(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, MergeConfig): + raise ValueError("Operation config must be MergeConfig") + + if not cfg.operation.repo_ids: + raise ValueError("repo_ids must be specified for merge operation") + + if not cfg.repo_id: + raise ValueError("repo_id must be specified as the output repository for merged dataset") + + logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge") + datasets = [LeRobotDataset(repo_id, root=cfg.root) for repo_id in cfg.operation.repo_ids] + + output_dir = Path(cfg.root) / cfg.repo_id if cfg.root else HF_LEROBOT_HOME / cfg.repo_id + + logging.info(f"Merging datasets into {cfg.repo_id}") + merged_dataset = merge_datasets( + datasets, + output_repo_id=cfg.repo_id, + output_dir=output_dir, + ) + + logging.info(f"Merged dataset saved to {output_dir}") + logging.info( + f"Episodes: {merged_dataset.meta.total_episodes}, Frames: {merged_dataset.meta.total_frames}" + ) + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {cfg.repo_id}") + LeRobotDataset(merged_dataset.repo_id, root=output_dir).push_to_hub() + + +def handle_remove_feature(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, RemoveFeatureConfig): + raise ValueError("Operation config must be RemoveFeatureConfig") + + if not cfg.operation.feature_names: + raise ValueError("feature_names must be specified for remove_feature operation") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + output_repo_id, output_dir = get_output_path( + cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None + ) + + if cfg.new_repo_id is None: + dataset.root = Path(str(dataset.root) + "_old") + + logging.info(f"Removing features {cfg.operation.feature_names} from {cfg.repo_id}") + new_dataset = remove_feature( + dataset, + feature_names=cfg.operation.feature_names, + output_dir=output_dir, + repo_id=output_repo_id, + ) + + logging.info(f"Dataset saved to {output_dir}") + logging.info(f"Remaining features: {list(new_dataset.meta.features.keys())}") + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {output_repo_id}") + LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() + + +@parser.wrap() +def edit_dataset(cfg: EditDatasetConfig) -> None: + operation_type = cfg.operation.type + + if operation_type == "delete_episodes": + handle_delete_episodes(cfg) + elif operation_type == "split": + handle_split(cfg) + elif operation_type == "merge": + handle_merge(cfg) + elif operation_type == "remove_feature": + handle_remove_feature(cfg) + else: + raise ValueError( + f"Unknown operation type: {operation_type}\n" + f"Available operations: delete_episodes, split, merge, remove_feature" + ) + + +def main() -> None: + init_logging() + edit_dataset() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 8777d5a9d..dfcd4a6b1 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -27,6 +27,7 @@ from statistics import mean import numpy as np import torch +from datasets.utils.logging import disable_progress_bar, enable_progress_bar def inside_slurm(): @@ -247,6 +248,25 @@ def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float): return days, hours, minutes, seconds +class SuppressProgressBars: + """ + Context manager to suppress progress bars. + + Example + -------- + ```python + with SuppressProgressBars(): + # Code that would normally show progress bars + ``` + """ + + def __enter__(self): + disable_progress_bar() + + def __exit__(self, exc_type, exc_val, exc_tb): + enable_progress_bar() + + class TimerManager: """ Lightweight utility to measure elapsed time. diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index 4f316f80e..b710a3a4b 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -181,6 +181,54 @@ def assert_dataset_iteration_works(aggr_ds): pass +def assert_video_timestamps_within_bounds(aggr_ds): + """Test that all video timestamps are within valid bounds for their respective video files. + + This catches bugs where timestamps point to frames beyond the actual video length, + which would cause "Invalid frame index" errors during data loading. + """ + try: + from torchcodec.decoders import VideoDecoder + except ImportError: + return + + for ep_idx in range(aggr_ds.num_episodes): + ep = aggr_ds.meta.episodes[ep_idx] + + for vid_key in aggr_ds.meta.video_keys: + from_ts = ep[f"videos/{vid_key}/from_timestamp"] + to_ts = ep[f"videos/{vid_key}/to_timestamp"] + video_path = aggr_ds.root / aggr_ds.meta.get_video_file_path(ep_idx, vid_key) + + if not video_path.exists(): + continue + + from_frame_idx = round(from_ts * aggr_ds.fps) + to_frame_idx = round(to_ts * aggr_ds.fps) + + try: + decoder = VideoDecoder(str(video_path)) + num_frames = len(decoder) + + # Verify timestamps don't exceed video bounds + assert from_frame_idx >= 0, ( + f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) < 0" + ) + assert from_frame_idx < num_frames, ( + f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= video frames ({num_frames})" + ) + assert to_frame_idx <= num_frames, ( + f"Episode {ep_idx}, {vid_key}: to_frame_idx ({to_frame_idx}) > video frames ({num_frames})" + ) + assert from_frame_idx < to_frame_idx, ( + f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= to_frame_idx ({to_frame_idx})" + ) + except Exception as e: + raise AssertionError( + f"Failed to verify timestamps for episode {ep_idx}, {vid_key}: {e}" + ) from e + + def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): """Test basic aggregation functionality with standard parameters.""" ds_0_num_frames = 400 @@ -227,6 +275,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): assert_metadata_consistency(aggr_ds, ds_0, ds_1) assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) assert_video_frames_integrity(aggr_ds, ds_0, ds_1) + assert_video_timestamps_within_bounds(aggr_ds) assert_dataset_iteration_works(aggr_ds) @@ -277,6 +326,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): assert_metadata_consistency(aggr_ds, ds_0, ds_1) assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) assert_video_frames_integrity(aggr_ds, ds_0, ds_1) + assert_video_timestamps_within_bounds(aggr_ds) assert_dataset_iteration_works(aggr_ds) # Check that multiple files were actually created due to small size limits @@ -290,3 +340,43 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): if video_dir.exists(): video_files = list(video_dir.rglob("*.mp4")) assert len(video_files) > 1, "Small file size limits should create multiple video files" + + +def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory): + """Regression test for video timestamp bug when merging datasets. + + This test specifically checks that video timestamps are correctly calculated + and accumulated when merging multiple datasets. + """ + datasets = [] + for i in range(3): + ds = lerobot_dataset_factory( + root=tmp_path / f"regression_{i}", + repo_id=f"{DUMMY_REPO_ID}_regression_{i}", + total_episodes=2, + total_frames=100, + ) + datasets.append(ds) + + aggregate_datasets( + repo_ids=[ds.repo_id for ds in datasets], + roots=[ds.root for ds in datasets], + aggr_repo_id=f"{DUMMY_REPO_ID}_regression_aggr", + aggr_root=tmp_path / "regression_aggr", + ) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "regression_aggr") + aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr") + + assert_video_timestamps_within_bounds(aggr_ds) + + for i in range(len(aggr_ds)): + item = aggr_ds[i] + for key in aggr_ds.meta.video_keys: + assert key in item, f"Video key {key} missing from item {i}" + assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}" diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py new file mode 100644 index 000000000..fe117b35b --- /dev/null +++ b/tests/datasets/test_dataset_tools.py @@ -0,0 +1,891 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for dataset tools utilities.""" + +from unittest.mock import patch + +import numpy as np +import pytest +import torch + +from lerobot.datasets.dataset_tools import ( + add_feature, + delete_episodes, + merge_datasets, + remove_feature, + split_dataset, +) + + +@pytest.fixture +def sample_dataset(tmp_path, empty_lerobot_dataset_factory): + """Create a sample dataset for testing.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + dataset = empty_lerobot_dataset_factory( + root=tmp_path / "test_dataset", + features=features, + ) + + for ep_idx in range(5): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx % 2}", + } + dataset.add_frame(frame) + dataset.save_episode() + + return dataset + + +def test_delete_single_episode(sample_dataset, tmp_path): + """Test deleting a single episode.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[2], + output_dir=output_dir, + ) + + assert new_dataset.meta.total_episodes == 4 + assert new_dataset.meta.total_frames == 40 + + episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]} + assert episode_indices == {0, 1, 2, 3} + + assert len(new_dataset) == 40 + + +def test_delete_multiple_episodes(sample_dataset, tmp_path): + """Test deleting multiple episodes.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[1, 3], + output_dir=output_dir, + ) + + assert new_dataset.meta.total_episodes == 3 + assert new_dataset.meta.total_frames == 30 + + episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]} + assert episode_indices == {0, 1, 2} + + +def test_delete_invalid_episodes(sample_dataset, tmp_path): + """Test error handling for invalid episode indices.""" + with pytest.raises(ValueError, match="Invalid episode indices"): + delete_episodes( + sample_dataset, + episode_indices=[10, 20], + output_dir=tmp_path / "filtered", + ) + + +def test_delete_all_episodes(sample_dataset, tmp_path): + """Test error when trying to delete all episodes.""" + with pytest.raises(ValueError, match="Cannot delete all episodes"): + delete_episodes( + sample_dataset, + episode_indices=list(range(5)), + output_dir=tmp_path / "filtered", + ) + + +def test_delete_empty_list(sample_dataset, tmp_path): + """Test error when no episodes specified.""" + with pytest.raises(ValueError, match="No episodes to delete"): + delete_episodes( + sample_dataset, + episode_indices=[], + output_dir=tmp_path / "filtered", + ) + + +def test_split_by_episodes(sample_dataset, tmp_path): + """Test splitting dataset by specific episode indices.""" + splits = { + "train": [0, 1, 2], + "val": [3, 4], + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + if "train" in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_train") + elif "val" in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_val") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + assert set(result.keys()) == {"train", "val"} + + assert result["train"].meta.total_episodes == 3 + assert result["train"].meta.total_frames == 30 + + assert result["val"].meta.total_episodes == 2 + assert result["val"].meta.total_frames == 20 + + train_episodes = {int(idx.item()) for idx in result["train"].hf_dataset["episode_index"]} + assert train_episodes == {0, 1, 2} + + val_episodes = {int(idx.item()) for idx in result["val"].hf_dataset["episode_index"]} + assert val_episodes == {0, 1} + + +def test_split_by_fractions(sample_dataset, tmp_path): + """Test splitting dataset by fractions.""" + splits = { + "train": 0.6, + "val": 0.4, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + for split_name in splits: + if split_name in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + assert result["train"].meta.total_episodes == 3 + assert result["val"].meta.total_episodes == 2 + + +def test_split_overlapping_episodes(sample_dataset, tmp_path): + """Test error when episodes appear in multiple splits.""" + splits = { + "train": [0, 1, 2], + "val": [2, 3, 4], + } + + with pytest.raises(ValueError, match="Episodes cannot appear in multiple splits"): + split_dataset(sample_dataset, splits=splits, output_dir=tmp_path) + + +def test_split_invalid_fractions(sample_dataset, tmp_path): + """Test error when fractions sum to more than 1.""" + splits = { + "train": 0.7, + "val": 0.5, + } + + with pytest.raises(ValueError, match="Split fractions must sum to <= 1.0"): + split_dataset(sample_dataset, splits=splits, output_dir=tmp_path) + + +def test_split_empty(sample_dataset, tmp_path): + """Test error with empty splits.""" + with pytest.raises(ValueError, match="No splits provided"): + split_dataset(sample_dataset, splits={}, output_dir=tmp_path) + + +def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory): + """Test merging two datasets.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + dataset2 = empty_lerobot_dataset_factory( + root=tmp_path / "test_dataset2", + features=features, + ) + + for ep_idx in range(3): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx % 2}", + } + dataset2.add_frame(frame) + dataset2.save_episode() + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") + + merged = merge_datasets( + [sample_dataset, dataset2], + output_repo_id="merged_dataset", + output_dir=tmp_path / "merged_dataset", + ) + + assert merged.meta.total_episodes == 8 # 5 + 3 + assert merged.meta.total_frames == 80 # 50 + 30 + + episode_indices = sorted({int(idx.item()) for idx in merged.hf_dataset["episode_index"]}) + assert episode_indices == list(range(8)) + + +def test_merge_empty_list(tmp_path): + """Test error when merging empty list.""" + with pytest.raises(ValueError, match="No datasets to merge"): + merge_datasets([], output_repo_id="merged", output_dir=tmp_path) + + +def test_add_feature_with_values(sample_dataset, tmp_path): + """Test adding a feature with pre-computed values.""" + num_frames = sample_dataset.meta.total_frames + reward_values = np.random.randn(num_frames, 1).astype(np.float32) + + feature_info = { + "dtype": "float32", + "shape": (1,), + "names": None, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "with_reward") + + new_dataset = add_feature( + sample_dataset, + feature_name="reward", + feature_values=reward_values, + feature_info=feature_info, + output_dir=tmp_path / "with_reward", + ) + + assert "reward" in new_dataset.meta.features + assert new_dataset.meta.features["reward"] == feature_info + + assert len(new_dataset) == num_frames + sample_item = new_dataset[0] + assert "reward" in sample_item + assert isinstance(sample_item["reward"], torch.Tensor) + + +def test_add_feature_with_callable(sample_dataset, tmp_path): + """Test adding a feature with a callable.""" + + def compute_reward(frame_dict, episode_idx, frame_idx): + return float(episode_idx * 10 + frame_idx) + + feature_info = { + "dtype": "float32", + "shape": (1,), + "names": None, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "with_reward") + + new_dataset = add_feature( + sample_dataset, + feature_name="reward", + feature_values=compute_reward, + feature_info=feature_info, + output_dir=tmp_path / "with_reward", + ) + + assert "reward" in new_dataset.meta.features + + items = [new_dataset[i] for i in range(10)] + first_episode_items = [item for item in items if item["episode_index"] == 0] + assert len(first_episode_items) == 10 + + first_frame = first_episode_items[0] + assert first_frame["frame_index"] == 0 + assert float(first_frame["reward"]) == 0.0 + + +def test_add_existing_feature(sample_dataset, tmp_path): + """Test error when adding an existing feature.""" + feature_info = {"dtype": "float32", "shape": (1,)} + + with pytest.raises(ValueError, match="Feature 'action' already exists"): + add_feature( + sample_dataset, + feature_name="action", + feature_values=np.zeros(50), + feature_info=feature_info, + output_dir=tmp_path / "modified", + ) + + +def test_add_feature_invalid_info(sample_dataset, tmp_path): + """Test error with invalid feature info.""" + with pytest.raises(ValueError, match="feature_info must contain keys"): + add_feature( + sample_dataset, + feature_name="reward", + feature_values=np.zeros(50), + feature_info={"dtype": "float32"}, + output_dir=tmp_path / "modified", + ) + + +def test_remove_single_feature(sample_dataset, tmp_path): + """Test removing a single feature.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset_with_reward = add_feature( + sample_dataset, + feature_name="reward", + feature_values=np.random.randn(50, 1).astype(np.float32), + feature_info=feature_info, + output_dir=tmp_path / "with_reward", + ) + + dataset_without_reward = remove_feature( + dataset_with_reward, + feature_names="reward", + output_dir=tmp_path / "without_reward", + ) + + assert "reward" not in dataset_without_reward.meta.features + + sample_item = dataset_without_reward[0] + assert "reward" not in sample_item + + +def test_remove_multiple_features(sample_dataset, tmp_path): + """Test removing multiple features at once.""" + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset = sample_dataset + for feature_name in ["reward", "success"]: + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + dataset = add_feature( + dataset, + feature_name=feature_name, + feature_values=np.random.randn(dataset.meta.total_frames, 1).astype(np.float32), + feature_info=feature_info, + output_dir=tmp_path / f"with_{feature_name}", + ) + + dataset_clean = remove_feature( + dataset, + feature_names=["reward", "success"], + output_dir=tmp_path / "clean", + ) + + assert "reward" not in dataset_clean.meta.features + assert "success" not in dataset_clean.meta.features + + +def test_remove_nonexistent_feature(sample_dataset, tmp_path): + """Test error when removing non-existent feature.""" + with pytest.raises(ValueError, match="Feature 'nonexistent' not found"): + remove_feature( + sample_dataset, + feature_names="nonexistent", + output_dir=tmp_path / "modified", + ) + + +def test_remove_required_feature(sample_dataset, tmp_path): + """Test error when trying to remove required features.""" + with pytest.raises(ValueError, match="Cannot remove required features"): + remove_feature( + sample_dataset, + feature_names="timestamp", + output_dir=tmp_path / "modified", + ) + + +def test_remove_camera_feature(sample_dataset, tmp_path): + """Test removing a camera feature.""" + camera_keys = sample_dataset.meta.camera_keys + if not camera_keys: + pytest.skip("No camera keys in dataset") + + camera_to_remove = camera_keys[0] + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "without_camera") + + dataset_without_camera = remove_feature( + sample_dataset, + feature_names=camera_to_remove, + output_dir=tmp_path / "without_camera", + ) + + assert camera_to_remove not in dataset_without_camera.meta.features + assert camera_to_remove not in dataset_without_camera.meta.camera_keys + + sample_item = dataset_without_camera[0] + assert camera_to_remove not in sample_item + + +def test_complex_workflow_integration(sample_dataset, tmp_path): + """Test a complex workflow combining multiple operations.""" + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset = add_feature( + sample_dataset, + feature_name="reward", + feature_values=np.random.randn(50, 1).astype(np.float32), + feature_info={"dtype": "float32", "shape": (1,), "names": None}, + output_dir=tmp_path / "step1", + ) + + dataset = delete_episodes( + dataset, + episode_indices=[2], + output_dir=tmp_path / "step2", + ) + + splits = split_dataset( + dataset, + splits={"train": 0.75, "val": 0.25}, + output_dir=tmp_path / "step3", + ) + + merged = merge_datasets( + list(splits.values()), + output_repo_id="final_dataset", + output_dir=tmp_path / "step4", + ) + + assert merged.meta.total_episodes == 4 + assert merged.meta.total_frames == 40 + assert "reward" in merged.meta.features + + assert len(merged) == 40 + sample_item = merged[0] + assert "reward" in sample_item + + +def test_delete_episodes_preserves_stats(sample_dataset, tmp_path): + """Test that deleting episodes preserves statistics correctly.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[2], + output_dir=output_dir, + ) + + assert new_dataset.meta.stats is not None + for feature in ["action", "observation.state"]: + assert feature in new_dataset.meta.stats + assert "mean" in new_dataset.meta.stats[feature] + assert "std" in new_dataset.meta.stats[feature] + + +def test_delete_episodes_preserves_tasks(sample_dataset, tmp_path): + """Test that tasks are preserved correctly after deletion.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[0], + output_dir=output_dir, + ) + + assert new_dataset.meta.tasks is not None + assert len(new_dataset.meta.tasks) == 2 + + tasks_in_dataset = {str(item["task"]) for item in new_dataset} + assert len(tasks_in_dataset) > 0 + + +def test_split_three_ways(sample_dataset, tmp_path): + """Test splitting dataset into three splits.""" + splits = { + "train": 0.6, + "val": 0.2, + "test": 0.2, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + for split_name in splits: + if split_name in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + assert set(result.keys()) == {"train", "val", "test"} + assert result["train"].meta.total_episodes == 3 + assert result["val"].meta.total_episodes == 1 + assert result["test"].meta.total_episodes == 1 + + total_frames = sum(ds.meta.total_frames for ds in result.values()) + assert total_frames == sample_dataset.meta.total_frames + + +def test_split_preserves_stats(sample_dataset, tmp_path): + """Test that statistics are preserved when splitting.""" + splits = {"train": [0, 1, 2], "val": [3, 4]} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + for split_name in splits: + if split_name in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + for split_ds in result.values(): + assert split_ds.meta.stats is not None + for feature in ["action", "observation.state"]: + assert feature in split_ds.meta.stats + assert "mean" in split_ds.meta.stats[feature] + assert "std" in split_ds.meta.stats[feature] + + +def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory): + """Test merging three datasets.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + datasets = [sample_dataset] + + for i in range(2): + dataset = empty_lerobot_dataset_factory( + root=tmp_path / f"test_dataset{i + 2}", + features=features, + ) + + for ep_idx in range(2): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx}", + } + dataset.add_frame(frame) + dataset.save_episode() + + datasets.append(dataset) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") + + merged = merge_datasets( + datasets, + output_repo_id="merged_dataset", + output_dir=tmp_path / "merged_dataset", + ) + + assert merged.meta.total_episodes == 9 + assert merged.meta.total_frames == 90 + + +def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_factory): + """Test that statistics are computed for merged datasets.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + dataset2 = empty_lerobot_dataset_factory( + root=tmp_path / "test_dataset2", + features=features, + ) + + for ep_idx in range(3): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx % 2}", + } + dataset2.add_frame(frame) + dataset2.save_episode() + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") + + merged = merge_datasets( + [sample_dataset, dataset2], + output_repo_id="merged_dataset", + output_dir=tmp_path / "merged_dataset", + ) + + assert merged.meta.stats is not None + for feature in ["action", "observation.state"]: + assert feature in merged.meta.stats + assert "mean" in merged.meta.stats[feature] + assert "std" in merged.meta.stats[feature] + + +def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path): + """Test that adding a feature preserves existing stats.""" + num_frames = sample_dataset.meta.total_frames + reward_values = np.random.randn(num_frames, 1).astype(np.float32) + + feature_info = { + "dtype": "float32", + "shape": (1,), + "names": None, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "with_reward") + + new_dataset = add_feature( + sample_dataset, + feature_name="reward", + feature_values=reward_values, + feature_info=feature_info, + output_dir=tmp_path / "with_reward", + ) + + assert new_dataset.meta.stats is not None + for feature in ["action", "observation.state"]: + assert feature in new_dataset.meta.stats + assert "mean" in new_dataset.meta.stats[feature] + assert "std" in new_dataset.meta.stats[feature] + + +def test_remove_feature_updates_stats(sample_dataset, tmp_path): + """Test that removing a feature removes it from stats.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset_with_reward = add_feature( + sample_dataset, + feature_name="reward", + feature_values=np.random.randn(50, 1).astype(np.float32), + feature_info=feature_info, + output_dir=tmp_path / "with_reward", + ) + + dataset_without_reward = remove_feature( + dataset_with_reward, + feature_names="reward", + output_dir=tmp_path / "without_reward", + ) + + if dataset_without_reward.meta.stats: + assert "reward" not in dataset_without_reward.meta.stats + + +def test_delete_consecutive_episodes(sample_dataset, tmp_path): + """Test deleting consecutive episodes.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[1, 2, 3], + output_dir=output_dir, + ) + + assert new_dataset.meta.total_episodes == 2 + assert new_dataset.meta.total_frames == 20 + + episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) + assert episode_indices == [0, 1] + + +def test_delete_first_and_last_episodes(sample_dataset, tmp_path): + """Test deleting first and last episodes.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[0, 4], + output_dir=output_dir, + ) + + assert new_dataset.meta.total_episodes == 3 + assert new_dataset.meta.total_frames == 30 + + episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) + assert episode_indices == [0, 1, 2] + + +def test_split_all_episodes_assigned(sample_dataset, tmp_path): + """Test that all episodes can be explicitly assigned to splits.""" + splits = { + "split1": [0, 1], + "split2": [2, 3], + "split3": [4], + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + for split_name in splits: + if split_name in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + total_episodes = sum(ds.meta.total_episodes for ds in result.values()) + assert total_episodes == sample_dataset.meta.total_episodes From 0699b46d87ded6e2394f4144dd9c92b2a5e4f1b8 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Fri, 10 Oct 2025 20:41:37 +0200 Subject: [PATCH 06/24] refactor(envs): add custom-observation-size (#2167) --- src/lerobot/envs/configs.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 0daaaf9fd..7a979b864 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -50,6 +50,8 @@ class AlohaEnv(EnvConfig): fps: int = 50 episode_length: int = 400 obs_type: str = "pixels_agent_pos" + observation_height: int = 480 + observation_width: int = 640 render_mode: str = "rgb_array" features: dict[str, PolicyFeature] = field( default_factory=lambda: { @@ -67,10 +69,14 @@ class AlohaEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels": - self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) + self.features["top"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) + ) elif self.obs_type == "pixels_agent_pos": self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,)) - self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) + self.features["pixels/top"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) + ) @property def gym_kwargs(self) -> dict: @@ -91,6 +97,8 @@ class PushtEnv(EnvConfig): render_mode: str = "rgb_array" visualization_width: int = 384 visualization_height: int = 384 + observation_height: int = 384 + observation_width: int = 384 features: dict[str, PolicyFeature] = field( default_factory=lambda: { ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), @@ -108,7 +116,9 @@ class PushtEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels_agent_pos": - self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3)) + self.features["pixels"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) + ) elif self.obs_type == "environment_state_agent_pos": self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,)) @@ -255,6 +265,8 @@ class LiberoEnv(EnvConfig): camera_name: str = "agentview_image,robot0_eye_in_hand_image" init_states: bool = True camera_name_mapping: dict[str, str] | None = None + observation_height: int = 360 + observation_width: int = 360 features: dict[str, PolicyFeature] = field( default_factory=lambda: { ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)), @@ -272,18 +284,18 @@ class LiberoEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels": self.features["pixels/agentview_image"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(360, 360, 3) + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) ) self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(360, 360, 3) + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) ) elif self.obs_type == "pixels_agent_pos": self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,)) self.features["pixels/agentview_image"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(360, 360, 3) + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) ) self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(360, 360, 3) + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) ) else: raise ValueError(f"Unsupported obs_type: {self.obs_type}") From 25f60c301b6201b0eeb7bff2787c299f79a0dc40 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Sat, 11 Oct 2025 00:15:42 +0200 Subject: [PATCH 07/24] use TeleopEvents.RERECORD_EPISODE in gym_manipulator (#2165) Co-authored-by: Michel Aractingi --- src/lerobot/rl/gym_manipulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index ad36f1b36..f9c9d0d7a 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -696,7 +696,7 @@ def control_loop( episode_idx += 1 if dataset is not None: - if transition[TransitionKey.INFO].get("rerecord_episode", False): + if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False): logging.info(f"Re-recording episode {episode_idx}") dataset.clear_episode_buffer() episode_idx -= 1 From f2ff370459a9027319a8ab405fbe0d7c019a327e Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Sat, 11 Oct 2025 11:01:30 +0200 Subject: [PATCH 08/24] Incremental parquet writing (#1903) * incremental parquet writing * add .finalise() and a backup __del__ for stopping writers * fix missing import * precommit fixes added back the use of embed images * added lazy loading for hf_Dataset to avoid frequently reloading the dataset during recording * fix bug in video timestamps * Added proper closing of parquet file before reading * Added rigorous testing to validate the consistency of the meta data after creation of a new dataset * fix bug in episode index during clear_episode_buffer * fix(empty concat): check for empty paths list before data files concatenation * fix(v3.0 message): updating v3.0 backward compatibility message. * added fixes for the resume logic * answering co-pilot review * reverting some changes and style nits * removed unused functions * fix chunk_id and file_id when resuming * - fix parquet loading when resuming - add test to verify the parquet file integrity when resuming so that data files are now overwritten * added general function get_file_size_in_mb and removed the one for video * fix table size value when resuming * Remove unnecessary reloading of the parquet file when resuming record. Write to a new parquet file when resuming record * added back reading parquet file for image datasets only * - respond to Qlhoest comments - Use pyarrows `from_pydict` function - Add buffer for episode metadata to write to the parquet file in batches to improve efficiency - Remove the use of `to_parquet_with_hf_images` * fix(dataset_tools) with the new logic using proper finalize bug in finding the latest path of the metdata that was pointing to the data files added check for the metadata size in the case the metadatabuffer was not written yet * nit in flush_metadata_buffer * fix(lerobot_dataset) return the right dataset len when a subset of the dataset is requested --------- Co-authored-by: Harsimrat Sandhawalia --- src/lerobot/datasets/aggregate.py | 8 +- src/lerobot/datasets/dataset_tools.py | 11 + src/lerobot/datasets/lerobot_dataset.py | 343 +++++++++++++++++------- src/lerobot/datasets/utils.py | 18 +- src/lerobot/datasets/video_utils.py | 3 + src/lerobot/rl/buffer.py | 1 + tests/datasets/test_dataset_tools.py | 4 + tests/datasets/test_datasets.py | 143 ++++++++++ 8 files changed, 421 insertions(+), 110 deletions(-) diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index e7ea59ed0..870c9571e 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -31,8 +31,8 @@ from lerobot.datasets.utils import ( DEFAULT_EPISODES_PATH, DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, + get_file_size_in_mb, get_parquet_file_size_in_mb, - get_video_size_in_mb, to_parquet_with_hf_images, update_chunk_file_indices, write_info, @@ -217,6 +217,7 @@ def aggregate_datasets( robot_type=robot_type, features=features, root=aggr_root, + use_videos=len(video_keys) > 0, chunks_size=chunk_size, data_files_size_in_mb=data_files_size_in_mb, video_files_size_in_mb=video_files_size_in_mb, @@ -307,8 +308,9 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu current_offset += src_duration continue - src_size = get_video_size_in_mb(src_path) - dst_size = get_video_size_in_mb(dst_path) + # Check file sizes before appending + src_size = get_file_size_in_mb(src_path) + dst_size = get_file_size_in_mb(dst_path) if dst_size + src_size >= video_files_size_in_mb: # Rotate to a new file, this source becomes start of new destination diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index fdeb24a72..8ebc4a59d 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -42,6 +42,7 @@ from lerobot.datasets.utils import ( DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, get_parquet_file_size_in_mb, + load_episodes, to_parquet_with_hf_images, update_chunk_file_indices, write_info, @@ -436,6 +437,9 @@ def _copy_and_reindex_data( Returns: dict mapping episode index to its data file metadata (chunk_index, file_index, etc.) """ + if src_dataset.meta.episodes is None: + src_dataset.meta.episodes = load_episodes(src_dataset.meta.root) + file_to_episodes: dict[Path, set[int]] = {} for old_idx in episode_mapping: file_path = src_dataset.meta.get_data_file_path(old_idx) @@ -645,6 +649,8 @@ def _copy_and_reindex_videos( Returns: dict mapping episode index to its video metadata (chunk_index, file_index, timestamps) """ + if src_dataset.meta.episodes is None: + src_dataset.meta.episodes = load_episodes(src_dataset.meta.root) episodes_video_metadata: dict[int, dict] = {new_idx: {} for new_idx in episode_mapping.values()} @@ -770,6 +776,9 @@ def _copy_and_reindex_episodes_metadata( """ from lerobot.datasets.utils import flatten_dict + if src_dataset.meta.episodes is None: + src_dataset.meta.episodes = load_episodes(src_dataset.meta.root) + all_stats = [] total_frames = 0 @@ -831,6 +840,8 @@ def _copy_and_reindex_episodes_metadata( total_frames += src_episode["length"] + dst_meta._close_writer() + dst_meta.info.update( { "total_episodes": len(episode_mapping), diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 229d37641..ae142c1e8 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -import gc import logging import shutil import tempfile @@ -26,6 +25,8 @@ import numpy as np import packaging.version import pandas as pd import PIL.Image +import pyarrow as pa +import pyarrow.parquet as pq import torch import torch.utils from huggingface_hub import HfApi, snapshot_download @@ -46,13 +47,9 @@ from lerobot.datasets.utils import ( embed_images, flatten_dict, get_delta_indices, - get_hf_dataset_cache_dir, - get_hf_dataset_size_in_mb, + get_file_size_in_mb, get_hf_features_from_features, - get_parquet_file_size_in_mb, - get_parquet_num_frames, get_safe_version, - get_video_size_in_mb, hf_transform_to_torch, is_valid_version, load_episodes, @@ -60,7 +57,6 @@ from lerobot.datasets.utils import ( load_nested_dataset, load_stats, load_tasks, - to_parquet_with_hf_images, update_chunk_file_indices, validate_episode_buffer, validate_frame, @@ -90,10 +86,15 @@ class LeRobotDatasetMetadata: root: str | Path | None = None, revision: str | None = None, force_cache_sync: bool = False, + metadata_buffer_size: int = 10, ): self.repo_id = repo_id self.revision = revision if revision else CODEBASE_VERSION self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + self.writer = None + self.latest_episode = None + self.metadata_buffer: list[dict] = [] + self.metadata_buffer_size = metadata_buffer_size try: if force_cache_sync: @@ -107,6 +108,54 @@ class LeRobotDatasetMetadata: self.pull_from_repo(allow_patterns="meta/") self.load_metadata() + def _flush_metadata_buffer(self) -> None: + """Write all buffered episode metadata to parquet file.""" + if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0: + return + + combined_dict = {} + for episode_dict in self.metadata_buffer: + for key, value in episode_dict.items(): + if key not in combined_dict: + combined_dict[key] = [] + # Extract value and serialize numpy arrays + # because PyArrow's from_pydict function doesn't support numpy arrays + val = value[0] if isinstance(value, list) else value + combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val) + + first_ep = self.metadata_buffer[0] + chunk_idx = first_ep["meta/episodes/chunk_index"][0] + file_idx = first_ep["meta/episodes/file_index"][0] + + table = pa.Table.from_pydict(combined_dict) + + if not self.writer: + path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)) + path.parent.mkdir(parents=True, exist_ok=True) + self.writer = pq.ParquetWriter( + path, schema=table.schema, compression="snappy", use_dictionary=True + ) + + self.writer.write_table(table) + + self.latest_episode = self.metadata_buffer[-1] + self.metadata_buffer.clear() + + def _close_writer(self) -> None: + """Close and cleanup the parquet writer if it exists.""" + self._flush_metadata_buffer() + + writer = getattr(self, "writer", None) + if writer is not None: + writer.close() + self.writer = None + + def __del__(self): + """ + Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor + """ + self._close_writer() + def load_metadata(self): self.info = load_info(self.root) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) @@ -138,6 +187,12 @@ class LeRobotDatasetMetadata: return packaging.version.parse(self.info["codebase_version"]) def get_data_file_path(self, ep_index: int) -> Path: + if self.episodes is None: + self.episodes = load_episodes(self.root) + if ep_index >= len(self.episodes): + raise IndexError( + f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" + ) ep = self.episodes[ep_index] chunk_idx = ep["data/chunk_index"] file_idx = ep["data/file_index"] @@ -145,6 +200,12 @@ class LeRobotDatasetMetadata: return Path(fpath) def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: + if self.episodes is None: + self.episodes = load_episodes(self.root) + if ep_index >= len(self.episodes): + raise IndexError( + f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" + ) ep = self.episodes[ep_index] chunk_idx = ep[f"videos/{vid_key}/chunk_index"] file_idx = ep[f"videos/{vid_key}/file_index"] @@ -260,72 +321,75 @@ class LeRobotDatasetMetadata: write_tasks(self.tasks, self.root) def _save_episode_metadata(self, episode_dict: dict) -> None: - """Save episode metadata to a parquet file and update the Hugging Face dataset of episodes metadata. + """Buffer episode metadata and write to parquet in batches for efficiency. - This function processes episodes metadata from a dictionary, converts it into a Hugging Face dataset, - and saves it as a parquet file. It handles both the creation of new parquet files and the - updating of existing ones based on size constraints. After saving the metadata, it reloads - the Hugging Face dataset to ensure it is up-to-date. + This function accumulates episode metadata in a buffer and flushes it when the buffer + reaches the configured size. This reduces I/O overhead by writing multiple episodes + at once instead of one row at a time. Notes: We both need to update parquet files and HF dataset: - `pandas` loads parquet file in RAM - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, or loads directly from pyarrow cache. """ - # Convert buffer into HF Dataset + # Convert to list format for each value episode_dict = {key: [value] for key, value in episode_dict.items()} - ep_dataset = datasets.Dataset.from_dict(episode_dict) - ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset) - df = pd.DataFrame(ep_dataset) num_frames = episode_dict["length"][0] - if self.episodes is None: + if self.latest_episode is None: # Initialize indices and frame count for a new dataset made of the first episode data chunk_idx, file_idx = 0, 0 - df["meta/episodes/chunk_index"] = [chunk_idx] - df["meta/episodes/file_index"] = [file_idx] - df["dataset_from_index"] = [0] - df["dataset_to_index"] = [num_frames] - else: - # Retrieve information from the latest parquet file - latest_ep = self.episodes[-1] - chunk_idx = latest_ep["meta/episodes/chunk_index"] - file_idx = latest_ep["meta/episodes/file_index"] + if self.episodes is not None and len(self.episodes) > 0: + # It means we are resuming recording, so we need to load the latest episode + # Update the indices to avoid overwriting the latest episode + chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"] + file_idx = self.episodes[-1]["meta/episodes/file_index"] + latest_num_frames = self.episodes[-1]["dataset_to_index"] + episode_dict["dataset_from_index"] = [latest_num_frames] + episode_dict["dataset_to_index"] = [latest_num_frames + num_frames] - latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - latest_size_in_mb = get_parquet_file_size_in_mb(latest_path) - - if latest_size_in_mb + ep_size_in_mb >= self.data_files_size_in_mb: - # Size limit is reached, prepare new parquet file + # When resuming, move to the next file chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + else: + episode_dict["dataset_from_index"] = [0] + episode_dict["dataset_to_index"] = [num_frames] + + episode_dict["meta/episodes/chunk_index"] = [chunk_idx] + episode_dict["meta/episodes/file_index"] = [file_idx] + else: + chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0] + file_idx = self.latest_episode["meta/episodes/file_index"][0] + + latest_path = ( + self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + if self.writer is None + else self.writer.where + ) + + if Path(latest_path).exists(): + latest_size_in_mb = get_file_size_in_mb(Path(latest_path)) + latest_num_frames = self.latest_episode["episode_index"][0] + + av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0 + + if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb: + # Size limit is reached, flush buffer and prepare new parquet file + self._flush_metadata_buffer() + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + self._close_writer() # Update the existing pandas dataframe with new row - df["meta/episodes/chunk_index"] = [chunk_idx] - df["meta/episodes/file_index"] = [file_idx] - df["dataset_from_index"] = [latest_ep["dataset_to_index"]] - df["dataset_to_index"] = [latest_ep["dataset_to_index"] + num_frames] + episode_dict["meta/episodes/chunk_index"] = [chunk_idx] + episode_dict["meta/episodes/file_index"] = [file_idx] + episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]] + episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames] - if latest_size_in_mb + ep_size_in_mb < self.data_files_size_in_mb: - # Size limit wasnt reached, concatenate latest dataframe with new one - latest_df = pd.read_parquet(latest_path) - df = pd.concat([latest_df, df], ignore_index=True) + # Add to buffer + self.metadata_buffer.append(episode_dict) + self.latest_episode = episode_dict - # Memort optimization - del latest_df - gc.collect() - - # Write the resulting dataframe from RAM to disk - path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - path.parent.mkdir(parents=True, exist_ok=True) - df.to_parquet(path, index=False) - - if self.episodes is not None: - # Remove the episodes cache directory, necessary to avoid cache bloat - cached_dir = get_hf_dataset_cache_dir(self.episodes) - if cached_dir is not None: - shutil.rmtree(cached_dir) - - self.episodes = load_episodes(self.root) + if len(self.metadata_buffer) >= self.metadata_buffer_size: + self._flush_metadata_buffer() def save_episode( self, @@ -438,6 +502,7 @@ class LeRobotDatasetMetadata: robot_type: str | None = None, root: str | Path | None = None, use_videos: bool = True, + metadata_buffer_size: int = 10, chunks_size: int | None = None, data_files_size_in_mb: int | None = None, video_files_size_in_mb: int | None = None, @@ -469,6 +534,10 @@ class LeRobotDatasetMetadata: raise ValueError() write_json(obj.info, obj.root / INFO_PATH) obj.revision = None + obj.writer = None + obj.latest_episode = None + obj.metadata_buffer = [] + obj.metadata_buffer_size = metadata_buffer_size return obj @@ -615,6 +684,8 @@ class LeRobotDataset(torch.utils.data.Dataset): # Unused attributes self.image_writer = None self.episode_buffer = None + self.writer = None + self.latest_episode = None self.root.mkdir(exist_ok=True, parents=True) @@ -623,6 +694,11 @@ class LeRobotDataset(torch.utils.data.Dataset): self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync ) + # Track dataset state for efficient incremental writing + self._lazy_loading = False + self._recorded_frames = self.meta.total_frames + self._writer_closed_for_reading = False + # Load actual data try: if force_cache_sync: @@ -641,6 +717,19 @@ class LeRobotDataset(torch.utils.data.Dataset): check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) + def _close_writer(self) -> None: + """Close and cleanup the parquet writer if it exists.""" + writer = getattr(self, "writer", None) + if writer is not None: + writer.close() + self.writer = None + + def __del__(self): + """ + Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor + """ + self._close_writer() + def push_to_hub( self, branch: str | None = None, @@ -781,8 +870,15 @@ class LeRobotDataset(torch.utils.data.Dataset): @property def num_frames(self) -> int: - """Number of frames in selected episodes.""" - return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames + """Number of frames in selected episodes. + + Note: When episodes a subset of the full dataset is requested, we must return the + actual loaded data length (len(self.hf_dataset)) rather than metadata total_frames. + self.meta.total_frames is the total number of frames in the full dataset. + """ + if self.episodes is not None and self.hf_dataset is not None: + return len(self.hf_dataset) + return self.meta.total_frames @property def num_episodes(self) -> int: @@ -860,10 +956,22 @@ class LeRobotDataset(torch.utils.data.Dataset): return item + def _ensure_hf_dataset_loaded(self): + """Lazy load the HF dataset only when needed for reading.""" + if self._lazy_loading or self.hf_dataset is None: + # Close the writer before loading to ensure parquet file is properly finalized + if self.writer is not None: + self._close_writer() + self._writer_closed_for_reading = True + self.hf_dataset = self.load_hf_dataset() + self._lazy_loading = False + def __len__(self): return self.num_frames def __getitem__(self, idx) -> dict: + # Ensure dataset is loaded when we actually need to read from it + self._ensure_hf_dataset_loaded() item = self.hf_dataset[idx] ep_idx = item["episode_index"].item() @@ -902,6 +1010,14 @@ class LeRobotDataset(torch.utils.data.Dataset): "})',\n" ) + def finalize(self): + """ + Close the parquet writers. This function needs to be called after data collection/conversion, else footer metadata won't be written to the parquet files. + The dataset won't be valid and can't be loaded as ds = LeRobotDataset(repo_id=repo, root=HF_LEROBOT_HOME.joinpath(repo)) + """ + self._close_writer() + self.meta._close_writer() + def create_episode_buffer(self, episode_index: int | None = None) -> dict: current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index ep_buffer = {} @@ -1109,74 +1225,101 @@ class LeRobotDataset(torch.utils.data.Dataset): ep_dict = {key: episode_buffer[key] for key in self.hf_features} ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train") ep_dataset = embed_images(ep_dataset) - ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset) ep_num_frames = len(ep_dataset) - df = pd.DataFrame(ep_dataset) - if self.meta.episodes is None: + if self.latest_episode is None: # Initialize indices and frame count for a new dataset made of the first episode data chunk_idx, file_idx = 0, 0 - latest_num_frames = 0 + global_frame_index = 0 + # However, if the episodes already exists + # It means we are resuming recording, so we need to load the latest episode + # Update the indices to avoid overwriting the latest episode + if self.meta.episodes is not None and len(self.meta.episodes) > 0: + latest_ep = self.meta.episodes[-1] + global_frame_index = latest_ep["dataset_to_index"] + chunk_idx = latest_ep["data/chunk_index"] + file_idx = latest_ep["data/file_index"] + + # When resuming, move to the next file + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) else: # Retrieve information from the latest parquet file - latest_ep = self.meta.episodes[-1] + latest_ep = self.latest_episode chunk_idx = latest_ep["data/chunk_index"] file_idx = latest_ep["data/file_index"] + global_frame_index = latest_ep["index"][-1] + 1 latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) - latest_size_in_mb = get_parquet_file_size_in_mb(latest_path) - latest_num_frames = get_parquet_num_frames(latest_path) + latest_size_in_mb = get_file_size_in_mb(latest_path) + + frames_in_current_file = global_frame_index - latest_ep["dataset_from_index"] + av_size_per_frame = ( + latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0 + ) # Determine if a new parquet file is needed - if latest_size_in_mb + ep_size_in_mb >= self.meta.data_files_size_in_mb: - # Size limit is reached, prepare new parquet file + if ( + latest_size_in_mb + av_size_per_frame * ep_num_frames >= self.meta.data_files_size_in_mb + or self._writer_closed_for_reading + ): + # Size limit is reached or writer was closed for reading, prepare new parquet file chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) - latest_num_frames = 0 - else: - # Update the existing parquet file with new rows - latest_df = pd.read_parquet(latest_path) - df = pd.concat([latest_df, df], ignore_index=True) + self._close_writer() + self._writer_closed_for_reading = False - # Memort optimization - del latest_df - gc.collect() + ep_dict["data/chunk_index"] = chunk_idx + ep_dict["data/file_index"] = file_idx # Write the resulting dataframe from RAM to disk path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) path.parent.mkdir(parents=True, exist_ok=True) - if len(self.meta.image_keys) > 0: - to_parquet_with_hf_images(df, path) - else: - df.to_parquet(path) - if self.hf_dataset is not None: - # Remove hf dataset cache directory, necessary to avoid cache bloat - cached_dir = get_hf_dataset_cache_dir(self.hf_dataset) - if cached_dir is not None: - shutil.rmtree(cached_dir) - - self.hf_dataset = self.load_hf_dataset() + table = ep_dataset.with_format("arrow")[:] + if not self.writer: + self.writer = pq.ParquetWriter( + path, schema=table.schema, compression="snappy", use_dictionary=True + ) + self.writer.write_table(table) metadata = { "data/chunk_index": chunk_idx, "data/file_index": file_idx, - "dataset_from_index": latest_num_frames, - "dataset_to_index": latest_num_frames + ep_num_frames, + "dataset_from_index": global_frame_index, + "dataset_to_index": global_frame_index + ep_num_frames, } + + # Store metadata with episode data for next episode + self.latest_episode = {**ep_dict, **metadata} + + # Mark that the HF dataset needs reloading (lazy loading approach) + # This avoids expensive reloading during sequential recording + self._lazy_loading = True + # Update recorded frames count for efficient length tracking + self._recorded_frames += ep_num_frames + return metadata def _save_episode_video(self, video_key: str, episode_index: int) -> dict: # Encode episode frames into a temporary video ep_path = self._encode_temporary_episode_video(video_key, episode_index) - ep_size_in_mb = get_video_size_in_mb(ep_path) + ep_size_in_mb = get_file_size_in_mb(ep_path) ep_duration_in_s = get_video_duration_in_s(ep_path) - if self.meta.episodes is None or ( - f"videos/{video_key}/chunk_index" not in self.meta.episodes.column_names - or f"videos/{video_key}/file_index" not in self.meta.episodes.column_names + if ( + episode_index == 0 + or self.meta.latest_episode is None + or f"videos/{video_key}/chunk_index" not in self.meta.latest_episode ): # Initialize indices for a new dataset made of the first episode data chunk_idx, file_idx = 0, 0 + if self.meta.episodes is not None and len(self.meta.episodes) > 0: + # It means we are resuming recording, so we need to load the latest episode + # Update the indices to avoid overwriting the latest episode + old_chunk_idx = self.meta.episodes[-1][f"videos/{video_key}/chunk_index"] + old_file_idx = self.meta.episodes[-1][f"videos/{video_key}/file_index"] + chunk_idx, file_idx = update_chunk_file_indices( + old_chunk_idx, old_file_idx, self.meta.chunks_size + ) latest_duration_in_s = 0.0 new_path = self.root / self.meta.video_path.format( video_key=video_key, chunk_index=chunk_idx, file_index=file_idx @@ -1184,16 +1327,16 @@ class LeRobotDataset(torch.utils.data.Dataset): new_path.parent.mkdir(parents=True, exist_ok=True) shutil.move(str(ep_path), str(new_path)) else: - # Retrieve information from the latest updated video file (possibly several episodes ago) - latest_ep = self.meta.episodes[episode_index - 1] - chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"] - file_idx = latest_ep[f"videos/{video_key}/file_index"] + # Retrieve information from the latest updated video file using latest_episode + latest_ep = self.meta.latest_episode + chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"][0] + file_idx = latest_ep[f"videos/{video_key}/file_index"][0] latest_path = self.root / self.meta.video_path.format( video_key=video_key, chunk_index=chunk_idx, file_index=file_idx ) - latest_size_in_mb = get_video_size_in_mb(latest_path) - latest_duration_in_s = get_video_duration_in_s(latest_path) + latest_size_in_mb = get_file_size_in_mb(latest_path) + latest_duration_in_s = latest_ep[f"videos/{video_key}/to_timestamp"][0] if latest_size_in_mb + ep_size_in_mb >= self.meta.video_files_size_in_mb: # Move temporary episode video to a new video file in the dataset @@ -1327,6 +1470,12 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.delta_timestamps = None obj.delta_indices = None obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() + obj.writer = None + obj.latest_episode = None + # Initialize tracking for incremental recording + obj._lazy_loading = False + obj._recorded_frames = 0 + obj._writer_closed_for_reading = False return obj diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 422a7010a..37d8432b2 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -94,12 +94,6 @@ def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int: return hf_ds.data.nbytes // (1024**2) -def get_hf_dataset_cache_dir(hf_ds: Dataset) -> Path | None: - if hf_ds.cache_files is None or len(hf_ds.cache_files) == 0: - return None - return Path(hf_ds.cache_files[0]["filename"]).parents[2] - - def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]: if file_idx == chunks_size - 1: file_idx = 0 @@ -133,10 +127,14 @@ def get_parquet_num_frames(parquet_path: str | Path) -> int: return metadata.num_rows -def get_video_size_in_mb(mp4_path: Path) -> float: - file_size_bytes = mp4_path.stat().st_size - file_size_mb = file_size_bytes / (1024**2) - return file_size_mb +def get_file_size_in_mb(file_path: Path) -> float: + """Get file size on disk in megabytes. + + Args: + file_path (Path): Path to the file. + """ + file_size_bytes = file_path.stat().st_size + return file_size_bytes / (1024**2) def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 620ba863a..740cdb602 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -642,6 +642,9 @@ class VideoEncodingManager: ) self.dataset._batch_save_episode_video(start_ep, end_ep) + # Finalize the dataset to properly close all writers + self.dataset.finalize() + # Clean up episode images if recording was interrupted if exc_type is not None: interrupted_episode_index = self.dataset.num_episodes diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index 917e4e2cc..81aa29c48 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -607,6 +607,7 @@ class ReplayBuffer: lerobot_dataset.save_episode() lerobot_dataset.stop_image_writer() + lerobot_dataset.finalize() return lerobot_dataset diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index fe117b35b..a9c04d6f2 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -55,6 +55,7 @@ def sample_dataset(tmp_path, empty_lerobot_dataset_factory): dataset.add_frame(frame) dataset.save_episode() + dataset.finalize() return dataset @@ -263,6 +264,7 @@ def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_fact } dataset2.add_frame(frame) dataset2.save_episode() + dataset2.finalize() with ( patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, @@ -685,6 +687,7 @@ def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_fa } dataset.add_frame(frame) dataset.save_episode() + dataset.finalize() datasets.append(dataset) @@ -728,6 +731,7 @@ def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_f } dataset2.add_frame(frame) dataset2.save_episode() + dataset2.finalize() with ( patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 2bc3bea43..e174c5789 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -806,6 +806,8 @@ def test_episode_index_distribution(tmp_path, empty_lerobot_dataset_factory): dataset.add_frame({"state": torch.randn(2), "task": f"task_{episode_idx}"}) dataset.save_episode() + dataset.finalize() + # Load the dataset and check episode indices loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) @@ -855,6 +857,8 @@ def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_fact dataset.add_frame({"state": torch.randn(3), ACTION: torch.randn(2), "task": tasks[episode_idx]}) dataset.save_episode() + dataset.finalize() + # Load and validate episode metadata loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) @@ -893,6 +897,8 @@ def test_data_consistency_across_episodes(tmp_path, empty_lerobot_dataset_factor dataset.add_frame({"state": torch.randn(1), "task": "consistency_test"}) dataset.save_episode() + dataset.finalize() + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) # Check data consistency - no gaps or overlaps @@ -944,6 +950,8 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory) dataset.add_frame({"state": state_data, ACTION: action_data, "task": "stats_test"}) dataset.save_episode() + dataset.finalize() + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) # Check that statistics exist for all features @@ -989,6 +997,8 @@ def test_episode_boundary_integrity(tmp_path, empty_lerobot_dataset_factory): dataset.add_frame({"state": torch.tensor([float(frame_idx)]), "task": f"episode_{episode_idx}"}) dataset.save_episode() + dataset.finalize() + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) # Test episode boundaries @@ -1031,6 +1041,8 @@ def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory): dataset.add_frame({"state": torch.randn(1), "task": task}) dataset.save_episode() + dataset.finalize() + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) # Check that all unique tasks are in the tasks metadata @@ -1056,3 +1068,134 @@ def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory): # Check total number of tasks assert loaded_dataset.meta.total_tasks == len(unique_tasks) + + +def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory): + """Test that resuming dataset recording preserves previously recorded episodes. + + This test validates the critical resume functionality by: + 1. Recording initial episodes and finalizing + 2. Reopening the dataset + 3. Recording additional episodes + 4. Verifying all data (old + new) is intact + + This specifically tests the bug fix where parquet files were being overwritten + instead of appended to during resume. + """ + features = { + "observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + "action": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + } + + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) + + initial_episodes = 2 + frames_per_episode = 3 + + for ep_idx in range(initial_episodes): + for frame_idx in range(frames_per_episode): + dataset.add_frame( + { + "observation.state": torch.tensor([float(ep_idx), float(frame_idx)]), + "action": torch.tensor([0.5, 0.5]), + "task": f"task_{ep_idx}", + } + ) + dataset.save_episode() + + assert dataset.meta.total_episodes == initial_episodes + assert dataset.meta.total_frames == initial_episodes * frames_per_episode + + dataset.finalize() + initial_root = dataset.root + initial_repo_id = dataset.repo_id + del dataset + + dataset_verify = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0") + assert dataset_verify.meta.total_episodes == initial_episodes + assert dataset_verify.meta.total_frames == initial_episodes * frames_per_episode + assert len(dataset_verify.hf_dataset) == initial_episodes * frames_per_episode + + for idx in range(len(dataset_verify.hf_dataset)): + item = dataset_verify[idx] + expected_ep = idx // frames_per_episode + expected_frame = idx % frames_per_episode + assert item["episode_index"].item() == expected_ep + assert item["frame_index"].item() == expected_frame + assert item["index"].item() == idx + assert item["observation.state"][0].item() == float(expected_ep) + assert item["observation.state"][1].item() == float(expected_frame) + + del dataset_verify + + # Phase 3: Resume recording - add more episodes + dataset_resumed = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0") + + assert dataset_resumed.meta.total_episodes == initial_episodes + assert dataset_resumed.meta.total_frames == initial_episodes * frames_per_episode + assert dataset_resumed.latest_episode is None # Not recording yet + assert dataset_resumed.writer is None + assert dataset_resumed.meta.writer is None + + additional_episodes = 2 + for ep_idx in range(initial_episodes, initial_episodes + additional_episodes): + for frame_idx in range(frames_per_episode): + dataset_resumed.add_frame( + { + "observation.state": torch.tensor([float(ep_idx), float(frame_idx)]), + "action": torch.tensor([0.5, 0.5]), + "task": f"task_{ep_idx}", + } + ) + dataset_resumed.save_episode() + + total_episodes = initial_episodes + additional_episodes + total_frames = total_episodes * frames_per_episode + assert dataset_resumed.meta.total_episodes == total_episodes + assert dataset_resumed.meta.total_frames == total_frames + + dataset_resumed.finalize() + del dataset_resumed + + dataset_final = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0") + + assert dataset_final.meta.total_episodes == total_episodes + assert dataset_final.meta.total_frames == total_frames + assert len(dataset_final.hf_dataset) == total_frames + + for idx in range(total_frames): + item = dataset_final[idx] + expected_ep = idx // frames_per_episode + expected_frame = idx % frames_per_episode + + assert item["episode_index"].item() == expected_ep, ( + f"Frame {idx}: wrong episode_index. Expected {expected_ep}, got {item['episode_index'].item()}" + ) + assert item["frame_index"].item() == expected_frame, ( + f"Frame {idx}: wrong frame_index. Expected {expected_frame}, got {item['frame_index'].item()}" + ) + assert item["index"].item() == idx, ( + f"Frame {idx}: wrong index. Expected {idx}, got {item['index'].item()}" + ) + + # Verify data integrity + assert item["observation.state"][0].item() == float(expected_ep), ( + f"Frame {idx}: wrong observation.state[0]. Expected {float(expected_ep)}, " + f"got {item['observation.state'][0].item()}" + ) + assert item["observation.state"][1].item() == float(expected_frame), ( + f"Frame {idx}: wrong observation.state[1]. Expected {float(expected_frame)}, " + f"got {item['observation.state'][1].item()}" + ) + + assert len(dataset_final.meta.episodes) == total_episodes + for ep_idx in range(total_episodes): + ep_metadata = dataset_final.meta.episodes[ep_idx] + assert ep_metadata["episode_index"] == ep_idx + assert ep_metadata["length"] == frames_per_episode + assert ep_metadata["tasks"] == [f"task_{ep_idx}"] + + expected_from = ep_idx * frames_per_episode + expected_to = (ep_idx + 1) * frames_per_episode + assert ep_metadata["dataset_from_index"] == expected_from + assert ep_metadata["dataset_to_index"] == expected_to From 0c79cf8f4ed4baa98db878ee7d2d091df447d878 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Sat, 11 Oct 2025 21:15:43 +0200 Subject: [PATCH 09/24] Add missing finalize calls in example (#2175) - add missing calls to dataset.finalize in the example recording scripts - add section in the dataset docs on calling dataset.finalize --- docs/source/lerobot-dataset-v3.mdx | 33 ++++++++++++++++++++++++++ examples/lekiwi/evaluate.py | 2 ++ examples/lekiwi/record.py | 2 ++ examples/phone_to_so100/evaluate.py | 2 ++ examples/phone_to_so100/record.py | 2 ++ examples/port_datasets/port_droid.py | 2 ++ examples/so100_to_so100_EE/evaluate.py | 2 ++ examples/so100_to_so100_EE/record.py | 2 ++ 8 files changed, 47 insertions(+) diff --git a/docs/source/lerobot-dataset-v3.mdx b/docs/source/lerobot-dataset-v3.mdx index cf1942fdc..3521914f2 100644 --- a/docs/source/lerobot-dataset-v3.mdx +++ b/docs/source/lerobot-dataset-v3.mdx @@ -279,3 +279,36 @@ python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id= Date: Mon, 13 Oct 2025 10:44:53 +0200 Subject: [PATCH 10/24] fix: very minor fix but hey devil is in details (#2168) Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- src/lerobot/policies/pi0/modeling_pi0.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index a2dcdaea3..596b273d5 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -897,7 +897,7 @@ class PI0Policy(PreTrainedPolicy): ) -> T: """Override the from_pretrained method to handle key remapping and display important disclaimer.""" print( - "The PI05 model is a direct port of the OpenPI implementation. \n" + "The PI0 model is a direct port of the OpenPI implementation. \n" "This implementation follows the original OpenPI structure for compatibility. \n" "Original implementation: https://github.com/Physical-Intelligence/openpi" ) From 6f5bb4d4a49fbdb47acfeaa2c190b5fa125f645a Mon Sep 17 00:00:00 2001 From: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Date: Mon, 13 Oct 2025 16:43:23 +0200 Subject: [PATCH 11/24] fix outdated example in docs (#2182) * fix outdated example Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> * Update docs/source/il_robots.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --------- Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/source/il_robots.mdx | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 91df14028..0d8fd56e5 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -513,13 +513,14 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import hw_to_dataset_features from lerobot.policies.act.modeling_act import ACTPolicy +from lerobot.policies.factory import make_pre_post_processors from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.scripts.lerobot_record import record_loop from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun -from lerobot.record import record_loop -from lerobot.policies.factory import make_processor + NUM_EPISODES = 5 FPS = 30 @@ -562,7 +563,7 @@ init_rerun(session_name="recording") # Connect the robot robot.connect() -preprocessor, postprocessor = make_processor( +preprocessor, postprocessor = make_pre_post_processors( policy_cfg=policy, pretrained_path=HF_MODEL_ID, dataset_stats=dataset.meta.stats, From 3ce5bcf24ddd39ff3978ff21e1741d19d96cf5aa Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 14 Oct 2025 14:00:52 +0200 Subject: [PATCH 12/24] feat(deps): add setuptools dependency (#2187) --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index a70208cb2..44ca596b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ dependencies = [ "huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0", # Core dependencies + "setuptools>=71.0.0,<81.0.0", "cmake>=3.29.0.1,<4.2.0", "einops>=0.8.0,<0.9.0", "opencv-python-headless>=4.9.0,<4.13.0", From bf6ac5e110d1cc8f0e0d0a6909c0c2f95e310bf0 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 14 Oct 2025 14:36:32 +0200 Subject: [PATCH 13/24] fix(datasets): conversion script function naming (#2199) Co-authored-by: gagalo123 --- src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py index 42ab2f642..b8ae29ad6 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -69,9 +69,9 @@ from lerobot.datasets.utils import ( LEGACY_TASKS_PATH, cast_stats_to_numpy, flatten_dict, + get_file_size_in_mb, get_parquet_file_size_in_mb, get_parquet_num_frames, - get_video_size_in_mb, load_info, update_chunk_file_indices, write_episodes, @@ -310,7 +310,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f episodes_metadata = [] for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"): - ep_size_in_mb = get_video_size_in_mb(ep_path) + ep_size_in_mb = get_file_size_in_mb(ep_path) ep_duration_in_s = get_video_duration_in_s(ep_path) # Check if adding this episode would exceed the limit From 723013c71bf5c1c05471e7617d859478b9d72d6c Mon Sep 17 00:00:00 2001 From: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Date: Tue, 14 Oct 2025 15:47:32 +0200 Subject: [PATCH 14/24] feat(scripts): Introduce `build_inference_frame`/`make_robot_action` util to easily allow API-based Inference (#2143) * fix: expose a function explicitly building a frame for inference * fix: first make dataset frame, then make ready for inference * fix: reducing reliance on lerobot record for policy's ouptuts too * fix: encapsulating squeezing out + device handling from predict action * fix: remove duplicated call to build_inference_frame and add a function to only perform data type handling (whole conversion is: keys matching + data type conversion) * fix(policies): right utils signature + docstrings (#2198) --------- Co-authored-by: Steven Palma --- src/lerobot/policies/utils.py | 113 ++++++++++++++++++++++++++ src/lerobot/scripts/lerobot_record.py | 6 +- src/lerobot/utils/control_utils.py | 19 +---- 3 files changed, 117 insertions(+), 21 deletions(-) diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index 5a3994cdf..21b39a80e 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -16,10 +16,16 @@ import logging from collections import deque +from typing import Any +import numpy as np import torch from torch import nn +from lerobot.datasets.utils import build_dataset_frame +from lerobot.processor import PolicyAction, RobotAction, RobotObservation +from lerobot.utils.constants import ACTION, OBS_STR + def populate_queues( queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None @@ -85,3 +91,110 @@ def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str]) logging.warning(f"Missing key(s) when loading model: {missing_keys}") if unexpected_keys: logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}") + + +# TODO(Steven): Move this function to a proper preprocessor step +def prepare_observation_for_inference( + observation: dict[str, np.ndarray], + device: torch.device, + task: str | None = None, + robot_type: str | None = None, +) -> RobotObservation: + """Converts observation data to model-ready PyTorch tensors. + + This function takes a dictionary of NumPy arrays, performs necessary + preprocessing, and prepares it for model inference. The steps include: + 1. Converting NumPy arrays to PyTorch tensors. + 2. Normalizing and permuting image data (if any). + 3. Adding a batch dimension to each tensor. + 4. Moving all tensors to the specified compute device. + 5. Adding task and robot type information to the dictionary. + + Args: + observation: A dictionary mapping observation names (str) to NumPy + array data. For images, the format is expected to be (H, W, C). + device: The PyTorch device (e.g., 'cpu' or 'cuda') to which the + tensors will be moved. + task: An optional string identifier for the current task. + robot_type: An optional string identifier for the robot being used. + + Returns: + A dictionary where values are PyTorch tensors preprocessed for + inference, residing on the target device. Image tensors are reshaped + to (C, H, W) and normalized to a [0, 1] range. + """ + for name in observation: + observation[name] = torch.from_numpy(observation[name]) + if "image" in name: + observation[name] = observation[name].type(torch.float32) / 255 + observation[name] = observation[name].permute(2, 0, 1).contiguous() + observation[name] = observation[name].unsqueeze(0) + observation[name] = observation[name].to(device) + + observation["task"] = task if task else "" + observation["robot_type"] = robot_type if robot_type else "" + + return observation + + +def build_inference_frame( + observation: dict[str, Any], + device: torch.device, + ds_features: dict[str, dict], + task: str | None = None, + robot_type: str | None = None, +) -> RobotObservation: + """Constructs a model-ready observation tensor dict from a raw observation. + + This utility function orchestrates the process of converting a raw, + unstructured observation from an environment into a structured, + tensor-based format suitable for passing to a policy model. + + Args: + observation: The raw observation dictionary, which may contain + superfluous keys. + device: The target PyTorch device for the final tensors. + ds_features: A configuration dictionary that specifies which features + to extract from the raw observation. + task: An optional string identifier for the current task. + robot_type: An optional string identifier for the robot being used. + + Returns: + A dictionary of preprocessed tensors ready for model inference. + """ + # Extracts the correct keys from the incoming raw observation + observation = build_dataset_frame(ds_features, observation, prefix=OBS_STR) + + # Performs the necessary conversions to the observation + observation = prepare_observation_for_inference(observation, device, task, robot_type) + + return observation + + +def make_robot_action(action_tensor: PolicyAction, ds_features: dict[str, dict]) -> RobotAction: + """Converts a policy's output tensor into a dictionary of named actions. + + This function translates the numerical output from a policy model into a + human-readable and robot-consumable format, where each dimension of the + action tensor is mapped to a named motor or actuator command. + + Args: + action_tensor: A PyTorch tensor representing the policy's action, + typically with a batch dimension (e.g., shape [1, action_dim]). + ds_features: A configuration dictionary containing metadata, including + the names corresponding to each index of the action tensor. + + Returns: + A dictionary mapping action names (e.g., "joint_1_motor") to their + corresponding floating-point values, ready to be sent to a robot + controller. + """ + # TODO(Steven): Check if these steps are already in all postprocessor policies + action_tensor = action_tensor.squeeze(0) + action_tensor = action_tensor.to("cpu") + + action_names = ds_features[ACTION]["names"] + act_processed_policy: RobotAction = { + f"{name}": float(action_tensor[i]) for i, name in enumerate(action_names) + } + return act_processed_policy diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 55846ff63..6df92d893 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -79,6 +79,7 @@ from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts from lerobot.datasets.video_utils import VideoEncodingManager from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.utils import make_robot_action from lerobot.processor import ( PolicyAction, PolicyProcessorPipeline, @@ -316,10 +317,7 @@ def record_loop( robot_type=robot.robot_type, ) - action_names = dataset.features[ACTION]["names"] - act_processed_policy: RobotAction = { - f"{name}": float(action_values[i]) for i, name in enumerate(action_names) - } + act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features) elif policy is None and isinstance(teleop, Teleoperator): act = teleop.get_action() diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py index 17371921c..7cfe177ef 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/utils/control_utils.py @@ -31,6 +31,7 @@ from deepdiff import DeepDiff from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import DEFAULT_FEATURES from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.utils import prepare_observation_for_inference from lerobot.processor import PolicyAction, PolicyProcessorPipeline from lerobot.robots import Robot @@ -102,17 +103,7 @@ def predict_action( torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), ): # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension - for name in observation: - observation[name] = torch.from_numpy(observation[name]) - if "image" in name: - observation[name] = observation[name].type(torch.float32) / 255 - observation[name] = observation[name].permute(2, 0, 1).contiguous() - observation[name] = observation[name].unsqueeze(0) - observation[name] = observation[name].to(device) - - observation["task"] = task if task else "" - observation["robot_type"] = robot_type if robot_type else "" - + observation = prepare_observation_for_inference(observation, device, task, robot_type) observation = preprocessor(observation) # Compute the next action with the policy @@ -121,12 +112,6 @@ def predict_action( action = postprocessor(action) - # Remove batch dimension - action = action.squeeze(0) - - # Move to cpu, if not already the case - action = action.to("cpu") - return action From 6e8be57eb2a93418ecfed81a14bf48b937d42247 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 14 Oct 2025 16:00:42 +0200 Subject: [PATCH 15/24] chore(policies): deprecate pi0fast (#2203) --- src/lerobot/policies/factory.py | 19 +- .../policies/pi0fast/configuration_pi0fast.py | 153 --- .../policies/pi0fast/modeling_pi0fast.py | 980 ------------------ .../policies/pi0fast/processor_pi0fast.py | 92 -- .../templates/lerobot_modelcard_template.md | 2 - 5 files changed, 2 insertions(+), 1244 deletions(-) delete mode 100644 src/lerobot/policies/pi0fast/configuration_pi0fast.py delete mode 100644 src/lerobot/policies/pi0fast/modeling_pi0fast.py delete mode 100644 src/lerobot/policies/pi0fast/processor_pi0fast.py diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index ac76baf9f..cfb550ab2 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -31,7 +31,6 @@ from lerobot.envs.utils import env_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config -from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig @@ -58,7 +57,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: Args: name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", - "vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla". + "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla". Returns: The policy class corresponding to the given name. @@ -82,10 +81,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy return VQBeTPolicy - elif name == "pi0fast": - from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy - - return PI0FASTPolicy elif name == "pi0": from lerobot.policies.pi0.modeling_pi0 import PI0Policy @@ -119,7 +114,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: Args: policy_type: The type of the policy. Supported types include "tdmpc", - "diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla", + "diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla", "reward_classifier". **kwargs: Keyword arguments to be passed to the configuration class constructor. @@ -137,8 +132,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return ACTConfig(**kwargs) elif policy_type == "vqbet": return VQBeTConfig(**kwargs) - elif policy_type == "pi0fast": - return PI0FASTConfig(**kwargs) elif policy_type == "pi0": return PI0Config(**kwargs) elif policy_type == "pi05": @@ -260,14 +253,6 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) - elif isinstance(policy_cfg, PI0FASTConfig): - from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors - - processors = make_pi0fast_pre_post_processors( - config=policy_cfg, - dataset_stats=kwargs.get("dataset_stats"), - ) - elif isinstance(policy_cfg, PI0Config): from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors diff --git a/src/lerobot/policies/pi0fast/configuration_pi0fast.py b/src/lerobot/policies/pi0fast/configuration_pi0fast.py deleted file mode 100644 index cefd4e688..000000000 --- a/src/lerobot/policies/pi0fast/configuration_pi0fast.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass, field - -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.optim.optimizers import AdamWConfig -from lerobot.optim.schedulers import ( - CosineDecayWithWarmupSchedulerConfig, -) -from lerobot.utils.constants import OBS_IMAGES - - -@PreTrainedConfig.register_subclass("pi0fast") -@dataclass -class PI0FASTConfig(PreTrainedConfig): - # Input / output structure. - n_obs_steps: int = 1 - chunk_size: int = 10 - n_action_steps: int = 5 - - normalization_mapping: dict[str, NormalizationMode] = field( - default_factory=lambda: { - "VISUAL": NormalizationMode.IDENTITY, - "STATE": NormalizationMode.MEAN_STD, - "ACTION": NormalizationMode.MEAN_STD, - } - ) - - # Shorter state and action vectors will be padded - max_state_dim: int = 32 # 32 - max_action_dim: int = 32 # 32 - - # Image preprocessing - resize_imgs_with_padding: tuple[int, int] = (224, 224) - interpolate_like_pi: bool = False - - # Add empty images. Used by pi0_aloha_sim which adds the empty - # left and right wrist cameras in addition to the top camera. - empty_cameras: int = 0 - - # Converts the joint and gripper values from the standard Aloha space to - # the space used by the pi internal runtime which was used to train the base model. - adapt_to_pi_aloha: bool = False - - # Converts joint dimensions to deltas with respect to the current state before passing to the model. - # Gripper dimensions will remain in absolute values. - use_delta_joint_actions_aloha: bool = False - - # Tokenizer - tokenizer_max_length: int = 48 - - # Projector - proj_width: int = 1024 - - # Decoding - max_decoding_steps: int = 256 - fast_skip_tokens: int = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens - max_input_seq_len: int = 256 # 512 - - # Utils - use_cache: bool = True - - # Frozen parameters - freeze_vision_encoder: bool = True - freeze_lm_head: bool = True - - # Training presets - optimizer_lr: float = 1e-4 - optimizer_betas: tuple[float, float] = (0.9, 0.95) - optimizer_eps: float = 1e-8 - optimizer_weight_decay: float = 1e-5 - - scheduler_warmup_steps: int = 1_000 - scheduler_decay_steps: int = 30_000 - scheduler_decay_lr: float = 2.5e-6 - - checkpoint_path: str = None - - padding_side: str = "right" - - precision: str = "bfloat16" - grad_clip_norm: float = 1 - - # Allows padding/truncation of generated action tokens during detokenization to ensure decoding. - # In the original version, tensors of 0s were generated if shapes didn't match for stable decoding. - relaxed_action_decoding: bool = True - - def __post_init__(self): - super().__post_init__() - - """Input validation (not exhaustive).""" - if self.n_action_steps > self.chunk_size: - raise ValueError( - f"The chunk size is the upper bound for the number of action steps per model invocation. Got " - f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." - ) - if self.n_obs_steps != 1: - raise ValueError( - f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" - ) - - def validate_features(self) -> None: - for i in range(self.empty_cameras): - key = f"{OBS_IMAGES}.empty_camera_{i}" - empty_camera = PolicyFeature( - type=FeatureType.VISUAL, - shape=(3, 480, 640), - ) - self.input_features[key] = empty_camera - - def get_optimizer_preset(self) -> AdamWConfig: - return AdamWConfig( - lr=self.optimizer_lr, - betas=self.optimizer_betas, - eps=self.optimizer_eps, - weight_decay=self.optimizer_weight_decay, - grad_clip_norm=self.grad_clip_norm, - ) - - def get_scheduler_preset(self): - return CosineDecayWithWarmupSchedulerConfig( - peak_lr=self.optimizer_lr, - decay_lr=self.scheduler_decay_lr, - num_warmup_steps=self.scheduler_warmup_steps, - num_decay_steps=self.scheduler_decay_steps, - ) - - @property - def observation_delta_indices(self) -> None: - return None - - @property - def action_delta_indices(self) -> list: - return list(range(self.chunk_size)) - - @property - def reward_delta_indices(self) -> None: - return None diff --git a/src/lerobot/policies/pi0fast/modeling_pi0fast.py b/src/lerobot/policies/pi0fast/modeling_pi0fast.py deleted file mode 100644 index 102cfb8fa..000000000 --- a/src/lerobot/policies/pi0fast/modeling_pi0fast.py +++ /dev/null @@ -1,980 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models - -[Paper](https://huggingface.co/papers/2501.09747) -[Jax code](https://github.com/Physical-Intelligence/openpi) - -Designed by Physical Intelligence. Ported from Jax by Hugging Face. -Disclaimer: It is not expected to perform as well as the original implementation. - -Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`): -```bash -lerobot-train \ ---policy.path=lerobot/pi0fast_base \ ---dataset.repo_id=danaaubakirova/koch_test -``` - -Example of training the pi0+FAST neural network with from scratch: -```bash -lerobot-train \ ---policy.type=pi0fast \ ---dataset.repo_id=danaaubakirova/koch_test -``` - -Example of using the pi0 pretrained model outside LeRobot training framework: -```python -policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base") -``` - -""" - -from collections import deque -from functools import partial - -import numpy as np -import torch -import torch.nn.functional as F # noqa: N812 -from PIL import Image -from scipy.fft import idct -from torch import Tensor, nn -from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration -from transformers.cache_utils import HybridCache, StaticCache -from transformers.models.auto import CONFIG_MAPPING - -from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.constants import ACTION, OBS_STATE - -PRECISION = { - "float16": torch.float16, - "float32": torch.float32, - "bfloat16": torch.bfloat16, -} - - -def normalize(x, min_val, max_val): - return (x - min_val) / (max_val - min_val) - - -def unnormalize(x, min_val, max_val): - return x * (max_val - min_val) + min_val - - -def safe_arcsin(value): - # This ensures that the input stays within - # [−1,1] to avoid invalid values for arcsin - return torch.arcsin(torch.clamp(value, -1.0, 1.0)) - - -def aloha_gripper_to_angular(value): - # Aloha transforms the gripper positions into a linear space. The following code - # reverses this transformation to be consistent with pi0 which is pretrained in - # angular space. - # - # These values are coming from the Aloha code: - # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED - value = unnormalize(value, min_val=0.01844, max_val=0.05800) - - # This is the inverse of the angular to linear transformation inside the Interbotix code. - def linear_to_radian(linear_position, arm_length, horn_radius): - value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) - return safe_arcsin(value) - - # The constants are taken from the Interbotix code. - value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) - - # Normalize to [0, 1]. - # The values 0.4 and 1.5 were measured on an actual Trossen robot. - return normalize(value, min_val=0.4, max_val=1.5) - - -def aloha_gripper_from_angular(value): - # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. - # Note that the units are still angular but the range is different. - - # The values 0.4 and 1.5 were measured on an actual Trossen robot. - value = unnormalize(value, min_val=0.4, max_val=1.5) - - # These values are coming from the Aloha code: - # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE - return normalize(value, min_val=-0.6213, max_val=1.4910) - - -def aloha_gripper_from_angular_inv(value): - # Directly inverts the gripper_from_angular function. - value = unnormalize(value, min_val=-0.6213, max_val=1.4910) - return normalize(value, min_val=0.4, max_val=1.5) - - -class PI0FASTPolicy(PreTrainedPolicy): - """Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot.""" - - config_class = PI0FASTConfig - name = "pi0fast" - - def __init__( - self, - config: PI0FASTConfig, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, - ): - """ - Args: - config: Policy configuration class instance or None, in which case the default instantiation of - the configuration class is used. - dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected - that they will be passed with a call to `load_state_dict` before the policy is used. - """ - - super().__init__(config) - config.validate_features() - self.config = config - - self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224") - self.model = PI0FAST(config) - - self.reset() - - def reset(self): - """This should be called whenever the environment is reset.""" - self._action_queue = deque([], maxlen=self.config.n_action_steps) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - """Override the from_pretrained method to display important disclaimer.""" - print( - "⚠️ DISCLAIMER: The PI0FAST model is ported from JAX by the Hugging Face team. \n" - " It is not expected to perform as well as the original implementation. \n" - " Original implementation: https://github.com/Physical-Intelligence/openpi" - ) - return super().from_pretrained(*args, **kwargs) - - def get_optim_params(self) -> dict: - return self.parameters() - - def _pi_aloha_decode_state(self, state): - # Flip the joints. - for motor_idx in [1, 2, 8, 9]: - state[:, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) - return state - - def _pi_aloha_encode_actions(self, actions): - # Flip the joints. - for motor_idx in [1, 2, 8, 9]: - actions[:, :, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) - return actions - - def _pi_aloha_encode_actions_inv(self, actions): - # Flip the joints again. - for motor_idx in [1, 2, 8, 9]: - actions[:, :, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) - return actions - - @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: - """Predict a chunk of actions given environment observations.""" - raise NotImplementedError("Currently not implemented for PI0FAST") - - @torch.no_grad() - def select_action(self, batch: dict[str, Tensor]) -> Tensor: - """Select a single action given environment observations. - - This method wraps `select_actions` in order to return one action at a time for execution in the - environment. It works by managing the actions in a queue and only calling `select_actions` when the - queue is empty. - """ - self.eval() - - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - - # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by - # querying the policy. - if len(self._action_queue) == 0: - actions = self.model.generate_actions(batch) - - actions = actions[:, : self.config.n_action_steps] - - original_action_dim = self.config.action_feature.shape[ - 0 - ] # self.config.max_action_dim # self.config.action_feature.shape[0] - actions = actions[:, :, :original_action_dim] - - if self.config.adapt_to_pi_aloha: - actions = self._pi_aloha_encode_actions(actions) - - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(actions.transpose(0, 1)) - return self._action_queue.popleft() - - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) - loss_dict = self.model.forward(batch) - return loss_dict["loss"], loss_dict - - -def block_causal_update_causal_mask( - attention_mask, - token_type_ids=None, - past_key_values=None, - cache_position=None, - input_tensor=None, - attn_implementation: str = "eager", - dtype: torch.dtype = "float32", -): - """ - Update the causal mask during training and generation. It can be customized to different attention masks. - """ - if attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - using_static_cache = isinstance(past_key_values, StaticCache) - min_dtype = torch.finfo(dtype).min - - if input_tensor is None: - input_tensor = attention_mask - - inputs_lead_dim, sequence_length = input_tensor.shape[:2] - - if using_static_cache or isinstance(past_key_values, HybridCache): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else cache_position[0] + sequence_length + 1 - ) - - # Handle precomputed attention masks - if attention_mask is not None and attention_mask.dim() == 4: - return attention_mask - - # Causal mask initialization - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - - # Standard causal masking (triu ensures tokens can only attend to past) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - - # Apply block causal mask - if token_type_ids is not None: - token_type_ids = token_type_ids.to(causal_mask.device).bool() - cumsum = torch.cumsum(token_type_ids, dim=1) - block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None] - - # Combine causal_mask with block-wise attention mask - causal_mask = torch.where(block_causal_mask, 0.0, causal_mask) - causal_mask = causal_mask[:, None, :, :] - else: - # Apply past cache position constraint - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) - else: - # Apply past cache position constraint - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) - - if attention_mask is not None: - causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits - mask_length = attention_mask.shape[-1] - - # Apply padding mask - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -def prepare_inputs_for_generation( - # self, - input_ids, - past_key_values=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - pixel_values=None, - attention_mask=None, - token_type_ids=None, - use_cache=True, - num_logits_to_keep=None, - labels=None, - self=None, - **kwargs, -): - # create block causal attention - if cache_position[0] > 0 and input_ids.shape[1] > 0: - input_tensor = input_ids[:, -1:] - new_positions = ( - torch.ones( - (position_ids.shape[0], input_ids.shape[1]), - dtype=position_ids.dtype, - device=position_ids.device, - ).cumsum(-1) - + position_ids[:, -1:] - ) - position_ids = torch.cat([position_ids, new_positions], dim=-1) - else: - input_tensor = inputs_embeds - attention_mask = block_causal_update_causal_mask( - attention_mask=attention_mask, - past_key_values=past_key_values, - cache_position=cache_position, - input_tensor=input_tensor, - token_type_ids=token_type_ids, - dtype=self.dtype, - attn_implementation=self.config.text_config._attn_implementation, - ) - # Overwritten -- custom `position_ids` and `pixel_values` handling - model_inputs = self.language_model.prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - cache_position=cache_position, - use_cache=use_cache, - num_logits_to_keep=num_logits_to_keep, - token_type_ids=token_type_ids, - **kwargs, - ) - - # Position_ids in Paligemma are 1-indexed - if model_inputs.get("position_ids") is not None: - model_inputs["position_ids"] += 1 - # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore - # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always - if cache_position[0] == 0: - model_inputs["pixel_values"] = pixel_values - is_training = token_type_ids is not None and labels is not None - if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): - input_tensor = inputs_embeds if inputs_embeds is not None else input_ids - causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training - ) - model_inputs["attention_mask"] = causal_mask - - return model_inputs - - -class PI0FAST(nn.Module): - def __init__(self, config: PI0FASTConfig): - super().__init__() - self.config = config - - # TODO: move tokenizers in Policy - fast_tokenizer_path = "physical-intelligence/fast" - pi0_paligemma_path = "google/paligemma-3b-pt-224" - self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path) - self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path) - self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True) - self.fast_skip_tokens = self.config.fast_skip_tokens - self.max_input_seq_len = self.config.max_input_seq_len - self.action_horizon = self.config.chunk_size - self.action_dim = self.config.action_feature.shape[ - 0 - ] # self.config.max_action_dim # self.config.action_feature.shape[0] - precision = config.precision - torch_precision = PRECISION.get(precision, torch.float32) - self.pad_token_id = ( - self.paligemma_tokenizer.pad_token_id - if hasattr(self.paligemma_tokenizer, "pad_token_id") - else self.paligemma_tokenizer.eos_token_id - ) - - paligemma_config = CONFIG_MAPPING["paligemma"]( - transformers_version="4.48.1", - _vocab_size=257152, - bos_token_id=2, - eos_token_id=1, - hidden_size=2048, - image_token_index=257152, - model_type="paligemma", - pad_token_id=0, - projection_dim=2048, - text_config={ - "hidden_activation": "gelu_pytorch_tanh", - "hidden_size": 2048, - "intermediate_size": 16384, - "model_type": "gemma", - "num_attention_heads": 8, - "num_hidden_layers": 18, - "num_image_tokens": 256, - "num_key_value_heads": 1, - "torch_dtype": precision, - "vocab_size": 257152, - "_attn_implementation": "eager", - }, - vision_config={ - "hidden_size": 1152, - "intermediate_size": 4304, - "model_type": "siglip_vision_model", - "num_attention_heads": 16, - "num_hidden_layers": 27, - "num_image_tokens": 256, - "patch_size": 14, - "projection_dim": 2048, - "projector_hidden_act": "gelu_pytorch_tanh", - "torch_dtype": precision, - "vision_use_head": False, - }, - ) - self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config) - - self.pi0_paligemma.prepare_inputs_for_generation = partial( - prepare_inputs_for_generation, self=self.pi0_paligemma - ) - # change important stuff in bf16 - params_to_change_dtype = [ - "language_model", - "vision_tower", - "multi_modal", - ] - for name, param in self.pi0_paligemma.named_parameters(): - if any(selector in name for selector in params_to_change_dtype): - param.data = param.data.to(dtype=torch_precision) - self.set_requires_grad() - self.image_keys = self.config.image_features.keys() - # TODO: Remove this once we bump transformers to >4.52.0 because the attribute will be removed - # AttributeError: 'PaliGemmaConfig' object has no attribute 'ignore_index' - self.ignore_index = self.pi0_paligemma.config.ignore_index - self.padding_side = self.config.padding_side - - def set_requires_grad(self): - if self.config.freeze_vision_encoder: - self.pi0_paligemma.vision_tower.eval() - for params in self.pi0_paligemma.vision_tower.parameters(): - params.requires_grad = False - # To avoid unused params issue with distributed training - if self.config.freeze_lm_head: - for name, params in self.pi0_paligemma.named_parameters(): - if "embed_tokens" in name: # lm heads and embedding layer are tied - params.requires_grad = False - - def embed_tokens(self, tokens: torch.Tensor): - return self.pi0_paligemma.language_model.model.embed_tokens(tokens) - - def prepare_inputs_for_generation(self, *args, **kwargs): - return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs) - - def prepare_images(self, batch): - """Preprocess LeRobot batch into Pi0 inputs""" - images = [] - img_masks = [] - present_img_keys = [key for key in self.image_keys if key in batch] - if len(present_img_keys) == 0: - raise ValueError( - f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" - ) - - # Preprocess image features present in the batch - num_empty_cameras = 0 - for key in self.image_keys: - if key in present_img_keys: - img = batch[key] - - if self.config.resize_imgs_with_padding is not None: - img = resize_with_pad( - img, - *self.config.resize_imgs_with_padding, - pad_value=0, - interpolate_like_pi=self.config.interpolate_like_pi, - ) - - # Normalize from range [0,1] to [-1,1] as expected by siglip - img = img * 2.0 - 1.0 - - bsize = img.shape[0] - device = img.device - mask = torch.ones(bsize, dtype=torch.bool, device=device) - else: - if num_empty_cameras >= self.config.empty_cameras: - continue - img = torch.ones_like(img) * -1 - bsize = img.shape[0] - device = img.device - mask = torch.ones(bsize, dtype=torch.bool, device=device) - num_empty_cameras += 1 - - images.append(img) - img_masks.append(mask) - return images, img_masks - - def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor: - mins = actions.amin(dim=(1, 2), keepdim=True) # [0] - maxs = actions.amax(dim=(1, 2), keepdim=True) # [0] - return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1 - - def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor: - out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens - return out - - def fast_tokenizer_wrapper(self, actions_norm): - """ - A wrapper for self.fast_tokenizer that ensures batch processing, - conversion to PyTorch tensors, and returns a dictionary without padding. - """ - batch_tokens = self.fast_tokenizer(actions_norm) - fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt") - - return fast_out - - def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor: - token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool) - # Compute cumulative sum mask - cumsum_mask = (padded_mask != 0).cumsum(dim=1) - # Suffix block (everything after prefix_len) - suffix_mask = cumsum_mask > prefix_len - token_type_ids = suffix_mask - return token_type_ids - - def create_input_tokens(self, state, lang_text, actions=None): - bsize = state.shape[0] - device = state.device - bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1] - discretized = torch.bucketize(state, bins) - 1 - discretized = discretized[:, :32] - - prefix_texts = [] - state_text = [] - for txt, disc in zip(lang_text, discretized, strict=False): - cleaned = txt.lower().strip().replace("_", " ") - state_str = " ".join(str(val.item()) for val in disc) - prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n") - state_text.append(f"State: {state_str};\n") - - prefix_out = self.paligemma_tokenizer( - prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False - ) - prefix_ids = prefix_out["input_ids"].to(device) - prefix_mask = prefix_out["attention_mask"].to(device) - prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu() - - if actions is not None: - actions_norm = self.normalize_actions(actions) - actions_pad = F.pad( - actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0 - )[:, :, : self.config.max_action_dim] - fast_out = self.fast_tokenizer_wrapper( - actions_pad.cpu(), - ) - act_ids = fast_out["input_ids"] - act_mask = fast_out["attention_mask"].to(device) - - act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device) - # Replace action with 0 to pad tokens - act_ids = torch.where( - act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens, - self.pad_token_id, - act_ids, - ) - - eos_token = torch.tensor( - [self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device - ).expand(bsize, -1) - eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1) - bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt") - bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device) - bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device) - act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1) - act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1) - act_mask = act_mask.to(device) - else: - act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device) - act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device) - final_ids = torch.cat([prefix_ids, act_ids], dim=1) - - final_mask = torch.cat([prefix_mask, act_mask], dim=1) - batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()} - - # Use tokenizer pad function - padded_output = self.paligemma_tokenizer.pad( - batch_inputs, padding="longest", max_length=180, return_tensors="pt" - ) - padded_mask = padded_output["attention_mask"] - - # define tensor of padding lengths - att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens - - token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens) - - padded_output["padded_mask"] = padded_output.pop("attention_mask") - padded_output["attention_mask"] = att_mask - # loss is computed not on prefix, and not on padding - padded_output["loss_mask"] = att_mask & padded_output["padded_mask"] - padded_output["token_type_ids"] = token_type_ids - return padded_output - - def shift_padding_side( - self, - tokens: torch.Tensor, - ar_mask: torch.Tensor, - padding_mask: torch.Tensor, - loss_mask: torch.Tensor, - targets: torch.Tensor, - token_type_ids: torch.Tensor, - padding_side: str = "right", - ) -> tuple[torch.Tensor]: - if padding_side not in ["right", "left"]: - return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids - - new_tokens = torch.empty_like(tokens) - new_ar_masks = torch.empty_like(ar_mask) - new_padding_mask = torch.empty_like(padding_mask) - new_loss_mask = torch.empty_like(loss_mask) - new_targets = torch.empty_like(targets) - new_token_type_ids = torch.empty_like(token_type_ids) - batch_size = tokens.shape[0] - for i in range(batch_size): - padding_indices = torch.where(padding_mask[i] == 0)[0] - non_padding_indices = torch.where(padding_mask[i] == 1)[0] - if padding_side == "left": - new_indices = torch.cat((padding_indices, non_padding_indices), dim=0) - else: - new_indices = torch.cat((non_padding_indices, padding_indices), dim=0) - new_tokens[i] = tokens[i].index_select(0, new_indices) - new_ar_masks[i] = ar_mask[i].index_select(0, new_indices) - new_padding_mask[i] = padding_mask[i].index_select(0, new_indices) - new_loss_mask[i] = loss_mask[i].index_select(0, new_indices) - new_targets[i] = targets[i].index_select(0, new_indices) - new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices) - - return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids - - def forward(self, batch: dict[str, Tensor]): - device = batch[OBS_STATE].device - # TODO: keep like this or move to the policy .forward - images, img_masks = self.prepare_images(batch) - - padded_outs = self.create_input_tokens( - state=batch[OBS_STATE], - lang_text=batch["task"], - actions=batch[ACTION], - ) - - embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs( - images, - img_masks, - padded_outs["input_ids"], - padded_outs["padded_mask"], - padded_outs["attention_mask"], - padded_outs["loss_mask"], - padded_outs["token_type_ids"], - padding_side=self.padding_side, - ) - position_ids = torch.cumsum(pad_masks, dim=1) - 1 - token_type_ids = token_type_ids.to(dtype=torch.int64) - past_seen_tokens = 0 - cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device) - pad_masks = block_causal_update_causal_mask( - attention_mask=pad_masks, - past_key_values=None, - cache_position=cache_position, - input_tensor=embs, - token_type_ids=token_type_ids, - dtype=self.pi0_paligemma.dtype, - attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation, - ) - outputs = self.pi0_paligemma.forward( - input_ids=None, - token_type_ids=None, - attention_mask=pad_masks, - position_ids=position_ids, - past_key_values=None, - inputs_embeds=embs, - use_cache=False, - labels=None, - ) - - logits = outputs.logits - - loss_fct = nn.CrossEntropyLoss(reduction="none") - - # Shift left for next-step prediction - logits = logits[:, :-1, :] - targets = targets[:, 1:].to(device) # Shift targets - loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape - - # Compute per-token loss - token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1)) - - # Apply loss mask - token_loss = token_loss * loss_mask.reshape(-1) - - # Compute final loss - loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1) - - # Return loss dictionary - loss_dict = {"ce_loss": loss.item(), "loss": loss} - return loss_dict - - def decode_actions_with_fast( - self, - tokens: list[list[int]], - *, - time_horizon: int | None = None, - action_dim: int | None = None, - relaxed_decoding: bool = True, - ) -> np.array: - """ - Adapt original decoding in FAST to always return actions instead of zeros. - """ - self.time_horizon = ( - time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon - ) - self.action_dim = ( - action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim - ) - - # Cache the time horizon and action dimension for the next call - self.called_time_horizon = self.time_horizon - self.called_action_dim = self.action_dim - - assert self.time_horizon is not None and self.action_dim is not None, ( - "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim." - ) - - decoded_actions = [] - for token in tokens: - try: - decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token) - decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token - if relaxed_decoding: - # Expected sequence length - expected_seq_len = self.time_horizon * self.action_dim - diff = expected_seq_len - decoded_dct_coeff.shape[0] - # Apply truncation if too long - if diff < 0: - decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right - # Apply padding if too short - elif diff > 0: - decoded_dct_coeff = np.pad( - decoded_dct_coeff, (0, diff), mode="constant", constant_values=0 - ) - - decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim) - assert decoded_dct_coeff.shape == ( - self.time_horizon, - self.action_dim, - ), ( - f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})" - ) - except Exception as e: - print(f"Error decoding tokens: {e}") - print(f"Tokens: {token}") - decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim)) - decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho")) - return np.stack(decoded_actions) - - def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor: - """ - Extracts actions from predicted output tokens using the FAST model. - - Args: - tokens (torch.Tensor): The input tensor of tokenized outputs. - action_horizon (int): The number of timesteps for actions. - action_dim (int): The dimensionality of each action. - - Returns: - torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim). - """ - # Decode predicted output tokens - decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True) - cleaned_tokens = [ - tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip() - for tokens_sequence in decoded_tokens - ] - raw_action_tokens = [ - self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False) - for sample_tokens in cleaned_tokens - ] # something like this should be robust #looks good - action_tokens = [ - self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens - ] - # returns the tensor of decoded actions per sample in a list - decoded_actions = [ - torch.tensor( - self.decode_actions_with_fast( - tok.tolist(), - time_horizon=action_horizon, - action_dim=action_dim, - relaxed_decoding=self.config.relaxed_action_decoding, - ), - device=tokens.device, - ).squeeze(0) - for tok in action_tokens - ] - - return torch.stack( - decoded_actions, - dim=0, - ) - - def generate_actions(self, batch: dict[str, Tensor]): - # TODO: keep like this or move to the policy .forward - images, img_masks = self.prepare_images(batch) - - padded_outs = self.create_input_tokens(state=batch[OBS_STATE], lang_text=batch["task"], actions=None) - embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs( - images, - img_masks, - padded_outs["input_ids"], - padded_outs["padded_mask"], - padded_outs["attention_mask"], - padded_outs["loss_mask"], - padded_outs["token_type_ids"], - padding_side="left", - ) - token_type_ids = token_type_ids.to(dtype=torch.int64) - prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1 - output_tokens = self.pi0_paligemma.generate( - input_ids=None, - attention_mask=pad_masks, - position_ids=prefix_position_ids, - past_key_values=None, - inputs_embeds=embs, - use_cache=self.config.use_cache, - max_new_tokens=self.config.max_decoding_steps, - do_sample=False, - num_beams=1, - token_type_ids=token_type_ids, - ) - actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim) - return actions - - def embed_image(self, image: torch.Tensor): - # Handle different transformers versions - if hasattr(self.pi0_paligemma, "get_image_features"): - return self.pi0_paligemma.get_image_features(image) - else: - return self.pi0_paligemma.model.get_image_features(image) - - def embed_inputs( - self, - images, - img_masks, - tokens, - pad_mask, - ar_mask, - loss_mask, - token_type_ids, - padding_side: str = "right", - ): - # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty - # images are a list of same size - # vectorizing everything! - device = images[0].device - image_embedding_dim = images[0].shape[-1] # TODO should be from self.config - all_images = torch.stack(images, dim=1).to(device) - b, n, c, h, w = all_images.shape - all_images = all_images.view(b * n, c, h, w) - embedded = self.embed_image(all_images).to(device) - b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions - m = b_n // b # Compute the number of images per sample dynamically - - # Reshape dynamically - embedded = embedded.view(b, m, p, image_embedding_dim) - tokens_embs = self.embed_tokens(tokens.to(device)) - - img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device) - num_img_emb = embedded.shape[2] - img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1) - img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1) - - image_target_tokens = ( - torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id - ).reshape(b, -1) - image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1) - - embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D) - - embs = torch.cat([embedded, tokens_embs], dim=1).to(device) - pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1) - att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1) - loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1) - targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1) - token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1) - - # Shift pad tokens to the left (.generate()) or right (.train()) - embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side( - embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side - ) - - targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets) - return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids - - -def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True): - # assume no-op when width height fits already - if img.ndim != 4: - raise ValueError(f"(b,c,h,w) expected, but {img.shape}") - - cur_height, cur_width = img.shape[2:] - - ratio = max(cur_width / width, cur_height / height) - resized_height = int(cur_height / ratio) - resized_width = int(cur_width / ratio) - - if interpolate_like_pi: - img = (img * 255.0).to(dtype=torch.uint8) - img = img.permute(0, 2, 3, 1) - original_device = img.device - img = img.to(device="cpu").numpy() - imgs = [] - for sub_img in img: - sub_img = Image.fromarray(sub_img) - resized_img = sub_img.resize((resized_width, resized_height), resample=2) - resized_img = torch.from_numpy(np.array(resized_img)) - imgs.append(resized_img) - img = torch.stack(imgs, dim=0) - img = img.permute(0, 3, 1, 2) - resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0 - else: - resized_img = F.interpolate( - img, size=(resized_height, resized_width), mode="bilinear", align_corners=False - ) - - pad_height = max(0, int(height - resized_height)) - pad_width = max(0, int(width - resized_width)) - - # pad on left and top of image - padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) - return padded_img diff --git a/src/lerobot/policies/pi0fast/processor_pi0fast.py b/src/lerobot/policies/pi0fast/processor_pi0fast.py deleted file mode 100644 index 95b5e541b..000000000 --- a/src/lerobot/policies/pi0fast/processor_pi0fast.py +++ /dev/null @@ -1,92 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -import torch - -from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig -from lerobot.processor import ( - AddBatchDimensionProcessorStep, - DeviceProcessorStep, - NormalizerProcessorStep, - PolicyAction, - PolicyProcessorPipeline, - RenameObservationsProcessorStep, - UnnormalizerProcessorStep, -) -from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action -from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME - - -def make_pi0fast_pre_post_processors( - config: PI0FASTConfig, - dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, -) -> tuple[ - PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], - PolicyProcessorPipeline[PolicyAction, PolicyAction], -]: - """ - Constructs pre-processor and post-processor pipelines for the PI0Fast policy. - - The pre-processing pipeline prepares input data for the model by: - 1. Renaming features to match pretrained configurations. - 2. Normalizing input and output features based on dataset statistics. - 3. Adding a batch dimension. - 4. Moving all data to the specified device. - - The post-processing pipeline handles the model's output by: - 1. Moving data to the CPU. - 2. Unnormalizing the output features to their original scale. - - Args: - config: The configuration object for the PI0Fast policy. - dataset_stats: A dictionary of statistics for normalization. - preprocessor_kwargs: Additional arguments for the pre-processor pipeline. - postprocessor_kwargs: Additional arguments for the post-processor pipeline. - - Returns: - A tuple containing the configured pre-processor and post-processor pipelines. - """ - - input_steps = [ - RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one - AddBatchDimensionProcessorStep(), - DeviceProcessorStep(device=config.device), - NormalizerProcessorStep( - features={**config.input_features, **config.output_features}, - norm_map=config.normalization_mapping, - stats=dataset_stats, - ), - ] - output_steps = [ - UnnormalizerProcessorStep( - features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats - ), - DeviceProcessorStep(device="cpu"), - ] - return ( - PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( - steps=input_steps, - name=POLICY_PREPROCESSOR_DEFAULT_NAME, - ), - PolicyProcessorPipeline[PolicyAction, PolicyAction]( - steps=output_steps, - name=POLICY_POSTPROCESSOR_DEFAULT_NAME, - to_transition=policy_action_to_transition, - to_output=transition_to_policy_action, - ), - ) diff --git a/src/lerobot/templates/lerobot_modelcard_template.md b/src/lerobot/templates/lerobot_modelcard_template.md index 34af282b0..c59cf4183 100644 --- a/src/lerobot/templates/lerobot_modelcard_template.md +++ b/src/lerobot/templates/lerobot_modelcard_template.md @@ -19,8 +19,6 @@ [Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation. {% elif model_name == "vqbet" %} [VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills. -{% elif model_name == "pi0fast" %} -[Pi0-Fast](https://huggingface.co/papers/2501.09747) is a variant of Pi0 that uses a new tokenization method called FAST, which enables training of an autoregressive vision-language-action policy for high-frequency robotic tasks with improved performance and reduced training time. {% elif model_name == "pi0" %} **π₀ (Pi0)** From 8e940bf361d8534414ac10f356e9ad5101d4ba2e Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Tue, 14 Oct 2025 16:19:50 +0200 Subject: [PATCH 16/24] Feat/expand add features (#2202) * make add_feature take multiple features at a time and rename to add_features * - New function: modify_features that was a combination of remove features and add features. - This function is important for when we want to add a feature and remove another so we can do it in one time to avoid copying and creating the dataset multiple times --- examples/dataset/use_dataset_tools.py | 59 +++--- src/lerobot/datasets/dataset_tools.py | 234 +++++++++++++++-------- tests/datasets/test_dataset_tools.py | 262 ++++++++++++++++++++------ 3 files changed, 395 insertions(+), 160 deletions(-) diff --git a/examples/dataset/use_dataset_tools.py b/examples/dataset/use_dataset_tools.py index 244259872..bd7c389bc 100644 --- a/examples/dataset/use_dataset_tools.py +++ b/examples/dataset/use_dataset_tools.py @@ -30,9 +30,10 @@ Usage: import numpy as np from lerobot.datasets.dataset_tools import ( - add_feature, + add_features, delete_episodes, merge_datasets, + modify_features, remove_feature, split_dataset, ) @@ -57,50 +58,56 @@ def main(): print(f"Train split: {splits['train'].meta.total_episodes} episodes") print(f"Val split: {splits['val'].meta.total_episodes} episodes") - print("\n3. Adding a reward feature...") + print("\n3. Adding features...") reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32) - dataset_with_reward = add_feature( - dataset, - feature_name="reward", - feature_values=reward_values, - feature_info={ - "dtype": "float32", - "shape": (1,), - "names": None, - }, - repo_id="lerobot/pusht_with_reward", - ) def compute_success(row_dict, episode_index, frame_index): episode_length = 10 return float(frame_index >= episode_length - 10) - dataset_with_success = add_feature( - dataset_with_reward, - feature_name="success", - feature_values=compute_success, - feature_info={ - "dtype": "float32", - "shape": (1,), - "names": None, + dataset_with_features = add_features( + dataset, + features={ + "reward": ( + reward_values, + {"dtype": "float32", "shape": (1,), "names": None}, + ), + "success": ( + compute_success, + {"dtype": "float32", "shape": (1,), "names": None}, + ), }, - repo_id="lerobot/pusht_with_reward_and_success", + repo_id="lerobot/pusht_with_features", ) - print(f"New features: {list(dataset_with_success.meta.features.keys())}") + print(f"New features: {list(dataset_with_features.meta.features.keys())}") print("\n4. Removing the success feature...") dataset_cleaned = remove_feature( - dataset_with_success, feature_names="success", repo_id="lerobot/pusht_cleaned" + dataset_with_features, feature_names="success", repo_id="lerobot/pusht_cleaned" ) print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}") - print("\n5. Merging train and val splits back together...") + print("\n5. Using modify_features to add and remove features simultaneously...") + dataset_modified = modify_features( + dataset_with_features, + add_features={ + "discount": ( + np.ones(dataset.meta.total_frames, dtype=np.float32) * 0.99, + {"dtype": "float32", "shape": (1,), "names": None}, + ), + }, + remove_features="reward", + repo_id="lerobot/pusht_modified", + ) + print(f"Modified features: {list(dataset_modified.meta.features.keys())}") + + print("\n6. Merging train and val splits back together...") merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged") print(f"Merged dataset: {merged.meta.total_episodes} episodes") - print("\n6. Complex workflow example...") + print("\n7. Complex workflow example...") if len(dataset.meta.camera_keys) > 1: camera_to_remove = dataset.meta.camera_keys[0] diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 8ebc4a59d..2735ba0a0 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -28,8 +28,10 @@ import shutil from collections.abc import Callable from pathlib import Path +import datasets import numpy as np import pandas as pd +import pyarrow.parquet as pq import torch from tqdm import tqdm @@ -43,7 +45,6 @@ from lerobot.datasets.utils import ( DEFAULT_EPISODES_PATH, get_parquet_file_size_in_mb, load_episodes, - to_parquet_with_hf_images, update_chunk_file_indices, write_info, write_stats, @@ -268,39 +269,79 @@ def merge_datasets( return merged_dataset -def add_feature( +def modify_features( dataset: LeRobotDataset, - feature_name: str, - feature_values: np.ndarray | torch.Tensor | Callable, - feature_info: dict, + add_features: dict[str, tuple[np.ndarray | torch.Tensor | Callable, dict]] | None = None, + remove_features: str | list[str] | None = None, output_dir: str | Path | None = None, repo_id: str | None = None, ) -> LeRobotDataset: - """Add a new feature to a LeRobotDataset. + """Modify a LeRobotDataset by adding and/or removing features in a single pass. + + This is the most efficient way to modify features, as it only copies the dataset once + regardless of how many features are being added or removed. Args: dataset: The source LeRobotDataset. - feature_name: Name of the new feature. - feature_values: Either: - - Array/tensor of shape (num_frames, ...) with values for each frame - - Callable that takes (frame_dict, episode_index, frame_index) and returns feature value - feature_info: Dictionary with feature metadata (dtype, shape, names). + add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples. + remove_features: Optional feature name(s) to remove. Can be a single string or list. output_dir: Directory to save the new dataset. If None, uses default location. repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + + Returns: + New dataset with features modified. + + Example: + new_dataset = modify_features( + dataset, + add_features={ + "reward": (reward_array, {"dtype": "float32", "shape": [1], "names": None}), + }, + remove_features=["old_feature"], + output_dir="./output", + ) """ - if feature_name in dataset.meta.features: - raise ValueError(f"Feature '{feature_name}' already exists in dataset") + if add_features is None and remove_features is None: + raise ValueError("Must specify at least one of add_features or remove_features") + + remove_features_list: list[str] = [] + if remove_features is not None: + remove_features_list = [remove_features] if isinstance(remove_features, str) else remove_features + + if add_features: + required_keys = {"dtype", "shape"} + for feature_name, (_, feature_info) in add_features.items(): + if feature_name in dataset.meta.features: + raise ValueError(f"Feature '{feature_name}' already exists in dataset") + + if not required_keys.issubset(feature_info.keys()): + raise ValueError(f"feature_info for '{feature_name}' must contain keys: {required_keys}") + + if remove_features_list: + for name in remove_features_list: + if name not in dataset.meta.features: + raise ValueError(f"Feature '{name}' not found in dataset") + + required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"} + if any(name in required_features for name in remove_features_list): + raise ValueError(f"Cannot remove required features: {required_features}") if repo_id is None: repo_id = f"{dataset.repo_id}_modified" output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id - required_keys = {"dtype", "shape"} - if not required_keys.issubset(feature_info.keys()): - raise ValueError(f"feature_info must contain keys: {required_keys}") - new_features = dataset.meta.features.copy() - new_features[feature_name] = feature_info + + if remove_features_list: + for name in remove_features_list: + new_features.pop(name, None) + + if add_features: + for feature_name, (_, feature_info) in add_features.items(): + new_features[feature_name] = feature_info + + video_keys_to_remove = [name for name in remove_features_list if name in dataset.meta.video_keys] + remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove] new_meta = LeRobotDatasetMetadata.create( repo_id=repo_id, @@ -308,17 +349,18 @@ def add_feature( features=new_features, robot_type=dataset.meta.robot_type, root=output_dir, - use_videos=len(dataset.meta.video_keys) > 0, + use_videos=len(remaining_video_keys) > 0, ) _copy_data_with_feature_changes( dataset=dataset, new_meta=new_meta, - add_features={feature_name: (feature_values, feature_info)}, + add_features=add_features, + remove_features=remove_features_list if remove_features_list else None, ) - if dataset.meta.video_keys: - _copy_videos(dataset, new_meta) + if new_meta.video_keys: + _copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove if video_keys_to_remove else None) new_dataset = LeRobotDataset( repo_id=repo_id, @@ -331,6 +373,46 @@ def add_feature( return new_dataset +def add_features( + dataset: LeRobotDataset, + features: dict[str, tuple[np.ndarray | torch.Tensor | Callable, dict]], + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Add multiple features to a LeRobotDataset in a single pass. + + This is more efficient than calling add_feature() multiple times, as it only + copies the dataset once regardless of how many features are being added. + + Args: + dataset: The source LeRobotDataset. + features: Dictionary mapping feature names to (feature_values, feature_info) tuples. + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + + Returns: + New dataset with all features added. + + Example: + features = { + "task_embedding": (task_emb_array, {"dtype": "float32", "shape": [384], "names": None}), + "cam1_embedding": (cam1_emb_array, {"dtype": "float32", "shape": [768], "names": None}), + "cam2_embedding": (cam2_emb_array, {"dtype": "float32", "shape": [768], "names": None}), + } + new_dataset = add_features(dataset, features, output_dir="./output", repo_id="my_dataset") + """ + if not features: + raise ValueError("No features provided") + + return modify_features( + dataset=dataset, + add_features=features, + remove_features=None, + output_dir=output_dir, + repo_id=repo_id, + ) + + def remove_feature( dataset: LeRobotDataset, feature_names: str | list[str], @@ -345,56 +427,17 @@ def remove_feature( output_dir: Directory to save the new dataset. If None, uses default location. repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + Returns: + New dataset with features removed. """ - if isinstance(feature_names, str): - feature_names = [feature_names] - - for name in feature_names: - if name not in dataset.meta.features: - raise ValueError(f"Feature '{name}' not found in dataset") - - required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"} - if any(name in required_features for name in feature_names): - raise ValueError(f"Cannot remove required features: {required_features}") - - if repo_id is None: - repo_id = f"{dataset.repo_id}_modified" - output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id - - new_features = {k: v for k, v in dataset.meta.features.items() if k not in feature_names} - - video_keys_to_remove = [name for name in feature_names if name in dataset.meta.video_keys] - - remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove] - - new_meta = LeRobotDatasetMetadata.create( - repo_id=repo_id, - fps=dataset.meta.fps, - features=new_features, - robot_type=dataset.meta.robot_type, - root=output_dir, - use_videos=len(remaining_video_keys) > 0, - ) - - _copy_data_with_feature_changes( + return modify_features( dataset=dataset, - new_meta=new_meta, + add_features=None, remove_features=feature_names, - ) - - if new_meta.video_keys: - _copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove) - - new_dataset = LeRobotDataset( + output_dir=output_dir, repo_id=repo_id, - root=output_dir, - image_transforms=dataset.image_transforms, - delta_timestamps=dataset.delta_timestamps, - tolerance_s=dataset.tolerance_s, ) - return new_dataset - def _fractions_to_episode_indices( total_episodes: int, @@ -501,10 +544,7 @@ def _copy_and_reindex_data( dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) dst_path.parent.mkdir(parents=True, exist_ok=True) - if len(dst_meta.image_keys) > 0: - to_parquet_with_hf_images(df, dst_path) - else: - df.to_parquet(dst_path, index=False) + _write_parquet(df, dst_path, dst_meta) for ep_old_idx in episodes_to_keep: ep_new_idx = episode_mapping[ep_old_idx] @@ -862,6 +902,25 @@ def _copy_and_reindex_episodes_metadata( write_stats(filtered_stats, dst_meta.root) +def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -> None: + """Write DataFrame to parquet + + This ensures images are properly embedded and the file can be loaded correctly by HF datasets. + """ + from lerobot.datasets.utils import embed_images, get_hf_features_from_features + + hf_features = get_hf_features_from_features(meta.features) + ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train") + + if len(meta.image_keys) > 0: + ep_dataset = embed_images(ep_dataset) + + table = ep_dataset.with_format("arrow")[:] + writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True) + writer.write_table(table) + writer.close() + + def _save_data_chunk( df: pd.DataFrame, meta: LeRobotDatasetMetadata, @@ -877,10 +936,7 @@ def _save_data_chunk( path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) path.parent.mkdir(parents=True, exist_ok=True) - if len(meta.image_keys) > 0: - to_parquet_with_hf_images(df, path) - else: - df.to_parquet(path, index=False) + _write_parquet(df, path, meta) episode_metadata = {} for ep_idx in df["episode_index"].unique(): @@ -906,19 +962,34 @@ def _copy_data_with_feature_changes( remove_features: list[str] | None = None, ) -> None: """Copy data while adding or removing features.""" - file_paths = set() + if dataset.meta.episodes is None: + dataset.meta.episodes = load_episodes(dataset.meta.root) + + # Map file paths to episode indices to extract chunk/file indices + file_to_episodes: dict[Path, set[int]] = {} for ep_idx in range(dataset.meta.total_episodes): - file_paths.add(dataset.meta.get_data_file_path(ep_idx)) + file_path = dataset.meta.get_data_file_path(ep_idx) + if file_path not in file_to_episodes: + file_to_episodes[file_path] = set() + file_to_episodes[file_path].add(ep_idx) frame_idx = 0 - for src_path in tqdm(sorted(file_paths), desc="Processing data files"): + for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"): df = pd.read_parquet(dataset.root / src_path).reset_index(drop=True) + # Get chunk_idx and file_idx from the source file's first episode + episodes_in_file = file_to_episodes[src_path] + first_ep_idx = min(episodes_in_file) + src_ep = dataset.meta.episodes[first_ep_idx] + chunk_idx = src_ep["data/chunk_index"] + file_idx = src_ep["data/file_index"] + if remove_features: df = df.drop(columns=remove_features, errors="ignore") if add_features: + end_idx = frame_idx + len(df) for feature_name, (values, _) in add_features.items(): if callable(values): feature_values = [] @@ -931,15 +1002,18 @@ def _copy_data_with_feature_changes( feature_values.append(value) df[feature_name] = feature_values else: - end_idx = frame_idx + len(df) feature_slice = values[frame_idx:end_idx] if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1: df[feature_name] = feature_slice.flatten() else: df[feature_name] = feature_slice - frame_idx = end_idx + frame_idx = end_idx - _save_data_chunk(df, new_meta) + # Write using the preserved chunk_idx and file_idx from source + dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + dst_path.parent.mkdir(parents=True, exist_ok=True) + + _write_parquet(df, dst_path, new_meta) _copy_episodes_metadata_and_stats(dataset, new_meta) diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index a9c04d6f2..8bc1dbf6b 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -22,9 +22,10 @@ import pytest import torch from lerobot.datasets.dataset_tools import ( - add_feature, + add_features, delete_episodes, merge_datasets, + modify_features, remove_feature, split_dataset, ) @@ -292,7 +293,7 @@ def test_merge_empty_list(tmp_path): merge_datasets([], output_repo_id="merged", output_dir=tmp_path) -def test_add_feature_with_values(sample_dataset, tmp_path): +def test_add_features_with_values(sample_dataset, tmp_path): """Test adding a feature with pre-computed values.""" num_frames = sample_dataset.meta.total_frames reward_values = np.random.randn(num_frames, 1).astype(np.float32) @@ -302,6 +303,9 @@ def test_add_feature_with_values(sample_dataset, tmp_path): "shape": (1,), "names": None, } + features = { + "reward": (reward_values, feature_info), + } with ( patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, @@ -310,11 +314,9 @@ def test_add_feature_with_values(sample_dataset, tmp_path): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "with_reward") - new_dataset = add_feature( - sample_dataset, - feature_name="reward", - feature_values=reward_values, - feature_info=feature_info, + new_dataset = add_features( + dataset=sample_dataset, + features=features, output_dir=tmp_path / "with_reward", ) @@ -327,7 +329,7 @@ def test_add_feature_with_values(sample_dataset, tmp_path): assert isinstance(sample_item["reward"], torch.Tensor) -def test_add_feature_with_callable(sample_dataset, tmp_path): +def test_add_features_with_callable(sample_dataset, tmp_path): """Test adding a feature with a callable.""" def compute_reward(frame_dict, episode_idx, frame_idx): @@ -338,7 +340,9 @@ def test_add_feature_with_callable(sample_dataset, tmp_path): "shape": (1,), "names": None, } - + features = { + "reward": (compute_reward, feature_info), + } with ( patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, @@ -346,11 +350,9 @@ def test_add_feature_with_callable(sample_dataset, tmp_path): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "with_reward") - new_dataset = add_feature( - sample_dataset, - feature_name="reward", - feature_values=compute_reward, - feature_info=feature_info, + new_dataset = add_features( + dataset=sample_dataset, + features=features, output_dir=tmp_path / "with_reward", ) @@ -368,31 +370,88 @@ def test_add_feature_with_callable(sample_dataset, tmp_path): def test_add_existing_feature(sample_dataset, tmp_path): """Test error when adding an existing feature.""" feature_info = {"dtype": "float32", "shape": (1,)} + features = { + "action": (np.zeros(50), feature_info), + } with pytest.raises(ValueError, match="Feature 'action' already exists"): - add_feature( - sample_dataset, - feature_name="action", - feature_values=np.zeros(50), - feature_info=feature_info, + add_features( + dataset=sample_dataset, + features=features, output_dir=tmp_path / "modified", ) def test_add_feature_invalid_info(sample_dataset, tmp_path): """Test error with invalid feature info.""" - with pytest.raises(ValueError, match="feature_info must contain keys"): - add_feature( - sample_dataset, - feature_name="reward", - feature_values=np.zeros(50), - feature_info={"dtype": "float32"}, + with pytest.raises(ValueError, match="feature_info for 'reward' must contain keys"): + add_features( + dataset=sample_dataset, + features={ + "reward": (np.zeros(50), {"dtype": "float32"}), + }, output_dir=tmp_path / "modified", ) -def test_remove_single_feature(sample_dataset, tmp_path): - """Test removing a single feature.""" +def test_modify_features_add_and_remove(sample_dataset, tmp_path): + """Test modifying features by adding and removing simultaneously.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "modified") + + # First add a feature we'll later remove + dataset_with_reward = add_features( + sample_dataset, + features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)}, + output_dir=tmp_path / "with_reward", + ) + + # Now use modify_features to add "success" and remove "reward" in one pass + modified_dataset = modify_features( + dataset_with_reward, + add_features={ + "success": (np.random.randn(50, 1).astype(np.float32), feature_info), + }, + remove_features="reward", + output_dir=tmp_path / "modified", + ) + + assert "success" in modified_dataset.meta.features + assert "reward" not in modified_dataset.meta.features + assert len(modified_dataset) == 50 + + +def test_modify_features_only_add(sample_dataset, tmp_path): + """Test that modify_features works with only add_features.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "modified") + + modified_dataset = modify_features( + sample_dataset, + add_features={ + "reward": (np.random.randn(50, 1).astype(np.float32), feature_info), + }, + output_dir=tmp_path / "modified", + ) + + assert "reward" in modified_dataset.meta.features + assert len(modified_dataset) == 50 + + +def test_modify_features_only_remove(sample_dataset, tmp_path): + """Test that modify_features works with only remove_features.""" feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( @@ -402,11 +461,46 @@ def test_remove_single_feature(sample_dataset, tmp_path): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) - dataset_with_reward = add_feature( + dataset_with_reward = add_features( sample_dataset, - feature_name="reward", - feature_values=np.random.randn(50, 1).astype(np.float32), - feature_info=feature_info, + features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)}, + output_dir=tmp_path / "with_reward", + ) + + modified_dataset = modify_features( + dataset_with_reward, + remove_features="reward", + output_dir=tmp_path / "modified", + ) + + assert "reward" not in modified_dataset.meta.features + + +def test_modify_features_no_changes(sample_dataset, tmp_path): + """Test error when modify_features is called with no changes.""" + with pytest.raises(ValueError, match="Must specify at least one of add_features or remove_features"): + modify_features( + sample_dataset, + output_dir=tmp_path / "modified", + ) + + +def test_remove_single_feature(sample_dataset, tmp_path): + """Test removing a single feature.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + features = { + "reward": (np.random.randn(50, 1).astype(np.float32), feature_info), + } + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset_with_reward = add_features( + dataset=sample_dataset, + features=features, output_dir=tmp_path / "with_reward", ) @@ -432,20 +526,19 @@ def test_remove_multiple_features(sample_dataset, tmp_path): mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) dataset = sample_dataset + features = {} for feature_name in ["reward", "success"]: feature_info = {"dtype": "float32", "shape": (1,), "names": None} - dataset = add_feature( - dataset, - feature_name=feature_name, - feature_values=np.random.randn(dataset.meta.total_frames, 1).astype(np.float32), - feature_info=feature_info, - output_dir=tmp_path / f"with_{feature_name}", + features[feature_name] = ( + np.random.randn(dataset.meta.total_frames, 1).astype(np.float32), + feature_info, ) + dataset_with_features = add_features( + dataset, features=features, output_dir=tmp_path / "with_features" + ) dataset_clean = remove_feature( - dataset, - feature_names=["reward", "success"], - output_dir=tmp_path / "clean", + dataset_with_features, feature_names=["reward", "success"], output_dir=tmp_path / "clean" ) assert "reward" not in dataset_clean.meta.features @@ -509,11 +602,14 @@ def test_complex_workflow_integration(sample_dataset, tmp_path): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) - dataset = add_feature( + dataset = add_features( sample_dataset, - feature_name="reward", - feature_values=np.random.randn(50, 1).astype(np.float32), - feature_info={"dtype": "float32", "shape": (1,), "names": None}, + features={ + "reward": ( + np.random.randn(50, 1).astype(np.float32), + {"dtype": "float32", "shape": (1,), "names": None}, + ) + }, output_dir=tmp_path / "step1", ) @@ -753,7 +849,7 @@ def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_f assert "std" in merged.meta.stats[feature] -def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path): +def test_add_features_preserves_existing_stats(sample_dataset, tmp_path): """Test that adding a feature preserves existing stats.""" num_frames = sample_dataset.meta.total_frames reward_values = np.random.randn(num_frames, 1).astype(np.float32) @@ -763,6 +859,9 @@ def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path): "shape": (1,), "names": None, } + features = { + "reward": (reward_values, feature_info), + } with ( patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, @@ -771,11 +870,9 @@ def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "with_reward") - new_dataset = add_feature( - sample_dataset, - feature_name="reward", - feature_values=reward_values, - feature_info=feature_info, + new_dataset = add_features( + dataset=sample_dataset, + features=features, output_dir=tmp_path / "with_reward", ) @@ -797,11 +894,11 @@ def test_remove_feature_updates_stats(sample_dataset, tmp_path): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) - dataset_with_reward = add_feature( + dataset_with_reward = add_features( sample_dataset, - feature_name="reward", - feature_values=np.random.randn(50, 1).astype(np.float32), - feature_info=feature_info, + features={ + "reward": (np.random.randn(50, 1).astype(np.float32), feature_info), + }, output_dir=tmp_path / "with_reward", ) @@ -893,3 +990,60 @@ def test_split_all_episodes_assigned(sample_dataset, tmp_path): total_episodes = sum(ds.meta.total_episodes for ds in result.values()) assert total_episodes == sample_dataset.meta.total_episodes + + +def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): + """Test that modifying features preserves chunk_idx and file_idx from source dataset.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + return str(kwargs.get("local_dir", tmp_path / repo_id.split("/")[-1])) + + mock_snapshot_download.side_effect = mock_snapshot + + # First split the dataset to create a non-zero starting chunk/file structure + splits = split_dataset( + sample_dataset, + splits={"train": [0, 1, 2], "val": [3, 4]}, + output_dir=tmp_path / "splits", + ) + + train_dataset = splits["train"] + + # Get original chunk/file indices from first episode + if train_dataset.meta.episodes is None: + from lerobot.datasets.utils import load_episodes + + train_dataset.meta.episodes = load_episodes(train_dataset.meta.root) + original_chunk_indices = [ep["data/chunk_index"] for ep in train_dataset.meta.episodes] + original_file_indices = [ep["data/file_index"] for ep in train_dataset.meta.episodes] + + # Now add a feature to the split dataset + modified_dataset = add_features( + train_dataset, + features={ + "reward": ( + np.random.randn(train_dataset.meta.total_frames, 1).astype(np.float32), + feature_info, + ), + }, + output_dir=tmp_path / "modified", + ) + + # Check that chunk/file indices are preserved + if modified_dataset.meta.episodes is None: + from lerobot.datasets.utils import load_episodes + + modified_dataset.meta.episodes = load_episodes(modified_dataset.meta.root) + new_chunk_indices = [ep["data/chunk_index"] for ep in modified_dataset.meta.episodes] + new_file_indices = [ep["data/file_index"] for ep in modified_dataset.meta.episodes] + + assert new_chunk_indices == original_chunk_indices, "Chunk indices should be preserved" + assert new_file_indices == original_file_indices, "File indices should be preserved" + assert "reward" in modified_dataset.meta.features From 271d92dcaae084a0ab48763e7a2efd37b9a27fe6 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Tue, 14 Oct 2025 17:21:18 +0200 Subject: [PATCH 17/24] feat(sim): add metaworld env (#2088) * add metaworld * smol update Signed-off-by: Jade Choghari * update design * Update src/lerobot/envs/metaworld.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jade Choghari * update * small changes * iterate on review * small fix * small fix * add docs * update doc * add better gif * smol doc fix * updage gymnasium * add note * depreciate gym-xarm * more changes * update doc * comply with mypy * more fixes * update readme * precommit * update pusht * add pusht instead * changes * style * add changes * update * revert * update v2 * chore(envs): move metaworld config to its own file + remove comments + simplify _format_raw_obs (#2200) * update final changes --------- Signed-off-by: Jade Choghari Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Steven Palma --- CONTRIBUTING.md | 1 - Makefile | 10 +- docs/source/_toctree.yml | 10 +- docs/source/installation.mdx | 2 +- docs/source/libero.mdx | 2 +- docs/source/metaworld.mdx | 80 +++++++ pyproject.toml | 11 +- src/lerobot/__init__.py | 12 - src/lerobot/envs/__init__.py | 2 +- src/lerobot/envs/configs.py | 81 ++++--- src/lerobot/envs/factory.py | 17 +- src/lerobot/envs/libero.py | 26 +- src/lerobot/envs/metaworld.py | 313 +++++++++++++++++++++++++ src/lerobot/envs/metaworld_config.json | 121 ++++++++++ src/lerobot/scripts/lerobot_eval.py | 10 +- tests/policies/test_policies.py | 3 - 16 files changed, 612 insertions(+), 89 deletions(-) create mode 100644 docs/source/metaworld.mdx create mode 100644 src/lerobot/envs/metaworld.py create mode 100644 src/lerobot/envs/metaworld_config.json diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 369af602b..a07596728 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -72,7 +72,6 @@ post it. Look at our implementations for [datasets](./src/lerobot/datasets/), [policies](./src/lerobot/policies/), environments ([aloha](https://github.com/huggingface/gym-aloha), -[xarm](https://github.com/huggingface/gym-xarm), [pusht](https://github.com/huggingface/gym-pusht)) and follow the same api design. diff --git a/Makefile b/Makefile index fbe8a5bae..e02f02403 100644 --- a/Makefile +++ b/Makefile @@ -119,10 +119,9 @@ test-tdmpc-ete-train: --policy.type=tdmpc \ --policy.device=$(DEVICE) \ --policy.push_to_hub=false \ - --env.type=xarm \ - --env.task=XarmLift-v0 \ + --env.type=pusht \ --env.episode_length=5 \ - --dataset.repo_id=lerobot/xarm_lift_medium \ + --dataset.repo_id=lerobot/pusht_image \ --dataset.image_transforms.enable=true \ --dataset.episodes="[0]" \ --batch_size=2 \ @@ -140,9 +139,10 @@ test-tdmpc-ete-eval: lerobot-eval \ --policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \ --policy.device=$(DEVICE) \ - --env.type=xarm \ + --env.type=pusht \ --env.episode_length=5 \ - --env.task=XarmLift-v0 \ + --env.observation_height=96 \ + --env.observation_width=96 \ --eval.n_episodes=1 \ --eval.batch_size=1 diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 568bd6380..b7e71e010 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -7,8 +7,6 @@ - sections: - local: il_robots title: Imitation Learning for Robots - - local: il_sim - title: Imitation Learning in Sim - local: cameras title: Cameras - local: integrate_hardware @@ -37,9 +35,15 @@ title: π₀ (Pi0) - local: pi05 title: π₀.₅ (Pi05) + title: "Policies" +- sections: + - local: il_sim + title: Imitation Learning in Sim - local: libero title: Using Libero - title: "Policies" + - local: metaworld + title: Using MetaWorld + title: "Simulation" - sections: - local: introduction_processors title: Introduction to Robot Processors diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 93354c2ee..f5fd09acd 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -91,7 +91,7 @@ LeRobot provides optional extras for specific functionalities. Multiple extras c ### Simulations -Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), `xarm` ([gym-xarm](https://github.com/huggingface/gym-xarm)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht)) +Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht)) Example: ```bash diff --git a/docs/source/libero.mdx b/docs/source/libero.mdx index 3f2b92406..14f51ef3b 100644 --- a/docs/source/libero.mdx +++ b/docs/source/libero.mdx @@ -137,7 +137,7 @@ The finetuned model can be found here: We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command: ```bash -python src/lerobot/scripts/eval.py \ +lerobot-eval \ --output_dir=/logs/ \ --env.type=libero \ --env.task=libero_spatial,libero_object,libero_goal,libero_10 \ diff --git a/docs/source/metaworld.mdx b/docs/source/metaworld.mdx new file mode 100644 index 000000000..da90bd51d --- /dev/null +++ b/docs/source/metaworld.mdx @@ -0,0 +1,80 @@ +# Meta-World + +Meta-World is a well-designed, open-source simulation benchmark for multi-task and meta reinforcement learning in continuous-control robotic manipulation. It gives researchers a shared, realistic playground to test whether algorithms can _learn many different tasks_ and _generalize quickly to new ones_ — two central challenges for real-world robotics. + +- 📄 [MetaWorld paper](https://arxiv.org/pdf/1910.10897) +- 💻 [Original MetaWorld repo](https://github.com/Farama-Foundation/Metaworld) + +![MetaWorld MT10 demo](https://meta-world.github.io/figures/ml45.gif) + +## Why Meta-World matters + +- **Diverse, realistic tasks.** Meta-World bundles a large suite of simulated manipulation tasks (50 in the MT50 suite) using everyday objects and a common tabletop Sawyer arm. This diversity exposes algorithms to a wide variety of dynamics, contacts and goal specifications while keeping a consistent control and observation structure. +- **Focus on generalization and multi-task learning.** By evaluating across task distributions that share structure but differ in goals and objects, Meta-World reveals whether an agent truly learns transferable skills rather than overfitting to a narrow task. +- **Standardized evaluation protocol.** It provides clear evaluation modes and difficulty splits, so different methods can be compared fairly across easy, medium, hard and very-hard regimes. +- **Empirical insight.** Past evaluations on Meta-World show impressive progress on some fronts, but also highlight that current multi-task and meta-RL methods still struggle with large, diverse task sets. That gap points to important research directions. + +## What it enables in LeRobot + +In LeRobot, you can evaluate any policy or vision-language-action (VLA) model on Meta-World tasks and get a clear success-rate measure. The integration is designed to be straightforward: + +- We provide a LeRobot-ready dataset for Meta-World (MT50) on the HF Hub: `https://huggingface.co/datasets/lerobot/metaworld_mt50`. + - This dataset is formatted for the MT50 evaluation that uses all 50 tasks (the most challenging multi-task setting). + - MT50 gives the policy a one-hot task vector and uses fixed object/goal positions for consistency. + +- Task descriptions and the exact keys required for evaluation are available in the repo/dataset — use these to ensure your policy outputs the right success signals. + +## Quick start, train a SmolVLA policy on Meta-World + +Example command to train a SmolVLA policy on a subset of tasks: + +```bash +lerobot-train \ + --policy.type=smolvla \ + --policy.repo_id=${HF_USER}/metaworld-test \ + --policy.load_vlm_weights=true \ + --dataset.repo_id=lerobot/metaworld_mt50 \ + --env.type=metaworld \ + --env.task=assembly-v3,dial-turn-v3,handle-press-side-v3 \ + --output_dir=./outputs/ \ + --steps=100000 \ + --batch_size=4 \ + --eval.batch_size=1 \ + --eval.n_episodes=1 \ + --eval_freq=1000 +``` + +Notes: + +- `--env.task` accepts explicit task lists (comma separated) or difficulty groups (e.g., `env.task="hard"`). +- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget. +- **Gymnasium Assertion Error**: if you encounter an error like + `AssertionError: ['human', 'rgb_array', 'depth_array']` when running MetaWorld environments, this comes from a mismatch between MetaWorld and your Gymnasium version. + We recommend using: + +```bash + pip install "gymnasium==1.1.0" +``` + +to ensure proper compatibility. + +## Quick start — evaluate a trained policy + +To evaluate a trained policy on the Meta-World medium difficulty split: + +```bash +lerobot-eval \ + --policy.path="your-policy-id" \ + --env.type=metaworld \ + --env.task=medium \ + --eval.batch_size=1 \ + --eval.n_episodes=2 +``` + +This will run episodes and return per-task success rates using the standard Meta-World evaluation keys. + +## Practical tips + +- If you care about generalization, run on the full MT50 suite — it’s intentionally challenging and reveals strengths/weaknesses better than a few narrow tasks. +- Use the one-hot task conditioning for multi-task training (MT10 / MT50 conventions) so policies have explicit task context. +- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark. diff --git a/pyproject.toml b/pyproject.toml index 44ca596b1..6d43c33df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,7 +80,7 @@ dependencies = [ "torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency "draccus==0.10.0", # TODO: Remove == - "gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency + "gymnasium>=1.0.0", "rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency # Support dependencies @@ -133,11 +133,10 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0 video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"] # Simulation -aloha = ["gym-aloha>=0.1.1,<0.2.0"] +aloha = ["gym-aloha>=0.1.2,<0.2.0"] pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead -xarm = ["gym-xarm>=0.1.1,<0.2.0"] -libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"] - +libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@upgrade-dep#egg=libero"] +metaworld = ["metaworld>=3.0.0"] # All all = [ @@ -157,9 +156,9 @@ all = [ "lerobot[video_benchmark]", "lerobot[aloha]", "lerobot[pusht]", - "lerobot[xarm]", "lerobot[phone]", "lerobot[libero]", + "lerobot[metaworld]", ] [project.scripts] diff --git a/src/lerobot/__init__.py b/src/lerobot/__init__.py index 9d3ed1893..eec574296 100644 --- a/src/lerobot/__init__.py +++ b/src/lerobot/__init__.py @@ -57,7 +57,6 @@ available_tasks_per_env = { "AlohaTransferCube-v0", ], "pusht": ["PushT-v0"], - "xarm": ["XarmLift-v0"], } available_envs = list(available_tasks_per_env.keys()) @@ -75,16 +74,6 @@ available_datasets_per_env = { # TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly # coupled with tests. "pusht": ["lerobot/pusht", "lerobot/pusht_image"], - "xarm": [ - "lerobot/xarm_lift_medium", - "lerobot/xarm_lift_medium_replay", - "lerobot/xarm_push_medium", - "lerobot/xarm_push_medium_replay", - "lerobot/xarm_lift_medium_image", - "lerobot/xarm_lift_medium_replay_image", - "lerobot/xarm_push_medium_image", - "lerobot/xarm_push_medium_replay_image", - ], } available_real_world_datasets = [ @@ -195,7 +184,6 @@ available_motors = [ available_policies_per_env = { "aloha": ["act"], "pusht": ["diffusion", "vqbet"], - "xarm": ["tdmpc"], "koch_real": ["act_koch_real"], "aloha_real": ["act_aloha_real"], } diff --git a/src/lerobot/envs/__init__.py b/src/lerobot/envs/__init__.py index 4977d11d9..d767b6e8c 100644 --- a/src/lerobot/envs/__init__.py +++ b/src/lerobot/envs/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401 +from .configs import AlohaEnv, EnvConfig, PushtEnv # noqa: F401 diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 7a979b864..3aa155093 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -133,45 +133,6 @@ class PushtEnv(EnvConfig): } -@EnvConfig.register_subclass("xarm") -@dataclass -class XarmEnv(EnvConfig): - task: str | None = "XarmLift-v0" - fps: int = 15 - episode_length: int = 200 - obs_type: str = "pixels_agent_pos" - render_mode: str = "rgb_array" - visualization_width: int = 384 - visualization_height: int = 384 - features: dict[str, PolicyFeature] = field( - default_factory=lambda: { - ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)), - "pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)), - } - ) - features_map: dict[str, str] = field( - default_factory=lambda: { - ACTION: ACTION, - "agent_pos": OBS_STATE, - "pixels": OBS_IMAGE, - } - ) - - def __post_init__(self): - if self.obs_type == "pixels_agent_pos": - self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,)) - - @property - def gym_kwargs(self) -> dict: - return { - "obs_type": self.obs_type, - "render_mode": self.render_mode, - "visualization_width": self.visualization_width, - "visualization_height": self.visualization_height, - "max_episode_steps": self.episode_length, - } - - @dataclass class ImagePreprocessingConfig: crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None @@ -306,3 +267,45 @@ class LiberoEnv(EnvConfig): "obs_type": self.obs_type, "render_mode": self.render_mode, } + + +@EnvConfig.register_subclass("metaworld") +@dataclass +class MetaworldEnv(EnvConfig): + task: str = "metaworld-push-v2" # add all tasks + fps: int = 80 + episode_length: int = 400 + obs_type: str = "pixels_agent_pos" + render_mode: str = "rgb_array" + multitask_eval: bool = True + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "agent_pos": OBS_STATE, + "top": f"{OBS_IMAGE}", + "pixels/top": f"{OBS_IMAGE}", + } + ) + + def __post_init__(self): + if self.obs_type == "pixels": + self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3)) + + elif self.obs_type == "pixels_agent_pos": + self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,)) + self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3)) + + else: + raise ValueError(f"Unsupported obs_type: {self.obs_type}") + + @property + def gym_kwargs(self) -> dict: + return { + "obs_type": self.obs_type, + "render_mode": self.render_mode, + } diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index c27f01b65..059e0e11a 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -17,7 +17,7 @@ import importlib import gymnasium as gym -from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv, XarmEnv +from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv def make_env_config(env_type: str, **kwargs) -> EnvConfig: @@ -25,8 +25,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: return AlohaEnv(**kwargs) elif env_type == "pusht": return PushtEnv(**kwargs) - elif env_type == "xarm": - return XarmEnv(**kwargs) elif env_type == "libero": return LiberoEnv(**kwargs) else: @@ -74,7 +72,18 @@ def make_env( gym_kwargs=cfg.gym_kwargs, env_cls=env_cls, ) + elif "metaworld" in cfg.type: + from lerobot.envs.metaworld import create_metaworld_envs + if cfg.task is None: + raise ValueError("MetaWorld requires a task to be specified") + + return create_metaworld_envs( + task=cfg.task, + n_envs=n_envs, + gym_kwargs=cfg.gym_kwargs, + env_cls=env_cls, + ) package_name = f"gym_{cfg.type}" try: importlib.import_module(package_name) @@ -87,7 +96,7 @@ def make_env( def _make_one(): return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {})) - vec = env_cls([_make_one for _ in range(n_envs)]) + vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP) # normalize to {suite: {task_id: vec_env}} for consistency suite_name = cfg.type # e.g., "pusht", "aloha" diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 99ec6712f..94b08e991 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -260,19 +260,23 @@ class LiberoEnv(gym.Env): is_success = self._env.check_success() terminated = done or is_success - info["is_success"] = is_success - + info.update( + { + "task": self.task, + "task_id": self.task_id, + "done": done, + "is_success": is_success, + } + ) observation = self._format_raw_obs(raw_obs) - if done: + if terminated: + info["final_info"] = { + "task": self.task, + "task_id": self.task_id, + "done": bool(done), + "is_success": bool(is_success), + } self.reset() - info.update( - { - "task": self.task, - "task_id": self.task_id, - "done": done, - "is_success": is_success, - } - ) truncated = False return observation, reward, terminated, truncated, info diff --git a/src/lerobot/envs/metaworld.py b/src/lerobot/envs/metaworld.py new file mode 100644 index 000000000..9190f33ad --- /dev/null +++ b/src/lerobot/envs/metaworld.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from collections import defaultdict +from collections.abc import Callable, Sequence +from pathlib import Path +from typing import Any + +import gymnasium as gym +import metaworld +import metaworld.policies as policies +import numpy as np +from gymnasium import spaces + +# ---- Load configuration data from the external JSON file ---- +CONFIG_PATH = Path(__file__).parent / "metaworld_config.json" +try: + with open(CONFIG_PATH) as f: + data = json.load(f) +except FileNotFoundError as err: + raise FileNotFoundError( + "Could not find 'metaworld_config.json'. " + "Please ensure the configuration file is in the same directory as the script." + ) from err +except json.JSONDecodeError as err: + raise ValueError( + "Failed to decode 'metaworld_config.json'. Please ensure it is a valid JSON file." + ) from err + +# ---- Process the loaded data ---- + +# extract and type-check top-level dicts +task_descriptions_obj = data.get("TASK_DESCRIPTIONS") +if not isinstance(task_descriptions_obj, dict): + raise TypeError("Expected TASK_DESCRIPTIONS to be a dict[str, str]") +TASK_DESCRIPTIONS: dict[str, str] = task_descriptions_obj + +task_name_to_id_obj = data.get("TASK_NAME_TO_ID") +if not isinstance(task_name_to_id_obj, dict): + raise TypeError("Expected TASK_NAME_TO_ID to be a dict[str, int]") +TASK_NAME_TO_ID: dict[str, int] = task_name_to_id_obj + +# difficulty -> tasks mapping +difficulty_to_tasks = data.get("DIFFICULTY_TO_TASKS") +if not isinstance(difficulty_to_tasks, dict): + raise TypeError("Expected 'DIFFICULTY_TO_TASKS' to be a dict[str, list[str]]") +DIFFICULTY_TO_TASKS: dict[str, list[str]] = difficulty_to_tasks + +# convert policy strings -> actual policy classes +task_policy_mapping = data.get("TASK_POLICY_MAPPING") +if not isinstance(task_policy_mapping, dict): + raise TypeError("Expected 'TASK_POLICY_MAPPING' to be a dict[str, str]") +TASK_POLICY_MAPPING: dict[str, Any] = { + task_name: getattr(policies, policy_class_name) + for task_name, policy_class_name in task_policy_mapping.items() +} +ACTION_DIM = 4 +OBS_DIM = 4 + + +class MetaworldEnv(gym.Env): + metadata = {"render_modes": ["rgb_array"], "render_fps": 80} + + def __init__( + self, + task, + camera_name="corner2", + obs_type="pixels", + render_mode="rgb_array", + observation_width=480, + observation_height=480, + visualization_width=640, + visualization_height=480, + ): + super().__init__() + self.task = task.replace("metaworld-", "") + self.obs_type = obs_type + self.render_mode = render_mode + self.observation_width = observation_width + self.observation_height = observation_height + self.visualization_width = visualization_width + self.visualization_height = visualization_height + self.camera_name = camera_name + + self._env = self._make_envs_task(self.task) + self._max_episode_steps = self._env.max_path_length + self.task_description = TASK_DESCRIPTIONS[self.task] + + self.expert_policy = TASK_POLICY_MAPPING[self.task]() + + if self.obs_type == "state": + raise NotImplementedError() + elif self.obs_type == "pixels": + self.observation_space = spaces.Dict( + { + "pixels": spaces.Box( + low=0, + high=255, + shape=(self.observation_height, self.observation_width, 3), + dtype=np.uint8, + ) + } + ) + elif self.obs_type == "pixels_agent_pos": + self.observation_space = spaces.Dict( + { + "pixels": spaces.Box( + low=0, + high=255, + shape=(self.observation_height, self.observation_width, 3), + dtype=np.uint8, + ), + "agent_pos": spaces.Box( + low=-1000.0, + high=1000.0, + shape=(OBS_DIM,), + dtype=np.float64, + ), + } + ) + + self.action_space = spaces.Box(low=-1, high=1, shape=(ACTION_DIM,), dtype=np.float32) + + def render(self) -> np.ndarray: + """ + Render the current environment frame. + + Returns: + np.ndarray: The rendered RGB image from the environment. + """ + image = self._env.render() + if self.camera_name == "corner2": + # Images from this camera are flipped — correct them + image = np.flip(image, (0, 1)) + return image + + def _make_envs_task(self, env_name: str): + mt1 = metaworld.MT1(env_name, seed=42) + env = mt1.train_classes[env_name](render_mode="rgb_array", camera_name=self.camera_name) + env.set_task(mt1.train_tasks[0]) + if self.camera_name == "corner2": + env.model.cam_pos[2] = [ + 0.75, + 0.075, + 0.7, + ] # corner2 position, similar to https://arxiv.org/pdf/2206.14244 + env.reset() + env._freeze_rand_vec = False # otherwise no randomization + return env + + def _format_raw_obs(self, raw_obs: np.ndarray) -> dict[str, Any]: + image = None + if self._env is not None: + image = self._env.render() + if self.camera_name == "corner2": + # NOTE: The "corner2" camera in MetaWorld environments outputs images with both axes inverted. + image = np.flip(image, (0, 1)) + agent_pos = raw_obs[:4] + if self.obs_type == "state": + raise NotImplementedError( + "'state' obs_type not implemented for MetaWorld. Use pixel modes instead." + ) + + elif self.obs_type in ("pixels", "pixels_agent_pos"): + assert image is not None, ( + "Expected `image` to be rendered before constructing pixel-based observations. " + "This likely means `env.render()` returned None or the environment was not provided." + ) + + if self.obs_type == "pixels": + obs = {"pixels": image.copy()} + + else: # pixels_agent_pos + obs = { + "pixels": image.copy(), + "agent_pos": agent_pos, + } + else: + raise ValueError(f"Unknown obs_type: {self.obs_type}") + return obs + + def reset( + self, + seed: int | None = None, + **kwargs, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Reset the environment to its initial state. + + Args: + seed (Optional[int]): Random seed for environment initialization. + + Returns: + observation (Dict[str, Any]): The initial formatted observation. + info (Dict[str, Any]): Additional info about the reset state. + """ + super().reset(seed=seed) + + raw_obs, info = self._env.reset(seed=seed) + + observation = self._format_raw_obs(raw_obs) + + info = {"is_success": False} + return observation, info + + def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + """ + Perform one environment step. + + Args: + action (np.ndarray): The action to execute, must be 1-D with shape (action_dim,). + + Returns: + observation (Dict[str, Any]): The formatted observation after the step. + reward (float): The scalar reward for this step. + terminated (bool): Whether the episode terminated successfully. + truncated (bool): Whether the episode was truncated due to a time limit. + info (Dict[str, Any]): Additional environment info. + """ + if action.ndim != 1: + raise ValueError( + f"Expected action to be 1-D (shape (action_dim,)), " + f"but got shape {action.shape} with ndim={action.ndim}" + ) + raw_obs, reward, done, truncated, info = self._env.step(action) + + # Determine whether the task was successful + is_success = bool(info.get("success", 0)) + terminated = done or is_success + info.update( + { + "task": self.task, + "done": done, + "is_success": is_success, + } + ) + + # Format the raw observation into the expected structure + observation = self._format_raw_obs(raw_obs) + if terminated: + info["final_info"] = { + "task": self.task, + "done": bool(done), + "is_success": bool(is_success), + } + self.reset() + + return observation, reward, terminated, truncated, info + + def close(self): + self._env.close() + + +# ---- Main API ---------------------------------------------------------------- + + +def create_metaworld_envs( + task: str, + n_envs: int, + gym_kwargs: dict[str, Any] | None = None, + env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None, +) -> dict[str, dict[int, Any]]: + """ + Create vectorized Meta-World environments with a consistent return shape. + + Returns: + dict[task_group][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories) + Notes: + - n_envs is the number of rollouts *per task* (episode_index = 0..n_envs-1). + - `task` can be a single difficulty group (e.g., "easy", "medium", "hard") or a comma-separated list. + - If a task name is not in DIFFICULTY_TO_TASKS, we treat it as a single custom task. + """ + if env_cls is None or not callable(env_cls): + raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.") + if not isinstance(n_envs, int) or n_envs <= 0: + raise ValueError(f"n_envs must be a positive int; got {n_envs}.") + + gym_kwargs = dict(gym_kwargs or {}) + task_groups = [t.strip() for t in task.split(",") if t.strip()] + if not task_groups: + raise ValueError("`task` must contain at least one Meta-World task or difficulty group.") + + print(f"Creating Meta-World envs | task_groups={task_groups} | n_envs(per task)={n_envs}") + + out: dict[str, dict[int, Any]] = defaultdict(dict) + + for group in task_groups: + # if not in difficulty presets, treat it as a single custom task + tasks = DIFFICULTY_TO_TASKS.get(group, [group]) + + for tid, task_name in enumerate(tasks): + print(f"Building vec env | group={group} | task_id={tid} | task={task_name}") + + # build n_envs factories + fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)] + + out[group][tid] = env_cls(fns) + + # return a plain dict for consistency + return {group: dict(task_map) for group, task_map in out.items()} diff --git a/src/lerobot/envs/metaworld_config.json b/src/lerobot/envs/metaworld_config.json new file mode 100644 index 000000000..41a417fef --- /dev/null +++ b/src/lerobot/envs/metaworld_config.json @@ -0,0 +1,121 @@ +{ + "TASK_DESCRIPTIONS": { + "assembly-v3": "Pick up a nut and place it onto a peg", + "basketball-v3": "Dunk the basketball into the basket", + "bin-picking-v3": "Grasp the puck from one bin and place it into another bin", + "box-close-v3": "Grasp the cover and close the box with it", + "button-press-topdown-v3": "Press a button from the top", + "button-press-topdown-wall-v3": "Bypass a wall and press a button from the top", + "button-press-v3": "Press a button", + "button-press-wall-v3": "Bypass a wall and press a button", + "coffee-button-v3": "Push a button on the coffee machine", + "coffee-pull-v3": "Pull a mug from a coffee machine", + "coffee-push-v3": "Push a mug under a coffee machine", + "dial-turn-v3": "Rotate a dial 180 degrees", + "disassemble-v3": "Pick a nut out of a peg", + "door-close-v3": "Close a door with a revolving joint", + "door-lock-v3": "Lock the door by rotating the lock clockwise", + "door-open-v3": "Open a door with a revolving joint", + "door-unlock-v3": "Unlock the door by rotating the lock counter-clockwise", + "hand-insert-v3": "Insert the gripper into a hole", + "drawer-close-v3": "Push and close a drawer", + "drawer-open-v3": "Open a drawer", + "faucet-open-v3": "Rotate the faucet counter-clockwise", + "faucet-close-v3": "Rotate the faucet clockwise", + "hammer-v3": "Hammer a screw on the wall", + "handle-press-side-v3": "Press a handle down sideways", + "handle-press-v3": "Press a handle down", + "handle-pull-side-v3": "Pull a handle up sideways", + "handle-pull-v3": "Pull a handle up", + "lever-pull-v3": "Pull a lever down 90 degrees", + "peg-insert-side-v3": "Insert a peg sideways", + "pick-place-wall-v3": "Pick a puck, bypass a wall and place the puck", + "pick-out-of-hole-v3": "Pick up a puck from a hole", + "reach-v3": "Reach a goal position", + "push-back-v3": "Push the puck to a goal", + "push-v3": "Push the puck to a goal", + "pick-place-v3": "Pick and place a puck to a goal", + "plate-slide-v3": "Slide a plate into a cabinet", + "plate-slide-side-v3": "Slide a plate into a cabinet sideways", + "plate-slide-back-v3": "Get a plate from the cabinet", + "plate-slide-back-side-v3": "Get a plate from the cabinet sideways", + "peg-unplug-side-v3": "Unplug a peg sideways", + "soccer-v3": "Kick a soccer into the goal", + "stick-push-v3": "Grasp a stick and push a box using the stick", + "stick-pull-v3": "Grasp a stick and pull a box with the stick", + "push-wall-v3": "Bypass a wall and push a puck to a goal", + "reach-wall-v3": "Bypass a wall and reach a goal", + "shelf-place-v3": "Pick and place a puck onto a shelf", + "sweep-into-v3": "Sweep a puck into a hole", + "sweep-v3": "Sweep a puck off the table", + "window-open-v3": "Push and open a window", + "window-close-v3": "Push and close a window" + }, + "TASK_NAME_TO_ID": { + "assembly-v3": 0, "basketball-v3": 1, "bin-picking-v3": 2, "box-close-v3": 3, + "button-press-topdown-v3": 4, "button-press-topdown-wall-v3": 5, "button-press-v3": 6, + "button-press-wall-v3": 7, "coffee-button-v3": 8, "coffee-pull-v3": 9, "coffee-push-v3": 10, + "dial-turn-v3": 11, "disassemble-v3": 12, "door-close-v3": 13, "door-lock-v3": 14, + "door-open-v3": 15, "door-unlock-v3": 16, "drawer-close-v3": 17, "drawer-open-v3": 18, + "faucet-close-v3": 19, "faucet-open-v3": 20, "hammer-v3": 21, "hand-insert-v3": 22, + "handle-press-side-v3": 23, "handle-press-v3": 24, "handle-pull-side-v3": 25, + "handle-pull-v3": 26, "lever-pull-v3": 27, "peg-insert-side-v3": 28, "peg-unplug-side-v3": 29, + "pick-out-of-hole-v3": 30, "pick-place-v3": 31, "pick-place-wall-v3": 32, + "plate-slide-back-side-v3": 33, "plate-slide-back-v3": 34, "plate-slide-side-v3": 35, + "plate-slide-v3": 36, "push-back-v3": 37, "push-v3": 38, "push-wall-v3": 39, "reach-v3": 40, + "reach-wall-v3": 41, "shelf-place-v3": 42, "soccer-v3": 43, "stick-pull-v3": 44, + "stick-push-v3": 45, "sweep-into-v3": 46, "sweep-v3": 47, "window-open-v3": 48, + "window-close-v3": 49 + }, + "DIFFICULTY_TO_TASKS": { + "easy": [ + "button-press-v3", "button-press-topdown-v3", "button-press-topdown-wall-v3", + "button-press-wall-v3", "coffee-button-v3", "dial-turn-v3", "door-close-v3", + "door-lock-v3", "door-open-v3", "door-unlock-v3", "drawer-close-v3", "drawer-open-v3", + "faucet-close-v3", "faucet-open-v3", "handle-press-v3", "handle-press-side-v3", + "handle-pull-v3", "handle-pull-side-v3", "lever-pull-v3", "plate-slide-v3", + "plate-slide-back-v3", "plate-slide-back-side-v3", "plate-slide-side-v3", "reach-v3", + "reach-wall-v3", "window-close-v3", "window-open-v3", "peg-unplug-side-v3" + ], + "medium": [ + "basketball-v3", "bin-picking-v3", "box-close-v3", "coffee-pull-v3", "coffee-push-v3", + "hammer-v3", "peg-insert-side-v3", "push-wall-v3", "soccer-v3", "sweep-v3", "sweep-into-v3" + ], + "hard": [ + "assembly-v3", "hand-insert-v3", "pick-out-of-hole-v3", "pick-place-v3", "push-v3", "push-back-v3" + ], + "very_hard": [ + "shelf-place-v3", "disassemble-v3", "stick-pull-v3", "stick-push-v3", "pick-place-wall-v3" + ] + }, + "TASK_POLICY_MAPPING": { + "assembly-v3": "SawyerAssemblyV3Policy", "basketball-v3": "SawyerBasketballV3Policy", + "bin-picking-v3": "SawyerBinPickingV3Policy", "box-close-v3": "SawyerBoxCloseV3Policy", + "button-press-topdown-v3": "SawyerButtonPressTopdownV3Policy", + "button-press-topdown-wall-v3": "SawyerButtonPressTopdownWallV3Policy", + "button-press-v3": "SawyerButtonPressV3Policy", "button-press-wall-v3": "SawyerButtonPressWallV3Policy", + "coffee-button-v3": "SawyerCoffeeButtonV3Policy", "coffee-pull-v3": "SawyerCoffeePullV3Policy", + "coffee-push-v3": "SawyerCoffeePushV3Policy", "dial-turn-v3": "SawyerDialTurnV3Policy", + "disassemble-v3": "SawyerDisassembleV3Policy", "door-close-v3": "SawyerDoorCloseV3Policy", + "door-lock-v3": "SawyerDoorLockV3Policy", "door-open-v3": "SawyerDoorOpenV3Policy", + "door-unlock-v3": "SawyerDoorUnlockV3Policy", "drawer-close-v3": "SawyerDrawerCloseV3Policy", + "drawer-open-v3": "SawyerDrawerOpenV3Policy", "faucet-close-v3": "SawyerFaucetCloseV3Policy", + "faucet-open-v3": "SawyerFaucetOpenV3Policy", "hammer-v3": "SawyerHammerV3Policy", + "hand-insert-v3": "SawyerHandInsertV3Policy", "handle-press-side-v3": "SawyerHandlePressSideV3Policy", + "handle-press-v3": "SawyerHandlePressV3Policy", "handle-pull-side-v3": "SawyerHandlePullSideV3Policy", + "handle-pull-v3": "SawyerHandlePullV3Policy", "lever-pull-v3": "SawyerLeverPullV3Policy", + "peg-insert-side-v3": "SawyerPegInsertionSideV3Policy", "peg-unplug-side-v3": "SawyerPegUnplugSideV3Policy", + "pick-out-of-hole-v3": "SawyerPickOutOfHoleV3Policy", "pick-place-v3": "SawyerPickPlaceV3Policy", + "pick-place-wall-v3": "SawyerPickPlaceWallV3Policy", + "plate-slide-back-side-v3": "SawyerPlateSlideBackSideV3Policy", + "plate-slide-back-v3": "SawyerPlateSlideBackV3Policy", + "plate-slide-side-v3": "SawyerPlateSlideSideV3Policy", "plate-slide-v3": "SawyerPlateSlideV3Policy", + "push-back-v3": "SawyerPushBackV3Policy", "push-v3": "SawyerPushV3Policy", + "push-wall-v3": "SawyerPushWallV3Policy", "reach-v3": "SawyerReachV3Policy", + "reach-wall-v3": "SawyerReachWallV3Policy", "shelf-place-v3": "SawyerShelfPlaceV3Policy", + "soccer-v3": "SawyerSoccerV3Policy", "stick-pull-v3": "SawyerStickPullV3Policy", + "stick-push-v3": "SawyerStickPushV3Policy", "sweep-into-v3": "SawyerSweepIntoV3Policy", + "sweep-v3": "SawyerSweepV3Policy", "window-open-v3": "SawyerWindowOpenV3Policy", + "window-close-v3": "SawyerWindowCloseV3Policy" + } +} diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index d45be5c42..aed7d32e3 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -180,9 +180,15 @@ def rollout( render_callback(env) # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't - # available of none of the envs finished. + # available if none of the envs finished. if "final_info" in info: - successes = [info["is_success"] if info is not None else False for info in info["final_info"]] + final_info = info["final_info"] + if not isinstance(final_info, dict): + raise RuntimeError( + "Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). " + "You're likely using an older version of gymnasium (< 1.0). Please upgrade." + ) + successes = final_info["is_success"].tolist() else: successes = [False] * env.num_envs diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 34fa89390..345526d90 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -95,7 +95,6 @@ def test_get_policy_and_config_classes(policy_name: str): @pytest.mark.parametrize( "ds_repo_id,env_name,env_kwargs,policy_name,policy_kwargs", [ - ("lerobot/xarm_lift_medium", "xarm", {}, "tdmpc", {"use_mpc": True}), ("lerobot/pusht", "pusht", {}, "diffusion", {}), ("lerobot/pusht", "pusht", {}, "vqbet", {}), ("lerobot/pusht", "pusht", {}, "act", {}), @@ -328,8 +327,6 @@ def test_multikey_construction(multikey: bool): # TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it # was changed to true. For some reason, tests would pass locally, but not in CI. So here we override # to test with `policy.use_mpc=false`. - ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"), - # ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"), # TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to # to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference # that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass. From a6ff3cfebb0304f2c378515dd30ea06fff8f473f Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 14 Oct 2025 18:19:49 +0200 Subject: [PATCH 18/24] chore(deps): libero dep pointing to main (#2201) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6d43c33df..e7727700c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,7 +135,7 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"] # Simulation aloha = ["gym-aloha>=0.1.2,<0.2.0"] pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead -libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@upgrade-dep#egg=libero"] +libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"] metaworld = ["metaworld>=3.0.0"] # All From 845b359d390b4c4f53c0560a8009e6d3ce7b88e1 Mon Sep 17 00:00:00 2001 From: Ryan Pennings Date: Thu, 16 Oct 2025 20:39:05 +1100 Subject: [PATCH 19/24] Fix homunculus teleoperator input lag (#2196) Removes input lag by making changes to the serial reading loop - remove serial flush as this only clears output buffer - read all data in the input buffer in per loop and use the latest line as the state to clear the input buffer previously was only reading one line per loop, which in combination with teleoperator script loop busy_wait function (which is slowing the _read_loops down) was causing a backlog in input buffer Co-authored-by: Martino Russi <77496684+nepyope@users.noreply.github.com> --- .../teleoperators/homunculus/homunculus_arm.py | 11 +++++++++-- .../teleoperators/homunculus/homunculus_glove.py | 11 +++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/lerobot/teleoperators/homunculus/homunculus_arm.py b/src/lerobot/teleoperators/homunculus/homunculus_arm.py index 21d73de2e..43116f5c0 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_arm.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_arm.py @@ -270,8 +270,15 @@ class HomunculusArm(Teleoperator): raw_values = None with self.serial_lock: if self.serial.in_waiting > 0: - self.serial.flush() - raw_values = self.serial.readline().decode("utf-8").strip().split(" ") + lines = [] + while self.serial.in_waiting > 0: + line = self.serial.read_until().decode("utf-8").strip() + if line: + lines.append(line.split(" ")) + + if lines: + raw_values = lines[-1] + if raw_values is None or len(raw_values) != 21: # 16 raw + 5 angle values continue diff --git a/src/lerobot/teleoperators/homunculus/homunculus_glove.py b/src/lerobot/teleoperators/homunculus/homunculus_glove.py index 251ecf56d..fefeec1e8 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_glove.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_glove.py @@ -304,8 +304,15 @@ class HomunculusGlove(Teleoperator): positions = None with self.serial_lock: if self.serial.in_waiting > 0: - self.serial.flush() - positions = self.serial.readline().decode("utf-8").strip().split(" ") + lines = [] + while self.serial.in_waiting > 0: + line = self.serial.read_until().decode("utf-8").strip() + if line: + lines.append(line.split(" ")) + + if lines: + positions = lines[-1] + if positions is None or len(positions) != len(self.joints): continue From e82e7a02e901e21f512e4d1e9fd252b1d3e8ce6d Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Thu, 16 Oct 2025 17:41:55 +0200 Subject: [PATCH 20/24] feat(train): add accelerate for multi gpu training (#2154) * Enhance training and logging functionality with accelerator support - Added support for multi-GPU training by introducing an `accelerator` parameter in training functions. - Updated `update_policy` to handle gradient updates based on the presence of an accelerator. - Modified logging to prevent duplicate messages in non-main processes. - Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage. - Updated `MetricsTracker` to account for the number of processes when calculating metrics. - Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency. * Initialize logging in training script for both main and non-main processes - Added `init_logging` calls to ensure proper logging setup when using the accelerator and in standard training mode. - This change enhances the clarity and consistency of logging during training sessions. * add docs and only push model once * Place logging under accelerate and update docs * fix pre commit * only log in main process * main logging * try with local rank * add tests * change runner * fix test * dont push to hub in multi gpu tests * pre download dataset in tests * small fixes * fix path optimizer state * update docs, and small improvements in train * simplify accelerate main process detection * small improvements in train * fix OOM bug * change accelerate detection * add some debugging * always use accelerate * cleanup update method * cleanup * fix bug * scale lr decay if we reduce steps * cleanup logging * fix formatting * encorperate feedback pr * add min memory to cpu tests * use accelerate to determin logging * fix precommit and fix tests * chore: minor details --------- Co-authored-by: AdilZouitine Co-authored-by: Steven Palma --- .github/workflows/nightly.yml | 33 +++ docs/source/_toctree.yml | 2 + docs/source/multi_gpu_training.mdx | 125 ++++++++ pyproject.toml | 1 + src/lerobot/optim/schedulers.py | 37 ++- src/lerobot/policies/pi0/configuration_pi0.py | 2 + .../policies/pi05/configuration_pi05.py | 2 + src/lerobot/rl/wandb_utils.py | 2 +- src/lerobot/scripts/lerobot_train.py | 277 +++++++++++------- src/lerobot/utils/logging_utils.py | 6 +- src/lerobot/utils/random_utils.py | 10 +- src/lerobot/utils/utils.py | 51 ++-- tests/training/test_multi_gpu.py | 211 +++++++++++++ 13 files changed, 625 insertions(+), 134 deletions(-) create mode 100644 docs/source/multi_gpu_training.mdx create mode 100644 tests/training/test_multi_gpu.py diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 03f26a792..f9fa02597 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -119,6 +119,7 @@ jobs: TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton container: image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images] + options: --shm-size "16gb" credentials: username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }} password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }} @@ -158,3 +159,35 @@ jobs: run: pytest tests -vv --maxfail=10 - name: Run end-to-end tests run: make test-end-to-end + + # This job runs multi-GPU training tests with 4 GPUs + nightly-multi-gpu-tests: + name: Nightly Multi-GPU Tests + needs: [build-docker-gpu-nightly] + runs-on: + group: aws-g4dn-12xlarge # Instance with 4 GPUs + env: + HF_HOME: /home/user_lerobot/.cache/huggingface + HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot + TORCH_HOME: /home/user_lerobot/.cache/torch + TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton + CUDA_VISIBLE_DEVICES: "0,1,2,3" + container: + image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images] + options: --gpus all --shm-size "16gb" + credentials: + username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }} + password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }} + defaults: + run: + shell: bash + working-directory: /lerobot + steps: + - name: Verify GPU availability + run: | + nvidia-smi + python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')" + + - name: Run multi-GPU training tests + run: pytest tests/training/test_multi_gpu.py -vv --maxfail=3 + timeout-minutes: 10 diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index b7e71e010..5e100013a 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -17,6 +17,8 @@ title: Train RL in Simulation - local: async title: Use Async Inference + - local: multi_gpu_training + title: Multi GPU training title: "Tutorials" - sections: - local: lerobot-dataset-v3 diff --git a/docs/source/multi_gpu_training.mdx b/docs/source/multi_gpu_training.mdx new file mode 100644 index 000000000..122670f69 --- /dev/null +++ b/docs/source/multi_gpu_training.mdx @@ -0,0 +1,125 @@ +# Multi-GPU Training + +This guide shows you how to train policies on multiple GPUs using [Hugging Face Accelerate](https://huggingface.co/docs/accelerate). + +## Installation + +First, ensure you have accelerate installed: + +```bash +pip install accelerate +``` + +## Training with Multiple GPUs + +You can launch training in two ways: + +### Option 1: Without config (specify parameters directly) + +You can specify all parameters directly in the command without running `accelerate config`: + +```bash +accelerate launch \ + --multi_gpu \ + --num_processes=2 \ + $(which lerobot-train) \ + --dataset.repo_id=${HF_USER}/my_dataset \ + --policy.type=act \ + --policy.repo_id=${HF_USER}/my_trained_policy \ + --output_dir=outputs/train/act_multi_gpu \ + --job_name=act_multi_gpu \ + --wandb.enable=true +``` + +**Key accelerate parameters:** + +- `--multi_gpu`: Enable multi-GPU training +- `--num_processes=2`: Number of GPUs to use +- `--mixed_precision=fp16`: Use fp16 mixed precision (or `bf16` if supported) + +### Option 2: Using accelerate config + +If you prefer to save your configuration, you can optionally configure accelerate for your hardware setup by running: + +```bash +accelerate config +``` + +This interactive setup will ask you questions about your training environment (number of GPUs, mixed precision settings, etc.) and saves the configuration for future use. For a simple multi-GPU setup on a single machine, you can use these recommended settings: + +- Compute environment: This machine +- Number of machines: 1 +- Number of processes: (number of GPUs you want to use) +- GPU ids to use: (leave empty to use all) +- Mixed precision: fp16 or bf16 (recommended for faster training) + +Then launch training with: + +```bash +accelerate launch $(which lerobot-train) \ + --dataset.repo_id=${HF_USER}/my_dataset \ + --policy.type=act \ + --policy.repo_id=${HF_USER}/my_trained_policy \ + --output_dir=outputs/train/act_multi_gpu \ + --job_name=act_multi_gpu \ + --wandb.enable=true +``` + +## How It Works + +When you launch training with accelerate: + +1. **Automatic detection**: LeRobot automatically detects if it's running under accelerate +2. **Data distribution**: Your batch is automatically split across GPUs +3. **Gradient synchronization**: Gradients are synchronized across GPUs during backpropagation +4. **Single process logging**: Only the main process logs to wandb and saves checkpoints + +## Learning Rate and Training Steps Scaling + +**Important:** LeRobot does **NOT** automatically scale learning rates or training steps based on the number of GPUs. This gives you full control over your training hyperparameters. + +### Why No Automatic Scaling? + +Many distributed training frameworks automatically scale the learning rate by the number of GPUs (e.g., `lr = base_lr × num_gpus`). +However, LeRobot keeps the learning rate exactly as you specify it. + +### When and How to Scale + +If you want to scale your hyperparameters when using multiple GPUs, you should do it manually: + +**Learning Rate Scaling:** + +```bash +# Example: 2 GPUs with linear LR scaling +# Base LR: 1e-4, with 2 GPUs -> 2e-4 +accelerate launch --num_processes=2 $(which lerobot-train) \ + --optimizer.lr=2e-4 \ + --dataset.repo_id=lerobot/pusht \ + --policy=act +``` + +**Training Steps Scaling:** + +Since the effective batch size `bs` increases with multiple GPUs (batch_size × num_gpus), you may want to reduce the number of training steps proportionally: + +```bash +# Example: 2 GPUs with effective batch size 2x larger +# Original: batch_size=8, steps=100000 +# With 2 GPUs: batch_size=8 (16 in total), steps=50000 +accelerate launch --num_processes=2 $(which lerobot-train) \ + --batch_size=8 \ + --steps=50000 \ + --dataset.repo_id=lerobot/pusht \ + --policy=act +``` + +## Notes + +- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration. +- Training logs, checkpoints, and hub uploads are only done by the main process to avoid conflicts. Non-main processes have console logging disabled to prevent duplicate output. +- The effective batch size is `batch_size × num_gpus`. If you use 4 GPUs with `--batch_size=8`, your effective batch size is 32. +- Learning rate scheduling is handled correctly across multiple processes—LeRobot sets `step_scheduler_with_optimizer=False` to prevent accelerate from adjusting scheduler steps based on the number of processes. +- When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility. +- WandB integration automatically initializes only on the main process, preventing multiple runs from being created. + +For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate). If you want to learn more about how to train on a large number of GPUs, checkout this awesome guide: [Ultrascale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook). diff --git a/pyproject.toml b/pyproject.toml index e7727700c..d0e03e35a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ dependencies = [ "datasets>=4.0.0,<4.2.0", "diffusers>=0.27.2,<0.36.0", "huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0", + "accelerate>=1.10.0,<2.0.0", # Core dependencies "setuptools>=71.0.0,<81.0.0", diff --git a/src/lerobot/optim/schedulers.py b/src/lerobot/optim/schedulers.py index 55ee62e40..b5d54b396 100644 --- a/src/lerobot/optim/schedulers.py +++ b/src/lerobot/optim/schedulers.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import logging import math from dataclasses import asdict, dataclass from pathlib import Path @@ -79,7 +80,11 @@ class VQBeTSchedulerConfig(LRSchedulerConfig): @LRSchedulerConfig.register_subclass("cosine_decay_with_warmup") @dataclass class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): - """Used by Physical Intelligence to train Pi0""" + """Used by Physical Intelligence to train Pi0. + + Automatically scales warmup and decay steps if num_training_steps < num_decay_steps. + This ensures the learning rate schedule completes properly even with shorter training runs. + """ num_warmup_steps: int num_decay_steps: int @@ -87,23 +92,39 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): decay_lr: float def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: - del num_training_steps + # Auto-scale scheduler parameters if training steps are shorter than configured decay steps + actual_warmup_steps = self.num_warmup_steps + actual_decay_steps = self.num_decay_steps + + if num_training_steps < self.num_decay_steps: + # Calculate scaling factor to fit the schedule into the available training steps + scale_factor = num_training_steps / self.num_decay_steps + actual_warmup_steps = int(self.num_warmup_steps * scale_factor) + actual_decay_steps = num_training_steps + + logging.info( + f"Auto-scaling LR scheduler: " + f"num_training_steps ({num_training_steps}) < num_decay_steps ({self.num_decay_steps}). " + f"Scaling warmup: {self.num_warmup_steps} → {actual_warmup_steps}, " + f"decay: {self.num_decay_steps} → {actual_decay_steps} " + f"(scale factor: {scale_factor:.3f})" + ) def lr_lambda(current_step): def linear_warmup_schedule(current_step): if current_step <= 0: - return 1 / (self.num_warmup_steps + 1) - frac = 1 - current_step / self.num_warmup_steps - return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1 + return 1 / (actual_warmup_steps + 1) + frac = 1 - current_step / actual_warmup_steps + return (1 / (actual_warmup_steps + 1) - 1) * frac + 1 def cosine_decay_schedule(current_step): - step = min(current_step, self.num_decay_steps) - cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps)) + step = min(current_step, actual_decay_steps) + cosine_decay = 0.5 * (1 + math.cos(math.pi * step / actual_decay_steps)) alpha = self.decay_lr / self.peak_lr decayed = (1 - alpha) * cosine_decay + alpha return decayed - if current_step < self.num_warmup_steps: + if current_step < actual_warmup_steps: return linear_warmup_schedule(current_step) return cosine_decay_schedule(current_step) diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py index cc1cda9d8..d745f4317 100644 --- a/src/lerobot/policies/pi0/configuration_pi0.py +++ b/src/lerobot/policies/pi0/configuration_pi0.py @@ -75,6 +75,8 @@ class PI0Config(PreTrainedConfig): optimizer_grad_clip_norm: float = 1.0 # Scheduler settings: see openpi `CosineDecaySchedule` + # Note: These will auto-scale if --steps < scheduler_decay_steps + # For example, --steps=3000 will scale warmup to 100 and decay to 3000 scheduler_warmup_steps: int = 1_000 scheduler_decay_steps: int = 30_000 scheduler_decay_lr: float = 2.5e-6 diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py index 7c1e950b0..61346c330 100644 --- a/src/lerobot/policies/pi05/configuration_pi05.py +++ b/src/lerobot/policies/pi05/configuration_pi05.py @@ -75,6 +75,8 @@ class PI05Config(PreTrainedConfig): optimizer_grad_clip_norm: float = 1.0 # Scheduler settings: see openpi `CosineDecaySchedule` + # Note: These will auto-scale if --steps < scheduler_decay_steps + # For example, --steps=3000 will scale warmup to 100 and decay to 3000 scheduler_warmup_steps: int = 1_000 scheduler_decay_steps: int = 30_000 scheduler_decay_lr: float = 2.5e-6 diff --git a/src/lerobot/rl/wandb_utils.py b/src/lerobot/rl/wandb_utils.py index 01cef9487..1537b3783 100644 --- a/src/lerobot/rl/wandb_utils.py +++ b/src/lerobot/rl/wandb_utils.py @@ -99,7 +99,7 @@ class WandBLogger: cfg.wandb.run_id = run_id # Handle custom step key for rl asynchronous training. self._wandb_custom_step_key: set[str] | None = None - print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) + logging.info(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") self._wandb = wandb diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index bc66618ca..84eb81ad4 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -20,8 +20,8 @@ from pprint import pformat from typing import Any import torch +from accelerate import Accelerator from termcolor import colored -from torch.amp import GradScaler from torch.optim import Optimizer from lerobot.configs import parser @@ -34,7 +34,6 @@ from lerobot.envs.utils import close_envs from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.utils import get_device_from_parameters from lerobot.rl.wandb_utils import WandBLogger from lerobot.scripts.lerobot_eval import eval_policy_all from lerobot.utils.logging_utils import AverageMeter, MetricsTracker @@ -48,7 +47,6 @@ from lerobot.utils.train_utils import ( ) from lerobot.utils.utils import ( format_big_number, - get_safe_torch_device, has_method, init_logging, ) @@ -60,16 +58,15 @@ def update_policy( batch: Any, optimizer: Optimizer, grad_clip_norm: float, - grad_scaler: GradScaler, + accelerator: Accelerator, lr_scheduler=None, - use_amp: bool = False, lock=None, ) -> tuple[MetricsTracker, dict]: """ Performs a single training step to update the policy's weights. This function executes the forward and backward passes, clips gradients, and steps the optimizer and - learning rate scheduler. It also handles mixed-precision training via a GradScaler. + learning rate scheduler. Accelerator handles mixed-precision training automatically. Args: train_metrics: A MetricsTracker instance to record training statistics. @@ -77,9 +74,8 @@ def update_policy( batch: A batch of training data. optimizer: The optimizer used to update the policy's parameters. grad_clip_norm: The maximum norm for gradient clipping. - grad_scaler: The GradScaler for automatic mixed-precision training. + accelerator: The Accelerator instance for distributed training and mixed precision. lr_scheduler: An optional learning rate scheduler. - use_amp: A boolean indicating whether to use automatic mixed precision. lock: An optional lock for thread-safe optimizer updates. Returns: @@ -88,28 +84,27 @@ def update_policy( - A dictionary of outputs from the policy's forward pass, for logging purposes. """ start_time = time.perf_counter() - device = get_device_from_parameters(policy) policy.train() - with torch.autocast(device_type=device.type) if use_amp else nullcontext(): + + # Let accelerator handle mixed precision + with accelerator.autocast(): loss, output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) - grad_scaler.scale(loss).backward() - # Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**. - grad_scaler.unscale_(optimizer) + # Use accelerator's backward method + accelerator.backward(loss) - grad_norm = torch.nn.utils.clip_grad_norm_( - policy.parameters(), - grad_clip_norm, - error_if_nonfinite=False, - ) + # Clip gradients if specified + if grad_clip_norm > 0: + grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.parameters(), float("inf"), error_if_nonfinite=False + ) - # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, - # although it still skips optimizer.step() if the gradients contain infs or NaNs. + # Optimizer step with lock if lock is not None else nullcontext(): - grad_scaler.step(optimizer) - # Updates the scale for next iteration. - grad_scaler.update() + optimizer.step() optimizer.zero_grad() @@ -117,9 +112,9 @@ def update_policy( if lr_scheduler is not None: lr_scheduler.step() - if has_method(policy, "update"): - # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). - policy.update() + # Update internal buffers if policy has update method + if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): + accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update() train_metrics.loss = loss.item() train_metrics.grad_norm = grad_norm.item() @@ -129,7 +124,7 @@ def update_policy( @parser.wrap() -def train(cfg: TrainPipelineConfig): +def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): """ Main function to train a policy. @@ -143,41 +138,76 @@ def train(cfg: TrainPipelineConfig): Args: cfg: A `TrainPipelineConfig` object containing all training configurations. + accelerator: Optional Accelerator instance. If None, one will be created automatically. """ cfg.validate() - logging.info(pformat(cfg.to_dict())) - if cfg.wandb.enable and cfg.wandb.project: + # Create Accelerator if not provided + # It will automatically detect if running in distributed mode or single-process mode + # We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes + # We set find_unused_parameters=True to handle models with conditional computation + if accelerator is None: + from accelerate.utils import DistributedDataParallelKwargs + + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs]) + + init_logging(accelerator=accelerator) + + # Determine if this is the main process (for logging and checkpointing) + # When using accelerate, only the main process should log to avoid duplicate outputs + is_main_process = accelerator.is_main_process + + # Only log on main process + if is_main_process: + logging.info(pformat(cfg.to_dict())) + + # Initialize wandb only on main process + if cfg.wandb.enable and cfg.wandb.project and is_main_process: wandb_logger = WandBLogger(cfg) else: wandb_logger = None - logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + if is_main_process: + logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) if cfg.seed is not None: - set_seed(cfg.seed) + set_seed(cfg.seed, accelerator=accelerator) - # Check device is available - device = get_safe_torch_device(cfg.policy.device, log=True) + # Use accelerator's device + device = accelerator.device torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - logging.info("Creating dataset") - dataset = make_dataset(cfg) + # Dataset loading synchronization: main process downloads first to avoid race conditions + if is_main_process: + logging.info("Creating dataset") + dataset = make_dataset(cfg) + + accelerator.wait_for_everyone() + + # Now all other processes can safely load the dataset + if not is_main_process: + dataset = make_dataset(cfg) # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, # using the eval.py instead, with gym_dora environment and dora-rs. eval_env = None if cfg.eval_freq > 0 and cfg.env is not None: - logging.info("Creating env") + if is_main_process: + logging.info("Creating env") eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) - logging.info("Creating policy") + if is_main_process: + logging.info("Creating policy") policy = make_policy( cfg=cfg.policy, ds_meta=dataset.meta, ) + # Wait for all processes to finish policy creation before continuing + accelerator.wait_for_everyone() + # Create processors - only provide dataset_stats if not resuming from saved processors processor_kwargs = {} postprocessor_kwargs = {} @@ -209,9 +239,9 @@ def train(cfg: TrainPipelineConfig): **postprocessor_kwargs, ) - logging.info("Creating optimizer and scheduler") + if is_main_process: + logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) - grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) step = 0 # number of policy updates (forward + backward + optim) @@ -221,14 +251,18 @@ def train(cfg: TrainPipelineConfig): num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) - logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") - if cfg.env is not None: - logging.info(f"{cfg.env.task=}") - logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") - logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") - logging.info(f"{dataset.num_episodes=}") - logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") - logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") + if is_main_process: + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") + if cfg.env is not None: + logging.info(f"{cfg.env.task=}") + logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") + logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") + logging.info(f"{dataset.num_episodes=}") + num_processes = accelerator.num_processes + effective_bs = cfg.batch_size * num_processes + logging.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}") + logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") + logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") # create dataloader for offline training if hasattr(cfg.policy, "drop_n_last_frames"): @@ -251,7 +285,13 @@ def train(cfg: TrainPipelineConfig): sampler=sampler, pin_memory=device.type == "cuda", drop_last=False, - prefetch_factor=2, + prefetch_factor=2 if cfg.num_workers > 0 else None, + ) + + # Prepare everything with accelerator + accelerator.wait_for_everyone() + policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( + policy, optimizer, dataloader, lr_scheduler ) dl_iter = cycle(dataloader) @@ -265,11 +305,20 @@ def train(cfg: TrainPipelineConfig): "dataloading_s": AverageMeter("data_s", ":.3f"), } + # Use effective batch size for proper epoch calculation in distributed training + effective_batch_size = cfg.batch_size * accelerator.num_processes train_tracker = MetricsTracker( - cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step + effective_batch_size, + dataset.num_frames, + dataset.num_episodes, + train_metrics, + initial_step=step, + accelerator=accelerator, ) - logging.info("Start offline training on a fixed dataset") + if is_main_process: + logging.info("Start offline training on a fixed dataset") + for _ in range(step, cfg.steps): start_time = time.perf_counter() batch = next(dl_iter) @@ -282,16 +331,15 @@ def train(cfg: TrainPipelineConfig): batch, optimizer, cfg.optimizer.grad_clip_norm, - grad_scaler=grad_scaler, + accelerator=accelerator, lr_scheduler=lr_scheduler, - use_amp=cfg.policy.use_amp, ) # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # increment `step` here. step += 1 train_tracker.step() - is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 + is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 @@ -305,69 +353,90 @@ def train(cfg: TrainPipelineConfig): train_tracker.reset_averages() if cfg.save_checkpoint and is_saving_step: - logging.info(f"Checkpoint policy after step {step}") - checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) - save_checkpoint( - checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor - ) - update_last_checkpoint(checkpoint_dir) - if wandb_logger: - wandb_logger.log_policy(checkpoint_dir) - - if cfg.env and is_eval_step: - step_id = get_step_identifier(step, cfg.steps) - logging.info(f"Eval policy at step {step}") - with ( - torch.no_grad(), - torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), - ): - eval_info = eval_policy_all( - envs=eval_env, # dict[suite][task_id] -> vec_env - policy=policy, + if is_main_process: + logging.info(f"Checkpoint policy after step {step}") + checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) + save_checkpoint( + checkpoint_dir=checkpoint_dir, + step=step, + cfg=cfg, + policy=accelerator.unwrap_model(policy), + optimizer=optimizer, + scheduler=lr_scheduler, preprocessor=preprocessor, postprocessor=postprocessor, - n_episodes=cfg.eval.n_episodes, - videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", - max_episodes_rendered=4, - start_seed=cfg.seed, - max_parallel_tasks=cfg.env.max_parallel_tasks, ) - # overall metrics (suite-agnostic) - aggregated = eval_info["overall"] + update_last_checkpoint(checkpoint_dir) + if wandb_logger: + wandb_logger.log_policy(checkpoint_dir) - # optional: per-suite logging - for suite, suite_info in eval_info.items(): - logging.info("Suite %s aggregated: %s", suite, suite_info) + accelerator.wait_for_everyone() - # meters/tracker - eval_metrics = { - "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), - "pc_success": AverageMeter("success", ":.1f"), - "eval_s": AverageMeter("eval_s", ":.3f"), - } - eval_tracker = MetricsTracker( - cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step - ) - eval_tracker.eval_s = aggregated.pop("eval_s") - eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward") - eval_tracker.pc_success = aggregated.pop("pc_success") - if wandb_logger: - wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} - wandb_logger.log_dict(wandb_log_dict, step, mode="eval") - wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval") + if cfg.env and is_eval_step: + if is_main_process: + step_id = get_step_identifier(step, cfg.steps) + logging.info(f"Eval policy at step {step}") + with torch.no_grad(), accelerator.autocast(): + eval_info = eval_policy_all( + envs=eval_env, # dict[suite][task_id] -> vec_env + policy=accelerator.unwrap_model(policy), + preprocessor=preprocessor, + postprocessor=postprocessor, + n_episodes=cfg.eval.n_episodes, + videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", + max_episodes_rendered=4, + start_seed=cfg.seed, + max_parallel_tasks=cfg.env.max_parallel_tasks, + ) + # overall metrics (suite-agnostic) + aggregated = eval_info["overall"] + + # optional: per-suite logging + for suite, suite_info in eval_info.items(): + logging.info("Suite %s aggregated: %s", suite, suite_info) + + # meters/tracker + eval_metrics = { + "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), + "pc_success": AverageMeter("success", ":.1f"), + "eval_s": AverageMeter("eval_s", ":.3f"), + } + eval_tracker = MetricsTracker( + cfg.batch_size, + dataset.num_frames, + dataset.num_episodes, + eval_metrics, + initial_step=step, + accelerator=accelerator, + ) + eval_tracker.eval_s = aggregated.pop("eval_s") + eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward") + eval_tracker.pc_success = aggregated.pop("pc_success") + if wandb_logger: + wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} + wandb_logger.log_dict(wandb_log_dict, step, mode="eval") + wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval") + + accelerator.wait_for_everyone() if eval_env: close_envs(eval_env) - logging.info("End of training") - if cfg.policy.push_to_hub: - policy.push_model_to_hub(cfg) - preprocessor.push_to_hub(cfg.policy.repo_id) - postprocessor.push_to_hub(cfg.policy.repo_id) + if is_main_process: + logging.info("End of training") + + if cfg.policy.push_to_hub: + unwrapped_policy = accelerator.unwrap_model(policy) + unwrapped_policy.push_model_to_hub(cfg) + preprocessor.push_to_hub(cfg.policy.repo_id) + postprocessor.push_to_hub(cfg.policy.repo_id) + + # Properly clean up the distributed process group + accelerator.wait_for_everyone() + accelerator.end_training() def main(): - init_logging() train() diff --git a/src/lerobot/utils/logging_utils.py b/src/lerobot/utils/logging_utils.py index b6404e66d..c4c1f42e0 100644 --- a/src/lerobot/utils/logging_utils.py +++ b/src/lerobot/utils/logging_utils.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable from typing import Any from lerobot.utils.utils import format_big_number @@ -84,6 +85,7 @@ class MetricsTracker: "samples", "episodes", "epochs", + "accelerator", ] def __init__( @@ -93,6 +95,7 @@ class MetricsTracker: num_episodes: int, metrics: dict[str, AverageMeter], initial_step: int = 0, + accelerator: Callable | None = None, ): self.__dict__.update(dict.fromkeys(self.__keys__)) self._batch_size = batch_size @@ -106,6 +109,7 @@ class MetricsTracker: self.samples = self.steps * self._batch_size self.episodes = self.samples / self._avg_samples_per_ep self.epochs = self.samples / self._num_frames + self.accelerator = accelerator def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any: if name in self.__dict__: @@ -128,7 +132,7 @@ class MetricsTracker: Updates metrics that depend on 'step' for one step. """ self.steps += 1 - self.samples += self._batch_size + self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1) self.episodes = self.samples / self._avg_samples_per_ep self.epochs = self.samples / self._num_frames diff --git a/src/lerobot/utils/random_utils.py b/src/lerobot/utils/random_utils.py index 1bb1f0631..b34d357aa 100644 --- a/src/lerobot/utils/random_utils.py +++ b/src/lerobot/utils/random_utils.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import random -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager from pathlib import Path from typing import Any @@ -164,14 +164,20 @@ def set_rng_state(random_state_dict: dict[str, Any]): torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) -def set_seed(seed) -> None: +def set_seed(seed, accelerator: Callable | None = None) -> None: """Set seed for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) + if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + if accelerator: + from accelerate.utils import set_seed as _accelerate_set_seed + + _accelerate_set_seed(seed) + @contextmanager def seeded_context(seed: int) -> Generator[None, None, None]: diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index dfcd4a6b1..4447a1fcf 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -27,6 +27,7 @@ from statistics import mean import numpy as np import torch +from accelerate import Accelerator from datasets.utils.logging import disable_progress_bar, enable_progress_bar @@ -110,36 +111,50 @@ def init_logging( display_pid: bool = False, console_level: str = "INFO", file_level: str = "DEBUG", + accelerator: Accelerator | None = None, ): + """Initialize logging configuration for LeRobot. + + In multi-GPU training, only the main process logs to console to avoid duplicate output. + Non-main processes have console logging suppressed but can still log to file. + + Args: + log_file: Optional file path to write logs to + display_pid: Include process ID in log messages (useful for debugging multi-process) + console_level: Logging level for console output + file_level: Logging level for file output + accelerator: Optional Accelerator instance (for multi-GPU detection) + """ + def custom_format(record: logging.LogRecord) -> str: dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fnameline = f"{record.pathname}:{record.lineno}" - - # NOTE: Display PID is useful for multi-process logging. - if display_pid: - pid_str = f"[PID: {os.getpid()}]" - message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.getMessage()}" - else: - message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.getMessage()}" - return message + pid_str = f"[PID: {os.getpid()}] " if display_pid else "" + return f"{record.levelname} {pid_str}{dt} {fnameline[-15:]:>15} {record.getMessage()}" formatter = logging.Formatter() formatter.format = custom_format logger = logging.getLogger() - logger.setLevel(logging.NOTSET) # Set the logger to the lowest level to capture all messages + logger.setLevel(logging.NOTSET) - # Remove unused default handlers - for handler in logger.handlers[:]: - logger.removeHandler(handler) + # Clear any existing handlers + logger.handlers.clear() - # Write logs to console - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - console_handler.setLevel(console_level.upper()) - logger.addHandler(console_handler) + # Determine if this is a non-main process in distributed training + is_main_process = accelerator.is_main_process if accelerator is not None else True + + # Console logging (main process only) + if is_main_process: + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + console_handler.setLevel(console_level.upper()) + logger.addHandler(console_handler) + else: + # Suppress console output for non-main processes + logger.addHandler(logging.NullHandler()) + logger.setLevel(logging.ERROR) - # Additionally write logs to file if log_file is not None: file_handler = logging.FileHandler(log_file) file_handler.setFormatter(formatter) diff --git a/tests/training/test_multi_gpu.py b/tests/training/test_multi_gpu.py new file mode 100644 index 000000000..bb234e2e7 --- /dev/null +++ b/tests/training/test_multi_gpu.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Multi-GPU Training Tests + +This module tests multi-GPU training functionality with accelerate. +These tests are designed to run on machines with 2+ GPUs and are executed +in the nightly CI workflow. + +The tests automatically generate accelerate configs and launch training +with subprocess to properly test the distributed training environment. +""" + +import os +import subprocess +import tempfile +from pathlib import Path + +import pytest +import torch + +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +def get_num_available_gpus(): + """Returns the number of available GPUs.""" + if not torch.cuda.is_available(): + return 0 + return torch.cuda.device_count() + + +def download_dataset(repo_id, episodes): + """ + Pre-download dataset to avoid race conditions in multi-GPU training. + + Args: + repo_id: HuggingFace dataset repository ID + episodes: List of episode indices to download + """ + # Simply instantiating the dataset will download it + _ = LeRobotDataset(repo_id, episodes=episodes) + print(f"Dataset {repo_id} downloaded successfully") + + +def run_accelerate_training(config_args, num_processes=4, temp_dir=None): + """ + Helper function to run training with accelerate launch. + + Args: + config_args: List of config arguments to pass to lerobot_train.py + num_processes: Number of processes (GPUs) to use + temp_dir: Temporary directory for outputs + + Returns: + subprocess.CompletedProcess result + """ + + config_path = Path(temp_dir) / "accelerate_config.yaml" + + # Write YAML config + with open(config_path, "w") as f: + f.write("compute_environment: LOCAL_MACHINE\n") + f.write("distributed_type: MULTI_GPU\n") + f.write("mixed_precision: 'no'\n") + f.write(f"num_processes: {num_processes}\n") + f.write("use_cpu: false\n") + f.write("gpu_ids: all\n") + f.write("downcast_bf16: 'no'\n") + f.write("machine_rank: 0\n") + f.write("main_training_function: main\n") + f.write("num_machines: 1\n") + f.write("rdzv_backend: static\n") + f.write("same_network: true\n") + + cmd = [ + "accelerate", + "launch", + "--config_file", + str(config_path), + "-m", + "lerobot.scripts.lerobot_train", + ] + config_args + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + env={**os.environ, "CUDA_VISIBLE_DEVICES": ",".join(map(str, range(num_processes)))}, + ) + + return result + + +@pytest.mark.skipif( + get_num_available_gpus() < 2, + reason="Multi-GPU tests require at least 2 GPUs", +) +class TestMultiGPUTraining: + """Test suite for multi-GPU training functionality.""" + + def test_basic_multi_gpu_training(self): + """ + Test that basic multi-GPU training runs successfully. + Verifies that the training completes without errors. + """ + # Pre-download dataset to avoid race conditions + download_dataset("lerobot/pusht", episodes=[0]) + + with tempfile.TemporaryDirectory() as temp_dir: + output_dir = Path(temp_dir) / "outputs" + + config_args = [ + "--dataset.repo_id=lerobot/pusht", + "--dataset.episodes=[0]", + "--policy.type=act", + "--policy.device=cuda", + "--policy.push_to_hub=false", + f"--output_dir={output_dir}", + "--batch_size=4", + "--steps=10", + "--eval_freq=-1", + "--log_freq=5", + "--save_freq=10", + "--seed=42", + "--num_workers=0", + ] + + result = run_accelerate_training(config_args, num_processes=4, temp_dir=temp_dir) + + # Check that training completed successfully + assert result.returncode == 0, ( + f"Multi-GPU training failed with return code {result.returncode}\n" + f"STDOUT:\n{result.stdout}\n" + f"STDERR:\n{result.stderr}" + ) + + # Verify checkpoint was saved + checkpoints_dir = output_dir / "checkpoints" + assert checkpoints_dir.exists(), "Checkpoints directory was not created" + + # Verify that training completed + assert "End of training" in result.stdout or "End of training" in result.stderr + + def test_checkpoint_saving_multi_gpu(self): + """ + Test that checkpoints are correctly saved during multi-GPU training. + Only the main process (rank 0) should save checkpoints. + """ + # Pre-download dataset to avoid race conditions + download_dataset("lerobot/pusht", episodes=[0]) + + with tempfile.TemporaryDirectory() as temp_dir: + output_dir = Path(temp_dir) / "outputs" + + config_args = [ + "--dataset.repo_id=lerobot/pusht", + "--dataset.episodes=[0]", + "--policy.type=act", + "--policy.device=cuda", + "--policy.push_to_hub=false", + f"--output_dir={output_dir}", + "--batch_size=4", + "--steps=20", + "--eval_freq=-1", + "--log_freq=5", + "--save_freq=10", + "--seed=42", + "--num_workers=0", + ] + + result = run_accelerate_training(config_args, num_processes=2, temp_dir=temp_dir) + + assert result.returncode == 0, ( + f"Training failed:\nSTDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}" + ) + + # Verify checkpoint directory exists + checkpoints_dir = output_dir / "checkpoints" + assert checkpoints_dir.exists(), "Checkpoints directory not created" + + # Count checkpoint directories (should have checkpoint at step 10 and 20) + checkpoint_dirs = [d for d in checkpoints_dir.iterdir() if d.is_dir()] + assert len(checkpoint_dirs) >= 1, f"Expected at least 1 checkpoint, found {len(checkpoint_dirs)}" + + # Verify checkpoint contents + for checkpoint_dir in checkpoint_dirs: + # Check for model files + model_files = list(checkpoint_dir.rglob("*.safetensors")) + assert len(model_files) > 0, f"No model files in checkpoint {checkpoint_dir}" + + # Check for training state + training_state_dir = checkpoint_dir / "training_state" + assert training_state_dir.exists(), f"No training state in checkpoint {checkpoint_dir}" + + # Verify optimizer state exists + optimizer_state = training_state_dir / "optimizer_state.safetensors" + assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}" From 8bd0aec618413a760005c7ebbc09f1b5c9169b0e Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 16 Oct 2025 17:44:50 +0200 Subject: [PATCH 21/24] chore(ci): relax stale bot for PRs (#2222) --- .github/workflows/stale.yml | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index af91c9f58..06fc69fc4 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -27,15 +27,17 @@ env: This issue was closed because it has been stalled for 14 days with no activity. Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions. CLOSE_PR_MESSAGE: > - This PR was closed because it has been stalled for 14 days with no activity. + This PR was closed because it has been stalled for 21 days with no activity. Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions. WARN_ISSUE_MESSAGE: > This issue has been automatically marked as stale because it has not had recent activity (6 months). It will be closed if no further activity occurs. + Any change, comment or update to this issue will reset this count. Thank you for your contributions. WARN_PR_MESSAGE: > This PR has been automatically marked as stale because it has not had - recent activity (6 months). It will be closed if no further activity occurs. + recent activity (1 year). It will be closed if no further activity occurs. + Any change, comment or update to this PR will reset this count. Thank you for your contributions. jobs: @@ -56,10 +58,10 @@ jobs: stale-pr-label: stale exempt-issue-labels: never-stale exempt-pr-labels: never-stale - days-before-issue-stale: 180 # TODO(Steven): Will modify this to 90 after initial cleanup + days-before-issue-stale: 180 days-before-issue-close: 14 - days-before-pr-stale: 180 - days-before-pr-close: 14 + days-before-pr-stale: 365 + days-before-pr-close: 21 delete-branch: true close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }} close-pr-message: ${{ env.CLOSE_PR_MESSAGE }} From 96c664e09f4601925d6f9f7e8c477bb6cb95ad87 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 17 Oct 2025 13:59:10 +0200 Subject: [PATCH 22/24] fix(scripts): warmup in find cameras script (#2229) --- src/lerobot/scripts/lerobot_find_cameras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/scripts/lerobot_find_cameras.py b/src/lerobot/scripts/lerobot_find_cameras.py index e17dca805..0248a2768 100644 --- a/src/lerobot/scripts/lerobot_find_cameras.py +++ b/src/lerobot/scripts/lerobot_find_cameras.py @@ -180,7 +180,7 @@ def create_camera_instance(cam_meta: dict[str, Any]) -> dict[str, Any] | None: if instance: logger.info(f"Connecting to {cam_type} camera: {cam_id}...") - instance.connect(warmup=False) + instance.connect(warmup=True) return {"instance": instance, "meta": cam_meta} except Exception as e: logger.error(f"Failed to connect or configure {cam_type} camera {cam_id}: {e}") From 4afb2538251f28bff0f1f95203ca550accc6508b Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 17 Oct 2025 13:59:31 +0200 Subject: [PATCH 23/24] fix(dependencies): wandb > 0.22.0 uses a different version of protobuf (#2230) --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d0e03e35a..c8879ac36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ dependencies = [ "packaging>=24.2,<26.0", "pynput>=1.7.7,<1.9.0", "pyserial>=3.5,<4.0", - "wandb>=0.20.0,<0.23.0", + "wandb>=0.20.0,<0.22.0", # TODO: Bumb dependency (compatible with protobuf) "torch>=2.2.1,<2.8.0", # TODO: Bumb dependency "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency @@ -97,7 +97,7 @@ dependencies = [ pygame-dep = ["pygame>=2.5.1,<2.7.0"] placo-dep = ["placo>=0.9.6,<0.10.0"] transformers-dep = ["transformers>=4.53.0,<5.0.0"] -grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] +grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb) # Motors feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"] From 52455d03a734bed9fd481ad3511db1379c55a653 Mon Sep 17 00:00:00 2001 From: Infinity4B <54246104+Infinity4B@users.noreply.github.com> Date: Fri, 17 Oct 2025 20:34:21 +0800 Subject: [PATCH 24/24] fix eval-related doc errors (#2183) Signed-off-by: Steven Palma Co-authored-by: Steven Palma --- README.md | 2 +- src/lerobot/scripts/lerobot_eval.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 357e62cc1..6409d931d 100644 --- a/README.md +++ b/README.md @@ -310,7 +310,7 @@ To upload these to the hub, run the following: huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model ``` -See [eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/eval.py) for an example of how other people may use your policy. +See [lerobot_eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py) for an example of how other people may use your policy. ### Acknowledgment diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index aed7d32e3..0fdec9286 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -27,7 +27,7 @@ lerobot-eval \ --eval.batch_size=10 \ --eval.n_episodes=10 \ --use_amp=false \ - --device=cuda + --policy.device=cuda ``` OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes. @@ -38,7 +38,7 @@ lerobot-eval \ --eval.batch_size=10 \ --eval.n_episodes=10 \ --use_amp=false \ - --device=cuda + --policy.device=cuda ``` Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files.