fix(datasets): Move the remapping into EpisodeAwareSampler via absolute_to_relative_idx

This commit is contained in:
Khalil Meftah
2026-06-16 18:32:48 +02:00
parent 2b83956eb5
commit 9449e68725
3 changed files with 19 additions and 3 deletions
+12 -2
View File
@@ -370,6 +370,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.reader.load_and_activate()
return self.reader.hf_dataset
@property
def absolute_to_relative_idx(self) -> dict[int, int] | None:
"""Mapping from absolute frame indices to HF dataset row positions.
Non-None only for episode-filtered datasets where absolute indices
(from metadata) differ from row positions in the loaded HF dataset.
"""
reader = self._ensure_reader()
if reader.hf_dataset is None:
reader.load_and_activate()
return reader._absolute_to_relative_idx
# ── Writer-delegated methods ──────────────────────────────────────
def add_frame(self, frame: dict) -> None:
@@ -474,8 +486,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
if reader.hf_dataset is None:
# One-shot load after finalize()
reader.load_and_activate()
if reader._absolute_to_relative_idx is not None and idx in reader._absolute_to_relative_idx:
idx = reader._absolute_to_relative_idx[idx]
return reader.get_item(idx)
def select_columns(self, column_names: str | list[str]):
+6 -1
View File
@@ -53,6 +53,7 @@ class EpisodeAwareSampler:
drop_n_last_frames: int = 0,
shuffle: bool = False,
seed: int = 0,
absolute_to_relative_idx: dict[int, int] | None = None,
):
"""
Args:
@@ -107,6 +108,7 @@ class EpisodeAwareSampler:
self.seed = seed
self._epoch = 0
self._start_index = 0
self._absolute_to_relative = absolute_to_relative_idx
@property
def indices(self) -> list[int]:
@@ -132,7 +134,10 @@ class EpisodeAwareSampler:
def _frame_index(self, position: int) -> int:
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
return int(self._starts[episode]) + position_in_episode
absolute_idx = int(self._starts[episode]) + position_in_episode
if self._absolute_to_relative is not None:
return self._absolute_to_relative[absolute_idx]
return absolute_idx
def __iter__(self) -> Iterator[int]:
# Advance epoch state eagerly, not on first consumption of the generator.
+1
View File
@@ -407,6 +407,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
shuffle=True,
seed=cfg.seed if cfg.seed is not None else 0,
absolute_to_relative_idx=dataset.absolute_to_relative_idx,
)
if cfg.resume and step > 0:
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so