From 5ad4c8f7b687630d2ff8128d06cf5723cda0edc8 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 3 Apr 2026 13:23:44 +0200 Subject: [PATCH] refactor(envs): move dispatch logic from factory into EnvConfig subclasses MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace hardcoded if/elif chains in factory.py with create_envs() and get_env_processors() methods on EnvConfig. New benchmarks now only need to register a config subclass — no factory.py edits required. Net -23 lines: factory.py shrinks from ~200 to ~70 lines of logic. Made-with: Cursor --- docs/source/adding_benchmarks.mdx | 75 ++++++++++++++----------------- 1 file changed, 33 insertions(+), 42 deletions(-) diff --git a/docs/source/adding_benchmarks.mdx b/docs/source/adding_benchmarks.mdx index 3ba606dd2..a70a791a1 100644 --- a/docs/source/adding_benchmarks.mdx +++ b/docs/source/adding_benchmarks.mdx @@ -103,16 +103,15 @@ Each `EnvConfig` subclass declares: ### Checklist -| File | Required | Description | -| ---------------------------------------- | -------- | ----------------------------------------------------------------------------------- | -| `src/lerobot/envs/.py` | Yes | `gym.Env` subclass + `create__envs()` factory | -| `src/lerobot/envs/configs.py` | Yes | `@EnvConfig.register_subclass("")` dataclass | -| `src/lerobot/envs/factory.py` | Yes | Add dispatch branch in `make_env()` and optionally `make_env_pre_post_processors()` | -| `src/lerobot/processor/env_processor.py` | Optional | `ProcessorStep` subclass for env-specific observation transforms | -| `src/lerobot/envs/utils.py` | Optional | Extend `preprocess_observation()` if new raw keys are needed | -| `pyproject.toml` | Yes | Add optional dependency group | -| `docs/source/.mdx` | Yes | User-facing benchmark documentation | -| `docs/source/_toctree.yml` | Yes | Add entry under the "Benchmarks" section | +| File | Required | Description | +| ---------------------------------------- | -------- | -------------------------------------------------------------------------------- | +| `src/lerobot/envs/.py` | Yes | `gym.Env` subclass + `create__envs()` factory | +| `src/lerobot/envs/configs.py` | Yes | `@EnvConfig.register_subclass("")` dataclass with `create_envs()` override | +| `src/lerobot/processor/env_processor.py` | Optional | `ProcessorStep` subclass for env-specific observation transforms | +| `src/lerobot/envs/utils.py` | Optional | Extend `preprocess_observation()` if new raw keys are needed | +| `pyproject.toml` | Yes | Add optional dependency group | +| `docs/source/.mdx` | Yes | User-facing benchmark documentation | +| `docs/source/_toctree.yml` | Yes | Add entry under the "Benchmarks" section | ### 1. The gym.Env wrapper (`src/lerobot/envs/.py`) @@ -169,7 +168,10 @@ See `create_libero_envs()` in `src/lerobot/envs/libero.py` (multi-suite, multi-t ### 2. The config (`src/lerobot/envs/configs.py`) -Register a new config dataclass: +Register a new config dataclass. Each config owns its environment creation and processor logic via two methods: + +- **`create_envs(n_envs, use_async_envs)`** — Returns `{suite: {task_id: VectorEnv}}`. The base class default uses `gym.make()` for single-task envs. Multi-task benchmarks override this. +- **`get_env_processors()`** — Returns `(preprocessor, postprocessor)`. The base class default returns identity (no-op) pipelines. Override if your benchmark needs observation/action transforms. ```python @EnvConfig.register_subclass("") @@ -196,6 +198,20 @@ class MyBenchmarkEnv(EnvConfig): @property def gym_kwargs(self) -> dict: return {"obs_type": self.obs_type, "render_mode": self.render_mode} + + def create_envs(self, n_envs: int, use_async_envs: bool = False): + """Override for multi-task benchmarks or custom env creation.""" + from lerobot.envs. import create__envs + return create__envs(task=self.task, n_envs=n_envs, ...) + + def get_env_processors(self): + """Override if your benchmark needs observation/action transforms.""" + from lerobot.processor.pipeline import PolicyProcessorPipeline + from lerobot.processor.env_processor import MyBenchmarkProcessorStep + return ( + PolicyProcessorPipeline(steps=[MyBenchmarkProcessorStep()]), + PolicyProcessorPipeline(steps=[]), + ) ``` Key points: @@ -203,36 +219,11 @@ Key points: - The `register_subclass` name is what users pass as `--env.type=` on the CLI. - `features` declares what the environment produces (used to configure the policy). - `features_map` maps raw observation keys to LeRobot convention keys. +- **No changes to `factory.py` needed** — the factory delegates to `cfg.create_envs()` and `cfg.get_env_processors()` automatically. -### 3. The factory dispatch (`src/lerobot/envs/factory.py`) +### 3. Env processor (optional) (`src/lerobot/processor/env_processor.py`) -Add a branch in `make_env()`: - -```python -elif "" in cfg.type: - from lerobot.envs. import create__envs - - if cfg.task is None: - raise ValueError(" requires a task to be specified") - - return create__envs( - task=cfg.task, - n_envs=n_envs, - gym_kwargs=cfg.gym_kwargs, - env_cls=env_cls, - ) -``` - -If your benchmark needs an env processor, add it in `make_env_pre_post_processors()`: - -```python -if isinstance(env_cfg, MyBenchmarkEnv) or "" in env_cfg.type: - preprocessor_steps.append(MyBenchmarkProcessorStep()) -``` - -### 4. Env processor (optional) (`src/lerobot/processor/env_processor.py`) - -If your benchmark needs observation transforms beyond what `preprocess_observation()` handles (e.g., image flipping, coordinate frame conversion), add a `ProcessorStep`: +If your benchmark needs observation transforms beyond what `preprocess_observation()` handles (e.g., image flipping, coordinate frame conversion), add a `ProcessorStep` and return it from `get_env_processors()` in your config (see step 2): ```python @dataclass @@ -253,7 +244,7 @@ class MyBenchmarkProcessorStep(ObservationProcessorStep): See `LiberoProcessorStep` for a full example (image rotation, quaternion-to-axis-angle conversion). -### 5. Dependencies (`pyproject.toml`) +### 4. Dependencies (`pyproject.toml`) Add a new optional-dependency group under `[project.optional-dependencies]`: @@ -274,11 +265,11 @@ Users install with: pip install -e ".[mybenchmark]" ``` -### 6. Documentation (`docs/source/.mdx`) +### 5. Documentation (`docs/source/.mdx`) Follow the template below. See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for full examples. -### 7. Table of contents (`docs/source/_toctree.yml`) +### 6. Table of contents (`docs/source/_toctree.yml`) Add your benchmark under the "Benchmarks" section: