add changes

This commit is contained in:
Jade Choghari
2025-12-08 11:11:38 +01:00
parent 0217e1e3ad
commit 8831b3c47b
4 changed files with 123 additions and 9 deletions
+45 -3
View File
@@ -48,7 +48,8 @@ To make your environment loadable from the Hub, your repository must contain at
**`env.py`** (or custom Python file)
- Must expose a `make_env(n_envs: int, use_async_envs: bool)` function
- Must expose a `make_env(n_envs: int, use_async_envs: bool, **kwargs)` function
- The function should accept `**kwargs` to allow users to pass custom configurations
- This function should return one of:
- A `gym.vector.VectorEnv` (most common)
- A single `gym.Env` (will be automatically wrapped)
@@ -92,6 +93,8 @@ Create an `env.py` file with a `make_env` function:
```python
# env.py
import gymnasium as gym
from pathlib import Path
from typing import Any
def make_env(n_envs: int = 1, use_async_envs: bool = False):
"""
@@ -243,6 +246,44 @@ envs_dict = make_env(
)
```
### Custom Configuration via kwargs
Hub environments can accept custom configurations through keyword arguments. This is useful for parameterizing tasks, loading different objects, or overriding default settings:
```python
from pathlib import Path
# Pass a config file path
envs_dict = make_env(
"nvkartik/isaaclab-arena-envs:envs/microwave_g1.py",
n_envs=4,
trust_remote_code=True,
config_path=Path("/path/to/my_config.yaml"),
)
# Pass config overrides as a dictionary
envs_dict = make_env(
"nvkartik/isaaclab-arena-envs:envs/microwave_g1.py",
n_envs=4,
trust_remote_code=True,
config_overrides={
"scene.object": "microwave",
"sim.dt": 0.01,
},
)
# Combine config path with overrides
envs_dict = make_env(
"username/my-env",
n_envs=4,
trust_remote_code=True,
config_path="configs/gr1_pick_place.yaml",
config_overrides={"scene.table_objects": ["apple", "banana", "cup"]},
)
```
Any keyword arguments you pass will be forwarded to the hub environment's `make_env` function. Check the environment's documentation for supported configuration options.
## URL Format Reference
The hub URL format supports several patterns:
@@ -259,7 +300,7 @@ The hub URL format supports several patterns:
For benchmarks with multiple tasks (like LIBERO), return a nested dictionary:
```python
def make_env(n_envs: int = 1, use_async_envs: bool = False):
def make_env(n_envs: int = 1, use_async_envs: bool = False, **kwargs):
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
# Return dict: {suite_name: {task_id: VectorEnv}}
@@ -381,8 +422,9 @@ pip install gymnasium numpy
Your `env.py` must expose a `make_env` function:
```python
def make_env(n_envs: int, use_async_envs: bool):
def make_env(n_envs: int, use_async_envs: bool, **kwargs):
# Your implementation
# kwargs can include config_path, config_overrides, etc.
pass
```
+5 -2
View File
@@ -85,6 +85,7 @@ def make_env(
use_async_envs: bool = False,
hub_cache_dir: str | None = None,
trust_remote_code: bool = False,
**kwargs,
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""Makes a gym vector environment according to the config or Hub reference.
@@ -98,6 +99,8 @@ def make_env(
hub_cache_dir (str | None): Optional cache path for downloaded hub files.
trust_remote_code (bool): **Explicit consent** to execute remote code from the Hub.
Default False — must be set to True to import/exec hub `env.py`.
**kwargs: Additional keyword arguments passed to the hub environment's `make_env` function.
Useful for passing custom configurations like `config_path`, `config_overrides`, etc.
Raises:
ValueError: if n_envs < 1
@@ -119,8 +122,8 @@ def make_env(
# import and surface clear import errors
module = _import_hub_module(local_file, repo_id)
# call the hub-provided make_env
raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs)
# call the hub-provided make_env with any additional kwargs
raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs, **kwargs)
# normalize the return into {suite: {task_id: vec_env}}
return _normalize_hub_result(raw_result)
+11 -4
View File
@@ -302,16 +302,23 @@ def _import_hub_module(local_file: str, repo_id: str) -> Any:
return module
def _call_make_env(module: Any, n_envs: int, use_async_envs: bool) -> Any:
def _call_make_env(module: Any, n_envs: int, use_async_envs: bool, **kwargs) -> Any:
"""
Ensure module exposes make_env and call it.
Ensure module exposes make_env and call it with any additional kwargs.
Args:
module: The imported hub module containing make_env.
n_envs: Number of parallel environments.
use_async_envs: Whether to use AsyncVectorEnv or SyncVectorEnv.
**kwargs: Additional keyword arguments to pass to the hub's make_env function.
Common examples include config_path, config_overrides, etc.
"""
if not hasattr(module, "make_env"):
raise AttributeError(
f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool)`."
f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool, **kwargs)`."
)
entry_fn = module.make_env
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs)
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs, **kwargs)
def _normalize_hub_result(result: Any) -> dict[str, dict[int, gym.vector.VectorEnv]]:
+62
View File
@@ -266,3 +266,65 @@ def test_make_env_from_hub_async():
# clean up
env.close()
def test_make_env_from_hub_with_kwargs():
"""Test that kwargs are correctly passed to hub environment's make_env."""
hub_id = "lerobot/dummy-hub-env"
# Test with config_path kwarg
envs_dict = make_env(
hub_id,
n_envs=1,
trust_remote_code=True,
config_path="/path/to/config.yaml",
)
env = envs_dict["cartpole_suite"][0]
assert hasattr(env, "hub_config")
assert env.hub_config["config_path"] == "/path/to/config.yaml"
env.close()
# Test with config_overrides dict
envs_dict = make_env(
hub_id,
n_envs=1,
trust_remote_code=True,
config_overrides={"scene.object": "microwave", "sim.dt": 0.01},
)
env = envs_dict["cartpole_suite"][0]
assert env.hub_config["config_overrides"]["scene.object"] == "microwave"
assert env.hub_config["config_overrides"]["sim.dt"] == 0.01
env.close()
# Test with arbitrary extra kwargs
envs_dict = make_env(
hub_id,
n_envs=1,
trust_remote_code=True,
custom_param="value",
another_param=42,
)
env = envs_dict["cartpole_suite"][0]
assert env.hub_config["extra_kwargs"]["custom_param"] == "value"
assert env.hub_config["extra_kwargs"]["another_param"] == 42
env.close()
# Test combining config_path, config_overrides, and extra kwargs
envs_dict = make_env(
hub_id,
n_envs=2,
trust_remote_code=True,
config_path="my_config.yaml",
config_overrides={"robot": "gr1"},
task_name="pick_and_place",
)
env = envs_dict["cartpole_suite"][0]
assert env.hub_config["config_path"] == "my_config.yaml"
assert env.hub_config["config_overrides"]["robot"] == "gr1"
assert env.hub_config["extra_kwargs"]["task_name"] == "pick_and_place"
assert env.num_envs == 2
env.close()