mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
refactor(envs): move dispatch logic from factory into EnvConfig subclasses
Replace hardcoded if/elif chains in factory.py with create_envs() and get_env_processors() methods on EnvConfig. New benchmarks now only need to register a config subclass — no factory.py edits required. Net -23 lines: factory.py shrinks from ~200 to ~70 lines of logic. Made-with: Cursor
This commit is contained in:
@@ -103,16 +103,15 @@ Each `EnvConfig` subclass declares:
|
||||
|
||||
### Checklist
|
||||
|
||||
| File | Required | Description |
|
||||
| ---------------------------------------- | -------- | ----------------------------------------------------------------------------------- |
|
||||
| `src/lerobot/envs/<benchmark>.py` | Yes | `gym.Env` subclass + `create_<benchmark>_envs()` factory |
|
||||
| `src/lerobot/envs/configs.py` | Yes | `@EnvConfig.register_subclass("<name>")` dataclass |
|
||||
| `src/lerobot/envs/factory.py` | Yes | Add dispatch branch in `make_env()` and optionally `make_env_pre_post_processors()` |
|
||||
| `src/lerobot/processor/env_processor.py` | Optional | `ProcessorStep` subclass for env-specific observation transforms |
|
||||
| `src/lerobot/envs/utils.py` | Optional | Extend `preprocess_observation()` if new raw keys are needed |
|
||||
| `pyproject.toml` | Yes | Add optional dependency group |
|
||||
| `docs/source/<benchmark>.mdx` | Yes | User-facing benchmark documentation |
|
||||
| `docs/source/_toctree.yml` | Yes | Add entry under the "Benchmarks" section |
|
||||
| File | Required | Description |
|
||||
| ---------------------------------------- | -------- | -------------------------------------------------------------------------------- |
|
||||
| `src/lerobot/envs/<benchmark>.py` | Yes | `gym.Env` subclass + `create_<benchmark>_envs()` factory |
|
||||
| `src/lerobot/envs/configs.py` | Yes | `@EnvConfig.register_subclass("<name>")` dataclass with `create_envs()` override |
|
||||
| `src/lerobot/processor/env_processor.py` | Optional | `ProcessorStep` subclass for env-specific observation transforms |
|
||||
| `src/lerobot/envs/utils.py` | Optional | Extend `preprocess_observation()` if new raw keys are needed |
|
||||
| `pyproject.toml` | Yes | Add optional dependency group |
|
||||
| `docs/source/<benchmark>.mdx` | Yes | User-facing benchmark documentation |
|
||||
| `docs/source/_toctree.yml` | Yes | Add entry under the "Benchmarks" section |
|
||||
|
||||
### 1. The gym.Env wrapper (`src/lerobot/envs/<benchmark>.py`)
|
||||
|
||||
@@ -169,7 +168,10 @@ See `create_libero_envs()` in `src/lerobot/envs/libero.py` (multi-suite, multi-t
|
||||
|
||||
### 2. The config (`src/lerobot/envs/configs.py`)
|
||||
|
||||
Register a new config dataclass:
|
||||
Register a new config dataclass. Each config owns its environment creation and processor logic via two methods:
|
||||
|
||||
- **`create_envs(n_envs, use_async_envs)`** — Returns `{suite: {task_id: VectorEnv}}`. The base class default uses `gym.make()` for single-task envs. Multi-task benchmarks override this.
|
||||
- **`get_env_processors()`** — Returns `(preprocessor, postprocessor)`. The base class default returns identity (no-op) pipelines. Override if your benchmark needs observation/action transforms.
|
||||
|
||||
```python
|
||||
@EnvConfig.register_subclass("<benchmark_name>")
|
||||
@@ -196,6 +198,20 @@ class MyBenchmarkEnv(EnvConfig):
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
||||
|
||||
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
||||
"""Override for multi-task benchmarks or custom env creation."""
|
||||
from lerobot.envs.<benchmark> import create_<benchmark>_envs
|
||||
return create_<benchmark>_envs(task=self.task, n_envs=n_envs, ...)
|
||||
|
||||
def get_env_processors(self):
|
||||
"""Override if your benchmark needs observation/action transforms."""
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
from lerobot.processor.env_processor import MyBenchmarkProcessorStep
|
||||
return (
|
||||
PolicyProcessorPipeline(steps=[MyBenchmarkProcessorStep()]),
|
||||
PolicyProcessorPipeline(steps=[]),
|
||||
)
|
||||
```
|
||||
|
||||
Key points:
|
||||
@@ -203,36 +219,11 @@ Key points:
|
||||
- The `register_subclass` name is what users pass as `--env.type=<name>` on the CLI.
|
||||
- `features` declares what the environment produces (used to configure the policy).
|
||||
- `features_map` maps raw observation keys to LeRobot convention keys.
|
||||
- **No changes to `factory.py` needed** — the factory delegates to `cfg.create_envs()` and `cfg.get_env_processors()` automatically.
|
||||
|
||||
### 3. The factory dispatch (`src/lerobot/envs/factory.py`)
|
||||
### 3. Env processor (optional) (`src/lerobot/processor/env_processor.py`)
|
||||
|
||||
Add a branch in `make_env()`:
|
||||
|
||||
```python
|
||||
elif "<benchmark_name>" in cfg.type:
|
||||
from lerobot.envs.<benchmark> import create_<benchmark>_envs
|
||||
|
||||
if cfg.task is None:
|
||||
raise ValueError("<BenchmarkName> requires a task to be specified")
|
||||
|
||||
return create_<benchmark>_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
```
|
||||
|
||||
If your benchmark needs an env processor, add it in `make_env_pre_post_processors()`:
|
||||
|
||||
```python
|
||||
if isinstance(env_cfg, MyBenchmarkEnv) or "<benchmark_name>" in env_cfg.type:
|
||||
preprocessor_steps.append(MyBenchmarkProcessorStep())
|
||||
```
|
||||
|
||||
### 4. Env processor (optional) (`src/lerobot/processor/env_processor.py`)
|
||||
|
||||
If your benchmark needs observation transforms beyond what `preprocess_observation()` handles (e.g., image flipping, coordinate frame conversion), add a `ProcessorStep`:
|
||||
If your benchmark needs observation transforms beyond what `preprocess_observation()` handles (e.g., image flipping, coordinate frame conversion), add a `ProcessorStep` and return it from `get_env_processors()` in your config (see step 2):
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
@@ -253,7 +244,7 @@ class MyBenchmarkProcessorStep(ObservationProcessorStep):
|
||||
|
||||
See `LiberoProcessorStep` for a full example (image rotation, quaternion-to-axis-angle conversion).
|
||||
|
||||
### 5. Dependencies (`pyproject.toml`)
|
||||
### 4. Dependencies (`pyproject.toml`)
|
||||
|
||||
Add a new optional-dependency group under `[project.optional-dependencies]`:
|
||||
|
||||
@@ -274,11 +265,11 @@ Users install with:
|
||||
pip install -e ".[mybenchmark]"
|
||||
```
|
||||
|
||||
### 6. Documentation (`docs/source/<benchmark>.mdx`)
|
||||
### 5. Documentation (`docs/source/<benchmark>.mdx`)
|
||||
|
||||
Follow the template below. See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for full examples.
|
||||
|
||||
### 7. Table of contents (`docs/source/_toctree.yml`)
|
||||
### 6. Table of contents (`docs/source/_toctree.yml`)
|
||||
|
||||
Add your benchmark under the "Benchmarks" section:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user