mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
Dataset tools (#2100)
* feat(dataset-tools): add dataset utilities and example script - Introduced dataset tools for LeRobotDataset, including functions for deleting episodes, splitting datasets, adding/removing features, and merging datasets. - Added an example script demonstrating the usage of these utilities. - Implemented comprehensive tests for all new functionalities to ensure reliability and correctness. * style fixes * move example to dataset dir * missing lisence * fixes mostly path * clean comments * move tests to functions instead of class based * - fix video editting, decode, delete frames and rencode video - copy unchanged video and parquet files to avoid recreating the entire dataset * Fortify tooling tests * Fix type issue resulting from saving numpy arrays with shape 3,1,1 * added lerobot_edit_dataset * - revert changes in examples - remove hardcoded split names * update comment * fix comment add lerobot-edit-dataset shortcut * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co> * style nit after copilot review * fix: bug in dataset root when editing the dataset in place (without setting new_repo_id * Fix bug in aggregate.py when accumelating video timestamps; add tests to fortify aggregate videos * Added missing output repo id * migrate delete episode to using pyav instead of decoding, writing frames to disk and encoding again. Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com> * added modified suffix in case repo_id is not set in delete_episode * adding docs for dataset tools * bump av version and add back time_base assignment * linter * modified push_to_hub logic in lerobot_edit_dataset * fix(progress bar): fixing the progress bar issue in dataset tools * chore(concatenate): removing no longer needed concatenate_datasets usage * fix(file sizes forwarding): forwarding files and chunk sizes in metadata info when splitting and aggregating datasets * style fix * refactor(aggregate): Fix video indexing and timestamp bugs in dataset merging There were three critical bugs in aggregate.py that prevented correct dataset merging: 1. Video file indices: Changed from += to = assignment to correctly reference merged video files 2. Video timestamps: Implemented per-source-file offset tracking to maintain continuous timestamps when merging split datasets (was causing non-monotonic timestamp warnings) 3. File rotation offsets: Store timestamp offsets after rotation decision to prevent out-of-bounds frame access (was causing "Invalid frame index" errors with small file size limits) Changes: - Updated update_meta_data() to apply per-source-file timestamp offsets - Updated aggregate_videos() to track offsets correctly during file rotation - Added get_video_duration_in_s import for duration calculation * Improved docs for split dataset and added a check for the possible case that the split size results in zero episodes * chore(docs): update merge documentation details Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> --------- Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com> Co-authored-by: Jack Vial <vialjack@gmail.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -181,6 +181,54 @@ def assert_dataset_iteration_works(aggr_ds):
|
||||
pass
|
||||
|
||||
|
||||
def assert_video_timestamps_within_bounds(aggr_ds):
|
||||
"""Test that all video timestamps are within valid bounds for their respective video files.
|
||||
|
||||
This catches bugs where timestamps point to frames beyond the actual video length,
|
||||
which would cause "Invalid frame index" errors during data loading.
|
||||
"""
|
||||
try:
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
for ep_idx in range(aggr_ds.num_episodes):
|
||||
ep = aggr_ds.meta.episodes[ep_idx]
|
||||
|
||||
for vid_key in aggr_ds.meta.video_keys:
|
||||
from_ts = ep[f"videos/{vid_key}/from_timestamp"]
|
||||
to_ts = ep[f"videos/{vid_key}/to_timestamp"]
|
||||
video_path = aggr_ds.root / aggr_ds.meta.get_video_file_path(ep_idx, vid_key)
|
||||
|
||||
if not video_path.exists():
|
||||
continue
|
||||
|
||||
from_frame_idx = round(from_ts * aggr_ds.fps)
|
||||
to_frame_idx = round(to_ts * aggr_ds.fps)
|
||||
|
||||
try:
|
||||
decoder = VideoDecoder(str(video_path))
|
||||
num_frames = len(decoder)
|
||||
|
||||
# Verify timestamps don't exceed video bounds
|
||||
assert from_frame_idx >= 0, (
|
||||
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) < 0"
|
||||
)
|
||||
assert from_frame_idx < num_frames, (
|
||||
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= video frames ({num_frames})"
|
||||
)
|
||||
assert to_frame_idx <= num_frames, (
|
||||
f"Episode {ep_idx}, {vid_key}: to_frame_idx ({to_frame_idx}) > video frames ({num_frames})"
|
||||
)
|
||||
assert from_frame_idx < to_frame_idx, (
|
||||
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= to_frame_idx ({to_frame_idx})"
|
||||
)
|
||||
except Exception as e:
|
||||
raise AssertionError(
|
||||
f"Failed to verify timestamps for episode {ep_idx}, {vid_key}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
"""Test basic aggregation functionality with standard parameters."""
|
||||
ds_0_num_frames = 400
|
||||
@@ -227,6 +275,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
|
||||
@@ -277,6 +326,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
# Check that multiple files were actually created due to small size limits
|
||||
@@ -290,3 +340,43 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
if video_dir.exists():
|
||||
video_files = list(video_dir.rglob("*.mp4"))
|
||||
assert len(video_files) > 1, "Small file size limits should create multiple video files"
|
||||
|
||||
|
||||
def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
"""Regression test for video timestamp bug when merging datasets.
|
||||
|
||||
This test specifically checks that video timestamps are correctly calculated
|
||||
and accumulated when merging multiple datasets.
|
||||
"""
|
||||
datasets = []
|
||||
for i in range(3):
|
||||
ds = lerobot_dataset_factory(
|
||||
root=tmp_path / f"regression_{i}",
|
||||
repo_id=f"{DUMMY_REPO_ID}_regression_{i}",
|
||||
total_episodes=2,
|
||||
total_frames=100,
|
||||
)
|
||||
datasets.append(ds)
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds.repo_id for ds in datasets],
|
||||
roots=[ds.root for ds in datasets],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_regression_aggr",
|
||||
aggr_root=tmp_path / "regression_aggr",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "regression_aggr")
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr")
|
||||
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
|
||||
for i in range(len(aggr_ds)):
|
||||
item = aggr_ds[i]
|
||||
for key in aggr_ds.meta.video_keys:
|
||||
assert key in item, f"Video key {key} missing from item {i}"
|
||||
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
|
||||
|
||||
Reference in New Issue
Block a user