diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 99219783d..a44566269 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -370,6 +370,18 @@ class LeRobotDataset(torch.utils.data.Dataset): self.reader.load_and_activate() return self.reader.hf_dataset + @property + def absolute_to_relative_idx(self) -> dict[int, int] | None: + """Mapping from absolute frame indices to HF dataset row positions. + + Non-None only for episode-filtered datasets where absolute indices + (from metadata) differ from row positions in the loaded HF dataset. + """ + reader = self._ensure_reader() + if reader.hf_dataset is None: + reader.load_and_activate() + return reader._absolute_to_relative_idx + # ── Writer-delegated methods ────────────────────────────────────── def add_frame(self, frame: dict) -> None: @@ -474,8 +486,6 @@ class LeRobotDataset(torch.utils.data.Dataset): if reader.hf_dataset is None: # One-shot load after finalize() reader.load_and_activate() - if reader._absolute_to_relative_idx is not None and idx in reader._absolute_to_relative_idx: - idx = reader._absolute_to_relative_idx[idx] return reader.get_item(idx) def select_columns(self, column_names: str | list[str]): diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index af85dff9b..aee6ce46d 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -53,6 +53,7 @@ class EpisodeAwareSampler: drop_n_last_frames: int = 0, shuffle: bool = False, seed: int = 0, + absolute_to_relative_idx: dict[int, int] | None = None, ): """ Args: @@ -107,6 +108,7 @@ class EpisodeAwareSampler: self.seed = seed self._epoch = 0 self._start_index = 0 + self._absolute_to_relative = absolute_to_relative_idx @property def indices(self) -> list[int]: @@ -132,7 +134,10 @@ class EpisodeAwareSampler: def _frame_index(self, position: int) -> int: episode = int(np.searchsorted(self._cum_lengths, position, side="right")) position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0) - return int(self._starts[episode]) + position_in_episode + absolute_idx = int(self._starts[episode]) + position_in_episode + if self._absolute_to_relative is not None: + return self._absolute_to_relative[absolute_idx] + return absolute_idx def __iter__(self) -> Iterator[int]: # Advance epoch state eagerly, not on first consumption of the generator. diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index b0ad6a50b..956bcdae2 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -407,6 +407,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0), shuffle=True, seed=cfg.seed if cfg.seed is not None else 0, + absolute_to_relative_idx=dataset.absolute_to_relative_idx, ) if cfg.resume and step > 0: # The resume offset depends on the (num_processes, batch_size) that produced `step`, so