test(aggregate): extending aggregation tests to depth frames

This commit is contained in:
CarolinePascal
2026-05-26 17:00:34 +02:00
parent 05d2a6062d
commit 5498bac1b0
2 changed files with 84 additions and 8 deletions
+79 -8
View File
@@ -29,7 +29,13 @@ from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.feature_utils import features_equal_for_merge
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from tests.fixtures.constants import DUMMY_REPO_ID
from tests.fixtures.constants import (
DUMMY_CAMERA_FEATURES,
DUMMY_DEPTH_CAMERA_FEATURES,
DUMMY_REPO_ID,
)
CAMERA_FEATURES_WITH_DEPTH = {**DUMMY_CAMERA_FEATURES, **DUMMY_DEPTH_CAMERA_FEATURES}
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
@@ -191,6 +197,28 @@ def assert_dataset_iteration_works(aggr_ds):
pass
def assert_depth_keys_preserved(aggr_ds, ds_0, ds_1):
"""Test that depth keys are correctly preserved after aggregation.
Ensures that the ``is_depth_map`` marker on visual features survives
aggregation, so that downstream consumers (e.g. the dataset reader's
depth decoding path) keep working on the merged dataset.
"""
expected_depth_keys = set(ds_0.meta.depth_keys)
assert expected_depth_keys == set(ds_1.meta.depth_keys), (
"Source datasets disagree on depth_keys; test setup is inconsistent"
)
actual_depth_keys = set(aggr_ds.meta.depth_keys)
assert actual_depth_keys == expected_depth_keys, (
f"Expected depth_keys {expected_depth_keys}, got {actual_depth_keys}"
)
for key in expected_depth_keys:
info = aggr_ds.meta.info.features[key].get("info") or {}
assert info.get("is_depth_map") is True, (
f"Depth marker lost on feature {key!r} after aggregation"
)
def assert_video_timestamps_within_bounds(aggr_ds):
"""Test that all video timestamps are within valid bounds for their respective video files.
@@ -240,7 +268,11 @@ def assert_video_timestamps_within_bounds(aggr_ds):
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
"""Test basic aggregation functionality with standard parameters."""
"""Test basic aggregation functionality with standard parameters.
Source datasets include both RGB and depth video features so the same
aggregation flow is exercised on the ``is_depth_map`` branch.
"""
ds_0_num_frames = 400
ds_1_num_frames = 800
ds_0_num_episodes = 10
@@ -252,14 +284,21 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
repo_id=f"{DUMMY_REPO_ID}_0",
total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames,
camera_features=CAMERA_FEATURES_WITH_DEPTH,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "test_1",
repo_id=f"{DUMMY_REPO_ID}_1",
total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames,
camera_features=CAMERA_FEATURES_WITH_DEPTH,
)
# Confirm depth was actually wired into the source datasets so the
# rest of the assertions exercise the depth aggregation path.
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
@@ -286,6 +325,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
assert_video_timestamps_within_bounds(aggr_ds)
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
assert_dataset_iteration_works(aggr_ds)
@@ -403,7 +443,11 @@ def test_aggregate_incomplete_video_encoder_info_warns_and_nuls_encoders(
def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
"""Test aggregation with small file size limits to force file rotation/sharding."""
"""Test aggregation with small file size limits to force file rotation/sharding.
Depth video features are included to verify that file rotation/concat
correctly handles depth-marked features alongside regular RGB ones.
"""
ds_0_num_episodes = ds_1_num_episodes = 10
ds_0_num_frames = ds_1_num_frames = 400
@@ -412,14 +456,19 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
repo_id=f"{DUMMY_REPO_ID}_small_0",
total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames,
camera_features=CAMERA_FEATURES_WITH_DEPTH,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "small_1",
repo_id=f"{DUMMY_REPO_ID}_small_1",
total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames,
camera_features=CAMERA_FEATURES_WITH_DEPTH,
)
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
# Use the new configurable parameters to force file rotation
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
@@ -450,6 +499,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
assert_video_timestamps_within_bounds(aggr_ds)
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
assert_dataset_iteration_works(aggr_ds)
# Check that multiple files were actually created due to small size limits
@@ -469,7 +519,8 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
"""Regression test for video timestamp bug when merging datasets.
This test specifically checks that video timestamps are correctly calculated
and accumulated when merging multiple datasets.
and accumulated when merging multiple datasets. Depth video features are
included so depth timestamps are also covered by the regression.
"""
datasets = []
for i in range(3):
@@ -478,9 +529,13 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
repo_id=f"{DUMMY_REPO_ID}_regression_{i}",
total_episodes=2,
total_frames=100,
camera_features=CAMERA_FEATURES_WITH_DEPTH,
)
datasets.append(ds)
for i, ds in enumerate(datasets):
assert len(ds.meta.depth_keys) > 0, f"Dataset {i} should expose at least one depth key"
aggregate_datasets(
repo_ids=[ds.repo_id for ds in datasets],
roots=[ds.root for ds in datasets],
@@ -497,12 +552,21 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr")
assert_video_timestamps_within_bounds(aggr_ds)
# Depth keys must survive the merge for the regression to cover the
# ``is_depth_map`` decoding branch.
assert set(aggr_ds.meta.depth_keys) == set(datasets[0].meta.depth_keys)
depth_keys = set(aggr_ds.meta.depth_keys)
for i in range(len(aggr_ds)):
item = aggr_ds[i]
for key in aggr_ds.meta.video_keys:
assert key in item, f"Video key {key} missing from item {i}"
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
# Depth frames are single-channel (1, H, W) after dequantization;
# standard RGB frames keep the 3-channel layout.
expected_channels = 1 if key in depth_keys else 3
assert item[key].shape[0] == expected_channels, (
f"Expected {expected_channels} channels for video key {key}, got {item[key].shape}"
)
def assert_image_schema_preserved(aggr_ds):
@@ -584,25 +648,31 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
ds_0_num_episodes = 2
ds_1_num_episodes = 3
# Create two image-based datasets (use_videos=False)
# Create two image-based datasets (use_videos=False) with a mix of RGB
# and depth-marked cameras so the depth path is exercised in image mode.
ds_0 = lerobot_dataset_factory(
root=tmp_path / "image_0",
repo_id=f"{DUMMY_REPO_ID}_image_0",
total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames,
use_videos=False, # Image-based dataset
use_videos=False,
camera_features=CAMERA_FEATURES_WITH_DEPTH,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "image_1",
repo_id=f"{DUMMY_REPO_ID}_image_1",
total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames,
use_videos=False, # Image-based dataset
use_videos=False,
camera_features=CAMERA_FEATURES_WITH_DEPTH,
)
# Verify source datasets have image keys
assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys"
assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys"
# And that the depth marker actually made it onto an image feature.
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
# Aggregate the datasets
aggregate_datasets(
@@ -637,6 +707,7 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
# Image-specific assertions
assert_image_schema_preserved(aggr_ds)
assert_image_frames_integrity(aggr_ds, ds_0, ds_1)
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
# Verify images can be accessed and have correct shape
sample_item = aggr_ds[0]
+5
View File
@@ -485,10 +485,14 @@ def lerobot_dataset_factory(
hf_dataset: datasets.Dataset | None = None,
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
chunks_size: int = DEFAULT_CHUNK_SIZE,
camera_features: dict | None = None,
**kwargs,
) -> LeRobotDataset:
# Instantiate objects
if info is None:
info_kwargs = {}
if camera_features is not None:
info_kwargs["camera_features"] = camera_features
info = info_factory(
total_episodes=total_episodes,
total_frames=total_frames,
@@ -496,6 +500,7 @@ def lerobot_dataset_factory(
use_videos=use_videos,
data_files_size_in_mb=data_files_size_in_mb,
chunks_size=chunks_size,
**info_kwargs,
)
if stats is None:
stats = stats_factory(features=info.features)