fix: close envs between tasks to prevent worker process accumulation

eval_policy_all never closed environments after each task completed,
causing AsyncVectorEnv worker processes to accumulate (N_tasks × n_envs).
This led to OOM, BrokenPipeError and EOFError on multi-task benchmarks.

Also fixes:
- AsyncVectorEnv compat in envs/utils.py (use get_attr/call instead of .envs)
- Tuple task handling in tokenizer_processor and lerobot_eval
- _LazyAsyncVectorEnv for deferred worker spawning in LIBERO

Made-with: Cursor
This commit is contained in:
Pepijn Kooijmans
2026-04-07 12:30:22 +02:00
parent 8c3babc2cb
commit 6b3d25bc79
7 changed files with 129 additions and 44 deletions
+9 -2
View File
@@ -44,6 +44,13 @@ from lerobot.utils.constants import (
) )
def _make_vec_env_cls(use_async: bool, n_envs: int):
"""Return the right VectorEnv constructor."""
if use_async and n_envs > 1:
return gym.vector.AsyncVectorEnv
return gym.vector.SyncVectorEnv
@dataclass @dataclass
class EnvConfig(draccus.ChoiceRegistry, abc.ABC): class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
task: str | None = None task: str | None = None
@@ -405,7 +412,7 @@ class LiberoEnv(EnvConfig):
if self.task is None: if self.task is None:
raise ValueError("LiberoEnv requires a task to be specified") raise ValueError("LiberoEnv requires a task to be specified")
env_cls = gym.vector.AsyncVectorEnv if (use_async_envs and n_envs > 1) else gym.vector.SyncVectorEnv env_cls = _make_vec_env_cls(use_async_envs, n_envs)
return create_libero_envs( return create_libero_envs(
task=self.task, task=self.task,
n_envs=n_envs, n_envs=n_envs,
@@ -474,7 +481,7 @@ class MetaworldEnv(EnvConfig):
if self.task is None: if self.task is None:
raise ValueError("MetaWorld requires a task to be specified") raise ValueError("MetaWorld requires a task to be specified")
env_cls = gym.vector.AsyncVectorEnv if (use_async_envs and n_envs > 1) else gym.vector.SyncVectorEnv env_cls = _make_vec_env_cls(use_async_envs, n_envs)
return create_metaworld_envs( return create_metaworld_envs(
task=self.task, task=self.task,
n_envs=n_envs, n_envs=n_envs,
+57 -2
View File
@@ -403,6 +403,57 @@ def _make_env_fns(
return fns return fns
class _LazyAsyncVectorEnv:
"""Wrapper that defers AsyncVectorEnv creation until first use.
Creating all tasks' AsyncVectorEnvs upfront spawns N_tasks × n_envs worker
processes, all of which allocate EGL/GPU resources immediately. Since tasks
are evaluated sequentially, only one task's workers need to be alive at a
time. This wrapper stores the factory functions and creates the real
AsyncVectorEnv on first reset(), keeping peak process count = n_envs.
"""
def __init__(self, env_fns: list[Callable]):
self._env_fns = env_fns
self._env: gym.vector.AsyncVectorEnv | None = None
self.num_envs = len(env_fns)
# Instantiate one env to expose spaces (no GPU — _ensure_env is lazy).
tmp = env_fns[0]()
self.observation_space = tmp.observation_space
self.action_space = tmp.action_space
self.single_observation_space = tmp.observation_space
self.single_action_space = tmp.action_space
tmp.close()
def _ensure(self):
if self._env is None:
self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver")
def reset(self, **kwargs):
self._ensure()
return self._env.reset(**kwargs)
def step(self, actions):
self._ensure()
return self._env.step(actions)
def call(self, name, *args, **kwargs):
self._ensure()
return self._env.call(name, *args, **kwargs)
def get_attr(self, name):
self._ensure()
return self._env.get_attr(name)
def close(self):
if self._env is not None:
self._env.close()
self._env = None
def __del__(self):
self.close()
# ---- Main API ---------------------------------------------------------------- # ---- Main API ----------------------------------------------------------------
@@ -446,6 +497,8 @@ def create_libero_envs(
if task_ids_filter is not None: if task_ids_filter is not None:
print(f"Restricting to task_ids={task_ids_filter}") print(f"Restricting to task_ids={task_ids_filter}")
is_async = env_cls is gym.vector.AsyncVectorEnv
out: dict[str, dict[int, Any]] = defaultdict(dict) out: dict[str, dict[int, Any]] = defaultdict(dict)
for suite_name in suite_names: for suite_name in suite_names:
suite = _get_suite(suite_name) suite = _get_suite(suite_name)
@@ -467,8 +520,10 @@ def create_libero_envs(
control_mode=control_mode, control_mode=control_mode,
camera_name_mapping=camera_name_mapping, camera_name_mapping=camera_name_mapping,
) )
out[suite_name][tid] = env_cls(fns) if is_async:
out[suite_name][tid] = _LazyAsyncVectorEnv(fns)
else:
out[suite_name][tid] = env_cls(fns)
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}") print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
# return plain dicts for predictability
return {suite: dict(task_map) for suite, task_map in out.items()} return {suite: dict(task_map) for suite, task_map in out.items()}
+21 -26
View File
@@ -130,56 +130,51 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
return policy_features return policy_features
def are_all_envs_same_type(env: gym.vector.VectorEnv) -> bool: def _get_sub_env_attr(env: gym.vector.VectorEnv, attr: str, index: int = 0):
first_type = type(env.envs[0]) # Get type of first env """Retrieve an attribute from a sub-environment, works for both Sync and Async."""
return all(type(e) is first_type for e in env.envs) # Fast type check try:
return env.get_attr(attr)[index]
except (AttributeError, Exception):
return None
def _sub_env_has_attr(env: gym.vector.VectorEnv, attr: str) -> bool:
try:
env.get_attr(attr)
return True
except (AttributeError, Exception):
return False
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None: def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("once", UserWarning) # Apply filter only in this function warnings.simplefilter("once", UserWarning)
if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")): if not (_sub_env_has_attr(env, "task_description") and _sub_env_has_attr(env, "task")):
warnings.warn( warnings.warn(
"The environment does not have 'task_description' and 'task'. Some policies require these features.", "The environment does not have 'task_description' and 'task'. Some policies require these features.",
UserWarning, UserWarning,
stacklevel=2, stacklevel=2,
) )
if not are_all_envs_same_type(env):
warnings.warn(
"The environments have different types. Make sure you infer the right task from each environment. Empty task will be passed instead.",
UserWarning,
stacklevel=2,
)
def add_envs_task(env: gym.vector.VectorEnv, observation: RobotObservation) -> RobotObservation: def add_envs_task(env: gym.vector.VectorEnv, observation: RobotObservation) -> RobotObservation:
"""Adds task feature to the observation dict with respect to the first environment attribute.""" """Adds task feature to the observation dict with respect to the first environment attribute."""
if hasattr(env.envs[0], "task_description"): if _sub_env_has_attr(env, "task_description"):
task_result = env.call("task_description") task_result = list(env.call("task_description"))
if isinstance(task_result, tuple):
task_result = list(task_result)
if not isinstance(task_result, list):
raise TypeError(f"Expected task_description to return a list, got {type(task_result)}")
if not all(isinstance(item, str) for item in task_result): if not all(isinstance(item, str) for item in task_result):
raise TypeError("All items in task_description result must be strings") raise TypeError("All items in task_description result must be strings")
observation["task"] = task_result observation["task"] = task_result
elif hasattr(env.envs[0], "task"): elif _sub_env_has_attr(env, "task"):
task_result = env.call("task") task_result = list(env.call("task"))
if isinstance(task_result, tuple):
task_result = list(task_result)
if not isinstance(task_result, list):
raise TypeError(f"Expected task to return a list, got {type(task_result)}")
if not all(isinstance(item, str) for item in task_result): if not all(isinstance(item, str) for item in task_result):
raise TypeError("All items in task result must be strings") raise TypeError("All items in task result must be strings")
observation["task"] = task_result observation["task"] = task_result
else: # For envs without language instructions, e.g. aloha transfer cube and etc. else:
num_envs = observation[list(observation.keys())[0]].shape[0] num_envs = observation[list(observation.keys())[0]].shape[0]
observation["task"] = ["" for _ in range(num_envs)] observation["task"] = ["" for _ in range(num_envs)]
return observation return observation
+2 -2
View File
@@ -136,8 +136,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
# Standardize to a list of strings for the tokenizer # Standardize to a list of strings for the tokenizer
if isinstance(task, str): if isinstance(task, str):
return [task] return [task]
elif isinstance(task, list) and all(isinstance(t, str) for t in task): elif isinstance(task, (list, tuple)) and all(isinstance(t, str) for t in task):
return task return list(task)
return None return None
+15 -11
View File
@@ -167,7 +167,7 @@ def rollout(
# Infer "task" from sub-environments. # Infer "task" from sub-environments.
# env.call() works with both SyncVectorEnv and AsyncVectorEnv. # env.call() works with both SyncVectorEnv and AsyncVectorEnv.
observation["task"] = env.call("task") observation["task"] = list(env.call("task"))
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO) # Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
observation = env_preprocessor(observation) observation = env_preprocessor(observation)
@@ -754,23 +754,27 @@ def eval_policy_all(
) )
if max_parallel_tasks <= 1: if max_parallel_tasks <= 1:
# sequential path (single accumulator path on the main thread)
# NOTE: keeping a single-threaded accumulator avoids concurrent list appends or locks
for task_group, task_id, env in tasks: for task_group, task_id, env in tasks:
tg, tid, metrics = task_runner(task_group, task_id, env) try:
_accumulate_to(tg, metrics) tg, tid, metrics = task_runner(task_group, task_id, env)
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics}) _accumulate_to(tg, metrics)
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
finally:
env.close()
else: else:
# threaded path: submit all tasks, consume completions on main thread and accumulate there
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor: with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
fut2meta = {} fut2meta = {}
for task_group, task_id, env in tasks: for task_group, task_id, env in tasks:
fut = executor.submit(task_runner, task_group, task_id, env) fut = executor.submit(task_runner, task_group, task_id, env)
fut2meta[fut] = (task_group, task_id) fut2meta[fut] = (task_group, task_id, env)
for fut in cf.as_completed(fut2meta): for fut in cf.as_completed(fut2meta):
tg, tid, metrics = fut.result() tg, tid, env = fut2meta[fut]
_accumulate_to(tg, metrics) try:
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics}) tg, tid, metrics = fut.result()
_accumulate_to(tg, metrics)
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
finally:
env.close()
# compute aggregated metrics helper (robust to lists/scalars) # compute aggregated metrics helper (robust to lists/scalars)
def _agg_from_list(xs): def _agg_from_list(xs):
+1 -1
View File
@@ -90,7 +90,7 @@ def test_base_create_envs():
envs = _Env().create_envs(n_envs=2) envs = _Env().create_envs(n_envs=2)
assert "_dispatch_base_test" in envs assert "_dispatch_base_test" in envs
env = envs["_dispatch_base_test"][0] env = envs["_dispatch_base_test"][0]
assert isinstance(env, gym.vector.SyncVectorEnv) assert isinstance(env, gym.vector.VectorEnv)
assert env.num_envs == 2 assert env.num_envs == 2
env.close() env.close()
finally: finally:
@@ -189,6 +189,30 @@ def test_list_of_strings_tokenization(mock_auto_tokenizer):
assert attention_mask.shape == (2, 8) assert attention_mask.shape == (2, 8)
@require_package("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_tuple_of_strings_tokenization(mock_auto_tokenizer):
"""Test tokenization of a tuple of strings (returned by VectorEnv.call())."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": ("pick up cube", "place on table")},
)
result = processor(transition)
observation = result[TransitionKey.OBSERVATION]
tokens = observation[f"{OBS_LANGUAGE}.tokens"]
attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"]
assert tokens.shape == (2, 8)
assert attention_mask.shape == (2, 8)
@require_package("transformers") @require_package("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_custom_keys(mock_auto_tokenizer): def test_custom_keys(mock_auto_tokenizer):