diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index 4f89ba2a4..b404ddb18 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +from collections.abc import Callable from pathlib import Path import numpy as np @@ -189,6 +190,29 @@ class LeRobotDatasetMetadata: if self.episodes is None: self._load_metadata() + def filter_episodes( + self, + predicate: Callable[[dict], bool], + candidates: list[int] | None = None, + ) -> list[int]: + """Filter episodes whose metadata satisfies a given predicate. + + Args: + predicate: Predicate over per-episode metadata rows used to select episodes. + candidates: Optional list of episode indices to restrict evaluation to. + + Returns: + List of sorted episode indices that satisfy the predicate. + """ + self.ensure_readable() + if candidates is not None: + candidate_set = set(candidates) + combined = lambda ep: ep["episode_index"] in candidate_set and predicate(ep) # noqa: E731 + else: + combined = predicate + filtered = self.episodes.filter(combined, keep_in_memory=True, load_from_cache_file=False) + return sorted(int(idx) for idx in filtered["episode_index"]) + def _pull_from_repo( self, allow_patterns: list[str] | str | None = None, diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index b6ab0f5f0..ab55aa9f8 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -49,6 +49,7 @@ class LeRobotDataset(torch.utils.data.Dataset): repo_id: str, root: str | Path | None = None, episodes: list[int] | None = None, + episode_filter: Callable[[dict], bool] | None = None, image_transforms: Callable | None = None, delta_timestamps: dict[str, list[float]] | None = None, tolerance_s: float = 1e-4, @@ -153,6 +154,11 @@ class LeRobotDataset(torch.utils.data.Dataset): ``$HF_LEROBOT_HOME/hub``. episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. Defaults to None. + episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode + metadata rows used to select episodes. Evaluated against ``meta/`` without ``stats`` keys + (e.g.``task_index``, ``episode_index``, ``length``, ``from_timestamp``, ``to_timestamp``). + Intersected with ``episodes`` when both are set. Example: ``lambda ep: ep["length"] >= 100``. + Defaults to None. image_transforms (Callable | None, optional): Transform applied to visual modalities inside `__getitem__` after image decoding / tensor conversion. This works for both image-backed and video-backed observations and can later be @@ -199,7 +205,6 @@ class LeRobotDataset(torch.utils.data.Dataset): self.reader = None self.set_image_transforms(image_transforms) self.delta_timestamps = delta_timestamps - self.episodes = episodes self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION self._video_backend = video_backend if video_backend else get_safe_default_codec() @@ -218,6 +223,23 @@ class LeRobotDataset(torch.utils.data.Dataset): self.root = self.meta.root self.revision = self.meta.revision + if episodes is not None and any( + episode >= self.meta.total_episodes or episode < 0 for episode in episodes + ): + logger.warning( + f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes})." + ) + + if episode_filter is not None: + resolved = self.meta.filter_episodes(episode_filter, candidates=episodes) + if not resolved: + raise ValueError( + "The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible." + ) + logger.info(f"The episode filter matched {len(resolved)} episode(s).") + episodes = resolved + self.episodes = episodes + # Create reader (hf_dataset loaded below) self.reader = DatasetReader( meta=self.meta, diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 3e1a17a62..654f8cdf1 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1691,3 +1691,68 @@ def test_delta_timestamps_query_returns_correct_values(tmp_path, empty_lerobot_d # Previous frame is outside episode, so it's clamped to first frame and marked as padded assert state_values == [10.0, 10.0], f"Expected [10.0, 10.0], got {state_values}" assert is_pad == [True, False], f"Expected [True, False], got {is_pad}" + + +def test_episode_filter_filters_dataset(tmp_path, lerobot_dataset_factory): + """episode_filter on LeRobotDataset narrows the loaded dataset to matching episodes.""" + dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200) + lengths = dataset.meta.episodes["length"] + threshold = sorted(lengths)[len(lengths) // 2] + expected_eps = [i for i, length in enumerate(lengths) if length >= threshold] + expected_frames = sum(lengths[i] for i in expected_eps) + + filtered = LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episode_filter=lambda ep: ep["length"] >= threshold, + ) + + assert filtered.num_episodes == len(expected_eps) + assert filtered.num_frames == expected_frames + seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))} + assert seen_eps == set(expected_eps) + + +def test_episode_filter_intersects_with_episodes(tmp_path, lerobot_dataset_factory): + """When both episodes and episode_filter are given to LeRobotDataset, the result is their intersection.""" + dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200) + lengths = dataset.meta.episodes["length"] + candidates = [0, 2, 4, 6] + candidate_lengths = [lengths[i] for i in candidates] + threshold = sorted(candidate_lengths)[len(candidate_lengths) // 2] + expected_eps = [i for i in candidates if lengths[i] >= threshold] + + filtered = LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episodes=candidates, + episode_filter=lambda ep: ep["length"] >= threshold, + ) + + assert filtered.num_episodes == len(expected_eps) + seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))} + assert seen_eps == set(expected_eps) + + +def test_episode_filter_no_match_raises(tmp_path, lerobot_dataset_factory): + """An empty match in LeRobotDataset's episode_filter raises a ValueError rather than silently returning an empty dataset.""" + dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=4, total_frames=100) + + with pytest.raises(ValueError, match=r"The episode filter did not match any episode"): + LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episode_filter=lambda ep: ep["length"] < 0, + ) + + +def test_episode_filter_unknown_key_raises(tmp_path, lerobot_dataset_factory): + """A predicate referencing a column absent from meta.episodes surfaces a clear KeyError.""" + dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=4, total_frames=100) + + with pytest.raises(KeyError, match="not_a_real_field"): + LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episode_filter=lambda ep: ep["not_a_real_field"] > 0, + )