mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
refactor(envs): move dispatch logic from factory into EnvConfig subclasses
Replace hardcoded if/elif chains in factory.py with create_envs() and get_env_processors() methods on EnvConfig. New benchmarks now only need to register a config subclass — no factory.py edits required. Net -23 lines: factory.py shrinks from ~200 to ~70 lines of logic. Made-with: Cursor
This commit is contained in:
@@ -103,16 +103,15 @@ Each `EnvConfig` subclass declares:
|
|||||||
|
|
||||||
### Checklist
|
### Checklist
|
||||||
|
|
||||||
| File | Required | Description |
|
| File | Required | Description |
|
||||||
| ---------------------------------------- | -------- | ----------------------------------------------------------------------------------- |
|
| ---------------------------------------- | -------- | -------------------------------------------------------------------------------- |
|
||||||
| `src/lerobot/envs/<benchmark>.py` | Yes | `gym.Env` subclass + `create_<benchmark>_envs()` factory |
|
| `src/lerobot/envs/<benchmark>.py` | Yes | `gym.Env` subclass + `create_<benchmark>_envs()` factory |
|
||||||
| `src/lerobot/envs/configs.py` | Yes | `@EnvConfig.register_subclass("<name>")` dataclass |
|
| `src/lerobot/envs/configs.py` | Yes | `@EnvConfig.register_subclass("<name>")` dataclass with `create_envs()` override |
|
||||||
| `src/lerobot/envs/factory.py` | Yes | Add dispatch branch in `make_env()` and optionally `make_env_pre_post_processors()` |
|
| `src/lerobot/processor/env_processor.py` | Optional | `ProcessorStep` subclass for env-specific observation transforms |
|
||||||
| `src/lerobot/processor/env_processor.py` | Optional | `ProcessorStep` subclass for env-specific observation transforms |
|
| `src/lerobot/envs/utils.py` | Optional | Extend `preprocess_observation()` if new raw keys are needed |
|
||||||
| `src/lerobot/envs/utils.py` | Optional | Extend `preprocess_observation()` if new raw keys are needed |
|
| `pyproject.toml` | Yes | Add optional dependency group |
|
||||||
| `pyproject.toml` | Yes | Add optional dependency group |
|
| `docs/source/<benchmark>.mdx` | Yes | User-facing benchmark documentation |
|
||||||
| `docs/source/<benchmark>.mdx` | Yes | User-facing benchmark documentation |
|
| `docs/source/_toctree.yml` | Yes | Add entry under the "Benchmarks" section |
|
||||||
| `docs/source/_toctree.yml` | Yes | Add entry under the "Benchmarks" section |
|
|
||||||
|
|
||||||
### 1. The gym.Env wrapper (`src/lerobot/envs/<benchmark>.py`)
|
### 1. The gym.Env wrapper (`src/lerobot/envs/<benchmark>.py`)
|
||||||
|
|
||||||
@@ -169,7 +168,10 @@ See `create_libero_envs()` in `src/lerobot/envs/libero.py` (multi-suite, multi-t
|
|||||||
|
|
||||||
### 2. The config (`src/lerobot/envs/configs.py`)
|
### 2. The config (`src/lerobot/envs/configs.py`)
|
||||||
|
|
||||||
Register a new config dataclass:
|
Register a new config dataclass. Each config owns its environment creation and processor logic via two methods:
|
||||||
|
|
||||||
|
- **`create_envs(n_envs, use_async_envs)`** — Returns `{suite: {task_id: VectorEnv}}`. The base class default uses `gym.make()` for single-task envs. Multi-task benchmarks override this.
|
||||||
|
- **`get_env_processors()`** — Returns `(preprocessor, postprocessor)`. The base class default returns identity (no-op) pipelines. Override if your benchmark needs observation/action transforms.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@EnvConfig.register_subclass("<benchmark_name>")
|
@EnvConfig.register_subclass("<benchmark_name>")
|
||||||
@@ -196,6 +198,20 @@ class MyBenchmarkEnv(EnvConfig):
|
|||||||
@property
|
@property
|
||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
return {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
return {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
||||||
|
|
||||||
|
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
||||||
|
"""Override for multi-task benchmarks or custom env creation."""
|
||||||
|
from lerobot.envs.<benchmark> import create_<benchmark>_envs
|
||||||
|
return create_<benchmark>_envs(task=self.task, n_envs=n_envs, ...)
|
||||||
|
|
||||||
|
def get_env_processors(self):
|
||||||
|
"""Override if your benchmark needs observation/action transforms."""
|
||||||
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
|
from lerobot.processor.env_processor import MyBenchmarkProcessorStep
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline(steps=[MyBenchmarkProcessorStep()]),
|
||||||
|
PolicyProcessorPipeline(steps=[]),
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Key points:
|
Key points:
|
||||||
@@ -203,36 +219,11 @@ Key points:
|
|||||||
- The `register_subclass` name is what users pass as `--env.type=<name>` on the CLI.
|
- The `register_subclass` name is what users pass as `--env.type=<name>` on the CLI.
|
||||||
- `features` declares what the environment produces (used to configure the policy).
|
- `features` declares what the environment produces (used to configure the policy).
|
||||||
- `features_map` maps raw observation keys to LeRobot convention keys.
|
- `features_map` maps raw observation keys to LeRobot convention keys.
|
||||||
|
- **No changes to `factory.py` needed** — the factory delegates to `cfg.create_envs()` and `cfg.get_env_processors()` automatically.
|
||||||
|
|
||||||
### 3. The factory dispatch (`src/lerobot/envs/factory.py`)
|
### 3. Env processor (optional) (`src/lerobot/processor/env_processor.py`)
|
||||||
|
|
||||||
Add a branch in `make_env()`:
|
If your benchmark needs observation transforms beyond what `preprocess_observation()` handles (e.g., image flipping, coordinate frame conversion), add a `ProcessorStep` and return it from `get_env_processors()` in your config (see step 2):
|
||||||
|
|
||||||
```python
|
|
||||||
elif "<benchmark_name>" in cfg.type:
|
|
||||||
from lerobot.envs.<benchmark> import create_<benchmark>_envs
|
|
||||||
|
|
||||||
if cfg.task is None:
|
|
||||||
raise ValueError("<BenchmarkName> requires a task to be specified")
|
|
||||||
|
|
||||||
return create_<benchmark>_envs(
|
|
||||||
task=cfg.task,
|
|
||||||
n_envs=n_envs,
|
|
||||||
gym_kwargs=cfg.gym_kwargs,
|
|
||||||
env_cls=env_cls,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
If your benchmark needs an env processor, add it in `make_env_pre_post_processors()`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
if isinstance(env_cfg, MyBenchmarkEnv) or "<benchmark_name>" in env_cfg.type:
|
|
||||||
preprocessor_steps.append(MyBenchmarkProcessorStep())
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Env processor (optional) (`src/lerobot/processor/env_processor.py`)
|
|
||||||
|
|
||||||
If your benchmark needs observation transforms beyond what `preprocess_observation()` handles (e.g., image flipping, coordinate frame conversion), add a `ProcessorStep`:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -253,7 +244,7 @@ class MyBenchmarkProcessorStep(ObservationProcessorStep):
|
|||||||
|
|
||||||
See `LiberoProcessorStep` for a full example (image rotation, quaternion-to-axis-angle conversion).
|
See `LiberoProcessorStep` for a full example (image rotation, quaternion-to-axis-angle conversion).
|
||||||
|
|
||||||
### 5. Dependencies (`pyproject.toml`)
|
### 4. Dependencies (`pyproject.toml`)
|
||||||
|
|
||||||
Add a new optional-dependency group under `[project.optional-dependencies]`:
|
Add a new optional-dependency group under `[project.optional-dependencies]`:
|
||||||
|
|
||||||
@@ -274,11 +265,11 @@ Users install with:
|
|||||||
pip install -e ".[mybenchmark]"
|
pip install -e ".[mybenchmark]"
|
||||||
```
|
```
|
||||||
|
|
||||||
### 6. Documentation (`docs/source/<benchmark>.mdx`)
|
### 5. Documentation (`docs/source/<benchmark>.mdx`)
|
||||||
|
|
||||||
Follow the template below. See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for full examples.
|
Follow the template below. See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for full examples.
|
||||||
|
|
||||||
### 7. Table of contents (`docs/source/_toctree.yml`)
|
### 6. Table of contents (`docs/source/_toctree.yml`)
|
||||||
|
|
||||||
Add your benchmark under the "Benchmarks" section:
|
Add your benchmark under the "Benchmarks" section:
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user