add safethread support

This commit is contained in:
Jade Choghari (jchoghar)
2025-08-25 14:52:35 -04:00
parent 4be3942cbc
commit ff861ba869
3 changed files with 43 additions and 25 deletions
+2 -2
View File
@@ -32,7 +32,7 @@ MAX_NUM_IMAGES=2
MAX_IMAGE_DIM=1024 MAX_IMAGE_DIM=1024
unset LEROBOT_HOME unset LEROBOT_HOME
unset HF_LEROBOT_HOME unset HF_LEROBOT_HOME
export MUJOCO_GL=egl
echo -e "\033[1;33m[WARNING]\033[0m LIBERO is not yet fully supported in this PR!" echo -e "\033[1;33m[WARNING]\033[0m LIBERO is not yet fully supported in this PR!"
# launch # launch
@@ -48,6 +48,6 @@ python src/lerobot/scripts/train.py \
--save_freq=$SAVE_FREQ \ --save_freq=$SAVE_FREQ \
--num_workers=$NUM_WORKERS \ --num_workers=$NUM_WORKERS \
--policy.repo_id=$VLM_REPO_ID \ --policy.repo_id=$VLM_REPO_ID \
--env.multitask_eval=True \ --env.multitask_eval=False \
--eval.batch_size=1 \ --eval.batch_size=1 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
+2 -2
View File
@@ -8,12 +8,12 @@ TASK=libero_object
ENV_TYPE="libero" ENV_TYPE="libero"
BATCH_SIZE=1 BATCH_SIZE=1
N_EPISODES=1 N_EPISODES=1
export MUJOCO_GL=egl
# RUN EVALUATION # RUN EVALUATION
python src/lerobot/scripts/eval.py \ python src/lerobot/scripts/eval.py \
--policy.path="$POLICY_PATH" \ --policy.path="$POLICY_PATH" \
--env.type="$ENV_TYPE" \ --env.type="$ENV_TYPE" \
--eval.batch_size="$BATCH_SIZE" \ --eval.batch_size="$BATCH_SIZE" \
--eval.n_episodes="$N_EPISODES" \ --eval.n_episodes="$N_EPISODES" \
--env.multitask_eval=True \ --env.multitask_eval=False \
--env.task=$TASK \ --env.task=$TASK \
+39 -21
View File
@@ -573,31 +573,49 @@ def eval_policy_multitask(
"video_paths": task_result.get("video_paths", []), "video_paths": task_result.get("video_paths", []),
} }
with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor: task_group_results = {}
future_to_task = { if max_parallel_tasks == 1:
executor.submit(eval_task, task_group, task_id, env): (task_group, task_id) # sequential mode (safe for colab / EGL)
for task_group, tasks in envs.items() for task_group, tasks in envs.items():
for task_id, env in tasks.items() for task_id, env in tasks.items():
} task_result = eval_task(task_group, task_id, env)
if task_group not in task_group_results:
task_group_results[task_group] = {
"sum_rewards": [],
"max_rewards": [],
"successes": [],
"video_paths": [],
}
task_group_results[task_group]["sum_rewards"].extend(task_result["sum_rewards"])
task_group_results[task_group]["max_rewards"].extend(task_result["max_rewards"])
task_group_results[task_group]["successes"].extend(task_result["successes"])
task_group_results[task_group]["video_paths"].extend(task_result["video_paths"])
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
future_to_task = {
executor.submit(eval_task, task_group, task_id, env): (task_group, task_id)
for task_group, tasks in envs.items()
for task_id, env in tasks.items()
}
task_group_results = {} task_group_results = {}
for future in concurrent.futures.as_completed(future_to_task): for future in concurrent.futures.as_completed(future_to_task):
task_result = future.result() task_result = future.result()
task_group = task_result["task_group"] task_group = task_result["task_group"]
if task_group not in task_group_results: if task_group not in task_group_results:
task_group_results[task_group] = { task_group_results[task_group] = {
"sum_rewards": [], "sum_rewards": [],
"max_rewards": [], "max_rewards": [],
"successes": [], "successes": [],
"video_paths": [], "video_paths": [],
} }
task_group_results[task_group]["sum_rewards"].extend(task_result["sum_rewards"]) task_group_results[task_group]["sum_rewards"].extend(task_result["sum_rewards"])
task_group_results[task_group]["max_rewards"].extend(task_result["max_rewards"]) task_group_results[task_group]["max_rewards"].extend(task_result["max_rewards"])
task_group_results[task_group]["successes"].extend(task_result["successes"]) task_group_results[task_group]["successes"].extend(task_result["successes"])
task_group_results[task_group]["video_paths"].extend(task_result["video_paths"]) task_group_results[task_group]["video_paths"].extend(task_result["video_paths"])
# Process results per task group # Process results per task group
for task_group, data in task_group_results.items(): for task_group, data in task_group_results.items():