Cache lazy async env metadata for eval (#3416)

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
This commit is contained in:
Haoming Song
2026-04-20 21:33:13 +08:00
committed by GitHub
parent 777b808c70
commit b2765b39b8
3 changed files with 14 additions and 3 deletions
+3 -1
View File
@@ -462,6 +462,7 @@ def create_libero_envs(
# Probe once and reuse to avoid creating a temp env per task. # Probe once and reuse to avoid creating a temp env per task.
cached_obs_space: spaces.Space | None = None cached_obs_space: spaces.Space | None = None
cached_act_space: spaces.Space | None = None cached_act_space: spaces.Space | None = None
cached_metadata: dict[str, Any] | None = None
for tid in selected: for tid in selected:
fns = _make_env_fns( fns = _make_env_fns(
@@ -477,10 +478,11 @@ def create_libero_envs(
camera_name_mapping=camera_name_mapping, camera_name_mapping=camera_name_mapping,
) )
if is_async: if is_async:
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space) lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
if cached_obs_space is None: if cached_obs_space is None:
cached_obs_space = lazy.observation_space cached_obs_space = lazy.observation_space
cached_act_space = lazy.action_space cached_act_space = lazy.action_space
cached_metadata = lazy.metadata
out[suite_name][tid] = lazy out[suite_name][tid] = lazy
else: else:
out[suite_name][tid] = env_cls(fns) out[suite_name][tid] = env_cls(fns)
+3 -1
View File
@@ -311,6 +311,7 @@ def create_metaworld_envs(
is_async = env_cls is gym.vector.AsyncVectorEnv is_async = env_cls is gym.vector.AsyncVectorEnv
cached_obs_space = None cached_obs_space = None
cached_act_space = None cached_act_space = None
cached_metadata = None
out: dict[str, dict[int, Any]] = defaultdict(dict) out: dict[str, dict[int, Any]] = defaultdict(dict)
for group in task_groups: for group in task_groups:
@@ -324,10 +325,11 @@ def create_metaworld_envs(
fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)] fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)]
if is_async: if is_async:
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space) lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
if cached_obs_space is None: if cached_obs_space is None:
cached_obs_space = lazy.observation_space cached_obs_space = lazy.observation_space
cached_act_space = lazy.action_space cached_act_space = lazy.action_space
cached_metadata = lazy.metadata
out[group][tid] = lazy out[group][tid] = lazy
else: else:
out[group][tid] = env_cls(fns) out[group][tid] = env_cls(fns)
+8 -1
View File
@@ -153,17 +153,20 @@ class _LazyAsyncVectorEnv:
env_fns: list[Callable], env_fns: list[Callable],
observation_space=None, observation_space=None,
action_space=None, action_space=None,
metadata=None,
): ):
self._env_fns = env_fns self._env_fns = env_fns
self._env: gym.vector.AsyncVectorEnv | None = None self._env: gym.vector.AsyncVectorEnv | None = None
self.num_envs = len(env_fns) self.num_envs = len(env_fns)
if observation_space is not None and action_space is not None: if observation_space is not None and action_space is not None and metadata is not None:
self.observation_space = observation_space self.observation_space = observation_space
self.action_space = action_space self.action_space = action_space
self.metadata = metadata
else: else:
tmp = env_fns[0]() tmp = env_fns[0]()
self.observation_space = tmp.observation_space self.observation_space = tmp.observation_space
self.action_space = tmp.action_space self.action_space = tmp.action_space
self.metadata = tmp.metadata
tmp.close() tmp.close()
self.single_observation_space = self.observation_space self.single_observation_space = self.observation_space
self.single_action_space = self.action_space self.single_action_space = self.action_space
@@ -172,6 +175,10 @@ class _LazyAsyncVectorEnv:
if self._env is None: if self._env is None:
self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver", shared_memory=True) self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver", shared_memory=True)
@property
def unwrapped(self):
return self
def reset(self, **kwargs): def reset(self, **kwargs):
self._ensure() self._ensure()
return self._env.reset(**kwargs) return self._env.reset(**kwargs)