diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index 51a9209f6..b4bced908 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -767,19 +767,19 @@ def eval_policy_all( prefetch_thread.join() prefetch_thread = None - # Prefetch next task's AsyncVectorEnv workers while this task runs. - if i + 1 < len(tasks): - next_env = tasks[i + 1][2] - if hasattr(next_env, "_ensure"): - prefetch_thread = threading.Thread(target=next_env._ensure, daemon=True) - prefetch_thread.start() - try: tg, tid, metrics = task_runner(task_group, task_id, env) _accumulate_to(tg, metrics) per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics}) finally: env.close() + # Prefetch next task's workers *after* closing current env to prevent + # GPU memory overlap between consecutive tasks. + if i + 1 < len(tasks): + next_env = tasks[i + 1][2] + if hasattr(next_env, "_ensure"): + prefetch_thread = threading.Thread(target=next_env._ensure, daemon=True) + prefetch_thread.start() else: with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor: fut2meta = {}