mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
656 lines
26 KiB
Python
656 lines
26 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team.
|
|
# All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import logging
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
import datasets
|
|
import pandas as pd
|
|
import tqdm
|
|
|
|
from .compute_stats import aggregate_stats
|
|
from .dataset_metadata import LeRobotDatasetMetadata
|
|
from .feature_utils import get_hf_features_from_features
|
|
from .io_utils import (
|
|
get_file_size_in_mb,
|
|
get_parquet_file_size_in_mb,
|
|
to_parquet_with_hf_images,
|
|
write_info,
|
|
write_stats,
|
|
write_tasks,
|
|
)
|
|
from .utils import (
|
|
DEFAULT_CHUNK_SIZE,
|
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
|
DEFAULT_DATA_PATH,
|
|
DEFAULT_EPISODES_PATH,
|
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
|
DEFAULT_VIDEO_PATH,
|
|
update_chunk_file_indices,
|
|
)
|
|
from .video_utils import concatenate_video_files, get_video_duration_in_s
|
|
|
|
|
|
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
|
"""Validates that all dataset metadata have consistent properties.
|
|
|
|
Ensures all datasets have the same fps, robot_type, and features to guarantee
|
|
compatibility when aggregating them into a single dataset.
|
|
|
|
Args:
|
|
all_metadata: List of LeRobotDatasetMetadata objects to validate.
|
|
|
|
Returns:
|
|
tuple: A tuple containing (fps, robot_type, features) from the first metadata.
|
|
|
|
Raises:
|
|
ValueError: If any metadata has different fps, robot_type, or features
|
|
than the first metadata in the list.
|
|
"""
|
|
|
|
fps = all_metadata[0].fps
|
|
robot_type = all_metadata[0].robot_type
|
|
features = all_metadata[0].features
|
|
|
|
for meta in tqdm.tqdm(all_metadata, desc="Validate all meta data"):
|
|
if fps != meta.fps:
|
|
raise ValueError(f"Same fps is expected, but got fps={meta.fps} instead of {fps}.")
|
|
if robot_type != meta.robot_type:
|
|
raise ValueError(
|
|
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
|
|
)
|
|
if features != meta.features:
|
|
raise ValueError(
|
|
f"Same features is expected, but got features={meta.features} instead of {features}."
|
|
)
|
|
|
|
return fps, robot_type, features
|
|
|
|
|
|
def update_data_df(df, src_meta, dst_meta):
|
|
"""Updates a data DataFrame with new indices and task mappings for aggregation.
|
|
|
|
Adjusts episode indices, frame indices, and task indices to account for
|
|
previously aggregated data in the destination dataset.
|
|
|
|
Args:
|
|
df: DataFrame containing the data to be updated.
|
|
src_meta: Source dataset metadata.
|
|
dst_meta: Destination dataset metadata.
|
|
|
|
Returns:
|
|
pd.DataFrame: Updated DataFrame with adjusted indices.
|
|
"""
|
|
|
|
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
|
df["index"] = df["index"] + dst_meta.info["total_frames"]
|
|
|
|
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
|
|
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
|
|
|
|
return df
|
|
|
|
|
|
def update_meta_data(
|
|
df,
|
|
dst_meta,
|
|
meta_idx,
|
|
data_idx,
|
|
videos_idx,
|
|
):
|
|
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
|
|
|
|
Adjusts all indices and timestamps to account for previously aggregated
|
|
data and videos in the destination dataset.
|
|
|
|
For data file indices, uses the 'src_to_dst' mapping from aggregate_data()
|
|
to correctly map source file indices to their destination locations.
|
|
|
|
Args:
|
|
df: DataFrame containing the metadata to be updated.
|
|
dst_meta: Destination dataset metadata.
|
|
meta_idx: Dictionary containing current metadata chunk and file indices.
|
|
data_idx: Dictionary containing current data chunk and file indices.
|
|
videos_idx: Dictionary containing current video indices and timestamps.
|
|
|
|
Returns:
|
|
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
|
|
"""
|
|
|
|
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
|
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
|
|
|
|
# Update data file indices using source-to-destination mapping
|
|
# This is critical for handling datasets that are already results of a merge
|
|
data_src_to_dst = data_idx.get("src_to_dst", {})
|
|
if data_src_to_dst:
|
|
# Store original indices for lookup
|
|
df["_orig_data_chunk"] = df["data/chunk_index"].copy()
|
|
df["_orig_data_file"] = df["data/file_index"].copy()
|
|
|
|
# Vectorized mapping from (src_chunk, src_file) to (dst_chunk, dst_file)
|
|
# This is much faster than per-row iteration for large metadata tables
|
|
mapping_index = pd.MultiIndex.from_tuples(
|
|
list(data_src_to_dst.keys()),
|
|
names=["chunk_index", "file_index"],
|
|
)
|
|
mapping_values = list(data_src_to_dst.values())
|
|
mapping_df = pd.DataFrame(
|
|
mapping_values,
|
|
index=mapping_index,
|
|
columns=["dst_chunk", "dst_file"],
|
|
)
|
|
|
|
# Construct a MultiIndex for each row based on original data indices
|
|
row_index = pd.MultiIndex.from_arrays(
|
|
[df["_orig_data_chunk"], df["_orig_data_file"]],
|
|
names=["chunk_index", "file_index"],
|
|
)
|
|
|
|
# Align mapping to rows; missing keys fall back to the default destination
|
|
reindexed = mapping_df.reindex(row_index)
|
|
reindexed[["dst_chunk", "dst_file"]] = reindexed[["dst_chunk", "dst_file"]].fillna(
|
|
{"dst_chunk": data_idx["chunk"], "dst_file": data_idx["file"]}
|
|
)
|
|
|
|
# Assign mapped destination indices back to the DataFrame
|
|
df["data/chunk_index"] = reindexed["dst_chunk"].to_numpy()
|
|
df["data/file_index"] = reindexed["dst_file"].to_numpy()
|
|
|
|
# Clean up temporary columns
|
|
df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"])
|
|
else:
|
|
# Fallback to simple offset (backward compatibility for single-file sources)
|
|
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
|
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
|
for key, video_idx in videos_idx.items():
|
|
# Store original video file indices before updating
|
|
orig_chunk_col = f"videos/{key}/chunk_index"
|
|
orig_file_col = f"videos/{key}/file_index"
|
|
df["_orig_chunk"] = df[orig_chunk_col].copy()
|
|
df["_orig_file"] = df[orig_file_col].copy()
|
|
|
|
# Get mappings for this video key
|
|
src_to_offset = video_idx.get("src_to_offset", {})
|
|
src_to_dst = video_idx.get("src_to_dst", {})
|
|
|
|
# Apply per-source-file mappings
|
|
if src_to_dst:
|
|
# Map each episode to its correct destination file and apply offset
|
|
for idx in df.index:
|
|
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
|
|
|
# Get destination chunk/file for this source file
|
|
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
|
|
df.at[idx, orig_chunk_col] = dst_chunk
|
|
df.at[idx, orig_file_col] = dst_file
|
|
|
|
# Apply timestamp offset
|
|
offset = src_to_offset.get(src_key, 0)
|
|
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
|
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
|
elif src_to_offset:
|
|
# Fallback: use same destination for all, but apply per-file offsets
|
|
df[orig_chunk_col] = video_idx["chunk"]
|
|
df[orig_file_col] = video_idx["file"]
|
|
for idx in df.index:
|
|
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
|
offset = src_to_offset.get(src_key, 0)
|
|
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
|
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
|
else:
|
|
# Fallback to simple offset (for backward compatibility)
|
|
df[orig_chunk_col] = video_idx["chunk"]
|
|
df[orig_file_col] = video_idx["file"]
|
|
df[f"videos/{key}/from_timestamp"] = (
|
|
df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
|
)
|
|
df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
|
|
|
|
# Clean up temporary columns
|
|
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
|
|
|
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
|
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
|
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
|
|
|
return df
|
|
|
|
|
|
def aggregate_datasets(
|
|
repo_ids: list[str],
|
|
aggr_repo_id: str,
|
|
roots: list[Path] | None = None,
|
|
aggr_root: Path | None = None,
|
|
data_files_size_in_mb: float | None = None,
|
|
video_files_size_in_mb: float | None = None,
|
|
chunk_size: int | None = None,
|
|
):
|
|
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
|
|
|
This is the main function that orchestrates the aggregation process by:
|
|
1. Loading and validating all source dataset metadata
|
|
2. Creating a new destination dataset with unified tasks
|
|
3. Aggregating videos, data, and metadata from all source datasets
|
|
4. Finalizing the aggregated dataset with proper statistics
|
|
|
|
Args:
|
|
repo_ids: List of repository IDs for the datasets to aggregate.
|
|
aggr_repo_id: Repository ID for the aggregated output dataset.
|
|
roots: Optional list of root paths for the source datasets.
|
|
aggr_root: Optional root path for the aggregated dataset.
|
|
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
|
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
|
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
|
"""
|
|
logging.info("Start aggregate_datasets")
|
|
|
|
if data_files_size_in_mb is None:
|
|
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
|
if video_files_size_in_mb is None:
|
|
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
|
if chunk_size is None:
|
|
chunk_size = DEFAULT_CHUNK_SIZE
|
|
|
|
all_metadata = (
|
|
[LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
|
if roots is None
|
|
else [
|
|
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
|
|
]
|
|
)
|
|
fps, robot_type, features = validate_all_metadata(all_metadata)
|
|
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
|
|
|
dst_meta = LeRobotDatasetMetadata.create(
|
|
repo_id=aggr_repo_id,
|
|
fps=fps,
|
|
robot_type=robot_type,
|
|
features=features,
|
|
root=aggr_root,
|
|
use_videos=len(video_keys) > 0,
|
|
chunks_size=chunk_size,
|
|
data_files_size_in_mb=data_files_size_in_mb,
|
|
video_files_size_in_mb=video_files_size_in_mb,
|
|
)
|
|
|
|
logging.info("Find all tasks")
|
|
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
|
|
dst_meta.tasks = pd.DataFrame(
|
|
{"task_index": range(len(unique_tasks))}, index=pd.Index(unique_tasks, name="task")
|
|
)
|
|
|
|
meta_idx = {"chunk": 0, "file": 0}
|
|
data_idx = {"chunk": 0, "file": 0}
|
|
videos_idx = {
|
|
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
|
|
}
|
|
|
|
dst_meta.episodes = {}
|
|
|
|
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
|
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
|
|
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
|
|
|
|
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
|
|
|
# Clear the src_to_dst mapping after processing each source dataset
|
|
# to avoid interference between different source datasets
|
|
data_idx.pop("src_to_dst", None)
|
|
|
|
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
|
dst_meta.info["total_frames"] += src_meta.total_frames
|
|
|
|
finalize_aggregation(dst_meta, all_metadata)
|
|
logging.info("Aggregation complete.")
|
|
|
|
|
|
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
|
|
"""Aggregates video chunks from a source dataset into the destination dataset.
|
|
|
|
Handles video file concatenation and rotation based on file size limits.
|
|
Creates new video files when size limits are exceeded.
|
|
|
|
Args:
|
|
src_meta: Source dataset metadata.
|
|
dst_meta: Destination dataset metadata.
|
|
videos_idx: Dictionary tracking video chunk and file indices.
|
|
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
|
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
|
|
|
Returns:
|
|
dict: Updated videos_idx with current chunk and file indices.
|
|
"""
|
|
for key in videos_idx:
|
|
videos_idx[key]["episode_duration"] = 0
|
|
# Track offset for each source (chunk, file) pair
|
|
videos_idx[key]["src_to_offset"] = {}
|
|
# Track destination (chunk, file) for each source (chunk, file) pair
|
|
videos_idx[key]["src_to_dst"] = {}
|
|
# Initialize dst_file_durations if not present
|
|
# dst_file_durations tracks duration of each destination file
|
|
if "dst_file_durations" not in videos_idx[key]:
|
|
videos_idx[key]["dst_file_durations"] = {}
|
|
|
|
for key, video_idx in videos_idx.items():
|
|
unique_chunk_file_pairs = {
|
|
(chunk, file)
|
|
for chunk, file in zip(
|
|
src_meta.episodes[f"videos/{key}/chunk_index"],
|
|
src_meta.episodes[f"videos/{key}/file_index"],
|
|
strict=False,
|
|
)
|
|
}
|
|
unique_chunk_file_pairs = sorted(unique_chunk_file_pairs)
|
|
|
|
chunk_idx = video_idx["chunk"]
|
|
file_idx = video_idx["file"]
|
|
dst_file_durations = video_idx["dst_file_durations"]
|
|
|
|
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
|
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
|
video_key=key,
|
|
chunk_index=src_chunk_idx,
|
|
file_index=src_file_idx,
|
|
)
|
|
|
|
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
|
video_key=key,
|
|
chunk_index=chunk_idx,
|
|
file_index=file_idx,
|
|
)
|
|
|
|
src_duration = get_video_duration_in_s(src_path)
|
|
dst_key = (chunk_idx, file_idx)
|
|
|
|
if not dst_path.exists():
|
|
# New destination file: offset is 0
|
|
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
|
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
|
shutil.copy(str(src_path), str(dst_path))
|
|
# Track duration of this destination file
|
|
dst_file_durations[dst_key] = src_duration
|
|
videos_idx[key]["episode_duration"] += src_duration
|
|
continue
|
|
|
|
# Check file sizes before appending
|
|
src_size = get_file_size_in_mb(src_path)
|
|
dst_size = get_file_size_in_mb(dst_path)
|
|
|
|
if dst_size + src_size >= video_files_size_in_mb:
|
|
# Rotate to a new file - offset is 0
|
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
|
dst_key = (chunk_idx, file_idx)
|
|
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
|
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
|
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
|
video_key=key,
|
|
chunk_index=chunk_idx,
|
|
file_index=file_idx,
|
|
)
|
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
|
shutil.copy(str(src_path), str(dst_path))
|
|
# Track duration of this new destination file
|
|
dst_file_durations[dst_key] = src_duration
|
|
else:
|
|
# Append to existing destination file
|
|
# Offset is the current duration of this destination file
|
|
current_dst_duration = dst_file_durations.get(dst_key, 0)
|
|
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
|
|
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
|
concatenate_video_files(
|
|
[dst_path, src_path],
|
|
dst_path,
|
|
)
|
|
# Update duration of this destination file
|
|
dst_file_durations[dst_key] = current_dst_duration + src_duration
|
|
|
|
videos_idx[key]["episode_duration"] += src_duration
|
|
|
|
videos_idx[key]["chunk"] = chunk_idx
|
|
videos_idx[key]["file"] = file_idx
|
|
|
|
return videos_idx
|
|
|
|
|
|
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
|
|
"""Aggregates data chunks from a source dataset into the destination dataset.
|
|
|
|
Reads source data files, updates indices to match the aggregated dataset,
|
|
and writes them to the destination with proper file rotation.
|
|
|
|
Tracks a `src_to_dst` mapping from source (chunk, file) to destination (chunk, file)
|
|
which is critical for correctly updating episode metadata when source datasets
|
|
have multiple data files (e.g., from a previous merge operation).
|
|
|
|
Args:
|
|
src_meta: Source dataset metadata.
|
|
dst_meta: Destination dataset metadata.
|
|
data_idx: Dictionary tracking data chunk and file indices.
|
|
data_files_size_in_mb: Maximum size for data files in MB.
|
|
chunk_size: Maximum number of files per chunk.
|
|
|
|
Returns:
|
|
dict: Updated data_idx with current chunk and file indices.
|
|
"""
|
|
unique_chunk_file_ids = {
|
|
(c, f)
|
|
for c, f in zip(
|
|
src_meta.episodes["data/chunk_index"], src_meta.episodes["data/file_index"], strict=False
|
|
)
|
|
}
|
|
|
|
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
|
|
contains_images = len(dst_meta.image_keys) > 0
|
|
|
|
# retrieve features schema for proper image typing in parquet
|
|
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
|
|
|
|
# Track source to destination file mapping for metadata update
|
|
# This is critical for handling datasets that are already results of a merge
|
|
src_to_dst: dict[tuple[int, int], tuple[int, int]] = {}
|
|
|
|
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
|
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
|
chunk_index=src_chunk_idx, file_index=src_file_idx
|
|
)
|
|
if contains_images:
|
|
# Use HuggingFace datasets to read source data to preserve image format
|
|
src_ds = datasets.Dataset.from_parquet(str(src_path))
|
|
df = src_ds.to_pandas()
|
|
else:
|
|
df = pd.read_parquet(src_path)
|
|
df = update_data_df(df, src_meta, dst_meta)
|
|
|
|
# Write data and get the actual destination file it was written to
|
|
# This avoids duplicating the rotation logic here
|
|
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
|
|
df,
|
|
src_path,
|
|
data_idx,
|
|
data_files_size_in_mb,
|
|
chunk_size,
|
|
DEFAULT_DATA_PATH,
|
|
contains_images=contains_images,
|
|
aggr_root=dst_meta.root,
|
|
hf_features=hf_features,
|
|
)
|
|
|
|
# Record the mapping from source to actual destination
|
|
src_to_dst[(src_chunk_idx, src_file_idx)] = (dst_chunk, dst_file)
|
|
|
|
# Add the mapping to data_idx for use in metadata update
|
|
data_idx["src_to_dst"] = src_to_dst
|
|
|
|
return data_idx
|
|
|
|
|
|
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
|
"""Aggregates metadata from a source dataset into the destination dataset.
|
|
|
|
Reads source metadata files, updates all indices and timestamps,
|
|
and writes them to the destination with proper file rotation.
|
|
|
|
Args:
|
|
src_meta: Source dataset metadata.
|
|
dst_meta: Destination dataset metadata.
|
|
meta_idx: Dictionary tracking metadata chunk and file indices.
|
|
data_idx: Dictionary tracking data chunk and file indices.
|
|
videos_idx: Dictionary tracking video indices and timestamps.
|
|
|
|
Returns:
|
|
dict: Updated meta_idx with current chunk and file indices.
|
|
"""
|
|
chunk_file_ids = {
|
|
(c, f)
|
|
for c, f in zip(
|
|
src_meta.episodes["meta/episodes/chunk_index"],
|
|
src_meta.episodes["meta/episodes/file_index"],
|
|
strict=False,
|
|
)
|
|
}
|
|
|
|
chunk_file_ids = sorted(chunk_file_ids)
|
|
for chunk_idx, file_idx in chunk_file_ids:
|
|
src_path = src_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
|
df = pd.read_parquet(src_path)
|
|
df = update_meta_data(
|
|
df,
|
|
dst_meta,
|
|
meta_idx,
|
|
data_idx,
|
|
videos_idx,
|
|
)
|
|
|
|
meta_idx, _ = append_or_create_parquet_file(
|
|
df,
|
|
src_path,
|
|
meta_idx,
|
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
|
DEFAULT_CHUNK_SIZE,
|
|
DEFAULT_EPISODES_PATH,
|
|
contains_images=False,
|
|
aggr_root=dst_meta.root,
|
|
)
|
|
|
|
# Increment latest_duration by the total duration added from this source dataset
|
|
for k in videos_idx:
|
|
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
|
|
|
return meta_idx
|
|
|
|
|
|
def append_or_create_parquet_file(
|
|
df: pd.DataFrame,
|
|
src_path: Path,
|
|
idx: dict[str, int],
|
|
max_mb: float,
|
|
chunk_size: int,
|
|
default_path: str,
|
|
contains_images: bool = False,
|
|
aggr_root: Path = None,
|
|
hf_features: datasets.Features | None = None,
|
|
) -> tuple[dict[str, int], tuple[int, int]]:
|
|
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
|
|
|
Manages file rotation when size limits are exceeded to prevent individual files
|
|
from becoming too large. Handles both regular parquet files and those containing images.
|
|
|
|
Args:
|
|
df: DataFrame to write to the parquet file.
|
|
src_path: Path to the source file (used for size estimation).
|
|
idx: Dictionary containing current 'chunk' and 'file' indices.
|
|
max_mb: Maximum allowed file size in MB before rotation.
|
|
chunk_size: Maximum number of files per chunk before incrementing chunk index.
|
|
default_path: Format string for generating file paths.
|
|
contains_images: Whether the data contains images requiring special handling.
|
|
aggr_root: Root path for the aggregated dataset.
|
|
hf_features: Optional HuggingFace Features schema for proper image typing.
|
|
|
|
Returns:
|
|
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
|
and (dst_chunk, dst_file) is the actual destination file the data was written to.
|
|
"""
|
|
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
|
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
|
|
|
if not dst_path.exists():
|
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
|
if contains_images:
|
|
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
|
else:
|
|
df.to_parquet(dst_path)
|
|
return idx, (dst_chunk, dst_file)
|
|
|
|
src_size = get_parquet_file_size_in_mb(src_path)
|
|
dst_size = get_parquet_file_size_in_mb(dst_path)
|
|
|
|
if dst_size + src_size >= max_mb:
|
|
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
|
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
|
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
|
final_df = df
|
|
target_path = new_path
|
|
else:
|
|
if contains_images:
|
|
# Use HuggingFace datasets to read existing data to preserve image format
|
|
existing_ds = datasets.Dataset.from_parquet(str(dst_path))
|
|
existing_df = existing_ds.to_pandas()
|
|
else:
|
|
existing_df = pd.read_parquet(dst_path)
|
|
final_df = pd.concat([existing_df, df], ignore_index=True)
|
|
target_path = dst_path
|
|
|
|
if contains_images:
|
|
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
|
else:
|
|
final_df.to_parquet(target_path)
|
|
|
|
return idx, (dst_chunk, dst_file)
|
|
|
|
|
|
def finalize_aggregation(aggr_meta, all_metadata):
|
|
"""Finalizes the dataset aggregation by writing summary files and statistics.
|
|
|
|
Writes the tasks file, info file with total counts and splits, and
|
|
aggregated statistics from all source datasets.
|
|
|
|
Args:
|
|
aggr_meta: Aggregated dataset metadata.
|
|
all_metadata: List of all source dataset metadata objects.
|
|
"""
|
|
logging.info("write tasks")
|
|
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
|
|
|
logging.info("write info")
|
|
aggr_meta.info.update(
|
|
{
|
|
"total_tasks": len(aggr_meta.tasks),
|
|
"total_episodes": sum(m.total_episodes for m in all_metadata),
|
|
"total_frames": sum(m.total_frames for m in all_metadata),
|
|
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
|
|
}
|
|
)
|
|
write_info(aggr_meta.info, aggr_meta.root)
|
|
|
|
logging.info("write stats")
|
|
aggr_meta.stats = aggregate_stats([m.stats for m in all_metadata])
|
|
write_stats(aggr_meta.stats, aggr_meta.root)
|