mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
refactor(envs): move dispatch logic from factory into EnvConfig subclasses
Replace hardcoded if/elif chains in factory.py with create_envs() and get_env_processors() methods on EnvConfig. New benchmarks now only need to register a config subclass — no factory.py edits required. Net -23 lines: factory.py shrinks from ~200 to ~70 lines of logic. Made-with: Cursor
This commit is contained in:
@@ -103,16 +103,15 @@ Each `EnvConfig` subclass declares:
|
|||||||
|
|
||||||
### Checklist
|
### Checklist
|
||||||
|
|
||||||
| File | Required | Description |
|
| File | Required | Description |
|
||||||
| ---------------------------------------- | -------- | ----------------------------------------------------------------------------------- |
|
| ---------------------------------------- | -------- | -------------------------------------------------------------------------------- |
|
||||||
| `src/lerobot/envs/<benchmark>.py` | Yes | `gym.Env` subclass + `create_<benchmark>_envs()` factory |
|
| `src/lerobot/envs/<benchmark>.py` | Yes | `gym.Env` subclass + `create_<benchmark>_envs()` factory |
|
||||||
| `src/lerobot/envs/configs.py` | Yes | `@EnvConfig.register_subclass("<name>")` dataclass |
|
| `src/lerobot/envs/configs.py` | Yes | `@EnvConfig.register_subclass("<name>")` dataclass with `create_envs()` override |
|
||||||
| `src/lerobot/envs/factory.py` | Yes | Add dispatch branch in `make_env()` and optionally `make_env_pre_post_processors()` |
|
| `src/lerobot/processor/env_processor.py` | Optional | `ProcessorStep` subclass for env-specific observation transforms |
|
||||||
| `src/lerobot/processor/env_processor.py` | Optional | `ProcessorStep` subclass for env-specific observation transforms |
|
| `src/lerobot/envs/utils.py` | Optional | Extend `preprocess_observation()` if new raw keys are needed |
|
||||||
| `src/lerobot/envs/utils.py` | Optional | Extend `preprocess_observation()` if new raw keys are needed |
|
| `pyproject.toml` | Yes | Add optional dependency group |
|
||||||
| `pyproject.toml` | Yes | Add optional dependency group |
|
| `docs/source/<benchmark>.mdx` | Yes | User-facing benchmark documentation |
|
||||||
| `docs/source/<benchmark>.mdx` | Yes | User-facing benchmark documentation |
|
| `docs/source/_toctree.yml` | Yes | Add entry under the "Benchmarks" section |
|
||||||
| `docs/source/_toctree.yml` | Yes | Add entry under the "Benchmarks" section |
|
|
||||||
|
|
||||||
### 1. The gym.Env wrapper (`src/lerobot/envs/<benchmark>.py`)
|
### 1. The gym.Env wrapper (`src/lerobot/envs/<benchmark>.py`)
|
||||||
|
|
||||||
@@ -169,7 +168,10 @@ See `create_libero_envs()` in `src/lerobot/envs/libero.py` (multi-suite, multi-t
|
|||||||
|
|
||||||
### 2. The config (`src/lerobot/envs/configs.py`)
|
### 2. The config (`src/lerobot/envs/configs.py`)
|
||||||
|
|
||||||
Register a new config dataclass:
|
Register a new config dataclass. Each config owns its environment creation and processor logic via two methods:
|
||||||
|
|
||||||
|
- **`create_envs(n_envs, use_async_envs)`** — Returns `{suite: {task_id: VectorEnv}}`. The base class default uses `gym.make()` for single-task envs. Multi-task benchmarks override this.
|
||||||
|
- **`get_env_processors()`** — Returns `(preprocessor, postprocessor)`. The base class default returns identity (no-op) pipelines. Override if your benchmark needs observation/action transforms.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@EnvConfig.register_subclass("<benchmark_name>")
|
@EnvConfig.register_subclass("<benchmark_name>")
|
||||||
@@ -196,6 +198,20 @@ class MyBenchmarkEnv(EnvConfig):
|
|||||||
@property
|
@property
|
||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
return {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
return {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
||||||
|
|
||||||
|
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
||||||
|
"""Override for multi-task benchmarks or custom env creation."""
|
||||||
|
from lerobot.envs.<benchmark> import create_<benchmark>_envs
|
||||||
|
return create_<benchmark>_envs(task=self.task, n_envs=n_envs, ...)
|
||||||
|
|
||||||
|
def get_env_processors(self):
|
||||||
|
"""Override if your benchmark needs observation/action transforms."""
|
||||||
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
|
from lerobot.processor.env_processor import MyBenchmarkProcessorStep
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline(steps=[MyBenchmarkProcessorStep()]),
|
||||||
|
PolicyProcessorPipeline(steps=[]),
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Key points:
|
Key points:
|
||||||
@@ -203,36 +219,11 @@ Key points:
|
|||||||
- The `register_subclass` name is what users pass as `--env.type=<name>` on the CLI.
|
- The `register_subclass` name is what users pass as `--env.type=<name>` on the CLI.
|
||||||
- `features` declares what the environment produces (used to configure the policy).
|
- `features` declares what the environment produces (used to configure the policy).
|
||||||
- `features_map` maps raw observation keys to LeRobot convention keys.
|
- `features_map` maps raw observation keys to LeRobot convention keys.
|
||||||
|
- **No changes to `factory.py` needed** — the factory delegates to `cfg.create_envs()` and `cfg.get_env_processors()` automatically.
|
||||||
|
|
||||||
### 3. The factory dispatch (`src/lerobot/envs/factory.py`)
|
### 3. Env processor (optional) (`src/lerobot/processor/env_processor.py`)
|
||||||
|
|
||||||
Add a branch in `make_env()`:
|
If your benchmark needs observation transforms beyond what `preprocess_observation()` handles (e.g., image flipping, coordinate frame conversion), add a `ProcessorStep` and return it from `get_env_processors()` in your config (see step 2):
|
||||||
|
|
||||||
```python
|
|
||||||
elif "<benchmark_name>" in cfg.type:
|
|
||||||
from lerobot.envs.<benchmark> import create_<benchmark>_envs
|
|
||||||
|
|
||||||
if cfg.task is None:
|
|
||||||
raise ValueError("<BenchmarkName> requires a task to be specified")
|
|
||||||
|
|
||||||
return create_<benchmark>_envs(
|
|
||||||
task=cfg.task,
|
|
||||||
n_envs=n_envs,
|
|
||||||
gym_kwargs=cfg.gym_kwargs,
|
|
||||||
env_cls=env_cls,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
If your benchmark needs an env processor, add it in `make_env_pre_post_processors()`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
if isinstance(env_cfg, MyBenchmarkEnv) or "<benchmark_name>" in env_cfg.type:
|
|
||||||
preprocessor_steps.append(MyBenchmarkProcessorStep())
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Env processor (optional) (`src/lerobot/processor/env_processor.py`)
|
|
||||||
|
|
||||||
If your benchmark needs observation transforms beyond what `preprocess_observation()` handles (e.g., image flipping, coordinate frame conversion), add a `ProcessorStep`:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -253,7 +244,7 @@ class MyBenchmarkProcessorStep(ObservationProcessorStep):
|
|||||||
|
|
||||||
See `LiberoProcessorStep` for a full example (image rotation, quaternion-to-axis-angle conversion).
|
See `LiberoProcessorStep` for a full example (image rotation, quaternion-to-axis-angle conversion).
|
||||||
|
|
||||||
### 5. Dependencies (`pyproject.toml`)
|
### 4. Dependencies (`pyproject.toml`)
|
||||||
|
|
||||||
Add a new optional-dependency group under `[project.optional-dependencies]`:
|
Add a new optional-dependency group under `[project.optional-dependencies]`:
|
||||||
|
|
||||||
@@ -274,11 +265,11 @@ Users install with:
|
|||||||
pip install -e ".[mybenchmark]"
|
pip install -e ".[mybenchmark]"
|
||||||
```
|
```
|
||||||
|
|
||||||
### 6. Documentation (`docs/source/<benchmark>.mdx`)
|
### 5. Documentation (`docs/source/<benchmark>.mdx`)
|
||||||
|
|
||||||
Follow the template below. See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for full examples.
|
Follow the template below. See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for full examples.
|
||||||
|
|
||||||
### 7. Table of contents (`docs/source/_toctree.yml`)
|
### 6. Table of contents (`docs/source/_toctree.yml`)
|
||||||
|
|
||||||
Add your benchmark under the "Benchmarks" section:
|
Add your benchmark under the "Benchmarks" section:
|
||||||
|
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ observation = {
|
|||||||
|
|
||||||
### Factory Function
|
### Factory Function
|
||||||
|
|
||||||
The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
|
The `make_env_pre_post_processors` function delegates to `env_cfg.get_env_processors()`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from lerobot.envs.factory import make_env_pre_post_processors
|
from lerobot.envs.factory import make_env_pre_post_processors
|
||||||
@@ -159,47 +159,31 @@ from lerobot.envs.configs import LiberoEnv, PushtEnv
|
|||||||
|
|
||||||
# For LIBERO: Returns LiberoProcessorStep in preprocessor
|
# For LIBERO: Returns LiberoProcessorStep in preprocessor
|
||||||
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
|
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
|
||||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
|
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg, policy_cfg)
|
||||||
|
|
||||||
# For other environments: Returns identity processors (no-op)
|
# For other environments: Returns identity processors (no-op)
|
||||||
pusht_cfg = PushtEnv()
|
pusht_cfg = PushtEnv()
|
||||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
|
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg, policy_cfg)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Implementation in `envs/factory.py`
|
### How It Works
|
||||||
|
|
||||||
|
Each `EnvConfig` subclass can override `get_env_processors()` to return benchmark-specific
|
||||||
|
processor pipelines. The base class returns identity (no-op) processors by default.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def make_env_pre_post_processors(
|
# In your EnvConfig subclass:
|
||||||
env_cfg: EnvConfig,
|
def get_env_processors(self):
|
||||||
) -> tuple[
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
return (
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
PolicyProcessorPipeline(steps=[MyProcessorStep()]),
|
||||||
]:
|
PolicyProcessorPipeline(steps=[]),
|
||||||
"""
|
)
|
||||||
Create preprocessor and postprocessor pipelines for environment observations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
env_cfg: The configuration of the environment.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple containing:
|
|
||||||
- preprocessor: Pipeline that processes environment observations
|
|
||||||
- postprocessor: Pipeline that processes environment outputs
|
|
||||||
"""
|
|
||||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
|
||||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
|
||||||
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
|
|
||||||
else:
|
|
||||||
# For all other environments, return an identity preprocessor
|
|
||||||
preprocessor = PolicyProcessorPipeline(steps=[])
|
|
||||||
|
|
||||||
# Postprocessor is currently identity for all environments
|
|
||||||
# Future: Could add environment-specific action transformations
|
|
||||||
postprocessor = PolicyProcessorPipeline(steps=[])
|
|
||||||
|
|
||||||
return preprocessor, postprocessor
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The factory function `make_env_pre_post_processors` simply delegates to this method,
|
||||||
|
with a special case for `XVLAConfig` policies which override the env processors entirely.
|
||||||
|
|
||||||
### Integration in Evaluation
|
### Integration in Evaluation
|
||||||
|
|
||||||
In `lerobot_eval.py`, the environment processors are created once and used throughout:
|
In `lerobot_eval.py`, the environment processors are created once and used throughout:
|
||||||
|
|||||||
@@ -12,11 +12,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import importlib
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.envs.registration import registry as gym_registry
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
from lerobot.robots import RobotConfig
|
from lerobot.robots import RobotConfig
|
||||||
@@ -67,6 +72,44 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def create_envs(
|
||||||
|
self,
|
||||||
|
n_envs: int,
|
||||||
|
use_async_envs: bool = False,
|
||||||
|
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||||
|
"""Create {suite: {task_id: VectorEnv}}.
|
||||||
|
|
||||||
|
Default: single-task env via gym.make(). Multi-task benchmarks override.
|
||||||
|
"""
|
||||||
|
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||||
|
|
||||||
|
if self.gym_id not in gym_registry:
|
||||||
|
print(f"gym id '{self.gym_id}' not found, attempting to import '{self.package_name}'...")
|
||||||
|
try:
|
||||||
|
importlib.import_module(self.package_name)
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
f"Package '{self.package_name}' required for env '{self.type}' not found. "
|
||||||
|
f"Please install it or check PYTHONPATH."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if self.gym_id not in gym_registry:
|
||||||
|
raise gym.error.NameNotFound(
|
||||||
|
f"Environment '{self.gym_id}' not registered even after importing '{self.package_name}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_one():
|
||||||
|
return gym.make(self.gym_id, disable_env_checker=self.disable_env_checker, **self.gym_kwargs)
|
||||||
|
|
||||||
|
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
|
||||||
|
return {self.type: {0: vec}}
|
||||||
|
|
||||||
|
def get_env_processors(self):
|
||||||
|
"""Return (preprocessor, postprocessor) for this env. Default: identity."""
|
||||||
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
|
|
||||||
|
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HubEnvConfig(EnvConfig):
|
class HubEnvConfig(EnvConfig):
|
||||||
@@ -345,6 +388,32 @@ class LiberoEnv(EnvConfig):
|
|||||||
kwargs["task_ids"] = self.task_ids
|
kwargs["task_ids"] = self.task_ids
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
||||||
|
from lerobot.envs.libero import create_libero_envs
|
||||||
|
|
||||||
|
if self.task is None:
|
||||||
|
raise ValueError("LiberoEnv requires a task to be specified")
|
||||||
|
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||||
|
return create_libero_envs(
|
||||||
|
task=self.task,
|
||||||
|
n_envs=n_envs,
|
||||||
|
camera_name=self.camera_name,
|
||||||
|
init_states=self.init_states,
|
||||||
|
gym_kwargs=self.gym_kwargs,
|
||||||
|
env_cls=env_cls,
|
||||||
|
control_mode=self.control_mode,
|
||||||
|
episode_length=self.episode_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_env_processors(self):
|
||||||
|
from lerobot.processor.env_processor import LiberoProcessorStep
|
||||||
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
|
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline(steps=[LiberoProcessorStep()]),
|
||||||
|
PolicyProcessorPipeline(steps=[]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass("metaworld")
|
@EnvConfig.register_subclass("metaworld")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -387,6 +456,19 @@ class MetaworldEnv(EnvConfig):
|
|||||||
"render_mode": self.render_mode,
|
"render_mode": self.render_mode,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
||||||
|
from lerobot.envs.metaworld import create_metaworld_envs
|
||||||
|
|
||||||
|
if self.task is None:
|
||||||
|
raise ValueError("MetaWorld requires a task to be specified")
|
||||||
|
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||||
|
return create_metaworld_envs(
|
||||||
|
task=self.task,
|
||||||
|
n_envs=n_envs,
|
||||||
|
gym_kwargs=self.gym_kwargs,
|
||||||
|
env_cls=env_cls,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass("isaaclab_arena")
|
@EnvConfig.register_subclass("isaaclab_arena")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -454,3 +536,18 @@ class IsaaclabArenaEnv(HubEnvConfig):
|
|||||||
@property
|
@property
|
||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def get_env_processors(self):
|
||||||
|
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep
|
||||||
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
|
|
||||||
|
state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip())
|
||||||
|
camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip())
|
||||||
|
if not state_keys and not camera_keys:
|
||||||
|
raise ValueError("At least one of state_keys or camera_keys must be specified.")
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline(
|
||||||
|
steps=[IsaaclabArenaProcessorStep(state_keys=state_keys, camera_keys=camera_keys)]
|
||||||
|
),
|
||||||
|
PolicyProcessorPipeline(steps=[]),
|
||||||
|
)
|
||||||
|
|||||||
+20
-117
@@ -13,90 +13,46 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import importlib
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.envs.registration import registry as gym_registry
|
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.envs.configs import EnvConfig, HubEnvConfig
|
||||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
|
|
||||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
|
||||||
from lerobot.processor import ProcessorStep
|
|
||||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
|
|
||||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
|
||||||
|
|
||||||
|
|
||||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||||
if env_type == "aloha":
|
try:
|
||||||
return AlohaEnv(**kwargs)
|
cls = EnvConfig.get_choice_class(env_type)
|
||||||
elif env_type == "pusht":
|
except KeyError as err:
|
||||||
return PushtEnv(**kwargs)
|
raise ValueError(
|
||||||
elif env_type == "libero":
|
f"Environment type '{env_type}' is not registered. "
|
||||||
return LiberoEnv(**kwargs)
|
f"Available: {list(EnvConfig.get_known_choices().keys())}"
|
||||||
else:
|
) from err
|
||||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def make_env_pre_post_processors(
|
def make_env_pre_post_processors(
|
||||||
env_cfg: EnvConfig,
|
env_cfg: EnvConfig,
|
||||||
policy_cfg: PreTrainedConfig,
|
policy_cfg: Any,
|
||||||
) -> tuple[
|
) -> tuple[Any, Any]:
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
|
||||||
]:
|
|
||||||
"""
|
"""
|
||||||
Create preprocessor and postprocessor pipelines for environment observations.
|
Create preprocessor and postprocessor pipelines for environment observations.
|
||||||
|
|
||||||
This function creates processor pipelines that transform raw environment
|
Returns a tuple of (preprocessor, postprocessor). By default, delegates to
|
||||||
observations and actions. By default, it returns identity processors that do nothing.
|
``env_cfg.get_env_processors()``. The XVLAConfig policy-specific override
|
||||||
For specific environments like LIBERO, it adds environment-specific processing steps.
|
stays here because it depends on the *policy* config, not the env config.
|
||||||
|
|
||||||
Args:
|
|
||||||
env_cfg: The configuration of the environment.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple containing:
|
|
||||||
- preprocessor: Pipeline that processes environment observations
|
|
||||||
- postprocessor: Pipeline that processes environment outputs (currently identity)
|
|
||||||
"""
|
"""
|
||||||
# Preprocessor and Postprocessor steps are Identity for most environments
|
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
preprocessor_steps: list[ProcessorStep] = []
|
|
||||||
postprocessor_steps: list[ProcessorStep] = []
|
|
||||||
if isinstance(policy_cfg, XVLAConfig):
|
if isinstance(policy_cfg, XVLAConfig):
|
||||||
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
|
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
|
||||||
|
|
||||||
return make_xvla_libero_pre_post_processors()
|
return make_xvla_libero_pre_post_processors()
|
||||||
|
|
||||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
return env_cfg.get_env_processors()
|
||||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
|
||||||
preprocessor_steps.append(LiberoProcessorStep())
|
|
||||||
|
|
||||||
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
|
|
||||||
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
|
|
||||||
# Parse comma-separated keys (handle None for state-based policies)
|
|
||||||
if env_cfg.state_keys:
|
|
||||||
state_keys = tuple(k.strip() for k in env_cfg.state_keys.split(",") if k.strip())
|
|
||||||
else:
|
|
||||||
state_keys = ()
|
|
||||||
if env_cfg.camera_keys:
|
|
||||||
camera_keys = tuple(k.strip() for k in env_cfg.camera_keys.split(",") if k.strip())
|
|
||||||
else:
|
|
||||||
camera_keys = ()
|
|
||||||
if not state_keys and not camera_keys:
|
|
||||||
raise ValueError("At least one of state_keys or camera_keys must be specified.")
|
|
||||||
preprocessor_steps.append(
|
|
||||||
IsaaclabArenaProcessorStep(
|
|
||||||
state_keys=state_keys,
|
|
||||||
camera_keys=camera_keys,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
|
|
||||||
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
|
|
||||||
|
|
||||||
return preprocessor, postprocessor
|
|
||||||
|
|
||||||
|
|
||||||
def make_env(
|
def make_env(
|
||||||
@@ -163,57 +119,4 @@ def make_env(
|
|||||||
if n_envs < 1:
|
if n_envs < 1:
|
||||||
raise ValueError("`n_envs` must be at least 1")
|
raise ValueError("`n_envs` must be at least 1")
|
||||||
|
|
||||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
return cfg.create_envs(n_envs=n_envs, use_async_envs=use_async_envs)
|
||||||
|
|
||||||
if "libero" in cfg.type:
|
|
||||||
from lerobot.envs.libero import create_libero_envs
|
|
||||||
|
|
||||||
if cfg.task is None:
|
|
||||||
raise ValueError("LiberoEnv requires a task to be specified")
|
|
||||||
|
|
||||||
return create_libero_envs(
|
|
||||||
task=cfg.task,
|
|
||||||
n_envs=n_envs,
|
|
||||||
camera_name=cfg.camera_name,
|
|
||||||
init_states=cfg.init_states,
|
|
||||||
gym_kwargs=cfg.gym_kwargs,
|
|
||||||
env_cls=env_cls,
|
|
||||||
control_mode=cfg.control_mode,
|
|
||||||
episode_length=cfg.episode_length,
|
|
||||||
)
|
|
||||||
elif "metaworld" in cfg.type:
|
|
||||||
from lerobot.envs.metaworld import create_metaworld_envs
|
|
||||||
|
|
||||||
if cfg.task is None:
|
|
||||||
raise ValueError("MetaWorld requires a task to be specified")
|
|
||||||
|
|
||||||
return create_metaworld_envs(
|
|
||||||
task=cfg.task,
|
|
||||||
n_envs=n_envs,
|
|
||||||
gym_kwargs=cfg.gym_kwargs,
|
|
||||||
env_cls=env_cls,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.gym_id not in gym_registry:
|
|
||||||
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
|
|
||||||
try:
|
|
||||||
importlib.import_module(cfg.package_name)
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
raise ModuleNotFoundError(
|
|
||||||
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
|
|
||||||
f"Please install it or check PYTHONPATH."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if cfg.gym_id not in gym_registry:
|
|
||||||
raise gym.error.NameNotFound(
|
|
||||||
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _make_one():
|
|
||||||
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
|
|
||||||
|
|
||||||
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
|
|
||||||
|
|
||||||
# normalize to {suite: {task_id: vec_env}} for consistency
|
|
||||||
suite_name = cfg.type # e.g., "pusht", "aloha"
|
|
||||||
return {suite_name: {0: vec}}
|
|
||||||
|
|||||||
@@ -0,0 +1,143 @@
|
|||||||
|
"""Tests for the benchmark dispatch refactor (create_envs / get_env_processors on EnvConfig)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import pytest
|
||||||
|
from gymnasium.envs.registration import register, registry as gym_registry
|
||||||
|
|
||||||
|
from lerobot.configs.types import PolicyFeature
|
||||||
|
from lerobot.envs.configs import EnvConfig
|
||||||
|
from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_all_types():
|
||||||
|
"""make_env_config should resolve every registered EnvConfig subclass via the registry."""
|
||||||
|
known = list(EnvConfig.get_known_choices().keys())
|
||||||
|
assert len(known) >= 6
|
||||||
|
for t in known:
|
||||||
|
cfg = make_env_config(t)
|
||||||
|
assert cfg.type == t
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_type():
|
||||||
|
with pytest.raises(ValueError, match="not registered"):
|
||||||
|
make_env_config("nonexistent")
|
||||||
|
|
||||||
|
|
||||||
|
def test_identity_processors():
|
||||||
|
"""Base class get_env_processors() returns identity pipelines."""
|
||||||
|
cfg = make_env_config("aloha")
|
||||||
|
pre, post = cfg.get_env_processors()
|
||||||
|
assert len(pre.steps) == 0 and len(post.steps) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_delegation():
|
||||||
|
"""make_env() should call cfg.create_envs(), not use if/elif dispatch."""
|
||||||
|
sentinel = {"delegated": {0: "marker"}}
|
||||||
|
fake = type(
|
||||||
|
"Fake",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"hub_path": None,
|
||||||
|
"create_envs": lambda self, n_envs, use_async_envs=False: sentinel,
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
result = make_env(fake, n_envs=1)
|
||||||
|
assert result is sentinel
|
||||||
|
|
||||||
|
|
||||||
|
def test_processors_delegation():
|
||||||
|
"""make_env_pre_post_processors delegates to cfg.get_env_processors()."""
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
|
||||||
|
cfg = make_env_config("aloha")
|
||||||
|
pre, post = make_env_pre_post_processors(cfg, PreTrainedConfig())
|
||||||
|
assert len(pre.steps) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_create_envs():
|
||||||
|
"""Base class create_envs() should build a single-task VectorEnv via gym.make()."""
|
||||||
|
gym_id = "_dispatch_test/CartPole-v99"
|
||||||
|
if gym_id not in gym_registry:
|
||||||
|
register(id=gym_id, entry_point="gymnasium.envs.classic_control:CartPoleEnv")
|
||||||
|
|
||||||
|
@EnvConfig.register_subclass("_dispatch_base_test")
|
||||||
|
@dataclass
|
||||||
|
class _Env(EnvConfig):
|
||||||
|
task: str = "CartPole-v99"
|
||||||
|
fps: int = 10
|
||||||
|
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def package_name(self):
|
||||||
|
return "_dispatch_test"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gym_id(self):
|
||||||
|
return gym_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gym_kwargs(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
envs = _Env().create_envs(n_envs=2)
|
||||||
|
assert "_dispatch_base_test" in envs
|
||||||
|
env = envs["_dispatch_base_test"][0]
|
||||||
|
assert isinstance(env, gym.vector.SyncVectorEnv)
|
||||||
|
assert env.num_envs == 2
|
||||||
|
env.close()
|
||||||
|
finally:
|
||||||
|
if gym_id in gym_registry:
|
||||||
|
del gym_registry[gym_id]
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_create_envs_override():
|
||||||
|
"""A custom EnvConfig subclass can override create_envs()."""
|
||||||
|
mock_vec = gym.vector.SyncVectorEnv([lambda: gym.make("CartPole-v1")])
|
||||||
|
|
||||||
|
@EnvConfig.register_subclass("_dispatch_custom_test")
|
||||||
|
@dataclass
|
||||||
|
class _Env(EnvConfig):
|
||||||
|
task: str = "x"
|
||||||
|
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gym_kwargs(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def create_envs(self, n_envs, use_async_envs=False):
|
||||||
|
return {"custom_suite": {0: mock_vec}}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = make_env(_Env(), n_envs=1)
|
||||||
|
assert "custom_suite" in result
|
||||||
|
finally:
|
||||||
|
mock_vec.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_get_env_processors_override():
|
||||||
|
"""A custom EnvConfig subclass can override get_env_processors()."""
|
||||||
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
|
|
||||||
|
@EnvConfig.register_subclass("_dispatch_proc_test")
|
||||||
|
@dataclass
|
||||||
|
class _Env(EnvConfig):
|
||||||
|
task: str = "x"
|
||||||
|
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gym_kwargs(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_env_processors(self):
|
||||||
|
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
|
||||||
|
|
||||||
|
pre, post = _Env().get_env_processors()
|
||||||
|
assert isinstance(pre, PolicyProcessorPipeline)
|
||||||
Reference in New Issue
Block a user