chore(format): formatting code

This commit is contained in:
CarolinePascal
2026-05-10 17:03:30 +02:00
parent 2982329e28
commit 0a848acc3b
2 changed files with 16 additions and 10 deletions
+13 -6
View File
@@ -155,9 +155,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
episodes (list[int] | None, optional): If specified, this will only load episodes specified by episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None. their episode_index in this list. Defaults to None.
episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode
metadata rows used to select episodes; evaluated against ``meta/`` only, so metadata rows used to select episodes. Evaluated against ``meta/`` without ``stats`` keys
non-matches skip ``data/`` and ``videos/`` downloads. Intersected with ``episodes`` (e.g.``task_index``, ``episode_index``, ``length``, ``from_timestamp``, ``to_timestamp``).
when both are set. Example: ``lambda ep: ep["length"] >= 100``. Defaults to None. Intersected with ``episodes`` when both are set. Example: ``lambda ep: ep["length"] >= 100``.
Defaults to None.
image_transforms (Callable | None, optional): image_transforms (Callable | None, optional):
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
conversion. This works for both image-backed and video-backed observations and can later be conversion. This works for both image-backed and video-backed observations and can later be
@@ -222,13 +223,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.root = self.meta.root self.root = self.meta.root
self.revision = self.meta.revision self.revision = self.meta.revision
if episodes is not None and any(episode>=self.meta.total_episodes or episode<0 for episode in episodes): if episodes is not None and any(
logger.warning(f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes}).") episode >= self.meta.total_episodes or episode < 0 for episode in episodes
):
logger.warning(
f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes})."
)
if episode_filter is not None: if episode_filter is not None:
resolved = self.meta.filter_episodes(episode_filter, candidates=episodes) resolved = self.meta.filter_episodes(episode_filter, candidates=episodes)
if not resolved: if not resolved:
raise ValueError(f"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible.") raise ValueError(
"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible."
)
logger.info(f"The episode filter matched {len(resolved)} episode(s).") logger.info(f"The episode filter matched {len(resolved)} episode(s).")
episodes = resolved episodes = resolved
self.episodes = episodes self.episodes = episodes
+2 -3
View File
@@ -1717,11 +1717,10 @@ def test_episode_filter_intersects_with_episodes(tmp_path, lerobot_dataset_facto
"""When both episodes and episode_filter are given to LeRobotDataset, the result is their intersection.""" """When both episodes and episode_filter are given to LeRobotDataset, the result is their intersection."""
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200) dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
lengths = dataset.meta.episodes["length"] lengths = dataset.meta.episodes["length"]
threshold = sorted(lengths)[len(lengths) // 2]
candidates = [0, 2, 4, 6] candidates = [0, 2, 4, 6]
candidate_lengths = [lengths[i] for i in candidates]
threshold = sorted(candidate_lengths)[len(candidate_lengths) // 2]
expected_eps = [i for i in candidates if lengths[i] >= threshold] expected_eps = [i for i in candidates if lengths[i] >= threshold]
if not expected_eps:
pytest.skip("multinomial draw produced no candidate above threshold")
filtered = LeRobotDataset( filtered = LeRobotDataset(
dataset.repo_id, dataset.repo_id,