mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
refactor(envs): move _LazyAsyncVectorEnv to utils and apply to metaworld
_LazyAsyncVectorEnv lived in libero.py but metaworld had the same OOM problem: all tasks' AsyncVectorEnv workers were spawned eagerly, wasting GPU memory for tasks not yet running. Move the class to envs/utils.py so both environments share it, then apply the same is_async + lazy wrapping pattern in create_metaworld_envs. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -29,6 +29,7 @@ from gymnasium import spaces
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
from lerobot.envs.utils import _LazyAsyncVectorEnv
|
||||
from lerobot.types import RobotObservation
|
||||
|
||||
|
||||
@@ -403,62 +404,6 @@ def _make_env_fns(
|
||||
return fns
|
||||
|
||||
|
||||
class _LazyAsyncVectorEnv:
|
||||
"""Wrapper that defers AsyncVectorEnv creation until first use.
|
||||
|
||||
Creating all tasks' AsyncVectorEnvs upfront spawns N_tasks × n_envs worker
|
||||
processes, all of which allocate EGL/GPU resources immediately. Since tasks
|
||||
are evaluated sequentially, only one task's workers need to be alive at a
|
||||
time. This wrapper stores the factory functions and creates the real
|
||||
AsyncVectorEnv on first reset(), keeping peak process count = n_envs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: list[Callable],
|
||||
observation_space: spaces.Space | None = None,
|
||||
action_space: spaces.Space | None = None,
|
||||
):
|
||||
self._env_fns = env_fns
|
||||
self._env: gym.vector.AsyncVectorEnv | None = None
|
||||
self.num_envs = len(env_fns)
|
||||
if observation_space is not None and action_space is not None:
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
else:
|
||||
tmp = env_fns[0]()
|
||||
self.observation_space = tmp.observation_space
|
||||
self.action_space = tmp.action_space
|
||||
tmp.close()
|
||||
self.single_observation_space = self.observation_space
|
||||
self.single_action_space = self.action_space
|
||||
|
||||
def _ensure(self):
|
||||
if self._env is None:
|
||||
self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver", shared_memory=True)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self._ensure()
|
||||
return self._env.reset(**kwargs)
|
||||
|
||||
def step(self, actions):
|
||||
self._ensure()
|
||||
return self._env.step(actions)
|
||||
|
||||
def call(self, name, *args, **kwargs):
|
||||
self._ensure()
|
||||
return self._env.call(name, *args, **kwargs)
|
||||
|
||||
def get_attr(self, name):
|
||||
self._ensure()
|
||||
return self._env.get_attr(name)
|
||||
|
||||
def close(self):
|
||||
if self._env is not None:
|
||||
self._env.close()
|
||||
self._env = None
|
||||
|
||||
|
||||
# ---- Main API ----------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ import metaworld.policies as policies
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from lerobot.envs.utils import _LazyAsyncVectorEnv
|
||||
from lerobot.types import RobotObservation
|
||||
|
||||
# ---- Load configuration data from the external JSON file ----
|
||||
@@ -306,6 +307,9 @@ def create_metaworld_envs(
|
||||
|
||||
print(f"Creating Meta-World envs | task_groups={task_groups} | n_envs(per task)={n_envs}")
|
||||
|
||||
is_async = env_cls is gym.vector.AsyncVectorEnv
|
||||
cached_obs_space = None
|
||||
cached_act_space = None
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
|
||||
for group in task_groups:
|
||||
@@ -318,7 +322,14 @@ def create_metaworld_envs(
|
||||
# build n_envs factories
|
||||
fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)]
|
||||
|
||||
out[group][tid] = env_cls(fns)
|
||||
if is_async:
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space)
|
||||
if cached_obs_space is None:
|
||||
cached_obs_space = lazy.observation_space
|
||||
cached_act_space = lazy.action_space
|
||||
out[group][tid] = lazy
|
||||
else:
|
||||
out[group][tid] = env_cls(fns)
|
||||
|
||||
# return a plain dict for consistency
|
||||
return {group: dict(task_map) for group, task_map in out.items()}
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
import importlib.util
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from functools import singledispatch
|
||||
from typing import Any
|
||||
|
||||
@@ -138,6 +138,62 @@ def _sub_env_has_attr(env: gym.vector.VectorEnv, attr: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _LazyAsyncVectorEnv:
|
||||
"""Defers AsyncVectorEnv creation until first use.
|
||||
|
||||
Creating all tasks' AsyncVectorEnvs upfront spawns N_tasks × n_envs worker
|
||||
processes, all of which allocate EGL/GPU resources immediately. Since tasks
|
||||
are evaluated sequentially, only one task's workers need to be alive at a
|
||||
time. This wrapper stores the factory functions and creates the real
|
||||
AsyncVectorEnv on first reset()/step()/call(), keeping peak process count = n_envs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: list[Callable],
|
||||
observation_space=None,
|
||||
action_space=None,
|
||||
):
|
||||
self._env_fns = env_fns
|
||||
self._env: gym.vector.AsyncVectorEnv | None = None
|
||||
self.num_envs = len(env_fns)
|
||||
if observation_space is not None and action_space is not None:
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
else:
|
||||
tmp = env_fns[0]()
|
||||
self.observation_space = tmp.observation_space
|
||||
self.action_space = tmp.action_space
|
||||
tmp.close()
|
||||
self.single_observation_space = self.observation_space
|
||||
self.single_action_space = self.action_space
|
||||
|
||||
def _ensure(self) -> None:
|
||||
if self._env is None:
|
||||
self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver", shared_memory=True)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self._ensure()
|
||||
return self._env.reset(**kwargs)
|
||||
|
||||
def step(self, actions):
|
||||
self._ensure()
|
||||
return self._env.step(actions)
|
||||
|
||||
def call(self, name, *args, **kwargs):
|
||||
self._ensure()
|
||||
return self._env.call(name, *args, **kwargs)
|
||||
|
||||
def get_attr(self, name):
|
||||
self._ensure()
|
||||
return self._env.get_attr(name)
|
||||
|
||||
def close(self) -> None:
|
||||
if self._env is not None:
|
||||
self._env.close()
|
||||
self._env = None
|
||||
|
||||
|
||||
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("once", UserWarning)
|
||||
|
||||
Reference in New Issue
Block a user