mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
add trim_episode_start dataset edit operation
Add a dataset edit operation to trim the first N seconds from episodes while rebuilding frame and episode indices and metadata consistently. Skip episodes that are too short to trim and cover parsing plus metadata invariants with focused tests. Made-with: Cursor
This commit is contained in:
@@ -29,6 +29,7 @@ from lerobot.datasets.dataset_tools import (
|
||||
modify_tasks,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
trim_episode_start,
|
||||
)
|
||||
from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset
|
||||
|
||||
@@ -142,6 +143,104 @@ def test_delete_empty_list(sample_dataset, tmp_path):
|
||||
)
|
||||
|
||||
|
||||
def test_trim_episode_start_updates_indices(sample_dataset, tmp_path):
|
||||
"""Test trimming episode starts updates frame/timestamp/index metadata consistently."""
|
||||
output_dir = tmp_path / "trimmed"
|
||||
trim_seconds = 0.1 # 3 frames at 30 FPS
|
||||
trim_frames = int(trim_seconds * sample_dataset.meta.fps)
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
|
||||
new_dataset = trim_episode_start(
|
||||
sample_dataset,
|
||||
seconds=trim_seconds,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
expected_length = 10 - trim_frames
|
||||
assert new_dataset.meta.total_episodes == sample_dataset.meta.total_episodes
|
||||
assert new_dataset.meta.total_frames == sample_dataset.meta.total_episodes * expected_length
|
||||
|
||||
indices = [int(i.item()) for i in new_dataset.hf_dataset["index"]]
|
||||
assert indices == list(range(new_dataset.meta.total_frames))
|
||||
|
||||
episode_indices = [int(i.item()) for i in new_dataset.hf_dataset["episode_index"]]
|
||||
frame_indices = [int(i.item()) for i in new_dataset.hf_dataset["frame_index"]]
|
||||
timestamps = [float(i.item()) for i in new_dataset.hf_dataset["timestamp"]]
|
||||
|
||||
for ep_idx in range(sample_dataset.meta.total_episodes):
|
||||
ep_frame_indices = [f for e, f in zip(episode_indices, frame_indices, strict=False) if e == ep_idx]
|
||||
ep_timestamps = [t for e, t in zip(episode_indices, timestamps, strict=False) if e == ep_idx]
|
||||
|
||||
assert len(ep_frame_indices) == expected_length
|
||||
assert ep_frame_indices == list(range(expected_length))
|
||||
assert ep_timestamps[0] == pytest.approx(0.0)
|
||||
assert ep_timestamps[-1] == pytest.approx((expected_length - 1) / sample_dataset.meta.fps)
|
||||
|
||||
ep_meta = new_dataset.meta.episodes[ep_idx]
|
||||
assert int(ep_meta["length"]) == expected_length
|
||||
assert int(ep_meta["dataset_from_index"]) == ep_idx * expected_length
|
||||
assert int(ep_meta["dataset_to_index"]) == (ep_idx + 1) * expected_length
|
||||
|
||||
|
||||
def test_trim_episode_start_skips_too_short_episodes(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test too-short episodes are skipped and remaining episodes are reindexed."""
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": None},
|
||||
"observation.state": {"dtype": "float32", "shape": (2,), "names": None},
|
||||
"observation.images.top": {"dtype": "image", "shape": (32, 32, 3), "names": None},
|
||||
}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "source", features=features)
|
||||
|
||||
for ep_len in [10, 2, 10]:
|
||||
for _ in range(ep_len):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"action": np.random.randn(2).astype(np.float32),
|
||||
"observation.state": np.random.randn(2).astype(np.float32),
|
||||
"observation.images.top": np.random.randint(0, 255, size=(32, 32, 3), dtype=np.uint8),
|
||||
"task": "task",
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
trim_seconds = 0.1 # 3 frames at 30 FPS
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "trimmed")
|
||||
|
||||
new_dataset = trim_episode_start(
|
||||
dataset,
|
||||
seconds=trim_seconds,
|
||||
output_dir=tmp_path / "trimmed",
|
||||
)
|
||||
|
||||
# Episode 1 is too short and gets skipped. Remaining episodes are trimmed and reindexed.
|
||||
assert new_dataset.meta.total_episodes == 2
|
||||
assert new_dataset.meta.total_frames == 14
|
||||
assert sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) == [0, 1]
|
||||
assert [int(ep["length"]) for ep in new_dataset.meta.episodes] == [7, 7]
|
||||
|
||||
|
||||
def test_trim_episode_start_rejects_when_all_selected_are_too_short(sample_dataset, tmp_path):
|
||||
"""Test trimming fails when all selected episodes are too short and would be skipped."""
|
||||
with pytest.raises(ValueError, match="All episodes selected for trimming are too short"):
|
||||
trim_episode_start(
|
||||
sample_dataset,
|
||||
seconds=1.0, # 30 frames > 10-frame episodes
|
||||
output_dir=tmp_path / "trimmed",
|
||||
)
|
||||
|
||||
|
||||
def test_split_by_episodes(sample_dataset, tmp_path):
|
||||
"""Test splitting dataset by specific episode indices."""
|
||||
splits = {
|
||||
|
||||
@@ -26,6 +26,7 @@ from lerobot.scripts.lerobot_edit_dataset import (
|
||||
OperationConfig,
|
||||
RemoveFeatureConfig,
|
||||
SplitConfig,
|
||||
TrimEpisodeStartConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -45,6 +46,7 @@ class TestOperationTypeParsing:
|
||||
("merge", MergeConfig),
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("trim_episode_start", TrimEpisodeStartConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
],
|
||||
)
|
||||
@@ -62,6 +64,7 @@ class TestOperationTypeParsing:
|
||||
("merge", MergeConfig),
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("trim_episode_start", TrimEpisodeStartConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user