From 9e1fb1c2dda78cb5fd2692e8489a74923a832f7c Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Thu, 7 May 2026 15:52:08 +0200 Subject: [PATCH] feat(episode filtering): adding support for episodes filtering at initialization time in LeRobotDataset --- src/lerobot/datasets/dataset_metadata.py | 23 +++++++++++++++++++++++ src/lerobot/datasets/lerobot_dataset.py | 16 ++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index 4f89ba2a4..4d4584ba4 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +from collections.abc import Callable from pathlib import Path import numpy as np @@ -189,6 +190,28 @@ class LeRobotDatasetMetadata: if self.episodes is None: self._load_metadata() + def filter_episodes( + self, + predicate: Callable[[dict], bool], + candidates: list[int] | None = None, + ) -> list[int]: + """Filter episodes whose metadata satisfies a given predicate. + + Args: + predicate: Predicate over per-episode metadata rows used to select episodes. + candidates: Optional list of episode indices to restrict evaluation to. + + Returns: + List of sorted episode indices that satisfy the predicate. + """ + self.ensure_readable() + ep_table = self.episodes + if candidates is not None: + candidate_set = set(candidates) + ep_table = ep_table.filter(lambda ep: ep["episode_index"] in candidate_set) + filtered = ep_table.filter(predicate) + return sorted(int(idx) for idx in filtered["episode_index"]) + def _pull_from_repo( self, allow_patterns: list[str] | str | None = None, diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index b6ab0f5f0..82b17ed0c 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -49,6 +49,7 @@ class LeRobotDataset(torch.utils.data.Dataset): repo_id: str, root: str | Path | None = None, episodes: list[int] | None = None, + episode_filter: Callable[[dict], bool] | None = None, image_transforms: Callable | None = None, delta_timestamps: dict[str, list[float]] | None = None, tolerance_s: float = 1e-4, @@ -153,6 +154,10 @@ class LeRobotDataset(torch.utils.data.Dataset): ``$HF_LEROBOT_HOME/hub``. episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. Defaults to None. + episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode + metadata rows used to select episodes; evaluated against ``meta/`` only, so + non-matches skip ``data/`` and ``videos/`` downloads. Intersected with ``episodes`` + when both are set. Example: ``lambda ep: ep["length"] >= 100``. Defaults to None. image_transforms (Callable | None, optional): Transform applied to visual modalities inside `__getitem__` after image decoding / tensor conversion. This works for both image-backed and video-backed observations and can later be @@ -218,6 +223,17 @@ class LeRobotDataset(torch.utils.data.Dataset): self.root = self.meta.root self.revision = self.meta.revision + if episode_filter is not None: + resolved = self.meta.filter_episodes(episode_filter, candidates=episodes) + pool = len(episodes) if episodes is not None else self.meta.total_episodes + if not resolved: + raise ValueError(f"episode_filter did not match any episode over {pool} episode(s).") + logger.info( + f"episode_filter matched {len(resolved)} episode(s) over {pool} episode(s)." + ) + episodes = resolved + self.episodes = resolved + # Create reader (hf_dataset loaded below) self.reader = DatasetReader( meta=self.meta,