diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 52fe8737a..b0ad6a50b 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -461,13 +461,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): if eval_dataset is not None: eval_ds = eval_dataset if cfg.max_eval_samples > 0 and hasattr(eval_dataset, "hf_dataset"): - task_indices = eval_dataset.hf_dataset["task_index"] - unique_tasks = sorted(set(task_indices)) + task_arr = eval_dataset.hf_dataset.data.column("task_index").to_numpy() + unique_tasks = sorted(set(task_arr.tolist())) per_task = max(1, cfg.max_eval_samples // len(unique_tasks)) selected: list[int] = [] for t in unique_tasks: - frames = [i for i, ti in enumerate(task_indices) if ti == t][:per_task] - selected.extend(frames) + frames = (task_arr == t).nonzero()[0][:per_task] + selected.extend(frames.tolist()) eval_ds = torch.utils.data.Subset(eval_dataset, selected) eval_collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None @@ -479,6 +479,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): pin_memory=device.type == "cuda", drop_last=False, collate_fn=eval_collate_fn, + prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None, + persistent_workers=cfg.persistent_workers and cfg.num_workers > 0, ) # Prepare everything with accelerator