Files
lerobot/src/lerobot/datasets/aggregate.py
T

656 lines
26 KiB
Python

#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import shutil
from pathlib import Path
import datasets
import pandas as pd
import tqdm
from .compute_stats import aggregate_stats
from .dataset_metadata import LeRobotDatasetMetadata
from .feature_utils import get_hf_features_from_features
from .io_utils import (
get_file_size_in_mb,
get_parquet_file_size_in_mb,
to_parquet_with_hf_images,
write_info,
write_stats,
write_tasks,
)
from .utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
update_chunk_file_indices,
)
from .video_utils import concatenate_video_files, get_video_duration_in_s
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
"""Validates that all dataset metadata have consistent properties.
Ensures all datasets have the same fps, robot_type, and features to guarantee
compatibility when aggregating them into a single dataset.
Args:
all_metadata: List of LeRobotDatasetMetadata objects to validate.
Returns:
tuple: A tuple containing (fps, robot_type, features) from the first metadata.
Raises:
ValueError: If any metadata has different fps, robot_type, or features
than the first metadata in the list.
"""
fps = all_metadata[0].fps
robot_type = all_metadata[0].robot_type
features = all_metadata[0].features
for meta in tqdm.tqdm(all_metadata, desc="Validate all meta data"):
if fps != meta.fps:
raise ValueError(f"Same fps is expected, but got fps={meta.fps} instead of {fps}.")
if robot_type != meta.robot_type:
raise ValueError(
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
)
if features != meta.features:
raise ValueError(
f"Same features is expected, but got features={meta.features} instead of {features}."
)
return fps, robot_type, features
def update_data_df(df, src_meta, dst_meta):
"""Updates a data DataFrame with new indices and task mappings for aggregation.
Adjusts episode indices, frame indices, and task indices to account for
previously aggregated data in the destination dataset.
Args:
df: DataFrame containing the data to be updated.
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
Returns:
pd.DataFrame: Updated DataFrame with adjusted indices.
"""
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
df["index"] = df["index"] + dst_meta.info["total_frames"]
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
return df
def update_meta_data(
df,
dst_meta,
meta_idx,
data_idx,
videos_idx,
):
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
Adjusts all indices and timestamps to account for previously aggregated
data and videos in the destination dataset.
For data file indices, uses the 'src_to_dst' mapping from aggregate_data()
to correctly map source file indices to their destination locations.
Args:
df: DataFrame containing the metadata to be updated.
dst_meta: Destination dataset metadata.
meta_idx: Dictionary containing current metadata chunk and file indices.
data_idx: Dictionary containing current data chunk and file indices.
videos_idx: Dictionary containing current video indices and timestamps.
Returns:
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
"""
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
# Update data file indices using source-to-destination mapping
# This is critical for handling datasets that are already results of a merge
data_src_to_dst = data_idx.get("src_to_dst", {})
if data_src_to_dst:
# Store original indices for lookup
df["_orig_data_chunk"] = df["data/chunk_index"].copy()
df["_orig_data_file"] = df["data/file_index"].copy()
# Vectorized mapping from (src_chunk, src_file) to (dst_chunk, dst_file)
# This is much faster than per-row iteration for large metadata tables
mapping_index = pd.MultiIndex.from_tuples(
list(data_src_to_dst.keys()),
names=["chunk_index", "file_index"],
)
mapping_values = list(data_src_to_dst.values())
mapping_df = pd.DataFrame(
mapping_values,
index=mapping_index,
columns=["dst_chunk", "dst_file"],
)
# Construct a MultiIndex for each row based on original data indices
row_index = pd.MultiIndex.from_arrays(
[df["_orig_data_chunk"], df["_orig_data_file"]],
names=["chunk_index", "file_index"],
)
# Align mapping to rows; missing keys fall back to the default destination
reindexed = mapping_df.reindex(row_index)
reindexed[["dst_chunk", "dst_file"]] = reindexed[["dst_chunk", "dst_file"]].fillna(
{"dst_chunk": data_idx["chunk"], "dst_file": data_idx["file"]}
)
# Assign mapped destination indices back to the DataFrame
df["data/chunk_index"] = reindexed["dst_chunk"].to_numpy()
df["data/file_index"] = reindexed["dst_file"].to_numpy()
# Clean up temporary columns
df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"])
else:
# Fallback to simple offset (backward compatibility for single-file sources)
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
for key, video_idx in videos_idx.items():
# Store original video file indices before updating
orig_chunk_col = f"videos/{key}/chunk_index"
orig_file_col = f"videos/{key}/file_index"
df["_orig_chunk"] = df[orig_chunk_col].copy()
df["_orig_file"] = df[orig_file_col].copy()
# Get mappings for this video key
src_to_offset = video_idx.get("src_to_offset", {})
src_to_dst = video_idx.get("src_to_dst", {})
# Apply per-source-file mappings
if src_to_dst:
# Map each episode to its correct destination file and apply offset
for idx in df.index:
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
# Get destination chunk/file for this source file
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
df.at[idx, orig_chunk_col] = dst_chunk
df.at[idx, orig_file_col] = dst_file
# Apply timestamp offset
offset = src_to_offset.get(src_key, 0)
df.at[idx, f"videos/{key}/from_timestamp"] += offset
df.at[idx, f"videos/{key}/to_timestamp"] += offset
elif src_to_offset:
# Fallback: use same destination for all, but apply per-file offsets
df[orig_chunk_col] = video_idx["chunk"]
df[orig_file_col] = video_idx["file"]
for idx in df.index:
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
offset = src_to_offset.get(src_key, 0)
df.at[idx, f"videos/{key}/from_timestamp"] += offset
df.at[idx, f"videos/{key}/to_timestamp"] += offset
else:
# Fallback to simple offset (for backward compatibility)
df[orig_chunk_col] = video_idx["chunk"]
df[orig_file_col] = video_idx["file"]
df[f"videos/{key}/from_timestamp"] = (
df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
)
df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
# Clean up temporary columns
df = df.drop(columns=["_orig_chunk", "_orig_file"])
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
return df
def aggregate_datasets(
repo_ids: list[str],
aggr_repo_id: str,
roots: list[Path] | None = None,
aggr_root: Path | None = None,
data_files_size_in_mb: float | None = None,
video_files_size_in_mb: float | None = None,
chunk_size: int | None = None,
):
"""Aggregates multiple LeRobot datasets into a single unified dataset.
This is the main function that orchestrates the aggregation process by:
1. Loading and validating all source dataset metadata
2. Creating a new destination dataset with unified tasks
3. Aggregating videos, data, and metadata from all source datasets
4. Finalizing the aggregated dataset with proper statistics
Args:
repo_ids: List of repository IDs for the datasets to aggregate.
aggr_repo_id: Repository ID for the aggregated output dataset.
roots: Optional list of root paths for the source datasets.
aggr_root: Optional root path for the aggregated dataset.
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
"""
logging.info("Start aggregate_datasets")
if data_files_size_in_mb is None:
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
if video_files_size_in_mb is None:
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
if chunk_size is None:
chunk_size = DEFAULT_CHUNK_SIZE
all_metadata = (
[LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
if roots is None
else [
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
]
)
fps, robot_type, features = validate_all_metadata(all_metadata)
video_keys = [key for key in features if features[key]["dtype"] == "video"]
dst_meta = LeRobotDatasetMetadata.create(
repo_id=aggr_repo_id,
fps=fps,
robot_type=robot_type,
features=features,
root=aggr_root,
use_videos=len(video_keys) > 0,
chunks_size=chunk_size,
data_files_size_in_mb=data_files_size_in_mb,
video_files_size_in_mb=video_files_size_in_mb,
)
logging.info("Find all tasks")
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
dst_meta.tasks = pd.DataFrame(
{"task_index": range(len(unique_tasks))}, index=pd.Index(unique_tasks, name="task")
)
meta_idx = {"chunk": 0, "file": 0}
data_idx = {"chunk": 0, "file": 0}
videos_idx = {
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
}
dst_meta.episodes = {}
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
# Clear the src_to_dst mapping after processing each source dataset
# to avoid interference between different source datasets
data_idx.pop("src_to_dst", None)
dst_meta.info["total_episodes"] += src_meta.total_episodes
dst_meta.info["total_frames"] += src_meta.total_frames
finalize_aggregation(dst_meta, all_metadata)
logging.info("Aggregation complete.")
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
"""Aggregates video chunks from a source dataset into the destination dataset.
Handles video file concatenation and rotation based on file size limits.
Creates new video files when size limits are exceeded.
Args:
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
videos_idx: Dictionary tracking video chunk and file indices.
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
Returns:
dict: Updated videos_idx with current chunk and file indices.
"""
for key in videos_idx:
videos_idx[key]["episode_duration"] = 0
# Track offset for each source (chunk, file) pair
videos_idx[key]["src_to_offset"] = {}
# Track destination (chunk, file) for each source (chunk, file) pair
videos_idx[key]["src_to_dst"] = {}
# Initialize dst_file_durations if not present
# dst_file_durations tracks duration of each destination file
if "dst_file_durations" not in videos_idx[key]:
videos_idx[key]["dst_file_durations"] = {}
for key, video_idx in videos_idx.items():
unique_chunk_file_pairs = {
(chunk, file)
for chunk, file in zip(
src_meta.episodes[f"videos/{key}/chunk_index"],
src_meta.episodes[f"videos/{key}/file_index"],
strict=False,
)
}
unique_chunk_file_pairs = sorted(unique_chunk_file_pairs)
chunk_idx = video_idx["chunk"]
file_idx = video_idx["file"]
dst_file_durations = video_idx["dst_file_durations"]
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key,
chunk_index=src_chunk_idx,
file_index=src_file_idx,
)
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key,
chunk_index=chunk_idx,
file_index=file_idx,
)
src_duration = get_video_duration_in_s(src_path)
dst_key = (chunk_idx, file_idx)
if not dst_path.exists():
# New destination file: offset is 0
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(src_path), str(dst_path))
# Track duration of this destination file
dst_file_durations[dst_key] = src_duration
videos_idx[key]["episode_duration"] += src_duration
continue
# Check file sizes before appending
src_size = get_file_size_in_mb(src_path)
dst_size = get_file_size_in_mb(dst_path)
if dst_size + src_size >= video_files_size_in_mb:
# Rotate to a new file - offset is 0
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
dst_key = (chunk_idx, file_idx)
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key,
chunk_index=chunk_idx,
file_index=file_idx,
)
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(src_path), str(dst_path))
# Track duration of this new destination file
dst_file_durations[dst_key] = src_duration
else:
# Append to existing destination file
# Offset is the current duration of this destination file
current_dst_duration = dst_file_durations.get(dst_key, 0)
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
concatenate_video_files(
[dst_path, src_path],
dst_path,
)
# Update duration of this destination file
dst_file_durations[dst_key] = current_dst_duration + src_duration
videos_idx[key]["episode_duration"] += src_duration
videos_idx[key]["chunk"] = chunk_idx
videos_idx[key]["file"] = file_idx
return videos_idx
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
"""Aggregates data chunks from a source dataset into the destination dataset.
Reads source data files, updates indices to match the aggregated dataset,
and writes them to the destination with proper file rotation.
Tracks a `src_to_dst` mapping from source (chunk, file) to destination (chunk, file)
which is critical for correctly updating episode metadata when source datasets
have multiple data files (e.g., from a previous merge operation).
Args:
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
data_idx: Dictionary tracking data chunk and file indices.
data_files_size_in_mb: Maximum size for data files in MB.
chunk_size: Maximum number of files per chunk.
Returns:
dict: Updated data_idx with current chunk and file indices.
"""
unique_chunk_file_ids = {
(c, f)
for c, f in zip(
src_meta.episodes["data/chunk_index"], src_meta.episodes["data/file_index"], strict=False
)
}
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
contains_images = len(dst_meta.image_keys) > 0
# retrieve features schema for proper image typing in parquet
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
# Track source to destination file mapping for metadata update
# This is critical for handling datasets that are already results of a merge
src_to_dst: dict[tuple[int, int], tuple[int, int]] = {}
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
chunk_index=src_chunk_idx, file_index=src_file_idx
)
if contains_images:
# Use HuggingFace datasets to read source data to preserve image format
src_ds = datasets.Dataset.from_parquet(str(src_path))
df = src_ds.to_pandas()
else:
df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta)
# Write data and get the actual destination file it was written to
# This avoids duplicating the rotation logic here
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
df,
src_path,
data_idx,
data_files_size_in_mb,
chunk_size,
DEFAULT_DATA_PATH,
contains_images=contains_images,
aggr_root=dst_meta.root,
hf_features=hf_features,
)
# Record the mapping from source to actual destination
src_to_dst[(src_chunk_idx, src_file_idx)] = (dst_chunk, dst_file)
# Add the mapping to data_idx for use in metadata update
data_idx["src_to_dst"] = src_to_dst
return data_idx
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
"""Aggregates metadata from a source dataset into the destination dataset.
Reads source metadata files, updates all indices and timestamps,
and writes them to the destination with proper file rotation.
Args:
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
meta_idx: Dictionary tracking metadata chunk and file indices.
data_idx: Dictionary tracking data chunk and file indices.
videos_idx: Dictionary tracking video indices and timestamps.
Returns:
dict: Updated meta_idx with current chunk and file indices.
"""
chunk_file_ids = {
(c, f)
for c, f in zip(
src_meta.episodes["meta/episodes/chunk_index"],
src_meta.episodes["meta/episodes/file_index"],
strict=False,
)
}
chunk_file_ids = sorted(chunk_file_ids)
for chunk_idx, file_idx in chunk_file_ids:
src_path = src_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
df = pd.read_parquet(src_path)
df = update_meta_data(
df,
dst_meta,
meta_idx,
data_idx,
videos_idx,
)
meta_idx, _ = append_or_create_parquet_file(
df,
src_path,
meta_idx,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_CHUNK_SIZE,
DEFAULT_EPISODES_PATH,
contains_images=False,
aggr_root=dst_meta.root,
)
# Increment latest_duration by the total duration added from this source dataset
for k in videos_idx:
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
return meta_idx
def append_or_create_parquet_file(
df: pd.DataFrame,
src_path: Path,
idx: dict[str, int],
max_mb: float,
chunk_size: int,
default_path: str,
contains_images: bool = False,
aggr_root: Path = None,
hf_features: datasets.Features | None = None,
) -> tuple[dict[str, int], tuple[int, int]]:
"""Appends data to an existing parquet file or creates a new one based on size constraints.
Manages file rotation when size limits are exceeded to prevent individual files
from becoming too large. Handles both regular parquet files and those containing images.
Args:
df: DataFrame to write to the parquet file.
src_path: Path to the source file (used for size estimation).
idx: Dictionary containing current 'chunk' and 'file' indices.
max_mb: Maximum allowed file size in MB before rotation.
chunk_size: Maximum number of files per chunk before incrementing chunk index.
default_path: Format string for generating file paths.
contains_images: Whether the data contains images requiring special handling.
aggr_root: Root path for the aggregated dataset.
hf_features: Optional HuggingFace Features schema for proper image typing.
Returns:
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
and (dst_chunk, dst_file) is the actual destination file the data was written to.
"""
dst_chunk, dst_file = idx["chunk"], idx["file"]
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
if not dst_path.exists():
dst_path.parent.mkdir(parents=True, exist_ok=True)
if contains_images:
to_parquet_with_hf_images(df, dst_path, features=hf_features)
else:
df.to_parquet(dst_path)
return idx, (dst_chunk, dst_file)
src_size = get_parquet_file_size_in_mb(src_path)
dst_size = get_parquet_file_size_in_mb(dst_path)
if dst_size + src_size >= max_mb:
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
dst_chunk, dst_file = idx["chunk"], idx["file"]
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
new_path.parent.mkdir(parents=True, exist_ok=True)
final_df = df
target_path = new_path
else:
if contains_images:
# Use HuggingFace datasets to read existing data to preserve image format
existing_ds = datasets.Dataset.from_parquet(str(dst_path))
existing_df = existing_ds.to_pandas()
else:
existing_df = pd.read_parquet(dst_path)
final_df = pd.concat([existing_df, df], ignore_index=True)
target_path = dst_path
if contains_images:
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
else:
final_df.to_parquet(target_path)
return idx, (dst_chunk, dst_file)
def finalize_aggregation(aggr_meta, all_metadata):
"""Finalizes the dataset aggregation by writing summary files and statistics.
Writes the tasks file, info file with total counts and splits, and
aggregated statistics from all source datasets.
Args:
aggr_meta: Aggregated dataset metadata.
all_metadata: List of all source dataset metadata objects.
"""
logging.info("write tasks")
write_tasks(aggr_meta.tasks, aggr_meta.root)
logging.info("write info")
aggr_meta.info.update(
{
"total_tasks": len(aggr_meta.tasks),
"total_episodes": sum(m.total_episodes for m in all_metadata),
"total_frames": sum(m.total_frames for m in all_metadata),
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
}
)
write_info(aggr_meta.info, aggr_meta.root)
logging.info("write stats")
aggr_meta.stats = aggregate_stats([m.stats for m in all_metadata])
write_stats(aggr_meta.stats, aggr_meta.root)