add factory

This commit is contained in:
Jade Choghari
2025-08-08 09:34:14 -04:00
parent 4e76c1f88c
commit d2684d41cd
+32 -16
View File
@@ -17,7 +17,7 @@ import importlib
import gymnasium as gym
from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, LiberoEnv, PushtEnv, XarmEnv
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -29,6 +29,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
return XarmEnv(**kwargs)
elif env_type == "hil":
return HILEnvConfig(**kwargs)
elif env_type == "libero":
return LiberoEnv(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")
@@ -39,12 +41,12 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
Args:
cfg (EnvConfig): the config of the environment to instantiate.
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
use_async_envs (bool, optional): Wether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
False.
Raises:
ValueError: if n_envs < 1
ModuleNotFoundError: If the requested env package is not installed
ModuleNotFoundError: If the requested env package is not intalled
Returns:
gym.vector.VectorEnv: The parallelized gym.env instance.
@@ -52,20 +54,34 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
if n_envs < 1:
raise ValueError("`n_envs must be at least 1")
package_name = f"gym_{cfg.type}"
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
raise e
gym_handle = f"{package_name}/{cfg.task}"
# batched version of the env that returns an observation of shape (b, c)
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
env = env_cls(
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
)
if "libero" in cfg.type:
from lerobot.envs.libero import create_libero_envs
env = create_libero_envs(
task=cfg.task,
n_envs=n_envs,
camera_name=cfg.camera_name,
init_states=cfg.init_states,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
multitask_eval=cfg.multitask_eval,
)
else:
package_name = f"gym_{cfg.type}"
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`"
)
raise e
gym_handle = f"{package_name}/{cfg.task}"
env = env_cls(
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
)
return env