mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Cache lazy async env metadata for eval (#3416)
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
This commit is contained in:
@@ -462,6 +462,7 @@ def create_libero_envs(
|
|||||||
# Probe once and reuse to avoid creating a temp env per task.
|
# Probe once and reuse to avoid creating a temp env per task.
|
||||||
cached_obs_space: spaces.Space | None = None
|
cached_obs_space: spaces.Space | None = None
|
||||||
cached_act_space: spaces.Space | None = None
|
cached_act_space: spaces.Space | None = None
|
||||||
|
cached_metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
for tid in selected:
|
for tid in selected:
|
||||||
fns = _make_env_fns(
|
fns = _make_env_fns(
|
||||||
@@ -477,10 +478,11 @@ def create_libero_envs(
|
|||||||
camera_name_mapping=camera_name_mapping,
|
camera_name_mapping=camera_name_mapping,
|
||||||
)
|
)
|
||||||
if is_async:
|
if is_async:
|
||||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space)
|
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
|
||||||
if cached_obs_space is None:
|
if cached_obs_space is None:
|
||||||
cached_obs_space = lazy.observation_space
|
cached_obs_space = lazy.observation_space
|
||||||
cached_act_space = lazy.action_space
|
cached_act_space = lazy.action_space
|
||||||
|
cached_metadata = lazy.metadata
|
||||||
out[suite_name][tid] = lazy
|
out[suite_name][tid] = lazy
|
||||||
else:
|
else:
|
||||||
out[suite_name][tid] = env_cls(fns)
|
out[suite_name][tid] = env_cls(fns)
|
||||||
|
|||||||
@@ -311,6 +311,7 @@ def create_metaworld_envs(
|
|||||||
is_async = env_cls is gym.vector.AsyncVectorEnv
|
is_async = env_cls is gym.vector.AsyncVectorEnv
|
||||||
cached_obs_space = None
|
cached_obs_space = None
|
||||||
cached_act_space = None
|
cached_act_space = None
|
||||||
|
cached_metadata = None
|
||||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||||
|
|
||||||
for group in task_groups:
|
for group in task_groups:
|
||||||
@@ -324,10 +325,11 @@ def create_metaworld_envs(
|
|||||||
fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)]
|
fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)]
|
||||||
|
|
||||||
if is_async:
|
if is_async:
|
||||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space)
|
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
|
||||||
if cached_obs_space is None:
|
if cached_obs_space is None:
|
||||||
cached_obs_space = lazy.observation_space
|
cached_obs_space = lazy.observation_space
|
||||||
cached_act_space = lazy.action_space
|
cached_act_space = lazy.action_space
|
||||||
|
cached_metadata = lazy.metadata
|
||||||
out[group][tid] = lazy
|
out[group][tid] = lazy
|
||||||
else:
|
else:
|
||||||
out[group][tid] = env_cls(fns)
|
out[group][tid] = env_cls(fns)
|
||||||
|
|||||||
@@ -153,17 +153,20 @@ class _LazyAsyncVectorEnv:
|
|||||||
env_fns: list[Callable],
|
env_fns: list[Callable],
|
||||||
observation_space=None,
|
observation_space=None,
|
||||||
action_space=None,
|
action_space=None,
|
||||||
|
metadata=None,
|
||||||
):
|
):
|
||||||
self._env_fns = env_fns
|
self._env_fns = env_fns
|
||||||
self._env: gym.vector.AsyncVectorEnv | None = None
|
self._env: gym.vector.AsyncVectorEnv | None = None
|
||||||
self.num_envs = len(env_fns)
|
self.num_envs = len(env_fns)
|
||||||
if observation_space is not None and action_space is not None:
|
if observation_space is not None and action_space is not None and metadata is not None:
|
||||||
self.observation_space = observation_space
|
self.observation_space = observation_space
|
||||||
self.action_space = action_space
|
self.action_space = action_space
|
||||||
|
self.metadata = metadata
|
||||||
else:
|
else:
|
||||||
tmp = env_fns[0]()
|
tmp = env_fns[0]()
|
||||||
self.observation_space = tmp.observation_space
|
self.observation_space = tmp.observation_space
|
||||||
self.action_space = tmp.action_space
|
self.action_space = tmp.action_space
|
||||||
|
self.metadata = tmp.metadata
|
||||||
tmp.close()
|
tmp.close()
|
||||||
self.single_observation_space = self.observation_space
|
self.single_observation_space = self.observation_space
|
||||||
self.single_action_space = self.action_space
|
self.single_action_space = self.action_space
|
||||||
@@ -172,6 +175,10 @@ class _LazyAsyncVectorEnv:
|
|||||||
if self._env is None:
|
if self._env is None:
|
||||||
self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver", shared_memory=True)
|
self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver", shared_memory=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unwrapped(self):
|
||||||
|
return self
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
self._ensure()
|
self._ensure()
|
||||||
return self._env.reset(**kwargs)
|
return self._env.reset(**kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user