mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
fix(train): vectorize eval subset selection for max_eval_samples
This commit is contained in:
@@ -461,13 +461,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if eval_dataset is not None:
|
||||
eval_ds = eval_dataset
|
||||
if cfg.max_eval_samples > 0 and hasattr(eval_dataset, "hf_dataset"):
|
||||
task_indices = eval_dataset.hf_dataset["task_index"]
|
||||
unique_tasks = sorted(set(task_indices))
|
||||
task_arr = eval_dataset.hf_dataset.data.column("task_index").to_numpy()
|
||||
unique_tasks = sorted(set(task_arr.tolist()))
|
||||
per_task = max(1, cfg.max_eval_samples // len(unique_tasks))
|
||||
selected: list[int] = []
|
||||
for t in unique_tasks:
|
||||
frames = [i for i, ti in enumerate(task_indices) if ti == t][:per_task]
|
||||
selected.extend(frames)
|
||||
frames = (task_arr == t).nonzero()[0][:per_task]
|
||||
selected.extend(frames.tolist())
|
||||
eval_ds = torch.utils.data.Subset(eval_dataset, selected)
|
||||
|
||||
eval_collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
|
||||
@@ -479,6 +479,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=False,
|
||||
collate_fn=eval_collate_fn,
|
||||
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
||||
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||
)
|
||||
|
||||
# Prepare everything with accelerator
|
||||
|
||||
Reference in New Issue
Block a user