fix(train): vectorize eval subset selection for max_eval_samples

This commit is contained in:
Khalil Meftah
2026-06-16 16:22:55 +02:00
parent 7309790d56
commit 2b83956eb5
+6 -4
View File
@@ -461,13 +461,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if eval_dataset is not None:
eval_ds = eval_dataset
if cfg.max_eval_samples > 0 and hasattr(eval_dataset, "hf_dataset"):
task_indices = eval_dataset.hf_dataset["task_index"]
unique_tasks = sorted(set(task_indices))
task_arr = eval_dataset.hf_dataset.data.column("task_index").to_numpy()
unique_tasks = sorted(set(task_arr.tolist()))
per_task = max(1, cfg.max_eval_samples // len(unique_tasks))
selected: list[int] = []
for t in unique_tasks:
frames = [i for i, ti in enumerate(task_indices) if ti == t][:per_task]
selected.extend(frames)
frames = (task_arr == t).nonzero()[0][:per_task]
selected.extend(frames.tolist())
eval_ds = torch.utils.data.Subset(eval_dataset, selected)
eval_collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
@@ -479,6 +479,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
pin_memory=device.type == "cuda",
drop_last=False,
collate_fn=eval_collate_fn,
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
)
# Prepare everything with accelerator