diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 5db3f934d..07da5b039 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -286,6 +286,8 @@ def aggregate_datasets( data_files_size_in_mb: int | None = None, video_files_size_in_mb: int | None = None, chunk_size: int | None = None, + concatenate_videos: bool = True, + concatenate_data: bool = True, ): """Aggregates multiple LeRobot datasets into a single unified dataset. @@ -303,6 +305,8 @@ def aggregate_datasets( data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB) video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB) chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE) + concatenate_videos: When False, keep one mp4 per source file instead of packing into shards. + concatenate_data: When False, keep one parquet per source file instead of packing into shards. """ logging.info("Start aggregate_datasets") @@ -351,8 +355,12 @@ def aggregate_datasets( dst_meta.episodes = {} for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"): - videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size) - data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size) + videos_idx = aggregate_videos( + src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos + ) + data_idx = aggregate_data( + src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data + ) meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx) @@ -367,7 +375,9 @@ def aggregate_datasets( logging.info("Aggregation complete.") -def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size): +def aggregate_videos( + src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos=True +): """Aggregates video chunks from a source dataset into the destination dataset. Handles video file concatenation and rotation based on file size limits. @@ -379,6 +389,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu videos_idx: Dictionary tracking video chunk and file indices. video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB) chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE) + concatenate_videos: When False, keep one mp4 per source file instead of packing into shards. Returns: dict: Updated videos_idx with current chunk and file indices. """ @@ -439,7 +450,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu src_size = get_file_size_in_mb(src_path) dst_size = get_file_size_in_mb(dst_path) - if dst_size + src_size >= video_files_size_in_mb: + if not concatenate_videos or dst_size + src_size >= video_files_size_in_mb: # Rotate to a new file - offset is 0 chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size) dst_key = (chunk_idx, file_idx) @@ -477,7 +488,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu return videos_idx -def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size): +def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data=True): """Aggregates data chunks from a source dataset into the destination dataset. Reads source data files, updates indices to match the aggregated dataset, @@ -493,6 +504,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si data_idx: Dictionary tracking data chunk and file indices. data_files_size_in_mb: Maximum size for data files in MB. chunk_size: Maximum number of files per chunk. + concatenate_data: When False, keep one parquet per source file instead of packing into shards. Returns: dict: Updated data_idx with current chunk and file indices. @@ -538,6 +550,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si contains_images=contains_images, aggr_root=dst_meta.root, hf_features=hf_features, + concatenate=concatenate_data, ) # Record the mapping from source to actual destination @@ -614,6 +627,7 @@ def append_or_create_parquet_file( contains_images: bool = False, aggr_root: Path = None, hf_features: datasets.Features | None = None, + concatenate: bool = True, ) -> tuple[dict[str, int], tuple[int, int]]: """Appends data to an existing parquet file or creates a new one based on size constraints. @@ -630,6 +644,7 @@ def append_or_create_parquet_file( contains_images: Whether the data contains images requiring special handling. aggr_root: Root path for the aggregated dataset. hf_features: Optional HuggingFace Features schema for proper image typing. + concatenate: When False, always rotate to a new file instead of appending to the current one. Returns: tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict @@ -649,7 +664,7 @@ def append_or_create_parquet_file( src_size = get_parquet_file_size_in_mb(src_path) dst_size = get_parquet_file_size_in_mb(dst_path) - if dst_size + src_size >= max_mb: + if not concatenate or dst_size + src_size >= max_mb: idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size) dst_chunk, dst_file = idx["chunk"], idx["file"] new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file) diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index adbb841c4..91dc66af2 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -261,6 +261,8 @@ def merge_datasets( datasets: list[LeRobotDataset], output_repo_id: str, output_dir: str | Path | None = None, + concatenate_videos: bool = True, + concatenate_data: bool = True, ) -> LeRobotDataset: """Merge multiple LeRobotDatasets into a single dataset. @@ -270,6 +272,8 @@ def merge_datasets( datasets: List of LeRobotDatasets to merge. output_repo_id: Merged dataset identifier. output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id. + concatenate_videos: When False, keep one mp4 per source file instead of packing into shards. + concatenate_data: When False, keep one parquet per source file instead of packing into shards. """ if not datasets: raise ValueError("No datasets to merge") @@ -284,6 +288,8 @@ def merge_datasets( aggr_repo_id=output_repo_id, roots=roots, aggr_root=output_dir, + concatenate_videos=concatenate_videos, + concatenate_data=concatenate_data, ) merged_dataset = LeRobotDataset( diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 3c1edbb31..eaadf47de 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -94,6 +94,14 @@ Merge multiple datasets from a list of local dataset paths: --operation.repo_ids "['pusht_train', 'pusht_val']" \ --operation.roots "['/path/to/pusht_train', '/path/to/pusht_val']" +Merge multiple datasets while keeping one file per source file (no video/data stitching): + lerobot-edit-dataset \ + --new_repo_id lerobot/pusht_merged \ + --operation.type merge \ + --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" \ + --operation.concatenate_videos false \ + --operation.concatenate_data false + Remove camera feature: lerobot-edit-dataset \ --repo_id lerobot/pusht \ @@ -257,6 +265,9 @@ class SplitConfig(OperationConfig): class MergeConfig(OperationConfig): repo_ids: list[str] | None = None roots: list[str] | None = None + # When False, keep one file per source file instead of packing into shards. + concatenate_videos: bool = True + concatenate_data: bool = True @OperationConfig.register_subclass("remove_feature") @@ -461,6 +472,8 @@ def handle_merge(cfg: EditDatasetConfig) -> None: datasets, output_repo_id=cfg.new_repo_id, output_dir=output_dir, + concatenate_videos=cfg.operation.concatenate_videos, + concatenate_data=cfg.operation.concatenate_data, ) logging.info(f"Merged dataset saved to {output_dir}") diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index 80a95aa1f..f3edc3af8 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -289,6 +289,52 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): assert_dataset_iteration_works(aggr_ds) +def test_aggregate_datasets_without_concatenation(tmp_path, lerobot_dataset_factory): + """With concatenation disabled, each source file is kept as its own destination file.""" + ds_0 = lerobot_dataset_factory( + root=tmp_path / "no_stitch_0", + repo_id=f"{DUMMY_REPO_ID}_no_stitch_0", + total_episodes=3, + total_frames=60, + ) + ds_1 = lerobot_dataset_factory( + root=tmp_path / "no_stitch_1", + repo_id=f"{DUMMY_REPO_ID}_no_stitch_1", + total_episodes=4, + total_frames=80, + ) + + aggr_root = tmp_path / "no_stitch_aggr" + aggregate_datasets( + repo_ids=[ds_0.repo_id, ds_1.repo_id], + roots=[ds_0.root, ds_1.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_no_stitch_aggr", + aggr_root=aggr_root, + concatenate_videos=False, + concatenate_data=False, + ) + + with ( + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(aggr_root) + aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_no_stitch_aggr", root=aggr_root) + + assert_episode_and_frame_counts( + aggr_ds, ds_0.num_episodes + ds_1.num_episodes, ds_0.num_frames + ds_1.num_frames + ) + assert_dataset_iteration_works(aggr_ds) + assert_video_timestamps_within_bounds(aggr_ds) + + # Two single-file sources stay as two files each, instead of being packed together. + assert len(list((aggr_root / "data").rglob("*.parquet"))) == 2 + assert aggr_ds.meta.video_keys, "Test fixture should produce at least one video feature" + for key in aggr_ds.meta.video_keys: + assert len(list((aggr_root / "videos" / key).rglob("*.mp4"))) == 2 + + @pytest.mark.parametrize("mutation", ["mismatched_value", "missing_key"]) def test_aggregate_incomplete_video_encoder_info_warns_and_nuls_encoders( tmp_path, lerobot_dataset_factory, caplog, mutation diff --git a/tests/scripts/test_edit_dataset_parsing.py b/tests/scripts/test_edit_dataset_parsing.py index 83ed5a78b..c90cffb38 100644 --- a/tests/scripts/test_edit_dataset_parsing.py +++ b/tests/scripts/test_edit_dataset_parsing.py @@ -66,6 +66,20 @@ class TestOperationTypeParsing: with pytest.raises(ValueError, match="--new_repo_id is required for merge"): _validate_config(cfg) + @pytest.mark.parametrize("flag", ["concatenate_videos", "concatenate_data"]) + def test_merge_concatenate_flag_defaults_true(self, flag): + cfg = parse_cfg(["--new_repo_id", "test/merged", "--operation.type", "merge"]) + assert isinstance(cfg.operation, MergeConfig) + assert getattr(cfg.operation, flag) is True + + @pytest.mark.parametrize("flag", ["concatenate_videos", "concatenate_data"]) + def test_merge_concatenate_flag_can_be_disabled(self, flag): + cfg = parse_cfg( + ["--new_repo_id", "test/merged", "--operation.type", "merge", f"--operation.{flag}", "false"] + ) + assert isinstance(cfg.operation, MergeConfig) + assert getattr(cfg.operation, flag) is False + def test_non_merge_requires_repo_id(self): cfg = parse_cfg(["--operation.type", "delete_episodes"]) with pytest.raises(ValueError, match="--repo_id is required for delete_episodes"):