add: tests forcing new file creation

This commit is contained in:
fracapuano
2025-06-11 14:43:55 +02:00
committed by Michel Aractingi
parent c8a5df963b
commit 4e01f87a6e
+63 -6
View File
@@ -33,13 +33,15 @@ def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames)
def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1):
"""Test that the content of both datasets is preserved correctly in the aggregated dataset."""
keys_to_ignore = ["episode_index", "index", "timestamp"]
# Test first part of dataset corresponds to ds_0, check first item (index 0) matches ds_0[0]
aggr_first_item = aggr_ds[0]
ds_0_first_item = ds_0[0]
# Compare all keys except episode_index and index which should be updated
for key in ds_0_first_item:
if key not in ["episode_index", "index"]:
if key not in keys_to_ignore:
# Handle both tensor and non-tensor data
if torch.is_tensor(aggr_first_item[key]) and torch.is_tensor(ds_0_first_item[key]):
assert torch.allclose(aggr_first_item[key], ds_0_first_item[key], atol=1e-6), (
@@ -55,7 +57,7 @@ def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1):
ds_0_last_item = ds_0[-1]
for key in ds_0_last_item:
if key not in ["episode_index", "index"]:
if key not in keys_to_ignore:
# Handle both tensor and non-tensor data
if torch.is_tensor(aggr_ds_0_last_item[key]) and torch.is_tensor(ds_0_last_item[key]):
assert torch.allclose(aggr_ds_0_last_item[key], ds_0_last_item[key], atol=1e-6), (
@@ -72,7 +74,7 @@ def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1):
ds_1_first_item = ds_1[0]
for key in ds_1_first_item:
if key not in ["episode_index", "index"]:
if key not in keys_to_ignore:
# Handle both tensor and non-tensor data
if torch.is_tensor(aggr_ds_1_first_item[key]) and torch.is_tensor(ds_1_first_item[key]):
assert torch.allclose(aggr_ds_1_first_item[key], ds_1_first_item[key], atol=1e-6), (
@@ -88,7 +90,7 @@ def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1):
ds_1_last_item = ds_1[-1]
for key in ds_1_last_item:
if key not in ["episode_index", "index"]:
if key not in keys_to_ignore:
# Handle both tensor and non-tensor data
if torch.is_tensor(aggr_last_item[key]) and torch.is_tensor(ds_1_last_item[key]):
assert torch.allclose(aggr_last_item[key], ds_1_last_item[key], atol=1e-6), (
@@ -180,9 +182,9 @@ def assert_dataset_iteration_works(aggr_ds):
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
"""Test basic aggregation functionality with standard parameters."""
ds_0_num_frames = 400
ds_1_num_frames = 400
ds_1_num_frames = 800
ds_0_num_episodes = 10
ds_1_num_episodes = 10
ds_1_num_episodes = 25
# Create two datasets with different number of frames and episodes
ds_0 = lerobot_dataset_factory(
@@ -217,3 +219,58 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
assert_dataset_iteration_works(aggr_ds)
def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
"""Test aggregation with small file size limits to force file rotation/sharding."""
ds_0_num_episodes = ds_1_num_episodes = 10
ds_0_num_frames = ds_1_num_frames = 400
ds_0 = lerobot_dataset_factory(
root=tmp_path / "small_0",
repo_id=f"{DUMMY_REPO_ID}_small_0",
total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "small_1",
repo_id=f"{DUMMY_REPO_ID}_small_1",
total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames,
)
# Use the new configurable parameters to force file rotation
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_small_aggr",
aggr_root=tmp_path / "small_aggr",
# Tiny file size to trigger new file instantiation
data_files_size_in_mb=0.01,
video_files_size_in_mb=0.1,
)
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr")
# Verify aggregation worked correctly despite file size constraints
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
expected_total_frames = ds_0_num_frames + ds_1_num_frames
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
assert_dataset_iteration_works(aggr_ds)
# Check that multiple files were actually created due to small size limits
data_dir = tmp_path / "small_aggr" / "data"
video_dir = tmp_path / "small_aggr" / "videos"
if data_dir.exists():
parquet_files = list(data_dir.rglob("*.parquet"))
assert len(parquet_files) > 1, "Small file size limits should create multiple parquet files"
if video_dir.exists():
video_files = list(video_dir.rglob("*.mp4"))
assert len(video_files) > 1, "Small file size limits should create multiple video files"