mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
feat(episodes): adding support for metadata based episodes filtering (#3530)
* feat(episode filtering): adding support for episodes filtering at initialization time in LeRobotDataset * test(tests): adding tests * chore(format): formatting code * feat(performance): improving implementation for better performances on big datasets * chores(warning): improving warnings and errors for episodes filtering * test(invalid key): adding test for invalid filtering key * chore(format): formatting code
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from collections.abc import Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -189,6 +190,29 @@ class LeRobotDatasetMetadata:
|
|||||||
if self.episodes is None:
|
if self.episodes is None:
|
||||||
self._load_metadata()
|
self._load_metadata()
|
||||||
|
|
||||||
|
def filter_episodes(
|
||||||
|
self,
|
||||||
|
predicate: Callable[[dict], bool],
|
||||||
|
candidates: list[int] | None = None,
|
||||||
|
) -> list[int]:
|
||||||
|
"""Filter episodes whose metadata satisfies a given predicate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predicate: Predicate over per-episode metadata rows used to select episodes.
|
||||||
|
candidates: Optional list of episode indices to restrict evaluation to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of sorted episode indices that satisfy the predicate.
|
||||||
|
"""
|
||||||
|
self.ensure_readable()
|
||||||
|
if candidates is not None:
|
||||||
|
candidate_set = set(candidates)
|
||||||
|
combined = lambda ep: ep["episode_index"] in candidate_set and predicate(ep) # noqa: E731
|
||||||
|
else:
|
||||||
|
combined = predicate
|
||||||
|
filtered = self.episodes.filter(combined, keep_in_memory=True, load_from_cache_file=False)
|
||||||
|
return sorted(int(idx) for idx in filtered["episode_index"])
|
||||||
|
|
||||||
def _pull_from_repo(
|
def _pull_from_repo(
|
||||||
self,
|
self,
|
||||||
allow_patterns: list[str] | str | None = None,
|
allow_patterns: list[str] | str | None = None,
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
repo_id: str,
|
repo_id: str,
|
||||||
root: str | Path | None = None,
|
root: str | Path | None = None,
|
||||||
episodes: list[int] | None = None,
|
episodes: list[int] | None = None,
|
||||||
|
episode_filter: Callable[[dict], bool] | None = None,
|
||||||
image_transforms: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
delta_timestamps: dict[str, list[float]] | None = None,
|
delta_timestamps: dict[str, list[float]] | None = None,
|
||||||
tolerance_s: float = 1e-4,
|
tolerance_s: float = 1e-4,
|
||||||
@@ -153,6 +154,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
``$HF_LEROBOT_HOME/hub``.
|
``$HF_LEROBOT_HOME/hub``.
|
||||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||||
their episode_index in this list. Defaults to None.
|
their episode_index in this list. Defaults to None.
|
||||||
|
episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode
|
||||||
|
metadata rows used to select episodes. Evaluated against ``meta/`` without ``stats`` keys
|
||||||
|
(e.g.``task_index``, ``episode_index``, ``length``, ``from_timestamp``, ``to_timestamp``).
|
||||||
|
Intersected with ``episodes`` when both are set. Example: ``lambda ep: ep["length"] >= 100``.
|
||||||
|
Defaults to None.
|
||||||
image_transforms (Callable | None, optional):
|
image_transforms (Callable | None, optional):
|
||||||
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
|
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
|
||||||
conversion. This works for both image-backed and video-backed observations and can later be
|
conversion. This works for both image-backed and video-backed observations and can later be
|
||||||
@@ -199,7 +205,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.reader = None
|
self.reader = None
|
||||||
self.set_image_transforms(image_transforms)
|
self.set_image_transforms(image_transforms)
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.episodes = episodes
|
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||||
@@ -218,6 +223,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.root = self.meta.root
|
self.root = self.meta.root
|
||||||
self.revision = self.meta.revision
|
self.revision = self.meta.revision
|
||||||
|
|
||||||
|
if episodes is not None and any(
|
||||||
|
episode >= self.meta.total_episodes or episode < 0 for episode in episodes
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes})."
|
||||||
|
)
|
||||||
|
|
||||||
|
if episode_filter is not None:
|
||||||
|
resolved = self.meta.filter_episodes(episode_filter, candidates=episodes)
|
||||||
|
if not resolved:
|
||||||
|
raise ValueError(
|
||||||
|
"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible."
|
||||||
|
)
|
||||||
|
logger.info(f"The episode filter matched {len(resolved)} episode(s).")
|
||||||
|
episodes = resolved
|
||||||
|
self.episodes = episodes
|
||||||
|
|
||||||
# Create reader (hf_dataset loaded below)
|
# Create reader (hf_dataset loaded below)
|
||||||
self.reader = DatasetReader(
|
self.reader = DatasetReader(
|
||||||
meta=self.meta,
|
meta=self.meta,
|
||||||
|
|||||||
@@ -1691,3 +1691,68 @@ def test_delta_timestamps_query_returns_correct_values(tmp_path, empty_lerobot_d
|
|||||||
# Previous frame is outside episode, so it's clamped to first frame and marked as padded
|
# Previous frame is outside episode, so it's clamped to first frame and marked as padded
|
||||||
assert state_values == [10.0, 10.0], f"Expected [10.0, 10.0], got {state_values}"
|
assert state_values == [10.0, 10.0], f"Expected [10.0, 10.0], got {state_values}"
|
||||||
assert is_pad == [True, False], f"Expected [True, False], got {is_pad}"
|
assert is_pad == [True, False], f"Expected [True, False], got {is_pad}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_episode_filter_filters_dataset(tmp_path, lerobot_dataset_factory):
|
||||||
|
"""episode_filter on LeRobotDataset narrows the loaded dataset to matching episodes."""
|
||||||
|
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
|
||||||
|
lengths = dataset.meta.episodes["length"]
|
||||||
|
threshold = sorted(lengths)[len(lengths) // 2]
|
||||||
|
expected_eps = [i for i, length in enumerate(lengths) if length >= threshold]
|
||||||
|
expected_frames = sum(lengths[i] for i in expected_eps)
|
||||||
|
|
||||||
|
filtered = LeRobotDataset(
|
||||||
|
dataset.repo_id,
|
||||||
|
root=dataset.root,
|
||||||
|
episode_filter=lambda ep: ep["length"] >= threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert filtered.num_episodes == len(expected_eps)
|
||||||
|
assert filtered.num_frames == expected_frames
|
||||||
|
seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))}
|
||||||
|
assert seen_eps == set(expected_eps)
|
||||||
|
|
||||||
|
|
||||||
|
def test_episode_filter_intersects_with_episodes(tmp_path, lerobot_dataset_factory):
|
||||||
|
"""When both episodes and episode_filter are given to LeRobotDataset, the result is their intersection."""
|
||||||
|
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
|
||||||
|
lengths = dataset.meta.episodes["length"]
|
||||||
|
candidates = [0, 2, 4, 6]
|
||||||
|
candidate_lengths = [lengths[i] for i in candidates]
|
||||||
|
threshold = sorted(candidate_lengths)[len(candidate_lengths) // 2]
|
||||||
|
expected_eps = [i for i in candidates if lengths[i] >= threshold]
|
||||||
|
|
||||||
|
filtered = LeRobotDataset(
|
||||||
|
dataset.repo_id,
|
||||||
|
root=dataset.root,
|
||||||
|
episodes=candidates,
|
||||||
|
episode_filter=lambda ep: ep["length"] >= threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert filtered.num_episodes == len(expected_eps)
|
||||||
|
seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))}
|
||||||
|
assert seen_eps == set(expected_eps)
|
||||||
|
|
||||||
|
|
||||||
|
def test_episode_filter_no_match_raises(tmp_path, lerobot_dataset_factory):
|
||||||
|
"""An empty match in LeRobotDataset's episode_filter raises a ValueError rather than silently returning an empty dataset."""
|
||||||
|
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=4, total_frames=100)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r"The episode filter did not match any episode"):
|
||||||
|
LeRobotDataset(
|
||||||
|
dataset.repo_id,
|
||||||
|
root=dataset.root,
|
||||||
|
episode_filter=lambda ep: ep["length"] < 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_episode_filter_unknown_key_raises(tmp_path, lerobot_dataset_factory):
|
||||||
|
"""A predicate referencing a column absent from meta.episodes surfaces a clear KeyError."""
|
||||||
|
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=4, total_frames=100)
|
||||||
|
|
||||||
|
with pytest.raises(KeyError, match="not_a_real_field"):
|
||||||
|
LeRobotDataset(
|
||||||
|
dataset.repo_id,
|
||||||
|
root=dataset.root,
|
||||||
|
episode_filter=lambda ep: ep["not_a_real_field"] > 0,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user