From 4e01f87a6e011322420bf51d08da2c3b2db1c0c0 Mon Sep 17 00:00:00 2001 From: fracapuano Date: Wed, 11 Jun 2025 14:43:55 +0200 Subject: [PATCH] add: tests forcing new file creation --- tests/datasets/test_aggregate.py | 69 +++++++++++++++++++++++++++++--- 1 file changed, 63 insertions(+), 6 deletions(-) diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index 9d75ece38..7b65b234b 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -33,13 +33,15 @@ def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames) def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1): """Test that the content of both datasets is preserved correctly in the aggregated dataset.""" + keys_to_ignore = ["episode_index", "index", "timestamp"] + # Test first part of dataset corresponds to ds_0, check first item (index 0) matches ds_0[0] aggr_first_item = aggr_ds[0] ds_0_first_item = ds_0[0] # Compare all keys except episode_index and index which should be updated for key in ds_0_first_item: - if key not in ["episode_index", "index"]: + if key not in keys_to_ignore: # Handle both tensor and non-tensor data if torch.is_tensor(aggr_first_item[key]) and torch.is_tensor(ds_0_first_item[key]): assert torch.allclose(aggr_first_item[key], ds_0_first_item[key], atol=1e-6), ( @@ -55,7 +57,7 @@ def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1): ds_0_last_item = ds_0[-1] for key in ds_0_last_item: - if key not in ["episode_index", "index"]: + if key not in keys_to_ignore: # Handle both tensor and non-tensor data if torch.is_tensor(aggr_ds_0_last_item[key]) and torch.is_tensor(ds_0_last_item[key]): assert torch.allclose(aggr_ds_0_last_item[key], ds_0_last_item[key], atol=1e-6), ( @@ -72,7 +74,7 @@ def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1): ds_1_first_item = ds_1[0] for key in ds_1_first_item: - if key not in ["episode_index", "index"]: + if key not in keys_to_ignore: # Handle both tensor and non-tensor data if torch.is_tensor(aggr_ds_1_first_item[key]) and torch.is_tensor(ds_1_first_item[key]): assert torch.allclose(aggr_ds_1_first_item[key], ds_1_first_item[key], atol=1e-6), ( @@ -88,7 +90,7 @@ def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1): ds_1_last_item = ds_1[-1] for key in ds_1_last_item: - if key not in ["episode_index", "index"]: + if key not in keys_to_ignore: # Handle both tensor and non-tensor data if torch.is_tensor(aggr_last_item[key]) and torch.is_tensor(ds_1_last_item[key]): assert torch.allclose(aggr_last_item[key], ds_1_last_item[key], atol=1e-6), ( @@ -180,9 +182,9 @@ def assert_dataset_iteration_works(aggr_ds): def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): """Test basic aggregation functionality with standard parameters.""" ds_0_num_frames = 400 - ds_1_num_frames = 400 + ds_1_num_frames = 800 ds_0_num_episodes = 10 - ds_1_num_episodes = 10 + ds_1_num_episodes = 25 # Create two datasets with different number of frames and episodes ds_0 = lerobot_dataset_factory( @@ -217,3 +219,58 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) assert_video_frames_integrity(aggr_ds, ds_0, ds_1) assert_dataset_iteration_works(aggr_ds) + + +def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): + """Test aggregation with small file size limits to force file rotation/sharding.""" + ds_0_num_episodes = ds_1_num_episodes = 10 + ds_0_num_frames = ds_1_num_frames = 400 + + ds_0 = lerobot_dataset_factory( + root=tmp_path / "small_0", + repo_id=f"{DUMMY_REPO_ID}_small_0", + total_episodes=ds_0_num_episodes, + total_frames=ds_0_num_frames, + ) + ds_1 = lerobot_dataset_factory( + root=tmp_path / "small_1", + repo_id=f"{DUMMY_REPO_ID}_small_1", + total_episodes=ds_1_num_episodes, + total_frames=ds_1_num_frames, + ) + + # Use the new configurable parameters to force file rotation + aggregate_datasets( + repo_ids=[ds_0.repo_id, ds_1.repo_id], + roots=[ds_0.root, ds_1.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_small_aggr", + aggr_root=tmp_path / "small_aggr", + # Tiny file size to trigger new file instantiation + data_files_size_in_mb=0.01, + video_files_size_in_mb=0.1, + ) + + aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr") + + # Verify aggregation worked correctly despite file size constraints + expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes + expected_total_frames = ds_0_num_frames + ds_1_num_frames + + assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames) + assert_dataset_content_integrity(aggr_ds, ds_0, ds_1) + assert_metadata_consistency(aggr_ds, ds_0, ds_1) + assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) + assert_video_frames_integrity(aggr_ds, ds_0, ds_1) + assert_dataset_iteration_works(aggr_ds) + + # Check that multiple files were actually created due to small size limits + data_dir = tmp_path / "small_aggr" / "data" + video_dir = tmp_path / "small_aggr" / "videos" + + if data_dir.exists(): + parquet_files = list(data_dir.rglob("*.parquet")) + assert len(parquet_files) > 1, "Small file size limits should create multiple parquet files" + + if video_dir.exists(): + video_files = list(video_dir.rglob("*.mp4")) + assert len(video_files) > 1, "Small file size limits should create multiple video files"