From d2684d41cd21e3fcb34b284df469004106044998 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Fri, 8 Aug 2025 09:34:14 -0400 Subject: [PATCH] add factory --- src/lerobot/envs/factory.py | 48 ++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index dc6d96d61..cb897e68d 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -17,7 +17,7 @@ import importlib import gymnasium as gym -from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv +from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, LiberoEnv, PushtEnv, XarmEnv def make_env_config(env_type: str, **kwargs) -> EnvConfig: @@ -29,6 +29,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: return XarmEnv(**kwargs) elif env_type == "hil": return HILEnvConfig(**kwargs) + elif env_type == "libero": + return LiberoEnv(**kwargs) else: raise ValueError(f"Policy type '{env_type}' is not available.") @@ -39,12 +41,12 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g Args: cfg (EnvConfig): the config of the environment to instantiate. n_envs (int, optional): The number of parallelized env to return. Defaults to 1. - use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to + use_async_envs (bool, optional): Wether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to False. Raises: ValueError: if n_envs < 1 - ModuleNotFoundError: If the requested env package is not installed + ModuleNotFoundError: If the requested env package is not intalled Returns: gym.vector.VectorEnv: The parallelized gym.env instance. @@ -52,20 +54,34 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g if n_envs < 1: raise ValueError("`n_envs must be at least 1") - package_name = f"gym_{cfg.type}" - - try: - importlib.import_module(package_name) - except ModuleNotFoundError as e: - print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`") - raise e - - gym_handle = f"{package_name}/{cfg.task}" - # batched version of the env that returns an observation of shape (b, c) env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv - env = env_cls( - [lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)] - ) + + if "libero" in cfg.type: + from lerobot.envs.libero import create_libero_envs + + env = create_libero_envs( + task=cfg.task, + n_envs=n_envs, + camera_name=cfg.camera_name, + init_states=cfg.init_states, + gym_kwargs=cfg.gym_kwargs, + env_cls=env_cls, + multitask_eval=cfg.multitask_eval, + ) + else: + package_name = f"gym_{cfg.type}" + try: + importlib.import_module(package_name) + except ModuleNotFoundError as e: + print( + f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`" + ) + raise e + + gym_handle = f"{package_name}/{cfg.task}" + env = env_cls( + [lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)] + ) return env