Compare commits

...

2 Commits

Author SHA1 Message Date
Khalil Meftah 2b83956eb5 fix(train): vectorize eval subset selection for max_eval_samples 2026-06-16 16:22:55 +02:00
Khalil Meftah 7309790d56 fix(datasets): remap absolute indices in __getitem__ for filtered datasets 2026-06-16 15:15:11 +02:00
2 changed files with 8 additions and 6 deletions
+2
View File
@@ -474,6 +474,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
if reader.hf_dataset is None:
# One-shot load after finalize()
reader.load_and_activate()
if reader._absolute_to_relative_idx is not None and idx in reader._absolute_to_relative_idx:
idx = reader._absolute_to_relative_idx[idx]
return reader.get_item(idx)
def select_columns(self, column_names: str | list[str]):
+6 -6
View File
@@ -435,8 +435,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
f"Resuming data order at epoch {sampler_state['epoch']}, "
f"sample {sampler_state['start_index']}"
)
if dataset.reader._absolute_to_relative_idx is not None:
sampler.indices = [dataset.reader._absolute_to_relative_idx[i] for i in sampler.indices]
else:
shuffle = True
sampler = None
@@ -463,13 +461,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if eval_dataset is not None:
eval_ds = eval_dataset
if cfg.max_eval_samples > 0 and hasattr(eval_dataset, "hf_dataset"):
task_indices = eval_dataset.hf_dataset["task_index"]
unique_tasks = sorted(set(task_indices))
task_arr = eval_dataset.hf_dataset.data.column("task_index").to_numpy()
unique_tasks = sorted(set(task_arr.tolist()))
per_task = max(1, cfg.max_eval_samples // len(unique_tasks))
selected: list[int] = []
for t in unique_tasks:
frames = [i for i, ti in enumerate(task_indices) if ti == t][:per_task]
selected.extend(frames)
frames = (task_arr == t).nonzero()[0][:per_task]
selected.extend(frames.tolist())
eval_ds = torch.utils.data.Subset(eval_dataset, selected)
eval_collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
@@ -481,6 +479,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
pin_memory=device.type == "cuda",
drop_last=False,
collate_fn=eval_collate_fn,
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
)
# Prepare everything with accelerator