mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
chore(format): formatting code
This commit is contained in:
@@ -155,9 +155,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||||
their episode_index in this list. Defaults to None.
|
their episode_index in this list. Defaults to None.
|
||||||
episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode
|
episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode
|
||||||
metadata rows used to select episodes; evaluated against ``meta/`` only, so
|
metadata rows used to select episodes. Evaluated against ``meta/`` without ``stats`` keys
|
||||||
non-matches skip ``data/`` and ``videos/`` downloads. Intersected with ``episodes``
|
(e.g.``task_index``, ``episode_index``, ``length``, ``from_timestamp``, ``to_timestamp``).
|
||||||
when both are set. Example: ``lambda ep: ep["length"] >= 100``. Defaults to None.
|
Intersected with ``episodes`` when both are set. Example: ``lambda ep: ep["length"] >= 100``.
|
||||||
|
Defaults to None.
|
||||||
image_transforms (Callable | None, optional):
|
image_transforms (Callable | None, optional):
|
||||||
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
|
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
|
||||||
conversion. This works for both image-backed and video-backed observations and can later be
|
conversion. This works for both image-backed and video-backed observations and can later be
|
||||||
@@ -222,13 +223,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.root = self.meta.root
|
self.root = self.meta.root
|
||||||
self.revision = self.meta.revision
|
self.revision = self.meta.revision
|
||||||
|
|
||||||
if episodes is not None and any(episode>=self.meta.total_episodes or episode<0 for episode in episodes):
|
if episodes is not None and any(
|
||||||
logger.warning(f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes}).")
|
episode >= self.meta.total_episodes or episode < 0 for episode in episodes
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes})."
|
||||||
|
)
|
||||||
|
|
||||||
if episode_filter is not None:
|
if episode_filter is not None:
|
||||||
resolved = self.meta.filter_episodes(episode_filter, candidates=episodes)
|
resolved = self.meta.filter_episodes(episode_filter, candidates=episodes)
|
||||||
if not resolved:
|
if not resolved:
|
||||||
raise ValueError(f"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible.")
|
raise ValueError(
|
||||||
|
"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible."
|
||||||
|
)
|
||||||
logger.info(f"The episode filter matched {len(resolved)} episode(s).")
|
logger.info(f"The episode filter matched {len(resolved)} episode(s).")
|
||||||
episodes = resolved
|
episodes = resolved
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
|
|||||||
@@ -1717,11 +1717,10 @@ def test_episode_filter_intersects_with_episodes(tmp_path, lerobot_dataset_facto
|
|||||||
"""When both episodes and episode_filter are given to LeRobotDataset, the result is their intersection."""
|
"""When both episodes and episode_filter are given to LeRobotDataset, the result is their intersection."""
|
||||||
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
|
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
|
||||||
lengths = dataset.meta.episodes["length"]
|
lengths = dataset.meta.episodes["length"]
|
||||||
threshold = sorted(lengths)[len(lengths) // 2]
|
|
||||||
candidates = [0, 2, 4, 6]
|
candidates = [0, 2, 4, 6]
|
||||||
|
candidate_lengths = [lengths[i] for i in candidates]
|
||||||
|
threshold = sorted(candidate_lengths)[len(candidate_lengths) // 2]
|
||||||
expected_eps = [i for i in candidates if lengths[i] >= threshold]
|
expected_eps = [i for i in candidates if lengths[i] >= threshold]
|
||||||
if not expected_eps:
|
|
||||||
pytest.skip("multinomial draw produced no candidate above threshold")
|
|
||||||
|
|
||||||
filtered = LeRobotDataset(
|
filtered = LeRobotDataset(
|
||||||
dataset.repo_id,
|
dataset.repo_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user