mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
feat(episode filtering): adding support for episodes filtering at initialization time in LeRobotDataset
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from collections.abc import Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -189,6 +190,28 @@ class LeRobotDatasetMetadata:
|
|||||||
if self.episodes is None:
|
if self.episodes is None:
|
||||||
self._load_metadata()
|
self._load_metadata()
|
||||||
|
|
||||||
|
def filter_episodes(
|
||||||
|
self,
|
||||||
|
predicate: Callable[[dict], bool],
|
||||||
|
candidates: list[int] | None = None,
|
||||||
|
) -> list[int]:
|
||||||
|
"""Filter episodes whose metadata satisfies a given predicate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predicate: Predicate over per-episode metadata rows used to select episodes.
|
||||||
|
candidates: Optional list of episode indices to restrict evaluation to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of sorted episode indices that satisfy the predicate.
|
||||||
|
"""
|
||||||
|
self.ensure_readable()
|
||||||
|
ep_table = self.episodes
|
||||||
|
if candidates is not None:
|
||||||
|
candidate_set = set(candidates)
|
||||||
|
ep_table = ep_table.filter(lambda ep: ep["episode_index"] in candidate_set)
|
||||||
|
filtered = ep_table.filter(predicate)
|
||||||
|
return sorted(int(idx) for idx in filtered["episode_index"])
|
||||||
|
|
||||||
def _pull_from_repo(
|
def _pull_from_repo(
|
||||||
self,
|
self,
|
||||||
allow_patterns: list[str] | str | None = None,
|
allow_patterns: list[str] | str | None = None,
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
repo_id: str,
|
repo_id: str,
|
||||||
root: str | Path | None = None,
|
root: str | Path | None = None,
|
||||||
episodes: list[int] | None = None,
|
episodes: list[int] | None = None,
|
||||||
|
episode_filter: Callable[[dict], bool] | None = None,
|
||||||
image_transforms: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
delta_timestamps: dict[str, list[float]] | None = None,
|
delta_timestamps: dict[str, list[float]] | None = None,
|
||||||
tolerance_s: float = 1e-4,
|
tolerance_s: float = 1e-4,
|
||||||
@@ -153,6 +154,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
``$HF_LEROBOT_HOME/hub``.
|
``$HF_LEROBOT_HOME/hub``.
|
||||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||||
their episode_index in this list. Defaults to None.
|
their episode_index in this list. Defaults to None.
|
||||||
|
episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode
|
||||||
|
metadata rows used to select episodes; evaluated against ``meta/`` only, so
|
||||||
|
non-matches skip ``data/`` and ``videos/`` downloads. Intersected with ``episodes``
|
||||||
|
when both are set. Example: ``lambda ep: ep["length"] >= 100``. Defaults to None.
|
||||||
image_transforms (Callable | None, optional):
|
image_transforms (Callable | None, optional):
|
||||||
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
|
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
|
||||||
conversion. This works for both image-backed and video-backed observations and can later be
|
conversion. This works for both image-backed and video-backed observations and can later be
|
||||||
@@ -218,6 +223,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.root = self.meta.root
|
self.root = self.meta.root
|
||||||
self.revision = self.meta.revision
|
self.revision = self.meta.revision
|
||||||
|
|
||||||
|
if episode_filter is not None:
|
||||||
|
resolved = self.meta.filter_episodes(episode_filter, candidates=episodes)
|
||||||
|
pool = len(episodes) if episodes is not None else self.meta.total_episodes
|
||||||
|
if not resolved:
|
||||||
|
raise ValueError(f"episode_filter did not match any episode over {pool} episode(s).")
|
||||||
|
logger.info(
|
||||||
|
f"episode_filter matched {len(resolved)} episode(s) over {pool} episode(s)."
|
||||||
|
)
|
||||||
|
episodes = resolved
|
||||||
|
self.episodes = resolved
|
||||||
|
|
||||||
# Create reader (hf_dataset loaded below)
|
# Create reader (hf_dataset loaded below)
|
||||||
self.reader = DatasetReader(
|
self.reader = DatasetReader(
|
||||||
meta=self.meta,
|
meta=self.meta,
|
||||||
|
|||||||
Reference in New Issue
Block a user