From 5498bac1b025dc487748619d2ecfb0572da64116 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Tue, 26 May 2026 17:00:34 +0200 Subject: [PATCH] test(aggregate): extending aggregation tests to depth frames --- tests/datasets/test_aggregate.py | 87 ++++++++++++++++++++++++++--- tests/fixtures/dataset_factories.py | 5 ++ 2 files changed, 84 insertions(+), 8 deletions(-) diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index f3edc3af8..bdf344e2a 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -29,7 +29,13 @@ from lerobot.configs import VIDEO_ENCODER_INFO_KEYS from lerobot.datasets.aggregate import aggregate_datasets from lerobot.datasets.feature_utils import features_equal_for_merge from lerobot.datasets.lerobot_dataset import LeRobotDataset -from tests.fixtures.constants import DUMMY_REPO_ID +from tests.fixtures.constants import ( + DUMMY_CAMERA_FEATURES, + DUMMY_DEPTH_CAMERA_FEATURES, + DUMMY_REPO_ID, +) + +CAMERA_FEATURES_WITH_DEPTH = {**DUMMY_CAMERA_FEATURES, **DUMMY_DEPTH_CAMERA_FEATURES} def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames): @@ -191,6 +197,28 @@ def assert_dataset_iteration_works(aggr_ds): pass +def assert_depth_keys_preserved(aggr_ds, ds_0, ds_1): + """Test that depth keys are correctly preserved after aggregation. + + Ensures that the ``is_depth_map`` marker on visual features survives + aggregation, so that downstream consumers (e.g. the dataset reader's + depth decoding path) keep working on the merged dataset. + """ + expected_depth_keys = set(ds_0.meta.depth_keys) + assert expected_depth_keys == set(ds_1.meta.depth_keys), ( + "Source datasets disagree on depth_keys; test setup is inconsistent" + ) + actual_depth_keys = set(aggr_ds.meta.depth_keys) + assert actual_depth_keys == expected_depth_keys, ( + f"Expected depth_keys {expected_depth_keys}, got {actual_depth_keys}" + ) + for key in expected_depth_keys: + info = aggr_ds.meta.info.features[key].get("info") or {} + assert info.get("is_depth_map") is True, ( + f"Depth marker lost on feature {key!r} after aggregation" + ) + + def assert_video_timestamps_within_bounds(aggr_ds): """Test that all video timestamps are within valid bounds for their respective video files. @@ -240,7 +268,11 @@ def assert_video_timestamps_within_bounds(aggr_ds): def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): - """Test basic aggregation functionality with standard parameters.""" + """Test basic aggregation functionality with standard parameters. + + Source datasets include both RGB and depth video features so the same + aggregation flow is exercised on the ``is_depth_map`` branch. + """ ds_0_num_frames = 400 ds_1_num_frames = 800 ds_0_num_episodes = 10 @@ -252,14 +284,21 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): repo_id=f"{DUMMY_REPO_ID}_0", total_episodes=ds_0_num_episodes, total_frames=ds_0_num_frames, + camera_features=CAMERA_FEATURES_WITH_DEPTH, ) ds_1 = lerobot_dataset_factory( root=tmp_path / "test_1", repo_id=f"{DUMMY_REPO_ID}_1", total_episodes=ds_1_num_episodes, total_frames=ds_1_num_frames, + camera_features=CAMERA_FEATURES_WITH_DEPTH, ) + # Confirm depth was actually wired into the source datasets so the + # rest of the assertions exercise the depth aggregation path. + assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key" + assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key" + aggregate_datasets( repo_ids=[ds_0.repo_id, ds_1.repo_id], roots=[ds_0.root, ds_1.root], @@ -286,6 +325,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) assert_video_frames_integrity(aggr_ds, ds_0, ds_1) assert_video_timestamps_within_bounds(aggr_ds) + assert_depth_keys_preserved(aggr_ds, ds_0, ds_1) assert_dataset_iteration_works(aggr_ds) @@ -403,7 +443,11 @@ def test_aggregate_incomplete_video_encoder_info_warns_and_nuls_encoders( def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): - """Test aggregation with small file size limits to force file rotation/sharding.""" + """Test aggregation with small file size limits to force file rotation/sharding. + + Depth video features are included to verify that file rotation/concat + correctly handles depth-marked features alongside regular RGB ones. + """ ds_0_num_episodes = ds_1_num_episodes = 10 ds_0_num_frames = ds_1_num_frames = 400 @@ -412,14 +456,19 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): repo_id=f"{DUMMY_REPO_ID}_small_0", total_episodes=ds_0_num_episodes, total_frames=ds_0_num_frames, + camera_features=CAMERA_FEATURES_WITH_DEPTH, ) ds_1 = lerobot_dataset_factory( root=tmp_path / "small_1", repo_id=f"{DUMMY_REPO_ID}_small_1", total_episodes=ds_1_num_episodes, total_frames=ds_1_num_frames, + camera_features=CAMERA_FEATURES_WITH_DEPTH, ) + assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key" + assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key" + # Use the new configurable parameters to force file rotation aggregate_datasets( repo_ids=[ds_0.repo_id, ds_1.repo_id], @@ -450,6 +499,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) assert_video_frames_integrity(aggr_ds, ds_0, ds_1) assert_video_timestamps_within_bounds(aggr_ds) + assert_depth_keys_preserved(aggr_ds, ds_0, ds_1) assert_dataset_iteration_works(aggr_ds) # Check that multiple files were actually created due to small size limits @@ -469,7 +519,8 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory): """Regression test for video timestamp bug when merging datasets. This test specifically checks that video timestamps are correctly calculated - and accumulated when merging multiple datasets. + and accumulated when merging multiple datasets. Depth video features are + included so depth timestamps are also covered by the regression. """ datasets = [] for i in range(3): @@ -478,9 +529,13 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory): repo_id=f"{DUMMY_REPO_ID}_regression_{i}", total_episodes=2, total_frames=100, + camera_features=CAMERA_FEATURES_WITH_DEPTH, ) datasets.append(ds) + for i, ds in enumerate(datasets): + assert len(ds.meta.depth_keys) > 0, f"Dataset {i} should expose at least one depth key" + aggregate_datasets( repo_ids=[ds.repo_id for ds in datasets], roots=[ds.root for ds in datasets], @@ -497,12 +552,21 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory): aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr") assert_video_timestamps_within_bounds(aggr_ds) + # Depth keys must survive the merge for the regression to cover the + # ``is_depth_map`` decoding branch. + assert set(aggr_ds.meta.depth_keys) == set(datasets[0].meta.depth_keys) + depth_keys = set(aggr_ds.meta.depth_keys) for i in range(len(aggr_ds)): item = aggr_ds[i] for key in aggr_ds.meta.video_keys: assert key in item, f"Video key {key} missing from item {i}" - assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}" + # Depth frames are single-channel (1, H, W) after dequantization; + # standard RGB frames keep the 3-channel layout. + expected_channels = 1 if key in depth_keys else 3 + assert item[key].shape[0] == expected_channels, ( + f"Expected {expected_channels} channels for video key {key}, got {item[key].shape}" + ) def assert_image_schema_preserved(aggr_ds): @@ -584,25 +648,31 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory): ds_0_num_episodes = 2 ds_1_num_episodes = 3 - # Create two image-based datasets (use_videos=False) + # Create two image-based datasets (use_videos=False) with a mix of RGB + # and depth-marked cameras so the depth path is exercised in image mode. ds_0 = lerobot_dataset_factory( root=tmp_path / "image_0", repo_id=f"{DUMMY_REPO_ID}_image_0", total_episodes=ds_0_num_episodes, total_frames=ds_0_num_frames, - use_videos=False, # Image-based dataset + use_videos=False, + camera_features=CAMERA_FEATURES_WITH_DEPTH, ) ds_1 = lerobot_dataset_factory( root=tmp_path / "image_1", repo_id=f"{DUMMY_REPO_ID}_image_1", total_episodes=ds_1_num_episodes, total_frames=ds_1_num_frames, - use_videos=False, # Image-based dataset + use_videos=False, + camera_features=CAMERA_FEATURES_WITH_DEPTH, ) # Verify source datasets have image keys assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys" assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys" + # And that the depth marker actually made it onto an image feature. + assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key" + assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key" # Aggregate the datasets aggregate_datasets( @@ -637,6 +707,7 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory): # Image-specific assertions assert_image_schema_preserved(aggr_ds) assert_image_frames_integrity(aggr_ds, ds_0, ds_1) + assert_depth_keys_preserved(aggr_ds, ds_0, ds_1) # Verify images can be accessed and have correct shape sample_item = aggr_ds[0] diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 2f4d41ff8..5cdf473d9 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -485,10 +485,14 @@ def lerobot_dataset_factory( hf_dataset: datasets.Dataset | None = None, data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB, chunks_size: int = DEFAULT_CHUNK_SIZE, + camera_features: dict | None = None, **kwargs, ) -> LeRobotDataset: # Instantiate objects if info is None: + info_kwargs = {} + if camera_features is not None: + info_kwargs["camera_features"] = camera_features info = info_factory( total_episodes=total_episodes, total_frames=total_frames, @@ -496,6 +500,7 @@ def lerobot_dataset_factory( use_videos=use_videos, data_files_size_in_mb=data_files_size_in_mb, chunks_size=chunks_size, + **info_kwargs, ) if stats is None: stats = stats_factory(features=info.features)