mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Cache lazy async env metadata for eval (#3416)
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
This commit is contained in:
@@ -462,6 +462,7 @@ def create_libero_envs(
|
||||
# Probe once and reuse to avoid creating a temp env per task.
|
||||
cached_obs_space: spaces.Space | None = None
|
||||
cached_act_space: spaces.Space | None = None
|
||||
cached_metadata: dict[str, Any] | None = None
|
||||
|
||||
for tid in selected:
|
||||
fns = _make_env_fns(
|
||||
@@ -477,10 +478,11 @@ def create_libero_envs(
|
||||
camera_name_mapping=camera_name_mapping,
|
||||
)
|
||||
if is_async:
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space)
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
|
||||
if cached_obs_space is None:
|
||||
cached_obs_space = lazy.observation_space
|
||||
cached_act_space = lazy.action_space
|
||||
cached_metadata = lazy.metadata
|
||||
out[suite_name][tid] = lazy
|
||||
else:
|
||||
out[suite_name][tid] = env_cls(fns)
|
||||
|
||||
@@ -311,6 +311,7 @@ def create_metaworld_envs(
|
||||
is_async = env_cls is gym.vector.AsyncVectorEnv
|
||||
cached_obs_space = None
|
||||
cached_act_space = None
|
||||
cached_metadata = None
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
|
||||
for group in task_groups:
|
||||
@@ -324,10 +325,11 @@ def create_metaworld_envs(
|
||||
fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)]
|
||||
|
||||
if is_async:
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space)
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
|
||||
if cached_obs_space is None:
|
||||
cached_obs_space = lazy.observation_space
|
||||
cached_act_space = lazy.action_space
|
||||
cached_metadata = lazy.metadata
|
||||
out[group][tid] = lazy
|
||||
else:
|
||||
out[group][tid] = env_cls(fns)
|
||||
|
||||
@@ -153,17 +153,20 @@ class _LazyAsyncVectorEnv:
|
||||
env_fns: list[Callable],
|
||||
observation_space=None,
|
||||
action_space=None,
|
||||
metadata=None,
|
||||
):
|
||||
self._env_fns = env_fns
|
||||
self._env: gym.vector.AsyncVectorEnv | None = None
|
||||
self.num_envs = len(env_fns)
|
||||
if observation_space is not None and action_space is not None:
|
||||
if observation_space is not None and action_space is not None and metadata is not None:
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.metadata = metadata
|
||||
else:
|
||||
tmp = env_fns[0]()
|
||||
self.observation_space = tmp.observation_space
|
||||
self.action_space = tmp.action_space
|
||||
self.metadata = tmp.metadata
|
||||
tmp.close()
|
||||
self.single_observation_space = self.observation_space
|
||||
self.single_action_space = self.action_space
|
||||
@@ -172,6 +175,10 @@ class _LazyAsyncVectorEnv:
|
||||
if self._env is None:
|
||||
self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver", shared_memory=True)
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
return self
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self._ensure()
|
||||
return self._env.reset(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user