mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
test(aggregate): extending aggregation tests to depth frames
This commit is contained in:
@@ -29,7 +29,13 @@ from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.feature_utils import features_equal_for_merge
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
from tests.fixtures.constants import (
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
DUMMY_DEPTH_CAMERA_FEATURES,
|
||||
DUMMY_REPO_ID,
|
||||
)
|
||||
|
||||
CAMERA_FEATURES_WITH_DEPTH = {**DUMMY_CAMERA_FEATURES, **DUMMY_DEPTH_CAMERA_FEATURES}
|
||||
|
||||
|
||||
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
|
||||
@@ -191,6 +197,28 @@ def assert_dataset_iteration_works(aggr_ds):
|
||||
pass
|
||||
|
||||
|
||||
def assert_depth_keys_preserved(aggr_ds, ds_0, ds_1):
|
||||
"""Test that depth keys are correctly preserved after aggregation.
|
||||
|
||||
Ensures that the ``is_depth_map`` marker on visual features survives
|
||||
aggregation, so that downstream consumers (e.g. the dataset reader's
|
||||
depth decoding path) keep working on the merged dataset.
|
||||
"""
|
||||
expected_depth_keys = set(ds_0.meta.depth_keys)
|
||||
assert expected_depth_keys == set(ds_1.meta.depth_keys), (
|
||||
"Source datasets disagree on depth_keys; test setup is inconsistent"
|
||||
)
|
||||
actual_depth_keys = set(aggr_ds.meta.depth_keys)
|
||||
assert actual_depth_keys == expected_depth_keys, (
|
||||
f"Expected depth_keys {expected_depth_keys}, got {actual_depth_keys}"
|
||||
)
|
||||
for key in expected_depth_keys:
|
||||
info = aggr_ds.meta.info.features[key].get("info") or {}
|
||||
assert info.get("is_depth_map") is True, (
|
||||
f"Depth marker lost on feature {key!r} after aggregation"
|
||||
)
|
||||
|
||||
|
||||
def assert_video_timestamps_within_bounds(aggr_ds):
|
||||
"""Test that all video timestamps are within valid bounds for their respective video files.
|
||||
|
||||
@@ -240,7 +268,11 @@ def assert_video_timestamps_within_bounds(aggr_ds):
|
||||
|
||||
|
||||
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
"""Test basic aggregation functionality with standard parameters."""
|
||||
"""Test basic aggregation functionality with standard parameters.
|
||||
|
||||
Source datasets include both RGB and depth video features so the same
|
||||
aggregation flow is exercised on the ``is_depth_map`` branch.
|
||||
"""
|
||||
ds_0_num_frames = 400
|
||||
ds_1_num_frames = 800
|
||||
ds_0_num_episodes = 10
|
||||
@@ -252,14 +284,21 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
repo_id=f"{DUMMY_REPO_ID}_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
camera_features=CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "test_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
camera_features=CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
|
||||
# Confirm depth was actually wired into the source datasets so the
|
||||
# rest of the assertions exercise the depth aggregation path.
|
||||
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
|
||||
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
@@ -286,6 +325,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
|
||||
@@ -403,7 +443,11 @@ def test_aggregate_incomplete_video_encoder_info_warns_and_nuls_encoders(
|
||||
|
||||
|
||||
def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
"""Test aggregation with small file size limits to force file rotation/sharding."""
|
||||
"""Test aggregation with small file size limits to force file rotation/sharding.
|
||||
|
||||
Depth video features are included to verify that file rotation/concat
|
||||
correctly handles depth-marked features alongside regular RGB ones.
|
||||
"""
|
||||
ds_0_num_episodes = ds_1_num_episodes = 10
|
||||
ds_0_num_frames = ds_1_num_frames = 400
|
||||
|
||||
@@ -412,14 +456,19 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
repo_id=f"{DUMMY_REPO_ID}_small_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
camera_features=CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "small_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_small_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
camera_features=CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
|
||||
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
|
||||
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
|
||||
|
||||
# Use the new configurable parameters to force file rotation
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
@@ -450,6 +499,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
# Check that multiple files were actually created due to small size limits
|
||||
@@ -469,7 +519,8 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
"""Regression test for video timestamp bug when merging datasets.
|
||||
|
||||
This test specifically checks that video timestamps are correctly calculated
|
||||
and accumulated when merging multiple datasets.
|
||||
and accumulated when merging multiple datasets. Depth video features are
|
||||
included so depth timestamps are also covered by the regression.
|
||||
"""
|
||||
datasets = []
|
||||
for i in range(3):
|
||||
@@ -478,9 +529,13 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
repo_id=f"{DUMMY_REPO_ID}_regression_{i}",
|
||||
total_episodes=2,
|
||||
total_frames=100,
|
||||
camera_features=CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
datasets.append(ds)
|
||||
|
||||
for i, ds in enumerate(datasets):
|
||||
assert len(ds.meta.depth_keys) > 0, f"Dataset {i} should expose at least one depth key"
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds.repo_id for ds in datasets],
|
||||
roots=[ds.root for ds in datasets],
|
||||
@@ -497,12 +552,21 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr")
|
||||
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
# Depth keys must survive the merge for the regression to cover the
|
||||
# ``is_depth_map`` decoding branch.
|
||||
assert set(aggr_ds.meta.depth_keys) == set(datasets[0].meta.depth_keys)
|
||||
|
||||
depth_keys = set(aggr_ds.meta.depth_keys)
|
||||
for i in range(len(aggr_ds)):
|
||||
item = aggr_ds[i]
|
||||
for key in aggr_ds.meta.video_keys:
|
||||
assert key in item, f"Video key {key} missing from item {i}"
|
||||
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
|
||||
# Depth frames are single-channel (1, H, W) after dequantization;
|
||||
# standard RGB frames keep the 3-channel layout.
|
||||
expected_channels = 1 if key in depth_keys else 3
|
||||
assert item[key].shape[0] == expected_channels, (
|
||||
f"Expected {expected_channels} channels for video key {key}, got {item[key].shape}"
|
||||
)
|
||||
|
||||
|
||||
def assert_image_schema_preserved(aggr_ds):
|
||||
@@ -584,25 +648,31 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
ds_0_num_episodes = 2
|
||||
ds_1_num_episodes = 3
|
||||
|
||||
# Create two image-based datasets (use_videos=False)
|
||||
# Create two image-based datasets (use_videos=False) with a mix of RGB
|
||||
# and depth-marked cameras so the depth path is exercised in image mode.
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "image_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_image_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
use_videos=False, # Image-based dataset
|
||||
use_videos=False,
|
||||
camera_features=CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "image_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_image_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
use_videos=False, # Image-based dataset
|
||||
use_videos=False,
|
||||
camera_features=CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
|
||||
# Verify source datasets have image keys
|
||||
assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys"
|
||||
assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys"
|
||||
# And that the depth marker actually made it onto an image feature.
|
||||
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
|
||||
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
|
||||
|
||||
# Aggregate the datasets
|
||||
aggregate_datasets(
|
||||
@@ -637,6 +707,7 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
# Image-specific assertions
|
||||
assert_image_schema_preserved(aggr_ds)
|
||||
assert_image_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
|
||||
|
||||
# Verify images can be accessed and have correct shape
|
||||
sample_item = aggr_ds[0]
|
||||
|
||||
Vendored
+5
@@ -485,10 +485,14 @@ def lerobot_dataset_factory(
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
camera_features: dict | None = None,
|
||||
**kwargs,
|
||||
) -> LeRobotDataset:
|
||||
# Instantiate objects
|
||||
if info is None:
|
||||
info_kwargs = {}
|
||||
if camera_features is not None:
|
||||
info_kwargs["camera_features"] = camera_features
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
@@ -496,6 +500,7 @@ def lerobot_dataset_factory(
|
||||
use_videos=use_videos,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
chunks_size=chunks_size,
|
||||
**info_kwargs,
|
||||
)
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info.features)
|
||||
|
||||
Reference in New Issue
Block a user