test(tests): adding tests

This commit is contained in:
CarolinePascal
2026-05-07 15:52:19 +02:00
parent 9e1fb1c2dd
commit bec69f7a9f
+54
View File
@@ -1691,3 +1691,57 @@ def test_delta_timestamps_query_returns_correct_values(tmp_path, empty_lerobot_d
# Previous frame is outside episode, so it's clamped to first frame and marked as padded
assert state_values == [10.0, 10.0], f"Expected [10.0, 10.0], got {state_values}"
assert is_pad == [True, False], f"Expected [True, False], got {is_pad}"
def test_episode_filter_filters_dataset(tmp_path, lerobot_dataset_factory):
"""episode_filter on LeRobotDataset narrows the loaded dataset to matching episodes."""
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
lengths = dataset.meta.episodes["length"]
threshold = sorted(lengths)[len(lengths) // 2]
expected_eps = [i for i, length in enumerate(lengths) if length >= threshold]
expected_frames = sum(lengths[i] for i in expected_eps)
filtered = LeRobotDataset(
dataset.repo_id,
root=dataset.root,
episode_filter=lambda ep: ep["length"] >= threshold,
)
assert filtered.num_episodes == len(expected_eps)
assert filtered.num_frames == expected_frames
seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))}
assert seen_eps == set(expected_eps)
def test_episode_filter_intersects_with_episodes(tmp_path, lerobot_dataset_factory):
"""When both episodes and episode_filter are given to LeRobotDataset, the result is their intersection."""
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
lengths = dataset.meta.episodes["length"]
threshold = sorted(lengths)[len(lengths) // 2]
candidates = [0, 2, 4, 6]
expected_eps = [i for i in candidates if lengths[i] >= threshold]
if not expected_eps:
pytest.skip("multinomial draw produced no candidate above threshold")
filtered = LeRobotDataset(
dataset.repo_id,
root=dataset.root,
episodes=candidates,
episode_filter=lambda ep: ep["length"] >= threshold,
)
assert filtered.num_episodes == len(expected_eps)
seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))}
assert seen_eps == set(expected_eps)
def test_episode_filter_no_match_raises(tmp_path, lerobot_dataset_factory):
"""An empty match in LeRobotDataset's episode_filter raises a ValueError rather than silently returning an empty dataset."""
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=4, total_frames=100)
with pytest.raises(ValueError, match=r"episode_filter did not match any episode"):
LeRobotDataset(
dataset.repo_id,
root=dataset.root,
episode_filter=lambda ep: ep["length"] < 0,
)