diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index 4d4584ba4..b404ddb18 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -205,11 +205,12 @@ class LeRobotDatasetMetadata: List of sorted episode indices that satisfy the predicate. """ self.ensure_readable() - ep_table = self.episodes if candidates is not None: candidate_set = set(candidates) - ep_table = ep_table.filter(lambda ep: ep["episode_index"] in candidate_set) - filtered = ep_table.filter(predicate) + combined = lambda ep: ep["episode_index"] in candidate_set and predicate(ep) # noqa: E731 + else: + combined = predicate + filtered = self.episodes.filter(combined, keep_in_memory=True, load_from_cache_file=False) return sorted(int(idx) for idx in filtered["episode_index"]) def _pull_from_repo(