diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 809789ce2..63b7bfb4c 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -1,3 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging import shutil from pathlib import Path @@ -26,6 +43,21 @@ from lerobot.datasets.utils import ( def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): + """Validates that all dataset metadata have consistent properties. + + Ensures all datasets have the same fps, robot_type, and features to guarantee + compatibility when aggregating them into a single dataset. + + Args: + all_metadata: List of LeRobotDatasetMetadata objects to validate. + + Returns: + tuple: A tuple containing (fps, robot_type, features) from the first metadata. + + Raises: + ValueError: If any metadata has different fps, robot_type, or features + than the first metadata in the list. + """ # validate same fps, robot_type, features fps = all_metadata[0].fps @@ -48,6 +80,20 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): def update_data_df(df, src_meta, dst_meta): + """Updates a data DataFrame with new indices and task mappings for aggregation. + + Adjusts episode indices, frame indices, and task indices to account for + previously aggregated data in the destination dataset. + + Args: + df: DataFrame containing the data to be updated. + src_meta: Source dataset metadata. + dst_meta: Destination dataset metadata. + + Returns: + pd.DataFrame: Updated DataFrame with adjusted indices. + """ + def _update(row): row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"] row["index"] = row["index"] + dst_meta.info["total_frames"] @@ -65,6 +111,22 @@ def update_meta_data( data_idx, videos_idx, ): + """Updates metadata DataFrame with new chunk, file, and timestamp indices. + + Adjusts all indices and timestamps to account for previously aggregated + data and videos in the destination dataset. + + Args: + df: DataFrame containing the metadata to be updated. + dst_meta: Destination dataset metadata. + meta_idx: Dictionary containing current metadata chunk and file indices. + data_idx: Dictionary containing current data chunk and file indices. + videos_idx: Dictionary containing current video indices and timestamps. + + Returns: + pd.DataFrame: Updated DataFrame with adjusted indices and timestamps. + """ + def _update(row): row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"] row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"] @@ -88,6 +150,20 @@ def update_meta_data( def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] = None, aggr_root=None): + """Aggregates multiple LeRobot datasets into a single unified dataset. + + This is the main function that orchestrates the aggregation process by: + 1. Loading and validating all source dataset metadata + 2. Creating a new destination dataset with unified tasks + 3. Aggregating videos, data, and metadata from all source datasets + 4. Finalizing the aggregated dataset with proper statistics + + Args: + repo_ids: List of repository IDs for the datasets to aggregate. + aggr_repo_id: Repository ID for the aggregated output dataset. + roots: Optional list of root paths for the source datasets. + aggr_root: Optional root path for the aggregated dataset. + """ logging.info("Start aggregate_datasets") # Load metadata @@ -144,8 +220,18 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] def aggregate_videos(src_meta, dst_meta, videos_idx): - """ - Aggregates video chunks from a dataset into the aggregated dataset folder. + """Aggregates video chunks from a source dataset into the destination dataset. + + Handles video file concatenation and rotation based on file size limits. + Creates new video files when size limits are exceeded. + + Args: + src_meta: Source dataset metadata. + dst_meta: Destination dataset metadata. + videos_idx: Dictionary tracking video chunk and file indices. + + Returns: + dict: Updated videos_idx with current chunk and file indices. """ for key, video_idx in videos_idx.items(): # Get unique (chunk, file) combinations @@ -213,6 +299,19 @@ def aggregate_videos(src_meta, dst_meta, videos_idx): def aggregate_data(src_meta, dst_meta, data_idx): + """Aggregates data chunks from a source dataset into the destination dataset. + + Reads source data files, updates indices to match the aggregated dataset, + and writes them to the destination with proper file rotation. + + Args: + src_meta: Source dataset metadata. + dst_meta: Destination dataset metadata. + data_idx: Dictionary tracking data chunk and file indices. + + Returns: + dict: Updated data_idx with current chunk and file indices. + """ unique_chunk_file_ids = { (c, f) for c, f in zip( @@ -241,6 +340,21 @@ def aggregate_data(src_meta, dst_meta, data_idx): def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): + """Aggregates metadata from a source dataset into the destination dataset. + + Reads source metadata files, updates all indices and timestamps, + and writes them to the destination with proper file rotation. + + Args: + src_meta: Source dataset metadata. + dst_meta: Destination dataset metadata. + meta_idx: Dictionary tracking metadata chunk and file indices. + data_idx: Dictionary tracking data chunk and file indices. + videos_idx: Dictionary tracking video indices and timestamps. + + Returns: + dict: Updated meta_idx with current chunk and file indices. + """ chunk_file_ids = { (c, f) for c, f in zip( @@ -288,19 +402,23 @@ def append_or_create_parquet_file( contains_images: bool = False, aggr_root: Path = None, ): - """ - Safely appends or creates a Parquet file at dst_path based on size constraints. + """Appends data to an existing parquet file or creates a new one based on size constraints. - Parameters: - df (pd.DataFrame): Data to write. - src_path (Path): Path to source file (used to get size). - idx (dict): Dictionary containing 'chunk' and 'file' indices. - max_mb (float): Maximum allowed file size in MB. - chunk_size (int): Maximum number of files per chunk. - default_path (str): Format string for generating a new file path. + Manages file rotation when size limits are exceeded to prevent individual files + from becoming too large. Handles both regular parquet files and those containing images. + + Args: + df: DataFrame to write to the parquet file. + src_path: Path to the source file (used for size estimation). + idx: Dictionary containing current 'chunk' and 'file' indices. + max_mb: Maximum allowed file size in MB before rotation. + chunk_size: Maximum number of files per chunk before incrementing chunk index. + default_path: Format string for generating file paths. + contains_images: Whether the data contains images requiring special handling. + aggr_root: Root path for the aggregated dataset. Returns: - dict: Updated index dictionary. + dict: Updated index dictionary with current chunk and file indices. """ # Initial destination path - use the correct default_path parameter dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) @@ -340,6 +458,15 @@ def append_or_create_parquet_file( def finalize_aggregation(aggr_meta, all_metadata): + """Finalizes the dataset aggregation by writing summary files and statistics. + + Writes the tasks file, info file with total counts and splits, and + aggregated statistics from all source datasets. + + Args: + aggr_meta: Aggregated dataset metadata. + all_metadata: List of all source dataset metadata objects. + """ logging.info("write tasks") write_tasks(aggr_meta.tasks, aggr_meta.root) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index fe0a1b290..33c92e344 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -511,23 +511,19 @@ def test_backward_compatibility(repo_id): ) # test2 first frames of first episode - i = dataset.meta.episodes["dataset_from_index"][0].item() + i = dataset.meta.episodes[0]["dataset_from_index"] load_and_compare(i) load_and_compare(i + 1) # test 2 frames at the middle of first episode i = int( - ( - dataset.meta.episodes["dataset_to_index"][0].item() - - dataset.meta.episodes["dataset_from_index"][0].item() - ) - / 2 + (dataset.meta.episodes[0]["dataset_to_index"] - dataset.meta.episodes[0]["dataset_from_index"]) / 2 ) load_and_compare(i) load_and_compare(i + 1) # test 2 last frames of first episode - i = dataset.meta.episodes["dataset_to_index"][0].item() + i = dataset.meta.episodes[0]["dataset_to_index"] load_and_compare(i - 2) load_and_compare(i - 1) diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 3a8c6a224..3243713cb 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -20,6 +20,7 @@ from pathlib import Path import einops import pytest import torch +from packaging import version from safetensors.torch import load_file from lerobot import available_policies @@ -68,11 +69,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p }, } info = info_factory( - total_episodes=1, - total_frames=1, - total_tasks=1, - camera_features=camera_features, - motor_features=motor_features, + total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features ) ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info) return ds_meta @@ -141,14 +138,13 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive, and for now we add tests as we see fit. """ - policy_kwargs["device"] = DEVICE - train_cfg = TrainPipelineConfig( # TODO(rcadene, aliberts): remove dataset download dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]), - policy=make_policy_config(policy_name, **policy_kwargs), + policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs), env=make_env_config(env_name, **env_kwargs), ) + train_cfg.validate() # Check that we can make the policy object. dataset = make_dataset(train_cfg) @@ -217,7 +213,7 @@ def test_act_backbone_lr(): cfg = TrainPipelineConfig( # TODO(rcadene, aliberts): remove dataset download dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]), - policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001), + policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False), ) cfg.validate() # Needed for auto-setting some parameters @@ -413,7 +409,17 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs 4. Check that this test now passes. 5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state. 6. Remember to stage and commit the resulting changes to `tests/artifacts`. + + NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact + is out of date. For example, some PyTorch versions have different randomness, see this PR: + https://github.com/huggingface/lerobot/pull/1127. + """ + + # NOTE: ACT policy has different randomness, after PyTorch 2.7.0 + if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"): + pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0") + ds_name = ds_repo_id.split("/")[-1] artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}" saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")