mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
feat(episode filtering): adding support for episodes filtering at initialization time in LeRobotDataset
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -189,6 +190,28 @@ class LeRobotDatasetMetadata:
|
||||
if self.episodes is None:
|
||||
self._load_metadata()
|
||||
|
||||
def filter_episodes(
|
||||
self,
|
||||
predicate: Callable[[dict], bool],
|
||||
candidates: list[int] | None = None,
|
||||
) -> list[int]:
|
||||
"""Filter episodes whose metadata satisfies a given predicate.
|
||||
|
||||
Args:
|
||||
predicate: Predicate over per-episode metadata rows used to select episodes.
|
||||
candidates: Optional list of episode indices to restrict evaluation to.
|
||||
|
||||
Returns:
|
||||
List of sorted episode indices that satisfy the predicate.
|
||||
"""
|
||||
self.ensure_readable()
|
||||
ep_table = self.episodes
|
||||
if candidates is not None:
|
||||
candidate_set = set(candidates)
|
||||
ep_table = ep_table.filter(lambda ep: ep["episode_index"] in candidate_set)
|
||||
filtered = ep_table.filter(predicate)
|
||||
return sorted(int(idx) for idx in filtered["episode_index"])
|
||||
|
||||
def _pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
|
||||
@@ -49,6 +49,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
episode_filter: Callable[[dict], bool] | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
@@ -153,6 +154,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
``$HF_LEROBOT_HOME/hub``.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode
|
||||
metadata rows used to select episodes; evaluated against ``meta/`` only, so
|
||||
non-matches skip ``data/`` and ``videos/`` downloads. Intersected with ``episodes``
|
||||
when both are set. Example: ``lambda ep: ep["length"] >= 100``. Defaults to None.
|
||||
image_transforms (Callable | None, optional):
|
||||
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
|
||||
conversion. This works for both image-backed and video-backed observations and can later be
|
||||
@@ -218,6 +223,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.root = self.meta.root
|
||||
self.revision = self.meta.revision
|
||||
|
||||
if episode_filter is not None:
|
||||
resolved = self.meta.filter_episodes(episode_filter, candidates=episodes)
|
||||
pool = len(episodes) if episodes is not None else self.meta.total_episodes
|
||||
if not resolved:
|
||||
raise ValueError(f"episode_filter did not match any episode over {pool} episode(s).")
|
||||
logger.info(
|
||||
f"episode_filter matched {len(resolved)} episode(s) over {pool} episode(s)."
|
||||
)
|
||||
episodes = resolved
|
||||
self.episodes = resolved
|
||||
|
||||
# Create reader (hf_dataset loaded below)
|
||||
self.reader = DatasetReader(
|
||||
meta=self.meta,
|
||||
|
||||
Reference in New Issue
Block a user