add trim_episode_start dataset edit operation

Add a dataset edit operation to trim the first N seconds from episodes while rebuilding frame and episode indices and metadata consistently. Skip episodes that are too short to trim and cover parsing plus metadata invariants with focused tests.

Made-with: Cursor
This commit is contained in:
pepijn
2026-03-06 13:58:57 +00:00
parent 0394fae446
commit 6bbc24a991
4 changed files with 466 additions and 0 deletions
+99
View File
@@ -29,6 +29,7 @@ from lerobot.datasets.dataset_tools import (
modify_tasks,
remove_feature,
split_dataset,
trim_episode_start,
)
from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset
@@ -142,6 +143,104 @@ def test_delete_empty_list(sample_dataset, tmp_path):
)
def test_trim_episode_start_updates_indices(sample_dataset, tmp_path):
"""Test trimming episode starts updates frame/timestamp/index metadata consistently."""
output_dir = tmp_path / "trimmed"
trim_seconds = 0.1 # 3 frames at 30 FPS
trim_frames = int(trim_seconds * sample_dataset.meta.fps)
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
new_dataset = trim_episode_start(
sample_dataset,
seconds=trim_seconds,
output_dir=output_dir,
)
expected_length = 10 - trim_frames
assert new_dataset.meta.total_episodes == sample_dataset.meta.total_episodes
assert new_dataset.meta.total_frames == sample_dataset.meta.total_episodes * expected_length
indices = [int(i.item()) for i in new_dataset.hf_dataset["index"]]
assert indices == list(range(new_dataset.meta.total_frames))
episode_indices = [int(i.item()) for i in new_dataset.hf_dataset["episode_index"]]
frame_indices = [int(i.item()) for i in new_dataset.hf_dataset["frame_index"]]
timestamps = [float(i.item()) for i in new_dataset.hf_dataset["timestamp"]]
for ep_idx in range(sample_dataset.meta.total_episodes):
ep_frame_indices = [f for e, f in zip(episode_indices, frame_indices, strict=False) if e == ep_idx]
ep_timestamps = [t for e, t in zip(episode_indices, timestamps, strict=False) if e == ep_idx]
assert len(ep_frame_indices) == expected_length
assert ep_frame_indices == list(range(expected_length))
assert ep_timestamps[0] == pytest.approx(0.0)
assert ep_timestamps[-1] == pytest.approx((expected_length - 1) / sample_dataset.meta.fps)
ep_meta = new_dataset.meta.episodes[ep_idx]
assert int(ep_meta["length"]) == expected_length
assert int(ep_meta["dataset_from_index"]) == ep_idx * expected_length
assert int(ep_meta["dataset_to_index"]) == (ep_idx + 1) * expected_length
def test_trim_episode_start_skips_too_short_episodes(tmp_path, empty_lerobot_dataset_factory):
"""Test too-short episodes are skipped and remaining episodes are reindexed."""
features = {
"action": {"dtype": "float32", "shape": (2,), "names": None},
"observation.state": {"dtype": "float32", "shape": (2,), "names": None},
"observation.images.top": {"dtype": "image", "shape": (32, 32, 3), "names": None},
}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "source", features=features)
for ep_len in [10, 2, 10]:
for _ in range(ep_len):
dataset.add_frame(
{
"action": np.random.randn(2).astype(np.float32),
"observation.state": np.random.randn(2).astype(np.float32),
"observation.images.top": np.random.randint(0, 255, size=(32, 32, 3), dtype=np.uint8),
"task": "task",
}
)
dataset.save_episode()
dataset.finalize()
trim_seconds = 0.1 # 3 frames at 30 FPS
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "trimmed")
new_dataset = trim_episode_start(
dataset,
seconds=trim_seconds,
output_dir=tmp_path / "trimmed",
)
# Episode 1 is too short and gets skipped. Remaining episodes are trimmed and reindexed.
assert new_dataset.meta.total_episodes == 2
assert new_dataset.meta.total_frames == 14
assert sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) == [0, 1]
assert [int(ep["length"]) for ep in new_dataset.meta.episodes] == [7, 7]
def test_trim_episode_start_rejects_when_all_selected_are_too_short(sample_dataset, tmp_path):
"""Test trimming fails when all selected episodes are too short and would be skipped."""
with pytest.raises(ValueError, match="All episodes selected for trimming are too short"):
trim_episode_start(
sample_dataset,
seconds=1.0, # 30 frames > 10-frame episodes
output_dir=tmp_path / "trimmed",
)
def test_split_by_episodes(sample_dataset, tmp_path):
"""Test splitting dataset by specific episode indices."""
splits = {
@@ -26,6 +26,7 @@ from lerobot.scripts.lerobot_edit_dataset import (
OperationConfig,
RemoveFeatureConfig,
SplitConfig,
TrimEpisodeStartConfig,
)
@@ -45,6 +46,7 @@ class TestOperationTypeParsing:
("merge", MergeConfig),
("remove_feature", RemoveFeatureConfig),
("modify_tasks", ModifyTasksConfig),
("trim_episode_start", TrimEpisodeStartConfig),
("convert_image_to_video", ConvertImageToVideoConfig),
],
)
@@ -62,6 +64,7 @@ class TestOperationTypeParsing:
("merge", MergeConfig),
("remove_feature", RemoveFeatureConfig),
("modify_tasks", ModifyTasksConfig),
("trim_episode_start", TrimEpisodeStartConfig),
("convert_image_to_video", ConvertImageToVideoConfig),
],
)