mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| eacb638299 | |||
| 927c6ac3c5 | |||
| 13a429e5c7 | |||
| c87fd37736 | |||
| bb5676ee5a | |||
| a4aa316470 |
@@ -15,8 +15,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -107,6 +109,7 @@ def update_meta_data(
|
|||||||
dst_meta,
|
dst_meta,
|
||||||
meta_idx,
|
meta_idx,
|
||||||
data_idx,
|
data_idx,
|
||||||
|
data_file_map,
|
||||||
videos_idx,
|
videos_idx,
|
||||||
):
|
):
|
||||||
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
|
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
|
||||||
@@ -127,8 +130,25 @@ def update_meta_data(
|
|||||||
|
|
||||||
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
||||||
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
|
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
|
||||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
# Remap data chunk/file indices per-source-file using the actual destination
|
||||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
# file chosen during data aggregation. A flat offset is incorrect when
|
||||||
|
# multiple source files are concatenated into a single destination file.
|
||||||
|
if data_file_map:
|
||||||
|
new_data_chunk = []
|
||||||
|
new_data_file = []
|
||||||
|
for idx in df.index:
|
||||||
|
src_chunk = int(df.at[idx, "data/chunk_index"]) # original source file location
|
||||||
|
src_file = int(df.at[idx, "data/file_index"]) # original source file location
|
||||||
|
dst_chunk, dst_file = data_file_map.get(
|
||||||
|
(src_chunk, src_file), (src_chunk + data_idx["chunk"], src_file + data_idx["file"])
|
||||||
|
)
|
||||||
|
new_data_chunk.append(dst_chunk)
|
||||||
|
new_data_file.append(dst_file)
|
||||||
|
df["data/chunk_index"] = new_data_chunk
|
||||||
|
df["data/file_index"] = new_data_file
|
||||||
|
else:
|
||||||
|
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||||
|
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||||
for key, video_idx in videos_idx.items():
|
for key, video_idx in videos_idx.items():
|
||||||
# Store original video file indices before updating
|
# Store original video file indices before updating
|
||||||
orig_chunk_col = f"videos/{key}/chunk_index"
|
orig_chunk_col = f"videos/{key}/chunk_index"
|
||||||
@@ -166,7 +186,7 @@ def update_meta_data(
|
|||||||
return df
|
return df
|
||||||
|
|
||||||
|
|
||||||
def aggregate_datasets(
|
def _aggregate_datasets(
|
||||||
repo_ids: list[str],
|
repo_ids: list[str],
|
||||||
aggr_repo_id: str,
|
aggr_repo_id: str,
|
||||||
roots: list[Path] | None = None,
|
roots: list[Path] | None = None,
|
||||||
@@ -175,39 +195,24 @@ def aggregate_datasets(
|
|||||||
video_files_size_in_mb: float | None = None,
|
video_files_size_in_mb: float | None = None,
|
||||||
chunk_size: int | None = None,
|
chunk_size: int | None = None,
|
||||||
):
|
):
|
||||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
"""Serial aggregation kernel: combines datasets into a destination dataset.
|
||||||
|
|
||||||
This is the main function that orchestrates the aggregation process by:
|
This function performs a single-process aggregation. It assumes it is the
|
||||||
1. Loading and validating all source dataset metadata
|
sole writer for its destination `aggr_root`.
|
||||||
2. Creating a new destination dataset with unified tasks
|
|
||||||
3. Aggregating videos, data, and metadata from all source datasets
|
|
||||||
4. Finalizing the aggregated dataset with proper statistics
|
|
||||||
|
|
||||||
Args:
|
|
||||||
repo_ids: List of repository IDs for the datasets to aggregate.
|
|
||||||
aggr_repo_id: Repository ID for the aggregated output dataset.
|
|
||||||
roots: Optional list of root paths for the source datasets.
|
|
||||||
aggr_root: Optional root path for the aggregated dataset.
|
|
||||||
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
|
||||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
|
||||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
|
||||||
"""
|
"""
|
||||||
logging.info("Start aggregate_datasets")
|
# Build metadata objects, supporting a per-dataset "root" that may be None.
|
||||||
|
# When root is provided we load from the local filesystem, otherwise from Hub cache.
|
||||||
if data_files_size_in_mb is None:
|
if roots is None:
|
||||||
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
||||||
if video_files_size_in_mb is None:
|
else:
|
||||||
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
all_metadata = [
|
||||||
if chunk_size is None:
|
(
|
||||||
chunk_size = DEFAULT_CHUNK_SIZE
|
LeRobotDatasetMetadata(repo_id, root=root)
|
||||||
|
if root is not None
|
||||||
all_metadata = (
|
else LeRobotDatasetMetadata(repo_id)
|
||||||
[LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
)
|
||||||
if roots is None
|
for repo_id, root in zip(repo_ids, roots, strict=False)
|
||||||
else [
|
|
||||||
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
|
|
||||||
]
|
]
|
||||||
)
|
|
||||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||||
|
|
||||||
@@ -237,9 +242,11 @@ def aggregate_datasets(
|
|||||||
|
|
||||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
|
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
|
||||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
|
data_idx, data_file_map = aggregate_data(
|
||||||
|
src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, data_file_map, videos_idx)
|
||||||
|
|
||||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||||
@@ -248,6 +255,168 @@ def aggregate_datasets(
|
|||||||
logging.info("Aggregation complete.")
|
logging.info("Aggregation complete.")
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_datasets(
|
||||||
|
repo_ids: list[str],
|
||||||
|
aggr_repo_id: str,
|
||||||
|
roots: list[Path] | None = None,
|
||||||
|
aggr_root: Path | None = None,
|
||||||
|
data_files_size_in_mb: float | None = None,
|
||||||
|
video_files_size_in_mb: float | None = None,
|
||||||
|
chunk_size: int | None = None,
|
||||||
|
num_workers: int | None = None,
|
||||||
|
tmp_root: Path | None = None,
|
||||||
|
):
|
||||||
|
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||||
|
|
||||||
|
This is the main function that orchestrates the aggregation process by:
|
||||||
|
1. Loading and validating all source dataset metadata
|
||||||
|
2. Creating a new destination dataset with unified tasks
|
||||||
|
3. Aggregating videos, data, and metadata from all source datasets
|
||||||
|
4. Finalizing the aggregated dataset with proper statistics
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_ids: List of repository IDs for the datasets to aggregate.
|
||||||
|
aggr_repo_id: Repository ID for the aggregated output dataset.
|
||||||
|
roots: Optional list of root paths for the source datasets.
|
||||||
|
aggr_root: Optional root path for the aggregated dataset.
|
||||||
|
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||||
|
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||||
|
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||||
|
num_workers: When > 1, performs a tree-based parallel reduction using a thread pool
|
||||||
|
tmp_root: Optional base directory to store intermediate reduction outputs
|
||||||
|
"""
|
||||||
|
logging.info("Start aggregate_datasets")
|
||||||
|
|
||||||
|
if data_files_size_in_mb is None:
|
||||||
|
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||||
|
if video_files_size_in_mb is None:
|
||||||
|
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||||
|
if chunk_size is None:
|
||||||
|
chunk_size = DEFAULT_CHUNK_SIZE
|
||||||
|
|
||||||
|
if num_workers is None or num_workers <= 1:
|
||||||
|
# Run aggregation sequentially
|
||||||
|
_aggregate_datasets(
|
||||||
|
repo_ids=repo_ids,
|
||||||
|
aggr_repo_id=aggr_repo_id,
|
||||||
|
aggr_root=aggr_root,
|
||||||
|
roots=roots,
|
||||||
|
data_files_size_in_mb=data_files_size_in_mb,
|
||||||
|
video_files_size_in_mb=video_files_size_in_mb,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Uses a parallel fan-out/fan-in strategy when num_workers is provided
|
||||||
|
elif num_workers > 1:
|
||||||
|
# Validate across all metadata early to fail fast
|
||||||
|
all_metadata_for_validation = (
|
||||||
|
[LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
||||||
|
if roots is None
|
||||||
|
else [
|
||||||
|
LeRobotDatasetMetadata(repo_id, root=root)
|
||||||
|
for repo_id, root in zip(repo_ids, roots, strict=False)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
validate_all_metadata(all_metadata_for_validation)
|
||||||
|
|
||||||
|
# Clamp workers to a sensible upper bound (pairs per round)
|
||||||
|
num_workers = min(num_workers, max(1, len(repo_ids) // 2))
|
||||||
|
|
||||||
|
# Choose a base temporary root for intermediate merge results
|
||||||
|
if tmp_root is not None:
|
||||||
|
base_tmp_root = tmp_root
|
||||||
|
elif aggr_root is not None:
|
||||||
|
base_tmp_root = aggr_root.parent / f".{aggr_repo_id}__tmp"
|
||||||
|
else:
|
||||||
|
base_tmp_root = Path.cwd() / f".{aggr_repo_id}__tmp"
|
||||||
|
base_tmp_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
current_repo_ids: list[str] = list(repo_ids)
|
||||||
|
# Always maintain a roots list aligned with repo_ids. Use None for Hub-backed inputs.
|
||||||
|
current_roots: list[Path | None] = list(roots) if roots is not None else [None] * len(repo_ids)
|
||||||
|
|
||||||
|
try:
|
||||||
|
level = 0
|
||||||
|
while len(current_repo_ids) > 1:
|
||||||
|
next_repo_ids: list[str] = []
|
||||||
|
next_roots: list[Path | None] = []
|
||||||
|
futures = []
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||||
|
group_index = 0
|
||||||
|
i = 0
|
||||||
|
while i < len(current_repo_ids):
|
||||||
|
group_repo_ids = current_repo_ids[i : i + 2]
|
||||||
|
group_roots = current_roots[i : i + 2]
|
||||||
|
|
||||||
|
if len(group_repo_ids) == 1:
|
||||||
|
# Carry over singleton to next level
|
||||||
|
next_repo_ids.append(group_repo_ids[0])
|
||||||
|
next_roots.append(group_roots[0])
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
out_repo_id = f"{aggr_repo_id}__reduce_l{level}_g{group_index}"
|
||||||
|
out_root = base_tmp_root / f"reduce_l{level}_g{group_index}"
|
||||||
|
|
||||||
|
futures.append(
|
||||||
|
executor.submit(
|
||||||
|
_aggregate_datasets,
|
||||||
|
group_repo_ids,
|
||||||
|
out_repo_id,
|
||||||
|
group_roots,
|
||||||
|
out_root,
|
||||||
|
data_files_size_in_mb,
|
||||||
|
video_files_size_in_mb,
|
||||||
|
chunk_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
next_repo_ids.append(out_repo_id)
|
||||||
|
next_roots.append(out_root)
|
||||||
|
|
||||||
|
i += 2
|
||||||
|
group_index += 1
|
||||||
|
|
||||||
|
for f in as_completed(futures):
|
||||||
|
# Bubble up any exception raised inside tasks
|
||||||
|
f.result()
|
||||||
|
|
||||||
|
# Cleanup previous level temporary outputs that won't be used again
|
||||||
|
base_resolved = base_tmp_root.resolve()
|
||||||
|
keep_set = {nr.resolve() for nr in next_roots if nr is not None}
|
||||||
|
for prev_root in current_roots:
|
||||||
|
if prev_root is None:
|
||||||
|
continue
|
||||||
|
# Suppress per-iteration to keep cleaning other roots even if one fails
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
pr = prev_root.resolve()
|
||||||
|
if pr not in keep_set and base_resolved in pr.parents:
|
||||||
|
shutil.rmtree(prev_root, ignore_errors=True)
|
||||||
|
|
||||||
|
current_repo_ids = next_repo_ids
|
||||||
|
current_roots = next_roots # aligned list of Path|None after first level
|
||||||
|
level += 1
|
||||||
|
|
||||||
|
# Final copy/aggregation into the desired output
|
||||||
|
_aggregate_datasets(
|
||||||
|
repo_ids=current_repo_ids,
|
||||||
|
aggr_repo_id=aggr_repo_id,
|
||||||
|
roots=current_roots,
|
||||||
|
aggr_root=aggr_root,
|
||||||
|
data_files_size_in_mb=data_files_size_in_mb,
|
||||||
|
video_files_size_in_mb=video_files_size_in_mb,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Remove all temporary reduction artifacts
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
shutil.rmtree(base_tmp_root, ignore_errors=True)
|
||||||
|
|
||||||
|
logging.info("Aggregation complete.")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
|
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
|
||||||
"""Aggregates video chunks from a source dataset into the destination dataset.
|
"""Aggregates video chunks from a source dataset into the destination dataset.
|
||||||
|
|
||||||
@@ -366,6 +535,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
|||||||
|
|
||||||
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
|
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
|
||||||
|
|
||||||
|
# Map source (chunk,file) -> destination (chunk,file) actually used during write
|
||||||
|
src_to_dst_file: dict[tuple[int, int], tuple[int, int]] = {}
|
||||||
|
|
||||||
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
||||||
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
||||||
chunk_index=src_chunk_idx, file_index=src_file_idx
|
chunk_index=src_chunk_idx, file_index=src_file_idx
|
||||||
@@ -373,7 +545,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
|||||||
df = pd.read_parquet(src_path)
|
df = pd.read_parquet(src_path)
|
||||||
df = update_data_df(df, src_meta, dst_meta)
|
df = update_data_df(df, src_meta, dst_meta)
|
||||||
|
|
||||||
data_idx = append_or_create_parquet_file(
|
data_idx, used_chunk, used_file = append_or_create_parquet_file(
|
||||||
df,
|
df,
|
||||||
src_path,
|
src_path,
|
||||||
data_idx,
|
data_idx,
|
||||||
@@ -383,11 +555,12 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
|||||||
contains_images=len(dst_meta.image_keys) > 0,
|
contains_images=len(dst_meta.image_keys) > 0,
|
||||||
aggr_root=dst_meta.root,
|
aggr_root=dst_meta.root,
|
||||||
)
|
)
|
||||||
|
src_to_dst_file[(src_chunk_idx, src_file_idx)] = (used_chunk, used_file)
|
||||||
|
|
||||||
return data_idx
|
return data_idx, src_to_dst_file
|
||||||
|
|
||||||
|
|
||||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, data_file_map, videos_idx):
|
||||||
"""Aggregates metadata from a source dataset into the destination dataset.
|
"""Aggregates metadata from a source dataset into the destination dataset.
|
||||||
|
|
||||||
Reads source metadata files, updates all indices and timestamps,
|
Reads source metadata files, updates all indices and timestamps,
|
||||||
@@ -421,10 +594,11 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
|||||||
dst_meta,
|
dst_meta,
|
||||||
meta_idx,
|
meta_idx,
|
||||||
data_idx,
|
data_idx,
|
||||||
|
data_file_map,
|
||||||
videos_idx,
|
videos_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
meta_idx = append_or_create_parquet_file(
|
meta_idx, _m_used_chunk, _m_used_file = append_or_create_parquet_file(
|
||||||
df,
|
df,
|
||||||
src_path,
|
src_path,
|
||||||
meta_idx,
|
meta_idx,
|
||||||
@@ -478,7 +652,7 @@ def append_or_create_parquet_file(
|
|||||||
to_parquet_with_hf_images(df, dst_path)
|
to_parquet_with_hf_images(df, dst_path)
|
||||||
else:
|
else:
|
||||||
df.to_parquet(dst_path)
|
df.to_parquet(dst_path)
|
||||||
return idx
|
return idx, idx["chunk"], idx["file"]
|
||||||
|
|
||||||
src_size = get_parquet_file_size_in_mb(src_path)
|
src_size = get_parquet_file_size_in_mb(src_path)
|
||||||
dst_size = get_parquet_file_size_in_mb(dst_path)
|
dst_size = get_parquet_file_size_in_mb(dst_path)
|
||||||
@@ -489,17 +663,19 @@ def append_or_create_parquet_file(
|
|||||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
final_df = df
|
final_df = df
|
||||||
target_path = new_path
|
target_path = new_path
|
||||||
|
used_chunk, used_file = idx["chunk"], idx["file"]
|
||||||
else:
|
else:
|
||||||
existing_df = pd.read_parquet(dst_path)
|
existing_df = pd.read_parquet(dst_path)
|
||||||
final_df = pd.concat([existing_df, df], ignore_index=True)
|
final_df = pd.concat([existing_df, df], ignore_index=True)
|
||||||
target_path = dst_path
|
target_path = dst_path
|
||||||
|
used_chunk, used_file = idx["chunk"], idx["file"]
|
||||||
|
|
||||||
if contains_images:
|
if contains_images:
|
||||||
to_parquet_with_hf_images(final_df, target_path)
|
to_parquet_with_hf_images(final_df, target_path)
|
||||||
else:
|
else:
|
||||||
final_df.to_parquet(target_path)
|
final_df.to_parquet(target_path)
|
||||||
|
|
||||||
return idx
|
return idx, used_chunk, used_file
|
||||||
|
|
||||||
|
|
||||||
def finalize_aggregation(aggr_meta, all_metadata):
|
def finalize_aggregation(aggr_meta, all_metadata):
|
||||||
|
|||||||
@@ -234,6 +234,7 @@ def merge_datasets(
|
|||||||
datasets: list[LeRobotDataset],
|
datasets: list[LeRobotDataset],
|
||||||
output_repo_id: str,
|
output_repo_id: str,
|
||||||
output_dir: str | Path | None = None,
|
output_dir: str | Path | None = None,
|
||||||
|
num_workers: int | None = None,
|
||||||
) -> LeRobotDataset:
|
) -> LeRobotDataset:
|
||||||
"""Merge multiple LeRobotDatasets into a single dataset.
|
"""Merge multiple LeRobotDatasets into a single dataset.
|
||||||
|
|
||||||
@@ -257,6 +258,7 @@ def merge_datasets(
|
|||||||
aggr_repo_id=output_repo_id,
|
aggr_repo_id=output_repo_id,
|
||||||
roots=roots,
|
roots=roots,
|
||||||
aggr_root=output_dir,
|
aggr_root=output_dir,
|
||||||
|
num_workers=num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
merged_dataset = LeRobotDataset(
|
merged_dataset = LeRobotDataset(
|
||||||
@@ -329,7 +331,7 @@ def modify_features(
|
|||||||
|
|
||||||
if repo_id is None:
|
if repo_id is None:
|
||||||
repo_id = f"{dataset.repo_id}_modified"
|
repo_id = f"{dataset.repo_id}_modified"
|
||||||
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
output_dir = Path(output_dir, exists_ok=True) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||||
|
|
||||||
new_features = dataset.meta.features.copy()
|
new_features = dataset.meta.features.copy()
|
||||||
|
|
||||||
|
|||||||
@@ -940,11 +940,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
return query_timestamps
|
return query_timestamps
|
||||||
|
|
||||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||||
return {
|
"""
|
||||||
key: torch.stack(self.hf_dataset[q_idx][key])
|
Query dataset for indices across keys, skipping video keys.
|
||||||
for key, q_idx in query_indices.items()
|
|
||||||
if key not in self.meta.video_keys
|
Tries column-first [key][indices] for speed, falls back to row-first.
|
||||||
}
|
|
||||||
|
Args:
|
||||||
|
query_indices: Dict mapping keys to index lists to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with stacked tensors of queried data (video keys excluded)
|
||||||
|
"""
|
||||||
|
result: dict = {}
|
||||||
|
for key, q_idx in query_indices.items():
|
||||||
|
if key in self.meta.video_keys:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
result[key] = torch.stack(self.hf_dataset[key][q_idx])
|
||||||
|
except (KeyError, TypeError, IndexError):
|
||||||
|
result[key] = torch.stack(self.hf_dataset[q_idx][key])
|
||||||
|
return result
|
||||||
|
|
||||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ class SplitConfig:
|
|||||||
class MergeConfig:
|
class MergeConfig:
|
||||||
type: str = "merge"
|
type: str = "merge"
|
||||||
repo_ids: list[str] | None = None
|
repo_ids: list[str] | None = None
|
||||||
|
num_workers: int | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -215,6 +216,7 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
|
|||||||
datasets,
|
datasets,
|
||||||
output_repo_id=cfg.repo_id,
|
output_repo_id=cfg.repo_id,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
|
num_workers=cfg.operation.num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info(f"Merged dataset saved to {output_dir}")
|
logging.info(f"Merged dataset saved to {output_dir}")
|
||||||
|
|||||||
Reference in New Issue
Block a user