mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
fix(datasets): Move the remapping into EpisodeAwareSampler via absolute_to_relative_idx
This commit is contained in:
@@ -370,6 +370,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.reader.load_and_activate()
|
||||
return self.reader.hf_dataset
|
||||
|
||||
@property
|
||||
def absolute_to_relative_idx(self) -> dict[int, int] | None:
|
||||
"""Mapping from absolute frame indices to HF dataset row positions.
|
||||
|
||||
Non-None only for episode-filtered datasets where absolute indices
|
||||
(from metadata) differ from row positions in the loaded HF dataset.
|
||||
"""
|
||||
reader = self._ensure_reader()
|
||||
if reader.hf_dataset is None:
|
||||
reader.load_and_activate()
|
||||
return reader._absolute_to_relative_idx
|
||||
|
||||
# ── Writer-delegated methods ──────────────────────────────────────
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
@@ -474,8 +486,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if reader.hf_dataset is None:
|
||||
# One-shot load after finalize()
|
||||
reader.load_and_activate()
|
||||
if reader._absolute_to_relative_idx is not None and idx in reader._absolute_to_relative_idx:
|
||||
idx = reader._absolute_to_relative_idx[idx]
|
||||
return reader.get_item(idx)
|
||||
|
||||
def select_columns(self, column_names: str | list[str]):
|
||||
|
||||
@@ -53,6 +53,7 @@ class EpisodeAwareSampler:
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
absolute_to_relative_idx: dict[int, int] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -107,6 +108,7 @@ class EpisodeAwareSampler:
|
||||
self.seed = seed
|
||||
self._epoch = 0
|
||||
self._start_index = 0
|
||||
self._absolute_to_relative = absolute_to_relative_idx
|
||||
|
||||
@property
|
||||
def indices(self) -> list[int]:
|
||||
@@ -132,7 +134,10 @@ class EpisodeAwareSampler:
|
||||
def _frame_index(self, position: int) -> int:
|
||||
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
|
||||
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
|
||||
return int(self._starts[episode]) + position_in_episode
|
||||
absolute_idx = int(self._starts[episode]) + position_in_episode
|
||||
if self._absolute_to_relative is not None:
|
||||
return self._absolute_to_relative[absolute_idx]
|
||||
return absolute_idx
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
# Advance epoch state eagerly, not on first consumption of the generator.
|
||||
|
||||
@@ -407,6 +407,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
|
||||
shuffle=True,
|
||||
seed=cfg.seed if cfg.seed is not None else 0,
|
||||
absolute_to_relative_idx=dataset.absolute_to_relative_idx,
|
||||
)
|
||||
if cfg.resume and step > 0:
|
||||
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so
|
||||
|
||||
Reference in New Issue
Block a user