mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-30 06:37:15 +00:00
feat(aggregate audio): adding support for audio in dataset aggregation functions
This commit is contained in:
@@ -26,6 +26,8 @@ import tqdm
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_AUDIO_PATH,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
@@ -110,6 +112,7 @@ def update_meta_data(
|
||||
meta_idx,
|
||||
data_idx,
|
||||
videos_idx,
|
||||
audios_idx,
|
||||
):
|
||||
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
|
||||
|
||||
@@ -122,7 +125,7 @@ def update_meta_data(
|
||||
meta_idx: Dictionary containing current metadata chunk and file indices.
|
||||
data_idx: Dictionary containing current data chunk and file indices.
|
||||
videos_idx: Dictionary containing current video indices and timestamps.
|
||||
|
||||
audios_idx: Dictionary containing current audio indices and timestamps.
|
||||
Returns:
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
|
||||
"""
|
||||
@@ -180,6 +183,36 @@ def update_meta_data(
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
for key, audio_idx in audios_idx.items():
|
||||
# Store original audio file indices before updating
|
||||
orig_chunk_col = f"audio/{key}/chunk_index"
|
||||
orig_file_col = f"audio/{key}/file_index"
|
||||
df["_orig_chunk"] = df[orig_chunk_col].copy()
|
||||
df["_orig_file"] = df[orig_file_col].copy()
|
||||
|
||||
# Update chunk and file indices to point to destination
|
||||
df[orig_chunk_col] = audio_idx["chunk"]
|
||||
df[orig_file_col] = audio_idx["file"]
|
||||
|
||||
# Apply per-source-file timestamp offsets
|
||||
src_to_offset = audio_idx.get("src_to_offset", {})
|
||||
if src_to_offset:
|
||||
# Apply offset based on original source file
|
||||
for idx in df.index:
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"audio/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"audio/{key}/to_timestamp"] += offset
|
||||
else:
|
||||
# Fallback to simple offset (for backward compatibility)
|
||||
df[f"audio/{key}/from_timestamp"] = (
|
||||
df[f"audio/{key}/from_timestamp"] + audio_idx["latest_duration"]
|
||||
)
|
||||
df[f"audio/{key}/to_timestamp"] = df[f"audio/{key}/to_timestamp"] + audio_idx["latest_duration"]
|
||||
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
@@ -194,6 +227,7 @@ def aggregate_datasets(
|
||||
aggr_root: Path | None = None,
|
||||
data_files_size_in_mb: float | None = None,
|
||||
video_files_size_in_mb: float | None = None,
|
||||
audio_files_size_in_mb: float | None = None,
|
||||
chunk_size: int | None = None,
|
||||
):
|
||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||
@@ -211,6 +245,7 @@ def aggregate_datasets(
|
||||
aggr_root: Optional root path for the aggregated dataset.
|
||||
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
audio_files_size_in_mb: Maximum size for audio files in MB (defaults to DEFAULT_AUDIO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
"""
|
||||
logging.info("Start aggregate_datasets")
|
||||
@@ -219,6 +254,8 @@ def aggregate_datasets(
|
||||
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
if video_files_size_in_mb is None:
|
||||
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
if audio_files_size_in_mb is None:
|
||||
audio_files_size_in_mb = DEFAULT_AUDIO_FILE_SIZE_IN_MB
|
||||
if chunk_size is None:
|
||||
chunk_size = DEFAULT_CHUNK_SIZE
|
||||
|
||||
@@ -231,6 +268,7 @@ def aggregate_datasets(
|
||||
)
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
audio_keys = [key for key in features if features[key]["dtype"] == "audio"]
|
||||
|
||||
dst_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=aggr_repo_id,
|
||||
@@ -242,6 +280,7 @@ def aggregate_datasets(
|
||||
chunks_size=chunk_size,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
video_files_size_in_mb=video_files_size_in_mb,
|
||||
audio_files_size_in_mb=audio_files_size_in_mb,
|
||||
)
|
||||
|
||||
logging.info("Find all tasks")
|
||||
@@ -253,14 +292,18 @@ def aggregate_datasets(
|
||||
videos_idx = {
|
||||
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
|
||||
}
|
||||
audios_idx = {
|
||||
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in audio_keys
|
||||
}
|
||||
|
||||
dst_meta.episodes = {}
|
||||
|
||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
|
||||
audios_idx = aggregate_audio(src_meta, dst_meta, audios_idx, audio_files_size_in_mb, chunk_size)
|
||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audios_idx)
|
||||
|
||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||
@@ -382,6 +425,101 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
return videos_idx
|
||||
|
||||
|
||||
def aggregate_audio(src_meta, dst_meta, audios_idx, audio_files_size_in_mb, chunk_size):
|
||||
"""Aggregates audio files from a source dataset into the destination dataset.
|
||||
|
||||
Handles audio file concatenation and rotation based on file size limits.
|
||||
Creates new audio files when size limits are exceeded.
|
||||
|
||||
Args:
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
audio_idx: Dictionary tracking audio chunk and file indices.
|
||||
audio_files_size_in_mb: Maximum size for audio files in MB (defaults to DEFAULT_AUDIO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
|
||||
Returns:
|
||||
dict: Updated audio_idx with current chunk and file indices.
|
||||
"""
|
||||
for key in audios_idx:
|
||||
audios_idx[key]["episode_duration"] = 0
|
||||
# Track offset for each source (chunk, file) pair
|
||||
audios_idx[key]["src_to_offset"] = {}
|
||||
|
||||
for key, audio_idx in audios_idx.items():
|
||||
unique_chunk_file_pairs = {
|
||||
(chunk, file)
|
||||
for chunk, file in zip(
|
||||
src_meta.episodes[f"audio/{key}/chunk_index"],
|
||||
src_meta.episodes[f"audio/{key}/file_index"],
|
||||
strict=False,
|
||||
)
|
||||
}
|
||||
unique_chunk_file_pairs = sorted(unique_chunk_file_pairs)
|
||||
|
||||
chunk_idx = audio_idx["chunk"]
|
||||
file_idx = audio_idx["file"]
|
||||
current_offset = audio_idx["latest_duration"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
src_path = src_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||
audio_key=key,
|
||||
chunk_index=src_chunk_idx,
|
||||
file_index=src_file_idx,
|
||||
)
|
||||
|
||||
dst_path = dst_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||
audio_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
src_duration = get_media_duration_in_s(src_path, media_type="audio")
|
||||
|
||||
if not dst_path.exists():
|
||||
# Store offset before incrementing
|
||||
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
audios_idx[key]["episode_duration"] += src_duration
|
||||
current_offset += src_duration
|
||||
continue
|
||||
|
||||
# Check file sizes before appending
|
||||
src_size = get_file_size_in_mb(src_path)
|
||||
dst_size = get_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= audio_files_size_in_mb:
|
||||
# Rotate to a new file, this source becomes start of new destination
|
||||
# So its offset should be 0
|
||||
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
dst_path = dst_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||
audio_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
# Reset offset for next file
|
||||
current_offset = src_duration
|
||||
else:
|
||||
# Append to existing video file - use current accumulated offset
|
||||
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
concatenate_media_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
)
|
||||
current_offset += src_duration
|
||||
|
||||
audios_idx[key]["episode_duration"] += src_duration
|
||||
|
||||
audios_idx[key]["chunk"] = chunk_idx
|
||||
audios_idx[key]["file"] = file_idx
|
||||
|
||||
return audios_idx
|
||||
|
||||
|
||||
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
|
||||
"""Aggregates data chunks from a source dataset into the destination dataset.
|
||||
|
||||
@@ -436,7 +574,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
return data_idx
|
||||
|
||||
|
||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audios_idx):
|
||||
"""Aggregates metadata from a source dataset into the destination dataset.
|
||||
|
||||
Reads source metadata files, updates all indices and timestamps,
|
||||
@@ -448,6 +586,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
meta_idx: Dictionary tracking metadata chunk and file indices.
|
||||
data_idx: Dictionary tracking data chunk and file indices.
|
||||
videos_idx: Dictionary tracking video indices and timestamps.
|
||||
audios_idx: Dictionary tracking audio indices and timestamps.
|
||||
|
||||
Returns:
|
||||
dict: Updated meta_idx with current chunk and file indices.
|
||||
@@ -471,6 +610,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
meta_idx,
|
||||
data_idx,
|
||||
videos_idx,
|
||||
audios_idx,
|
||||
)
|
||||
|
||||
meta_idx = append_or_create_parquet_file(
|
||||
@@ -487,7 +627,8 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
# Increment latest_duration by the total duration added from this source dataset
|
||||
for k in videos_idx:
|
||||
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||
|
||||
for k in audios_idx:
|
||||
audios_idx[k]["latest_duration"] += audios_idx[k]["episode_duration"]
|
||||
return meta_idx
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user