refactor(envs): move _LazyAsyncVectorEnv to utils and apply to metaworld

_LazyAsyncVectorEnv lived in libero.py but metaworld had the same OOM
problem: all tasks' AsyncVectorEnv workers were spawned eagerly, wasting
GPU memory for tasks not yet running.

Move the class to envs/utils.py so both environments share it, then apply
the same is_async + lazy wrapping pattern in create_metaworld_envs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-08 14:32:51 +02:00
parent 93b99e4c5d
commit fe05e5095b
3 changed files with 70 additions and 58 deletions
+1 -56
View File
@@ -29,6 +29,7 @@ from gymnasium import spaces
from libero.libero import benchmark, get_libero_path from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv from libero.libero.envs import OffScreenRenderEnv
from lerobot.envs.utils import _LazyAsyncVectorEnv
from lerobot.types import RobotObservation from lerobot.types import RobotObservation
@@ -403,62 +404,6 @@ def _make_env_fns(
return fns return fns
class _LazyAsyncVectorEnv:
"""Wrapper that defers AsyncVectorEnv creation until first use.
Creating all tasks' AsyncVectorEnvs upfront spawns N_tasks × n_envs worker
processes, all of which allocate EGL/GPU resources immediately. Since tasks
are evaluated sequentially, only one task's workers need to be alive at a
time. This wrapper stores the factory functions and creates the real
AsyncVectorEnv on first reset(), keeping peak process count = n_envs.
"""
def __init__(
self,
env_fns: list[Callable],
observation_space: spaces.Space | None = None,
action_space: spaces.Space | None = None,
):
self._env_fns = env_fns
self._env: gym.vector.AsyncVectorEnv | None = None
self.num_envs = len(env_fns)
if observation_space is not None and action_space is not None:
self.observation_space = observation_space
self.action_space = action_space
else:
tmp = env_fns[0]()
self.observation_space = tmp.observation_space
self.action_space = tmp.action_space
tmp.close()
self.single_observation_space = self.observation_space
self.single_action_space = self.action_space
def _ensure(self):
if self._env is None:
self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver", shared_memory=True)
def reset(self, **kwargs):
self._ensure()
return self._env.reset(**kwargs)
def step(self, actions):
self._ensure()
return self._env.step(actions)
def call(self, name, *args, **kwargs):
self._ensure()
return self._env.call(name, *args, **kwargs)
def get_attr(self, name):
self._ensure()
return self._env.get_attr(name)
def close(self):
if self._env is not None:
self._env.close()
self._env = None
# ---- Main API ---------------------------------------------------------------- # ---- Main API ----------------------------------------------------------------
+12 -1
View File
@@ -25,6 +25,7 @@ import metaworld.policies as policies
import numpy as np import numpy as np
from gymnasium import spaces from gymnasium import spaces
from lerobot.envs.utils import _LazyAsyncVectorEnv
from lerobot.types import RobotObservation from lerobot.types import RobotObservation
# ---- Load configuration data from the external JSON file ---- # ---- Load configuration data from the external JSON file ----
@@ -306,6 +307,9 @@ def create_metaworld_envs(
print(f"Creating Meta-World envs | task_groups={task_groups} | n_envs(per task)={n_envs}") print(f"Creating Meta-World envs | task_groups={task_groups} | n_envs(per task)={n_envs}")
is_async = env_cls is gym.vector.AsyncVectorEnv
cached_obs_space = None
cached_act_space = None
out: dict[str, dict[int, Any]] = defaultdict(dict) out: dict[str, dict[int, Any]] = defaultdict(dict)
for group in task_groups: for group in task_groups:
@@ -318,7 +322,14 @@ def create_metaworld_envs(
# build n_envs factories # build n_envs factories
fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)] fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)]
out[group][tid] = env_cls(fns) if is_async:
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space)
if cached_obs_space is None:
cached_obs_space = lazy.observation_space
cached_act_space = lazy.action_space
out[group][tid] = lazy
else:
out[group][tid] = env_cls(fns)
# return a plain dict for consistency # return a plain dict for consistency
return {group: dict(task_map) for group, task_map in out.items()} return {group: dict(task_map) for group, task_map in out.items()}
+57 -1
View File
@@ -16,7 +16,7 @@
import importlib.util import importlib.util
import os import os
import warnings import warnings
from collections.abc import Mapping, Sequence from collections.abc import Callable, Mapping, Sequence
from functools import singledispatch from functools import singledispatch
from typing import Any from typing import Any
@@ -138,6 +138,62 @@ def _sub_env_has_attr(env: gym.vector.VectorEnv, attr: str) -> bool:
return False return False
class _LazyAsyncVectorEnv:
"""Defers AsyncVectorEnv creation until first use.
Creating all tasks' AsyncVectorEnvs upfront spawns N_tasks × n_envs worker
processes, all of which allocate EGL/GPU resources immediately. Since tasks
are evaluated sequentially, only one task's workers need to be alive at a
time. This wrapper stores the factory functions and creates the real
AsyncVectorEnv on first reset()/step()/call(), keeping peak process count = n_envs.
"""
def __init__(
self,
env_fns: list[Callable],
observation_space=None,
action_space=None,
):
self._env_fns = env_fns
self._env: gym.vector.AsyncVectorEnv | None = None
self.num_envs = len(env_fns)
if observation_space is not None and action_space is not None:
self.observation_space = observation_space
self.action_space = action_space
else:
tmp = env_fns[0]()
self.observation_space = tmp.observation_space
self.action_space = tmp.action_space
tmp.close()
self.single_observation_space = self.observation_space
self.single_action_space = self.action_space
def _ensure(self) -> None:
if self._env is None:
self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver", shared_memory=True)
def reset(self, **kwargs):
self._ensure()
return self._env.reset(**kwargs)
def step(self, actions):
self._ensure()
return self._env.step(actions)
def call(self, name, *args, **kwargs):
self._ensure()
return self._env.call(name, *args, **kwargs)
def get_attr(self, name):
self._ensure()
return self._env.get_attr(name)
def close(self) -> None:
if self._env is not None:
self._env.close()
self._env = None
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None: def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("once", UserWarning) warnings.simplefilter("once", UserWarning)