mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
chore(format): formatting code
This commit is contained in:
@@ -155,9 +155,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode
|
||||
metadata rows used to select episodes; evaluated against ``meta/`` only, so
|
||||
non-matches skip ``data/`` and ``videos/`` downloads. Intersected with ``episodes``
|
||||
when both are set. Example: ``lambda ep: ep["length"] >= 100``. Defaults to None.
|
||||
metadata rows used to select episodes. Evaluated against ``meta/`` without ``stats`` keys
|
||||
(e.g.``task_index``, ``episode_index``, ``length``, ``from_timestamp``, ``to_timestamp``).
|
||||
Intersected with ``episodes`` when both are set. Example: ``lambda ep: ep["length"] >= 100``.
|
||||
Defaults to None.
|
||||
image_transforms (Callable | None, optional):
|
||||
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
|
||||
conversion. This works for both image-backed and video-backed observations and can later be
|
||||
@@ -222,13 +223,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.root = self.meta.root
|
||||
self.revision = self.meta.revision
|
||||
|
||||
if episodes is not None and any(episode>=self.meta.total_episodes or episode<0 for episode in episodes):
|
||||
logger.warning(f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes}).")
|
||||
|
||||
if episodes is not None and any(
|
||||
episode >= self.meta.total_episodes or episode < 0 for episode in episodes
|
||||
):
|
||||
logger.warning(
|
||||
f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes})."
|
||||
)
|
||||
|
||||
if episode_filter is not None:
|
||||
resolved = self.meta.filter_episodes(episode_filter, candidates=episodes)
|
||||
if not resolved:
|
||||
raise ValueError(f"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible.")
|
||||
raise ValueError(
|
||||
"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible."
|
||||
)
|
||||
logger.info(f"The episode filter matched {len(resolved)} episode(s).")
|
||||
episodes = resolved
|
||||
self.episodes = episodes
|
||||
|
||||
@@ -1717,11 +1717,10 @@ def test_episode_filter_intersects_with_episodes(tmp_path, lerobot_dataset_facto
|
||||
"""When both episodes and episode_filter are given to LeRobotDataset, the result is their intersection."""
|
||||
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
|
||||
lengths = dataset.meta.episodes["length"]
|
||||
threshold = sorted(lengths)[len(lengths) // 2]
|
||||
candidates = [0, 2, 4, 6]
|
||||
candidate_lengths = [lengths[i] for i in candidates]
|
||||
threshold = sorted(candidate_lengths)[len(candidate_lengths) // 2]
|
||||
expected_eps = [i for i in candidates if lengths[i] >= threshold]
|
||||
if not expected_eps:
|
||||
pytest.skip("multinomial draw produced no candidate above threshold")
|
||||
|
||||
filtered = LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
|
||||
Reference in New Issue
Block a user