chore(format): formatting code

This commit is contained in:
CarolinePascal
2026-05-10 17:03:30 +02:00
parent 2982329e28
commit 0a848acc3b
2 changed files with 16 additions and 10 deletions
+14 -7
View File
@@ -155,9 +155,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode
metadata rows used to select episodes; evaluated against ``meta/`` only, so
non-matches skip ``data/`` and ``videos/`` downloads. Intersected with ``episodes``
when both are set. Example: ``lambda ep: ep["length"] >= 100``. Defaults to None.
metadata rows used to select episodes. Evaluated against ``meta/`` without ``stats`` keys
(e.g.``task_index``, ``episode_index``, ``length``, ``from_timestamp``, ``to_timestamp``).
Intersected with ``episodes`` when both are set. Example: ``lambda ep: ep["length"] >= 100``.
Defaults to None.
image_transforms (Callable | None, optional):
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
conversion. This works for both image-backed and video-backed observations and can later be
@@ -222,13 +223,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.root = self.meta.root
self.revision = self.meta.revision
if episodes is not None and any(episode>=self.meta.total_episodes or episode<0 for episode in episodes):
logger.warning(f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes}).")
if episodes is not None and any(
episode >= self.meta.total_episodes or episode < 0 for episode in episodes
):
logger.warning(
f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes})."
)
if episode_filter is not None:
resolved = self.meta.filter_episodes(episode_filter, candidates=episodes)
if not resolved:
raise ValueError(f"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible.")
raise ValueError(
"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible."
)
logger.info(f"The episode filter matched {len(resolved)} episode(s).")
episodes = resolved
self.episodes = episodes
+2 -3
View File
@@ -1717,11 +1717,10 @@ def test_episode_filter_intersects_with_episodes(tmp_path, lerobot_dataset_facto
"""When both episodes and episode_filter are given to LeRobotDataset, the result is their intersection."""
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
lengths = dataset.meta.episodes["length"]
threshold = sorted(lengths)[len(lengths) // 2]
candidates = [0, 2, 4, 6]
candidate_lengths = [lengths[i] for i in candidates]
threshold = sorted(candidate_lengths)[len(candidate_lengths) // 2]
expected_eps = [i for i in candidates if lengths[i] >= threshold]
if not expected_eps:
pytest.skip("multinomial draw produced no candidate above threshold")
filtered = LeRobotDataset(
dataset.repo_id,