feat(episode filtering): adding support for episodes filtering at initialization time in LeRobotDataset

This commit is contained in:
CarolinePascal
2026-05-07 15:52:08 +02:00
parent 1f7b03f5f2
commit 9e1fb1c2dd
2 changed files with 39 additions and 0 deletions
+23
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from collections.abc import Callable
from pathlib import Path
import numpy as np
@@ -189,6 +190,28 @@ class LeRobotDatasetMetadata:
if self.episodes is None:
self._load_metadata()
def filter_episodes(
self,
predicate: Callable[[dict], bool],
candidates: list[int] | None = None,
) -> list[int]:
"""Filter episodes whose metadata satisfies a given predicate.
Args:
predicate: Predicate over per-episode metadata rows used to select episodes.
candidates: Optional list of episode indices to restrict evaluation to.
Returns:
List of sorted episode indices that satisfy the predicate.
"""
self.ensure_readable()
ep_table = self.episodes
if candidates is not None:
candidate_set = set(candidates)
ep_table = ep_table.filter(lambda ep: ep["episode_index"] in candidate_set)
filtered = ep_table.filter(predicate)
return sorted(int(idx) for idx in filtered["episode_index"])
def _pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
+16
View File
@@ -49,6 +49,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
repo_id: str,
root: str | Path | None = None,
episodes: list[int] | None = None,
episode_filter: Callable[[dict], bool] | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerance_s: float = 1e-4,
@@ -153,6 +154,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
``$HF_LEROBOT_HOME/hub``.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode
metadata rows used to select episodes; evaluated against ``meta/`` only, so
non-matches skip ``data/`` and ``videos/`` downloads. Intersected with ``episodes``
when both are set. Example: ``lambda ep: ep["length"] >= 100``. Defaults to None.
image_transforms (Callable | None, optional):
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
conversion. This works for both image-backed and video-backed observations and can later be
@@ -218,6 +223,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.root = self.meta.root
self.revision = self.meta.revision
if episode_filter is not None:
resolved = self.meta.filter_episodes(episode_filter, candidates=episodes)
pool = len(episodes) if episodes is not None else self.meta.total_episodes
if not resolved:
raise ValueError(f"episode_filter did not match any episode over {pool} episode(s).")
logger.info(
f"episode_filter matched {len(resolved)} episode(s) over {pool} episode(s)."
)
episodes = resolved
self.episodes = resolved
# Create reader (hf_dataset loaded below)
self.reader = DatasetReader(
meta=self.meta,