diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 03f26a792..f9fa02597 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -119,6 +119,7 @@ jobs: TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton container: image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images] + options: --shm-size "16gb" credentials: username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }} password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }} @@ -158,3 +159,35 @@ jobs: run: pytest tests -vv --maxfail=10 - name: Run end-to-end tests run: make test-end-to-end + + # This job runs multi-GPU training tests with 4 GPUs + nightly-multi-gpu-tests: + name: Nightly Multi-GPU Tests + needs: [build-docker-gpu-nightly] + runs-on: + group: aws-g4dn-12xlarge # Instance with 4 GPUs + env: + HF_HOME: /home/user_lerobot/.cache/huggingface + HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot + TORCH_HOME: /home/user_lerobot/.cache/torch + TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton + CUDA_VISIBLE_DEVICES: "0,1,2,3" + container: + image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images] + options: --gpus all --shm-size "16gb" + credentials: + username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }} + password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }} + defaults: + run: + shell: bash + working-directory: /lerobot + steps: + - name: Verify GPU availability + run: | + nvidia-smi + python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')" + + - name: Run multi-GPU training tests + run: pytest tests/training/test_multi_gpu.py -vv --maxfail=3 + timeout-minutes: 10 diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index af91c9f58..06fc69fc4 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -27,15 +27,17 @@ env: This issue was closed because it has been stalled for 14 days with no activity. Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions. CLOSE_PR_MESSAGE: > - This PR was closed because it has been stalled for 14 days with no activity. + This PR was closed because it has been stalled for 21 days with no activity. Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions. WARN_ISSUE_MESSAGE: > This issue has been automatically marked as stale because it has not had recent activity (6 months). It will be closed if no further activity occurs. + Any change, comment or update to this issue will reset this count. Thank you for your contributions. WARN_PR_MESSAGE: > This PR has been automatically marked as stale because it has not had - recent activity (6 months). It will be closed if no further activity occurs. + recent activity (1 year). It will be closed if no further activity occurs. + Any change, comment or update to this PR will reset this count. Thank you for your contributions. jobs: @@ -56,10 +58,10 @@ jobs: stale-pr-label: stale exempt-issue-labels: never-stale exempt-pr-labels: never-stale - days-before-issue-stale: 180 # TODO(Steven): Will modify this to 90 after initial cleanup + days-before-issue-stale: 180 days-before-issue-close: 14 - days-before-pr-stale: 180 - days-before-pr-close: 14 + days-before-pr-stale: 365 + days-before-pr-close: 21 delete-branch: true close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }} close-pr-message: ${{ env.CLOSE_PR_MESSAGE }} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 369af602b..a07596728 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -72,7 +72,6 @@ post it. Look at our implementations for [datasets](./src/lerobot/datasets/), [policies](./src/lerobot/policies/), environments ([aloha](https://github.com/huggingface/gym-aloha), -[xarm](https://github.com/huggingface/gym-xarm), [pusht](https://github.com/huggingface/gym-pusht)) and follow the same api design. diff --git a/Makefile b/Makefile index fbe8a5bae..e02f02403 100644 --- a/Makefile +++ b/Makefile @@ -119,10 +119,9 @@ test-tdmpc-ete-train: --policy.type=tdmpc \ --policy.device=$(DEVICE) \ --policy.push_to_hub=false \ - --env.type=xarm \ - --env.task=XarmLift-v0 \ + --env.type=pusht \ --env.episode_length=5 \ - --dataset.repo_id=lerobot/xarm_lift_medium \ + --dataset.repo_id=lerobot/pusht_image \ --dataset.image_transforms.enable=true \ --dataset.episodes="[0]" \ --batch_size=2 \ @@ -140,9 +139,10 @@ test-tdmpc-ete-eval: lerobot-eval \ --policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \ --policy.device=$(DEVICE) \ - --env.type=xarm \ + --env.type=pusht \ --env.episode_length=5 \ - --env.task=XarmLift-v0 \ + --env.observation_height=96 \ + --env.observation_width=96 \ --eval.n_episodes=1 \ --eval.batch_size=1 diff --git a/README.md b/README.md index 357e62cc1..6409d931d 100644 --- a/README.md +++ b/README.md @@ -310,7 +310,7 @@ To upload these to the hub, run the following: huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model ``` -See [eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/eval.py) for an example of how other people may use your policy. +See [lerobot_eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py) for an example of how other people may use your policy. ### Acknowledgment diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 36eaea165..5e100013a 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -7,8 +7,6 @@ - sections: - local: il_robots title: Imitation Learning for Robots - - local: il_sim - title: Imitation Learning in Sim - local: cameras title: Cameras - local: integrate_hardware @@ -19,23 +17,35 @@ title: Train RL in Simulation - local: async title: Use Async Inference + - local: multi_gpu_training + title: Multi GPU training title: "Tutorials" - sections: - local: lerobot-dataset-v3 title: Using LeRobotDataset - local: porting_datasets_v3 title: Porting Large Datasets + - local: using_dataset_tools + title: Using the Dataset Tools title: "Datasets" - sections: + - local: act + title: ACT - local: smolvla title: SmolVLA - local: pi0 title: π₀ (Pi0) - local: pi05 title: π₀.₅ (Pi05) + title: "Policies" +- sections: + - local: il_sim + title: Imitation Learning in Sim - local: libero title: Using Libero - title: "Policies" + - local: metaworld + title: Using MetaWorld + title: "Simulation" - sections: - local: introduction_processors title: Introduction to Robot Processors diff --git a/docs/source/act.mdx b/docs/source/act.mdx new file mode 100644 index 000000000..e3294ca69 --- /dev/null +++ b/docs/source/act.mdx @@ -0,0 +1,92 @@ +# ACT (Action Chunking with Transformers) + +ACT is a **lightweight and efficient policy for imitation learning**, especially well-suited for fine-grained manipulation tasks. It's the **first model we recommend when you're starting out** with LeRobot due to its fast training time, low computational requirements, and strong performance. + +
+ +
+ +_Watch this tutorial from the LeRobot team to learn how ACT works: [LeRobot ACT Tutorial](https://www.youtube.com/watch?v=ft73x0LfGpM)_ + +## Model Overview + +Action Chunking with Transformers (ACT) was introduced in the paper [Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware](https://arxiv.org/abs/2304.13705) by Zhao et al. The policy was designed to enable precise, contact-rich manipulation tasks using affordable hardware and minimal demonstration data. + +### Why ACT is Great for Beginners + +ACT stands out as an excellent starting point for several reasons: + +- **Fast Training**: Trains in a few hours on a single GPU +- **Lightweight**: Only ~80M parameters, making it efficient and easy to work with +- **Data Efficient**: Often achieves high success rates with just 50 demonstrations + +### Architecture + +ACT uses a transformer-based architecture with three main components: + +1. **Vision Backbone**: ResNet-18 processes images from multiple camera viewpoints +2. **Transformer Encoder**: Synthesizes information from camera features, joint positions, and a learned latent variable +3. **Transformer Decoder**: Generates coherent action sequences using cross-attention + +The policy takes as input: + +- Multiple RGB images (e.g., from wrist cameras, front/top cameras) +- Current robot joint positions +- A latent style variable `z` (learned during training, set to zero during inference) + +And outputs a chunk of `k` future action sequences. + +## Installation Requirements + +1. Install LeRobot by following our [Installation Guide](./installation). +2. ACT is included in the base LeRobot installation, so no additional dependencies are needed! + +## Training ACT + +ACT works seamlessly with the standard LeRobot training pipeline. Here's a complete example for training ACT on your dataset: + +```bash +lerobot-train \ + --dataset.repo_id=${HF_USER}/your_dataset \ + --policy.type=act \ + --output_dir=outputs/train/act_your_dataset \ + --job_name=act_your_dataset \ + --policy.device=cuda \ + --wandb.enable=true \ + --policy.repo_id=${HF_USER}/act_policy +``` + +### Training Tips + +1. **Start with defaults**: ACT's default hyperparameters work well for most tasks +2. **Training duration**: Expect a few hours for 100k training steps on a single GPU +3. **Batch size**: Start with batch size 8 and adjust based on your GPU memory + +### Train using Google Colab + +If your local computer doesn't have a powerful GPU, you can utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act). + +## Evaluating ACT + +Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes: + +```bash +lerobot-record \ + --robot.type=so100_follower \ + --robot.port=/dev/ttyACM0 \ + --robot.id=my_robot \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ + --display_data=true \ + --dataset.repo_id=${HF_USER}/eval_act_your_dataset \ + --dataset.num_episodes=10 \ + --dataset.single_task="Your task description" \ + --policy.path=${HF_USER}/act_policy +``` diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 91df14028..0d8fd56e5 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -513,13 +513,14 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import hw_to_dataset_features from lerobot.policies.act.modeling_act import ACTPolicy +from lerobot.policies.factory import make_pre_post_processors from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.scripts.lerobot_record import record_loop from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun -from lerobot.record import record_loop -from lerobot.policies.factory import make_processor + NUM_EPISODES = 5 FPS = 30 @@ -562,7 +563,7 @@ init_rerun(session_name="recording") # Connect the robot robot.connect() -preprocessor, postprocessor = make_processor( +preprocessor, postprocessor = make_pre_post_processors( policy_cfg=policy, pretrained_path=HF_MODEL_ID, dataset_stats=dataset.meta.stats, diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 93354c2ee..f5fd09acd 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -91,7 +91,7 @@ LeRobot provides optional extras for specific functionalities. Multiple extras c ### Simulations -Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), `xarm` ([gym-xarm](https://github.com/huggingface/gym-xarm)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht)) +Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht)) Example: ```bash diff --git a/docs/source/integrate_hardware.mdx b/docs/source/integrate_hardware.mdx index 7e7fe0bff..ed9dc8dd5 100644 --- a/docs/source/integrate_hardware.mdx +++ b/docs/source/integrate_hardware.mdx @@ -8,7 +8,7 @@ To that end, we provide the [`Robot`](https://github.com/huggingface/lerobot/blo - Your own robot which exposes a communication interface (e.g. serial, CAN, TCP) - A way to read sensor data and send motor commands programmatically, e.g. manufacturer's SDK or API, or your own protocol implementation. -- LeRobot installed in your environment. Follow our [Installation Guide](./installation.mdx). +- LeRobot installed in your environment. Follow our [Installation Guide](./installation). ## Choose your motors @@ -65,7 +65,7 @@ class MyCoolRobotConfig(RobotConfig): ``` -[Cameras tutorial](./cameras.mdx) to understand how to detect and add your camera. +[Cameras tutorial](./cameras) to understand how to detect and add your camera. Next, we'll create our actual robot class which inherits from `Robot`. This abstract class defines a contract you must follow for your robot to be usable with the rest of the LeRobot tools. diff --git a/docs/source/introduction_processors.mdx b/docs/source/introduction_processors.mdx index 308edbb3b..6f3768615 100644 --- a/docs/source/introduction_processors.mdx +++ b/docs/source/introduction_processors.mdx @@ -297,9 +297,9 @@ LeRobot provides many registered processor steps. Here are the most commonly use ### Next Steps -- **[Implement Your Own Processor](implement_your_own_processor.mdx)** - Create custom processor steps -- **[Debug Your Pipeline](debug_processor_pipeline.mdx)** - Troubleshoot and optimize pipelines -- **[Processors for Robots and Teleoperators](processors_robots_teleop.mdx)** - Real-world integration patterns +- **[Implement Your Own Processor](./implement_your_own_processor)** - Create custom processor steps +- **[Debug Your Pipeline](./debug_processor_pipeline)** - Troubleshoot and optimize pipelines +- **[Processors for Robots and Teleoperators](./processors_robots_teleop)** - Real-world integration patterns ## Summary diff --git a/docs/source/lerobot-dataset-v3.mdx b/docs/source/lerobot-dataset-v3.mdx index cf1942fdc..3521914f2 100644 --- a/docs/source/lerobot-dataset-v3.mdx +++ b/docs/source/lerobot-dataset-v3.mdx @@ -279,3 +279,36 @@ python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id= 2e-4 +accelerate launch --num_processes=2 $(which lerobot-train) \ + --optimizer.lr=2e-4 \ + --dataset.repo_id=lerobot/pusht \ + --policy=act +``` + +**Training Steps Scaling:** + +Since the effective batch size `bs` increases with multiple GPUs (batch_size × num_gpus), you may want to reduce the number of training steps proportionally: + +```bash +# Example: 2 GPUs with effective batch size 2x larger +# Original: batch_size=8, steps=100000 +# With 2 GPUs: batch_size=8 (16 in total), steps=50000 +accelerate launch --num_processes=2 $(which lerobot-train) \ + --batch_size=8 \ + --steps=50000 \ + --dataset.repo_id=lerobot/pusht \ + --policy=act +``` + +## Notes + +- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration. +- Training logs, checkpoints, and hub uploads are only done by the main process to avoid conflicts. Non-main processes have console logging disabled to prevent duplicate output. +- The effective batch size is `batch_size × num_gpus`. If you use 4 GPUs with `--batch_size=8`, your effective batch size is 32. +- Learning rate scheduling is handled correctly across multiple processes—LeRobot sets `step_scheduler_with_optimizer=False` to prevent accelerate from adjusting scheduler steps based on the number of processes. +- When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility. +- WandB integration automatically initializes only on the main process, preventing multiple runs from being created. + +For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate). If you want to learn more about how to train on a large number of GPUs, checkout this awesome guide: [Ultrascale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook). diff --git a/docs/source/phone_teleop.mdx b/docs/source/phone_teleop.mdx index 22159193c..76e3c367c 100644 --- a/docs/source/phone_teleop.mdx +++ b/docs/source/phone_teleop.mdx @@ -79,7 +79,7 @@ After running the example: - Android: after starting the script, open the printed local URL on your phone, tap Start, then press and hold Move. - iOS: open HEBI Mobile I/O first; B1 enables motion. A3 controls the gripper. -Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop.mdx) guide. +Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop) guide. - Run this example to record a dataset, which saves absolute end effector observations and actions: diff --git a/docs/source/using_dataset_tools.mdx b/docs/source/using_dataset_tools.mdx new file mode 100644 index 000000000..affca0ee5 --- /dev/null +++ b/docs/source/using_dataset_tools.mdx @@ -0,0 +1,102 @@ +# Using Dataset Tools + +This guide covers the dataset tools utilities available in LeRobot for modifying and editing existing datasets. + +## Overview + +LeRobot provides several utilities for manipulating datasets: + +1. **Delete Episodes** - Remove specific episodes from a dataset +2. **Split Dataset** - Divide a dataset into multiple smaller datasets +3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids` +4. **Add Features** - Add new features to a dataset +5. **Remove Features** - Remove features from a dataset + +The core implementation is in `lerobot.datasets.dataset_tools`. +An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`. + +## Command-Line Tool: lerobot-edit-dataset + +`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features. + +Run `lerobot-edit-dataset --help` for more information on the configuration of each operation. + +### Usage Examples + +#### Delete Episodes + +Remove specific episodes from a dataset. This is useful for filtering out undesired data. + +```bash +# Delete episodes 0, 2, and 5 (modifies original dataset) +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" + +# Delete episodes and save to a new dataset (preserves original dataset) +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht_after_deletion \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" +``` + +#### Split Dataset + +Divide a dataset into multiple subsets. + +```bash +# Split by fractions (e.g. 80% train, 20% test, 20% val) +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": 0.8, "test": 0.2, "val": 0.2}' + +# Split by specific episode indices +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"task1": [0, 1, 2, 3], "task2": [4, 5]}' +``` + +There are no constraints on the split names, they can be determined by the user. Resulting datasets are saved under the repo id with the split name appended, e.g. `lerobot/pusht_train`, `lerobot/pusht_task1`, `lerobot/pusht_task2`. + +#### Merge Datasets + +Combine multiple datasets into a single dataset. + +```bash +# Merge train and validation splits back into one dataset +lerobot-edit-dataset \ + --repo_id lerobot/pusht_merged \ + --operation.type merge \ + --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" +``` + +#### Remove Features + +Remove features from a dataset. + +```bash +# Remove a camera feature +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type remove_feature \ + --operation.feature_names "['observation.images.top']" +``` + +### Push to Hub + +Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub: + +```bash +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht_after_deletion \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" \ + --push_to_hub +``` + +There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`. diff --git a/examples/dataset/use_dataset_tools.py b/examples/dataset/use_dataset_tools.py new file mode 100644 index 000000000..bd7c389bc --- /dev/null +++ b/examples/dataset/use_dataset_tools.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example script demonstrating dataset tools utilities. + +This script shows how to: +1. Delete episodes from a dataset +2. Split a dataset into train/val sets +3. Add/remove features +4. Merge datasets + +Usage: + python examples/dataset/use_dataset_tools.py +""" + +import numpy as np + +from lerobot.datasets.dataset_tools import ( + add_features, + delete_episodes, + merge_datasets, + modify_features, + remove_feature, + split_dataset, +) +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +def main(): + dataset = LeRobotDataset("lerobot/pusht") + + print(f"Original dataset: {dataset.meta.total_episodes} episodes, {dataset.meta.total_frames} frames") + print(f"Features: {list(dataset.meta.features.keys())}") + + print("\n1. Deleting episodes 0 and 2...") + filtered_dataset = delete_episodes(dataset, episode_indices=[0, 2], repo_id="lerobot/pusht_filtered") + print(f"Filtered dataset: {filtered_dataset.meta.total_episodes} episodes") + + print("\n2. Splitting dataset into train/val...") + splits = split_dataset( + dataset, + splits={"train": 0.8, "val": 0.2}, + ) + print(f"Train split: {splits['train'].meta.total_episodes} episodes") + print(f"Val split: {splits['val'].meta.total_episodes} episodes") + + print("\n3. Adding features...") + + reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32) + + def compute_success(row_dict, episode_index, frame_index): + episode_length = 10 + return float(frame_index >= episode_length - 10) + + dataset_with_features = add_features( + dataset, + features={ + "reward": ( + reward_values, + {"dtype": "float32", "shape": (1,), "names": None}, + ), + "success": ( + compute_success, + {"dtype": "float32", "shape": (1,), "names": None}, + ), + }, + repo_id="lerobot/pusht_with_features", + ) + + print(f"New features: {list(dataset_with_features.meta.features.keys())}") + + print("\n4. Removing the success feature...") + dataset_cleaned = remove_feature( + dataset_with_features, feature_names="success", repo_id="lerobot/pusht_cleaned" + ) + print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}") + + print("\n5. Using modify_features to add and remove features simultaneously...") + dataset_modified = modify_features( + dataset_with_features, + add_features={ + "discount": ( + np.ones(dataset.meta.total_frames, dtype=np.float32) * 0.99, + {"dtype": "float32", "shape": (1,), "names": None}, + ), + }, + remove_features="reward", + repo_id="lerobot/pusht_modified", + ) + print(f"Modified features: {list(dataset_modified.meta.features.keys())}") + + print("\n6. Merging train and val splits back together...") + merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged") + print(f"Merged dataset: {merged.meta.total_episodes} episodes") + + print("\n7. Complex workflow example...") + + if len(dataset.meta.camera_keys) > 1: + camera_to_remove = dataset.meta.camera_keys[0] + print(f"Removing camera: {camera_to_remove}") + dataset_no_cam = remove_feature( + dataset, feature_names=camera_to_remove, repo_id="pusht_no_first_camera" + ) + print(f"Remaining cameras: {dataset_no_cam.meta.camera_keys}") + + print("\nDone! Check ~/.cache/huggingface/lerobot/ for the created datasets.") + + +if __name__ == "__main__": + main() diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 8a62d92a9..4501008d0 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -133,4 +133,6 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: log_say("Stop recording") robot.disconnect() listener.stop() + +dataset.finalize() dataset.push_to_hub() diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 9070741bf..491e1c386 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -130,4 +130,6 @@ robot.disconnect() leader_arm.disconnect() keyboard.disconnect() listener.stop() + +dataset.finalize() dataset.push_to_hub() diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index 0d53f1177..ff8dbddd2 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -194,4 +194,6 @@ for episode_idx in range(NUM_EPISODES): log_say("Stop recording") robot.disconnect() listener.stop() + +dataset.finalize() dataset.push_to_hub() diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index d3ef293a7..880f9c9b4 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -200,4 +200,6 @@ log_say("Stop recording") robot.disconnect() phone.disconnect() listener.stop() + +dataset.finalize() dataset.push_to_hub() diff --git a/examples/port_datasets/port_droid.py b/examples/port_datasets/port_droid.py index 4efb131e4..a1fb50914 100644 --- a/examples/port_datasets/port_droid.py +++ b/examples/port_datasets/port_droid.py @@ -362,6 +362,8 @@ def port_droid( lerobot_dataset.save_episode() logging.info("Save_episode") + lerobot_dataset.finalize() + if push_to_hub: lerobot_dataset.push_to_hub( # Add openx tag, since it belongs to the openx collection of datasets diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py index 53a385442..60489b3cf 100644 --- a/examples/so100_to_so100_EE/evaluate.py +++ b/examples/so100_to_so100_EE/evaluate.py @@ -195,4 +195,6 @@ for episode_idx in range(NUM_EPISODES): log_say("Stop recording") robot.disconnect() listener.stop() + +dataset.finalize() dataset.push_to_hub() diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index 9ed6e51a9..5ff1c286f 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -199,4 +199,6 @@ log_say("Stop recording") leader.disconnect() follower.disconnect() listener.stop() + +dataset.finalize() dataset.push_to_hub() diff --git a/pyproject.toml b/pyproject.toml index c67b481f0..c8879ac36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,24 +62,26 @@ dependencies = [ "datasets>=4.0.0,<4.2.0", "diffusers>=0.27.2,<0.36.0", "huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0", + "accelerate>=1.10.0,<2.0.0", # Core dependencies + "setuptools>=71.0.0,<81.0.0", "cmake>=3.29.0.1,<4.2.0", "einops>=0.8.0,<0.9.0", "opencv-python-headless>=4.9.0,<4.13.0", - "av>=14.2.0,<16.0.0", + "av>=15.0.0,<16.0.0", "jsonlines>=4.0.0,<5.0.0", "packaging>=24.2,<26.0", "pynput>=1.7.7,<1.9.0", "pyserial>=3.5,<4.0", - "wandb>=0.20.0,<0.23.0", + "wandb>=0.20.0,<0.22.0", # TODO: Bumb dependency (compatible with protobuf) "torch>=2.2.1,<2.8.0", # TODO: Bumb dependency "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency "torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency "draccus==0.10.0", # TODO: Remove == - "gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency + "gymnasium>=1.0.0", "rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency # Support dependencies @@ -95,7 +97,7 @@ dependencies = [ pygame-dep = ["pygame>=2.5.1,<2.7.0"] placo-dep = ["placo>=0.9.6,<0.10.0"] transformers-dep = ["transformers>=4.53.0,<5.0.0"] -grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] +grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb) # Motors feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"] @@ -132,11 +134,10 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0 video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"] # Simulation -aloha = ["gym-aloha>=0.1.1,<0.2.0"] +aloha = ["gym-aloha>=0.1.2,<0.2.0"] pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead -xarm = ["gym-xarm>=0.1.1,<0.2.0"] libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"] - +metaworld = ["metaworld>=3.0.0"] # All all = [ @@ -156,9 +157,9 @@ all = [ "lerobot[video_benchmark]", "lerobot[aloha]", "lerobot[pusht]", - "lerobot[xarm]", "lerobot[phone]", "lerobot[libero]", + "lerobot[metaworld]", ] [project.scripts] @@ -175,6 +176,7 @@ lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main" lerobot-info="lerobot.scripts.lerobot_info:main" lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main" lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" +lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" # ---------------- Tool Configurations ---------------- [tool.setuptools.packages.find] diff --git a/src/lerobot/__init__.py b/src/lerobot/__init__.py index 9d3ed1893..eec574296 100644 --- a/src/lerobot/__init__.py +++ b/src/lerobot/__init__.py @@ -57,7 +57,6 @@ available_tasks_per_env = { "AlohaTransferCube-v0", ], "pusht": ["PushT-v0"], - "xarm": ["XarmLift-v0"], } available_envs = list(available_tasks_per_env.keys()) @@ -75,16 +74,6 @@ available_datasets_per_env = { # TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly # coupled with tests. "pusht": ["lerobot/pusht", "lerobot/pusht_image"], - "xarm": [ - "lerobot/xarm_lift_medium", - "lerobot/xarm_lift_medium_replay", - "lerobot/xarm_push_medium", - "lerobot/xarm_push_medium_replay", - "lerobot/xarm_lift_medium_image", - "lerobot/xarm_lift_medium_replay_image", - "lerobot/xarm_push_medium_image", - "lerobot/xarm_push_medium_replay_image", - ], } available_real_world_datasets = [ @@ -195,7 +184,6 @@ available_motors = [ available_policies_per_env = { "aloha": ["act"], "pusht": ["diffusion", "vqbet"], - "xarm": ["tdmpc"], "koch_real": ["act_koch_real"], "aloha_real": ["act_aloha_real"], } diff --git a/src/lerobot/async_inference/configs.py b/src/lerobot/async_inference/configs.py index 24f889df1..d1768a323 100644 --- a/src/lerobot/async_inference/configs.py +++ b/src/lerobot/async_inference/configs.py @@ -142,11 +142,6 @@ class RobotClientConfig: default=False, metadata={"help": "Visualize the action queue size"} ) - # Verification configuration - verify_robot_cameras: bool = field( - default=True, metadata={"help": "Verify that the robot cameras match the policy cameras"} - ) - @property def environment_dt(self) -> float: """Environment time step, in seconds""" diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 54fad8c54..f73cbc1da 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -62,15 +62,6 @@ def visualize_action_queue_size(action_queue_size: list[int]) -> None: plt.show() -def validate_robot_cameras_for_policy( - lerobot_observation_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature] -) -> None: - image_keys = list(filter(is_image_key, lerobot_observation_features)) - assert set(image_keys) == set(policy_image_features.keys()), ( - f"Policy image features must match robot cameras! Received {list(policy_image_features.keys())} != {image_keys}" - ) - - def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]: return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False) diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index 8c4425c6b..f9d70a64e 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -48,7 +48,6 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 -from lerobot.configs.policies import PreTrainedConfig from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -76,7 +75,6 @@ from .helpers import ( TimedObservation, get_logger, map_robot_keys_to_lerobot_features, - validate_robot_cameras_for_policy, visualize_action_queue_size, ) @@ -98,14 +96,6 @@ class RobotClient: lerobot_features = map_robot_keys_to_lerobot_features(self.robot) - if config.verify_robot_cameras: - # Load policy config for validation - policy_config = PreTrainedConfig.from_pretrained(config.pretrained_name_or_path) - policy_image_features = policy_config.image_features - - # The cameras specified for inference must match the one supported by the policy chosen - validate_robot_cameras_for_policy(lerobot_features, policy_image_features) - # Use environment variable if server_address is not provided in config self.server_address = config.server_address diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 803645f29..870c9571e 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -31,15 +31,15 @@ from lerobot.datasets.utils import ( DEFAULT_EPISODES_PATH, DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, + get_file_size_in_mb, get_parquet_file_size_in_mb, - get_video_size_in_mb, to_parquet_with_hf_images, update_chunk_file_indices, write_info, write_stats, write_tasks, ) -from lerobot.datasets.video_utils import concatenate_video_files +from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): @@ -130,10 +130,34 @@ def update_meta_data( df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] df["data/file_index"] = df["data/file_index"] + data_idx["file"] for key, video_idx in videos_idx.items(): - df[f"videos/{key}/chunk_index"] = df[f"videos/{key}/chunk_index"] + video_idx["chunk"] - df[f"videos/{key}/file_index"] = df[f"videos/{key}/file_index"] + video_idx["file"] - df[f"videos/{key}/from_timestamp"] = df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"] - df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"] + # Store original video file indices before updating + orig_chunk_col = f"videos/{key}/chunk_index" + orig_file_col = f"videos/{key}/file_index" + df["_orig_chunk"] = df[orig_chunk_col].copy() + df["_orig_file"] = df[orig_file_col].copy() + + # Update chunk and file indices to point to destination + df[orig_chunk_col] = video_idx["chunk"] + df[orig_file_col] = video_idx["file"] + + # Apply per-source-file timestamp offsets + src_to_offset = video_idx.get("src_to_offset", {}) + if src_to_offset: + # Apply offset based on original source file + for idx in df.index: + src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"]) + offset = src_to_offset.get(src_key, 0) + df.at[idx, f"videos/{key}/from_timestamp"] += offset + df.at[idx, f"videos/{key}/to_timestamp"] += offset + else: + # Fallback to simple offset (for backward compatibility) + df[f"videos/{key}/from_timestamp"] = ( + df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"] + ) + df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"] + + # Clean up temporary columns + df = df.drop(columns=["_orig_chunk", "_orig_file"]) df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"] df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"] @@ -193,6 +217,10 @@ def aggregate_datasets( robot_type=robot_type, features=features, root=aggr_root, + use_videos=len(video_keys) > 0, + chunks_size=chunk_size, + data_files_size_in_mb=data_files_size_in_mb, + video_files_size_in_mb=video_files_size_in_mb, ) logging.info("Find all tasks") @@ -236,6 +264,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu Returns: dict: Updated videos_idx with current chunk and file indices. """ + for key in videos_idx: + videos_idx[key]["episode_duration"] = 0 + # Track offset for each source (chunk, file) pair + videos_idx[key]["src_to_offset"] = {} + for key, video_idx in videos_idx.items(): unique_chunk_file_pairs = { (chunk, file) @@ -249,6 +282,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu chunk_idx = video_idx["chunk"] file_idx = video_idx["file"] + current_offset = video_idx["latest_duration"] for src_chunk_idx, src_file_idx in unique_chunk_file_pairs: src_path = src_meta.root / DEFAULT_VIDEO_PATH.format( @@ -263,21 +297,25 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu file_index=file_idx, ) - # If a new file is created, we don't want to increment the latest_duration - update_latest_duration = False + src_duration = get_video_duration_in_s(src_path) if not dst_path.exists(): - # First write to this destination file + # Store offset before incrementing + videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset dst_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy(str(src_path), str(dst_path)) - continue # not accumulating further, already copied the file in place + videos_idx[key]["episode_duration"] += src_duration + current_offset += src_duration + continue # Check file sizes before appending - src_size = get_video_size_in_mb(src_path) - dst_size = get_video_size_in_mb(dst_path) + src_size = get_file_size_in_mb(src_path) + dst_size = get_file_size_in_mb(dst_path) if dst_size + src_size >= video_files_size_in_mb: - # Rotate to a new chunk/file + # Rotate to a new file, this source becomes start of new destination + # So its offset should be 0 + videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0 chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size) dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format( video_key=key, @@ -286,25 +324,22 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu ) dst_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy(str(src_path), str(dst_path)) + # Reset offset for next file + current_offset = src_duration else: - # Get the timestamps shift for this video - timestamps_shift_s = dst_meta.info["total_frames"] / dst_meta.info["fps"] - - # Append to existing video file + # Append to existing video file - use current accumulated offset + videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset concatenate_video_files( [dst_path, src_path], dst_path, ) - # Update the latest_duration when appending (shifts timestamps!) - update_latest_duration = not update_latest_duration + current_offset += src_duration + + videos_idx[key]["episode_duration"] += src_duration - # Update the videos_idx with the final chunk and file indices for this key videos_idx[key]["chunk"] = chunk_idx videos_idx[key]["file"] = file_idx - if update_latest_duration: - videos_idx[key]["latest_duration"] += timestamps_shift_s - return videos_idx @@ -389,9 +424,6 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): videos_idx, ) - for k in videos_idx: - videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"] - meta_idx = append_or_create_parquet_file( df, src_path, @@ -403,6 +435,10 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): aggr_root=dst_meta.root, ) + # Increment latest_duration by the total duration added from this source dataset + for k in videos_idx: + videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"] + return meta_idx diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py new file mode 100644 index 000000000..2735ba0a0 --- /dev/null +++ b/src/lerobot/datasets/dataset_tools.py @@ -0,0 +1,1089 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset tools utilities for LeRobotDataset. + +This module provides utilities for: +- Deleting episodes from datasets +- Splitting datasets into multiple smaller datasets +- Adding/removing features from datasets +- Merging datasets (wrapper around aggregate functionality) +""" + +import logging +import shutil +from collections.abc import Callable +from pathlib import Path + +import datasets +import numpy as np +import pandas as pd +import pyarrow.parquet as pq +import torch +from tqdm import tqdm + +from lerobot.datasets.aggregate import aggregate_datasets +from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.datasets.utils import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_DATA_PATH, + DEFAULT_EPISODES_PATH, + get_parquet_file_size_in_mb, + load_episodes, + update_chunk_file_indices, + write_info, + write_stats, + write_tasks, +) +from lerobot.utils.constants import HF_LEROBOT_HOME + + +def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict: + """Load a single episode's metadata including stats from parquet file. + + Args: + src_dataset: Source dataset + episode_idx: Episode index to load + + Returns: + dict containing episode metadata and stats + """ + ep_meta = src_dataset.meta.episodes[episode_idx] + chunk_idx = ep_meta["meta/episodes/chunk_index"] + file_idx = ep_meta["meta/episodes/file_index"] + + parquet_path = src_dataset.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + df = pd.read_parquet(parquet_path) + + episode_row = df[df["episode_index"] == episode_idx].iloc[0] + + return episode_row.to_dict() + + +def delete_episodes( + dataset: LeRobotDataset, + episode_indices: list[int], + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Delete episodes from a LeRobotDataset and create a new dataset. + + Args: + dataset: The source LeRobotDataset. + episode_indices: List of episode indices to delete. + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + """ + if not episode_indices: + raise ValueError("No episodes to delete") + + valid_indices = set(range(dataset.meta.total_episodes)) + invalid = set(episode_indices) - valid_indices + if invalid: + raise ValueError(f"Invalid episode indices: {invalid}") + + logging.info(f"Deleting {len(episode_indices)} episodes from dataset") + + if repo_id is None: + repo_id = f"{dataset.repo_id}_modified" + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id + + episodes_to_keep = [i for i in range(dataset.meta.total_episodes) if i not in episode_indices] + if not episodes_to_keep: + raise ValueError("Cannot delete all episodes from dataset") + + new_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=dataset.meta.fps, + features=dataset.meta.features, + robot_type=dataset.meta.robot_type, + root=output_dir, + use_videos=len(dataset.meta.video_keys) > 0, + ) + + episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(episodes_to_keep)} + + video_metadata = None + if dataset.meta.video_keys: + video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping) + + data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping) + + _copy_and_reindex_episodes_metadata(dataset, new_meta, episode_mapping, data_metadata, video_metadata) + + new_dataset = LeRobotDataset( + repo_id=repo_id, + root=output_dir, + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + ) + + logging.info(f"Created new dataset with {len(episodes_to_keep)} episodes") + return new_dataset + + +def split_dataset( + dataset: LeRobotDataset, + splits: dict[str, float | list[int]], + output_dir: str | Path | None = None, +) -> dict[str, LeRobotDataset]: + """Split a LeRobotDataset into multiple smaller datasets. + + Args: + dataset: The source LeRobotDataset to split. + splits: Either a dict mapping split names to episode indices, or a dict mapping + split names to fractions (must sum to <= 1.0). + output_dir: Base directory for output datasets. If None, uses default location. + + Examples: + Split by specific episodes + splits = {"train": [0, 1, 2], "val": [3, 4]} + datasets = split_dataset(dataset, splits) + + Split by fractions + splits = {"train": 0.8, "val": 0.2} + datasets = split_dataset(dataset, splits) + """ + if not splits: + raise ValueError("No splits provided") + + if all(isinstance(v, float) for v in splits.values()): + splits = _fractions_to_episode_indices(dataset.meta.total_episodes, splits) + + all_episodes = set() + for split_name, episodes in splits.items(): + if not episodes: + raise ValueError(f"Split '{split_name}' has no episodes") + episode_set = set(episodes) + if episode_set & all_episodes: + raise ValueError("Episodes cannot appear in multiple splits") + all_episodes.update(episode_set) + + valid_indices = set(range(dataset.meta.total_episodes)) + invalid = all_episodes - valid_indices + if invalid: + raise ValueError(f"Invalid episode indices: {invalid}") + + if output_dir is not None: + output_dir = Path(output_dir) + + result_datasets = {} + + for split_name, episodes in splits.items(): + logging.info(f"Creating split '{split_name}' with {len(episodes)} episodes") + + split_repo_id = f"{dataset.repo_id}_{split_name}" + + split_output_dir = ( + output_dir / split_name if output_dir is not None else HF_LEROBOT_HOME / split_repo_id + ) + + episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(sorted(episodes))} + + new_meta = LeRobotDatasetMetadata.create( + repo_id=split_repo_id, + fps=dataset.meta.fps, + features=dataset.meta.features, + robot_type=dataset.meta.robot_type, + root=split_output_dir, + use_videos=len(dataset.meta.video_keys) > 0, + chunks_size=dataset.meta.chunks_size, + data_files_size_in_mb=dataset.meta.data_files_size_in_mb, + video_files_size_in_mb=dataset.meta.video_files_size_in_mb, + ) + + video_metadata = None + if dataset.meta.video_keys: + video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping) + + data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping) + + _copy_and_reindex_episodes_metadata(dataset, new_meta, episode_mapping, data_metadata, video_metadata) + + new_dataset = LeRobotDataset( + repo_id=split_repo_id, + root=split_output_dir, + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + ) + + result_datasets[split_name] = new_dataset + + return result_datasets + + +def merge_datasets( + datasets: list[LeRobotDataset], + output_repo_id: str, + output_dir: str | Path | None = None, +) -> LeRobotDataset: + """Merge multiple LeRobotDatasets into a single dataset. + + This is a wrapper around the aggregate_datasets functionality with a cleaner API. + + Args: + datasets: List of LeRobotDatasets to merge. + output_repo_id: Repository ID for the merged dataset. + output_dir: Directory to save the merged dataset. If None, uses default location. + """ + if not datasets: + raise ValueError("No datasets to merge") + + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / output_repo_id + + repo_ids = [ds.repo_id for ds in datasets] + roots = [ds.root for ds in datasets] + + aggregate_datasets( + repo_ids=repo_ids, + aggr_repo_id=output_repo_id, + roots=roots, + aggr_root=output_dir, + ) + + merged_dataset = LeRobotDataset( + repo_id=output_repo_id, + root=output_dir, + image_transforms=datasets[0].image_transforms, + delta_timestamps=datasets[0].delta_timestamps, + tolerance_s=datasets[0].tolerance_s, + ) + + return merged_dataset + + +def modify_features( + dataset: LeRobotDataset, + add_features: dict[str, tuple[np.ndarray | torch.Tensor | Callable, dict]] | None = None, + remove_features: str | list[str] | None = None, + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Modify a LeRobotDataset by adding and/or removing features in a single pass. + + This is the most efficient way to modify features, as it only copies the dataset once + regardless of how many features are being added or removed. + + Args: + dataset: The source LeRobotDataset. + add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples. + remove_features: Optional feature name(s) to remove. Can be a single string or list. + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + + Returns: + New dataset with features modified. + + Example: + new_dataset = modify_features( + dataset, + add_features={ + "reward": (reward_array, {"dtype": "float32", "shape": [1], "names": None}), + }, + remove_features=["old_feature"], + output_dir="./output", + ) + """ + if add_features is None and remove_features is None: + raise ValueError("Must specify at least one of add_features or remove_features") + + remove_features_list: list[str] = [] + if remove_features is not None: + remove_features_list = [remove_features] if isinstance(remove_features, str) else remove_features + + if add_features: + required_keys = {"dtype", "shape"} + for feature_name, (_, feature_info) in add_features.items(): + if feature_name in dataset.meta.features: + raise ValueError(f"Feature '{feature_name}' already exists in dataset") + + if not required_keys.issubset(feature_info.keys()): + raise ValueError(f"feature_info for '{feature_name}' must contain keys: {required_keys}") + + if remove_features_list: + for name in remove_features_list: + if name not in dataset.meta.features: + raise ValueError(f"Feature '{name}' not found in dataset") + + required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"} + if any(name in required_features for name in remove_features_list): + raise ValueError(f"Cannot remove required features: {required_features}") + + if repo_id is None: + repo_id = f"{dataset.repo_id}_modified" + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id + + new_features = dataset.meta.features.copy() + + if remove_features_list: + for name in remove_features_list: + new_features.pop(name, None) + + if add_features: + for feature_name, (_, feature_info) in add_features.items(): + new_features[feature_name] = feature_info + + video_keys_to_remove = [name for name in remove_features_list if name in dataset.meta.video_keys] + remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove] + + new_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=dataset.meta.fps, + features=new_features, + robot_type=dataset.meta.robot_type, + root=output_dir, + use_videos=len(remaining_video_keys) > 0, + ) + + _copy_data_with_feature_changes( + dataset=dataset, + new_meta=new_meta, + add_features=add_features, + remove_features=remove_features_list if remove_features_list else None, + ) + + if new_meta.video_keys: + _copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove if video_keys_to_remove else None) + + new_dataset = LeRobotDataset( + repo_id=repo_id, + root=output_dir, + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + ) + + return new_dataset + + +def add_features( + dataset: LeRobotDataset, + features: dict[str, tuple[np.ndarray | torch.Tensor | Callable, dict]], + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Add multiple features to a LeRobotDataset in a single pass. + + This is more efficient than calling add_feature() multiple times, as it only + copies the dataset once regardless of how many features are being added. + + Args: + dataset: The source LeRobotDataset. + features: Dictionary mapping feature names to (feature_values, feature_info) tuples. + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + + Returns: + New dataset with all features added. + + Example: + features = { + "task_embedding": (task_emb_array, {"dtype": "float32", "shape": [384], "names": None}), + "cam1_embedding": (cam1_emb_array, {"dtype": "float32", "shape": [768], "names": None}), + "cam2_embedding": (cam2_emb_array, {"dtype": "float32", "shape": [768], "names": None}), + } + new_dataset = add_features(dataset, features, output_dir="./output", repo_id="my_dataset") + """ + if not features: + raise ValueError("No features provided") + + return modify_features( + dataset=dataset, + add_features=features, + remove_features=None, + output_dir=output_dir, + repo_id=repo_id, + ) + + +def remove_feature( + dataset: LeRobotDataset, + feature_names: str | list[str], + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Remove features from a LeRobotDataset. + + Args: + dataset: The source LeRobotDataset. + feature_names: Name(s) of features to remove. Can be a single string or list. + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + + Returns: + New dataset with features removed. + """ + return modify_features( + dataset=dataset, + add_features=None, + remove_features=feature_names, + output_dir=output_dir, + repo_id=repo_id, + ) + + +def _fractions_to_episode_indices( + total_episodes: int, + splits: dict[str, float], +) -> dict[str, list[int]]: + """Convert split fractions to episode indices.""" + if sum(splits.values()) > 1.0: + raise ValueError("Split fractions must sum to <= 1.0") + + indices = list(range(total_episodes)) + result = {} + start_idx = 0 + + for split_name, fraction in splits.items(): + num_episodes = int(total_episodes * fraction) + if num_episodes == 0: + logging.warning(f"Split '{split_name}' has no episodes, skipping...") + continue + end_idx = start_idx + num_episodes + if split_name == list(splits.keys())[-1]: + end_idx = total_episodes + result[split_name] = indices[start_idx:end_idx] + start_idx = end_idx + + return result + + +def _copy_and_reindex_data( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_mapping: dict[int, int], +) -> dict[int, dict]: + """Copy and filter data files, only modifying files with deleted episodes. + + Args: + src_dataset: Source dataset to copy from + dst_meta: Destination metadata object + episode_mapping: Mapping from old episode indices to new indices + + Returns: + dict mapping episode index to its data file metadata (chunk_index, file_index, etc.) + """ + if src_dataset.meta.episodes is None: + src_dataset.meta.episodes = load_episodes(src_dataset.meta.root) + + file_to_episodes: dict[Path, set[int]] = {} + for old_idx in episode_mapping: + file_path = src_dataset.meta.get_data_file_path(old_idx) + if file_path not in file_to_episodes: + file_to_episodes[file_path] = set() + file_to_episodes[file_path].add(old_idx) + + global_index = 0 + episode_data_metadata: dict[int, dict] = {} + + if dst_meta.tasks is None: + all_task_indices = set() + for src_path in file_to_episodes: + df = pd.read_parquet(src_dataset.root / src_path) + mask = df["episode_index"].isin(list(episode_mapping.keys())) + task_series: pd.Series = df[mask]["task_index"] + all_task_indices.update(task_series.unique().tolist()) + tasks = [src_dataset.meta.tasks.iloc[idx].name for idx in all_task_indices] + dst_meta.save_episode_tasks(list(set(tasks))) + + task_mapping = {} + for old_task_idx in range(len(src_dataset.meta.tasks)): + task_name = src_dataset.meta.tasks.iloc[old_task_idx].name + new_task_idx = dst_meta.get_task_index(task_name) + if new_task_idx is not None: + task_mapping[old_task_idx] = new_task_idx + + for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"): + df = pd.read_parquet(src_dataset.root / src_path) + + all_episodes_in_file = set(df["episode_index"].unique()) + episodes_to_keep = file_to_episodes[src_path] + + if all_episodes_in_file == episodes_to_keep: + df["episode_index"] = df["episode_index"].replace(episode_mapping) + df["index"] = range(global_index, global_index + len(df)) + df["task_index"] = df["task_index"].replace(task_mapping) + + first_ep_old_idx = min(episodes_to_keep) + src_ep = src_dataset.meta.episodes[first_ep_old_idx] + chunk_idx = src_ep["data/chunk_index"] + file_idx = src_ep["data/file_index"] + else: + mask = df["episode_index"].isin(list(episode_mapping.keys())) + df = df[mask].copy().reset_index(drop=True) + + if len(df) == 0: + continue + + df["episode_index"] = df["episode_index"].replace(episode_mapping) + df["index"] = range(global_index, global_index + len(df)) + df["task_index"] = df["task_index"].replace(task_mapping) + + first_ep_old_idx = min(episodes_to_keep) + src_ep = src_dataset.meta.episodes[first_ep_old_idx] + chunk_idx = src_ep["data/chunk_index"] + file_idx = src_ep["data/file_index"] + + dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + dst_path.parent.mkdir(parents=True, exist_ok=True) + + _write_parquet(df, dst_path, dst_meta) + + for ep_old_idx in episodes_to_keep: + ep_new_idx = episode_mapping[ep_old_idx] + ep_df = df[df["episode_index"] == ep_new_idx] + episode_data_metadata[ep_new_idx] = { + "data/chunk_index": chunk_idx, + "data/file_index": file_idx, + "dataset_from_index": int(ep_df["index"].min()), + "dataset_to_index": int(ep_df["index"].max() + 1), + } + + global_index += len(df) + + return episode_data_metadata + + +def _keep_episodes_from_video_with_av( + input_path: Path, + output_path: Path, + episodes_to_keep: list[tuple[float, float]], + fps: float, + vcodec: str = "libsvtav1", + pix_fmt: str = "yuv420p", +) -> None: + """Keep only specified episodes from a video file using PyAV. + + This function decodes frames from specified time ranges and re-encodes them with + properly reset timestamps to ensure monotonic progression. + + Args: + input_path: Source video file path. + output_path: Destination video file path. + episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep. + fps: Frame rate of the video. + vcodec: Video codec to use for encoding. + pix_fmt: Pixel format for output video. + """ + from fractions import Fraction + + import av + + if not episodes_to_keep: + raise ValueError("No episodes to keep") + + in_container = av.open(str(input_path)) + + # Check if video stream exists. + if not in_container.streams.video: + raise ValueError( + f"No video streams found in {input_path}. " + "The video file may be corrupted or empty. " + "Try re-downloading the dataset or checking the video file." + ) + + v_in = in_container.streams.video[0] + + out = av.open(str(output_path), mode="w") + + # Convert fps to Fraction for PyAV compatibility. + fps_fraction = Fraction(fps).limit_denominator(1000) + v_out = out.add_stream(vcodec, rate=fps_fraction) + + # PyAV type stubs don't distinguish video streams from audio/subtitle streams. + v_out.width = v_in.codec_context.width + v_out.height = v_in.codec_context.height + v_out.pix_fmt = pix_fmt + + # Set time_base to match the frame rate for proper timestamp handling. + v_out.time_base = Fraction(1, int(fps)) + + out.start_encoding() + + # Create set of (start, end) ranges for fast lookup. + # Convert to a sorted list for efficient checking. + time_ranges = sorted(episodes_to_keep) + + # Track frame index for setting PTS and current range being processed. + frame_count = 0 + range_idx = 0 + + # Read through entire video once and filter frames. + for packet in in_container.demux(v_in): + for frame in packet.decode(): + if frame is None: + continue + + # Get frame timestamp. + frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0 + + # Check if frame is in any of our desired time ranges. + # Skip ranges that have already passed. + while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]: + range_idx += 1 + + # If we've passed all ranges, stop processing. + if range_idx >= len(time_ranges): + break + + # Check if frame is in current range. + start_ts, end_ts = time_ranges[range_idx] + if frame_time < start_ts: + continue + + # Frame is in range - create a new frame with reset timestamps. + # We need to create a copy to avoid modifying the original. + new_frame = frame.reformat(width=v_out.width, height=v_out.height, format=v_out.pix_fmt) + new_frame.pts = frame_count + new_frame.time_base = Fraction(1, int(fps)) + + # Encode and mux the frame. + for pkt in v_out.encode(new_frame): + out.mux(pkt) + + frame_count += 1 + + # Flush encoder. + for pkt in v_out.encode(): + out.mux(pkt) + + out.close() + in_container.close() + + +def _copy_and_reindex_videos( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_mapping: dict[int, int], + vcodec: str = "libsvtav1", + pix_fmt: str = "yuv420p", +) -> dict[int, dict]: + """Copy and filter video files, only re-encoding files with deleted episodes. + + For video files that only contain kept episodes, we copy them directly. + For files with mixed kept/deleted episodes, we use PyAV filters to efficiently + re-encode only the desired segments. + + Args: + src_dataset: Source dataset to copy from + dst_meta: Destination metadata object + episode_mapping: Mapping from old episode indices to new indices + + Returns: + dict mapping episode index to its video metadata (chunk_index, file_index, timestamps) + """ + if src_dataset.meta.episodes is None: + src_dataset.meta.episodes = load_episodes(src_dataset.meta.root) + + episodes_video_metadata: dict[int, dict] = {new_idx: {} for new_idx in episode_mapping.values()} + + for video_key in src_dataset.meta.video_keys: + logging.info(f"Processing videos for {video_key}") + + if dst_meta.video_path is None: + raise ValueError("Destination metadata has no video_path defined") + + file_to_episodes: dict[tuple[int, int], list[int]] = {} + for old_idx in episode_mapping: + src_ep = src_dataset.meta.episodes[old_idx] + chunk_idx = src_ep[f"videos/{video_key}/chunk_index"] + file_idx = src_ep[f"videos/{video_key}/file_index"] + file_key = (chunk_idx, file_idx) + if file_key not in file_to_episodes: + file_to_episodes[file_key] = [] + file_to_episodes[file_key].append(old_idx) + + for (src_chunk_idx, src_file_idx), episodes_in_file in tqdm( + sorted(file_to_episodes.items()), desc=f"Processing {video_key} video files" + ): + all_episodes_in_file = [ + ep_idx + for ep_idx in range(src_dataset.meta.total_episodes) + if src_dataset.meta.episodes[ep_idx].get(f"videos/{video_key}/chunk_index") == src_chunk_idx + and src_dataset.meta.episodes[ep_idx].get(f"videos/{video_key}/file_index") == src_file_idx + ] + + episodes_to_keep_set = set(episodes_in_file) + all_in_file_set = set(all_episodes_in_file) + + if all_in_file_set == episodes_to_keep_set: + assert src_dataset.meta.video_path is not None + src_video_path = src_dataset.root / src_dataset.meta.video_path.format( + video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx + ) + dst_video_path = dst_meta.root / dst_meta.video_path.format( + video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx + ) + dst_video_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(src_video_path, dst_video_path) + + for old_idx in episodes_in_file: + new_idx = episode_mapping[old_idx] + src_ep = src_dataset.meta.episodes[old_idx] + episodes_video_metadata[new_idx][f"videos/{video_key}/chunk_index"] = src_chunk_idx + episodes_video_metadata[new_idx][f"videos/{video_key}/file_index"] = src_file_idx + episodes_video_metadata[new_idx][f"videos/{video_key}/from_timestamp"] = src_ep[ + f"videos/{video_key}/from_timestamp" + ] + episodes_video_metadata[new_idx][f"videos/{video_key}/to_timestamp"] = src_ep[ + f"videos/{video_key}/to_timestamp" + ] + else: + # Build list of time ranges to keep, in sorted order. + sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x]) + episodes_to_keep_ranges: list[tuple[float, float]] = [] + + for old_idx in sorted_keep_episodes: + src_ep = src_dataset.meta.episodes[old_idx] + from_ts = src_ep[f"videos/{video_key}/from_timestamp"] + to_ts = src_ep[f"videos/{video_key}/to_timestamp"] + episodes_to_keep_ranges.append((from_ts, to_ts)) + + # Use PyAV filters to efficiently re-encode only the desired segments. + assert src_dataset.meta.video_path is not None + src_video_path = src_dataset.root / src_dataset.meta.video_path.format( + video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx + ) + dst_video_path = dst_meta.root / dst_meta.video_path.format( + video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx + ) + dst_video_path.parent.mkdir(parents=True, exist_ok=True) + + logging.info( + f"Re-encoding {video_key} (chunk {src_chunk_idx}, file {src_file_idx}) " + f"with {len(episodes_to_keep_ranges)} episodes" + ) + _keep_episodes_from_video_with_av( + src_video_path, + dst_video_path, + episodes_to_keep_ranges, + src_dataset.meta.fps, + vcodec, + pix_fmt, + ) + + cumulative_ts = 0.0 + for old_idx in sorted_keep_episodes: + new_idx = episode_mapping[old_idx] + src_ep = src_dataset.meta.episodes[old_idx] + ep_length = src_ep["length"] + ep_duration = ep_length / src_dataset.meta.fps + + episodes_video_metadata[new_idx][f"videos/{video_key}/chunk_index"] = src_chunk_idx + episodes_video_metadata[new_idx][f"videos/{video_key}/file_index"] = src_file_idx + episodes_video_metadata[new_idx][f"videos/{video_key}/from_timestamp"] = cumulative_ts + episodes_video_metadata[new_idx][f"videos/{video_key}/to_timestamp"] = ( + cumulative_ts + ep_duration + ) + + cumulative_ts += ep_duration + + return episodes_video_metadata + + +def _copy_and_reindex_episodes_metadata( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_mapping: dict[int, int], + data_metadata: dict[int, dict], + video_metadata: dict[int, dict] | None = None, +) -> None: + """Copy and reindex episodes metadata using provided data and video metadata. + + Args: + src_dataset: Source dataset to copy from + dst_meta: Destination metadata object + episode_mapping: Mapping from old episode indices to new indices + data_metadata: Dict mapping new episode index to its data file metadata + video_metadata: Optional dict mapping new episode index to its video metadata + """ + from lerobot.datasets.utils import flatten_dict + + if src_dataset.meta.episodes is None: + src_dataset.meta.episodes = load_episodes(src_dataset.meta.root) + + all_stats = [] + total_frames = 0 + + for old_idx, new_idx in tqdm( + sorted(episode_mapping.items(), key=lambda x: x[1]), desc="Processing episodes metadata" + ): + src_episode_full = _load_episode_with_stats(src_dataset, old_idx) + + src_episode = src_dataset.meta.episodes[old_idx] + + episode_meta = data_metadata[new_idx].copy() + + if video_metadata and new_idx in video_metadata: + episode_meta.update(video_metadata[new_idx]) + + # Extract episode statistics from parquet metadata. + # Note (maractingi): When pandas/pyarrow serializes numpy arrays with shape (3, 1, 1) to parquet, + # they are being deserialized as nested object arrays like: + # array([array([array([0.])]), array([array([0.])]), array([array([0.])])]) + # This happens particularly with image/video statistics. We need to detect and flatten + # these nested structures back to proper (3, 1, 1) arrays so aggregate_stats can process them. + episode_stats = {} + for key in src_episode_full: + if key.startswith("stats/"): + stat_key = key.replace("stats/", "") + parts = stat_key.split("/") + if len(parts) == 2: + feature_name, stat_name = parts + if feature_name not in episode_stats: + episode_stats[feature_name] = {} + + value = src_episode_full[key] + + if feature_name in src_dataset.meta.features: + feature_dtype = src_dataset.meta.features[feature_name]["dtype"] + if feature_dtype in ["image", "video"] and stat_name != "count": + if isinstance(value, np.ndarray) and value.dtype == object: + flat_values = [] + for item in value: + while isinstance(item, np.ndarray): + item = item.flatten()[0] + flat_values.append(item) + value = np.array(flat_values, dtype=np.float64).reshape(3, 1, 1) + elif isinstance(value, np.ndarray) and value.shape == (3,): + value = value.reshape(3, 1, 1) + + episode_stats[feature_name][stat_name] = value + + all_stats.append(episode_stats) + + episode_dict = { + "episode_index": new_idx, + "tasks": src_episode["tasks"], + "length": src_episode["length"], + } + episode_dict.update(episode_meta) + episode_dict.update(flatten_dict({"stats": episode_stats})) + dst_meta._save_episode_metadata(episode_dict) + + total_frames += src_episode["length"] + + dst_meta._close_writer() + + dst_meta.info.update( + { + "total_episodes": len(episode_mapping), + "total_frames": total_frames, + "total_tasks": len(dst_meta.tasks) if dst_meta.tasks is not None else 0, + "splits": {"train": f"0:{len(episode_mapping)}"}, + } + ) + write_info(dst_meta.info, dst_meta.root) + + if not all_stats: + logging.warning("No statistics found to aggregate") + return + + logging.info(f"Aggregating statistics for {len(all_stats)} episodes") + aggregated_stats = aggregate_stats(all_stats) + filtered_stats = {k: v for k, v in aggregated_stats.items() if k in dst_meta.features} + write_stats(filtered_stats, dst_meta.root) + + +def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -> None: + """Write DataFrame to parquet + + This ensures images are properly embedded and the file can be loaded correctly by HF datasets. + """ + from lerobot.datasets.utils import embed_images, get_hf_features_from_features + + hf_features = get_hf_features_from_features(meta.features) + ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train") + + if len(meta.image_keys) > 0: + ep_dataset = embed_images(ep_dataset) + + table = ep_dataset.with_format("arrow")[:] + writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True) + writer.write_table(table) + writer.close() + + +def _save_data_chunk( + df: pd.DataFrame, + meta: LeRobotDatasetMetadata, + chunk_idx: int = 0, + file_idx: int = 0, +) -> tuple[int, int, dict[int, dict]]: + """Save a data chunk and return updated indices and episode metadata. + + Returns: + tuple: (next_chunk_idx, next_file_idx, episode_metadata_dict) + where episode_metadata_dict maps episode_index to its data file metadata + """ + path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + path.parent.mkdir(parents=True, exist_ok=True) + + _write_parquet(df, path, meta) + + episode_metadata = {} + for ep_idx in df["episode_index"].unique(): + ep_df = df[df["episode_index"] == ep_idx] + episode_metadata[ep_idx] = { + "data/chunk_index": chunk_idx, + "data/file_index": file_idx, + "dataset_from_index": int(ep_df["index"].min()), + "dataset_to_index": int(ep_df["index"].max() + 1), + } + + file_size = get_parquet_file_size_in_mb(path) + if file_size >= DEFAULT_DATA_FILE_SIZE_IN_MB * 0.9: + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE) + + return chunk_idx, file_idx, episode_metadata + + +def _copy_data_with_feature_changes( + dataset: LeRobotDataset, + new_meta: LeRobotDatasetMetadata, + add_features: dict[str, tuple] | None = None, + remove_features: list[str] | None = None, +) -> None: + """Copy data while adding or removing features.""" + if dataset.meta.episodes is None: + dataset.meta.episodes = load_episodes(dataset.meta.root) + + # Map file paths to episode indices to extract chunk/file indices + file_to_episodes: dict[Path, set[int]] = {} + for ep_idx in range(dataset.meta.total_episodes): + file_path = dataset.meta.get_data_file_path(ep_idx) + if file_path not in file_to_episodes: + file_to_episodes[file_path] = set() + file_to_episodes[file_path].add(ep_idx) + + frame_idx = 0 + + for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"): + df = pd.read_parquet(dataset.root / src_path).reset_index(drop=True) + + # Get chunk_idx and file_idx from the source file's first episode + episodes_in_file = file_to_episodes[src_path] + first_ep_idx = min(episodes_in_file) + src_ep = dataset.meta.episodes[first_ep_idx] + chunk_idx = src_ep["data/chunk_index"] + file_idx = src_ep["data/file_index"] + + if remove_features: + df = df.drop(columns=remove_features, errors="ignore") + + if add_features: + end_idx = frame_idx + len(df) + for feature_name, (values, _) in add_features.items(): + if callable(values): + feature_values = [] + for _, row in df.iterrows(): + ep_idx = row["episode_index"] + frame_in_ep = row["frame_index"] + value = values(row.to_dict(), ep_idx, frame_in_ep) + if isinstance(value, np.ndarray) and value.size == 1: + value = value.item() + feature_values.append(value) + df[feature_name] = feature_values + else: + feature_slice = values[frame_idx:end_idx] + if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1: + df[feature_name] = feature_slice.flatten() + else: + df[feature_name] = feature_slice + frame_idx = end_idx + + # Write using the preserved chunk_idx and file_idx from source + dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + dst_path.parent.mkdir(parents=True, exist_ok=True) + + _write_parquet(df, dst_path, new_meta) + + _copy_episodes_metadata_and_stats(dataset, new_meta) + + +def _copy_videos( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + exclude_keys: list[str] | None = None, +) -> None: + """Copy video files, optionally excluding certain keys.""" + if exclude_keys is None: + exclude_keys = [] + + for video_key in src_dataset.meta.video_keys: + if video_key in exclude_keys: + continue + + video_files = set() + for ep_idx in range(len(src_dataset.meta.episodes)): + try: + video_files.add(src_dataset.meta.get_video_file_path(ep_idx, video_key)) + except KeyError: + continue + + for src_path in tqdm(sorted(video_files), desc=f"Copying {video_key} videos"): + dst_path = dst_meta.root / src_path + dst_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(src_dataset.root / src_path, dst_path) + + +def _copy_episodes_metadata_and_stats( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, +) -> None: + """Copy episodes metadata and recalculate stats.""" + if src_dataset.meta.tasks is not None: + write_tasks(src_dataset.meta.tasks, dst_meta.root) + dst_meta.tasks = src_dataset.meta.tasks.copy() + + episodes_dir = src_dataset.root / "meta/episodes" + dst_episodes_dir = dst_meta.root / "meta/episodes" + if episodes_dir.exists(): + shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True) + + dst_meta.info.update( + { + "total_episodes": src_dataset.meta.total_episodes, + "total_frames": src_dataset.meta.total_frames, + "total_tasks": src_dataset.meta.total_tasks, + "splits": src_dataset.meta.info.get("splits", {"train": f"0:{src_dataset.meta.total_episodes}"}), + } + ) + + if dst_meta.video_keys and src_dataset.meta.video_keys: + for key in dst_meta.video_keys: + if key in src_dataset.meta.features: + dst_meta.info["features"][key]["info"] = src_dataset.meta.info["features"][key].get( + "info", {} + ) + + write_info(dst_meta.info, dst_meta.root) + + if set(dst_meta.features.keys()) != set(src_dataset.meta.features.keys()): + logging.info("Recalculating dataset statistics...") + if src_dataset.meta.stats: + new_stats = {} + for key in dst_meta.features: + if key in src_dataset.meta.stats: + new_stats[key] = src_dataset.meta.stats[key] + write_stats(new_stats, dst_meta.root) + else: + if src_dataset.meta.stats: + write_stats(src_dataset.meta.stats, dst_meta.root) diff --git a/src/lerobot/datasets/image_writer.py b/src/lerobot/datasets/image_writer.py index 4a4e1ab05..ee10df6e1 100644 --- a/src/lerobot/datasets/image_writer.py +++ b/src/lerobot/datasets/image_writer.py @@ -68,7 +68,30 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) return PIL.Image.fromarray(image_array) -def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path): +def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1): + """ + Saves a NumPy array or PIL Image to a file. + + This function handles both NumPy arrays and PIL Image objects, converting + the former to a PIL Image before saving. It includes error handling for + the save operation. + + Args: + image (np.ndarray | PIL.Image.Image): The image data to save. + fpath (Path): The destination file path for the image. + compress_level (int, optional): The compression level for the saved + image, as used by PIL.Image.save(). Defaults to 1. + Refer to: https://github.com/huggingface/lerobot/pull/2135 + for more details on the default value rationale. + + Raises: + TypeError: If the input 'image' is not a NumPy array or a + PIL.Image.Image object. + + Side Effects: + Prints an error message to the console if the image writing process + fails for any reason. + """ try: if isinstance(image, np.ndarray): img = image_array_to_pil_image(image) @@ -76,7 +99,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path): img = image else: raise TypeError(f"Unsupported image type: {type(image)}") - img.save(fpath) + img.save(fpath, compress_level=compress_level) except Exception as e: print(f"Error writing image {fpath}: {e}") diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index b661b21b0..ae142c1e8 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -import gc import logging import shutil import tempfile @@ -26,6 +25,8 @@ import numpy as np import packaging.version import pandas as pd import PIL.Image +import pyarrow as pa +import pyarrow.parquet as pq import torch import torch.utils from huggingface_hub import HfApi, snapshot_download @@ -46,13 +47,9 @@ from lerobot.datasets.utils import ( embed_images, flatten_dict, get_delta_indices, - get_hf_dataset_cache_dir, - get_hf_dataset_size_in_mb, + get_file_size_in_mb, get_hf_features_from_features, - get_parquet_file_size_in_mb, - get_parquet_num_frames, get_safe_version, - get_video_size_in_mb, hf_transform_to_torch, is_valid_version, load_episodes, @@ -60,7 +57,6 @@ from lerobot.datasets.utils import ( load_nested_dataset, load_stats, load_tasks, - to_parquet_with_hf_images, update_chunk_file_indices, validate_episode_buffer, validate_frame, @@ -90,10 +86,15 @@ class LeRobotDatasetMetadata: root: str | Path | None = None, revision: str | None = None, force_cache_sync: bool = False, + metadata_buffer_size: int = 10, ): self.repo_id = repo_id self.revision = revision if revision else CODEBASE_VERSION self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + self.writer = None + self.latest_episode = None + self.metadata_buffer: list[dict] = [] + self.metadata_buffer_size = metadata_buffer_size try: if force_cache_sync: @@ -107,6 +108,54 @@ class LeRobotDatasetMetadata: self.pull_from_repo(allow_patterns="meta/") self.load_metadata() + def _flush_metadata_buffer(self) -> None: + """Write all buffered episode metadata to parquet file.""" + if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0: + return + + combined_dict = {} + for episode_dict in self.metadata_buffer: + for key, value in episode_dict.items(): + if key not in combined_dict: + combined_dict[key] = [] + # Extract value and serialize numpy arrays + # because PyArrow's from_pydict function doesn't support numpy arrays + val = value[0] if isinstance(value, list) else value + combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val) + + first_ep = self.metadata_buffer[0] + chunk_idx = first_ep["meta/episodes/chunk_index"][0] + file_idx = first_ep["meta/episodes/file_index"][0] + + table = pa.Table.from_pydict(combined_dict) + + if not self.writer: + path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)) + path.parent.mkdir(parents=True, exist_ok=True) + self.writer = pq.ParquetWriter( + path, schema=table.schema, compression="snappy", use_dictionary=True + ) + + self.writer.write_table(table) + + self.latest_episode = self.metadata_buffer[-1] + self.metadata_buffer.clear() + + def _close_writer(self) -> None: + """Close and cleanup the parquet writer if it exists.""" + self._flush_metadata_buffer() + + writer = getattr(self, "writer", None) + if writer is not None: + writer.close() + self.writer = None + + def __del__(self): + """ + Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor + """ + self._close_writer() + def load_metadata(self): self.info = load_info(self.root) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) @@ -138,6 +187,12 @@ class LeRobotDatasetMetadata: return packaging.version.parse(self.info["codebase_version"]) def get_data_file_path(self, ep_index: int) -> Path: + if self.episodes is None: + self.episodes = load_episodes(self.root) + if ep_index >= len(self.episodes): + raise IndexError( + f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" + ) ep = self.episodes[ep_index] chunk_idx = ep["data/chunk_index"] file_idx = ep["data/file_index"] @@ -145,6 +200,12 @@ class LeRobotDatasetMetadata: return Path(fpath) def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: + if self.episodes is None: + self.episodes = load_episodes(self.root) + if ep_index >= len(self.episodes): + raise IndexError( + f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" + ) ep = self.episodes[ep_index] chunk_idx = ep[f"videos/{vid_key}/chunk_index"] file_idx = ep[f"videos/{vid_key}/file_index"] @@ -260,72 +321,75 @@ class LeRobotDatasetMetadata: write_tasks(self.tasks, self.root) def _save_episode_metadata(self, episode_dict: dict) -> None: - """Save episode metadata to a parquet file and update the Hugging Face dataset of episodes metadata. + """Buffer episode metadata and write to parquet in batches for efficiency. - This function processes episodes metadata from a dictionary, converts it into a Hugging Face dataset, - and saves it as a parquet file. It handles both the creation of new parquet files and the - updating of existing ones based on size constraints. After saving the metadata, it reloads - the Hugging Face dataset to ensure it is up-to-date. + This function accumulates episode metadata in a buffer and flushes it when the buffer + reaches the configured size. This reduces I/O overhead by writing multiple episodes + at once instead of one row at a time. Notes: We both need to update parquet files and HF dataset: - `pandas` loads parquet file in RAM - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, or loads directly from pyarrow cache. """ - # Convert buffer into HF Dataset + # Convert to list format for each value episode_dict = {key: [value] for key, value in episode_dict.items()} - ep_dataset = datasets.Dataset.from_dict(episode_dict) - ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset) - df = pd.DataFrame(ep_dataset) num_frames = episode_dict["length"][0] - if self.episodes is None: + if self.latest_episode is None: # Initialize indices and frame count for a new dataset made of the first episode data chunk_idx, file_idx = 0, 0 - df["meta/episodes/chunk_index"] = [chunk_idx] - df["meta/episodes/file_index"] = [file_idx] - df["dataset_from_index"] = [0] - df["dataset_to_index"] = [num_frames] - else: - # Retrieve information from the latest parquet file - latest_ep = self.episodes[-1] - chunk_idx = latest_ep["meta/episodes/chunk_index"] - file_idx = latest_ep["meta/episodes/file_index"] + if self.episodes is not None and len(self.episodes) > 0: + # It means we are resuming recording, so we need to load the latest episode + # Update the indices to avoid overwriting the latest episode + chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"] + file_idx = self.episodes[-1]["meta/episodes/file_index"] + latest_num_frames = self.episodes[-1]["dataset_to_index"] + episode_dict["dataset_from_index"] = [latest_num_frames] + episode_dict["dataset_to_index"] = [latest_num_frames + num_frames] - latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - latest_size_in_mb = get_parquet_file_size_in_mb(latest_path) - - if latest_size_in_mb + ep_size_in_mb >= self.data_files_size_in_mb: - # Size limit is reached, prepare new parquet file + # When resuming, move to the next file chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + else: + episode_dict["dataset_from_index"] = [0] + episode_dict["dataset_to_index"] = [num_frames] + + episode_dict["meta/episodes/chunk_index"] = [chunk_idx] + episode_dict["meta/episodes/file_index"] = [file_idx] + else: + chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0] + file_idx = self.latest_episode["meta/episodes/file_index"][0] + + latest_path = ( + self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + if self.writer is None + else self.writer.where + ) + + if Path(latest_path).exists(): + latest_size_in_mb = get_file_size_in_mb(Path(latest_path)) + latest_num_frames = self.latest_episode["episode_index"][0] + + av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0 + + if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb: + # Size limit is reached, flush buffer and prepare new parquet file + self._flush_metadata_buffer() + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + self._close_writer() # Update the existing pandas dataframe with new row - df["meta/episodes/chunk_index"] = [chunk_idx] - df["meta/episodes/file_index"] = [file_idx] - df["dataset_from_index"] = [latest_ep["dataset_to_index"]] - df["dataset_to_index"] = [latest_ep["dataset_to_index"] + num_frames] + episode_dict["meta/episodes/chunk_index"] = [chunk_idx] + episode_dict["meta/episodes/file_index"] = [file_idx] + episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]] + episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames] - if latest_size_in_mb + ep_size_in_mb < self.data_files_size_in_mb: - # Size limit wasnt reached, concatenate latest dataframe with new one - latest_df = pd.read_parquet(latest_path) - df = pd.concat([latest_df, df], ignore_index=True) + # Add to buffer + self.metadata_buffer.append(episode_dict) + self.latest_episode = episode_dict - # Memort optimization - del latest_df - gc.collect() - - # Write the resulting dataframe from RAM to disk - path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - path.parent.mkdir(parents=True, exist_ok=True) - df.to_parquet(path, index=False) - - if self.episodes is not None: - # Remove the episodes cache directory, necessary to avoid cache bloat - cached_dir = get_hf_dataset_cache_dir(self.episodes) - if cached_dir is not None: - shutil.rmtree(cached_dir) - - self.episodes = load_episodes(self.root) + if len(self.metadata_buffer) >= self.metadata_buffer_size: + self._flush_metadata_buffer() def save_episode( self, @@ -438,6 +502,10 @@ class LeRobotDatasetMetadata: robot_type: str | None = None, root: str | Path | None = None, use_videos: bool = True, + metadata_buffer_size: int = 10, + chunks_size: int | None = None, + data_files_size_in_mb: int | None = None, + video_files_size_in_mb: int | None = None, ) -> "LeRobotDatasetMetadata": """Creates metadata for a LeRobotDataset.""" obj = cls.__new__(cls) @@ -452,11 +520,24 @@ class LeRobotDatasetMetadata: obj.tasks = None obj.episodes = None obj.stats = None - obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type) + obj.info = create_empty_dataset_info( + CODEBASE_VERSION, + fps, + features, + use_videos, + robot_type, + chunks_size, + data_files_size_in_mb, + video_files_size_in_mb, + ) if len(obj.video_keys) > 0 and not use_videos: raise ValueError() write_json(obj.info, obj.root / INFO_PATH) obj.revision = None + obj.writer = None + obj.latest_episode = None + obj.metadata_buffer = [] + obj.metadata_buffer_size = metadata_buffer_size return obj @@ -603,6 +684,8 @@ class LeRobotDataset(torch.utils.data.Dataset): # Unused attributes self.image_writer = None self.episode_buffer = None + self.writer = None + self.latest_episode = None self.root.mkdir(exist_ok=True, parents=True) @@ -611,6 +694,11 @@ class LeRobotDataset(torch.utils.data.Dataset): self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync ) + # Track dataset state for efficient incremental writing + self._lazy_loading = False + self._recorded_frames = self.meta.total_frames + self._writer_closed_for_reading = False + # Load actual data try: if force_cache_sync: @@ -629,6 +717,19 @@ class LeRobotDataset(torch.utils.data.Dataset): check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) + def _close_writer(self) -> None: + """Close and cleanup the parquet writer if it exists.""" + writer = getattr(self, "writer", None) + if writer is not None: + writer.close() + self.writer = None + + def __del__(self): + """ + Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor + """ + self._close_writer() + def push_to_hub( self, branch: str | None = None, @@ -769,8 +870,15 @@ class LeRobotDataset(torch.utils.data.Dataset): @property def num_frames(self) -> int: - """Number of frames in selected episodes.""" - return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames + """Number of frames in selected episodes. + + Note: When episodes a subset of the full dataset is requested, we must return the + actual loaded data length (len(self.hf_dataset)) rather than metadata total_frames. + self.meta.total_frames is the total number of frames in the full dataset. + """ + if self.episodes is not None and self.hf_dataset is not None: + return len(self.hf_dataset) + return self.meta.total_frames @property def num_episodes(self) -> int: @@ -848,10 +956,22 @@ class LeRobotDataset(torch.utils.data.Dataset): return item + def _ensure_hf_dataset_loaded(self): + """Lazy load the HF dataset only when needed for reading.""" + if self._lazy_loading or self.hf_dataset is None: + # Close the writer before loading to ensure parquet file is properly finalized + if self.writer is not None: + self._close_writer() + self._writer_closed_for_reading = True + self.hf_dataset = self.load_hf_dataset() + self._lazy_loading = False + def __len__(self): return self.num_frames def __getitem__(self, idx) -> dict: + # Ensure dataset is loaded when we actually need to read from it + self._ensure_hf_dataset_loaded() item = self.hf_dataset[idx] ep_idx = item["episode_index"].item() @@ -890,6 +1010,14 @@ class LeRobotDataset(torch.utils.data.Dataset): "})',\n" ) + def finalize(self): + """ + Close the parquet writers. This function needs to be called after data collection/conversion, else footer metadata won't be written to the parquet files. + The dataset won't be valid and can't be loaded as ds = LeRobotDataset(repo_id=repo, root=HF_LEROBOT_HOME.joinpath(repo)) + """ + self._close_writer() + self.meta._close_writer() + def create_episode_buffer(self, episode_index: int | None = None) -> dict: current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index ep_buffer = {} @@ -1097,74 +1225,101 @@ class LeRobotDataset(torch.utils.data.Dataset): ep_dict = {key: episode_buffer[key] for key in self.hf_features} ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train") ep_dataset = embed_images(ep_dataset) - ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset) ep_num_frames = len(ep_dataset) - df = pd.DataFrame(ep_dataset) - if self.meta.episodes is None: + if self.latest_episode is None: # Initialize indices and frame count for a new dataset made of the first episode data chunk_idx, file_idx = 0, 0 - latest_num_frames = 0 + global_frame_index = 0 + # However, if the episodes already exists + # It means we are resuming recording, so we need to load the latest episode + # Update the indices to avoid overwriting the latest episode + if self.meta.episodes is not None and len(self.meta.episodes) > 0: + latest_ep = self.meta.episodes[-1] + global_frame_index = latest_ep["dataset_to_index"] + chunk_idx = latest_ep["data/chunk_index"] + file_idx = latest_ep["data/file_index"] + + # When resuming, move to the next file + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) else: # Retrieve information from the latest parquet file - latest_ep = self.meta.episodes[-1] + latest_ep = self.latest_episode chunk_idx = latest_ep["data/chunk_index"] file_idx = latest_ep["data/file_index"] + global_frame_index = latest_ep["index"][-1] + 1 latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) - latest_size_in_mb = get_parquet_file_size_in_mb(latest_path) - latest_num_frames = get_parquet_num_frames(latest_path) + latest_size_in_mb = get_file_size_in_mb(latest_path) + + frames_in_current_file = global_frame_index - latest_ep["dataset_from_index"] + av_size_per_frame = ( + latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0 + ) # Determine if a new parquet file is needed - if latest_size_in_mb + ep_size_in_mb >= self.meta.data_files_size_in_mb: - # Size limit is reached, prepare new parquet file + if ( + latest_size_in_mb + av_size_per_frame * ep_num_frames >= self.meta.data_files_size_in_mb + or self._writer_closed_for_reading + ): + # Size limit is reached or writer was closed for reading, prepare new parquet file chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) - latest_num_frames = 0 - else: - # Update the existing parquet file with new rows - latest_df = pd.read_parquet(latest_path) - df = pd.concat([latest_df, df], ignore_index=True) + self._close_writer() + self._writer_closed_for_reading = False - # Memort optimization - del latest_df - gc.collect() + ep_dict["data/chunk_index"] = chunk_idx + ep_dict["data/file_index"] = file_idx # Write the resulting dataframe from RAM to disk path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) path.parent.mkdir(parents=True, exist_ok=True) - if len(self.meta.image_keys) > 0: - to_parquet_with_hf_images(df, path) - else: - df.to_parquet(path) - if self.hf_dataset is not None: - # Remove hf dataset cache directory, necessary to avoid cache bloat - cached_dir = get_hf_dataset_cache_dir(self.hf_dataset) - if cached_dir is not None: - shutil.rmtree(cached_dir) - - self.hf_dataset = self.load_hf_dataset() + table = ep_dataset.with_format("arrow")[:] + if not self.writer: + self.writer = pq.ParquetWriter( + path, schema=table.schema, compression="snappy", use_dictionary=True + ) + self.writer.write_table(table) metadata = { "data/chunk_index": chunk_idx, "data/file_index": file_idx, - "dataset_from_index": latest_num_frames, - "dataset_to_index": latest_num_frames + ep_num_frames, + "dataset_from_index": global_frame_index, + "dataset_to_index": global_frame_index + ep_num_frames, } + + # Store metadata with episode data for next episode + self.latest_episode = {**ep_dict, **metadata} + + # Mark that the HF dataset needs reloading (lazy loading approach) + # This avoids expensive reloading during sequential recording + self._lazy_loading = True + # Update recorded frames count for efficient length tracking + self._recorded_frames += ep_num_frames + return metadata def _save_episode_video(self, video_key: str, episode_index: int) -> dict: # Encode episode frames into a temporary video ep_path = self._encode_temporary_episode_video(video_key, episode_index) - ep_size_in_mb = get_video_size_in_mb(ep_path) + ep_size_in_mb = get_file_size_in_mb(ep_path) ep_duration_in_s = get_video_duration_in_s(ep_path) - if self.meta.episodes is None or ( - f"videos/{video_key}/chunk_index" not in self.meta.episodes.column_names - or f"videos/{video_key}/file_index" not in self.meta.episodes.column_names + if ( + episode_index == 0 + or self.meta.latest_episode is None + or f"videos/{video_key}/chunk_index" not in self.meta.latest_episode ): # Initialize indices for a new dataset made of the first episode data chunk_idx, file_idx = 0, 0 + if self.meta.episodes is not None and len(self.meta.episodes) > 0: + # It means we are resuming recording, so we need to load the latest episode + # Update the indices to avoid overwriting the latest episode + old_chunk_idx = self.meta.episodes[-1][f"videos/{video_key}/chunk_index"] + old_file_idx = self.meta.episodes[-1][f"videos/{video_key}/file_index"] + chunk_idx, file_idx = update_chunk_file_indices( + old_chunk_idx, old_file_idx, self.meta.chunks_size + ) latest_duration_in_s = 0.0 new_path = self.root / self.meta.video_path.format( video_key=video_key, chunk_index=chunk_idx, file_index=file_idx @@ -1172,16 +1327,16 @@ class LeRobotDataset(torch.utils.data.Dataset): new_path.parent.mkdir(parents=True, exist_ok=True) shutil.move(str(ep_path), str(new_path)) else: - # Retrieve information from the latest updated video file (possibly several episodes ago) - latest_ep = self.meta.episodes[episode_index - 1] - chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"] - file_idx = latest_ep[f"videos/{video_key}/file_index"] + # Retrieve information from the latest updated video file using latest_episode + latest_ep = self.meta.latest_episode + chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"][0] + file_idx = latest_ep[f"videos/{video_key}/file_index"][0] latest_path = self.root / self.meta.video_path.format( video_key=video_key, chunk_index=chunk_idx, file_index=file_idx ) - latest_size_in_mb = get_video_size_in_mb(latest_path) - latest_duration_in_s = get_video_duration_in_s(latest_path) + latest_size_in_mb = get_file_size_in_mb(latest_path) + latest_duration_in_s = latest_ep[f"videos/{video_key}/to_timestamp"][0] if latest_size_in_mb + ep_size_in_mb >= self.meta.video_files_size_in_mb: # Move temporary episode video to a new video file in the dataset @@ -1315,6 +1470,12 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.delta_timestamps = None obj.delta_indices = None obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() + obj.writer = None + obj.latest_episode = None + # Initialize tracking for incremental recording + obj._lazy_loading = False + obj._recorded_frames = 0 + obj._writer_closed_for_reading = False return obj diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index a2f285014..37d8432b2 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -30,7 +30,7 @@ import pandas import pandas as pd import pyarrow.parquet as pq import torch -from datasets import Dataset, concatenate_datasets +from datasets import Dataset from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub.errors import RevisionNotFoundError @@ -44,7 +44,7 @@ from lerobot.datasets.backward_compatibility import ( ForwardCompatibilityError, ) from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR -from lerobot.utils.utils import is_valid_numpy_dtype_string +from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_string DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file @@ -94,12 +94,6 @@ def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int: return hf_ds.data.nbytes // (1024**2) -def get_hf_dataset_cache_dir(hf_ds: Dataset) -> Path | None: - if hf_ds.cache_files is None or len(hf_ds.cache_files) == 0: - return None - return Path(hf_ds.cache_files[0]["filename"]).parents[2] - - def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]: if file_idx == chunks_size - 1: file_idx = 0 @@ -123,8 +117,9 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") # TODO(rcadene): set num_proc to accelerate conversion to pyarrow - datasets = [Dataset.from_parquet(str(path), features=features) for path in paths] - return concatenate_datasets(datasets) + with SuppressProgressBars(): + datasets = Dataset.from_parquet([str(path) for path in paths], features=features) + return datasets def get_parquet_num_frames(parquet_path: str | Path) -> int: @@ -132,10 +127,14 @@ def get_parquet_num_frames(parquet_path: str | Path) -> int: return metadata.num_rows -def get_video_size_in_mb(mp4_path: Path) -> float: - file_size_bytes = mp4_path.stat().st_size - file_size_mb = file_size_bytes / (1024**2) - return file_size_mb +def get_file_size_in_mb(file_path: Path) -> float: + """Get file size on disk in megabytes. + + Args: + file_path (Path): Path to the file. + """ + file_size_bytes = file_path.stat().st_size + return file_size_bytes / (1024**2) def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py index 42ab2f642..b8ae29ad6 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -69,9 +69,9 @@ from lerobot.datasets.utils import ( LEGACY_TASKS_PATH, cast_stats_to_numpy, flatten_dict, + get_file_size_in_mb, get_parquet_file_size_in_mb, get_parquet_num_frames, - get_video_size_in_mb, load_info, update_chunk_file_indices, write_episodes, @@ -310,7 +310,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f episodes_metadata = [] for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"): - ep_size_in_mb = get_video_size_in_mb(ep_path) + ep_size_in_mb = get_file_size_in_mb(ep_path) ep_duration_in_s = get_video_duration_in_s(ep_path) # Check if adding this episode would exceed the limit diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 1d4f07c76..740cdb602 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -452,6 +452,9 @@ def concatenate_video_files( template=input_stream, opaque=True ) + # set the time base to the input stream time base (missing in the codec context) + stream_map[input_stream.index].time_base = input_stream.time_base + # Demux + remux packets (no re-encode) for packet in input_container.demux(): # Skip packets from un-mapped streams @@ -639,6 +642,9 @@ class VideoEncodingManager: ) self.dataset._batch_save_episode_video(start_ep, end_ep) + # Finalize the dataset to properly close all writers + self.dataset.finalize() + # Clean up episode images if recording was interrupted if exc_type is not None: interrupted_episode_index = self.dataset.num_episodes diff --git a/src/lerobot/envs/__init__.py b/src/lerobot/envs/__init__.py index 4977d11d9..d767b6e8c 100644 --- a/src/lerobot/envs/__init__.py +++ b/src/lerobot/envs/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401 +from .configs import AlohaEnv, EnvConfig, PushtEnv # noqa: F401 diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 0daaaf9fd..3aa155093 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -50,6 +50,8 @@ class AlohaEnv(EnvConfig): fps: int = 50 episode_length: int = 400 obs_type: str = "pixels_agent_pos" + observation_height: int = 480 + observation_width: int = 640 render_mode: str = "rgb_array" features: dict[str, PolicyFeature] = field( default_factory=lambda: { @@ -67,10 +69,14 @@ class AlohaEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels": - self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) + self.features["top"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) + ) elif self.obs_type == "pixels_agent_pos": self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,)) - self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) + self.features["pixels/top"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) + ) @property def gym_kwargs(self) -> dict: @@ -91,6 +97,8 @@ class PushtEnv(EnvConfig): render_mode: str = "rgb_array" visualization_width: int = 384 visualization_height: int = 384 + observation_height: int = 384 + observation_width: int = 384 features: dict[str, PolicyFeature] = field( default_factory=lambda: { ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), @@ -108,7 +116,9 @@ class PushtEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels_agent_pos": - self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3)) + self.features["pixels"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) + ) elif self.obs_type == "environment_state_agent_pos": self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,)) @@ -123,45 +133,6 @@ class PushtEnv(EnvConfig): } -@EnvConfig.register_subclass("xarm") -@dataclass -class XarmEnv(EnvConfig): - task: str | None = "XarmLift-v0" - fps: int = 15 - episode_length: int = 200 - obs_type: str = "pixels_agent_pos" - render_mode: str = "rgb_array" - visualization_width: int = 384 - visualization_height: int = 384 - features: dict[str, PolicyFeature] = field( - default_factory=lambda: { - ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)), - "pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)), - } - ) - features_map: dict[str, str] = field( - default_factory=lambda: { - ACTION: ACTION, - "agent_pos": OBS_STATE, - "pixels": OBS_IMAGE, - } - ) - - def __post_init__(self): - if self.obs_type == "pixels_agent_pos": - self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,)) - - @property - def gym_kwargs(self) -> dict: - return { - "obs_type": self.obs_type, - "render_mode": self.render_mode, - "visualization_width": self.visualization_width, - "visualization_height": self.visualization_height, - "max_episode_steps": self.episode_length, - } - - @dataclass class ImagePreprocessingConfig: crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None @@ -255,6 +226,8 @@ class LiberoEnv(EnvConfig): camera_name: str = "agentview_image,robot0_eye_in_hand_image" init_states: bool = True camera_name_mapping: dict[str, str] | None = None + observation_height: int = 360 + observation_width: int = 360 features: dict[str, PolicyFeature] = field( default_factory=lambda: { ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)), @@ -272,18 +245,18 @@ class LiberoEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels": self.features["pixels/agentview_image"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(360, 360, 3) + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) ) self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(360, 360, 3) + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) ) elif self.obs_type == "pixels_agent_pos": self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,)) self.features["pixels/agentview_image"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(360, 360, 3) + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) ) self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(360, 360, 3) + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) ) else: raise ValueError(f"Unsupported obs_type: {self.obs_type}") @@ -294,3 +267,45 @@ class LiberoEnv(EnvConfig): "obs_type": self.obs_type, "render_mode": self.render_mode, } + + +@EnvConfig.register_subclass("metaworld") +@dataclass +class MetaworldEnv(EnvConfig): + task: str = "metaworld-push-v2" # add all tasks + fps: int = 80 + episode_length: int = 400 + obs_type: str = "pixels_agent_pos" + render_mode: str = "rgb_array" + multitask_eval: bool = True + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "agent_pos": OBS_STATE, + "top": f"{OBS_IMAGE}", + "pixels/top": f"{OBS_IMAGE}", + } + ) + + def __post_init__(self): + if self.obs_type == "pixels": + self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3)) + + elif self.obs_type == "pixels_agent_pos": + self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,)) + self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3)) + + else: + raise ValueError(f"Unsupported obs_type: {self.obs_type}") + + @property + def gym_kwargs(self) -> dict: + return { + "obs_type": self.obs_type, + "render_mode": self.render_mode, + } diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index c27f01b65..059e0e11a 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -17,7 +17,7 @@ import importlib import gymnasium as gym -from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv, XarmEnv +from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv def make_env_config(env_type: str, **kwargs) -> EnvConfig: @@ -25,8 +25,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: return AlohaEnv(**kwargs) elif env_type == "pusht": return PushtEnv(**kwargs) - elif env_type == "xarm": - return XarmEnv(**kwargs) elif env_type == "libero": return LiberoEnv(**kwargs) else: @@ -74,7 +72,18 @@ def make_env( gym_kwargs=cfg.gym_kwargs, env_cls=env_cls, ) + elif "metaworld" in cfg.type: + from lerobot.envs.metaworld import create_metaworld_envs + if cfg.task is None: + raise ValueError("MetaWorld requires a task to be specified") + + return create_metaworld_envs( + task=cfg.task, + n_envs=n_envs, + gym_kwargs=cfg.gym_kwargs, + env_cls=env_cls, + ) package_name = f"gym_{cfg.type}" try: importlib.import_module(package_name) @@ -87,7 +96,7 @@ def make_env( def _make_one(): return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {})) - vec = env_cls([_make_one for _ in range(n_envs)]) + vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP) # normalize to {suite: {task_id: vec_env}} for consistency suite_name = cfg.type # e.g., "pusht", "aloha" diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 99ec6712f..94b08e991 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -260,19 +260,23 @@ class LiberoEnv(gym.Env): is_success = self._env.check_success() terminated = done or is_success - info["is_success"] = is_success - + info.update( + { + "task": self.task, + "task_id": self.task_id, + "done": done, + "is_success": is_success, + } + ) observation = self._format_raw_obs(raw_obs) - if done: + if terminated: + info["final_info"] = { + "task": self.task, + "task_id": self.task_id, + "done": bool(done), + "is_success": bool(is_success), + } self.reset() - info.update( - { - "task": self.task, - "task_id": self.task_id, - "done": done, - "is_success": is_success, - } - ) truncated = False return observation, reward, terminated, truncated, info diff --git a/src/lerobot/envs/metaworld.py b/src/lerobot/envs/metaworld.py new file mode 100644 index 000000000..9190f33ad --- /dev/null +++ b/src/lerobot/envs/metaworld.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from collections import defaultdict +from collections.abc import Callable, Sequence +from pathlib import Path +from typing import Any + +import gymnasium as gym +import metaworld +import metaworld.policies as policies +import numpy as np +from gymnasium import spaces + +# ---- Load configuration data from the external JSON file ---- +CONFIG_PATH = Path(__file__).parent / "metaworld_config.json" +try: + with open(CONFIG_PATH) as f: + data = json.load(f) +except FileNotFoundError as err: + raise FileNotFoundError( + "Could not find 'metaworld_config.json'. " + "Please ensure the configuration file is in the same directory as the script." + ) from err +except json.JSONDecodeError as err: + raise ValueError( + "Failed to decode 'metaworld_config.json'. Please ensure it is a valid JSON file." + ) from err + +# ---- Process the loaded data ---- + +# extract and type-check top-level dicts +task_descriptions_obj = data.get("TASK_DESCRIPTIONS") +if not isinstance(task_descriptions_obj, dict): + raise TypeError("Expected TASK_DESCRIPTIONS to be a dict[str, str]") +TASK_DESCRIPTIONS: dict[str, str] = task_descriptions_obj + +task_name_to_id_obj = data.get("TASK_NAME_TO_ID") +if not isinstance(task_name_to_id_obj, dict): + raise TypeError("Expected TASK_NAME_TO_ID to be a dict[str, int]") +TASK_NAME_TO_ID: dict[str, int] = task_name_to_id_obj + +# difficulty -> tasks mapping +difficulty_to_tasks = data.get("DIFFICULTY_TO_TASKS") +if not isinstance(difficulty_to_tasks, dict): + raise TypeError("Expected 'DIFFICULTY_TO_TASKS' to be a dict[str, list[str]]") +DIFFICULTY_TO_TASKS: dict[str, list[str]] = difficulty_to_tasks + +# convert policy strings -> actual policy classes +task_policy_mapping = data.get("TASK_POLICY_MAPPING") +if not isinstance(task_policy_mapping, dict): + raise TypeError("Expected 'TASK_POLICY_MAPPING' to be a dict[str, str]") +TASK_POLICY_MAPPING: dict[str, Any] = { + task_name: getattr(policies, policy_class_name) + for task_name, policy_class_name in task_policy_mapping.items() +} +ACTION_DIM = 4 +OBS_DIM = 4 + + +class MetaworldEnv(gym.Env): + metadata = {"render_modes": ["rgb_array"], "render_fps": 80} + + def __init__( + self, + task, + camera_name="corner2", + obs_type="pixels", + render_mode="rgb_array", + observation_width=480, + observation_height=480, + visualization_width=640, + visualization_height=480, + ): + super().__init__() + self.task = task.replace("metaworld-", "") + self.obs_type = obs_type + self.render_mode = render_mode + self.observation_width = observation_width + self.observation_height = observation_height + self.visualization_width = visualization_width + self.visualization_height = visualization_height + self.camera_name = camera_name + + self._env = self._make_envs_task(self.task) + self._max_episode_steps = self._env.max_path_length + self.task_description = TASK_DESCRIPTIONS[self.task] + + self.expert_policy = TASK_POLICY_MAPPING[self.task]() + + if self.obs_type == "state": + raise NotImplementedError() + elif self.obs_type == "pixels": + self.observation_space = spaces.Dict( + { + "pixels": spaces.Box( + low=0, + high=255, + shape=(self.observation_height, self.observation_width, 3), + dtype=np.uint8, + ) + } + ) + elif self.obs_type == "pixels_agent_pos": + self.observation_space = spaces.Dict( + { + "pixels": spaces.Box( + low=0, + high=255, + shape=(self.observation_height, self.observation_width, 3), + dtype=np.uint8, + ), + "agent_pos": spaces.Box( + low=-1000.0, + high=1000.0, + shape=(OBS_DIM,), + dtype=np.float64, + ), + } + ) + + self.action_space = spaces.Box(low=-1, high=1, shape=(ACTION_DIM,), dtype=np.float32) + + def render(self) -> np.ndarray: + """ + Render the current environment frame. + + Returns: + np.ndarray: The rendered RGB image from the environment. + """ + image = self._env.render() + if self.camera_name == "corner2": + # Images from this camera are flipped — correct them + image = np.flip(image, (0, 1)) + return image + + def _make_envs_task(self, env_name: str): + mt1 = metaworld.MT1(env_name, seed=42) + env = mt1.train_classes[env_name](render_mode="rgb_array", camera_name=self.camera_name) + env.set_task(mt1.train_tasks[0]) + if self.camera_name == "corner2": + env.model.cam_pos[2] = [ + 0.75, + 0.075, + 0.7, + ] # corner2 position, similar to https://arxiv.org/pdf/2206.14244 + env.reset() + env._freeze_rand_vec = False # otherwise no randomization + return env + + def _format_raw_obs(self, raw_obs: np.ndarray) -> dict[str, Any]: + image = None + if self._env is not None: + image = self._env.render() + if self.camera_name == "corner2": + # NOTE: The "corner2" camera in MetaWorld environments outputs images with both axes inverted. + image = np.flip(image, (0, 1)) + agent_pos = raw_obs[:4] + if self.obs_type == "state": + raise NotImplementedError( + "'state' obs_type not implemented for MetaWorld. Use pixel modes instead." + ) + + elif self.obs_type in ("pixels", "pixels_agent_pos"): + assert image is not None, ( + "Expected `image` to be rendered before constructing pixel-based observations. " + "This likely means `env.render()` returned None or the environment was not provided." + ) + + if self.obs_type == "pixels": + obs = {"pixels": image.copy()} + + else: # pixels_agent_pos + obs = { + "pixels": image.copy(), + "agent_pos": agent_pos, + } + else: + raise ValueError(f"Unknown obs_type: {self.obs_type}") + return obs + + def reset( + self, + seed: int | None = None, + **kwargs, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Reset the environment to its initial state. + + Args: + seed (Optional[int]): Random seed for environment initialization. + + Returns: + observation (Dict[str, Any]): The initial formatted observation. + info (Dict[str, Any]): Additional info about the reset state. + """ + super().reset(seed=seed) + + raw_obs, info = self._env.reset(seed=seed) + + observation = self._format_raw_obs(raw_obs) + + info = {"is_success": False} + return observation, info + + def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + """ + Perform one environment step. + + Args: + action (np.ndarray): The action to execute, must be 1-D with shape (action_dim,). + + Returns: + observation (Dict[str, Any]): The formatted observation after the step. + reward (float): The scalar reward for this step. + terminated (bool): Whether the episode terminated successfully. + truncated (bool): Whether the episode was truncated due to a time limit. + info (Dict[str, Any]): Additional environment info. + """ + if action.ndim != 1: + raise ValueError( + f"Expected action to be 1-D (shape (action_dim,)), " + f"but got shape {action.shape} with ndim={action.ndim}" + ) + raw_obs, reward, done, truncated, info = self._env.step(action) + + # Determine whether the task was successful + is_success = bool(info.get("success", 0)) + terminated = done or is_success + info.update( + { + "task": self.task, + "done": done, + "is_success": is_success, + } + ) + + # Format the raw observation into the expected structure + observation = self._format_raw_obs(raw_obs) + if terminated: + info["final_info"] = { + "task": self.task, + "done": bool(done), + "is_success": bool(is_success), + } + self.reset() + + return observation, reward, terminated, truncated, info + + def close(self): + self._env.close() + + +# ---- Main API ---------------------------------------------------------------- + + +def create_metaworld_envs( + task: str, + n_envs: int, + gym_kwargs: dict[str, Any] | None = None, + env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None, +) -> dict[str, dict[int, Any]]: + """ + Create vectorized Meta-World environments with a consistent return shape. + + Returns: + dict[task_group][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories) + Notes: + - n_envs is the number of rollouts *per task* (episode_index = 0..n_envs-1). + - `task` can be a single difficulty group (e.g., "easy", "medium", "hard") or a comma-separated list. + - If a task name is not in DIFFICULTY_TO_TASKS, we treat it as a single custom task. + """ + if env_cls is None or not callable(env_cls): + raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.") + if not isinstance(n_envs, int) or n_envs <= 0: + raise ValueError(f"n_envs must be a positive int; got {n_envs}.") + + gym_kwargs = dict(gym_kwargs or {}) + task_groups = [t.strip() for t in task.split(",") if t.strip()] + if not task_groups: + raise ValueError("`task` must contain at least one Meta-World task or difficulty group.") + + print(f"Creating Meta-World envs | task_groups={task_groups} | n_envs(per task)={n_envs}") + + out: dict[str, dict[int, Any]] = defaultdict(dict) + + for group in task_groups: + # if not in difficulty presets, treat it as a single custom task + tasks = DIFFICULTY_TO_TASKS.get(group, [group]) + + for tid, task_name in enumerate(tasks): + print(f"Building vec env | group={group} | task_id={tid} | task={task_name}") + + # build n_envs factories + fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)] + + out[group][tid] = env_cls(fns) + + # return a plain dict for consistency + return {group: dict(task_map) for group, task_map in out.items()} diff --git a/src/lerobot/envs/metaworld_config.json b/src/lerobot/envs/metaworld_config.json new file mode 100644 index 000000000..41a417fef --- /dev/null +++ b/src/lerobot/envs/metaworld_config.json @@ -0,0 +1,121 @@ +{ + "TASK_DESCRIPTIONS": { + "assembly-v3": "Pick up a nut and place it onto a peg", + "basketball-v3": "Dunk the basketball into the basket", + "bin-picking-v3": "Grasp the puck from one bin and place it into another bin", + "box-close-v3": "Grasp the cover and close the box with it", + "button-press-topdown-v3": "Press a button from the top", + "button-press-topdown-wall-v3": "Bypass a wall and press a button from the top", + "button-press-v3": "Press a button", + "button-press-wall-v3": "Bypass a wall and press a button", + "coffee-button-v3": "Push a button on the coffee machine", + "coffee-pull-v3": "Pull a mug from a coffee machine", + "coffee-push-v3": "Push a mug under a coffee machine", + "dial-turn-v3": "Rotate a dial 180 degrees", + "disassemble-v3": "Pick a nut out of a peg", + "door-close-v3": "Close a door with a revolving joint", + "door-lock-v3": "Lock the door by rotating the lock clockwise", + "door-open-v3": "Open a door with a revolving joint", + "door-unlock-v3": "Unlock the door by rotating the lock counter-clockwise", + "hand-insert-v3": "Insert the gripper into a hole", + "drawer-close-v3": "Push and close a drawer", + "drawer-open-v3": "Open a drawer", + "faucet-open-v3": "Rotate the faucet counter-clockwise", + "faucet-close-v3": "Rotate the faucet clockwise", + "hammer-v3": "Hammer a screw on the wall", + "handle-press-side-v3": "Press a handle down sideways", + "handle-press-v3": "Press a handle down", + "handle-pull-side-v3": "Pull a handle up sideways", + "handle-pull-v3": "Pull a handle up", + "lever-pull-v3": "Pull a lever down 90 degrees", + "peg-insert-side-v3": "Insert a peg sideways", + "pick-place-wall-v3": "Pick a puck, bypass a wall and place the puck", + "pick-out-of-hole-v3": "Pick up a puck from a hole", + "reach-v3": "Reach a goal position", + "push-back-v3": "Push the puck to a goal", + "push-v3": "Push the puck to a goal", + "pick-place-v3": "Pick and place a puck to a goal", + "plate-slide-v3": "Slide a plate into a cabinet", + "plate-slide-side-v3": "Slide a plate into a cabinet sideways", + "plate-slide-back-v3": "Get a plate from the cabinet", + "plate-slide-back-side-v3": "Get a plate from the cabinet sideways", + "peg-unplug-side-v3": "Unplug a peg sideways", + "soccer-v3": "Kick a soccer into the goal", + "stick-push-v3": "Grasp a stick and push a box using the stick", + "stick-pull-v3": "Grasp a stick and pull a box with the stick", + "push-wall-v3": "Bypass a wall and push a puck to a goal", + "reach-wall-v3": "Bypass a wall and reach a goal", + "shelf-place-v3": "Pick and place a puck onto a shelf", + "sweep-into-v3": "Sweep a puck into a hole", + "sweep-v3": "Sweep a puck off the table", + "window-open-v3": "Push and open a window", + "window-close-v3": "Push and close a window" + }, + "TASK_NAME_TO_ID": { + "assembly-v3": 0, "basketball-v3": 1, "bin-picking-v3": 2, "box-close-v3": 3, + "button-press-topdown-v3": 4, "button-press-topdown-wall-v3": 5, "button-press-v3": 6, + "button-press-wall-v3": 7, "coffee-button-v3": 8, "coffee-pull-v3": 9, "coffee-push-v3": 10, + "dial-turn-v3": 11, "disassemble-v3": 12, "door-close-v3": 13, "door-lock-v3": 14, + "door-open-v3": 15, "door-unlock-v3": 16, "drawer-close-v3": 17, "drawer-open-v3": 18, + "faucet-close-v3": 19, "faucet-open-v3": 20, "hammer-v3": 21, "hand-insert-v3": 22, + "handle-press-side-v3": 23, "handle-press-v3": 24, "handle-pull-side-v3": 25, + "handle-pull-v3": 26, "lever-pull-v3": 27, "peg-insert-side-v3": 28, "peg-unplug-side-v3": 29, + "pick-out-of-hole-v3": 30, "pick-place-v3": 31, "pick-place-wall-v3": 32, + "plate-slide-back-side-v3": 33, "plate-slide-back-v3": 34, "plate-slide-side-v3": 35, + "plate-slide-v3": 36, "push-back-v3": 37, "push-v3": 38, "push-wall-v3": 39, "reach-v3": 40, + "reach-wall-v3": 41, "shelf-place-v3": 42, "soccer-v3": 43, "stick-pull-v3": 44, + "stick-push-v3": 45, "sweep-into-v3": 46, "sweep-v3": 47, "window-open-v3": 48, + "window-close-v3": 49 + }, + "DIFFICULTY_TO_TASKS": { + "easy": [ + "button-press-v3", "button-press-topdown-v3", "button-press-topdown-wall-v3", + "button-press-wall-v3", "coffee-button-v3", "dial-turn-v3", "door-close-v3", + "door-lock-v3", "door-open-v3", "door-unlock-v3", "drawer-close-v3", "drawer-open-v3", + "faucet-close-v3", "faucet-open-v3", "handle-press-v3", "handle-press-side-v3", + "handle-pull-v3", "handle-pull-side-v3", "lever-pull-v3", "plate-slide-v3", + "plate-slide-back-v3", "plate-slide-back-side-v3", "plate-slide-side-v3", "reach-v3", + "reach-wall-v3", "window-close-v3", "window-open-v3", "peg-unplug-side-v3" + ], + "medium": [ + "basketball-v3", "bin-picking-v3", "box-close-v3", "coffee-pull-v3", "coffee-push-v3", + "hammer-v3", "peg-insert-side-v3", "push-wall-v3", "soccer-v3", "sweep-v3", "sweep-into-v3" + ], + "hard": [ + "assembly-v3", "hand-insert-v3", "pick-out-of-hole-v3", "pick-place-v3", "push-v3", "push-back-v3" + ], + "very_hard": [ + "shelf-place-v3", "disassemble-v3", "stick-pull-v3", "stick-push-v3", "pick-place-wall-v3" + ] + }, + "TASK_POLICY_MAPPING": { + "assembly-v3": "SawyerAssemblyV3Policy", "basketball-v3": "SawyerBasketballV3Policy", + "bin-picking-v3": "SawyerBinPickingV3Policy", "box-close-v3": "SawyerBoxCloseV3Policy", + "button-press-topdown-v3": "SawyerButtonPressTopdownV3Policy", + "button-press-topdown-wall-v3": "SawyerButtonPressTopdownWallV3Policy", + "button-press-v3": "SawyerButtonPressV3Policy", "button-press-wall-v3": "SawyerButtonPressWallV3Policy", + "coffee-button-v3": "SawyerCoffeeButtonV3Policy", "coffee-pull-v3": "SawyerCoffeePullV3Policy", + "coffee-push-v3": "SawyerCoffeePushV3Policy", "dial-turn-v3": "SawyerDialTurnV3Policy", + "disassemble-v3": "SawyerDisassembleV3Policy", "door-close-v3": "SawyerDoorCloseV3Policy", + "door-lock-v3": "SawyerDoorLockV3Policy", "door-open-v3": "SawyerDoorOpenV3Policy", + "door-unlock-v3": "SawyerDoorUnlockV3Policy", "drawer-close-v3": "SawyerDrawerCloseV3Policy", + "drawer-open-v3": "SawyerDrawerOpenV3Policy", "faucet-close-v3": "SawyerFaucetCloseV3Policy", + "faucet-open-v3": "SawyerFaucetOpenV3Policy", "hammer-v3": "SawyerHammerV3Policy", + "hand-insert-v3": "SawyerHandInsertV3Policy", "handle-press-side-v3": "SawyerHandlePressSideV3Policy", + "handle-press-v3": "SawyerHandlePressV3Policy", "handle-pull-side-v3": "SawyerHandlePullSideV3Policy", + "handle-pull-v3": "SawyerHandlePullV3Policy", "lever-pull-v3": "SawyerLeverPullV3Policy", + "peg-insert-side-v3": "SawyerPegInsertionSideV3Policy", "peg-unplug-side-v3": "SawyerPegUnplugSideV3Policy", + "pick-out-of-hole-v3": "SawyerPickOutOfHoleV3Policy", "pick-place-v3": "SawyerPickPlaceV3Policy", + "pick-place-wall-v3": "SawyerPickPlaceWallV3Policy", + "plate-slide-back-side-v3": "SawyerPlateSlideBackSideV3Policy", + "plate-slide-back-v3": "SawyerPlateSlideBackV3Policy", + "plate-slide-side-v3": "SawyerPlateSlideSideV3Policy", "plate-slide-v3": "SawyerPlateSlideV3Policy", + "push-back-v3": "SawyerPushBackV3Policy", "push-v3": "SawyerPushV3Policy", + "push-wall-v3": "SawyerPushWallV3Policy", "reach-v3": "SawyerReachV3Policy", + "reach-wall-v3": "SawyerReachWallV3Policy", "shelf-place-v3": "SawyerShelfPlaceV3Policy", + "soccer-v3": "SawyerSoccerV3Policy", "stick-pull-v3": "SawyerStickPullV3Policy", + "stick-push-v3": "SawyerStickPushV3Policy", "sweep-into-v3": "SawyerSweepIntoV3Policy", + "sweep-v3": "SawyerSweepV3Policy", "window-open-v3": "SawyerWindowOpenV3Policy", + "window-close-v3": "SawyerWindowCloseV3Policy" + } +} diff --git a/src/lerobot/optim/schedulers.py b/src/lerobot/optim/schedulers.py index 55ee62e40..b5d54b396 100644 --- a/src/lerobot/optim/schedulers.py +++ b/src/lerobot/optim/schedulers.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import logging import math from dataclasses import asdict, dataclass from pathlib import Path @@ -79,7 +80,11 @@ class VQBeTSchedulerConfig(LRSchedulerConfig): @LRSchedulerConfig.register_subclass("cosine_decay_with_warmup") @dataclass class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): - """Used by Physical Intelligence to train Pi0""" + """Used by Physical Intelligence to train Pi0. + + Automatically scales warmup and decay steps if num_training_steps < num_decay_steps. + This ensures the learning rate schedule completes properly even with shorter training runs. + """ num_warmup_steps: int num_decay_steps: int @@ -87,23 +92,39 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): decay_lr: float def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: - del num_training_steps + # Auto-scale scheduler parameters if training steps are shorter than configured decay steps + actual_warmup_steps = self.num_warmup_steps + actual_decay_steps = self.num_decay_steps + + if num_training_steps < self.num_decay_steps: + # Calculate scaling factor to fit the schedule into the available training steps + scale_factor = num_training_steps / self.num_decay_steps + actual_warmup_steps = int(self.num_warmup_steps * scale_factor) + actual_decay_steps = num_training_steps + + logging.info( + f"Auto-scaling LR scheduler: " + f"num_training_steps ({num_training_steps}) < num_decay_steps ({self.num_decay_steps}). " + f"Scaling warmup: {self.num_warmup_steps} → {actual_warmup_steps}, " + f"decay: {self.num_decay_steps} → {actual_decay_steps} " + f"(scale factor: {scale_factor:.3f})" + ) def lr_lambda(current_step): def linear_warmup_schedule(current_step): if current_step <= 0: - return 1 / (self.num_warmup_steps + 1) - frac = 1 - current_step / self.num_warmup_steps - return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1 + return 1 / (actual_warmup_steps + 1) + frac = 1 - current_step / actual_warmup_steps + return (1 / (actual_warmup_steps + 1) - 1) * frac + 1 def cosine_decay_schedule(current_step): - step = min(current_step, self.num_decay_steps) - cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps)) + step = min(current_step, actual_decay_steps) + cosine_decay = 0.5 * (1 + math.cos(math.pi * step / actual_decay_steps)) alpha = self.decay_lr / self.peak_lr decayed = (1 - alpha) * cosine_decay + alpha return decayed - if current_step < self.num_warmup_steps: + if current_step < actual_warmup_steps: return linear_warmup_schedule(current_step) return cosine_decay_schedule(current_step) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index ac76baf9f..cfb550ab2 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -31,7 +31,6 @@ from lerobot.envs.utils import env_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config -from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig @@ -58,7 +57,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: Args: name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", - "vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla". + "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla". Returns: The policy class corresponding to the given name. @@ -82,10 +81,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy return VQBeTPolicy - elif name == "pi0fast": - from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy - - return PI0FASTPolicy elif name == "pi0": from lerobot.policies.pi0.modeling_pi0 import PI0Policy @@ -119,7 +114,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: Args: policy_type: The type of the policy. Supported types include "tdmpc", - "diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla", + "diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla", "reward_classifier". **kwargs: Keyword arguments to be passed to the configuration class constructor. @@ -137,8 +132,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return ACTConfig(**kwargs) elif policy_type == "vqbet": return VQBeTConfig(**kwargs) - elif policy_type == "pi0fast": - return PI0FASTConfig(**kwargs) elif policy_type == "pi0": return PI0Config(**kwargs) elif policy_type == "pi05": @@ -260,14 +253,6 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) - elif isinstance(policy_cfg, PI0FASTConfig): - from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors - - processors = make_pi0fast_pre_post_processors( - config=policy_cfg, - dataset_stats=kwargs.get("dataset_stats"), - ) - elif isinstance(policy_cfg, PI0Config): from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py index cc1cda9d8..d745f4317 100644 --- a/src/lerobot/policies/pi0/configuration_pi0.py +++ b/src/lerobot/policies/pi0/configuration_pi0.py @@ -75,6 +75,8 @@ class PI0Config(PreTrainedConfig): optimizer_grad_clip_norm: float = 1.0 # Scheduler settings: see openpi `CosineDecaySchedule` + # Note: These will auto-scale if --steps < scheduler_decay_steps + # For example, --steps=3000 will scale warmup to 100 and decay to 3000 scheduler_warmup_steps: int = 1_000 scheduler_decay_steps: int = 30_000 scheduler_decay_lr: float = 2.5e-6 diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index a2dcdaea3..596b273d5 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -897,7 +897,7 @@ class PI0Policy(PreTrainedPolicy): ) -> T: """Override the from_pretrained method to handle key remapping and display important disclaimer.""" print( - "The PI05 model is a direct port of the OpenPI implementation. \n" + "The PI0 model is a direct port of the OpenPI implementation. \n" "This implementation follows the original OpenPI structure for compatibility. \n" "Original implementation: https://github.com/Physical-Intelligence/openpi" ) diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py index 7c1e950b0..61346c330 100644 --- a/src/lerobot/policies/pi05/configuration_pi05.py +++ b/src/lerobot/policies/pi05/configuration_pi05.py @@ -75,6 +75,8 @@ class PI05Config(PreTrainedConfig): optimizer_grad_clip_norm: float = 1.0 # Scheduler settings: see openpi `CosineDecaySchedule` + # Note: These will auto-scale if --steps < scheduler_decay_steps + # For example, --steps=3000 will scale warmup to 100 and decay to 3000 scheduler_warmup_steps: int = 1_000 scheduler_decay_steps: int = 30_000 scheduler_decay_lr: float = 2.5e-6 diff --git a/src/lerobot/policies/pi0fast/configuration_pi0fast.py b/src/lerobot/policies/pi0fast/configuration_pi0fast.py deleted file mode 100644 index cefd4e688..000000000 --- a/src/lerobot/policies/pi0fast/configuration_pi0fast.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass, field - -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.optim.optimizers import AdamWConfig -from lerobot.optim.schedulers import ( - CosineDecayWithWarmupSchedulerConfig, -) -from lerobot.utils.constants import OBS_IMAGES - - -@PreTrainedConfig.register_subclass("pi0fast") -@dataclass -class PI0FASTConfig(PreTrainedConfig): - # Input / output structure. - n_obs_steps: int = 1 - chunk_size: int = 10 - n_action_steps: int = 5 - - normalization_mapping: dict[str, NormalizationMode] = field( - default_factory=lambda: { - "VISUAL": NormalizationMode.IDENTITY, - "STATE": NormalizationMode.MEAN_STD, - "ACTION": NormalizationMode.MEAN_STD, - } - ) - - # Shorter state and action vectors will be padded - max_state_dim: int = 32 # 32 - max_action_dim: int = 32 # 32 - - # Image preprocessing - resize_imgs_with_padding: tuple[int, int] = (224, 224) - interpolate_like_pi: bool = False - - # Add empty images. Used by pi0_aloha_sim which adds the empty - # left and right wrist cameras in addition to the top camera. - empty_cameras: int = 0 - - # Converts the joint and gripper values from the standard Aloha space to - # the space used by the pi internal runtime which was used to train the base model. - adapt_to_pi_aloha: bool = False - - # Converts joint dimensions to deltas with respect to the current state before passing to the model. - # Gripper dimensions will remain in absolute values. - use_delta_joint_actions_aloha: bool = False - - # Tokenizer - tokenizer_max_length: int = 48 - - # Projector - proj_width: int = 1024 - - # Decoding - max_decoding_steps: int = 256 - fast_skip_tokens: int = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens - max_input_seq_len: int = 256 # 512 - - # Utils - use_cache: bool = True - - # Frozen parameters - freeze_vision_encoder: bool = True - freeze_lm_head: bool = True - - # Training presets - optimizer_lr: float = 1e-4 - optimizer_betas: tuple[float, float] = (0.9, 0.95) - optimizer_eps: float = 1e-8 - optimizer_weight_decay: float = 1e-5 - - scheduler_warmup_steps: int = 1_000 - scheduler_decay_steps: int = 30_000 - scheduler_decay_lr: float = 2.5e-6 - - checkpoint_path: str = None - - padding_side: str = "right" - - precision: str = "bfloat16" - grad_clip_norm: float = 1 - - # Allows padding/truncation of generated action tokens during detokenization to ensure decoding. - # In the original version, tensors of 0s were generated if shapes didn't match for stable decoding. - relaxed_action_decoding: bool = True - - def __post_init__(self): - super().__post_init__() - - """Input validation (not exhaustive).""" - if self.n_action_steps > self.chunk_size: - raise ValueError( - f"The chunk size is the upper bound for the number of action steps per model invocation. Got " - f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." - ) - if self.n_obs_steps != 1: - raise ValueError( - f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" - ) - - def validate_features(self) -> None: - for i in range(self.empty_cameras): - key = f"{OBS_IMAGES}.empty_camera_{i}" - empty_camera = PolicyFeature( - type=FeatureType.VISUAL, - shape=(3, 480, 640), - ) - self.input_features[key] = empty_camera - - def get_optimizer_preset(self) -> AdamWConfig: - return AdamWConfig( - lr=self.optimizer_lr, - betas=self.optimizer_betas, - eps=self.optimizer_eps, - weight_decay=self.optimizer_weight_decay, - grad_clip_norm=self.grad_clip_norm, - ) - - def get_scheduler_preset(self): - return CosineDecayWithWarmupSchedulerConfig( - peak_lr=self.optimizer_lr, - decay_lr=self.scheduler_decay_lr, - num_warmup_steps=self.scheduler_warmup_steps, - num_decay_steps=self.scheduler_decay_steps, - ) - - @property - def observation_delta_indices(self) -> None: - return None - - @property - def action_delta_indices(self) -> list: - return list(range(self.chunk_size)) - - @property - def reward_delta_indices(self) -> None: - return None diff --git a/src/lerobot/policies/pi0fast/modeling_pi0fast.py b/src/lerobot/policies/pi0fast/modeling_pi0fast.py deleted file mode 100644 index 102cfb8fa..000000000 --- a/src/lerobot/policies/pi0fast/modeling_pi0fast.py +++ /dev/null @@ -1,980 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models - -[Paper](https://huggingface.co/papers/2501.09747) -[Jax code](https://github.com/Physical-Intelligence/openpi) - -Designed by Physical Intelligence. Ported from Jax by Hugging Face. -Disclaimer: It is not expected to perform as well as the original implementation. - -Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`): -```bash -lerobot-train \ ---policy.path=lerobot/pi0fast_base \ ---dataset.repo_id=danaaubakirova/koch_test -``` - -Example of training the pi0+FAST neural network with from scratch: -```bash -lerobot-train \ ---policy.type=pi0fast \ ---dataset.repo_id=danaaubakirova/koch_test -``` - -Example of using the pi0 pretrained model outside LeRobot training framework: -```python -policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base") -``` - -""" - -from collections import deque -from functools import partial - -import numpy as np -import torch -import torch.nn.functional as F # noqa: N812 -from PIL import Image -from scipy.fft import idct -from torch import Tensor, nn -from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration -from transformers.cache_utils import HybridCache, StaticCache -from transformers.models.auto import CONFIG_MAPPING - -from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.constants import ACTION, OBS_STATE - -PRECISION = { - "float16": torch.float16, - "float32": torch.float32, - "bfloat16": torch.bfloat16, -} - - -def normalize(x, min_val, max_val): - return (x - min_val) / (max_val - min_val) - - -def unnormalize(x, min_val, max_val): - return x * (max_val - min_val) + min_val - - -def safe_arcsin(value): - # This ensures that the input stays within - # [−1,1] to avoid invalid values for arcsin - return torch.arcsin(torch.clamp(value, -1.0, 1.0)) - - -def aloha_gripper_to_angular(value): - # Aloha transforms the gripper positions into a linear space. The following code - # reverses this transformation to be consistent with pi0 which is pretrained in - # angular space. - # - # These values are coming from the Aloha code: - # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED - value = unnormalize(value, min_val=0.01844, max_val=0.05800) - - # This is the inverse of the angular to linear transformation inside the Interbotix code. - def linear_to_radian(linear_position, arm_length, horn_radius): - value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) - return safe_arcsin(value) - - # The constants are taken from the Interbotix code. - value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) - - # Normalize to [0, 1]. - # The values 0.4 and 1.5 were measured on an actual Trossen robot. - return normalize(value, min_val=0.4, max_val=1.5) - - -def aloha_gripper_from_angular(value): - # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. - # Note that the units are still angular but the range is different. - - # The values 0.4 and 1.5 were measured on an actual Trossen robot. - value = unnormalize(value, min_val=0.4, max_val=1.5) - - # These values are coming from the Aloha code: - # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE - return normalize(value, min_val=-0.6213, max_val=1.4910) - - -def aloha_gripper_from_angular_inv(value): - # Directly inverts the gripper_from_angular function. - value = unnormalize(value, min_val=-0.6213, max_val=1.4910) - return normalize(value, min_val=0.4, max_val=1.5) - - -class PI0FASTPolicy(PreTrainedPolicy): - """Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot.""" - - config_class = PI0FASTConfig - name = "pi0fast" - - def __init__( - self, - config: PI0FASTConfig, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, - ): - """ - Args: - config: Policy configuration class instance or None, in which case the default instantiation of - the configuration class is used. - dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected - that they will be passed with a call to `load_state_dict` before the policy is used. - """ - - super().__init__(config) - config.validate_features() - self.config = config - - self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224") - self.model = PI0FAST(config) - - self.reset() - - def reset(self): - """This should be called whenever the environment is reset.""" - self._action_queue = deque([], maxlen=self.config.n_action_steps) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - """Override the from_pretrained method to display important disclaimer.""" - print( - "⚠️ DISCLAIMER: The PI0FAST model is ported from JAX by the Hugging Face team. \n" - " It is not expected to perform as well as the original implementation. \n" - " Original implementation: https://github.com/Physical-Intelligence/openpi" - ) - return super().from_pretrained(*args, **kwargs) - - def get_optim_params(self) -> dict: - return self.parameters() - - def _pi_aloha_decode_state(self, state): - # Flip the joints. - for motor_idx in [1, 2, 8, 9]: - state[:, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) - return state - - def _pi_aloha_encode_actions(self, actions): - # Flip the joints. - for motor_idx in [1, 2, 8, 9]: - actions[:, :, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) - return actions - - def _pi_aloha_encode_actions_inv(self, actions): - # Flip the joints again. - for motor_idx in [1, 2, 8, 9]: - actions[:, :, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) - return actions - - @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: - """Predict a chunk of actions given environment observations.""" - raise NotImplementedError("Currently not implemented for PI0FAST") - - @torch.no_grad() - def select_action(self, batch: dict[str, Tensor]) -> Tensor: - """Select a single action given environment observations. - - This method wraps `select_actions` in order to return one action at a time for execution in the - environment. It works by managing the actions in a queue and only calling `select_actions` when the - queue is empty. - """ - self.eval() - - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - - # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by - # querying the policy. - if len(self._action_queue) == 0: - actions = self.model.generate_actions(batch) - - actions = actions[:, : self.config.n_action_steps] - - original_action_dim = self.config.action_feature.shape[ - 0 - ] # self.config.max_action_dim # self.config.action_feature.shape[0] - actions = actions[:, :, :original_action_dim] - - if self.config.adapt_to_pi_aloha: - actions = self._pi_aloha_encode_actions(actions) - - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(actions.transpose(0, 1)) - return self._action_queue.popleft() - - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) - loss_dict = self.model.forward(batch) - return loss_dict["loss"], loss_dict - - -def block_causal_update_causal_mask( - attention_mask, - token_type_ids=None, - past_key_values=None, - cache_position=None, - input_tensor=None, - attn_implementation: str = "eager", - dtype: torch.dtype = "float32", -): - """ - Update the causal mask during training and generation. It can be customized to different attention masks. - """ - if attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - using_static_cache = isinstance(past_key_values, StaticCache) - min_dtype = torch.finfo(dtype).min - - if input_tensor is None: - input_tensor = attention_mask - - inputs_lead_dim, sequence_length = input_tensor.shape[:2] - - if using_static_cache or isinstance(past_key_values, HybridCache): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else cache_position[0] + sequence_length + 1 - ) - - # Handle precomputed attention masks - if attention_mask is not None and attention_mask.dim() == 4: - return attention_mask - - # Causal mask initialization - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - - # Standard causal masking (triu ensures tokens can only attend to past) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - - # Apply block causal mask - if token_type_ids is not None: - token_type_ids = token_type_ids.to(causal_mask.device).bool() - cumsum = torch.cumsum(token_type_ids, dim=1) - block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None] - - # Combine causal_mask with block-wise attention mask - causal_mask = torch.where(block_causal_mask, 0.0, causal_mask) - causal_mask = causal_mask[:, None, :, :] - else: - # Apply past cache position constraint - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) - else: - # Apply past cache position constraint - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) - - if attention_mask is not None: - causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits - mask_length = attention_mask.shape[-1] - - # Apply padding mask - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -def prepare_inputs_for_generation( - # self, - input_ids, - past_key_values=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - pixel_values=None, - attention_mask=None, - token_type_ids=None, - use_cache=True, - num_logits_to_keep=None, - labels=None, - self=None, - **kwargs, -): - # create block causal attention - if cache_position[0] > 0 and input_ids.shape[1] > 0: - input_tensor = input_ids[:, -1:] - new_positions = ( - torch.ones( - (position_ids.shape[0], input_ids.shape[1]), - dtype=position_ids.dtype, - device=position_ids.device, - ).cumsum(-1) - + position_ids[:, -1:] - ) - position_ids = torch.cat([position_ids, new_positions], dim=-1) - else: - input_tensor = inputs_embeds - attention_mask = block_causal_update_causal_mask( - attention_mask=attention_mask, - past_key_values=past_key_values, - cache_position=cache_position, - input_tensor=input_tensor, - token_type_ids=token_type_ids, - dtype=self.dtype, - attn_implementation=self.config.text_config._attn_implementation, - ) - # Overwritten -- custom `position_ids` and `pixel_values` handling - model_inputs = self.language_model.prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - cache_position=cache_position, - use_cache=use_cache, - num_logits_to_keep=num_logits_to_keep, - token_type_ids=token_type_ids, - **kwargs, - ) - - # Position_ids in Paligemma are 1-indexed - if model_inputs.get("position_ids") is not None: - model_inputs["position_ids"] += 1 - # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore - # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always - if cache_position[0] == 0: - model_inputs["pixel_values"] = pixel_values - is_training = token_type_ids is not None and labels is not None - if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): - input_tensor = inputs_embeds if inputs_embeds is not None else input_ids - causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training - ) - model_inputs["attention_mask"] = causal_mask - - return model_inputs - - -class PI0FAST(nn.Module): - def __init__(self, config: PI0FASTConfig): - super().__init__() - self.config = config - - # TODO: move tokenizers in Policy - fast_tokenizer_path = "physical-intelligence/fast" - pi0_paligemma_path = "google/paligemma-3b-pt-224" - self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path) - self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path) - self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True) - self.fast_skip_tokens = self.config.fast_skip_tokens - self.max_input_seq_len = self.config.max_input_seq_len - self.action_horizon = self.config.chunk_size - self.action_dim = self.config.action_feature.shape[ - 0 - ] # self.config.max_action_dim # self.config.action_feature.shape[0] - precision = config.precision - torch_precision = PRECISION.get(precision, torch.float32) - self.pad_token_id = ( - self.paligemma_tokenizer.pad_token_id - if hasattr(self.paligemma_tokenizer, "pad_token_id") - else self.paligemma_tokenizer.eos_token_id - ) - - paligemma_config = CONFIG_MAPPING["paligemma"]( - transformers_version="4.48.1", - _vocab_size=257152, - bos_token_id=2, - eos_token_id=1, - hidden_size=2048, - image_token_index=257152, - model_type="paligemma", - pad_token_id=0, - projection_dim=2048, - text_config={ - "hidden_activation": "gelu_pytorch_tanh", - "hidden_size": 2048, - "intermediate_size": 16384, - "model_type": "gemma", - "num_attention_heads": 8, - "num_hidden_layers": 18, - "num_image_tokens": 256, - "num_key_value_heads": 1, - "torch_dtype": precision, - "vocab_size": 257152, - "_attn_implementation": "eager", - }, - vision_config={ - "hidden_size": 1152, - "intermediate_size": 4304, - "model_type": "siglip_vision_model", - "num_attention_heads": 16, - "num_hidden_layers": 27, - "num_image_tokens": 256, - "patch_size": 14, - "projection_dim": 2048, - "projector_hidden_act": "gelu_pytorch_tanh", - "torch_dtype": precision, - "vision_use_head": False, - }, - ) - self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config) - - self.pi0_paligemma.prepare_inputs_for_generation = partial( - prepare_inputs_for_generation, self=self.pi0_paligemma - ) - # change important stuff in bf16 - params_to_change_dtype = [ - "language_model", - "vision_tower", - "multi_modal", - ] - for name, param in self.pi0_paligemma.named_parameters(): - if any(selector in name for selector in params_to_change_dtype): - param.data = param.data.to(dtype=torch_precision) - self.set_requires_grad() - self.image_keys = self.config.image_features.keys() - # TODO: Remove this once we bump transformers to >4.52.0 because the attribute will be removed - # AttributeError: 'PaliGemmaConfig' object has no attribute 'ignore_index' - self.ignore_index = self.pi0_paligemma.config.ignore_index - self.padding_side = self.config.padding_side - - def set_requires_grad(self): - if self.config.freeze_vision_encoder: - self.pi0_paligemma.vision_tower.eval() - for params in self.pi0_paligemma.vision_tower.parameters(): - params.requires_grad = False - # To avoid unused params issue with distributed training - if self.config.freeze_lm_head: - for name, params in self.pi0_paligemma.named_parameters(): - if "embed_tokens" in name: # lm heads and embedding layer are tied - params.requires_grad = False - - def embed_tokens(self, tokens: torch.Tensor): - return self.pi0_paligemma.language_model.model.embed_tokens(tokens) - - def prepare_inputs_for_generation(self, *args, **kwargs): - return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs) - - def prepare_images(self, batch): - """Preprocess LeRobot batch into Pi0 inputs""" - images = [] - img_masks = [] - present_img_keys = [key for key in self.image_keys if key in batch] - if len(present_img_keys) == 0: - raise ValueError( - f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" - ) - - # Preprocess image features present in the batch - num_empty_cameras = 0 - for key in self.image_keys: - if key in present_img_keys: - img = batch[key] - - if self.config.resize_imgs_with_padding is not None: - img = resize_with_pad( - img, - *self.config.resize_imgs_with_padding, - pad_value=0, - interpolate_like_pi=self.config.interpolate_like_pi, - ) - - # Normalize from range [0,1] to [-1,1] as expected by siglip - img = img * 2.0 - 1.0 - - bsize = img.shape[0] - device = img.device - mask = torch.ones(bsize, dtype=torch.bool, device=device) - else: - if num_empty_cameras >= self.config.empty_cameras: - continue - img = torch.ones_like(img) * -1 - bsize = img.shape[0] - device = img.device - mask = torch.ones(bsize, dtype=torch.bool, device=device) - num_empty_cameras += 1 - - images.append(img) - img_masks.append(mask) - return images, img_masks - - def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor: - mins = actions.amin(dim=(1, 2), keepdim=True) # [0] - maxs = actions.amax(dim=(1, 2), keepdim=True) # [0] - return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1 - - def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor: - out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens - return out - - def fast_tokenizer_wrapper(self, actions_norm): - """ - A wrapper for self.fast_tokenizer that ensures batch processing, - conversion to PyTorch tensors, and returns a dictionary without padding. - """ - batch_tokens = self.fast_tokenizer(actions_norm) - fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt") - - return fast_out - - def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor: - token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool) - # Compute cumulative sum mask - cumsum_mask = (padded_mask != 0).cumsum(dim=1) - # Suffix block (everything after prefix_len) - suffix_mask = cumsum_mask > prefix_len - token_type_ids = suffix_mask - return token_type_ids - - def create_input_tokens(self, state, lang_text, actions=None): - bsize = state.shape[0] - device = state.device - bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1] - discretized = torch.bucketize(state, bins) - 1 - discretized = discretized[:, :32] - - prefix_texts = [] - state_text = [] - for txt, disc in zip(lang_text, discretized, strict=False): - cleaned = txt.lower().strip().replace("_", " ") - state_str = " ".join(str(val.item()) for val in disc) - prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n") - state_text.append(f"State: {state_str};\n") - - prefix_out = self.paligemma_tokenizer( - prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False - ) - prefix_ids = prefix_out["input_ids"].to(device) - prefix_mask = prefix_out["attention_mask"].to(device) - prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu() - - if actions is not None: - actions_norm = self.normalize_actions(actions) - actions_pad = F.pad( - actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0 - )[:, :, : self.config.max_action_dim] - fast_out = self.fast_tokenizer_wrapper( - actions_pad.cpu(), - ) - act_ids = fast_out["input_ids"] - act_mask = fast_out["attention_mask"].to(device) - - act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device) - # Replace action with 0 to pad tokens - act_ids = torch.where( - act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens, - self.pad_token_id, - act_ids, - ) - - eos_token = torch.tensor( - [self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device - ).expand(bsize, -1) - eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1) - bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt") - bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device) - bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device) - act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1) - act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1) - act_mask = act_mask.to(device) - else: - act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device) - act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device) - final_ids = torch.cat([prefix_ids, act_ids], dim=1) - - final_mask = torch.cat([prefix_mask, act_mask], dim=1) - batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()} - - # Use tokenizer pad function - padded_output = self.paligemma_tokenizer.pad( - batch_inputs, padding="longest", max_length=180, return_tensors="pt" - ) - padded_mask = padded_output["attention_mask"] - - # define tensor of padding lengths - att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens - - token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens) - - padded_output["padded_mask"] = padded_output.pop("attention_mask") - padded_output["attention_mask"] = att_mask - # loss is computed not on prefix, and not on padding - padded_output["loss_mask"] = att_mask & padded_output["padded_mask"] - padded_output["token_type_ids"] = token_type_ids - return padded_output - - def shift_padding_side( - self, - tokens: torch.Tensor, - ar_mask: torch.Tensor, - padding_mask: torch.Tensor, - loss_mask: torch.Tensor, - targets: torch.Tensor, - token_type_ids: torch.Tensor, - padding_side: str = "right", - ) -> tuple[torch.Tensor]: - if padding_side not in ["right", "left"]: - return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids - - new_tokens = torch.empty_like(tokens) - new_ar_masks = torch.empty_like(ar_mask) - new_padding_mask = torch.empty_like(padding_mask) - new_loss_mask = torch.empty_like(loss_mask) - new_targets = torch.empty_like(targets) - new_token_type_ids = torch.empty_like(token_type_ids) - batch_size = tokens.shape[0] - for i in range(batch_size): - padding_indices = torch.where(padding_mask[i] == 0)[0] - non_padding_indices = torch.where(padding_mask[i] == 1)[0] - if padding_side == "left": - new_indices = torch.cat((padding_indices, non_padding_indices), dim=0) - else: - new_indices = torch.cat((non_padding_indices, padding_indices), dim=0) - new_tokens[i] = tokens[i].index_select(0, new_indices) - new_ar_masks[i] = ar_mask[i].index_select(0, new_indices) - new_padding_mask[i] = padding_mask[i].index_select(0, new_indices) - new_loss_mask[i] = loss_mask[i].index_select(0, new_indices) - new_targets[i] = targets[i].index_select(0, new_indices) - new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices) - - return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids - - def forward(self, batch: dict[str, Tensor]): - device = batch[OBS_STATE].device - # TODO: keep like this or move to the policy .forward - images, img_masks = self.prepare_images(batch) - - padded_outs = self.create_input_tokens( - state=batch[OBS_STATE], - lang_text=batch["task"], - actions=batch[ACTION], - ) - - embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs( - images, - img_masks, - padded_outs["input_ids"], - padded_outs["padded_mask"], - padded_outs["attention_mask"], - padded_outs["loss_mask"], - padded_outs["token_type_ids"], - padding_side=self.padding_side, - ) - position_ids = torch.cumsum(pad_masks, dim=1) - 1 - token_type_ids = token_type_ids.to(dtype=torch.int64) - past_seen_tokens = 0 - cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device) - pad_masks = block_causal_update_causal_mask( - attention_mask=pad_masks, - past_key_values=None, - cache_position=cache_position, - input_tensor=embs, - token_type_ids=token_type_ids, - dtype=self.pi0_paligemma.dtype, - attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation, - ) - outputs = self.pi0_paligemma.forward( - input_ids=None, - token_type_ids=None, - attention_mask=pad_masks, - position_ids=position_ids, - past_key_values=None, - inputs_embeds=embs, - use_cache=False, - labels=None, - ) - - logits = outputs.logits - - loss_fct = nn.CrossEntropyLoss(reduction="none") - - # Shift left for next-step prediction - logits = logits[:, :-1, :] - targets = targets[:, 1:].to(device) # Shift targets - loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape - - # Compute per-token loss - token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1)) - - # Apply loss mask - token_loss = token_loss * loss_mask.reshape(-1) - - # Compute final loss - loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1) - - # Return loss dictionary - loss_dict = {"ce_loss": loss.item(), "loss": loss} - return loss_dict - - def decode_actions_with_fast( - self, - tokens: list[list[int]], - *, - time_horizon: int | None = None, - action_dim: int | None = None, - relaxed_decoding: bool = True, - ) -> np.array: - """ - Adapt original decoding in FAST to always return actions instead of zeros. - """ - self.time_horizon = ( - time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon - ) - self.action_dim = ( - action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim - ) - - # Cache the time horizon and action dimension for the next call - self.called_time_horizon = self.time_horizon - self.called_action_dim = self.action_dim - - assert self.time_horizon is not None and self.action_dim is not None, ( - "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim." - ) - - decoded_actions = [] - for token in tokens: - try: - decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token) - decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token - if relaxed_decoding: - # Expected sequence length - expected_seq_len = self.time_horizon * self.action_dim - diff = expected_seq_len - decoded_dct_coeff.shape[0] - # Apply truncation if too long - if diff < 0: - decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right - # Apply padding if too short - elif diff > 0: - decoded_dct_coeff = np.pad( - decoded_dct_coeff, (0, diff), mode="constant", constant_values=0 - ) - - decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim) - assert decoded_dct_coeff.shape == ( - self.time_horizon, - self.action_dim, - ), ( - f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})" - ) - except Exception as e: - print(f"Error decoding tokens: {e}") - print(f"Tokens: {token}") - decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim)) - decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho")) - return np.stack(decoded_actions) - - def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor: - """ - Extracts actions from predicted output tokens using the FAST model. - - Args: - tokens (torch.Tensor): The input tensor of tokenized outputs. - action_horizon (int): The number of timesteps for actions. - action_dim (int): The dimensionality of each action. - - Returns: - torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim). - """ - # Decode predicted output tokens - decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True) - cleaned_tokens = [ - tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip() - for tokens_sequence in decoded_tokens - ] - raw_action_tokens = [ - self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False) - for sample_tokens in cleaned_tokens - ] # something like this should be robust #looks good - action_tokens = [ - self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens - ] - # returns the tensor of decoded actions per sample in a list - decoded_actions = [ - torch.tensor( - self.decode_actions_with_fast( - tok.tolist(), - time_horizon=action_horizon, - action_dim=action_dim, - relaxed_decoding=self.config.relaxed_action_decoding, - ), - device=tokens.device, - ).squeeze(0) - for tok in action_tokens - ] - - return torch.stack( - decoded_actions, - dim=0, - ) - - def generate_actions(self, batch: dict[str, Tensor]): - # TODO: keep like this or move to the policy .forward - images, img_masks = self.prepare_images(batch) - - padded_outs = self.create_input_tokens(state=batch[OBS_STATE], lang_text=batch["task"], actions=None) - embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs( - images, - img_masks, - padded_outs["input_ids"], - padded_outs["padded_mask"], - padded_outs["attention_mask"], - padded_outs["loss_mask"], - padded_outs["token_type_ids"], - padding_side="left", - ) - token_type_ids = token_type_ids.to(dtype=torch.int64) - prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1 - output_tokens = self.pi0_paligemma.generate( - input_ids=None, - attention_mask=pad_masks, - position_ids=prefix_position_ids, - past_key_values=None, - inputs_embeds=embs, - use_cache=self.config.use_cache, - max_new_tokens=self.config.max_decoding_steps, - do_sample=False, - num_beams=1, - token_type_ids=token_type_ids, - ) - actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim) - return actions - - def embed_image(self, image: torch.Tensor): - # Handle different transformers versions - if hasattr(self.pi0_paligemma, "get_image_features"): - return self.pi0_paligemma.get_image_features(image) - else: - return self.pi0_paligemma.model.get_image_features(image) - - def embed_inputs( - self, - images, - img_masks, - tokens, - pad_mask, - ar_mask, - loss_mask, - token_type_ids, - padding_side: str = "right", - ): - # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty - # images are a list of same size - # vectorizing everything! - device = images[0].device - image_embedding_dim = images[0].shape[-1] # TODO should be from self.config - all_images = torch.stack(images, dim=1).to(device) - b, n, c, h, w = all_images.shape - all_images = all_images.view(b * n, c, h, w) - embedded = self.embed_image(all_images).to(device) - b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions - m = b_n // b # Compute the number of images per sample dynamically - - # Reshape dynamically - embedded = embedded.view(b, m, p, image_embedding_dim) - tokens_embs = self.embed_tokens(tokens.to(device)) - - img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device) - num_img_emb = embedded.shape[2] - img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1) - img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1) - - image_target_tokens = ( - torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id - ).reshape(b, -1) - image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1) - - embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D) - - embs = torch.cat([embedded, tokens_embs], dim=1).to(device) - pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1) - att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1) - loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1) - targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1) - token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1) - - # Shift pad tokens to the left (.generate()) or right (.train()) - embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side( - embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side - ) - - targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets) - return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids - - -def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True): - # assume no-op when width height fits already - if img.ndim != 4: - raise ValueError(f"(b,c,h,w) expected, but {img.shape}") - - cur_height, cur_width = img.shape[2:] - - ratio = max(cur_width / width, cur_height / height) - resized_height = int(cur_height / ratio) - resized_width = int(cur_width / ratio) - - if interpolate_like_pi: - img = (img * 255.0).to(dtype=torch.uint8) - img = img.permute(0, 2, 3, 1) - original_device = img.device - img = img.to(device="cpu").numpy() - imgs = [] - for sub_img in img: - sub_img = Image.fromarray(sub_img) - resized_img = sub_img.resize((resized_width, resized_height), resample=2) - resized_img = torch.from_numpy(np.array(resized_img)) - imgs.append(resized_img) - img = torch.stack(imgs, dim=0) - img = img.permute(0, 3, 1, 2) - resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0 - else: - resized_img = F.interpolate( - img, size=(resized_height, resized_width), mode="bilinear", align_corners=False - ) - - pad_height = max(0, int(height - resized_height)) - pad_width = max(0, int(width - resized_width)) - - # pad on left and top of image - padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) - return padded_img diff --git a/src/lerobot/policies/pi0fast/processor_pi0fast.py b/src/lerobot/policies/pi0fast/processor_pi0fast.py deleted file mode 100644 index 95b5e541b..000000000 --- a/src/lerobot/policies/pi0fast/processor_pi0fast.py +++ /dev/null @@ -1,92 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -import torch - -from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig -from lerobot.processor import ( - AddBatchDimensionProcessorStep, - DeviceProcessorStep, - NormalizerProcessorStep, - PolicyAction, - PolicyProcessorPipeline, - RenameObservationsProcessorStep, - UnnormalizerProcessorStep, -) -from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action -from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME - - -def make_pi0fast_pre_post_processors( - config: PI0FASTConfig, - dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, -) -> tuple[ - PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], - PolicyProcessorPipeline[PolicyAction, PolicyAction], -]: - """ - Constructs pre-processor and post-processor pipelines for the PI0Fast policy. - - The pre-processing pipeline prepares input data for the model by: - 1. Renaming features to match pretrained configurations. - 2. Normalizing input and output features based on dataset statistics. - 3. Adding a batch dimension. - 4. Moving all data to the specified device. - - The post-processing pipeline handles the model's output by: - 1. Moving data to the CPU. - 2. Unnormalizing the output features to their original scale. - - Args: - config: The configuration object for the PI0Fast policy. - dataset_stats: A dictionary of statistics for normalization. - preprocessor_kwargs: Additional arguments for the pre-processor pipeline. - postprocessor_kwargs: Additional arguments for the post-processor pipeline. - - Returns: - A tuple containing the configured pre-processor and post-processor pipelines. - """ - - input_steps = [ - RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one - AddBatchDimensionProcessorStep(), - DeviceProcessorStep(device=config.device), - NormalizerProcessorStep( - features={**config.input_features, **config.output_features}, - norm_map=config.normalization_mapping, - stats=dataset_stats, - ), - ] - output_steps = [ - UnnormalizerProcessorStep( - features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats - ), - DeviceProcessorStep(device="cpu"), - ] - return ( - PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( - steps=input_steps, - name=POLICY_PREPROCESSOR_DEFAULT_NAME, - ), - PolicyProcessorPipeline[PolicyAction, PolicyAction]( - steps=output_steps, - name=POLICY_POSTPROCESSOR_DEFAULT_NAME, - to_transition=policy_action_to_transition, - to_output=transition_to_policy_action, - ), - ) diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index 5a3994cdf..21b39a80e 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -16,10 +16,16 @@ import logging from collections import deque +from typing import Any +import numpy as np import torch from torch import nn +from lerobot.datasets.utils import build_dataset_frame +from lerobot.processor import PolicyAction, RobotAction, RobotObservation +from lerobot.utils.constants import ACTION, OBS_STR + def populate_queues( queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None @@ -85,3 +91,110 @@ def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str]) logging.warning(f"Missing key(s) when loading model: {missing_keys}") if unexpected_keys: logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}") + + +# TODO(Steven): Move this function to a proper preprocessor step +def prepare_observation_for_inference( + observation: dict[str, np.ndarray], + device: torch.device, + task: str | None = None, + robot_type: str | None = None, +) -> RobotObservation: + """Converts observation data to model-ready PyTorch tensors. + + This function takes a dictionary of NumPy arrays, performs necessary + preprocessing, and prepares it for model inference. The steps include: + 1. Converting NumPy arrays to PyTorch tensors. + 2. Normalizing and permuting image data (if any). + 3. Adding a batch dimension to each tensor. + 4. Moving all tensors to the specified compute device. + 5. Adding task and robot type information to the dictionary. + + Args: + observation: A dictionary mapping observation names (str) to NumPy + array data. For images, the format is expected to be (H, W, C). + device: The PyTorch device (e.g., 'cpu' or 'cuda') to which the + tensors will be moved. + task: An optional string identifier for the current task. + robot_type: An optional string identifier for the robot being used. + + Returns: + A dictionary where values are PyTorch tensors preprocessed for + inference, residing on the target device. Image tensors are reshaped + to (C, H, W) and normalized to a [0, 1] range. + """ + for name in observation: + observation[name] = torch.from_numpy(observation[name]) + if "image" in name: + observation[name] = observation[name].type(torch.float32) / 255 + observation[name] = observation[name].permute(2, 0, 1).contiguous() + observation[name] = observation[name].unsqueeze(0) + observation[name] = observation[name].to(device) + + observation["task"] = task if task else "" + observation["robot_type"] = robot_type if robot_type else "" + + return observation + + +def build_inference_frame( + observation: dict[str, Any], + device: torch.device, + ds_features: dict[str, dict], + task: str | None = None, + robot_type: str | None = None, +) -> RobotObservation: + """Constructs a model-ready observation tensor dict from a raw observation. + + This utility function orchestrates the process of converting a raw, + unstructured observation from an environment into a structured, + tensor-based format suitable for passing to a policy model. + + Args: + observation: The raw observation dictionary, which may contain + superfluous keys. + device: The target PyTorch device for the final tensors. + ds_features: A configuration dictionary that specifies which features + to extract from the raw observation. + task: An optional string identifier for the current task. + robot_type: An optional string identifier for the robot being used. + + Returns: + A dictionary of preprocessed tensors ready for model inference. + """ + # Extracts the correct keys from the incoming raw observation + observation = build_dataset_frame(ds_features, observation, prefix=OBS_STR) + + # Performs the necessary conversions to the observation + observation = prepare_observation_for_inference(observation, device, task, robot_type) + + return observation + + +def make_robot_action(action_tensor: PolicyAction, ds_features: dict[str, dict]) -> RobotAction: + """Converts a policy's output tensor into a dictionary of named actions. + + This function translates the numerical output from a policy model into a + human-readable and robot-consumable format, where each dimension of the + action tensor is mapped to a named motor or actuator command. + + Args: + action_tensor: A PyTorch tensor representing the policy's action, + typically with a batch dimension (e.g., shape [1, action_dim]). + ds_features: A configuration dictionary containing metadata, including + the names corresponding to each index of the action tensor. + + Returns: + A dictionary mapping action names (e.g., "joint_1_motor") to their + corresponding floating-point values, ready to be sent to a robot + controller. + """ + # TODO(Steven): Check if these steps are already in all postprocessor policies + action_tensor = action_tensor.squeeze(0) + action_tensor = action_tensor.to("cpu") + + action_names = ds_features[ACTION]["names"] + act_processed_policy: RobotAction = { + f"{name}": float(action_tensor[i]) for i, name in enumerate(action_names) + } + return act_processed_policy diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index 917e4e2cc..81aa29c48 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -607,6 +607,7 @@ class ReplayBuffer: lerobot_dataset.save_episode() lerobot_dataset.stop_image_writer() + lerobot_dataset.finalize() return lerobot_dataset diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index ad36f1b36..f9c9d0d7a 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -696,7 +696,7 @@ def control_loop( episode_idx += 1 if dataset is not None: - if transition[TransitionKey.INFO].get("rerecord_episode", False): + if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False): logging.info(f"Re-recording episode {episode_idx}") dataset.clear_episode_buffer() episode_idx -= 1 diff --git a/src/lerobot/rl/wandb_utils.py b/src/lerobot/rl/wandb_utils.py index 8c92e7145..7b7f8a57b 100644 --- a/src/lerobot/rl/wandb_utils.py +++ b/src/lerobot/rl/wandb_utils.py @@ -99,7 +99,7 @@ class WandBLogger: cfg.wandb.run_id = run_id # Handle custom step key for rl asynchronous training. self._wandb_custom_step_key: set[str] | None = None - print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) + logging.info(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") self._wandb = wandb diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py new file mode 100644 index 000000000..83ba027bc --- /dev/null +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Edit LeRobot datasets using various transformation tools. + +This script allows you to delete episodes, split datasets, merge datasets, +and remove features. When new_repo_id is specified, creates a new dataset. + +Usage Examples: + +Delete episodes 0, 2, and 5 from a dataset: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" + +Delete episodes and save to a new dataset: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht_filtered \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" + +Split dataset by fractions: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": 0.8, "val": 0.2}' + +Split dataset by episode indices: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}' + +Split into more than two splits: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}' + +Merge multiple datasets: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht_merged \ + --operation.type merge \ + --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" + +Remove camera feature: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type remove_feature \ + --operation.feature_names "['observation.images.top']" + +Using JSON config file: + python -m lerobot.scripts.lerobot_edit_dataset \ + --config_path path/to/edit_config.json +""" + +import logging +import shutil +from dataclasses import dataclass +from pathlib import Path + +from lerobot.configs import parser +from lerobot.datasets.dataset_tools import ( + delete_episodes, + merge_datasets, + remove_feature, + split_dataset, +) +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.utils import init_logging + + +@dataclass +class DeleteEpisodesConfig: + type: str = "delete_episodes" + episode_indices: list[int] | None = None + + +@dataclass +class SplitConfig: + type: str = "split" + splits: dict[str, float | list[int]] | None = None + + +@dataclass +class MergeConfig: + type: str = "merge" + repo_ids: list[str] | None = None + + +@dataclass +class RemoveFeatureConfig: + type: str = "remove_feature" + feature_names: list[str] | None = None + + +@dataclass +class EditDatasetConfig: + repo_id: str + operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig + root: str | None = None + new_repo_id: str | None = None + push_to_hub: bool = False + + +def get_output_path(repo_id: str, new_repo_id: str | None, root: Path | None) -> tuple[str, Path]: + if new_repo_id: + output_repo_id = new_repo_id + output_dir = root / new_repo_id if root else HF_LEROBOT_HOME / new_repo_id + else: + output_repo_id = repo_id + dataset_path = root / repo_id if root else HF_LEROBOT_HOME / repo_id + old_path = Path(str(dataset_path) + "_old") + + if dataset_path.exists(): + if old_path.exists(): + shutil.rmtree(old_path) + shutil.move(str(dataset_path), str(old_path)) + + output_dir = dataset_path + + return output_repo_id, output_dir + + +def handle_delete_episodes(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, DeleteEpisodesConfig): + raise ValueError("Operation config must be DeleteEpisodesConfig") + + if not cfg.operation.episode_indices: + raise ValueError("episode_indices must be specified for delete_episodes operation") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + output_repo_id, output_dir = get_output_path( + cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None + ) + + if cfg.new_repo_id is None: + dataset.root = Path(str(dataset.root) + "_old") + + logging.info(f"Deleting episodes {cfg.operation.episode_indices} from {cfg.repo_id}") + new_dataset = delete_episodes( + dataset, + episode_indices=cfg.operation.episode_indices, + output_dir=output_dir, + repo_id=output_repo_id, + ) + + logging.info(f"Dataset saved to {output_dir}") + logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}") + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {output_repo_id}") + LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() + + +def handle_split(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, SplitConfig): + raise ValueError("Operation config must be SplitConfig") + + if not cfg.operation.splits: + raise ValueError( + "splits dict must be specified with split names as keys and fractions/episode lists as values" + ) + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + + logging.info(f"Splitting dataset {cfg.repo_id} with splits: {cfg.operation.splits}") + split_datasets = split_dataset(dataset, splits=cfg.operation.splits) + + for split_name, split_ds in split_datasets.items(): + split_repo_id = f"{cfg.repo_id}_{split_name}" + logging.info( + f"{split_name}: {split_ds.meta.total_episodes} episodes, {split_ds.meta.total_frames} frames" + ) + + if cfg.push_to_hub: + logging.info(f"Pushing {split_name} split to hub as {split_repo_id}") + LeRobotDataset(split_ds.repo_id, root=split_ds.root).push_to_hub() + + +def handle_merge(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, MergeConfig): + raise ValueError("Operation config must be MergeConfig") + + if not cfg.operation.repo_ids: + raise ValueError("repo_ids must be specified for merge operation") + + if not cfg.repo_id: + raise ValueError("repo_id must be specified as the output repository for merged dataset") + + logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge") + datasets = [LeRobotDataset(repo_id, root=cfg.root) for repo_id in cfg.operation.repo_ids] + + output_dir = Path(cfg.root) / cfg.repo_id if cfg.root else HF_LEROBOT_HOME / cfg.repo_id + + logging.info(f"Merging datasets into {cfg.repo_id}") + merged_dataset = merge_datasets( + datasets, + output_repo_id=cfg.repo_id, + output_dir=output_dir, + ) + + logging.info(f"Merged dataset saved to {output_dir}") + logging.info( + f"Episodes: {merged_dataset.meta.total_episodes}, Frames: {merged_dataset.meta.total_frames}" + ) + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {cfg.repo_id}") + LeRobotDataset(merged_dataset.repo_id, root=output_dir).push_to_hub() + + +def handle_remove_feature(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, RemoveFeatureConfig): + raise ValueError("Operation config must be RemoveFeatureConfig") + + if not cfg.operation.feature_names: + raise ValueError("feature_names must be specified for remove_feature operation") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + output_repo_id, output_dir = get_output_path( + cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None + ) + + if cfg.new_repo_id is None: + dataset.root = Path(str(dataset.root) + "_old") + + logging.info(f"Removing features {cfg.operation.feature_names} from {cfg.repo_id}") + new_dataset = remove_feature( + dataset, + feature_names=cfg.operation.feature_names, + output_dir=output_dir, + repo_id=output_repo_id, + ) + + logging.info(f"Dataset saved to {output_dir}") + logging.info(f"Remaining features: {list(new_dataset.meta.features.keys())}") + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {output_repo_id}") + LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() + + +@parser.wrap() +def edit_dataset(cfg: EditDatasetConfig) -> None: + operation_type = cfg.operation.type + + if operation_type == "delete_episodes": + handle_delete_episodes(cfg) + elif operation_type == "split": + handle_split(cfg) + elif operation_type == "merge": + handle_merge(cfg) + elif operation_type == "remove_feature": + handle_remove_feature(cfg) + else: + raise ValueError( + f"Unknown operation type: {operation_type}\n" + f"Available operations: delete_episodes, split, merge, remove_feature" + ) + + +def main() -> None: + init_logging() + edit_dataset() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index d45be5c42..0fdec9286 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -27,7 +27,7 @@ lerobot-eval \ --eval.batch_size=10 \ --eval.n_episodes=10 \ --use_amp=false \ - --device=cuda + --policy.device=cuda ``` OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes. @@ -38,7 +38,7 @@ lerobot-eval \ --eval.batch_size=10 \ --eval.n_episodes=10 \ --use_amp=false \ - --device=cuda + --policy.device=cuda ``` Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files. @@ -180,9 +180,15 @@ def rollout( render_callback(env) # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't - # available of none of the envs finished. + # available if none of the envs finished. if "final_info" in info: - successes = [info["is_success"] if info is not None else False for info in info["final_info"]] + final_info = info["final_info"] + if not isinstance(final_info, dict): + raise RuntimeError( + "Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). " + "You're likely using an older version of gymnasium (< 1.0). Please upgrade." + ) + successes = final_info["is_success"].tolist() else: successes = [False] * env.num_envs diff --git a/src/lerobot/scripts/lerobot_find_cameras.py b/src/lerobot/scripts/lerobot_find_cameras.py index e17dca805..0248a2768 100644 --- a/src/lerobot/scripts/lerobot_find_cameras.py +++ b/src/lerobot/scripts/lerobot_find_cameras.py @@ -180,7 +180,7 @@ def create_camera_instance(cam_meta: dict[str, Any]) -> dict[str, Any] | None: if instance: logger.info(f"Connecting to {cam_type} camera: {cam_id}...") - instance.connect(warmup=False) + instance.connect(warmup=True) return {"instance": instance, "meta": cam_meta} except Exception as e: logger.error(f"Failed to connect or configure {cam_type} camera {cam_id}: {e}") diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 91cbbfb86..28da73be2 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -79,6 +79,7 @@ from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts from lerobot.datasets.video_utils import VideoEncodingManager from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.utils import make_robot_action from lerobot.processor import ( PolicyAction, PolicyProcessorPipeline, @@ -324,10 +325,7 @@ def record_loop( robot_type=robot.robot_type, ) - action_names = dataset.features[ACTION]["names"] - act_processed_policy: RobotAction = { - f"{name}": float(action_values[i]) for i, name in enumerate(action_names) - } + act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features) elif policy is None and isinstance(teleop, Teleoperator): act = teleop.get_action() diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 7023cb1d0..f4b1118c3 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -21,8 +21,8 @@ from pprint import pformat from typing import Any import torch +from accelerate import Accelerator from termcolor import colored -from torch.amp import GradScaler from torch.optim import Optimizer from lerobot.configs import parser @@ -35,7 +35,6 @@ from lerobot.envs.utils import close_envs from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.utils import get_device_from_parameters from lerobot.rl.wandb_utils import WandBLogger from lerobot.scripts.lerobot_eval import eval_policy_all from lerobot.utils.logging_utils import AverageMeter, MetricsTracker @@ -49,7 +48,6 @@ from lerobot.utils.train_utils import ( ) from lerobot.utils.utils import ( format_big_number, - get_safe_torch_device, has_method, init_logging, ) @@ -61,16 +59,15 @@ def update_policy( batch: Any, optimizer: Optimizer, grad_clip_norm: float, - grad_scaler: GradScaler, + accelerator: Accelerator, lr_scheduler=None, - use_amp: bool = False, lock=None, ) -> tuple[MetricsTracker, dict]: """ Performs a single training step to update the policy's weights. This function executes the forward and backward passes, clips gradients, and steps the optimizer and - learning rate scheduler. It also handles mixed-precision training via a GradScaler. + learning rate scheduler. Accelerator handles mixed-precision training automatically. Args: train_metrics: A MetricsTracker instance to record training statistics. @@ -78,9 +75,8 @@ def update_policy( batch: A batch of training data. optimizer: The optimizer used to update the policy's parameters. grad_clip_norm: The maximum norm for gradient clipping. - grad_scaler: The GradScaler for automatic mixed-precision training. + accelerator: The Accelerator instance for distributed training and mixed precision. lr_scheduler: An optional learning rate scheduler. - use_amp: A boolean indicating whether to use automatic mixed precision. lock: An optional lock for thread-safe optimizer updates. Returns: @@ -89,28 +85,27 @@ def update_policy( - A dictionary of outputs from the policy's forward pass, for logging purposes. """ start_time = time.perf_counter() - device = get_device_from_parameters(policy) policy.train() - with torch.autocast(device_type=device.type) if use_amp else nullcontext(): + + # Let accelerator handle mixed precision + with accelerator.autocast(): loss, output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) - grad_scaler.scale(loss).backward() - # Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**. - grad_scaler.unscale_(optimizer) + # Use accelerator's backward method + accelerator.backward(loss) - grad_norm = torch.nn.utils.clip_grad_norm_( - policy.parameters(), - grad_clip_norm, - error_if_nonfinite=False, - ) + # Clip gradients if specified + if grad_clip_norm > 0: + grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.parameters(), float("inf"), error_if_nonfinite=False + ) - # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, - # although it still skips optimizer.step() if the gradients contain infs or NaNs. + # Optimizer step with lock if lock is not None else nullcontext(): - grad_scaler.step(optimizer) - # Updates the scale for next iteration. - grad_scaler.update() + optimizer.step() optimizer.zero_grad() @@ -118,9 +113,9 @@ def update_policy( if lr_scheduler is not None: lr_scheduler.step() - if has_method(policy, "update"): - # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). - policy.update() + # Update internal buffers if policy has update method + if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): + accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update() train_metrics.loss = loss.item() train_metrics.grad_norm = grad_norm.item() @@ -209,7 +204,7 @@ def wrap_policy_in_peft_model(cfg, policy): @parser.wrap() -def train(cfg: TrainPipelineConfig): +def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): """ Main function to train a policy. @@ -223,36 +218,68 @@ def train(cfg: TrainPipelineConfig): Args: cfg: A `TrainPipelineConfig` object containing all training configurations. + accelerator: Optional Accelerator instance. If None, one will be created automatically. """ cfg.validate() - logging.info(pformat(cfg.to_dict())) - if cfg.wandb.enable and cfg.wandb.project: + # Create Accelerator if not provided + # It will automatically detect if running in distributed mode or single-process mode + # We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes + # We set find_unused_parameters=True to handle models with conditional computation + if accelerator is None: + from accelerate.utils import DistributedDataParallelKwargs + + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs]) + + init_logging(accelerator=accelerator) + + # Determine if this is the main process (for logging and checkpointing) + # When using accelerate, only the main process should log to avoid duplicate outputs + is_main_process = accelerator.is_main_process + + # Only log on main process + if is_main_process: + logging.info(pformat(cfg.to_dict())) + + # Initialize wandb only on main process + if cfg.wandb.enable and cfg.wandb.project and is_main_process: wandb_logger = WandBLogger(cfg) else: wandb_logger = None - logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + if is_main_process: + logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) if cfg.seed is not None: - set_seed(cfg.seed) + set_seed(cfg.seed, accelerator=accelerator) - # Check device is available - device = get_safe_torch_device(cfg.policy.device, log=True) + # Use accelerator's device + device = accelerator.device torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - logging.info("Creating dataset") - dataset = make_dataset(cfg) + # Dataset loading synchronization: main process downloads first to avoid race conditions + if is_main_process: + logging.info("Creating dataset") + dataset = make_dataset(cfg) + + accelerator.wait_for_everyone() + + # Now all other processes can safely load the dataset + if not is_main_process: + dataset = make_dataset(cfg) # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, # using the eval.py instead, with gym_dora environment and dora-rs. eval_env = None if cfg.eval_freq > 0 and cfg.env is not None: - logging.info("Creating env") + if is_main_process: + logging.info("Creating env") eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) - logging.info("Creating policy") + if is_main_process: + logging.info("Creating policy") policy = make_policy( cfg=cfg.policy, ds_meta=dataset.meta, @@ -262,6 +289,9 @@ def train(cfg: TrainPipelineConfig): logging.info("Using PEFT! Wrapping model.") policy = wrap_policy_in_peft_model(cfg, policy) + # Wait for all processes to finish policy creation before continuing + accelerator.wait_for_everyone() + # Create processors - only provide dataset_stats if not resuming from saved processors processor_kwargs = {} postprocessor_kwargs = {} @@ -293,9 +323,9 @@ def train(cfg: TrainPipelineConfig): **postprocessor_kwargs, ) - logging.info("Creating optimizer and scheduler") + if is_main_process: + logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) - grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) step = 0 # number of policy updates (forward + backward + optim) @@ -305,14 +335,18 @@ def train(cfg: TrainPipelineConfig): num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) - logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") - if cfg.env is not None: - logging.info(f"{cfg.env.task=}") - logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") - logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") - logging.info(f"{dataset.num_episodes=}") - logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") - logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") + if is_main_process: + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") + if cfg.env is not None: + logging.info(f"{cfg.env.task=}") + logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") + logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") + logging.info(f"{dataset.num_episodes=}") + num_processes = accelerator.num_processes + effective_bs = cfg.batch_size * num_processes + logging.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}") + logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") + logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") # create dataloader for offline training if hasattr(cfg.policy, "drop_n_last_frames"): @@ -335,7 +369,13 @@ def train(cfg: TrainPipelineConfig): sampler=sampler, pin_memory=device.type == "cuda", drop_last=False, - prefetch_factor=2, + prefetch_factor=2 if cfg.num_workers > 0 else None, + ) + + # Prepare everything with accelerator + accelerator.wait_for_everyone() + policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( + policy, optimizer, dataloader, lr_scheduler ) dl_iter = cycle(dataloader) @@ -349,11 +389,20 @@ def train(cfg: TrainPipelineConfig): "dataloading_s": AverageMeter("data_s", ":.3f"), } + # Use effective batch size for proper epoch calculation in distributed training + effective_batch_size = cfg.batch_size * accelerator.num_processes train_tracker = MetricsTracker( - cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step + effective_batch_size, + dataset.num_frames, + dataset.num_episodes, + train_metrics, + initial_step=step, + accelerator=accelerator, ) - logging.info("Start offline training on a fixed dataset") + if is_main_process: + logging.info("Start offline training on a fixed dataset") + for _ in range(step, cfg.steps): start_time = time.perf_counter() batch = next(dl_iter) @@ -366,16 +415,15 @@ def train(cfg: TrainPipelineConfig): batch, optimizer, cfg.optimizer.grad_clip_norm, - grad_scaler=grad_scaler, + accelerator=accelerator, lr_scheduler=lr_scheduler, - use_amp=cfg.policy.use_amp, ) # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # increment `step` here. step += 1 train_tracker.step() - is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 + is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 @@ -389,69 +437,90 @@ def train(cfg: TrainPipelineConfig): train_tracker.reset_averages() if cfg.save_checkpoint and is_saving_step: - logging.info(f"Checkpoint policy after step {step}") - checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) - save_checkpoint( - checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor - ) - update_last_checkpoint(checkpoint_dir) - if wandb_logger: - wandb_logger.log_policy(checkpoint_dir) - - if cfg.env and is_eval_step: - step_id = get_step_identifier(step, cfg.steps) - logging.info(f"Eval policy at step {step}") - with ( - torch.no_grad(), - torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), - ): - eval_info = eval_policy_all( - envs=eval_env, # dict[suite][task_id] -> vec_env - policy=policy, + if is_main_process: + logging.info(f"Checkpoint policy after step {step}") + checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) + save_checkpoint( + checkpoint_dir=checkpoint_dir, + step=step, + cfg=cfg, + policy=accelerator.unwrap_model(policy), + optimizer=optimizer, + scheduler=lr_scheduler, preprocessor=preprocessor, postprocessor=postprocessor, - n_episodes=cfg.eval.n_episodes, - videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", - max_episodes_rendered=4, - start_seed=cfg.seed, - max_parallel_tasks=cfg.env.max_parallel_tasks, ) - # overall metrics (suite-agnostic) - aggregated = eval_info["overall"] + update_last_checkpoint(checkpoint_dir) + if wandb_logger: + wandb_logger.log_policy(checkpoint_dir) - # optional: per-suite logging - for suite, suite_info in eval_info.items(): - logging.info("Suite %s aggregated: %s", suite, suite_info) + accelerator.wait_for_everyone() - # meters/tracker - eval_metrics = { - "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), - "pc_success": AverageMeter("success", ":.1f"), - "eval_s": AverageMeter("eval_s", ":.3f"), - } - eval_tracker = MetricsTracker( - cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step - ) - eval_tracker.eval_s = aggregated.pop("eval_s") - eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward") - eval_tracker.pc_success = aggregated.pop("pc_success") - if wandb_logger: - wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} - wandb_logger.log_dict(wandb_log_dict, step, mode="eval") - wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval") + if cfg.env and is_eval_step: + if is_main_process: + step_id = get_step_identifier(step, cfg.steps) + logging.info(f"Eval policy at step {step}") + with torch.no_grad(), accelerator.autocast(): + eval_info = eval_policy_all( + envs=eval_env, # dict[suite][task_id] -> vec_env + policy=accelerator.unwrap_model(policy), + preprocessor=preprocessor, + postprocessor=postprocessor, + n_episodes=cfg.eval.n_episodes, + videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", + max_episodes_rendered=4, + start_seed=cfg.seed, + max_parallel_tasks=cfg.env.max_parallel_tasks, + ) + # overall metrics (suite-agnostic) + aggregated = eval_info["overall"] + + # optional: per-suite logging + for suite, suite_info in eval_info.items(): + logging.info("Suite %s aggregated: %s", suite, suite_info) + + # meters/tracker + eval_metrics = { + "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), + "pc_success": AverageMeter("success", ":.1f"), + "eval_s": AverageMeter("eval_s", ":.3f"), + } + eval_tracker = MetricsTracker( + cfg.batch_size, + dataset.num_frames, + dataset.num_episodes, + eval_metrics, + initial_step=step, + accelerator=accelerator, + ) + eval_tracker.eval_s = aggregated.pop("eval_s") + eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward") + eval_tracker.pc_success = aggregated.pop("pc_success") + if wandb_logger: + wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} + wandb_logger.log_dict(wandb_log_dict, step, mode="eval") + wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval") + + accelerator.wait_for_everyone() if eval_env: close_envs(eval_env) - logging.info("End of training") - if cfg.policy.push_to_hub: - policy.push_model_to_hub(cfg) - preprocessor.push_to_hub(cfg.policy.repo_id) - postprocessor.push_to_hub(cfg.policy.repo_id) + if is_main_process: + logging.info("End of training") + + if cfg.policy.push_to_hub: + unwrapped_policy = accelerator.unwrap_model(policy) + unwrapped_policy.push_model_to_hub(cfg) + preprocessor.push_to_hub(cfg.policy.repo_id) + postprocessor.push_to_hub(cfg.policy.repo_id) + + # Properly clean up the distributed process group + accelerator.wait_for_everyone() + accelerator.end_training() def main(): - init_logging() train() diff --git a/src/lerobot/teleoperators/homunculus/homunculus_arm.py b/src/lerobot/teleoperators/homunculus/homunculus_arm.py index 21d73de2e..43116f5c0 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_arm.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_arm.py @@ -270,8 +270,15 @@ class HomunculusArm(Teleoperator): raw_values = None with self.serial_lock: if self.serial.in_waiting > 0: - self.serial.flush() - raw_values = self.serial.readline().decode("utf-8").strip().split(" ") + lines = [] + while self.serial.in_waiting > 0: + line = self.serial.read_until().decode("utf-8").strip() + if line: + lines.append(line.split(" ")) + + if lines: + raw_values = lines[-1] + if raw_values is None or len(raw_values) != 21: # 16 raw + 5 angle values continue diff --git a/src/lerobot/teleoperators/homunculus/homunculus_glove.py b/src/lerobot/teleoperators/homunculus/homunculus_glove.py index 251ecf56d..fefeec1e8 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_glove.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_glove.py @@ -304,8 +304,15 @@ class HomunculusGlove(Teleoperator): positions = None with self.serial_lock: if self.serial.in_waiting > 0: - self.serial.flush() - positions = self.serial.readline().decode("utf-8").strip().split(" ") + lines = [] + while self.serial.in_waiting > 0: + line = self.serial.read_until().decode("utf-8").strip() + if line: + lines.append(line.split(" ")) + + if lines: + positions = lines[-1] + if positions is None or len(positions) != len(self.joints): continue diff --git a/src/lerobot/templates/lerobot_modelcard_template.md b/src/lerobot/templates/lerobot_modelcard_template.md index 34af282b0..c59cf4183 100644 --- a/src/lerobot/templates/lerobot_modelcard_template.md +++ b/src/lerobot/templates/lerobot_modelcard_template.md @@ -19,8 +19,6 @@ [Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation. {% elif model_name == "vqbet" %} [VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills. -{% elif model_name == "pi0fast" %} -[Pi0-Fast](https://huggingface.co/papers/2501.09747) is a variant of Pi0 that uses a new tokenization method called FAST, which enables training of an autoregressive vision-language-action policy for high-frequency robotic tasks with improved performance and reduced training time. {% elif model_name == "pi0" %} **π₀ (Pi0)** diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py index 17371921c..7cfe177ef 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/utils/control_utils.py @@ -31,6 +31,7 @@ from deepdiff import DeepDiff from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import DEFAULT_FEATURES from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.utils import prepare_observation_for_inference from lerobot.processor import PolicyAction, PolicyProcessorPipeline from lerobot.robots import Robot @@ -102,17 +103,7 @@ def predict_action( torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), ): # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension - for name in observation: - observation[name] = torch.from_numpy(observation[name]) - if "image" in name: - observation[name] = observation[name].type(torch.float32) / 255 - observation[name] = observation[name].permute(2, 0, 1).contiguous() - observation[name] = observation[name].unsqueeze(0) - observation[name] = observation[name].to(device) - - observation["task"] = task if task else "" - observation["robot_type"] = robot_type if robot_type else "" - + observation = prepare_observation_for_inference(observation, device, task, robot_type) observation = preprocessor(observation) # Compute the next action with the policy @@ -121,12 +112,6 @@ def predict_action( action = postprocessor(action) - # Remove batch dimension - action = action.squeeze(0) - - # Move to cpu, if not already the case - action = action.to("cpu") - return action diff --git a/src/lerobot/utils/logging_utils.py b/src/lerobot/utils/logging_utils.py index b6404e66d..c4c1f42e0 100644 --- a/src/lerobot/utils/logging_utils.py +++ b/src/lerobot/utils/logging_utils.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable from typing import Any from lerobot.utils.utils import format_big_number @@ -84,6 +85,7 @@ class MetricsTracker: "samples", "episodes", "epochs", + "accelerator", ] def __init__( @@ -93,6 +95,7 @@ class MetricsTracker: num_episodes: int, metrics: dict[str, AverageMeter], initial_step: int = 0, + accelerator: Callable | None = None, ): self.__dict__.update(dict.fromkeys(self.__keys__)) self._batch_size = batch_size @@ -106,6 +109,7 @@ class MetricsTracker: self.samples = self.steps * self._batch_size self.episodes = self.samples / self._avg_samples_per_ep self.epochs = self.samples / self._num_frames + self.accelerator = accelerator def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any: if name in self.__dict__: @@ -128,7 +132,7 @@ class MetricsTracker: Updates metrics that depend on 'step' for one step. """ self.steps += 1 - self.samples += self._batch_size + self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1) self.episodes = self.samples / self._avg_samples_per_ep self.epochs = self.samples / self._num_frames diff --git a/src/lerobot/utils/random_utils.py b/src/lerobot/utils/random_utils.py index 1bb1f0631..b34d357aa 100644 --- a/src/lerobot/utils/random_utils.py +++ b/src/lerobot/utils/random_utils.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import random -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager from pathlib import Path from typing import Any @@ -164,14 +164,20 @@ def set_rng_state(random_state_dict: dict[str, Any]): torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) -def set_seed(seed) -> None: +def set_seed(seed, accelerator: Callable | None = None) -> None: """Set seed for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) + if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + if accelerator: + from accelerate.utils import set_seed as _accelerate_set_seed + + _accelerate_set_seed(seed) + @contextmanager def seeded_context(seed: int) -> Generator[None, None, None]: diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 8777d5a9d..4447a1fcf 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -27,6 +27,8 @@ from statistics import mean import numpy as np import torch +from accelerate import Accelerator +from datasets.utils.logging import disable_progress_bar, enable_progress_bar def inside_slurm(): @@ -109,36 +111,50 @@ def init_logging( display_pid: bool = False, console_level: str = "INFO", file_level: str = "DEBUG", + accelerator: Accelerator | None = None, ): + """Initialize logging configuration for LeRobot. + + In multi-GPU training, only the main process logs to console to avoid duplicate output. + Non-main processes have console logging suppressed but can still log to file. + + Args: + log_file: Optional file path to write logs to + display_pid: Include process ID in log messages (useful for debugging multi-process) + console_level: Logging level for console output + file_level: Logging level for file output + accelerator: Optional Accelerator instance (for multi-GPU detection) + """ + def custom_format(record: logging.LogRecord) -> str: dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fnameline = f"{record.pathname}:{record.lineno}" - - # NOTE: Display PID is useful for multi-process logging. - if display_pid: - pid_str = f"[PID: {os.getpid()}]" - message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.getMessage()}" - else: - message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.getMessage()}" - return message + pid_str = f"[PID: {os.getpid()}] " if display_pid else "" + return f"{record.levelname} {pid_str}{dt} {fnameline[-15:]:>15} {record.getMessage()}" formatter = logging.Formatter() formatter.format = custom_format logger = logging.getLogger() - logger.setLevel(logging.NOTSET) # Set the logger to the lowest level to capture all messages + logger.setLevel(logging.NOTSET) - # Remove unused default handlers - for handler in logger.handlers[:]: - logger.removeHandler(handler) + # Clear any existing handlers + logger.handlers.clear() - # Write logs to console - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - console_handler.setLevel(console_level.upper()) - logger.addHandler(console_handler) + # Determine if this is a non-main process in distributed training + is_main_process = accelerator.is_main_process if accelerator is not None else True + + # Console logging (main process only) + if is_main_process: + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + console_handler.setLevel(console_level.upper()) + logger.addHandler(console_handler) + else: + # Suppress console output for non-main processes + logger.addHandler(logging.NullHandler()) + logger.setLevel(logging.ERROR) - # Additionally write logs to file if log_file is not None: file_handler = logging.FileHandler(log_file) file_handler.setFormatter(formatter) @@ -247,6 +263,25 @@ def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float): return days, hours, minutes, seconds +class SuppressProgressBars: + """ + Context manager to suppress progress bars. + + Example + -------- + ```python + with SuppressProgressBars(): + # Code that would normally show progress bars + ``` + """ + + def __enter__(self): + disable_progress_bar() + + def __exit__(self, exc_type, exc_val, exc_tb): + enable_progress_bar() + + class TimerManager: """ Lightweight utility to measure elapsed time. diff --git a/tests/async_inference/test_e2e.py b/tests/async_inference/test_e2e.py index ebaef2ef1..11941ce32 100644 --- a/tests/async_inference/test_e2e.py +++ b/tests/async_inference/test_e2e.py @@ -139,7 +139,6 @@ def test_async_inference_e2e(monkeypatch): policy_type="test", pretrained_name_or_path="test", actions_per_chunk=20, - verify_robot_cameras=False, ) client = RobotClient(client_config) diff --git a/tests/async_inference/test_robot_client.py b/tests/async_inference/test_robot_client.py index dfdb8ce42..5b138d91b 100644 --- a/tests/async_inference/test_robot_client.py +++ b/tests/async_inference/test_robot_client.py @@ -51,7 +51,6 @@ def robot_client(): policy_type="test", pretrained_name_or_path="test", actions_per_chunk=20, - verify_robot_cameras=False, ) client = RobotClient(test_config) diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index 4f316f80e..b710a3a4b 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -181,6 +181,54 @@ def assert_dataset_iteration_works(aggr_ds): pass +def assert_video_timestamps_within_bounds(aggr_ds): + """Test that all video timestamps are within valid bounds for their respective video files. + + This catches bugs where timestamps point to frames beyond the actual video length, + which would cause "Invalid frame index" errors during data loading. + """ + try: + from torchcodec.decoders import VideoDecoder + except ImportError: + return + + for ep_idx in range(aggr_ds.num_episodes): + ep = aggr_ds.meta.episodes[ep_idx] + + for vid_key in aggr_ds.meta.video_keys: + from_ts = ep[f"videos/{vid_key}/from_timestamp"] + to_ts = ep[f"videos/{vid_key}/to_timestamp"] + video_path = aggr_ds.root / aggr_ds.meta.get_video_file_path(ep_idx, vid_key) + + if not video_path.exists(): + continue + + from_frame_idx = round(from_ts * aggr_ds.fps) + to_frame_idx = round(to_ts * aggr_ds.fps) + + try: + decoder = VideoDecoder(str(video_path)) + num_frames = len(decoder) + + # Verify timestamps don't exceed video bounds + assert from_frame_idx >= 0, ( + f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) < 0" + ) + assert from_frame_idx < num_frames, ( + f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= video frames ({num_frames})" + ) + assert to_frame_idx <= num_frames, ( + f"Episode {ep_idx}, {vid_key}: to_frame_idx ({to_frame_idx}) > video frames ({num_frames})" + ) + assert from_frame_idx < to_frame_idx, ( + f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= to_frame_idx ({to_frame_idx})" + ) + except Exception as e: + raise AssertionError( + f"Failed to verify timestamps for episode {ep_idx}, {vid_key}: {e}" + ) from e + + def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): """Test basic aggregation functionality with standard parameters.""" ds_0_num_frames = 400 @@ -227,6 +275,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): assert_metadata_consistency(aggr_ds, ds_0, ds_1) assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) assert_video_frames_integrity(aggr_ds, ds_0, ds_1) + assert_video_timestamps_within_bounds(aggr_ds) assert_dataset_iteration_works(aggr_ds) @@ -277,6 +326,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): assert_metadata_consistency(aggr_ds, ds_0, ds_1) assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) assert_video_frames_integrity(aggr_ds, ds_0, ds_1) + assert_video_timestamps_within_bounds(aggr_ds) assert_dataset_iteration_works(aggr_ds) # Check that multiple files were actually created due to small size limits @@ -290,3 +340,43 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): if video_dir.exists(): video_files = list(video_dir.rglob("*.mp4")) assert len(video_files) > 1, "Small file size limits should create multiple video files" + + +def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory): + """Regression test for video timestamp bug when merging datasets. + + This test specifically checks that video timestamps are correctly calculated + and accumulated when merging multiple datasets. + """ + datasets = [] + for i in range(3): + ds = lerobot_dataset_factory( + root=tmp_path / f"regression_{i}", + repo_id=f"{DUMMY_REPO_ID}_regression_{i}", + total_episodes=2, + total_frames=100, + ) + datasets.append(ds) + + aggregate_datasets( + repo_ids=[ds.repo_id for ds in datasets], + roots=[ds.root for ds in datasets], + aggr_repo_id=f"{DUMMY_REPO_ID}_regression_aggr", + aggr_root=tmp_path / "regression_aggr", + ) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "regression_aggr") + aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr") + + assert_video_timestamps_within_bounds(aggr_ds) + + for i in range(len(aggr_ds)): + item = aggr_ds[i] + for key in aggr_ds.meta.video_keys: + assert key in item, f"Video key {key} missing from item {i}" + assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}" diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py new file mode 100644 index 000000000..8bc1dbf6b --- /dev/null +++ b/tests/datasets/test_dataset_tools.py @@ -0,0 +1,1049 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for dataset tools utilities.""" + +from unittest.mock import patch + +import numpy as np +import pytest +import torch + +from lerobot.datasets.dataset_tools import ( + add_features, + delete_episodes, + merge_datasets, + modify_features, + remove_feature, + split_dataset, +) + + +@pytest.fixture +def sample_dataset(tmp_path, empty_lerobot_dataset_factory): + """Create a sample dataset for testing.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + dataset = empty_lerobot_dataset_factory( + root=tmp_path / "test_dataset", + features=features, + ) + + for ep_idx in range(5): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx % 2}", + } + dataset.add_frame(frame) + dataset.save_episode() + + dataset.finalize() + return dataset + + +def test_delete_single_episode(sample_dataset, tmp_path): + """Test deleting a single episode.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[2], + output_dir=output_dir, + ) + + assert new_dataset.meta.total_episodes == 4 + assert new_dataset.meta.total_frames == 40 + + episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]} + assert episode_indices == {0, 1, 2, 3} + + assert len(new_dataset) == 40 + + +def test_delete_multiple_episodes(sample_dataset, tmp_path): + """Test deleting multiple episodes.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[1, 3], + output_dir=output_dir, + ) + + assert new_dataset.meta.total_episodes == 3 + assert new_dataset.meta.total_frames == 30 + + episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]} + assert episode_indices == {0, 1, 2} + + +def test_delete_invalid_episodes(sample_dataset, tmp_path): + """Test error handling for invalid episode indices.""" + with pytest.raises(ValueError, match="Invalid episode indices"): + delete_episodes( + sample_dataset, + episode_indices=[10, 20], + output_dir=tmp_path / "filtered", + ) + + +def test_delete_all_episodes(sample_dataset, tmp_path): + """Test error when trying to delete all episodes.""" + with pytest.raises(ValueError, match="Cannot delete all episodes"): + delete_episodes( + sample_dataset, + episode_indices=list(range(5)), + output_dir=tmp_path / "filtered", + ) + + +def test_delete_empty_list(sample_dataset, tmp_path): + """Test error when no episodes specified.""" + with pytest.raises(ValueError, match="No episodes to delete"): + delete_episodes( + sample_dataset, + episode_indices=[], + output_dir=tmp_path / "filtered", + ) + + +def test_split_by_episodes(sample_dataset, tmp_path): + """Test splitting dataset by specific episode indices.""" + splits = { + "train": [0, 1, 2], + "val": [3, 4], + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + if "train" in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_train") + elif "val" in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_val") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + assert set(result.keys()) == {"train", "val"} + + assert result["train"].meta.total_episodes == 3 + assert result["train"].meta.total_frames == 30 + + assert result["val"].meta.total_episodes == 2 + assert result["val"].meta.total_frames == 20 + + train_episodes = {int(idx.item()) for idx in result["train"].hf_dataset["episode_index"]} + assert train_episodes == {0, 1, 2} + + val_episodes = {int(idx.item()) for idx in result["val"].hf_dataset["episode_index"]} + assert val_episodes == {0, 1} + + +def test_split_by_fractions(sample_dataset, tmp_path): + """Test splitting dataset by fractions.""" + splits = { + "train": 0.6, + "val": 0.4, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + for split_name in splits: + if split_name in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + assert result["train"].meta.total_episodes == 3 + assert result["val"].meta.total_episodes == 2 + + +def test_split_overlapping_episodes(sample_dataset, tmp_path): + """Test error when episodes appear in multiple splits.""" + splits = { + "train": [0, 1, 2], + "val": [2, 3, 4], + } + + with pytest.raises(ValueError, match="Episodes cannot appear in multiple splits"): + split_dataset(sample_dataset, splits=splits, output_dir=tmp_path) + + +def test_split_invalid_fractions(sample_dataset, tmp_path): + """Test error when fractions sum to more than 1.""" + splits = { + "train": 0.7, + "val": 0.5, + } + + with pytest.raises(ValueError, match="Split fractions must sum to <= 1.0"): + split_dataset(sample_dataset, splits=splits, output_dir=tmp_path) + + +def test_split_empty(sample_dataset, tmp_path): + """Test error with empty splits.""" + with pytest.raises(ValueError, match="No splits provided"): + split_dataset(sample_dataset, splits={}, output_dir=tmp_path) + + +def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory): + """Test merging two datasets.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + dataset2 = empty_lerobot_dataset_factory( + root=tmp_path / "test_dataset2", + features=features, + ) + + for ep_idx in range(3): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx % 2}", + } + dataset2.add_frame(frame) + dataset2.save_episode() + dataset2.finalize() + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") + + merged = merge_datasets( + [sample_dataset, dataset2], + output_repo_id="merged_dataset", + output_dir=tmp_path / "merged_dataset", + ) + + assert merged.meta.total_episodes == 8 # 5 + 3 + assert merged.meta.total_frames == 80 # 50 + 30 + + episode_indices = sorted({int(idx.item()) for idx in merged.hf_dataset["episode_index"]}) + assert episode_indices == list(range(8)) + + +def test_merge_empty_list(tmp_path): + """Test error when merging empty list.""" + with pytest.raises(ValueError, match="No datasets to merge"): + merge_datasets([], output_repo_id="merged", output_dir=tmp_path) + + +def test_add_features_with_values(sample_dataset, tmp_path): + """Test adding a feature with pre-computed values.""" + num_frames = sample_dataset.meta.total_frames + reward_values = np.random.randn(num_frames, 1).astype(np.float32) + + feature_info = { + "dtype": "float32", + "shape": (1,), + "names": None, + } + features = { + "reward": (reward_values, feature_info), + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "with_reward") + + new_dataset = add_features( + dataset=sample_dataset, + features=features, + output_dir=tmp_path / "with_reward", + ) + + assert "reward" in new_dataset.meta.features + assert new_dataset.meta.features["reward"] == feature_info + + assert len(new_dataset) == num_frames + sample_item = new_dataset[0] + assert "reward" in sample_item + assert isinstance(sample_item["reward"], torch.Tensor) + + +def test_add_features_with_callable(sample_dataset, tmp_path): + """Test adding a feature with a callable.""" + + def compute_reward(frame_dict, episode_idx, frame_idx): + return float(episode_idx * 10 + frame_idx) + + feature_info = { + "dtype": "float32", + "shape": (1,), + "names": None, + } + features = { + "reward": (compute_reward, feature_info), + } + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "with_reward") + + new_dataset = add_features( + dataset=sample_dataset, + features=features, + output_dir=tmp_path / "with_reward", + ) + + assert "reward" in new_dataset.meta.features + + items = [new_dataset[i] for i in range(10)] + first_episode_items = [item for item in items if item["episode_index"] == 0] + assert len(first_episode_items) == 10 + + first_frame = first_episode_items[0] + assert first_frame["frame_index"] == 0 + assert float(first_frame["reward"]) == 0.0 + + +def test_add_existing_feature(sample_dataset, tmp_path): + """Test error when adding an existing feature.""" + feature_info = {"dtype": "float32", "shape": (1,)} + features = { + "action": (np.zeros(50), feature_info), + } + + with pytest.raises(ValueError, match="Feature 'action' already exists"): + add_features( + dataset=sample_dataset, + features=features, + output_dir=tmp_path / "modified", + ) + + +def test_add_feature_invalid_info(sample_dataset, tmp_path): + """Test error with invalid feature info.""" + with pytest.raises(ValueError, match="feature_info for 'reward' must contain keys"): + add_features( + dataset=sample_dataset, + features={ + "reward": (np.zeros(50), {"dtype": "float32"}), + }, + output_dir=tmp_path / "modified", + ) + + +def test_modify_features_add_and_remove(sample_dataset, tmp_path): + """Test modifying features by adding and removing simultaneously.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "modified") + + # First add a feature we'll later remove + dataset_with_reward = add_features( + sample_dataset, + features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)}, + output_dir=tmp_path / "with_reward", + ) + + # Now use modify_features to add "success" and remove "reward" in one pass + modified_dataset = modify_features( + dataset_with_reward, + add_features={ + "success": (np.random.randn(50, 1).astype(np.float32), feature_info), + }, + remove_features="reward", + output_dir=tmp_path / "modified", + ) + + assert "success" in modified_dataset.meta.features + assert "reward" not in modified_dataset.meta.features + assert len(modified_dataset) == 50 + + +def test_modify_features_only_add(sample_dataset, tmp_path): + """Test that modify_features works with only add_features.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "modified") + + modified_dataset = modify_features( + sample_dataset, + add_features={ + "reward": (np.random.randn(50, 1).astype(np.float32), feature_info), + }, + output_dir=tmp_path / "modified", + ) + + assert "reward" in modified_dataset.meta.features + assert len(modified_dataset) == 50 + + +def test_modify_features_only_remove(sample_dataset, tmp_path): + """Test that modify_features works with only remove_features.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset_with_reward = add_features( + sample_dataset, + features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)}, + output_dir=tmp_path / "with_reward", + ) + + modified_dataset = modify_features( + dataset_with_reward, + remove_features="reward", + output_dir=tmp_path / "modified", + ) + + assert "reward" not in modified_dataset.meta.features + + +def test_modify_features_no_changes(sample_dataset, tmp_path): + """Test error when modify_features is called with no changes.""" + with pytest.raises(ValueError, match="Must specify at least one of add_features or remove_features"): + modify_features( + sample_dataset, + output_dir=tmp_path / "modified", + ) + + +def test_remove_single_feature(sample_dataset, tmp_path): + """Test removing a single feature.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + features = { + "reward": (np.random.randn(50, 1).astype(np.float32), feature_info), + } + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset_with_reward = add_features( + dataset=sample_dataset, + features=features, + output_dir=tmp_path / "with_reward", + ) + + dataset_without_reward = remove_feature( + dataset_with_reward, + feature_names="reward", + output_dir=tmp_path / "without_reward", + ) + + assert "reward" not in dataset_without_reward.meta.features + + sample_item = dataset_without_reward[0] + assert "reward" not in sample_item + + +def test_remove_multiple_features(sample_dataset, tmp_path): + """Test removing multiple features at once.""" + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset = sample_dataset + features = {} + for feature_name in ["reward", "success"]: + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + features[feature_name] = ( + np.random.randn(dataset.meta.total_frames, 1).astype(np.float32), + feature_info, + ) + + dataset_with_features = add_features( + dataset, features=features, output_dir=tmp_path / "with_features" + ) + dataset_clean = remove_feature( + dataset_with_features, feature_names=["reward", "success"], output_dir=tmp_path / "clean" + ) + + assert "reward" not in dataset_clean.meta.features + assert "success" not in dataset_clean.meta.features + + +def test_remove_nonexistent_feature(sample_dataset, tmp_path): + """Test error when removing non-existent feature.""" + with pytest.raises(ValueError, match="Feature 'nonexistent' not found"): + remove_feature( + sample_dataset, + feature_names="nonexistent", + output_dir=tmp_path / "modified", + ) + + +def test_remove_required_feature(sample_dataset, tmp_path): + """Test error when trying to remove required features.""" + with pytest.raises(ValueError, match="Cannot remove required features"): + remove_feature( + sample_dataset, + feature_names="timestamp", + output_dir=tmp_path / "modified", + ) + + +def test_remove_camera_feature(sample_dataset, tmp_path): + """Test removing a camera feature.""" + camera_keys = sample_dataset.meta.camera_keys + if not camera_keys: + pytest.skip("No camera keys in dataset") + + camera_to_remove = camera_keys[0] + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "without_camera") + + dataset_without_camera = remove_feature( + sample_dataset, + feature_names=camera_to_remove, + output_dir=tmp_path / "without_camera", + ) + + assert camera_to_remove not in dataset_without_camera.meta.features + assert camera_to_remove not in dataset_without_camera.meta.camera_keys + + sample_item = dataset_without_camera[0] + assert camera_to_remove not in sample_item + + +def test_complex_workflow_integration(sample_dataset, tmp_path): + """Test a complex workflow combining multiple operations.""" + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset = add_features( + sample_dataset, + features={ + "reward": ( + np.random.randn(50, 1).astype(np.float32), + {"dtype": "float32", "shape": (1,), "names": None}, + ) + }, + output_dir=tmp_path / "step1", + ) + + dataset = delete_episodes( + dataset, + episode_indices=[2], + output_dir=tmp_path / "step2", + ) + + splits = split_dataset( + dataset, + splits={"train": 0.75, "val": 0.25}, + output_dir=tmp_path / "step3", + ) + + merged = merge_datasets( + list(splits.values()), + output_repo_id="final_dataset", + output_dir=tmp_path / "step4", + ) + + assert merged.meta.total_episodes == 4 + assert merged.meta.total_frames == 40 + assert "reward" in merged.meta.features + + assert len(merged) == 40 + sample_item = merged[0] + assert "reward" in sample_item + + +def test_delete_episodes_preserves_stats(sample_dataset, tmp_path): + """Test that deleting episodes preserves statistics correctly.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[2], + output_dir=output_dir, + ) + + assert new_dataset.meta.stats is not None + for feature in ["action", "observation.state"]: + assert feature in new_dataset.meta.stats + assert "mean" in new_dataset.meta.stats[feature] + assert "std" in new_dataset.meta.stats[feature] + + +def test_delete_episodes_preserves_tasks(sample_dataset, tmp_path): + """Test that tasks are preserved correctly after deletion.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[0], + output_dir=output_dir, + ) + + assert new_dataset.meta.tasks is not None + assert len(new_dataset.meta.tasks) == 2 + + tasks_in_dataset = {str(item["task"]) for item in new_dataset} + assert len(tasks_in_dataset) > 0 + + +def test_split_three_ways(sample_dataset, tmp_path): + """Test splitting dataset into three splits.""" + splits = { + "train": 0.6, + "val": 0.2, + "test": 0.2, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + for split_name in splits: + if split_name in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + assert set(result.keys()) == {"train", "val", "test"} + assert result["train"].meta.total_episodes == 3 + assert result["val"].meta.total_episodes == 1 + assert result["test"].meta.total_episodes == 1 + + total_frames = sum(ds.meta.total_frames for ds in result.values()) + assert total_frames == sample_dataset.meta.total_frames + + +def test_split_preserves_stats(sample_dataset, tmp_path): + """Test that statistics are preserved when splitting.""" + splits = {"train": [0, 1, 2], "val": [3, 4]} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + for split_name in splits: + if split_name in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + for split_ds in result.values(): + assert split_ds.meta.stats is not None + for feature in ["action", "observation.state"]: + assert feature in split_ds.meta.stats + assert "mean" in split_ds.meta.stats[feature] + assert "std" in split_ds.meta.stats[feature] + + +def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory): + """Test merging three datasets.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + datasets = [sample_dataset] + + for i in range(2): + dataset = empty_lerobot_dataset_factory( + root=tmp_path / f"test_dataset{i + 2}", + features=features, + ) + + for ep_idx in range(2): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx}", + } + dataset.add_frame(frame) + dataset.save_episode() + dataset.finalize() + + datasets.append(dataset) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") + + merged = merge_datasets( + datasets, + output_repo_id="merged_dataset", + output_dir=tmp_path / "merged_dataset", + ) + + assert merged.meta.total_episodes == 9 + assert merged.meta.total_frames == 90 + + +def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_factory): + """Test that statistics are computed for merged datasets.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + dataset2 = empty_lerobot_dataset_factory( + root=tmp_path / "test_dataset2", + features=features, + ) + + for ep_idx in range(3): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx % 2}", + } + dataset2.add_frame(frame) + dataset2.save_episode() + dataset2.finalize() + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") + + merged = merge_datasets( + [sample_dataset, dataset2], + output_repo_id="merged_dataset", + output_dir=tmp_path / "merged_dataset", + ) + + assert merged.meta.stats is not None + for feature in ["action", "observation.state"]: + assert feature in merged.meta.stats + assert "mean" in merged.meta.stats[feature] + assert "std" in merged.meta.stats[feature] + + +def test_add_features_preserves_existing_stats(sample_dataset, tmp_path): + """Test that adding a feature preserves existing stats.""" + num_frames = sample_dataset.meta.total_frames + reward_values = np.random.randn(num_frames, 1).astype(np.float32) + + feature_info = { + "dtype": "float32", + "shape": (1,), + "names": None, + } + features = { + "reward": (reward_values, feature_info), + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "with_reward") + + new_dataset = add_features( + dataset=sample_dataset, + features=features, + output_dir=tmp_path / "with_reward", + ) + + assert new_dataset.meta.stats is not None + for feature in ["action", "observation.state"]: + assert feature in new_dataset.meta.stats + assert "mean" in new_dataset.meta.stats[feature] + assert "std" in new_dataset.meta.stats[feature] + + +def test_remove_feature_updates_stats(sample_dataset, tmp_path): + """Test that removing a feature removes it from stats.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset_with_reward = add_features( + sample_dataset, + features={ + "reward": (np.random.randn(50, 1).astype(np.float32), feature_info), + }, + output_dir=tmp_path / "with_reward", + ) + + dataset_without_reward = remove_feature( + dataset_with_reward, + feature_names="reward", + output_dir=tmp_path / "without_reward", + ) + + if dataset_without_reward.meta.stats: + assert "reward" not in dataset_without_reward.meta.stats + + +def test_delete_consecutive_episodes(sample_dataset, tmp_path): + """Test deleting consecutive episodes.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[1, 2, 3], + output_dir=output_dir, + ) + + assert new_dataset.meta.total_episodes == 2 + assert new_dataset.meta.total_frames == 20 + + episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) + assert episode_indices == [0, 1] + + +def test_delete_first_and_last_episodes(sample_dataset, tmp_path): + """Test deleting first and last episodes.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[0, 4], + output_dir=output_dir, + ) + + assert new_dataset.meta.total_episodes == 3 + assert new_dataset.meta.total_frames == 30 + + episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) + assert episode_indices == [0, 1, 2] + + +def test_split_all_episodes_assigned(sample_dataset, tmp_path): + """Test that all episodes can be explicitly assigned to splits.""" + splits = { + "split1": [0, 1], + "split2": [2, 3], + "split3": [4], + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + for split_name in splits: + if split_name in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + total_episodes = sum(ds.meta.total_episodes for ds in result.values()) + assert total_episodes == sample_dataset.meta.total_episodes + + +def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): + """Test that modifying features preserves chunk_idx and file_idx from source dataset.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + return str(kwargs.get("local_dir", tmp_path / repo_id.split("/")[-1])) + + mock_snapshot_download.side_effect = mock_snapshot + + # First split the dataset to create a non-zero starting chunk/file structure + splits = split_dataset( + sample_dataset, + splits={"train": [0, 1, 2], "val": [3, 4]}, + output_dir=tmp_path / "splits", + ) + + train_dataset = splits["train"] + + # Get original chunk/file indices from first episode + if train_dataset.meta.episodes is None: + from lerobot.datasets.utils import load_episodes + + train_dataset.meta.episodes = load_episodes(train_dataset.meta.root) + original_chunk_indices = [ep["data/chunk_index"] for ep in train_dataset.meta.episodes] + original_file_indices = [ep["data/file_index"] for ep in train_dataset.meta.episodes] + + # Now add a feature to the split dataset + modified_dataset = add_features( + train_dataset, + features={ + "reward": ( + np.random.randn(train_dataset.meta.total_frames, 1).astype(np.float32), + feature_info, + ), + }, + output_dir=tmp_path / "modified", + ) + + # Check that chunk/file indices are preserved + if modified_dataset.meta.episodes is None: + from lerobot.datasets.utils import load_episodes + + modified_dataset.meta.episodes = load_episodes(modified_dataset.meta.root) + new_chunk_indices = [ep["data/chunk_index"] for ep in modified_dataset.meta.episodes] + new_file_indices = [ep["data/file_index"] for ep in modified_dataset.meta.episodes] + + assert new_chunk_indices == original_chunk_indices, "Chunk indices should be preserved" + assert new_file_indices == original_file_indices, "File indices should be preserved" + assert "reward" in modified_dataset.meta.features diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 2bc3bea43..e174c5789 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -806,6 +806,8 @@ def test_episode_index_distribution(tmp_path, empty_lerobot_dataset_factory): dataset.add_frame({"state": torch.randn(2), "task": f"task_{episode_idx}"}) dataset.save_episode() + dataset.finalize() + # Load the dataset and check episode indices loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) @@ -855,6 +857,8 @@ def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_fact dataset.add_frame({"state": torch.randn(3), ACTION: torch.randn(2), "task": tasks[episode_idx]}) dataset.save_episode() + dataset.finalize() + # Load and validate episode metadata loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) @@ -893,6 +897,8 @@ def test_data_consistency_across_episodes(tmp_path, empty_lerobot_dataset_factor dataset.add_frame({"state": torch.randn(1), "task": "consistency_test"}) dataset.save_episode() + dataset.finalize() + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) # Check data consistency - no gaps or overlaps @@ -944,6 +950,8 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory) dataset.add_frame({"state": state_data, ACTION: action_data, "task": "stats_test"}) dataset.save_episode() + dataset.finalize() + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) # Check that statistics exist for all features @@ -989,6 +997,8 @@ def test_episode_boundary_integrity(tmp_path, empty_lerobot_dataset_factory): dataset.add_frame({"state": torch.tensor([float(frame_idx)]), "task": f"episode_{episode_idx}"}) dataset.save_episode() + dataset.finalize() + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) # Test episode boundaries @@ -1031,6 +1041,8 @@ def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory): dataset.add_frame({"state": torch.randn(1), "task": task}) dataset.save_episode() + dataset.finalize() + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) # Check that all unique tasks are in the tasks metadata @@ -1056,3 +1068,134 @@ def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory): # Check total number of tasks assert loaded_dataset.meta.total_tasks == len(unique_tasks) + + +def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory): + """Test that resuming dataset recording preserves previously recorded episodes. + + This test validates the critical resume functionality by: + 1. Recording initial episodes and finalizing + 2. Reopening the dataset + 3. Recording additional episodes + 4. Verifying all data (old + new) is intact + + This specifically tests the bug fix where parquet files were being overwritten + instead of appended to during resume. + """ + features = { + "observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + "action": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + } + + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) + + initial_episodes = 2 + frames_per_episode = 3 + + for ep_idx in range(initial_episodes): + for frame_idx in range(frames_per_episode): + dataset.add_frame( + { + "observation.state": torch.tensor([float(ep_idx), float(frame_idx)]), + "action": torch.tensor([0.5, 0.5]), + "task": f"task_{ep_idx}", + } + ) + dataset.save_episode() + + assert dataset.meta.total_episodes == initial_episodes + assert dataset.meta.total_frames == initial_episodes * frames_per_episode + + dataset.finalize() + initial_root = dataset.root + initial_repo_id = dataset.repo_id + del dataset + + dataset_verify = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0") + assert dataset_verify.meta.total_episodes == initial_episodes + assert dataset_verify.meta.total_frames == initial_episodes * frames_per_episode + assert len(dataset_verify.hf_dataset) == initial_episodes * frames_per_episode + + for idx in range(len(dataset_verify.hf_dataset)): + item = dataset_verify[idx] + expected_ep = idx // frames_per_episode + expected_frame = idx % frames_per_episode + assert item["episode_index"].item() == expected_ep + assert item["frame_index"].item() == expected_frame + assert item["index"].item() == idx + assert item["observation.state"][0].item() == float(expected_ep) + assert item["observation.state"][1].item() == float(expected_frame) + + del dataset_verify + + # Phase 3: Resume recording - add more episodes + dataset_resumed = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0") + + assert dataset_resumed.meta.total_episodes == initial_episodes + assert dataset_resumed.meta.total_frames == initial_episodes * frames_per_episode + assert dataset_resumed.latest_episode is None # Not recording yet + assert dataset_resumed.writer is None + assert dataset_resumed.meta.writer is None + + additional_episodes = 2 + for ep_idx in range(initial_episodes, initial_episodes + additional_episodes): + for frame_idx in range(frames_per_episode): + dataset_resumed.add_frame( + { + "observation.state": torch.tensor([float(ep_idx), float(frame_idx)]), + "action": torch.tensor([0.5, 0.5]), + "task": f"task_{ep_idx}", + } + ) + dataset_resumed.save_episode() + + total_episodes = initial_episodes + additional_episodes + total_frames = total_episodes * frames_per_episode + assert dataset_resumed.meta.total_episodes == total_episodes + assert dataset_resumed.meta.total_frames == total_frames + + dataset_resumed.finalize() + del dataset_resumed + + dataset_final = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0") + + assert dataset_final.meta.total_episodes == total_episodes + assert dataset_final.meta.total_frames == total_frames + assert len(dataset_final.hf_dataset) == total_frames + + for idx in range(total_frames): + item = dataset_final[idx] + expected_ep = idx // frames_per_episode + expected_frame = idx % frames_per_episode + + assert item["episode_index"].item() == expected_ep, ( + f"Frame {idx}: wrong episode_index. Expected {expected_ep}, got {item['episode_index'].item()}" + ) + assert item["frame_index"].item() == expected_frame, ( + f"Frame {idx}: wrong frame_index. Expected {expected_frame}, got {item['frame_index'].item()}" + ) + assert item["index"].item() == idx, ( + f"Frame {idx}: wrong index. Expected {idx}, got {item['index'].item()}" + ) + + # Verify data integrity + assert item["observation.state"][0].item() == float(expected_ep), ( + f"Frame {idx}: wrong observation.state[0]. Expected {float(expected_ep)}, " + f"got {item['observation.state'][0].item()}" + ) + assert item["observation.state"][1].item() == float(expected_frame), ( + f"Frame {idx}: wrong observation.state[1]. Expected {float(expected_frame)}, " + f"got {item['observation.state'][1].item()}" + ) + + assert len(dataset_final.meta.episodes) == total_episodes + for ep_idx in range(total_episodes): + ep_metadata = dataset_final.meta.episodes[ep_idx] + assert ep_metadata["episode_index"] == ep_idx + assert ep_metadata["length"] == frames_per_episode + assert ep_metadata["tasks"] == [f"task_{ep_idx}"] + + expected_from = ep_idx * frames_per_episode + expected_to = (ep_idx + 1) * frames_per_episode + assert ep_metadata["dataset_from_index"] == expected_from + assert ep_metadata["dataset_to_index"] == expected_to diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 34fa89390..345526d90 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -95,7 +95,6 @@ def test_get_policy_and_config_classes(policy_name: str): @pytest.mark.parametrize( "ds_repo_id,env_name,env_kwargs,policy_name,policy_kwargs", [ - ("lerobot/xarm_lift_medium", "xarm", {}, "tdmpc", {"use_mpc": True}), ("lerobot/pusht", "pusht", {}, "diffusion", {}), ("lerobot/pusht", "pusht", {}, "vqbet", {}), ("lerobot/pusht", "pusht", {}, "act", {}), @@ -328,8 +327,6 @@ def test_multikey_construction(multikey: bool): # TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it # was changed to true. For some reason, tests would pass locally, but not in CI. So here we override # to test with `policy.use_mpc=false`. - ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"), - # ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"), # TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to # to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference # that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass. diff --git a/tests/training/test_multi_gpu.py b/tests/training/test_multi_gpu.py new file mode 100644 index 000000000..bb234e2e7 --- /dev/null +++ b/tests/training/test_multi_gpu.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Multi-GPU Training Tests + +This module tests multi-GPU training functionality with accelerate. +These tests are designed to run on machines with 2+ GPUs and are executed +in the nightly CI workflow. + +The tests automatically generate accelerate configs and launch training +with subprocess to properly test the distributed training environment. +""" + +import os +import subprocess +import tempfile +from pathlib import Path + +import pytest +import torch + +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +def get_num_available_gpus(): + """Returns the number of available GPUs.""" + if not torch.cuda.is_available(): + return 0 + return torch.cuda.device_count() + + +def download_dataset(repo_id, episodes): + """ + Pre-download dataset to avoid race conditions in multi-GPU training. + + Args: + repo_id: HuggingFace dataset repository ID + episodes: List of episode indices to download + """ + # Simply instantiating the dataset will download it + _ = LeRobotDataset(repo_id, episodes=episodes) + print(f"Dataset {repo_id} downloaded successfully") + + +def run_accelerate_training(config_args, num_processes=4, temp_dir=None): + """ + Helper function to run training with accelerate launch. + + Args: + config_args: List of config arguments to pass to lerobot_train.py + num_processes: Number of processes (GPUs) to use + temp_dir: Temporary directory for outputs + + Returns: + subprocess.CompletedProcess result + """ + + config_path = Path(temp_dir) / "accelerate_config.yaml" + + # Write YAML config + with open(config_path, "w") as f: + f.write("compute_environment: LOCAL_MACHINE\n") + f.write("distributed_type: MULTI_GPU\n") + f.write("mixed_precision: 'no'\n") + f.write(f"num_processes: {num_processes}\n") + f.write("use_cpu: false\n") + f.write("gpu_ids: all\n") + f.write("downcast_bf16: 'no'\n") + f.write("machine_rank: 0\n") + f.write("main_training_function: main\n") + f.write("num_machines: 1\n") + f.write("rdzv_backend: static\n") + f.write("same_network: true\n") + + cmd = [ + "accelerate", + "launch", + "--config_file", + str(config_path), + "-m", + "lerobot.scripts.lerobot_train", + ] + config_args + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + env={**os.environ, "CUDA_VISIBLE_DEVICES": ",".join(map(str, range(num_processes)))}, + ) + + return result + + +@pytest.mark.skipif( + get_num_available_gpus() < 2, + reason="Multi-GPU tests require at least 2 GPUs", +) +class TestMultiGPUTraining: + """Test suite for multi-GPU training functionality.""" + + def test_basic_multi_gpu_training(self): + """ + Test that basic multi-GPU training runs successfully. + Verifies that the training completes without errors. + """ + # Pre-download dataset to avoid race conditions + download_dataset("lerobot/pusht", episodes=[0]) + + with tempfile.TemporaryDirectory() as temp_dir: + output_dir = Path(temp_dir) / "outputs" + + config_args = [ + "--dataset.repo_id=lerobot/pusht", + "--dataset.episodes=[0]", + "--policy.type=act", + "--policy.device=cuda", + "--policy.push_to_hub=false", + f"--output_dir={output_dir}", + "--batch_size=4", + "--steps=10", + "--eval_freq=-1", + "--log_freq=5", + "--save_freq=10", + "--seed=42", + "--num_workers=0", + ] + + result = run_accelerate_training(config_args, num_processes=4, temp_dir=temp_dir) + + # Check that training completed successfully + assert result.returncode == 0, ( + f"Multi-GPU training failed with return code {result.returncode}\n" + f"STDOUT:\n{result.stdout}\n" + f"STDERR:\n{result.stderr}" + ) + + # Verify checkpoint was saved + checkpoints_dir = output_dir / "checkpoints" + assert checkpoints_dir.exists(), "Checkpoints directory was not created" + + # Verify that training completed + assert "End of training" in result.stdout or "End of training" in result.stderr + + def test_checkpoint_saving_multi_gpu(self): + """ + Test that checkpoints are correctly saved during multi-GPU training. + Only the main process (rank 0) should save checkpoints. + """ + # Pre-download dataset to avoid race conditions + download_dataset("lerobot/pusht", episodes=[0]) + + with tempfile.TemporaryDirectory() as temp_dir: + output_dir = Path(temp_dir) / "outputs" + + config_args = [ + "--dataset.repo_id=lerobot/pusht", + "--dataset.episodes=[0]", + "--policy.type=act", + "--policy.device=cuda", + "--policy.push_to_hub=false", + f"--output_dir={output_dir}", + "--batch_size=4", + "--steps=20", + "--eval_freq=-1", + "--log_freq=5", + "--save_freq=10", + "--seed=42", + "--num_workers=0", + ] + + result = run_accelerate_training(config_args, num_processes=2, temp_dir=temp_dir) + + assert result.returncode == 0, ( + f"Training failed:\nSTDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}" + ) + + # Verify checkpoint directory exists + checkpoints_dir = output_dir / "checkpoints" + assert checkpoints_dir.exists(), "Checkpoints directory not created" + + # Count checkpoint directories (should have checkpoint at step 10 and 20) + checkpoint_dirs = [d for d in checkpoints_dir.iterdir() if d.is_dir()] + assert len(checkpoint_dirs) >= 1, f"Expected at least 1 checkpoint, found {len(checkpoint_dirs)}" + + # Verify checkpoint contents + for checkpoint_dir in checkpoint_dirs: + # Check for model files + model_files = list(checkpoint_dir.rglob("*.safetensors")) + assert len(model_files) > 0, f"No model files in checkpoint {checkpoint_dir}" + + # Check for training state + training_state_dir = checkpoint_dir / "training_state" + assert training_state_dir.exists(), f"No training state in checkpoint {checkpoint_dir}" + + # Verify optimizer state exists + optimizer_state = training_state_dir / "optimizer_state.safetensors" + assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}"