mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
refactor(envs): move benchmark dispatch into EnvConfig subclasses (#3272)
* docs(benchmarks): add benchmark integration guide and standardize benchmark docs Add a comprehensive guide for adding new benchmarks to LeRobot, and refactor the existing LIBERO and Meta-World docs to follow the new standardized template. * refactor(envs): move dispatch logic from factory into EnvConfig subclasses Replace hardcoded if/elif chains in factory.py with create_envs() and get_env_processors() methods on EnvConfig. New benchmarks now only need to register a config subclass — no factory.py edits required. Net -23 lines: factory.py shrinks from ~200 to ~70 lines of logic. * docs(benchmarks): clean up adding-benchmarks guide for clarity Rewrite for simpler language, better structure, and easier navigation. Move quick-reference table to the top, fold eval explanation into architecture section, condense the doc template to a bulleted outline. * fix link * fix task count * fix(tests): fix 3 failing dispatch tests - test_registry_all_types: skip non-EnvConfig stubs (e.g. TestPluginConfig) - test_processors_delegation: use None instead of abstract PreTrainedConfig - test_custom_get_env_processors_override: use DataProcessorPipeline for isinstance check (PolicyProcessorPipeline is a subscripted generic) * fix: enable SmolVLA eval on LIBERO with custom camera mappings - Thread camera_name_mapping from LiberoEnv config through to gym envs - Sync features_map with camera_name_mapping in LiberoEnv.__post_init__ - Fix render() to use first available camera instead of hardcoded "image" - Handle non-dict final_info in rollout by falling back to info["is_success"] - Add use_peft legacy field to SmolVLAConfig for checkpoint compat - Add defaults to GR00TN15Config init=False fields for transformers 5.3 Made-with: Cursor * fix: use direct AutoresetMode import for gymnasium compat Made-with: Cursor * fix: handle gymnasium < 1.0 without AutoresetMode Made-with: Cursor * refactor: revert policy changes, keep env-only camera mapping fixes - Revert GR00T N1.5 default_factory/default changes (transformers compat) - Revert SmolVLA use_peft legacy field - Apply ruff formatting fixes - camera_name_mapping stays entirely in env/eval layer (no policy changes) Made-with: Cursor * Update docs/source/env_processor.mdx Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co> Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com> * Update docs/source/env_processor.mdx Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co> Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com> * Update docs/source/env_processor.mdx Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co> Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com> * fix(eval): raise RuntimeError for unsupported final_info format (Gymnasium < 1.0) Made-with: Cursor * style: fix markdown code fences in env_processor.mdx Made-with: Cursor * docs: remove duplicate code blocks in env_processor.mdx Made-with: Cursor * style: revert quadruple backticks to triple (prettier compat) * docs(env_processor): add EnvConfig subclass step and policy_cfg examples - Add missing '### 2. Update Your EnvConfig Subclass' section with get_env_processors() snippet - Update factory usage example to show policy_cfg parameter and keyword-argument style for both SmolVLA and ACT cases * docs(env_processor): rename step 2 and fix policy_cfg examples - Rename '### 2. Update the Factory' → '### 2. Update Your EnvConfig Subclass' - Update factory usage examples to use keyword-argument style with policy_cfg parameter for both SmolVLA and ACT cases --------- Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
This commit is contained in:
@@ -115,23 +115,22 @@ Each `EnvConfig` subclass declares two dicts that tell the policy what to expect
|
||||
## Step by step
|
||||
|
||||
<Tip>
|
||||
At minimum, you need three files: a **gym.Env wrapper**, an **EnvConfig
|
||||
subclass**, and a **factory dispatch branch**. Everything else is optional or
|
||||
documentation.
|
||||
At minimum, you need two files: a **gym.Env wrapper** and an **EnvConfig
|
||||
subclass** with a `create_envs()` override. Everything else is optional or
|
||||
documentation. No changes to `factory.py` are needed.
|
||||
</Tip>
|
||||
|
||||
### Checklist
|
||||
|
||||
| File | Required | Why |
|
||||
| ---------------------------------------- | -------- | ----------------------------------------- |
|
||||
| `src/lerobot/envs/<benchmark>.py` | Yes | Wraps the simulator as a standard gym.Env |
|
||||
| `src/lerobot/envs/configs.py` | Yes | Registers your benchmark for the CLI |
|
||||
| `src/lerobot/envs/factory.py` | Yes | Tells `make_env()` how to build your envs |
|
||||
| `src/lerobot/processor/env_processor.py` | Optional | Custom observation/action transforms |
|
||||
| `src/lerobot/envs/utils.py` | Optional | Only if you need new raw observation keys |
|
||||
| `pyproject.toml` | Yes | Declares benchmark-specific dependencies |
|
||||
| `docs/source/<benchmark>.mdx` | Yes | User-facing documentation page |
|
||||
| `docs/source/_toctree.yml` | Yes | Adds your page to the docs sidebar |
|
||||
| File | Required | Why |
|
||||
| ---------------------------------------- | -------- | ------------------------------------------------------------ |
|
||||
| `src/lerobot/envs/<benchmark>.py` | Yes | Wraps the simulator as a standard gym.Env |
|
||||
| `src/lerobot/envs/configs.py` | Yes | Registers your benchmark and its `create_envs()` for the CLI |
|
||||
| `src/lerobot/processor/env_processor.py` | Optional | Custom observation/action transforms |
|
||||
| `src/lerobot/envs/utils.py` | Optional | Only if you need new raw observation keys |
|
||||
| `pyproject.toml` | Yes | Declares benchmark-specific dependencies |
|
||||
| `docs/source/<benchmark>.mdx` | Yes | User-facing documentation page |
|
||||
| `docs/source/_toctree.yml` | Yes | Adds your page to the docs sidebar |
|
||||
|
||||
### 1. The gym.Env wrapper (`src/lerobot/envs/<benchmark>.py`)
|
||||
|
||||
@@ -179,7 +178,10 @@ See `create_libero_envs()` (multi-suite, multi-task) and `create_metaworld_envs(
|
||||
|
||||
### 2. The config (`src/lerobot/envs/configs.py`)
|
||||
|
||||
Register a config dataclass so users can select your benchmark with `--env.type=<name>`:
|
||||
Register a config dataclass so users can select your benchmark with `--env.type=<name>`. Each config owns its environment creation and processor logic via two methods:
|
||||
|
||||
- **`create_envs(n_envs, use_async_envs)`** — Returns `{suite: {task_id: VectorEnv}}`. The base class default uses `gym.make()` for single-task envs. Multi-task benchmarks override this.
|
||||
- **`get_env_processors()`** — Returns `(preprocessor, postprocessor)`. The base class default returns identity (no-op) pipelines. Override if your benchmark needs observation/action transforms.
|
||||
|
||||
```python
|
||||
@EnvConfig.register_subclass("<benchmark_name>")
|
||||
@@ -204,6 +206,20 @@ class MyBenchmarkEnvConfig(EnvConfig):
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
||||
|
||||
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
||||
"""Override for multi-task benchmarks or custom env creation."""
|
||||
from lerobot.envs.<benchmark> import create_<benchmark>_envs
|
||||
return create_<benchmark>_envs(task=self.task, n_envs=n_envs, ...)
|
||||
|
||||
def get_env_processors(self):
|
||||
"""Override if your benchmark needs observation/action transforms."""
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
from lerobot.processor.env_processor import MyBenchmarkProcessorStep
|
||||
return (
|
||||
PolicyProcessorPipeline(steps=[MyBenchmarkProcessorStep()]),
|
||||
PolicyProcessorPipeline(steps=[]),
|
||||
)
|
||||
```
|
||||
|
||||
Key points:
|
||||
@@ -211,36 +227,11 @@ Key points:
|
||||
- The `register_subclass` name is what users pass on the CLI (`--env.type=<name>`).
|
||||
- `features` tells the policy what the environment produces.
|
||||
- `features_map` maps raw observation keys to LeRobot convention keys.
|
||||
- **No changes to `factory.py` needed** — the factory delegates to `cfg.create_envs()` and `cfg.get_env_processors()` automatically.
|
||||
|
||||
### 3. The factory dispatch (`src/lerobot/envs/factory.py`)
|
||||
### 3. Env processor (optional — `src/lerobot/processor/env_processor.py`)
|
||||
|
||||
Add a branch in `make_env()` to call your factory function:
|
||||
|
||||
```python
|
||||
elif "<benchmark_name>" in cfg.type:
|
||||
from lerobot.envs.<benchmark> import create_<benchmark>_envs
|
||||
|
||||
if cfg.task is None:
|
||||
raise ValueError("<BenchmarkName> requires a task to be specified")
|
||||
|
||||
return create_<benchmark>_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
```
|
||||
|
||||
If your benchmark needs an env processor, add it in `make_env_pre_post_processors()`:
|
||||
|
||||
```python
|
||||
if isinstance(env_cfg, MyBenchmarkEnvConfig) or "<benchmark_name>" in env_cfg.type:
|
||||
preprocessor_steps.append(MyBenchmarkProcessorStep())
|
||||
```
|
||||
|
||||
### 4. Env processor (optional — `src/lerobot/processor/env_processor.py`)
|
||||
|
||||
Only needed if your benchmark requires observation transforms beyond what `preprocess_observation()` handles (e.g. image flipping, coordinate conversion):
|
||||
Only needed if your benchmark requires observation transforms beyond what `preprocess_observation()` handles (e.g. image flipping, coordinate conversion). Define the processor step here and return it from `get_env_processors()` in your config (see step 2):
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
@@ -260,7 +251,7 @@ class MyBenchmarkProcessorStep(ObservationProcessorStep):
|
||||
|
||||
See `LiberoProcessorStep` for a full example (image rotation, quaternion-to-axis-angle conversion).
|
||||
|
||||
### 5. Dependencies (`pyproject.toml`)
|
||||
### 4. Dependencies (`pyproject.toml`)
|
||||
|
||||
Add a new optional-dependency group:
|
||||
|
||||
@@ -281,11 +272,11 @@ Users install with:
|
||||
pip install -e ".[mybenchmark]"
|
||||
```
|
||||
|
||||
### 6. Documentation (`docs/source/<benchmark>.mdx`)
|
||||
### 5. Documentation (`docs/source/<benchmark>.mdx`)
|
||||
|
||||
Write a user-facing page following the template in the next section. See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for full examples.
|
||||
|
||||
### 7. Table of contents (`docs/source/_toctree.yml`)
|
||||
### 6. Table of contents (`docs/source/_toctree.yml`)
|
||||
|
||||
Add your benchmark to the "Benchmarks" section:
|
||||
|
||||
|
||||
@@ -90,11 +90,17 @@ The same policy can work with different environment processors, and the same env
|
||||
|
||||
```python
|
||||
# Use SmolVLA policy with LIBERO environment
|
||||
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
|
||||
# Use SmolVLA policy with LIBERO environment
|
||||
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(
|
||||
env_cfg=libero_cfg,
|
||||
policy_cfg=smolvla_cfg,
|
||||
)
|
||||
smolvla_preprocessor, smolvla_postprocessor = make_pre_post_processors(smolvla_cfg)
|
||||
|
||||
# Or use ACT policy with the same LIBERO environment
|
||||
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
|
||||
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(
|
||||
env_cfg=libero_cfg,
|
||||
policy_cfg=act_cfg,
|
||||
)
|
||||
act_preprocessor, act_postprocessor = make_pre_post_processors(act_cfg)
|
||||
```
|
||||
|
||||
@@ -151,7 +157,7 @@ observation = {
|
||||
|
||||
### Factory Function
|
||||
|
||||
The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
|
||||
The `make_env_pre_post_processors` function delegates to `env_cfg.get_env_processors()`:
|
||||
|
||||
```python
|
||||
from lerobot.envs.factory import make_env_pre_post_processors
|
||||
@@ -159,47 +165,31 @@ from lerobot.envs.configs import LiberoEnv, PushtEnv
|
||||
|
||||
# For LIBERO: Returns LiberoProcessorStep in preprocessor
|
||||
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg, policy_cfg)
|
||||
|
||||
# For other environments: Returns identity processors (no-op)
|
||||
pusht_cfg = PushtEnv()
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg, policy_cfg)
|
||||
```
|
||||
|
||||
### Implementation in `envs/factory.py`
|
||||
### How It Works
|
||||
|
||||
Each `EnvConfig` subclass can override `get_env_processors()` to return benchmark-specific
|
||||
processor pipelines. The base class returns identity (no-op) processors by default.
|
||||
|
||||
```python
|
||||
def make_env_pre_post_processors(
|
||||
env_cfg: EnvConfig,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
]:
|
||||
"""
|
||||
Create preprocessor and postprocessor pipelines for environment observations.
|
||||
|
||||
Args:
|
||||
env_cfg: The configuration of the environment.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- preprocessor: Pipeline that processes environment observations
|
||||
- postprocessor: Pipeline that processes environment outputs
|
||||
"""
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
|
||||
else:
|
||||
# For all other environments, return an identity preprocessor
|
||||
preprocessor = PolicyProcessorPipeline(steps=[])
|
||||
|
||||
# Postprocessor is currently identity for all environments
|
||||
# Future: Could add environment-specific action transformations
|
||||
postprocessor = PolicyProcessorPipeline(steps=[])
|
||||
|
||||
return preprocessor, postprocessor
|
||||
# In your EnvConfig subclass:
|
||||
def get_env_processors(self):
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
return (
|
||||
PolicyProcessorPipeline(steps=[MyProcessorStep()]),
|
||||
PolicyProcessorPipeline(steps=[]),
|
||||
)
|
||||
```
|
||||
|
||||
The factory function `make_env_pre_post_processors` simply delegates to this method,
|
||||
with a special case for `XVLAConfig` policies which override the env processors entirely.
|
||||
|
||||
### Integration in Evaluation
|
||||
|
||||
In `lerobot_eval.py`, the environment processors are created once and used throughout:
|
||||
@@ -219,7 +209,10 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
)
|
||||
|
||||
# Create environment processors (NEW!)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(
|
||||
env_cfg=cfg.env,
|
||||
policy_cfg=cfg.policy,
|
||||
)
|
||||
|
||||
# Run evaluation with both processor types
|
||||
eval_policy_all(
|
||||
@@ -323,21 +316,22 @@ class MyEnvProcessorStep(ObservationProcessorStep):
|
||||
return processed
|
||||
```
|
||||
|
||||
### 2. Update the Factory
|
||||
### 2. Update Your `EnvConfig` Subclass
|
||||
|
||||
```python
|
||||
# In src/lerobot/envs/factory.py
|
||||
# In src/lerobot/envs/configs.py
|
||||
@EnvConfig.register_subclass("myenv")
|
||||
@dataclass
|
||||
class MyEnvConfig(EnvConfig):
|
||||
# ... task/features/gym kwargs ...
|
||||
|
||||
def make_env_pre_post_processors(env_cfg: EnvConfig):
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
|
||||
elif isinstance(env_cfg, MyEnvConfig) or "myenv" in env_cfg.type:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[MyEnvProcessorStep()])
|
||||
else:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[])
|
||||
def get_env_processors(self):
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
postprocessor = PolicyProcessorPipeline(steps=[])
|
||||
return preprocessor, postprocessor
|
||||
return (
|
||||
PolicyProcessorPipeline(steps=[MyEnvProcessorStep()]),
|
||||
PolicyProcessorPipeline(steps=[]),
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Use in Evaluation
|
||||
|
||||
@@ -12,11 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import importlib
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.robots import RobotConfig
|
||||
@@ -67,6 +72,49 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
def gym_kwargs(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
def create_envs(
|
||||
self,
|
||||
n_envs: int,
|
||||
use_async_envs: bool = False,
|
||||
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
"""Create {suite: {task_id: VectorEnv}}.
|
||||
|
||||
Default: single-task env via gym.make(). Multi-task benchmarks override.
|
||||
"""
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
|
||||
if self.gym_id not in gym_registry:
|
||||
print(f"gym id '{self.gym_id}' not found, attempting to import '{self.package_name}'...")
|
||||
try:
|
||||
importlib.import_module(self.package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Package '{self.package_name}' required for env '{self.type}' not found. "
|
||||
f"Please install it or check PYTHONPATH."
|
||||
) from e
|
||||
|
||||
if self.gym_id not in gym_registry:
|
||||
raise gym.error.NameNotFound(
|
||||
f"Environment '{self.gym_id}' not registered even after importing '{self.package_name}'."
|
||||
)
|
||||
|
||||
def _make_one():
|
||||
return gym.make(self.gym_id, disable_env_checker=self.disable_env_checker, **self.gym_kwargs)
|
||||
|
||||
try:
|
||||
from gymnasium.vector import AutoresetMode
|
||||
|
||||
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=AutoresetMode.SAME_STEP)
|
||||
except ImportError:
|
||||
vec = env_cls([_make_one for _ in range(n_envs)])
|
||||
return {self.type: {0: vec}}
|
||||
|
||||
def get_env_processors(self):
|
||||
"""Return (preprocessor, postprocessor) for this env. Default: identity."""
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
|
||||
|
||||
|
||||
@dataclass
|
||||
class HubEnvConfig(EnvConfig):
|
||||
@@ -338,6 +386,12 @@ class LiberoEnv(EnvConfig):
|
||||
else:
|
||||
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
||||
|
||||
if self.camera_name_mapping is not None:
|
||||
mapped_agentview = self.camera_name_mapping.get("agentview_image", "image")
|
||||
mapped_eye_in_hand = self.camera_name_mapping.get("robot0_eye_in_hand_image", "image2")
|
||||
self.features_map[LIBERO_KEY_PIXELS_AGENTVIEW] = f"{OBS_IMAGES}.{mapped_agentview}"
|
||||
self.features_map[LIBERO_KEY_PIXELS_EYE_IN_HAND] = f"{OBS_IMAGES}.{mapped_eye_in_hand}"
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
||||
@@ -345,6 +399,33 @@ class LiberoEnv(EnvConfig):
|
||||
kwargs["task_ids"] = self.task_ids
|
||||
return kwargs
|
||||
|
||||
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
||||
from lerobot.envs.libero import create_libero_envs
|
||||
|
||||
if self.task is None:
|
||||
raise ValueError("LiberoEnv requires a task to be specified")
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
return create_libero_envs(
|
||||
task=self.task,
|
||||
n_envs=n_envs,
|
||||
camera_name=self.camera_name,
|
||||
init_states=self.init_states,
|
||||
gym_kwargs=self.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
control_mode=self.control_mode,
|
||||
episode_length=self.episode_length,
|
||||
camera_name_mapping=self.camera_name_mapping,
|
||||
)
|
||||
|
||||
def get_env_processors(self):
|
||||
from lerobot.processor.env_processor import LiberoProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline(steps=[LiberoProcessorStep()]),
|
||||
PolicyProcessorPipeline(steps=[]),
|
||||
)
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("metaworld")
|
||||
@dataclass
|
||||
@@ -387,6 +468,19 @@ class MetaworldEnv(EnvConfig):
|
||||
"render_mode": self.render_mode,
|
||||
}
|
||||
|
||||
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
||||
from lerobot.envs.metaworld import create_metaworld_envs
|
||||
|
||||
if self.task is None:
|
||||
raise ValueError("MetaWorld requires a task to be specified")
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
return create_metaworld_envs(
|
||||
task=self.task,
|
||||
n_envs=n_envs,
|
||||
gym_kwargs=self.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("isaaclab_arena")
|
||||
@dataclass
|
||||
@@ -454,3 +548,18 @@ class IsaaclabArenaEnv(HubEnvConfig):
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
def get_env_processors(self):
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip())
|
||||
camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip())
|
||||
if not state_keys and not camera_keys:
|
||||
raise ValueError("At least one of state_keys or camera_keys must be specified.")
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=[IsaaclabArenaProcessorStep(state_keys=state_keys, camera_keys=camera_keys)]
|
||||
),
|
||||
PolicyProcessorPipeline(steps=[]),
|
||||
)
|
||||
|
||||
+20
-117
@@ -13,90 +13,46 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
|
||||
from lerobot.envs.configs import EnvConfig, HubEnvConfig
|
||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import ProcessorStep
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
if env_type == "aloha":
|
||||
return AlohaEnv(**kwargs)
|
||||
elif env_type == "pusht":
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "libero":
|
||||
return LiberoEnv(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
try:
|
||||
cls = EnvConfig.get_choice_class(env_type)
|
||||
except KeyError as err:
|
||||
raise ValueError(
|
||||
f"Environment type '{env_type}' is not registered. "
|
||||
f"Available: {list(EnvConfig.get_known_choices().keys())}"
|
||||
) from err
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
def make_env_pre_post_processors(
|
||||
env_cfg: EnvConfig,
|
||||
policy_cfg: PreTrainedConfig,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
]:
|
||||
policy_cfg: Any,
|
||||
) -> tuple[Any, Any]:
|
||||
"""
|
||||
Create preprocessor and postprocessor pipelines for environment observations.
|
||||
|
||||
This function creates processor pipelines that transform raw environment
|
||||
observations and actions. By default, it returns identity processors that do nothing.
|
||||
For specific environments like LIBERO, it adds environment-specific processing steps.
|
||||
|
||||
Args:
|
||||
env_cfg: The configuration of the environment.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- preprocessor: Pipeline that processes environment observations
|
||||
- postprocessor: Pipeline that processes environment outputs (currently identity)
|
||||
Returns a tuple of (preprocessor, postprocessor). By default, delegates to
|
||||
``env_cfg.get_env_processors()``. The XVLAConfig policy-specific override
|
||||
stays here because it depends on the *policy* config, not the env config.
|
||||
"""
|
||||
# Preprocessor and Postprocessor steps are Identity for most environments
|
||||
preprocessor_steps: list[ProcessorStep] = []
|
||||
postprocessor_steps: list[ProcessorStep] = []
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
|
||||
if isinstance(policy_cfg, XVLAConfig):
|
||||
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
|
||||
|
||||
return make_xvla_libero_pre_post_processors()
|
||||
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
preprocessor_steps.append(LiberoProcessorStep())
|
||||
|
||||
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
|
||||
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
|
||||
# Parse comma-separated keys (handle None for state-based policies)
|
||||
if env_cfg.state_keys:
|
||||
state_keys = tuple(k.strip() for k in env_cfg.state_keys.split(",") if k.strip())
|
||||
else:
|
||||
state_keys = ()
|
||||
if env_cfg.camera_keys:
|
||||
camera_keys = tuple(k.strip() for k in env_cfg.camera_keys.split(",") if k.strip())
|
||||
else:
|
||||
camera_keys = ()
|
||||
if not state_keys and not camera_keys:
|
||||
raise ValueError("At least one of state_keys or camera_keys must be specified.")
|
||||
preprocessor_steps.append(
|
||||
IsaaclabArenaProcessorStep(
|
||||
state_keys=state_keys,
|
||||
camera_keys=camera_keys,
|
||||
)
|
||||
)
|
||||
|
||||
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
|
||||
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
|
||||
|
||||
return preprocessor, postprocessor
|
||||
return env_cfg.get_env_processors()
|
||||
|
||||
|
||||
def make_env(
|
||||
@@ -163,57 +119,4 @@ def make_env(
|
||||
if n_envs < 1:
|
||||
raise ValueError("`n_envs` must be at least 1")
|
||||
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
|
||||
if "libero" in cfg.type:
|
||||
from lerobot.envs.libero import create_libero_envs
|
||||
|
||||
if cfg.task is None:
|
||||
raise ValueError("LiberoEnv requires a task to be specified")
|
||||
|
||||
return create_libero_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
camera_name=cfg.camera_name,
|
||||
init_states=cfg.init_states,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
control_mode=cfg.control_mode,
|
||||
episode_length=cfg.episode_length,
|
||||
)
|
||||
elif "metaworld" in cfg.type:
|
||||
from lerobot.envs.metaworld import create_metaworld_envs
|
||||
|
||||
if cfg.task is None:
|
||||
raise ValueError("MetaWorld requires a task to be specified")
|
||||
|
||||
return create_metaworld_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
if cfg.gym_id not in gym_registry:
|
||||
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
|
||||
try:
|
||||
importlib.import_module(cfg.package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
|
||||
f"Please install it or check PYTHONPATH."
|
||||
) from e
|
||||
|
||||
if cfg.gym_id not in gym_registry:
|
||||
raise gym.error.NameNotFound(
|
||||
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
|
||||
)
|
||||
|
||||
def _make_one():
|
||||
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
|
||||
|
||||
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
|
||||
|
||||
# normalize to {suite: {task_id: vec_env}} for consistency
|
||||
suite_name = cfg.type # e.g., "pusht", "aloha"
|
||||
return {suite_name: {0: vec}}
|
||||
return cfg.create_envs(n_envs=n_envs, use_async_envs=use_async_envs)
|
||||
|
||||
@@ -223,7 +223,8 @@ class LiberoEnv(gym.Env):
|
||||
|
||||
def render(self):
|
||||
raw_obs = self._env.env._get_observations()
|
||||
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
|
||||
pixels = self._format_raw_obs(raw_obs)["pixels"]
|
||||
image = next(iter(pixels.values()))
|
||||
image = image[::-1, ::-1] # flip both H and W for visualization
|
||||
return image
|
||||
|
||||
@@ -339,12 +340,6 @@ class LiberoEnv(gym.Env):
|
||||
)
|
||||
observation = self._format_raw_obs(raw_obs)
|
||||
if terminated:
|
||||
info["final_info"] = {
|
||||
"task": self.task,
|
||||
"task_id": self.task_id,
|
||||
"done": bool(done),
|
||||
"is_success": bool(is_success),
|
||||
}
|
||||
self.reset()
|
||||
truncated = False
|
||||
return observation, reward, terminated, truncated, info
|
||||
@@ -364,6 +359,7 @@ def _make_env_fns(
|
||||
init_states: bool,
|
||||
gym_kwargs: Mapping[str, Any],
|
||||
control_mode: str,
|
||||
camera_name_mapping: dict[str, str] | None = None,
|
||||
) -> list[Callable[[], LiberoEnv]]:
|
||||
"""Build n_envs factory callables for a single (suite, task_id)."""
|
||||
|
||||
@@ -379,6 +375,7 @@ def _make_env_fns(
|
||||
episode_index=episode_index,
|
||||
n_envs=n_envs,
|
||||
control_mode=control_mode,
|
||||
camera_name_mapping=camera_name_mapping,
|
||||
**local_kwargs,
|
||||
)
|
||||
|
||||
@@ -400,6 +397,7 @@ def create_libero_envs(
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
control_mode: str = "relative",
|
||||
episode_length: int | None = None,
|
||||
camera_name_mapping: dict[str, str] | None = None,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""
|
||||
Create vectorized LIBERO environments with a consistent return shape.
|
||||
@@ -449,6 +447,7 @@ def create_libero_envs(
|
||||
init_states=init_states,
|
||||
gym_kwargs=gym_kwargs,
|
||||
control_mode=control_mode,
|
||||
camera_name_mapping=camera_name_mapping,
|
||||
)
|
||||
out[suite_name][tid] = env_cls(fns)
|
||||
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||
|
||||
@@ -201,6 +201,11 @@ def rollout(
|
||||
"You're likely using an older version of gymnasium (< 1.0). Please upgrade."
|
||||
)
|
||||
successes = final_info["is_success"].tolist()
|
||||
elif "is_success" in info:
|
||||
is_success = info["is_success"]
|
||||
successes = (
|
||||
is_success.tolist() if hasattr(is_success, "tolist") else [bool(is_success)] * env.num_envs
|
||||
)
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
"""Tests for the benchmark dispatch refactor (create_envs / get_env_processors on EnvConfig)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import gymnasium as gym
|
||||
import pytest
|
||||
from gymnasium.envs.registration import register, registry as gym_registry
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_registry_all_types():
|
||||
"""make_env_config should resolve every registered EnvConfig subclass via the registry."""
|
||||
known = list(EnvConfig.get_known_choices().keys())
|
||||
assert len(known) >= 6
|
||||
for t in known:
|
||||
cfg = make_env_config(t)
|
||||
if not isinstance(cfg, EnvConfig):
|
||||
continue
|
||||
assert cfg.type == t
|
||||
|
||||
|
||||
def test_unknown_type():
|
||||
with pytest.raises(ValueError, match="not registered"):
|
||||
make_env_config("nonexistent")
|
||||
|
||||
|
||||
def test_identity_processors():
|
||||
"""Base class get_env_processors() returns identity pipelines."""
|
||||
cfg = make_env_config("aloha")
|
||||
pre, post = cfg.get_env_processors()
|
||||
assert len(pre.steps) == 0 and len(post.steps) == 0
|
||||
|
||||
|
||||
def test_delegation():
|
||||
"""make_env() should call cfg.create_envs(), not use if/elif dispatch."""
|
||||
sentinel = {"delegated": {0: "marker"}}
|
||||
fake = type(
|
||||
"Fake",
|
||||
(),
|
||||
{
|
||||
"hub_path": None,
|
||||
"create_envs": lambda self, n_envs, use_async_envs=False: sentinel,
|
||||
},
|
||||
)()
|
||||
result = make_env(fake, n_envs=1)
|
||||
assert result is sentinel
|
||||
|
||||
|
||||
def test_processors_delegation():
|
||||
"""make_env_pre_post_processors delegates to cfg.get_env_processors()."""
|
||||
cfg = make_env_config("aloha")
|
||||
pre, post = make_env_pre_post_processors(cfg, policy_cfg=None)
|
||||
assert len(pre.steps) == 0
|
||||
|
||||
|
||||
def test_base_create_envs():
|
||||
"""Base class create_envs() should build a single-task VectorEnv via gym.make()."""
|
||||
gym_id = "_dispatch_test/CartPole-v99"
|
||||
if gym_id not in gym_registry:
|
||||
register(id=gym_id, entry_point="gymnasium.envs.classic_control:CartPoleEnv")
|
||||
|
||||
@EnvConfig.register_subclass("_dispatch_base_test")
|
||||
@dataclass
|
||||
class _Env(EnvConfig):
|
||||
task: str = "CartPole-v99"
|
||||
fps: int = 10
|
||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def package_name(self):
|
||||
return "_dispatch_test"
|
||||
|
||||
@property
|
||||
def gym_id(self):
|
||||
return gym_id
|
||||
|
||||
@property
|
||||
def gym_kwargs(self):
|
||||
return {}
|
||||
|
||||
try:
|
||||
envs = _Env().create_envs(n_envs=2)
|
||||
assert "_dispatch_base_test" in envs
|
||||
env = envs["_dispatch_base_test"][0]
|
||||
assert isinstance(env, gym.vector.SyncVectorEnv)
|
||||
assert env.num_envs == 2
|
||||
env.close()
|
||||
finally:
|
||||
if gym_id in gym_registry:
|
||||
del gym_registry[gym_id]
|
||||
|
||||
|
||||
def test_custom_create_envs_override():
|
||||
"""A custom EnvConfig subclass can override create_envs()."""
|
||||
mock_vec = gym.vector.SyncVectorEnv([lambda: gym.make("CartPole-v1")])
|
||||
|
||||
@EnvConfig.register_subclass("_dispatch_custom_test")
|
||||
@dataclass
|
||||
class _Env(EnvConfig):
|
||||
task: str = "x"
|
||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self):
|
||||
return {}
|
||||
|
||||
def create_envs(self, n_envs, use_async_envs=False):
|
||||
return {"custom_suite": {0: mock_vec}}
|
||||
|
||||
try:
|
||||
result = make_env(_Env(), n_envs=1)
|
||||
assert "custom_suite" in result
|
||||
finally:
|
||||
mock_vec.close()
|
||||
|
||||
|
||||
def test_custom_get_env_processors_override():
|
||||
"""A custom EnvConfig subclass can override get_env_processors()."""
|
||||
from lerobot.processor.pipeline import DataProcessorPipeline
|
||||
|
||||
@EnvConfig.register_subclass("_dispatch_proc_test")
|
||||
@dataclass
|
||||
class _Env(EnvConfig):
|
||||
task: str = "x"
|
||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self):
|
||||
return {}
|
||||
|
||||
def get_env_processors(self):
|
||||
return DataProcessorPipeline(steps=[]), DataProcessorPipeline(steps=[])
|
||||
|
||||
pre, post = _Env().get_env_processors()
|
||||
assert isinstance(pre, DataProcessorPipeline)
|
||||
Reference in New Issue
Block a user