mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
Added docstrings to aggregate, fix test_policies.py
This commit is contained in:
@@ -1,3 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
@@ -26,6 +43,21 @@ from lerobot.datasets.utils import (
|
||||
|
||||
|
||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
"""Validates that all dataset metadata have consistent properties.
|
||||
|
||||
Ensures all datasets have the same fps, robot_type, and features to guarantee
|
||||
compatibility when aggregating them into a single dataset.
|
||||
|
||||
Args:
|
||||
all_metadata: List of LeRobotDatasetMetadata objects to validate.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing (fps, robot_type, features) from the first metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If any metadata has different fps, robot_type, or features
|
||||
than the first metadata in the list.
|
||||
"""
|
||||
# validate same fps, robot_type, features
|
||||
|
||||
fps = all_metadata[0].fps
|
||||
@@ -48,6 +80,20 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
|
||||
|
||||
def update_data_df(df, src_meta, dst_meta):
|
||||
"""Updates a data DataFrame with new indices and task mappings for aggregation.
|
||||
|
||||
Adjusts episode indices, frame indices, and task indices to account for
|
||||
previously aggregated data in the destination dataset.
|
||||
|
||||
Args:
|
||||
df: DataFrame containing the data to be updated.
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices.
|
||||
"""
|
||||
|
||||
def _update(row):
|
||||
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
|
||||
row["index"] = row["index"] + dst_meta.info["total_frames"]
|
||||
@@ -65,6 +111,22 @@ def update_meta_data(
|
||||
data_idx,
|
||||
videos_idx,
|
||||
):
|
||||
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
|
||||
|
||||
Adjusts all indices and timestamps to account for previously aggregated
|
||||
data and videos in the destination dataset.
|
||||
|
||||
Args:
|
||||
df: DataFrame containing the metadata to be updated.
|
||||
dst_meta: Destination dataset metadata.
|
||||
meta_idx: Dictionary containing current metadata chunk and file indices.
|
||||
data_idx: Dictionary containing current data chunk and file indices.
|
||||
videos_idx: Dictionary containing current video indices and timestamps.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
|
||||
"""
|
||||
|
||||
def _update(row):
|
||||
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
||||
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"]
|
||||
@@ -88,6 +150,20 @@ def update_meta_data(
|
||||
|
||||
|
||||
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] = None, aggr_root=None):
|
||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||
|
||||
This is the main function that orchestrates the aggregation process by:
|
||||
1. Loading and validating all source dataset metadata
|
||||
2. Creating a new destination dataset with unified tasks
|
||||
3. Aggregating videos, data, and metadata from all source datasets
|
||||
4. Finalizing the aggregated dataset with proper statistics
|
||||
|
||||
Args:
|
||||
repo_ids: List of repository IDs for the datasets to aggregate.
|
||||
aggr_repo_id: Repository ID for the aggregated output dataset.
|
||||
roots: Optional list of root paths for the source datasets.
|
||||
aggr_root: Optional root path for the aggregated dataset.
|
||||
"""
|
||||
logging.info("Start aggregate_datasets")
|
||||
|
||||
# Load metadata
|
||||
@@ -144,8 +220,18 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
||||
|
||||
|
||||
def aggregate_videos(src_meta, dst_meta, videos_idx):
|
||||
"""
|
||||
Aggregates video chunks from a dataset into the aggregated dataset folder.
|
||||
"""Aggregates video chunks from a source dataset into the destination dataset.
|
||||
|
||||
Handles video file concatenation and rotation based on file size limits.
|
||||
Creates new video files when size limits are exceeded.
|
||||
|
||||
Args:
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
videos_idx: Dictionary tracking video chunk and file indices.
|
||||
|
||||
Returns:
|
||||
dict: Updated videos_idx with current chunk and file indices.
|
||||
"""
|
||||
for key, video_idx in videos_idx.items():
|
||||
# Get unique (chunk, file) combinations
|
||||
@@ -213,6 +299,19 @@ def aggregate_videos(src_meta, dst_meta, videos_idx):
|
||||
|
||||
|
||||
def aggregate_data(src_meta, dst_meta, data_idx):
|
||||
"""Aggregates data chunks from a source dataset into the destination dataset.
|
||||
|
||||
Reads source data files, updates indices to match the aggregated dataset,
|
||||
and writes them to the destination with proper file rotation.
|
||||
|
||||
Args:
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
data_idx: Dictionary tracking data chunk and file indices.
|
||||
|
||||
Returns:
|
||||
dict: Updated data_idx with current chunk and file indices.
|
||||
"""
|
||||
unique_chunk_file_ids = {
|
||||
(c, f)
|
||||
for c, f in zip(
|
||||
@@ -241,6 +340,21 @@ def aggregate_data(src_meta, dst_meta, data_idx):
|
||||
|
||||
|
||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
"""Aggregates metadata from a source dataset into the destination dataset.
|
||||
|
||||
Reads source metadata files, updates all indices and timestamps,
|
||||
and writes them to the destination with proper file rotation.
|
||||
|
||||
Args:
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
meta_idx: Dictionary tracking metadata chunk and file indices.
|
||||
data_idx: Dictionary tracking data chunk and file indices.
|
||||
videos_idx: Dictionary tracking video indices and timestamps.
|
||||
|
||||
Returns:
|
||||
dict: Updated meta_idx with current chunk and file indices.
|
||||
"""
|
||||
chunk_file_ids = {
|
||||
(c, f)
|
||||
for c, f in zip(
|
||||
@@ -288,19 +402,23 @@ def append_or_create_parquet_file(
|
||||
contains_images: bool = False,
|
||||
aggr_root: Path = None,
|
||||
):
|
||||
"""
|
||||
Safely appends or creates a Parquet file at dst_path based on size constraints.
|
||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||
|
||||
Parameters:
|
||||
df (pd.DataFrame): Data to write.
|
||||
src_path (Path): Path to source file (used to get size).
|
||||
idx (dict): Dictionary containing 'chunk' and 'file' indices.
|
||||
max_mb (float): Maximum allowed file size in MB.
|
||||
chunk_size (int): Maximum number of files per chunk.
|
||||
default_path (str): Format string for generating a new file path.
|
||||
Manages file rotation when size limits are exceeded to prevent individual files
|
||||
from becoming too large. Handles both regular parquet files and those containing images.
|
||||
|
||||
Args:
|
||||
df: DataFrame to write to the parquet file.
|
||||
src_path: Path to the source file (used for size estimation).
|
||||
idx: Dictionary containing current 'chunk' and 'file' indices.
|
||||
max_mb: Maximum allowed file size in MB before rotation.
|
||||
chunk_size: Maximum number of files per chunk before incrementing chunk index.
|
||||
default_path: Format string for generating file paths.
|
||||
contains_images: Whether the data contains images requiring special handling.
|
||||
aggr_root: Root path for the aggregated dataset.
|
||||
|
||||
Returns:
|
||||
dict: Updated index dictionary.
|
||||
dict: Updated index dictionary with current chunk and file indices.
|
||||
"""
|
||||
# Initial destination path - use the correct default_path parameter
|
||||
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||
@@ -340,6 +458,15 @@ def append_or_create_parquet_file(
|
||||
|
||||
|
||||
def finalize_aggregation(aggr_meta, all_metadata):
|
||||
"""Finalizes the dataset aggregation by writing summary files and statistics.
|
||||
|
||||
Writes the tasks file, info file with total counts and splits, and
|
||||
aggregated statistics from all source datasets.
|
||||
|
||||
Args:
|
||||
aggr_meta: Aggregated dataset metadata.
|
||||
all_metadata: List of all source dataset metadata objects.
|
||||
"""
|
||||
logging.info("write tasks")
|
||||
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
||||
|
||||
|
||||
@@ -511,23 +511,19 @@ def test_backward_compatibility(repo_id):
|
||||
)
|
||||
|
||||
# test2 first frames of first episode
|
||||
i = dataset.meta.episodes["dataset_from_index"][0].item()
|
||||
i = dataset.meta.episodes[0]["dataset_from_index"]
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int(
|
||||
(
|
||||
dataset.meta.episodes["dataset_to_index"][0].item()
|
||||
- dataset.meta.episodes["dataset_from_index"][0].item()
|
||||
)
|
||||
/ 2
|
||||
(dataset.meta.episodes[0]["dataset_to_index"] - dataset.meta.episodes[0]["dataset_from_index"]) / 2
|
||||
)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 last frames of first episode
|
||||
i = dataset.meta.episodes["dataset_to_index"][0].item()
|
||||
i = dataset.meta.episodes[0]["dataset_to_index"]
|
||||
load_and_compare(i - 2)
|
||||
load_and_compare(i - 1)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from pathlib import Path
|
||||
import einops
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot import available_policies
|
||||
@@ -68,11 +69,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
|
||||
},
|
||||
}
|
||||
info = info_factory(
|
||||
total_episodes=1,
|
||||
total_frames=1,
|
||||
total_tasks=1,
|
||||
camera_features=camera_features,
|
||||
motor_features=motor_features,
|
||||
total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
|
||||
)
|
||||
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
|
||||
return ds_meta
|
||||
@@ -141,14 +138,13 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
||||
and for now we add tests as we see fit.
|
||||
"""
|
||||
policy_kwargs["device"] = DEVICE
|
||||
|
||||
train_cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||
policy=make_policy_config(policy_name, **policy_kwargs),
|
||||
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
|
||||
env=make_env_config(env_name, **env_kwargs),
|
||||
)
|
||||
train_cfg.validate()
|
||||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(train_cfg)
|
||||
@@ -217,7 +213,7 @@ def test_act_backbone_lr():
|
||||
cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
||||
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
|
||||
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False),
|
||||
)
|
||||
cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
@@ -413,7 +409,17 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
|
||||
4. Check that this test now passes.
|
||||
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
||||
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
|
||||
|
||||
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
|
||||
is out of date. For example, some PyTorch versions have different randomness, see this PR:
|
||||
https://github.com/huggingface/lerobot/pull/1127.
|
||||
|
||||
"""
|
||||
|
||||
# NOTE: ACT policy has different randomness, after PyTorch 2.7.0
|
||||
if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"):
|
||||
pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0")
|
||||
|
||||
ds_name = ds_repo_id.split("/")[-1]
|
||||
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
|
||||
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
|
||||
|
||||
Reference in New Issue
Block a user