From 2b83956eb5c3d0585f5b2492f9a7e6ea2467bbe4 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Tue, 16 Jun 2026 16:22:55 +0200 Subject: [PATCH] fix(train): vectorize eval subset selection for max_eval_samples --- src/lerobot/scripts/lerobot_train.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 52fe8737a..b0ad6a50b 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -461,13 +461,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): if eval_dataset is not None: eval_ds = eval_dataset if cfg.max_eval_samples > 0 and hasattr(eval_dataset, "hf_dataset"): - task_indices = eval_dataset.hf_dataset["task_index"] - unique_tasks = sorted(set(task_indices)) + task_arr = eval_dataset.hf_dataset.data.column("task_index").to_numpy() + unique_tasks = sorted(set(task_arr.tolist())) per_task = max(1, cfg.max_eval_samples // len(unique_tasks)) selected: list[int] = [] for t in unique_tasks: - frames = [i for i, ti in enumerate(task_indices) if ti == t][:per_task] - selected.extend(frames) + frames = (task_arr == t).nonzero()[0][:per_task] + selected.extend(frames.tolist()) eval_ds = torch.utils.data.Subset(eval_dataset, selected) eval_collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None @@ -479,6 +479,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): pin_memory=device.type == "cuda", drop_last=False, collate_fn=eval_collate_fn, + prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None, + persistent_workers=cfg.persistent_workers and cfg.num_workers > 0, ) # Prepare everything with accelerator