From 0a848acc3b3850fd2fc555135b8d6b9ab7655284 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Sun, 10 May 2026 17:03:30 +0200 Subject: [PATCH] chore(format): formatting code --- src/lerobot/datasets/lerobot_dataset.py | 21 ++++++++++++++------- tests/datasets/test_datasets.py | 5 ++--- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 65a7ed500..ab55aa9f8 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -155,9 +155,10 @@ class LeRobotDataset(torch.utils.data.Dataset): episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. Defaults to None. episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode - metadata rows used to select episodes; evaluated against ``meta/`` only, so - non-matches skip ``data/`` and ``videos/`` downloads. Intersected with ``episodes`` - when both are set. Example: ``lambda ep: ep["length"] >= 100``. Defaults to None. + metadata rows used to select episodes. Evaluated against ``meta/`` without ``stats`` keys + (e.g.``task_index``, ``episode_index``, ``length``, ``from_timestamp``, ``to_timestamp``). + Intersected with ``episodes`` when both are set. Example: ``lambda ep: ep["length"] >= 100``. + Defaults to None. image_transforms (Callable | None, optional): Transform applied to visual modalities inside `__getitem__` after image decoding / tensor conversion. This works for both image-backed and video-backed observations and can later be @@ -222,13 +223,19 @@ class LeRobotDataset(torch.utils.data.Dataset): self.root = self.meta.root self.revision = self.meta.revision - if episodes is not None and any(episode>=self.meta.total_episodes or episode<0 for episode in episodes): - logger.warning(f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes}).") - + if episodes is not None and any( + episode >= self.meta.total_episodes or episode < 0 for episode in episodes + ): + logger.warning( + f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes})." + ) + if episode_filter is not None: resolved = self.meta.filter_episodes(episode_filter, candidates=episodes) if not resolved: - raise ValueError(f"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible.") + raise ValueError( + "The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible." + ) logger.info(f"The episode filter matched {len(resolved)} episode(s).") episodes = resolved self.episodes = episodes diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 9d6553393..654f8cdf1 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1717,11 +1717,10 @@ def test_episode_filter_intersects_with_episodes(tmp_path, lerobot_dataset_facto """When both episodes and episode_filter are given to LeRobotDataset, the result is their intersection.""" dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200) lengths = dataset.meta.episodes["length"] - threshold = sorted(lengths)[len(lengths) // 2] candidates = [0, 2, 4, 6] + candidate_lengths = [lengths[i] for i in candidates] + threshold = sorted(candidate_lengths)[len(candidate_lengths) // 2] expected_eps = [i for i in candidates if lengths[i] >= threshold] - if not expected_eps: - pytest.skip("multinomial draw produced no candidate above threshold") filtered = LeRobotDataset( dataset.repo_id,