From fe05e5095ba7161ed16da819964799536c6d2a7d Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 8 Apr 2026 14:32:51 +0200 Subject: [PATCH] refactor(envs): move _LazyAsyncVectorEnv to utils and apply to metaworld _LazyAsyncVectorEnv lived in libero.py but metaworld had the same OOM problem: all tasks' AsyncVectorEnv workers were spawned eagerly, wasting GPU memory for tasks not yet running. Move the class to envs/utils.py so both environments share it, then apply the same is_async + lazy wrapping pattern in create_metaworld_envs. Co-Authored-By: Claude Sonnet 4.6 --- src/lerobot/envs/libero.py | 57 +--------------------------------- src/lerobot/envs/metaworld.py | 13 +++++++- src/lerobot/envs/utils.py | 58 ++++++++++++++++++++++++++++++++++- 3 files changed, 70 insertions(+), 58 deletions(-) diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 7b1e8efe0..1b814db52 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -29,6 +29,7 @@ from gymnasium import spaces from libero.libero import benchmark, get_libero_path from libero.libero.envs import OffScreenRenderEnv +from lerobot.envs.utils import _LazyAsyncVectorEnv from lerobot.types import RobotObservation @@ -403,62 +404,6 @@ def _make_env_fns( return fns -class _LazyAsyncVectorEnv: - """Wrapper that defers AsyncVectorEnv creation until first use. - - Creating all tasks' AsyncVectorEnvs upfront spawns N_tasks × n_envs worker - processes, all of which allocate EGL/GPU resources immediately. Since tasks - are evaluated sequentially, only one task's workers need to be alive at a - time. This wrapper stores the factory functions and creates the real - AsyncVectorEnv on first reset(), keeping peak process count = n_envs. - """ - - def __init__( - self, - env_fns: list[Callable], - observation_space: spaces.Space | None = None, - action_space: spaces.Space | None = None, - ): - self._env_fns = env_fns - self._env: gym.vector.AsyncVectorEnv | None = None - self.num_envs = len(env_fns) - if observation_space is not None and action_space is not None: - self.observation_space = observation_space - self.action_space = action_space - else: - tmp = env_fns[0]() - self.observation_space = tmp.observation_space - self.action_space = tmp.action_space - tmp.close() - self.single_observation_space = self.observation_space - self.single_action_space = self.action_space - - def _ensure(self): - if self._env is None: - self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver", shared_memory=True) - - def reset(self, **kwargs): - self._ensure() - return self._env.reset(**kwargs) - - def step(self, actions): - self._ensure() - return self._env.step(actions) - - def call(self, name, *args, **kwargs): - self._ensure() - return self._env.call(name, *args, **kwargs) - - def get_attr(self, name): - self._ensure() - return self._env.get_attr(name) - - def close(self): - if self._env is not None: - self._env.close() - self._env = None - - # ---- Main API ---------------------------------------------------------------- diff --git a/src/lerobot/envs/metaworld.py b/src/lerobot/envs/metaworld.py index 273251312..49c775957 100644 --- a/src/lerobot/envs/metaworld.py +++ b/src/lerobot/envs/metaworld.py @@ -25,6 +25,7 @@ import metaworld.policies as policies import numpy as np from gymnasium import spaces +from lerobot.envs.utils import _LazyAsyncVectorEnv from lerobot.types import RobotObservation # ---- Load configuration data from the external JSON file ---- @@ -306,6 +307,9 @@ def create_metaworld_envs( print(f"Creating Meta-World envs | task_groups={task_groups} | n_envs(per task)={n_envs}") + is_async = env_cls is gym.vector.AsyncVectorEnv + cached_obs_space = None + cached_act_space = None out: dict[str, dict[int, Any]] = defaultdict(dict) for group in task_groups: @@ -318,7 +322,14 @@ def create_metaworld_envs( # build n_envs factories fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)] - out[group][tid] = env_cls(fns) + if is_async: + lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space) + if cached_obs_space is None: + cached_obs_space = lazy.observation_space + cached_act_space = lazy.action_space + out[group][tid] = lazy + else: + out[group][tid] = env_cls(fns) # return a plain dict for consistency return {group: dict(task_map) for group, task_map in out.items()} diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index dfaaa3e3e..b47146325 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -16,7 +16,7 @@ import importlib.util import os import warnings -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from functools import singledispatch from typing import Any @@ -138,6 +138,62 @@ def _sub_env_has_attr(env: gym.vector.VectorEnv, attr: str) -> bool: return False +class _LazyAsyncVectorEnv: + """Defers AsyncVectorEnv creation until first use. + + Creating all tasks' AsyncVectorEnvs upfront spawns N_tasks × n_envs worker + processes, all of which allocate EGL/GPU resources immediately. Since tasks + are evaluated sequentially, only one task's workers need to be alive at a + time. This wrapper stores the factory functions and creates the real + AsyncVectorEnv on first reset()/step()/call(), keeping peak process count = n_envs. + """ + + def __init__( + self, + env_fns: list[Callable], + observation_space=None, + action_space=None, + ): + self._env_fns = env_fns + self._env: gym.vector.AsyncVectorEnv | None = None + self.num_envs = len(env_fns) + if observation_space is not None and action_space is not None: + self.observation_space = observation_space + self.action_space = action_space + else: + tmp = env_fns[0]() + self.observation_space = tmp.observation_space + self.action_space = tmp.action_space + tmp.close() + self.single_observation_space = self.observation_space + self.single_action_space = self.action_space + + def _ensure(self) -> None: + if self._env is None: + self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver", shared_memory=True) + + def reset(self, **kwargs): + self._ensure() + return self._env.reset(**kwargs) + + def step(self, actions): + self._ensure() + return self._env.step(actions) + + def call(self, name, *args, **kwargs): + self._ensure() + return self._env.call(name, *args, **kwargs) + + def get_attr(self, name): + self._ensure() + return self._env.get_attr(name) + + def close(self) -> None: + if self._env is not None: + self._env.close() + self._env = None + + def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None: with warnings.catch_warnings(): warnings.simplefilter("once", UserWarning)