From ff861ba86978b834ce2ac611e86ec4e185f213f3 Mon Sep 17 00:00:00 2001 From: "Jade Choghari (jchoghar)" Date: Mon, 25 Aug 2025 14:52:35 -0400 Subject: [PATCH] add safethread support --- examples/5_train_libero.sh | 4 +-- examples/6_evaluate_libero.sh | 4 +-- src/lerobot/scripts/eval.py | 60 +++++++++++++++++++++++------------ 3 files changed, 43 insertions(+), 25 deletions(-) diff --git a/examples/5_train_libero.sh b/examples/5_train_libero.sh index bfaf6a331..93f6d76d8 100755 --- a/examples/5_train_libero.sh +++ b/examples/5_train_libero.sh @@ -32,7 +32,7 @@ MAX_NUM_IMAGES=2 MAX_IMAGE_DIM=1024 unset LEROBOT_HOME unset HF_LEROBOT_HOME - +export MUJOCO_GL=egl echo -e "\033[1;33m[WARNING]\033[0m LIBERO is not yet fully supported in this PR!" # launch @@ -48,6 +48,6 @@ python src/lerobot/scripts/train.py \ --save_freq=$SAVE_FREQ \ --num_workers=$NUM_WORKERS \ --policy.repo_id=$VLM_REPO_ID \ - --env.multitask_eval=True \ + --env.multitask_eval=False \ --eval.batch_size=1 \ --eval.n_episodes=1 \ diff --git a/examples/6_evaluate_libero.sh b/examples/6_evaluate_libero.sh index 2552e4602..fe994f645 100644 --- a/examples/6_evaluate_libero.sh +++ b/examples/6_evaluate_libero.sh @@ -8,12 +8,12 @@ TASK=libero_object ENV_TYPE="libero" BATCH_SIZE=1 N_EPISODES=1 - +export MUJOCO_GL=egl # RUN EVALUATION python src/lerobot/scripts/eval.py \ --policy.path="$POLICY_PATH" \ --env.type="$ENV_TYPE" \ --eval.batch_size="$BATCH_SIZE" \ --eval.n_episodes="$N_EPISODES" \ - --env.multitask_eval=True \ + --env.multitask_eval=False \ --env.task=$TASK \ diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index f52c8c70e..c76f1fbec 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -573,31 +573,49 @@ def eval_policy_multitask( "video_paths": task_result.get("video_paths", []), } - with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor: - future_to_task = { - executor.submit(eval_task, task_group, task_id, env): (task_group, task_id) - for task_group, tasks in envs.items() - for task_id, env in tasks.items() - } + task_group_results = {} + if max_parallel_tasks == 1: + # sequential mode (safe for colab / EGL) + for task_group, tasks in envs.items(): + for task_id, env in tasks.items(): + task_result = eval_task(task_group, task_id, env) + if task_group not in task_group_results: + task_group_results[task_group] = { + "sum_rewards": [], + "max_rewards": [], + "successes": [], + "video_paths": [], + } + task_group_results[task_group]["sum_rewards"].extend(task_result["sum_rewards"]) + task_group_results[task_group]["max_rewards"].extend(task_result["max_rewards"]) + task_group_results[task_group]["successes"].extend(task_result["successes"]) + task_group_results[task_group]["video_paths"].extend(task_result["video_paths"]) + else: + with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor: + future_to_task = { + executor.submit(eval_task, task_group, task_id, env): (task_group, task_id) + for task_group, tasks in envs.items() + for task_id, env in tasks.items() + } - task_group_results = {} + task_group_results = {} - for future in concurrent.futures.as_completed(future_to_task): - task_result = future.result() - task_group = task_result["task_group"] + for future in concurrent.futures.as_completed(future_to_task): + task_result = future.result() + task_group = task_result["task_group"] - if task_group not in task_group_results: - task_group_results[task_group] = { - "sum_rewards": [], - "max_rewards": [], - "successes": [], - "video_paths": [], - } + if task_group not in task_group_results: + task_group_results[task_group] = { + "sum_rewards": [], + "max_rewards": [], + "successes": [], + "video_paths": [], + } - task_group_results[task_group]["sum_rewards"].extend(task_result["sum_rewards"]) - task_group_results[task_group]["max_rewards"].extend(task_result["max_rewards"]) - task_group_results[task_group]["successes"].extend(task_result["successes"]) - task_group_results[task_group]["video_paths"].extend(task_result["video_paths"]) + task_group_results[task_group]["sum_rewards"].extend(task_result["sum_rewards"]) + task_group_results[task_group]["max_rewards"].extend(task_result["max_rewards"]) + task_group_results[task_group]["successes"].extend(task_result["successes"]) + task_group_results[task_group]["video_paths"].extend(task_result["video_paths"]) # Process results per task group for task_group, data in task_group_results.items():