Added docstrings to aggregate, fix test_policies.py

This commit is contained in:
Michel Aractingi
2025-07-04 11:27:00 +02:00
parent 830a3b9f27
commit 3dbc3e60fb
3 changed files with 157 additions and 28 deletions
+139 -12
View File
@@ -1,3 +1,20 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
import shutil import shutil
from pathlib import Path from pathlib import Path
@@ -26,6 +43,21 @@ from lerobot.datasets.utils import (
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
"""Validates that all dataset metadata have consistent properties.
Ensures all datasets have the same fps, robot_type, and features to guarantee
compatibility when aggregating them into a single dataset.
Args:
all_metadata: List of LeRobotDatasetMetadata objects to validate.
Returns:
tuple: A tuple containing (fps, robot_type, features) from the first metadata.
Raises:
ValueError: If any metadata has different fps, robot_type, or features
than the first metadata in the list.
"""
# validate same fps, robot_type, features # validate same fps, robot_type, features
fps = all_metadata[0].fps fps = all_metadata[0].fps
@@ -48,6 +80,20 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
def update_data_df(df, src_meta, dst_meta): def update_data_df(df, src_meta, dst_meta):
"""Updates a data DataFrame with new indices and task mappings for aggregation.
Adjusts episode indices, frame indices, and task indices to account for
previously aggregated data in the destination dataset.
Args:
df: DataFrame containing the data to be updated.
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
Returns:
pd.DataFrame: Updated DataFrame with adjusted indices.
"""
def _update(row): def _update(row):
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"] row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
row["index"] = row["index"] + dst_meta.info["total_frames"] row["index"] = row["index"] + dst_meta.info["total_frames"]
@@ -65,6 +111,22 @@ def update_meta_data(
data_idx, data_idx,
videos_idx, videos_idx,
): ):
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
Adjusts all indices and timestamps to account for previously aggregated
data and videos in the destination dataset.
Args:
df: DataFrame containing the metadata to be updated.
dst_meta: Destination dataset metadata.
meta_idx: Dictionary containing current metadata chunk and file indices.
data_idx: Dictionary containing current data chunk and file indices.
videos_idx: Dictionary containing current video indices and timestamps.
Returns:
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
"""
def _update(row): def _update(row):
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"] row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"]
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"] row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"]
@@ -88,6 +150,20 @@ def update_meta_data(
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] = None, aggr_root=None): def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] = None, aggr_root=None):
"""Aggregates multiple LeRobot datasets into a single unified dataset.
This is the main function that orchestrates the aggregation process by:
1. Loading and validating all source dataset metadata
2. Creating a new destination dataset with unified tasks
3. Aggregating videos, data, and metadata from all source datasets
4. Finalizing the aggregated dataset with proper statistics
Args:
repo_ids: List of repository IDs for the datasets to aggregate.
aggr_repo_id: Repository ID for the aggregated output dataset.
roots: Optional list of root paths for the source datasets.
aggr_root: Optional root path for the aggregated dataset.
"""
logging.info("Start aggregate_datasets") logging.info("Start aggregate_datasets")
# Load metadata # Load metadata
@@ -144,8 +220,18 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
def aggregate_videos(src_meta, dst_meta, videos_idx): def aggregate_videos(src_meta, dst_meta, videos_idx):
""" """Aggregates video chunks from a source dataset into the destination dataset.
Aggregates video chunks from a dataset into the aggregated dataset folder.
Handles video file concatenation and rotation based on file size limits.
Creates new video files when size limits are exceeded.
Args:
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
videos_idx: Dictionary tracking video chunk and file indices.
Returns:
dict: Updated videos_idx with current chunk and file indices.
""" """
for key, video_idx in videos_idx.items(): for key, video_idx in videos_idx.items():
# Get unique (chunk, file) combinations # Get unique (chunk, file) combinations
@@ -213,6 +299,19 @@ def aggregate_videos(src_meta, dst_meta, videos_idx):
def aggregate_data(src_meta, dst_meta, data_idx): def aggregate_data(src_meta, dst_meta, data_idx):
"""Aggregates data chunks from a source dataset into the destination dataset.
Reads source data files, updates indices to match the aggregated dataset,
and writes them to the destination with proper file rotation.
Args:
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
data_idx: Dictionary tracking data chunk and file indices.
Returns:
dict: Updated data_idx with current chunk and file indices.
"""
unique_chunk_file_ids = { unique_chunk_file_ids = {
(c, f) (c, f)
for c, f in zip( for c, f in zip(
@@ -241,6 +340,21 @@ def aggregate_data(src_meta, dst_meta, data_idx):
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
"""Aggregates metadata from a source dataset into the destination dataset.
Reads source metadata files, updates all indices and timestamps,
and writes them to the destination with proper file rotation.
Args:
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
meta_idx: Dictionary tracking metadata chunk and file indices.
data_idx: Dictionary tracking data chunk and file indices.
videos_idx: Dictionary tracking video indices and timestamps.
Returns:
dict: Updated meta_idx with current chunk and file indices.
"""
chunk_file_ids = { chunk_file_ids = {
(c, f) (c, f)
for c, f in zip( for c, f in zip(
@@ -288,19 +402,23 @@ def append_or_create_parquet_file(
contains_images: bool = False, contains_images: bool = False,
aggr_root: Path = None, aggr_root: Path = None,
): ):
""" """Appends data to an existing parquet file or creates a new one based on size constraints.
Safely appends or creates a Parquet file at dst_path based on size constraints.
Parameters: Manages file rotation when size limits are exceeded to prevent individual files
df (pd.DataFrame): Data to write. from becoming too large. Handles both regular parquet files and those containing images.
src_path (Path): Path to source file (used to get size).
idx (dict): Dictionary containing 'chunk' and 'file' indices. Args:
max_mb (float): Maximum allowed file size in MB. df: DataFrame to write to the parquet file.
chunk_size (int): Maximum number of files per chunk. src_path: Path to the source file (used for size estimation).
default_path (str): Format string for generating a new file path. idx: Dictionary containing current 'chunk' and 'file' indices.
max_mb: Maximum allowed file size in MB before rotation.
chunk_size: Maximum number of files per chunk before incrementing chunk index.
default_path: Format string for generating file paths.
contains_images: Whether the data contains images requiring special handling.
aggr_root: Root path for the aggregated dataset.
Returns: Returns:
dict: Updated index dictionary. dict: Updated index dictionary with current chunk and file indices.
""" """
# Initial destination path - use the correct default_path parameter # Initial destination path - use the correct default_path parameter
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
@@ -340,6 +458,15 @@ def append_or_create_parquet_file(
def finalize_aggregation(aggr_meta, all_metadata): def finalize_aggregation(aggr_meta, all_metadata):
"""Finalizes the dataset aggregation by writing summary files and statistics.
Writes the tasks file, info file with total counts and splits, and
aggregated statistics from all source datasets.
Args:
aggr_meta: Aggregated dataset metadata.
all_metadata: List of all source dataset metadata objects.
"""
logging.info("write tasks") logging.info("write tasks")
write_tasks(aggr_meta.tasks, aggr_meta.root) write_tasks(aggr_meta.tasks, aggr_meta.root)
+3 -7
View File
@@ -511,23 +511,19 @@ def test_backward_compatibility(repo_id):
) )
# test2 first frames of first episode # test2 first frames of first episode
i = dataset.meta.episodes["dataset_from_index"][0].item() i = dataset.meta.episodes[0]["dataset_from_index"]
load_and_compare(i) load_and_compare(i)
load_and_compare(i + 1) load_and_compare(i + 1)
# test 2 frames at the middle of first episode # test 2 frames at the middle of first episode
i = int( i = int(
( (dataset.meta.episodes[0]["dataset_to_index"] - dataset.meta.episodes[0]["dataset_from_index"]) / 2
dataset.meta.episodes["dataset_to_index"][0].item()
- dataset.meta.episodes["dataset_from_index"][0].item()
)
/ 2
) )
load_and_compare(i) load_and_compare(i)
load_and_compare(i + 1) load_and_compare(i + 1)
# test 2 last frames of first episode # test 2 last frames of first episode
i = dataset.meta.episodes["dataset_to_index"][0].item() i = dataset.meta.episodes[0]["dataset_to_index"]
load_and_compare(i - 2) load_and_compare(i - 2)
load_and_compare(i - 1) load_and_compare(i - 1)
+15 -9
View File
@@ -20,6 +20,7 @@ from pathlib import Path
import einops import einops
import pytest import pytest
import torch import torch
from packaging import version
from safetensors.torch import load_file from safetensors.torch import load_file
from lerobot import available_policies from lerobot import available_policies
@@ -68,11 +69,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
}, },
} }
info = info_factory( info = info_factory(
total_episodes=1, total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
total_frames=1,
total_tasks=1,
camera_features=camera_features,
motor_features=motor_features,
) )
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info) ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
return ds_meta return ds_meta
@@ -141,14 +138,13 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive, Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
and for now we add tests as we see fit. and for now we add tests as we see fit.
""" """
policy_kwargs["device"] = DEVICE
train_cfg = TrainPipelineConfig( train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download # TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]), dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, **policy_kwargs), policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
env=make_env_config(env_name, **env_kwargs), env=make_env_config(env_name, **env_kwargs),
) )
train_cfg.validate()
# Check that we can make the policy object. # Check that we can make the policy object.
dataset = make_dataset(train_cfg) dataset = make_dataset(train_cfg)
@@ -217,7 +213,7 @@ def test_act_backbone_lr():
cfg = TrainPipelineConfig( cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download # TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]), dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001), policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False),
) )
cfg.validate() # Needed for auto-setting some parameters cfg.validate() # Needed for auto-setting some parameters
@@ -413,7 +409,17 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
4. Check that this test now passes. 4. Check that this test now passes.
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state. 5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/artifacts`. 6. Remember to stage and commit the resulting changes to `tests/artifacts`.
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
is out of date. For example, some PyTorch versions have different randomness, see this PR:
https://github.com/huggingface/lerobot/pull/1127.
""" """
# NOTE: ACT policy has different randomness, after PyTorch 2.7.0
if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"):
pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0")
ds_name = ds_repo_id.split("/")[-1] ds_name = ds_repo_id.split("/")[-1]
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}" artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors") saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")