mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
refactor(envs): move _LazyAsyncVectorEnv to utils and apply to metaworld
_LazyAsyncVectorEnv lived in libero.py but metaworld had the same OOM problem: all tasks' AsyncVectorEnv workers were spawned eagerly, wasting GPU memory for tasks not yet running. Move the class to envs/utils.py so both environments share it, then apply the same is_async + lazy wrapping pattern in create_metaworld_envs. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -29,6 +29,7 @@ from gymnasium import spaces
|
|||||||
from libero.libero import benchmark, get_libero_path
|
from libero.libero import benchmark, get_libero_path
|
||||||
from libero.libero.envs import OffScreenRenderEnv
|
from libero.libero.envs import OffScreenRenderEnv
|
||||||
|
|
||||||
|
from lerobot.envs.utils import _LazyAsyncVectorEnv
|
||||||
from lerobot.types import RobotObservation
|
from lerobot.types import RobotObservation
|
||||||
|
|
||||||
|
|
||||||
@@ -403,62 +404,6 @@ def _make_env_fns(
|
|||||||
return fns
|
return fns
|
||||||
|
|
||||||
|
|
||||||
class _LazyAsyncVectorEnv:
|
|
||||||
"""Wrapper that defers AsyncVectorEnv creation until first use.
|
|
||||||
|
|
||||||
Creating all tasks' AsyncVectorEnvs upfront spawns N_tasks × n_envs worker
|
|
||||||
processes, all of which allocate EGL/GPU resources immediately. Since tasks
|
|
||||||
are evaluated sequentially, only one task's workers need to be alive at a
|
|
||||||
time. This wrapper stores the factory functions and creates the real
|
|
||||||
AsyncVectorEnv on first reset(), keeping peak process count = n_envs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
env_fns: list[Callable],
|
|
||||||
observation_space: spaces.Space | None = None,
|
|
||||||
action_space: spaces.Space | None = None,
|
|
||||||
):
|
|
||||||
self._env_fns = env_fns
|
|
||||||
self._env: gym.vector.AsyncVectorEnv | None = None
|
|
||||||
self.num_envs = len(env_fns)
|
|
||||||
if observation_space is not None and action_space is not None:
|
|
||||||
self.observation_space = observation_space
|
|
||||||
self.action_space = action_space
|
|
||||||
else:
|
|
||||||
tmp = env_fns[0]()
|
|
||||||
self.observation_space = tmp.observation_space
|
|
||||||
self.action_space = tmp.action_space
|
|
||||||
tmp.close()
|
|
||||||
self.single_observation_space = self.observation_space
|
|
||||||
self.single_action_space = self.action_space
|
|
||||||
|
|
||||||
def _ensure(self):
|
|
||||||
if self._env is None:
|
|
||||||
self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver", shared_memory=True)
|
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
|
||||||
self._ensure()
|
|
||||||
return self._env.reset(**kwargs)
|
|
||||||
|
|
||||||
def step(self, actions):
|
|
||||||
self._ensure()
|
|
||||||
return self._env.step(actions)
|
|
||||||
|
|
||||||
def call(self, name, *args, **kwargs):
|
|
||||||
self._ensure()
|
|
||||||
return self._env.call(name, *args, **kwargs)
|
|
||||||
|
|
||||||
def get_attr(self, name):
|
|
||||||
self._ensure()
|
|
||||||
return self._env.get_attr(name)
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
if self._env is not None:
|
|
||||||
self._env.close()
|
|
||||||
self._env = None
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Main API ----------------------------------------------------------------
|
# ---- Main API ----------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import metaworld.policies as policies
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
from lerobot.envs.utils import _LazyAsyncVectorEnv
|
||||||
from lerobot.types import RobotObservation
|
from lerobot.types import RobotObservation
|
||||||
|
|
||||||
# ---- Load configuration data from the external JSON file ----
|
# ---- Load configuration data from the external JSON file ----
|
||||||
@@ -306,6 +307,9 @@ def create_metaworld_envs(
|
|||||||
|
|
||||||
print(f"Creating Meta-World envs | task_groups={task_groups} | n_envs(per task)={n_envs}")
|
print(f"Creating Meta-World envs | task_groups={task_groups} | n_envs(per task)={n_envs}")
|
||||||
|
|
||||||
|
is_async = env_cls is gym.vector.AsyncVectorEnv
|
||||||
|
cached_obs_space = None
|
||||||
|
cached_act_space = None
|
||||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||||
|
|
||||||
for group in task_groups:
|
for group in task_groups:
|
||||||
@@ -318,7 +322,14 @@ def create_metaworld_envs(
|
|||||||
# build n_envs factories
|
# build n_envs factories
|
||||||
fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)]
|
fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)]
|
||||||
|
|
||||||
out[group][tid] = env_cls(fns)
|
if is_async:
|
||||||
|
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space)
|
||||||
|
if cached_obs_space is None:
|
||||||
|
cached_obs_space = lazy.observation_space
|
||||||
|
cached_act_space = lazy.action_space
|
||||||
|
out[group][tid] = lazy
|
||||||
|
else:
|
||||||
|
out[group][tid] = env_cls(fns)
|
||||||
|
|
||||||
# return a plain dict for consistency
|
# return a plain dict for consistency
|
||||||
return {group: dict(task_map) for group, task_map in out.items()}
|
return {group: dict(task_map) for group, task_map in out.items()}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Callable, Mapping, Sequence
|
||||||
from functools import singledispatch
|
from functools import singledispatch
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -138,6 +138,62 @@ def _sub_env_has_attr(env: gym.vector.VectorEnv, attr: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class _LazyAsyncVectorEnv:
|
||||||
|
"""Defers AsyncVectorEnv creation until first use.
|
||||||
|
|
||||||
|
Creating all tasks' AsyncVectorEnvs upfront spawns N_tasks × n_envs worker
|
||||||
|
processes, all of which allocate EGL/GPU resources immediately. Since tasks
|
||||||
|
are evaluated sequentially, only one task's workers need to be alive at a
|
||||||
|
time. This wrapper stores the factory functions and creates the real
|
||||||
|
AsyncVectorEnv on first reset()/step()/call(), keeping peak process count = n_envs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env_fns: list[Callable],
|
||||||
|
observation_space=None,
|
||||||
|
action_space=None,
|
||||||
|
):
|
||||||
|
self._env_fns = env_fns
|
||||||
|
self._env: gym.vector.AsyncVectorEnv | None = None
|
||||||
|
self.num_envs = len(env_fns)
|
||||||
|
if observation_space is not None and action_space is not None:
|
||||||
|
self.observation_space = observation_space
|
||||||
|
self.action_space = action_space
|
||||||
|
else:
|
||||||
|
tmp = env_fns[0]()
|
||||||
|
self.observation_space = tmp.observation_space
|
||||||
|
self.action_space = tmp.action_space
|
||||||
|
tmp.close()
|
||||||
|
self.single_observation_space = self.observation_space
|
||||||
|
self.single_action_space = self.action_space
|
||||||
|
|
||||||
|
def _ensure(self) -> None:
|
||||||
|
if self._env is None:
|
||||||
|
self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver", shared_memory=True)
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
self._ensure()
|
||||||
|
return self._env.reset(**kwargs)
|
||||||
|
|
||||||
|
def step(self, actions):
|
||||||
|
self._ensure()
|
||||||
|
return self._env.step(actions)
|
||||||
|
|
||||||
|
def call(self, name, *args, **kwargs):
|
||||||
|
self._ensure()
|
||||||
|
return self._env.call(name, *args, **kwargs)
|
||||||
|
|
||||||
|
def get_attr(self, name):
|
||||||
|
self._ensure()
|
||||||
|
return self._env.get_attr(name)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
if self._env is not None:
|
||||||
|
self._env.close()
|
||||||
|
self._env = None
|
||||||
|
|
||||||
|
|
||||||
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("once", UserWarning)
|
warnings.simplefilter("once", UserWarning)
|
||||||
|
|||||||
Reference in New Issue
Block a user