mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
feat(episodes): adding support for metadata based episodes filtering (#3530)
* feat(episode filtering): adding support for episodes filtering at initialization time in LeRobotDataset * test(tests): adding tests * chore(format): formatting code * feat(performance): improving implementation for better performances on big datasets * chores(warning): improving warnings and errors for episodes filtering * test(invalid key): adding test for invalid filtering key * chore(format): formatting code
This commit is contained in:
@@ -1691,3 +1691,68 @@ def test_delta_timestamps_query_returns_correct_values(tmp_path, empty_lerobot_d
|
||||
# Previous frame is outside episode, so it's clamped to first frame and marked as padded
|
||||
assert state_values == [10.0, 10.0], f"Expected [10.0, 10.0], got {state_values}"
|
||||
assert is_pad == [True, False], f"Expected [True, False], got {is_pad}"
|
||||
|
||||
|
||||
def test_episode_filter_filters_dataset(tmp_path, lerobot_dataset_factory):
|
||||
"""episode_filter on LeRobotDataset narrows the loaded dataset to matching episodes."""
|
||||
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
|
||||
lengths = dataset.meta.episodes["length"]
|
||||
threshold = sorted(lengths)[len(lengths) // 2]
|
||||
expected_eps = [i for i, length in enumerate(lengths) if length >= threshold]
|
||||
expected_frames = sum(lengths[i] for i in expected_eps)
|
||||
|
||||
filtered = LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episode_filter=lambda ep: ep["length"] >= threshold,
|
||||
)
|
||||
|
||||
assert filtered.num_episodes == len(expected_eps)
|
||||
assert filtered.num_frames == expected_frames
|
||||
seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))}
|
||||
assert seen_eps == set(expected_eps)
|
||||
|
||||
|
||||
def test_episode_filter_intersects_with_episodes(tmp_path, lerobot_dataset_factory):
|
||||
"""When both episodes and episode_filter are given to LeRobotDataset, the result is their intersection."""
|
||||
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
|
||||
lengths = dataset.meta.episodes["length"]
|
||||
candidates = [0, 2, 4, 6]
|
||||
candidate_lengths = [lengths[i] for i in candidates]
|
||||
threshold = sorted(candidate_lengths)[len(candidate_lengths) // 2]
|
||||
expected_eps = [i for i in candidates if lengths[i] >= threshold]
|
||||
|
||||
filtered = LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episodes=candidates,
|
||||
episode_filter=lambda ep: ep["length"] >= threshold,
|
||||
)
|
||||
|
||||
assert filtered.num_episodes == len(expected_eps)
|
||||
seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))}
|
||||
assert seen_eps == set(expected_eps)
|
||||
|
||||
|
||||
def test_episode_filter_no_match_raises(tmp_path, lerobot_dataset_factory):
|
||||
"""An empty match in LeRobotDataset's episode_filter raises a ValueError rather than silently returning an empty dataset."""
|
||||
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=4, total_frames=100)
|
||||
|
||||
with pytest.raises(ValueError, match=r"The episode filter did not match any episode"):
|
||||
LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episode_filter=lambda ep: ep["length"] < 0,
|
||||
)
|
||||
|
||||
|
||||
def test_episode_filter_unknown_key_raises(tmp_path, lerobot_dataset_factory):
|
||||
"""A predicate referencing a column absent from meta.episodes surfaces a clear KeyError."""
|
||||
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=4, total_frames=100)
|
||||
|
||||
with pytest.raises(KeyError, match="not_a_real_field"):
|
||||
LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episode_filter=lambda ep: ep["not_a_real_field"] > 0,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user