From 8831b3c47b861c316a8330636a498fd70b3934ba Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Mon, 8 Dec 2025 11:11:38 +0100 Subject: [PATCH] add changes --- docs/source/envhub.mdx | 48 ++++++++++++++++++++++++++-- src/lerobot/envs/factory.py | 7 +++-- src/lerobot/envs/utils.py | 15 ++++++--- tests/envs/test_envs.py | 62 +++++++++++++++++++++++++++++++++++++ 4 files changed, 123 insertions(+), 9 deletions(-) diff --git a/docs/source/envhub.mdx b/docs/source/envhub.mdx index ba6464460..9fc6c6220 100644 --- a/docs/source/envhub.mdx +++ b/docs/source/envhub.mdx @@ -48,7 +48,8 @@ To make your environment loadable from the Hub, your repository must contain at **`env.py`** (or custom Python file) -- Must expose a `make_env(n_envs: int, use_async_envs: bool)` function +- Must expose a `make_env(n_envs: int, use_async_envs: bool, **kwargs)` function +- The function should accept `**kwargs` to allow users to pass custom configurations - This function should return one of: - A `gym.vector.VectorEnv` (most common) - A single `gym.Env` (will be automatically wrapped) @@ -92,6 +93,8 @@ Create an `env.py` file with a `make_env` function: ```python # env.py import gymnasium as gym +from pathlib import Path +from typing import Any def make_env(n_envs: int = 1, use_async_envs: bool = False): """ @@ -243,6 +246,44 @@ envs_dict = make_env( ) ``` +### Custom Configuration via kwargs + +Hub environments can accept custom configurations through keyword arguments. This is useful for parameterizing tasks, loading different objects, or overriding default settings: + +```python +from pathlib import Path + +# Pass a config file path +envs_dict = make_env( + "nvkartik/isaaclab-arena-envs:envs/microwave_g1.py", + n_envs=4, + trust_remote_code=True, + config_path=Path("/path/to/my_config.yaml"), +) + +# Pass config overrides as a dictionary +envs_dict = make_env( + "nvkartik/isaaclab-arena-envs:envs/microwave_g1.py", + n_envs=4, + trust_remote_code=True, + config_overrides={ + "scene.object": "microwave", + "sim.dt": 0.01, + }, +) + +# Combine config path with overrides +envs_dict = make_env( + "username/my-env", + n_envs=4, + trust_remote_code=True, + config_path="configs/gr1_pick_place.yaml", + config_overrides={"scene.table_objects": ["apple", "banana", "cup"]}, +) +``` + +Any keyword arguments you pass will be forwarded to the hub environment's `make_env` function. Check the environment's documentation for supported configuration options. + ## URL Format Reference The hub URL format supports several patterns: @@ -259,7 +300,7 @@ The hub URL format supports several patterns: For benchmarks with multiple tasks (like LIBERO), return a nested dictionary: ```python -def make_env(n_envs: int = 1, use_async_envs: bool = False): +def make_env(n_envs: int = 1, use_async_envs: bool = False, **kwargs): env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv # Return dict: {suite_name: {task_id: VectorEnv}} @@ -381,8 +422,9 @@ pip install gymnasium numpy Your `env.py` must expose a `make_env` function: ```python -def make_env(n_envs: int, use_async_envs: bool): +def make_env(n_envs: int, use_async_envs: bool, **kwargs): # Your implementation + # kwargs can include config_path, config_overrides, etc. pass ``` diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index b39cfee71..01f92574f 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -85,6 +85,7 @@ def make_env( use_async_envs: bool = False, hub_cache_dir: str | None = None, trust_remote_code: bool = False, + **kwargs, ) -> dict[str, dict[int, gym.vector.VectorEnv]]: """Makes a gym vector environment according to the config or Hub reference. @@ -98,6 +99,8 @@ def make_env( hub_cache_dir (str | None): Optional cache path for downloaded hub files. trust_remote_code (bool): **Explicit consent** to execute remote code from the Hub. Default False — must be set to True to import/exec hub `env.py`. + **kwargs: Additional keyword arguments passed to the hub environment's `make_env` function. + Useful for passing custom configurations like `config_path`, `config_overrides`, etc. Raises: ValueError: if n_envs < 1 @@ -119,8 +122,8 @@ def make_env( # import and surface clear import errors module = _import_hub_module(local_file, repo_id) - # call the hub-provided make_env - raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs) + # call the hub-provided make_env with any additional kwargs + raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs, **kwargs) # normalize the return into {suite: {task_id: vec_env}} return _normalize_hub_result(raw_result) diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index 8d0f24922..c86d61543 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -302,16 +302,23 @@ def _import_hub_module(local_file: str, repo_id: str) -> Any: return module -def _call_make_env(module: Any, n_envs: int, use_async_envs: bool) -> Any: +def _call_make_env(module: Any, n_envs: int, use_async_envs: bool, **kwargs) -> Any: """ - Ensure module exposes make_env and call it. + Ensure module exposes make_env and call it with any additional kwargs. + + Args: + module: The imported hub module containing make_env. + n_envs: Number of parallel environments. + use_async_envs: Whether to use AsyncVectorEnv or SyncVectorEnv. + **kwargs: Additional keyword arguments to pass to the hub's make_env function. + Common examples include config_path, config_overrides, etc. """ if not hasattr(module, "make_env"): raise AttributeError( - f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool)`." + f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool, **kwargs)`." ) entry_fn = module.make_env - return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs) + return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs, **kwargs) def _normalize_hub_result(result: Any) -> dict[str, dict[int, gym.vector.VectorEnv]]: diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 910c275eb..49df0b949 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -266,3 +266,65 @@ def test_make_env_from_hub_async(): # clean up env.close() + + +def test_make_env_from_hub_with_kwargs(): + """Test that kwargs are correctly passed to hub environment's make_env.""" + hub_id = "lerobot/dummy-hub-env" + + # Test with config_path kwarg + envs_dict = make_env( + hub_id, + n_envs=1, + trust_remote_code=True, + config_path="/path/to/config.yaml", + ) + env = envs_dict["cartpole_suite"][0] + + assert hasattr(env, "hub_config") + assert env.hub_config["config_path"] == "/path/to/config.yaml" + env.close() + + # Test with config_overrides dict + envs_dict = make_env( + hub_id, + n_envs=1, + trust_remote_code=True, + config_overrides={"scene.object": "microwave", "sim.dt": 0.01}, + ) + env = envs_dict["cartpole_suite"][0] + + assert env.hub_config["config_overrides"]["scene.object"] == "microwave" + assert env.hub_config["config_overrides"]["sim.dt"] == 0.01 + env.close() + + # Test with arbitrary extra kwargs + envs_dict = make_env( + hub_id, + n_envs=1, + trust_remote_code=True, + custom_param="value", + another_param=42, + ) + env = envs_dict["cartpole_suite"][0] + + assert env.hub_config["extra_kwargs"]["custom_param"] == "value" + assert env.hub_config["extra_kwargs"]["another_param"] == 42 + env.close() + + # Test combining config_path, config_overrides, and extra kwargs + envs_dict = make_env( + hub_id, + n_envs=2, + trust_remote_code=True, + config_path="my_config.yaml", + config_overrides={"robot": "gr1"}, + task_name="pick_and_place", + ) + env = envs_dict["cartpole_suite"][0] + + assert env.hub_config["config_path"] == "my_config.yaml" + assert env.hub_config["config_overrides"]["robot"] == "gr1" + assert env.hub_config["extra_kwargs"]["task_name"] == "pick_and_place" + assert env.num_envs == 2 + env.close()