From bec69f7a9f116202de0054478484790bbdefed62 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Thu, 7 May 2026 15:52:19 +0200 Subject: [PATCH] test(tests): adding tests --- tests/datasets/test_datasets.py | 54 +++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 3e1a17a62..38ef79720 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1691,3 +1691,57 @@ def test_delta_timestamps_query_returns_correct_values(tmp_path, empty_lerobot_d # Previous frame is outside episode, so it's clamped to first frame and marked as padded assert state_values == [10.0, 10.0], f"Expected [10.0, 10.0], got {state_values}" assert is_pad == [True, False], f"Expected [True, False], got {is_pad}" + + +def test_episode_filter_filters_dataset(tmp_path, lerobot_dataset_factory): + """episode_filter on LeRobotDataset narrows the loaded dataset to matching episodes.""" + dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200) + lengths = dataset.meta.episodes["length"] + threshold = sorted(lengths)[len(lengths) // 2] + expected_eps = [i for i, length in enumerate(lengths) if length >= threshold] + expected_frames = sum(lengths[i] for i in expected_eps) + + filtered = LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episode_filter=lambda ep: ep["length"] >= threshold, + ) + + assert filtered.num_episodes == len(expected_eps) + assert filtered.num_frames == expected_frames + seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))} + assert seen_eps == set(expected_eps) + + +def test_episode_filter_intersects_with_episodes(tmp_path, lerobot_dataset_factory): + """When both episodes and episode_filter are given to LeRobotDataset, the result is their intersection.""" + dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200) + lengths = dataset.meta.episodes["length"] + threshold = sorted(lengths)[len(lengths) // 2] + candidates = [0, 2, 4, 6] + expected_eps = [i for i in candidates if lengths[i] >= threshold] + if not expected_eps: + pytest.skip("multinomial draw produced no candidate above threshold") + + filtered = LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episodes=candidates, + episode_filter=lambda ep: ep["length"] >= threshold, + ) + + assert filtered.num_episodes == len(expected_eps) + seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))} + assert seen_eps == set(expected_eps) + + +def test_episode_filter_no_match_raises(tmp_path, lerobot_dataset_factory): + """An empty match in LeRobotDataset's episode_filter raises a ValueError rather than silently returning an empty dataset.""" + dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=4, total_frames=100) + + with pytest.raises(ValueError, match=r"episode_filter did not match any episode"): + LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episode_filter=lambda ep: ep["length"] < 0, + )