From 01d0b7b10263bf0691b3caace13853ccd7601283 Mon Sep 17 00:00:00 2001 From: fracapuano Date: Tue, 10 Jun 2025 14:44:24 +0200 Subject: [PATCH] fix: modularize tests to improve readability --- tests/datasets/test_aggregate.py | 118 ++++++++++++++++++------------- 1 file changed, 70 insertions(+), 48 deletions(-) diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index f5a12661b..b67ed001e 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -5,49 +5,19 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from tests.fixtures.constants import DUMMY_REPO_ID -def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): - ds_0_num_frames = 400 - ds_1_num_frames = 400 - ds_0_num_episodes = 10 - ds_1_num_episodes = 10 - - # Create two datasets with different number of frames and episodes - ds_0 = lerobot_dataset_factory( - root=tmp_path / "test_0", - repo_id=f"{DUMMY_REPO_ID}_0", - total_episodes=ds_0_num_episodes, - total_frames=ds_0_num_frames, +def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames): + """Test that total number of episodes and frames are correctly aggregated.""" + assert aggr_ds.num_episodes == expected_episodes, ( + f"Expected {expected_episodes} episodes, got {aggr_ds.num_episodes}" ) - ds_1 = lerobot_dataset_factory( - root=tmp_path / "test_1", - repo_id=f"{DUMMY_REPO_ID}_1", - total_episodes=ds_1_num_episodes, - total_frames=ds_1_num_frames, + assert aggr_ds.num_frames == expected_frames, ( + f"Expected {expected_frames} frames, got {aggr_ds.num_frames}" ) - aggregate_datasets( - repo_ids=[ds_0.repo_id, ds_1.repo_id], - roots=[ds_0.root, ds_1.root], - aggr_repo_id=f"{DUMMY_REPO_ID}_aggr", - aggr_root=tmp_path / "test_aggr", - ) - aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr") - - # Test 1: Total number of episodes corresponds - expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes - assert aggr_ds.num_episodes == expected_total_episodes, ( - f"Expected {expected_total_episodes} episodes, got {aggr_ds.num_episodes}" - ) - - # Test 2: Total number of frames corresponds - expected_total_frames = ds_0.num_frames + ds_1.num_frames - assert aggr_ds.num_frames == expected_total_frames, ( - f"Expected {expected_total_frames} frames, got {aggr_ds.num_frames}" - ) - - # Test 3: First part of dataset corresponds to ds_0 - # Check first item (index 0) matches ds_0[0] +def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1): + """Test that the content of both datasets is preserved correctly in the aggregated dataset.""" + # Test first part of dataset corresponds to ds_0, check first item (index 0) matches ds_0[0] aggr_first_item = aggr_ds[0] ds_0_first_item = ds_0[0] @@ -80,7 +50,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): f"Last ds_0 item key '{key}' doesn't match between aggregated and ds_0" ) - # Test 4: Second part of dataset corresponds to ds_1 + # Test second part of dataset corresponds to ds_1 # Check first item of ds_1 part (index len(ds_0)) matches ds_1[0] aggr_ds_1_first_item = aggr_ds[len(ds_0)] ds_1_first_item = ds_1[0] @@ -113,7 +83,9 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): f"Last item key '{key}' doesn't match between aggregated and ds_1" ) - # Test 5: Check metadata aggregation + +def assert_metadata_consistency(aggr_ds, ds_0, ds_1): + """Test that metadata is correctly aggregated.""" # Test basic info assert aggr_ds.fps == ds_0.fps == ds_1.fps, "FPS should be the same across all datasets" assert aggr_ds.meta.info["robot_type"] == ds_0.meta.info["robot_type"] == ds_1.meta.info["robot_type"], ( @@ -128,9 +100,11 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): actual_tasks = set(aggr_ds.meta.tasks.index) assert actual_tasks == expected_tasks, f"Expected tasks {expected_tasks}, got {actual_tasks}" - # Test episode indices are correctly updated + +def assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1): + """Test that episode indices are correctly updated after aggregation.""" # ds_0 episodes should have episode_index 0 to ds_0.num_episodes-1 - for i in range(ds_0_num_frames): + for i in range(len(ds_0)): assert aggr_ds[i]["episode_index"] < ds_0.num_episodes, ( f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be < {ds_0.num_episodes}" ) @@ -139,12 +113,16 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): return (ep_idx >= ds_0.num_episodes) and (ep_idx < ds_0.num_episodes + ds_1.num_episodes) # ds_1 episodes should have episode_index ds_0.num_episodes to total_episodes-1 - for i in range(ds_0_num_frames, ds_0_num_frames + ds_1_num_frames): + for i in range(len(ds_0), len(ds_0) + len(ds_1)): expected_min_episode_idx = ds_0.num_episodes assert ds1_episodes_condition(aggr_ds[i]["episode_index"]), ( f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be >= {expected_min_episode_idx}" ) + +def assert_video_frames_integrity(aggr_ds, ds_0, ds_1): + """Test that video frames are correctly preserved and frame indices are updated.""" + def visual_frames_equal(frame1, frame2): return torch.allclose(frame1, frame2) @@ -156,7 +134,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): ) # Test the section corresponding to the first dataset (ds_0) - for i in range(ds_0_num_frames): + for i in range(len(ds_0)): assert aggr_ds[i]["index"] == i, ( f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}" ) @@ -166,16 +144,60 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): ) # Test the section corresponding to the second dataset (ds_1) - for i in range(ds_0_num_frames, ds_0_num_frames + ds_1_num_frames): + for i in range(len(ds_0), len(ds_0) + len(ds_1)): # The frame index in the aggregated dataset should also match its position. assert aggr_ds[i]["index"] == i, ( f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}" ) for key in video_keys: - assert visual_frames_equal(aggr_ds[i][key], ds_1[i - ds_0_num_frames][key]), ( + assert visual_frames_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), ( f"Visual frames at position {i} should be equal between aggregated and ds_1" ) - # Test that we can iterate through the entire dataset without errors + +def assert_dataset_iteration_works(aggr_ds): + """Test that we can iterate through the entire dataset without errors.""" for _ in aggr_ds: pass + + +def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): + """Test basic aggregation functionality with standard parameters.""" + ds_0_num_frames = 400 + ds_1_num_frames = 400 + ds_0_num_episodes = 10 + ds_1_num_episodes = 10 + + # Create two datasets with different number of frames and episodes + ds_0 = lerobot_dataset_factory( + root=tmp_path / "test_0", + repo_id=f"{DUMMY_REPO_ID}_0", + total_episodes=ds_0_num_episodes, + total_frames=ds_0_num_frames, + ) + ds_1 = lerobot_dataset_factory( + root=tmp_path / "test_1", + repo_id=f"{DUMMY_REPO_ID}_1", + total_episodes=ds_1_num_episodes, + total_frames=ds_1_num_frames, + ) + + aggregate_datasets( + repo_ids=[ds_0.repo_id, ds_1.repo_id], + roots=[ds_0.root, ds_1.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_aggr", + aggr_root=tmp_path / "test_aggr", + ) + + aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr") + + # Run all assertion functions + expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes + expected_total_frames = ds_0.num_frames + ds_1.num_frames + + assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames) + assert_dataset_content_integrity(aggr_ds, ds_0, ds_1) + assert_metadata_consistency(aggr_ds, ds_0, ds_1) + assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) + assert_video_frames_integrity(aggr_ds, ds_0, ds_1) + assert_dataset_iteration_works(aggr_ds)