mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
add: tests forcing new file creation
This commit is contained in:
committed by
Michel Aractingi
parent
c8a5df963b
commit
4e01f87a6e
@@ -33,13 +33,15 @@ def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames)
|
||||
|
||||
def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1):
|
||||
"""Test that the content of both datasets is preserved correctly in the aggregated dataset."""
|
||||
keys_to_ignore = ["episode_index", "index", "timestamp"]
|
||||
|
||||
# Test first part of dataset corresponds to ds_0, check first item (index 0) matches ds_0[0]
|
||||
aggr_first_item = aggr_ds[0]
|
||||
ds_0_first_item = ds_0[0]
|
||||
|
||||
# Compare all keys except episode_index and index which should be updated
|
||||
for key in ds_0_first_item:
|
||||
if key not in ["episode_index", "index"]:
|
||||
if key not in keys_to_ignore:
|
||||
# Handle both tensor and non-tensor data
|
||||
if torch.is_tensor(aggr_first_item[key]) and torch.is_tensor(ds_0_first_item[key]):
|
||||
assert torch.allclose(aggr_first_item[key], ds_0_first_item[key], atol=1e-6), (
|
||||
@@ -55,7 +57,7 @@ def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1):
|
||||
ds_0_last_item = ds_0[-1]
|
||||
|
||||
for key in ds_0_last_item:
|
||||
if key not in ["episode_index", "index"]:
|
||||
if key not in keys_to_ignore:
|
||||
# Handle both tensor and non-tensor data
|
||||
if torch.is_tensor(aggr_ds_0_last_item[key]) and torch.is_tensor(ds_0_last_item[key]):
|
||||
assert torch.allclose(aggr_ds_0_last_item[key], ds_0_last_item[key], atol=1e-6), (
|
||||
@@ -72,7 +74,7 @@ def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1):
|
||||
ds_1_first_item = ds_1[0]
|
||||
|
||||
for key in ds_1_first_item:
|
||||
if key not in ["episode_index", "index"]:
|
||||
if key not in keys_to_ignore:
|
||||
# Handle both tensor and non-tensor data
|
||||
if torch.is_tensor(aggr_ds_1_first_item[key]) and torch.is_tensor(ds_1_first_item[key]):
|
||||
assert torch.allclose(aggr_ds_1_first_item[key], ds_1_first_item[key], atol=1e-6), (
|
||||
@@ -88,7 +90,7 @@ def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1):
|
||||
ds_1_last_item = ds_1[-1]
|
||||
|
||||
for key in ds_1_last_item:
|
||||
if key not in ["episode_index", "index"]:
|
||||
if key not in keys_to_ignore:
|
||||
# Handle both tensor and non-tensor data
|
||||
if torch.is_tensor(aggr_last_item[key]) and torch.is_tensor(ds_1_last_item[key]):
|
||||
assert torch.allclose(aggr_last_item[key], ds_1_last_item[key], atol=1e-6), (
|
||||
@@ -180,9 +182,9 @@ def assert_dataset_iteration_works(aggr_ds):
|
||||
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
"""Test basic aggregation functionality with standard parameters."""
|
||||
ds_0_num_frames = 400
|
||||
ds_1_num_frames = 400
|
||||
ds_1_num_frames = 800
|
||||
ds_0_num_episodes = 10
|
||||
ds_1_num_episodes = 10
|
||||
ds_1_num_episodes = 25
|
||||
|
||||
# Create two datasets with different number of frames and episodes
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
@@ -217,3 +219,58 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
|
||||
def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
"""Test aggregation with small file size limits to force file rotation/sharding."""
|
||||
ds_0_num_episodes = ds_1_num_episodes = 10
|
||||
ds_0_num_frames = ds_1_num_frames = 400
|
||||
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "small_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_small_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "small_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_small_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
)
|
||||
|
||||
# Use the new configurable parameters to force file rotation
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_small_aggr",
|
||||
aggr_root=tmp_path / "small_aggr",
|
||||
# Tiny file size to trigger new file instantiation
|
||||
data_files_size_in_mb=0.01,
|
||||
video_files_size_in_mb=0.1,
|
||||
)
|
||||
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr")
|
||||
|
||||
# Verify aggregation worked correctly despite file size constraints
|
||||
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
|
||||
expected_total_frames = ds_0_num_frames + ds_1_num_frames
|
||||
|
||||
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
|
||||
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
# Check that multiple files were actually created due to small size limits
|
||||
data_dir = tmp_path / "small_aggr" / "data"
|
||||
video_dir = tmp_path / "small_aggr" / "videos"
|
||||
|
||||
if data_dir.exists():
|
||||
parquet_files = list(data_dir.rglob("*.parquet"))
|
||||
assert len(parquet_files) > 1, "Small file size limits should create multiple parquet files"
|
||||
|
||||
if video_dir.exists():
|
||||
video_files = list(video_dir.rglob("*.mp4"))
|
||||
assert len(video_files) > 1, "Small file size limits should create multiple video files"
|
||||
|
||||
Reference in New Issue
Block a user