mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
add: tests for aggregation code
This commit is contained in:
@@ -1,20 +1,28 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
ds_0_num_frames = 400
|
||||
ds_1_num_frames = 400
|
||||
ds_0_num_episodes = 10
|
||||
ds_1_num_episodes = 10
|
||||
|
||||
# Create two datasets with different number of frames and episodes
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "test_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_0",
|
||||
total_episodes=10,
|
||||
total_frames=400,
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "test_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_1",
|
||||
total_episodes=10,
|
||||
total_frames=400,
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
)
|
||||
|
||||
aggregate_datasets(
|
||||
@@ -25,5 +33,149 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
)
|
||||
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
|
||||
|
||||
# Test 1: Total number of episodes corresponds
|
||||
expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes
|
||||
assert aggr_ds.num_episodes == expected_total_episodes, (
|
||||
f"Expected {expected_total_episodes} episodes, got {aggr_ds.num_episodes}"
|
||||
)
|
||||
|
||||
# Test 2: Total number of frames corresponds
|
||||
expected_total_frames = ds_0.num_frames + ds_1.num_frames
|
||||
assert aggr_ds.num_frames == expected_total_frames, (
|
||||
f"Expected {expected_total_frames} frames, got {aggr_ds.num_frames}"
|
||||
)
|
||||
|
||||
# Test 3: First part of dataset corresponds to ds_0
|
||||
# Check first item (index 0) matches ds_0[0]
|
||||
aggr_first_item = aggr_ds[0]
|
||||
ds_0_first_item = ds_0[0]
|
||||
|
||||
# Compare all keys except episode_index and index which should be updated
|
||||
for key in ds_0_first_item:
|
||||
if key not in ["episode_index", "index"]:
|
||||
# Handle both tensor and non-tensor data
|
||||
if torch.is_tensor(aggr_first_item[key]) and torch.is_tensor(ds_0_first_item[key]):
|
||||
assert torch.allclose(aggr_first_item[key], ds_0_first_item[key], atol=1e-6), (
|
||||
f"First item key '{key}' doesn't match between aggregated and ds_0"
|
||||
)
|
||||
else:
|
||||
assert aggr_first_item[key] == ds_0_first_item[key], (
|
||||
f"First item key '{key}' doesn't match between aggregated and ds_0"
|
||||
)
|
||||
|
||||
# Check last item of ds_0 part (index len(ds_0)-1) matches ds_0[-1]
|
||||
aggr_ds_0_last_item = aggr_ds[len(ds_0) - 1]
|
||||
ds_0_last_item = ds_0[-1]
|
||||
|
||||
for key in ds_0_last_item:
|
||||
if key not in ["episode_index", "index"]:
|
||||
# Handle both tensor and non-tensor data
|
||||
if torch.is_tensor(aggr_ds_0_last_item[key]) and torch.is_tensor(ds_0_last_item[key]):
|
||||
assert torch.allclose(aggr_ds_0_last_item[key], ds_0_last_item[key], atol=1e-6), (
|
||||
f"Last ds_0 item key '{key}' doesn't match between aggregated and ds_0"
|
||||
)
|
||||
else:
|
||||
assert aggr_ds_0_last_item[key] == ds_0_last_item[key], (
|
||||
f"Last ds_0 item key '{key}' doesn't match between aggregated and ds_0"
|
||||
)
|
||||
|
||||
# Test 4: Second part of dataset corresponds to ds_1
|
||||
# Check first item of ds_1 part (index len(ds_0)) matches ds_1[0]
|
||||
aggr_ds_1_first_item = aggr_ds[len(ds_0)]
|
||||
ds_1_first_item = ds_1[0]
|
||||
|
||||
for key in ds_1_first_item:
|
||||
if key not in ["episode_index", "index"]:
|
||||
# Handle both tensor and non-tensor data
|
||||
if torch.is_tensor(aggr_ds_1_first_item[key]) and torch.is_tensor(ds_1_first_item[key]):
|
||||
assert torch.allclose(aggr_ds_1_first_item[key], ds_1_first_item[key], atol=1e-6), (
|
||||
f"First ds_1 item key '{key}' doesn't match between aggregated and ds_1"
|
||||
)
|
||||
else:
|
||||
assert aggr_ds_1_first_item[key] == ds_1_first_item[key], (
|
||||
f"First ds_1 item key '{key}' doesn't match between aggregated and ds_1"
|
||||
)
|
||||
|
||||
# Check last item matches ds_1[-1]
|
||||
aggr_last_item = aggr_ds[-1]
|
||||
ds_1_last_item = ds_1[-1]
|
||||
|
||||
for key in ds_1_last_item:
|
||||
if key not in ["episode_index", "index"]:
|
||||
# Handle both tensor and non-tensor data
|
||||
if torch.is_tensor(aggr_last_item[key]) and torch.is_tensor(ds_1_last_item[key]):
|
||||
assert torch.allclose(aggr_last_item[key], ds_1_last_item[key], atol=1e-6), (
|
||||
f"Last item key '{key}' doesn't match between aggregated and ds_1"
|
||||
)
|
||||
else:
|
||||
assert aggr_last_item[key] == ds_1_last_item[key], (
|
||||
f"Last item key '{key}' doesn't match between aggregated and ds_1"
|
||||
)
|
||||
|
||||
# Test 5: Check metadata aggregation
|
||||
# Test basic info
|
||||
assert aggr_ds.fps == ds_0.fps == ds_1.fps, "FPS should be the same across all datasets"
|
||||
assert aggr_ds.meta.info["robot_type"] == ds_0.meta.info["robot_type"] == ds_1.meta.info["robot_type"], (
|
||||
"Robot type should be the same"
|
||||
)
|
||||
|
||||
# Test features are the same
|
||||
assert aggr_ds.features == ds_0.features == ds_1.features, "Features should be the same"
|
||||
|
||||
# Test tasks aggregation
|
||||
expected_tasks = set(ds_0.meta.tasks.index) | set(ds_1.meta.tasks.index)
|
||||
actual_tasks = set(aggr_ds.meta.tasks.index)
|
||||
assert actual_tasks == expected_tasks, f"Expected tasks {expected_tasks}, got {actual_tasks}"
|
||||
|
||||
# Test episode indices are correctly updated
|
||||
# ds_0 episodes should have episode_index 0 to ds_0.num_episodes-1
|
||||
for i in range(ds_0_num_frames):
|
||||
assert aggr_ds[i]["episode_index"] < ds_0.num_episodes, (
|
||||
f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be < {ds_0.num_episodes}"
|
||||
)
|
||||
|
||||
def ds1_episodes_condition(ep_idx):
|
||||
return (ep_idx >= ds_0.num_episodes) and (ep_idx < ds_0.num_episodes + ds_1.num_episodes)
|
||||
|
||||
# ds_1 episodes should have episode_index ds_0.num_episodes to total_episodes-1
|
||||
for i in range(ds_0_num_frames, ds_0_num_frames + ds_1_num_frames):
|
||||
expected_min_episode_idx = ds_0.num_episodes
|
||||
assert ds1_episodes_condition(aggr_ds[i]["episode_index"]), (
|
||||
f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be >= {expected_min_episode_idx}"
|
||||
)
|
||||
|
||||
def visual_frames_equal(frame1, frame2):
|
||||
return torch.allclose(frame1, frame2)
|
||||
|
||||
video_keys = list(
|
||||
filter(
|
||||
lambda key: aggr_ds.meta.info["features"][key]["dtype"] == "video",
|
||||
aggr_ds.meta.info["features"].keys(),
|
||||
)
|
||||
)
|
||||
|
||||
# Test the section corresponding to the first dataset (ds_0)
|
||||
for i in range(ds_0_num_frames):
|
||||
assert aggr_ds[i]["index"] == i, (
|
||||
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
||||
)
|
||||
for key in video_keys:
|
||||
assert visual_frames_equal(aggr_ds[i][key], ds_0[i][key]), (
|
||||
f"Visual frames at position {i} should be equal between aggregated and ds_0"
|
||||
)
|
||||
|
||||
# Test the section corresponding to the second dataset (ds_1)
|
||||
for i in range(ds_0_num_frames, ds_0_num_frames + ds_1_num_frames):
|
||||
# The frame index in the aggregated dataset should also match its position.
|
||||
assert aggr_ds[i]["index"] == i, (
|
||||
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
||||
)
|
||||
for key in video_keys:
|
||||
assert visual_frames_equal(aggr_ds[i][key], ds_1[i - ds_0_num_frames][key]), (
|
||||
f"Visual frames at position {i} should be equal between aggregated and ds_1"
|
||||
)
|
||||
|
||||
# Test that we can iterate through the entire dataset without errors
|
||||
for _ in aggr_ds:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user