mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
add safethread support
This commit is contained in:
@@ -32,7 +32,7 @@ MAX_NUM_IMAGES=2
|
|||||||
MAX_IMAGE_DIM=1024
|
MAX_IMAGE_DIM=1024
|
||||||
unset LEROBOT_HOME
|
unset LEROBOT_HOME
|
||||||
unset HF_LEROBOT_HOME
|
unset HF_LEROBOT_HOME
|
||||||
|
export MUJOCO_GL=egl
|
||||||
echo -e "\033[1;33m[WARNING]\033[0m LIBERO is not yet fully supported in this PR!"
|
echo -e "\033[1;33m[WARNING]\033[0m LIBERO is not yet fully supported in this PR!"
|
||||||
|
|
||||||
# launch
|
# launch
|
||||||
@@ -48,6 +48,6 @@ python src/lerobot/scripts/train.py \
|
|||||||
--save_freq=$SAVE_FREQ \
|
--save_freq=$SAVE_FREQ \
|
||||||
--num_workers=$NUM_WORKERS \
|
--num_workers=$NUM_WORKERS \
|
||||||
--policy.repo_id=$VLM_REPO_ID \
|
--policy.repo_id=$VLM_REPO_ID \
|
||||||
--env.multitask_eval=True \
|
--env.multitask_eval=False \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ TASK=libero_object
|
|||||||
ENV_TYPE="libero"
|
ENV_TYPE="libero"
|
||||||
BATCH_SIZE=1
|
BATCH_SIZE=1
|
||||||
N_EPISODES=1
|
N_EPISODES=1
|
||||||
|
export MUJOCO_GL=egl
|
||||||
# RUN EVALUATION
|
# RUN EVALUATION
|
||||||
python src/lerobot/scripts/eval.py \
|
python src/lerobot/scripts/eval.py \
|
||||||
--policy.path="$POLICY_PATH" \
|
--policy.path="$POLICY_PATH" \
|
||||||
--env.type="$ENV_TYPE" \
|
--env.type="$ENV_TYPE" \
|
||||||
--eval.batch_size="$BATCH_SIZE" \
|
--eval.batch_size="$BATCH_SIZE" \
|
||||||
--eval.n_episodes="$N_EPISODES" \
|
--eval.n_episodes="$N_EPISODES" \
|
||||||
--env.multitask_eval=True \
|
--env.multitask_eval=False \
|
||||||
--env.task=$TASK \
|
--env.task=$TASK \
|
||||||
|
|||||||
+39
-21
@@ -573,31 +573,49 @@ def eval_policy_multitask(
|
|||||||
"video_paths": task_result.get("video_paths", []),
|
"video_paths": task_result.get("video_paths", []),
|
||||||
}
|
}
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
|
task_group_results = {}
|
||||||
future_to_task = {
|
if max_parallel_tasks == 1:
|
||||||
executor.submit(eval_task, task_group, task_id, env): (task_group, task_id)
|
# sequential mode (safe for colab / EGL)
|
||||||
for task_group, tasks in envs.items()
|
for task_group, tasks in envs.items():
|
||||||
for task_id, env in tasks.items()
|
for task_id, env in tasks.items():
|
||||||
}
|
task_result = eval_task(task_group, task_id, env)
|
||||||
|
if task_group not in task_group_results:
|
||||||
|
task_group_results[task_group] = {
|
||||||
|
"sum_rewards": [],
|
||||||
|
"max_rewards": [],
|
||||||
|
"successes": [],
|
||||||
|
"video_paths": [],
|
||||||
|
}
|
||||||
|
task_group_results[task_group]["sum_rewards"].extend(task_result["sum_rewards"])
|
||||||
|
task_group_results[task_group]["max_rewards"].extend(task_result["max_rewards"])
|
||||||
|
task_group_results[task_group]["successes"].extend(task_result["successes"])
|
||||||
|
task_group_results[task_group]["video_paths"].extend(task_result["video_paths"])
|
||||||
|
else:
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
|
||||||
|
future_to_task = {
|
||||||
|
executor.submit(eval_task, task_group, task_id, env): (task_group, task_id)
|
||||||
|
for task_group, tasks in envs.items()
|
||||||
|
for task_id, env in tasks.items()
|
||||||
|
}
|
||||||
|
|
||||||
task_group_results = {}
|
task_group_results = {}
|
||||||
|
|
||||||
for future in concurrent.futures.as_completed(future_to_task):
|
for future in concurrent.futures.as_completed(future_to_task):
|
||||||
task_result = future.result()
|
task_result = future.result()
|
||||||
task_group = task_result["task_group"]
|
task_group = task_result["task_group"]
|
||||||
|
|
||||||
if task_group not in task_group_results:
|
if task_group not in task_group_results:
|
||||||
task_group_results[task_group] = {
|
task_group_results[task_group] = {
|
||||||
"sum_rewards": [],
|
"sum_rewards": [],
|
||||||
"max_rewards": [],
|
"max_rewards": [],
|
||||||
"successes": [],
|
"successes": [],
|
||||||
"video_paths": [],
|
"video_paths": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
task_group_results[task_group]["sum_rewards"].extend(task_result["sum_rewards"])
|
task_group_results[task_group]["sum_rewards"].extend(task_result["sum_rewards"])
|
||||||
task_group_results[task_group]["max_rewards"].extend(task_result["max_rewards"])
|
task_group_results[task_group]["max_rewards"].extend(task_result["max_rewards"])
|
||||||
task_group_results[task_group]["successes"].extend(task_result["successes"])
|
task_group_results[task_group]["successes"].extend(task_result["successes"])
|
||||||
task_group_results[task_group]["video_paths"].extend(task_result["video_paths"])
|
task_group_results[task_group]["video_paths"].extend(task_result["video_paths"])
|
||||||
|
|
||||||
# Process results per task group
|
# Process results per task group
|
||||||
for task_group, data in task_group_results.items():
|
for task_group, data in task_group_results.items():
|
||||||
|
|||||||
Reference in New Issue
Block a user