mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-26 20:57:28 +00:00
test(tests): adding tests
This commit is contained in:
@@ -1691,3 +1691,57 @@ def test_delta_timestamps_query_returns_correct_values(tmp_path, empty_lerobot_d
|
||||
# Previous frame is outside episode, so it's clamped to first frame and marked as padded
|
||||
assert state_values == [10.0, 10.0], f"Expected [10.0, 10.0], got {state_values}"
|
||||
assert is_pad == [True, False], f"Expected [True, False], got {is_pad}"
|
||||
|
||||
|
||||
def test_episode_filter_filters_dataset(tmp_path, lerobot_dataset_factory):
|
||||
"""episode_filter on LeRobotDataset narrows the loaded dataset to matching episodes."""
|
||||
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
|
||||
lengths = dataset.meta.episodes["length"]
|
||||
threshold = sorted(lengths)[len(lengths) // 2]
|
||||
expected_eps = [i for i, length in enumerate(lengths) if length >= threshold]
|
||||
expected_frames = sum(lengths[i] for i in expected_eps)
|
||||
|
||||
filtered = LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episode_filter=lambda ep: ep["length"] >= threshold,
|
||||
)
|
||||
|
||||
assert filtered.num_episodes == len(expected_eps)
|
||||
assert filtered.num_frames == expected_frames
|
||||
seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))}
|
||||
assert seen_eps == set(expected_eps)
|
||||
|
||||
|
||||
def test_episode_filter_intersects_with_episodes(tmp_path, lerobot_dataset_factory):
|
||||
"""When both episodes and episode_filter are given to LeRobotDataset, the result is their intersection."""
|
||||
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
|
||||
lengths = dataset.meta.episodes["length"]
|
||||
threshold = sorted(lengths)[len(lengths) // 2]
|
||||
candidates = [0, 2, 4, 6]
|
||||
expected_eps = [i for i in candidates if lengths[i] >= threshold]
|
||||
if not expected_eps:
|
||||
pytest.skip("multinomial draw produced no candidate above threshold")
|
||||
|
||||
filtered = LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episodes=candidates,
|
||||
episode_filter=lambda ep: ep["length"] >= threshold,
|
||||
)
|
||||
|
||||
assert filtered.num_episodes == len(expected_eps)
|
||||
seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))}
|
||||
assert seen_eps == set(expected_eps)
|
||||
|
||||
|
||||
def test_episode_filter_no_match_raises(tmp_path, lerobot_dataset_factory):
|
||||
"""An empty match in LeRobotDataset's episode_filter raises a ValueError rather than silently returning an empty dataset."""
|
||||
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=4, total_frames=100)
|
||||
|
||||
with pytest.raises(ValueError, match=r"episode_filter did not match any episode"):
|
||||
LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episode_filter=lambda ep: ep["length"] < 0,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user