mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
Added docstrings to aggregate, fix test_policies.py
This commit is contained in:
@@ -1,3 +1,20 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -26,6 +43,21 @@ from lerobot.datasets.utils import (
|
|||||||
|
|
||||||
|
|
||||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||||
|
"""Validates that all dataset metadata have consistent properties.
|
||||||
|
|
||||||
|
Ensures all datasets have the same fps, robot_type, and features to guarantee
|
||||||
|
compatibility when aggregating them into a single dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_metadata: List of LeRobotDatasetMetadata objects to validate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple containing (fps, robot_type, features) from the first metadata.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any metadata has different fps, robot_type, or features
|
||||||
|
than the first metadata in the list.
|
||||||
|
"""
|
||||||
# validate same fps, robot_type, features
|
# validate same fps, robot_type, features
|
||||||
|
|
||||||
fps = all_metadata[0].fps
|
fps = all_metadata[0].fps
|
||||||
@@ -48,6 +80,20 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
|||||||
|
|
||||||
|
|
||||||
def update_data_df(df, src_meta, dst_meta):
|
def update_data_df(df, src_meta, dst_meta):
|
||||||
|
"""Updates a data DataFrame with new indices and task mappings for aggregation.
|
||||||
|
|
||||||
|
Adjusts episode indices, frame indices, and task indices to account for
|
||||||
|
previously aggregated data in the destination dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame containing the data to be updated.
|
||||||
|
src_meta: Source dataset metadata.
|
||||||
|
dst_meta: Destination dataset metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: Updated DataFrame with adjusted indices.
|
||||||
|
"""
|
||||||
|
|
||||||
def _update(row):
|
def _update(row):
|
||||||
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
|
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
|
||||||
row["index"] = row["index"] + dst_meta.info["total_frames"]
|
row["index"] = row["index"] + dst_meta.info["total_frames"]
|
||||||
@@ -65,6 +111,22 @@ def update_meta_data(
|
|||||||
data_idx,
|
data_idx,
|
||||||
videos_idx,
|
videos_idx,
|
||||||
):
|
):
|
||||||
|
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
|
||||||
|
|
||||||
|
Adjusts all indices and timestamps to account for previously aggregated
|
||||||
|
data and videos in the destination dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame containing the metadata to be updated.
|
||||||
|
dst_meta: Destination dataset metadata.
|
||||||
|
meta_idx: Dictionary containing current metadata chunk and file indices.
|
||||||
|
data_idx: Dictionary containing current data chunk and file indices.
|
||||||
|
videos_idx: Dictionary containing current video indices and timestamps.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
|
||||||
|
"""
|
||||||
|
|
||||||
def _update(row):
|
def _update(row):
|
||||||
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
||||||
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"]
|
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"]
|
||||||
@@ -88,6 +150,20 @@ def update_meta_data(
|
|||||||
|
|
||||||
|
|
||||||
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] = None, aggr_root=None):
|
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] = None, aggr_root=None):
|
||||||
|
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||||
|
|
||||||
|
This is the main function that orchestrates the aggregation process by:
|
||||||
|
1. Loading and validating all source dataset metadata
|
||||||
|
2. Creating a new destination dataset with unified tasks
|
||||||
|
3. Aggregating videos, data, and metadata from all source datasets
|
||||||
|
4. Finalizing the aggregated dataset with proper statistics
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_ids: List of repository IDs for the datasets to aggregate.
|
||||||
|
aggr_repo_id: Repository ID for the aggregated output dataset.
|
||||||
|
roots: Optional list of root paths for the source datasets.
|
||||||
|
aggr_root: Optional root path for the aggregated dataset.
|
||||||
|
"""
|
||||||
logging.info("Start aggregate_datasets")
|
logging.info("Start aggregate_datasets")
|
||||||
|
|
||||||
# Load metadata
|
# Load metadata
|
||||||
@@ -144,8 +220,18 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||||||
|
|
||||||
|
|
||||||
def aggregate_videos(src_meta, dst_meta, videos_idx):
|
def aggregate_videos(src_meta, dst_meta, videos_idx):
|
||||||
"""
|
"""Aggregates video chunks from a source dataset into the destination dataset.
|
||||||
Aggregates video chunks from a dataset into the aggregated dataset folder.
|
|
||||||
|
Handles video file concatenation and rotation based on file size limits.
|
||||||
|
Creates new video files when size limits are exceeded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src_meta: Source dataset metadata.
|
||||||
|
dst_meta: Destination dataset metadata.
|
||||||
|
videos_idx: Dictionary tracking video chunk and file indices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Updated videos_idx with current chunk and file indices.
|
||||||
"""
|
"""
|
||||||
for key, video_idx in videos_idx.items():
|
for key, video_idx in videos_idx.items():
|
||||||
# Get unique (chunk, file) combinations
|
# Get unique (chunk, file) combinations
|
||||||
@@ -213,6 +299,19 @@ def aggregate_videos(src_meta, dst_meta, videos_idx):
|
|||||||
|
|
||||||
|
|
||||||
def aggregate_data(src_meta, dst_meta, data_idx):
|
def aggregate_data(src_meta, dst_meta, data_idx):
|
||||||
|
"""Aggregates data chunks from a source dataset into the destination dataset.
|
||||||
|
|
||||||
|
Reads source data files, updates indices to match the aggregated dataset,
|
||||||
|
and writes them to the destination with proper file rotation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src_meta: Source dataset metadata.
|
||||||
|
dst_meta: Destination dataset metadata.
|
||||||
|
data_idx: Dictionary tracking data chunk and file indices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Updated data_idx with current chunk and file indices.
|
||||||
|
"""
|
||||||
unique_chunk_file_ids = {
|
unique_chunk_file_ids = {
|
||||||
(c, f)
|
(c, f)
|
||||||
for c, f in zip(
|
for c, f in zip(
|
||||||
@@ -241,6 +340,21 @@ def aggregate_data(src_meta, dst_meta, data_idx):
|
|||||||
|
|
||||||
|
|
||||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||||
|
"""Aggregates metadata from a source dataset into the destination dataset.
|
||||||
|
|
||||||
|
Reads source metadata files, updates all indices and timestamps,
|
||||||
|
and writes them to the destination with proper file rotation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src_meta: Source dataset metadata.
|
||||||
|
dst_meta: Destination dataset metadata.
|
||||||
|
meta_idx: Dictionary tracking metadata chunk and file indices.
|
||||||
|
data_idx: Dictionary tracking data chunk and file indices.
|
||||||
|
videos_idx: Dictionary tracking video indices and timestamps.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Updated meta_idx with current chunk and file indices.
|
||||||
|
"""
|
||||||
chunk_file_ids = {
|
chunk_file_ids = {
|
||||||
(c, f)
|
(c, f)
|
||||||
for c, f in zip(
|
for c, f in zip(
|
||||||
@@ -288,19 +402,23 @@ def append_or_create_parquet_file(
|
|||||||
contains_images: bool = False,
|
contains_images: bool = False,
|
||||||
aggr_root: Path = None,
|
aggr_root: Path = None,
|
||||||
):
|
):
|
||||||
"""
|
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||||
Safely appends or creates a Parquet file at dst_path based on size constraints.
|
|
||||||
|
|
||||||
Parameters:
|
Manages file rotation when size limits are exceeded to prevent individual files
|
||||||
df (pd.DataFrame): Data to write.
|
from becoming too large. Handles both regular parquet files and those containing images.
|
||||||
src_path (Path): Path to source file (used to get size).
|
|
||||||
idx (dict): Dictionary containing 'chunk' and 'file' indices.
|
Args:
|
||||||
max_mb (float): Maximum allowed file size in MB.
|
df: DataFrame to write to the parquet file.
|
||||||
chunk_size (int): Maximum number of files per chunk.
|
src_path: Path to the source file (used for size estimation).
|
||||||
default_path (str): Format string for generating a new file path.
|
idx: Dictionary containing current 'chunk' and 'file' indices.
|
||||||
|
max_mb: Maximum allowed file size in MB before rotation.
|
||||||
|
chunk_size: Maximum number of files per chunk before incrementing chunk index.
|
||||||
|
default_path: Format string for generating file paths.
|
||||||
|
contains_images: Whether the data contains images requiring special handling.
|
||||||
|
aggr_root: Root path for the aggregated dataset.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Updated index dictionary.
|
dict: Updated index dictionary with current chunk and file indices.
|
||||||
"""
|
"""
|
||||||
# Initial destination path - use the correct default_path parameter
|
# Initial destination path - use the correct default_path parameter
|
||||||
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||||
@@ -340,6 +458,15 @@ def append_or_create_parquet_file(
|
|||||||
|
|
||||||
|
|
||||||
def finalize_aggregation(aggr_meta, all_metadata):
|
def finalize_aggregation(aggr_meta, all_metadata):
|
||||||
|
"""Finalizes the dataset aggregation by writing summary files and statistics.
|
||||||
|
|
||||||
|
Writes the tasks file, info file with total counts and splits, and
|
||||||
|
aggregated statistics from all source datasets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
aggr_meta: Aggregated dataset metadata.
|
||||||
|
all_metadata: List of all source dataset metadata objects.
|
||||||
|
"""
|
||||||
logging.info("write tasks")
|
logging.info("write tasks")
|
||||||
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
||||||
|
|
||||||
|
|||||||
@@ -511,23 +511,19 @@ def test_backward_compatibility(repo_id):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# test2 first frames of first episode
|
# test2 first frames of first episode
|
||||||
i = dataset.meta.episodes["dataset_from_index"][0].item()
|
i = dataset.meta.episodes[0]["dataset_from_index"]
|
||||||
load_and_compare(i)
|
load_and_compare(i)
|
||||||
load_and_compare(i + 1)
|
load_and_compare(i + 1)
|
||||||
|
|
||||||
# test 2 frames at the middle of first episode
|
# test 2 frames at the middle of first episode
|
||||||
i = int(
|
i = int(
|
||||||
(
|
(dataset.meta.episodes[0]["dataset_to_index"] - dataset.meta.episodes[0]["dataset_from_index"]) / 2
|
||||||
dataset.meta.episodes["dataset_to_index"][0].item()
|
|
||||||
- dataset.meta.episodes["dataset_from_index"][0].item()
|
|
||||||
)
|
|
||||||
/ 2
|
|
||||||
)
|
)
|
||||||
load_and_compare(i)
|
load_and_compare(i)
|
||||||
load_and_compare(i + 1)
|
load_and_compare(i + 1)
|
||||||
|
|
||||||
# test 2 last frames of first episode
|
# test 2 last frames of first episode
|
||||||
i = dataset.meta.episodes["dataset_to_index"][0].item()
|
i = dataset.meta.episodes[0]["dataset_to_index"]
|
||||||
load_and_compare(i - 2)
|
load_and_compare(i - 2)
|
||||||
load_and_compare(i - 1)
|
load_and_compare(i - 1)
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from pathlib import Path
|
|||||||
import einops
|
import einops
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from lerobot import available_policies
|
from lerobot import available_policies
|
||||||
@@ -68,11 +69,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
info = info_factory(
|
info = info_factory(
|
||||||
total_episodes=1,
|
total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
|
||||||
total_frames=1,
|
|
||||||
total_tasks=1,
|
|
||||||
camera_features=camera_features,
|
|
||||||
motor_features=motor_features,
|
|
||||||
)
|
)
|
||||||
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
|
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
|
||||||
return ds_meta
|
return ds_meta
|
||||||
@@ -141,14 +138,13 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
|||||||
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
||||||
and for now we add tests as we see fit.
|
and for now we add tests as we see fit.
|
||||||
"""
|
"""
|
||||||
policy_kwargs["device"] = DEVICE
|
|
||||||
|
|
||||||
train_cfg = TrainPipelineConfig(
|
train_cfg = TrainPipelineConfig(
|
||||||
# TODO(rcadene, aliberts): remove dataset download
|
# TODO(rcadene, aliberts): remove dataset download
|
||||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||||
policy=make_policy_config(policy_name, **policy_kwargs),
|
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
|
||||||
env=make_env_config(env_name, **env_kwargs),
|
env=make_env_config(env_name, **env_kwargs),
|
||||||
)
|
)
|
||||||
|
train_cfg.validate()
|
||||||
|
|
||||||
# Check that we can make the policy object.
|
# Check that we can make the policy object.
|
||||||
dataset = make_dataset(train_cfg)
|
dataset = make_dataset(train_cfg)
|
||||||
@@ -217,7 +213,7 @@ def test_act_backbone_lr():
|
|||||||
cfg = TrainPipelineConfig(
|
cfg = TrainPipelineConfig(
|
||||||
# TODO(rcadene, aliberts): remove dataset download
|
# TODO(rcadene, aliberts): remove dataset download
|
||||||
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
||||||
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
|
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False),
|
||||||
)
|
)
|
||||||
cfg.validate() # Needed for auto-setting some parameters
|
cfg.validate() # Needed for auto-setting some parameters
|
||||||
|
|
||||||
@@ -413,7 +409,17 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
|
|||||||
4. Check that this test now passes.
|
4. Check that this test now passes.
|
||||||
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
||||||
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
|
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
|
||||||
|
|
||||||
|
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
|
||||||
|
is out of date. For example, some PyTorch versions have different randomness, see this PR:
|
||||||
|
https://github.com/huggingface/lerobot/pull/1127.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# NOTE: ACT policy has different randomness, after PyTorch 2.7.0
|
||||||
|
if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"):
|
||||||
|
pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0")
|
||||||
|
|
||||||
ds_name = ds_repo_id.split("/")[-1]
|
ds_name = ds_repo_id.split("/")[-1]
|
||||||
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
|
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
|
||||||
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
|
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
|
||||||
|
|||||||
Reference in New Issue
Block a user