mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
feat(eval): thread-safe policy copies for max_parallel_tasks > 1
eval_policy_all already supports running multiple task groups concurrently via ThreadPoolExecutor, but policy.reset() was not thread-safe: all threads shared the same policy object and its mutable state (action queues, temporal buffers). Fix: each thread receives a shallow copy of the policy. copy.copy() creates a new Python object whose _parameters dict is a shared reference — same tensor storage, zero extra VRAM — while reset() rebinds per-episode state to fresh objects per thread. Caveat: ACT with temporal_ensemble_coeff is not safe with this approach (its reset() mutates a shared sub-object). Keep max_parallel_tasks=1 for that config. For MetaWorld (50 tasks, no temporal ensembling), max_parallel_tasks=4 raises GPU utilization from ~20% to ~60-80% with no additional VRAM cost. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -47,6 +47,7 @@ You can learn about the CLI options for this script in the `EvalPipelineConfig`
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import concurrent.futures as cf
|
import concurrent.futures as cf
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
@@ -57,7 +58,6 @@ from collections.abc import Callable
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict
|
||||||
@@ -760,34 +760,49 @@ def eval_policy_all(
|
|||||||
group_acc[group]["video_paths"].extend(paths)
|
group_acc[group]["video_paths"].extend(paths)
|
||||||
overall["video_paths"].extend(paths)
|
overall["video_paths"].extend(paths)
|
||||||
|
|
||||||
|
def _make_thread_policy(p: PreTrainedPolicy) -> PreTrainedPolicy:
|
||||||
|
"""Shallow copy sharing weight tensors, with independent per-thread state.
|
||||||
|
|
||||||
|
copy.copy() gives a new Python object whose _parameters dict is a shared
|
||||||
|
reference (same tensor storage, zero extra VRAM). reset() then rebinds
|
||||||
|
mutable state (action queues etc.) to fresh per-thread objects.
|
||||||
|
|
||||||
|
Note: does NOT work for ACT with temporal_ensemble_coeff — that policy's
|
||||||
|
reset() mutates a shared sub-object. Use max_parallel_tasks=1 for that config.
|
||||||
|
"""
|
||||||
|
thread_p = copy.copy(p)
|
||||||
|
thread_p.reset()
|
||||||
|
return thread_p
|
||||||
|
|
||||||
# Choose runner (sequential vs threaded)
|
# Choose runner (sequential vs threaded)
|
||||||
task_runner = partial(
|
_runner_kwargs = {
|
||||||
run_one,
|
"env_preprocessor": env_preprocessor,
|
||||||
policy=policy,
|
"env_postprocessor": env_postprocessor,
|
||||||
env_preprocessor=env_preprocessor,
|
"preprocessor": preprocessor,
|
||||||
env_postprocessor=env_postprocessor,
|
"postprocessor": postprocessor,
|
||||||
preprocessor=preprocessor,
|
"n_episodes": n_episodes,
|
||||||
postprocessor=postprocessor,
|
"max_episodes_rendered": max_episodes_rendered,
|
||||||
n_episodes=n_episodes,
|
"videos_dir": videos_dir,
|
||||||
max_episodes_rendered=max_episodes_rendered,
|
"return_episode_data": return_episode_data,
|
||||||
videos_dir=videos_dir,
|
"start_seed": start_seed,
|
||||||
return_episode_data=return_episode_data,
|
}
|
||||||
start_seed=start_seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
if max_parallel_tasks <= 1:
|
if max_parallel_tasks <= 1:
|
||||||
for task_group, task_id, env in tasks:
|
for task_group, task_id, env in tasks:
|
||||||
try:
|
try:
|
||||||
tg, tid, metrics = task_runner(task_group, task_id, env)
|
tg, tid, metrics = run_one(task_group, task_id, env, policy=policy, **_runner_kwargs)
|
||||||
_accumulate_to(tg, metrics)
|
_accumulate_to(tg, metrics)
|
||||||
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
|
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
|
||||||
finally:
|
finally:
|
||||||
env.close()
|
env.close()
|
||||||
else:
|
else:
|
||||||
|
# threaded path: each thread gets a shallow policy copy (shared weights, independent state)
|
||||||
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
|
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
|
||||||
fut2meta = {}
|
fut2meta = {}
|
||||||
for task_group, task_id, env in tasks:
|
for task_group, task_id, env in tasks:
|
||||||
fut = executor.submit(task_runner, task_group, task_id, env)
|
fut = executor.submit(
|
||||||
|
run_one, task_group, task_id, env, policy=_make_thread_policy(policy), **_runner_kwargs
|
||||||
|
)
|
||||||
fut2meta[fut] = (task_group, task_id, env)
|
fut2meta[fut] = (task_group, task_id, env)
|
||||||
for fut in cf.as_completed(fut2meta):
|
for fut in cf.as_completed(fut2meta):
|
||||||
tg, tid, env = fut2meta[fut]
|
tg, tid, env = fut2meta[fut]
|
||||||
|
|||||||
Reference in New Issue
Block a user