fix: modularize tests to improve readability

This commit is contained in:
fracapuano
2025-06-10 14:44:24 +02:00
committed by Michel Aractingi
parent f1fbc29819
commit 8ad085d882
+70 -48
View File
@@ -5,49 +5,19 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from tests.fixtures.constants import DUMMY_REPO_ID
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
ds_0_num_frames = 400
ds_1_num_frames = 400
ds_0_num_episodes = 10
ds_1_num_episodes = 10
# Create two datasets with different number of frames and episodes
ds_0 = lerobot_dataset_factory(
root=tmp_path / "test_0",
repo_id=f"{DUMMY_REPO_ID}_0",
total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames,
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
"""Test that total number of episodes and frames are correctly aggregated."""
assert aggr_ds.num_episodes == expected_episodes, (
f"Expected {expected_episodes} episodes, got {aggr_ds.num_episodes}"
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "test_1",
repo_id=f"{DUMMY_REPO_ID}_1",
total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames,
assert aggr_ds.num_frames == expected_frames, (
f"Expected {expected_frames} frames, got {aggr_ds.num_frames}"
)
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
aggr_root=tmp_path / "test_aggr",
)
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
# Test 1: Total number of episodes corresponds
expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes
assert aggr_ds.num_episodes == expected_total_episodes, (
f"Expected {expected_total_episodes} episodes, got {aggr_ds.num_episodes}"
)
# Test 2: Total number of frames corresponds
expected_total_frames = ds_0.num_frames + ds_1.num_frames
assert aggr_ds.num_frames == expected_total_frames, (
f"Expected {expected_total_frames} frames, got {aggr_ds.num_frames}"
)
# Test 3: First part of dataset corresponds to ds_0
# Check first item (index 0) matches ds_0[0]
def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1):
"""Test that the content of both datasets is preserved correctly in the aggregated dataset."""
# Test first part of dataset corresponds to ds_0, check first item (index 0) matches ds_0[0]
aggr_first_item = aggr_ds[0]
ds_0_first_item = ds_0[0]
@@ -80,7 +50,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
f"Last ds_0 item key '{key}' doesn't match between aggregated and ds_0"
)
# Test 4: Second part of dataset corresponds to ds_1
# Test second part of dataset corresponds to ds_1
# Check first item of ds_1 part (index len(ds_0)) matches ds_1[0]
aggr_ds_1_first_item = aggr_ds[len(ds_0)]
ds_1_first_item = ds_1[0]
@@ -113,7 +83,9 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
f"Last item key '{key}' doesn't match between aggregated and ds_1"
)
# Test 5: Check metadata aggregation
def assert_metadata_consistency(aggr_ds, ds_0, ds_1):
"""Test that metadata is correctly aggregated."""
# Test basic info
assert aggr_ds.fps == ds_0.fps == ds_1.fps, "FPS should be the same across all datasets"
assert aggr_ds.meta.info["robot_type"] == ds_0.meta.info["robot_type"] == ds_1.meta.info["robot_type"], (
@@ -128,9 +100,11 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
actual_tasks = set(aggr_ds.meta.tasks.index)
assert actual_tasks == expected_tasks, f"Expected tasks {expected_tasks}, got {actual_tasks}"
# Test episode indices are correctly updated
def assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1):
"""Test that episode indices are correctly updated after aggregation."""
# ds_0 episodes should have episode_index 0 to ds_0.num_episodes-1
for i in range(ds_0_num_frames):
for i in range(len(ds_0)):
assert aggr_ds[i]["episode_index"] < ds_0.num_episodes, (
f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be < {ds_0.num_episodes}"
)
@@ -139,12 +113,16 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
return (ep_idx >= ds_0.num_episodes) and (ep_idx < ds_0.num_episodes + ds_1.num_episodes)
# ds_1 episodes should have episode_index ds_0.num_episodes to total_episodes-1
for i in range(ds_0_num_frames, ds_0_num_frames + ds_1_num_frames):
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
expected_min_episode_idx = ds_0.num_episodes
assert ds1_episodes_condition(aggr_ds[i]["episode_index"]), (
f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be >= {expected_min_episode_idx}"
)
def assert_video_frames_integrity(aggr_ds, ds_0, ds_1):
"""Test that video frames are correctly preserved and frame indices are updated."""
def visual_frames_equal(frame1, frame2):
return torch.allclose(frame1, frame2)
@@ -156,7 +134,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
)
# Test the section corresponding to the first dataset (ds_0)
for i in range(ds_0_num_frames):
for i in range(len(ds_0)):
assert aggr_ds[i]["index"] == i, (
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
)
@@ -166,16 +144,60 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
)
# Test the section corresponding to the second dataset (ds_1)
for i in range(ds_0_num_frames, ds_0_num_frames + ds_1_num_frames):
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
# The frame index in the aggregated dataset should also match its position.
assert aggr_ds[i]["index"] == i, (
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
)
for key in video_keys:
assert visual_frames_equal(aggr_ds[i][key], ds_1[i - ds_0_num_frames][key]), (
assert visual_frames_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), (
f"Visual frames at position {i} should be equal between aggregated and ds_1"
)
# Test that we can iterate through the entire dataset without errors
def assert_dataset_iteration_works(aggr_ds):
"""Test that we can iterate through the entire dataset without errors."""
for _ in aggr_ds:
pass
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
"""Test basic aggregation functionality with standard parameters."""
ds_0_num_frames = 400
ds_1_num_frames = 400
ds_0_num_episodes = 10
ds_1_num_episodes = 10
# Create two datasets with different number of frames and episodes
ds_0 = lerobot_dataset_factory(
root=tmp_path / "test_0",
repo_id=f"{DUMMY_REPO_ID}_0",
total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "test_1",
repo_id=f"{DUMMY_REPO_ID}_1",
total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames,
)
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
aggr_root=tmp_path / "test_aggr",
)
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
# Run all assertion functions
expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes
expected_total_frames = ds_0.num_frames + ds_1.num_frames
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
assert_dataset_iteration_works(aggr_ds)