mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 46e9e22b05 | |||
| b43f9ab048 | |||
| 0045f88355 | |||
| 89ce91f69f | |||
| 90e614f6b9 | |||
| ff4f860e5d | |||
| 6f2823bfc4 | |||
| 77415559b8 | |||
| 24d9b74d81 | |||
| 508358749a |
@@ -115,23 +115,22 @@ Each `EnvConfig` subclass declares two dicts that tell the policy what to expect
|
|||||||
## Step by step
|
## Step by step
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
At minimum, you need three files: a **gym.Env wrapper**, an **EnvConfig
|
At minimum, you need two files: a **gym.Env wrapper** and an **EnvConfig
|
||||||
subclass**, and a **factory dispatch branch**. Everything else is optional or
|
subclass** with a `create_envs()` override. Everything else is optional or
|
||||||
documentation.
|
documentation. No changes to `factory.py` are needed.
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
### Checklist
|
### Checklist
|
||||||
|
|
||||||
| File | Required | Why |
|
| File | Required | Why |
|
||||||
| ---------------------------------------- | -------- | ----------------------------------------- |
|
| ---------------------------------------- | -------- | ------------------------------------------------------------ |
|
||||||
| `src/lerobot/envs/<benchmark>.py` | Yes | Wraps the simulator as a standard gym.Env |
|
| `src/lerobot/envs/<benchmark>.py` | Yes | Wraps the simulator as a standard gym.Env |
|
||||||
| `src/lerobot/envs/configs.py` | Yes | Registers your benchmark for the CLI |
|
| `src/lerobot/envs/configs.py` | Yes | Registers your benchmark and its `create_envs()` for the CLI |
|
||||||
| `src/lerobot/envs/factory.py` | Yes | Tells `make_env()` how to build your envs |
|
| `src/lerobot/processor/env_processor.py` | Optional | Custom observation/action transforms |
|
||||||
| `src/lerobot/processor/env_processor.py` | Optional | Custom observation/action transforms |
|
| `src/lerobot/envs/utils.py` | Optional | Only if you need new raw observation keys |
|
||||||
| `src/lerobot/envs/utils.py` | Optional | Only if you need new raw observation keys |
|
| `pyproject.toml` | Yes | Declares benchmark-specific dependencies |
|
||||||
| `pyproject.toml` | Yes | Declares benchmark-specific dependencies |
|
| `docs/source/<benchmark>.mdx` | Yes | User-facing documentation page |
|
||||||
| `docs/source/<benchmark>.mdx` | Yes | User-facing documentation page |
|
| `docs/source/_toctree.yml` | Yes | Adds your page to the docs sidebar |
|
||||||
| `docs/source/_toctree.yml` | Yes | Adds your page to the docs sidebar |
|
|
||||||
|
|
||||||
### 1. The gym.Env wrapper (`src/lerobot/envs/<benchmark>.py`)
|
### 1. The gym.Env wrapper (`src/lerobot/envs/<benchmark>.py`)
|
||||||
|
|
||||||
@@ -179,7 +178,10 @@ See `create_libero_envs()` (multi-suite, multi-task) and `create_metaworld_envs(
|
|||||||
|
|
||||||
### 2. The config (`src/lerobot/envs/configs.py`)
|
### 2. The config (`src/lerobot/envs/configs.py`)
|
||||||
|
|
||||||
Register a config dataclass so users can select your benchmark with `--env.type=<name>`:
|
Register a config dataclass so users can select your benchmark with `--env.type=<name>`. Each config owns its environment creation and processor logic via two methods:
|
||||||
|
|
||||||
|
- **`create_envs(n_envs, use_async_envs)`** — Returns `{suite: {task_id: VectorEnv}}`. The base class default uses `gym.make()` for single-task envs. Multi-task benchmarks override this.
|
||||||
|
- **`get_env_processors()`** — Returns `(preprocessor, postprocessor)`. The base class default returns identity (no-op) pipelines. Override if your benchmark needs observation/action transforms.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@EnvConfig.register_subclass("<benchmark_name>")
|
@EnvConfig.register_subclass("<benchmark_name>")
|
||||||
@@ -204,6 +206,20 @@ class MyBenchmarkEnvConfig(EnvConfig):
|
|||||||
@property
|
@property
|
||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
return {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
return {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
||||||
|
|
||||||
|
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
||||||
|
"""Override for multi-task benchmarks or custom env creation."""
|
||||||
|
from lerobot.envs.<benchmark> import create_<benchmark>_envs
|
||||||
|
return create_<benchmark>_envs(task=self.task, n_envs=n_envs, ...)
|
||||||
|
|
||||||
|
def get_env_processors(self):
|
||||||
|
"""Override if your benchmark needs observation/action transforms."""
|
||||||
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
|
from lerobot.processor.env_processor import MyBenchmarkProcessorStep
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline(steps=[MyBenchmarkProcessorStep()]),
|
||||||
|
PolicyProcessorPipeline(steps=[]),
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Key points:
|
Key points:
|
||||||
@@ -211,36 +227,11 @@ Key points:
|
|||||||
- The `register_subclass` name is what users pass on the CLI (`--env.type=<name>`).
|
- The `register_subclass` name is what users pass on the CLI (`--env.type=<name>`).
|
||||||
- `features` tells the policy what the environment produces.
|
- `features` tells the policy what the environment produces.
|
||||||
- `features_map` maps raw observation keys to LeRobot convention keys.
|
- `features_map` maps raw observation keys to LeRobot convention keys.
|
||||||
|
- **No changes to `factory.py` needed** — the factory delegates to `cfg.create_envs()` and `cfg.get_env_processors()` automatically.
|
||||||
|
|
||||||
### 3. The factory dispatch (`src/lerobot/envs/factory.py`)
|
### 3. Env processor (optional — `src/lerobot/processor/env_processor.py`)
|
||||||
|
|
||||||
Add a branch in `make_env()` to call your factory function:
|
Only needed if your benchmark requires observation transforms beyond what `preprocess_observation()` handles (e.g. image flipping, coordinate conversion). Define the processor step here and return it from `get_env_processors()` in your config (see step 2):
|
||||||
|
|
||||||
```python
|
|
||||||
elif "<benchmark_name>" in cfg.type:
|
|
||||||
from lerobot.envs.<benchmark> import create_<benchmark>_envs
|
|
||||||
|
|
||||||
if cfg.task is None:
|
|
||||||
raise ValueError("<BenchmarkName> requires a task to be specified")
|
|
||||||
|
|
||||||
return create_<benchmark>_envs(
|
|
||||||
task=cfg.task,
|
|
||||||
n_envs=n_envs,
|
|
||||||
gym_kwargs=cfg.gym_kwargs,
|
|
||||||
env_cls=env_cls,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
If your benchmark needs an env processor, add it in `make_env_pre_post_processors()`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
if isinstance(env_cfg, MyBenchmarkEnvConfig) or "<benchmark_name>" in env_cfg.type:
|
|
||||||
preprocessor_steps.append(MyBenchmarkProcessorStep())
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Env processor (optional — `src/lerobot/processor/env_processor.py`)
|
|
||||||
|
|
||||||
Only needed if your benchmark requires observation transforms beyond what `preprocess_observation()` handles (e.g. image flipping, coordinate conversion):
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -260,7 +251,7 @@ class MyBenchmarkProcessorStep(ObservationProcessorStep):
|
|||||||
|
|
||||||
See `LiberoProcessorStep` for a full example (image rotation, quaternion-to-axis-angle conversion).
|
See `LiberoProcessorStep` for a full example (image rotation, quaternion-to-axis-angle conversion).
|
||||||
|
|
||||||
### 5. Dependencies (`pyproject.toml`)
|
### 4. Dependencies (`pyproject.toml`)
|
||||||
|
|
||||||
Add a new optional-dependency group:
|
Add a new optional-dependency group:
|
||||||
|
|
||||||
@@ -281,11 +272,11 @@ Users install with:
|
|||||||
pip install -e ".[mybenchmark]"
|
pip install -e ".[mybenchmark]"
|
||||||
```
|
```
|
||||||
|
|
||||||
### 6. Documentation (`docs/source/<benchmark>.mdx`)
|
### 5. Documentation (`docs/source/<benchmark>.mdx`)
|
||||||
|
|
||||||
Write a user-facing page following the template in the next section. See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for full examples.
|
Write a user-facing page following the template in the next section. See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for full examples.
|
||||||
|
|
||||||
### 7. Table of contents (`docs/source/_toctree.yml`)
|
### 6. Table of contents (`docs/source/_toctree.yml`)
|
||||||
|
|
||||||
Add your benchmark to the "Benchmarks" section:
|
Add your benchmark to the "Benchmarks" section:
|
||||||
|
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ observation = {
|
|||||||
|
|
||||||
### Factory Function
|
### Factory Function
|
||||||
|
|
||||||
The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
|
The `make_env_pre_post_processors` function delegates to `env_cfg.get_env_processors()`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from lerobot.envs.factory import make_env_pre_post_processors
|
from lerobot.envs.factory import make_env_pre_post_processors
|
||||||
@@ -159,47 +159,31 @@ from lerobot.envs.configs import LiberoEnv, PushtEnv
|
|||||||
|
|
||||||
# For LIBERO: Returns LiberoProcessorStep in preprocessor
|
# For LIBERO: Returns LiberoProcessorStep in preprocessor
|
||||||
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
|
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
|
||||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
|
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg, policy_cfg)
|
||||||
|
|
||||||
# For other environments: Returns identity processors (no-op)
|
# For other environments: Returns identity processors (no-op)
|
||||||
pusht_cfg = PushtEnv()
|
pusht_cfg = PushtEnv()
|
||||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
|
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg, policy_cfg)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Implementation in `envs/factory.py`
|
### How It Works
|
||||||
|
|
||||||
|
Each `EnvConfig` subclass can override `get_env_processors()` to return benchmark-specific
|
||||||
|
processor pipelines. The base class returns identity (no-op) processors by default.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def make_env_pre_post_processors(
|
# In your EnvConfig subclass:
|
||||||
env_cfg: EnvConfig,
|
def get_env_processors(self):
|
||||||
) -> tuple[
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
return (
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
PolicyProcessorPipeline(steps=[MyProcessorStep()]),
|
||||||
]:
|
PolicyProcessorPipeline(steps=[]),
|
||||||
"""
|
)
|
||||||
Create preprocessor and postprocessor pipelines for environment observations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
env_cfg: The configuration of the environment.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple containing:
|
|
||||||
- preprocessor: Pipeline that processes environment observations
|
|
||||||
- postprocessor: Pipeline that processes environment outputs
|
|
||||||
"""
|
|
||||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
|
||||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
|
||||||
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
|
|
||||||
else:
|
|
||||||
# For all other environments, return an identity preprocessor
|
|
||||||
preprocessor = PolicyProcessorPipeline(steps=[])
|
|
||||||
|
|
||||||
# Postprocessor is currently identity for all environments
|
|
||||||
# Future: Could add environment-specific action transformations
|
|
||||||
postprocessor = PolicyProcessorPipeline(steps=[])
|
|
||||||
|
|
||||||
return preprocessor, postprocessor
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The factory function `make_env_pre_post_processors` simply delegates to this method,
|
||||||
|
with a special case for `XVLAConfig` policies which override the env processors entirely.
|
||||||
|
|
||||||
### Integration in Evaluation
|
### Integration in Evaluation
|
||||||
|
|
||||||
In `lerobot_eval.py`, the environment processors are created once and used throughout:
|
In `lerobot_eval.py`, the environment processors are created once and used throughout:
|
||||||
|
|||||||
@@ -67,7 +67,8 @@ class EvalConfig:
|
|||||||
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
|
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
|
||||||
batch_size: int = 50
|
batch_size: int = 50
|
||||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||||
use_async_envs: bool = False
|
# Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1.
|
||||||
|
use_async_envs: bool = True
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.batch_size > self.n_episodes:
|
if self.batch_size > self.n_episodes:
|
||||||
|
|||||||
@@ -12,11 +12,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import importlib
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.envs.registration import registry as gym_registry
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
from lerobot.robots import RobotConfig
|
from lerobot.robots import RobotConfig
|
||||||
@@ -67,6 +72,45 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def create_envs(
|
||||||
|
self,
|
||||||
|
n_envs: int,
|
||||||
|
use_async_envs: bool = True,
|
||||||
|
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||||
|
"""Create {suite: {task_id: VectorEnv}}.
|
||||||
|
|
||||||
|
Default: single-task env via gym.make(). Multi-task benchmarks override.
|
||||||
|
AsyncVectorEnv is the default for n_envs > 1; auto-downgraded to Sync for n_envs=1.
|
||||||
|
"""
|
||||||
|
env_cls = gym.vector.AsyncVectorEnv if (use_async_envs and n_envs > 1) else gym.vector.SyncVectorEnv
|
||||||
|
|
||||||
|
if self.gym_id not in gym_registry:
|
||||||
|
print(f"gym id '{self.gym_id}' not found, attempting to import '{self.package_name}'...")
|
||||||
|
try:
|
||||||
|
importlib.import_module(self.package_name)
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
f"Package '{self.package_name}' required for env '{self.type}' not found. "
|
||||||
|
f"Please install it or check PYTHONPATH."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if self.gym_id not in gym_registry:
|
||||||
|
raise gym.error.NameNotFound(
|
||||||
|
f"Environment '{self.gym_id}' not registered even after importing '{self.package_name}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_one():
|
||||||
|
return gym.make(self.gym_id, disable_env_checker=self.disable_env_checker, **self.gym_kwargs)
|
||||||
|
|
||||||
|
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
|
||||||
|
return {self.type: {0: vec}}
|
||||||
|
|
||||||
|
def get_env_processors(self):
|
||||||
|
"""Return (preprocessor, postprocessor) for this env. Default: identity."""
|
||||||
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
|
|
||||||
|
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HubEnvConfig(EnvConfig):
|
class HubEnvConfig(EnvConfig):
|
||||||
@@ -345,6 +389,32 @@ class LiberoEnv(EnvConfig):
|
|||||||
kwargs["task_ids"] = self.task_ids
|
kwargs["task_ids"] = self.task_ids
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
def create_envs(self, n_envs: int, use_async_envs: bool = True):
|
||||||
|
from lerobot.envs.libero import create_libero_envs
|
||||||
|
|
||||||
|
if self.task is None:
|
||||||
|
raise ValueError("LiberoEnv requires a task to be specified")
|
||||||
|
env_cls = gym.vector.AsyncVectorEnv if (use_async_envs and n_envs > 1) else gym.vector.SyncVectorEnv
|
||||||
|
return create_libero_envs(
|
||||||
|
task=self.task,
|
||||||
|
n_envs=n_envs,
|
||||||
|
camera_name=self.camera_name,
|
||||||
|
init_states=self.init_states,
|
||||||
|
gym_kwargs=self.gym_kwargs,
|
||||||
|
env_cls=env_cls,
|
||||||
|
control_mode=self.control_mode,
|
||||||
|
episode_length=self.episode_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_env_processors(self):
|
||||||
|
from lerobot.processor.env_processor import LiberoProcessorStep
|
||||||
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
|
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline(steps=[LiberoProcessorStep()]),
|
||||||
|
PolicyProcessorPipeline(steps=[]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass("metaworld")
|
@EnvConfig.register_subclass("metaworld")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -387,6 +457,19 @@ class MetaworldEnv(EnvConfig):
|
|||||||
"render_mode": self.render_mode,
|
"render_mode": self.render_mode,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def create_envs(self, n_envs: int, use_async_envs: bool = True):
|
||||||
|
from lerobot.envs.metaworld import create_metaworld_envs
|
||||||
|
|
||||||
|
if self.task is None:
|
||||||
|
raise ValueError("MetaWorld requires a task to be specified")
|
||||||
|
env_cls = gym.vector.AsyncVectorEnv if (use_async_envs and n_envs > 1) else gym.vector.SyncVectorEnv
|
||||||
|
return create_metaworld_envs(
|
||||||
|
task=self.task,
|
||||||
|
n_envs=n_envs,
|
||||||
|
gym_kwargs=self.gym_kwargs,
|
||||||
|
env_cls=env_cls,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass("isaaclab_arena")
|
@EnvConfig.register_subclass("isaaclab_arena")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -454,3 +537,18 @@ class IsaaclabArenaEnv(HubEnvConfig):
|
|||||||
@property
|
@property
|
||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def get_env_processors(self):
|
||||||
|
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep
|
||||||
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
|
|
||||||
|
state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip())
|
||||||
|
camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip())
|
||||||
|
if not state_keys and not camera_keys:
|
||||||
|
raise ValueError("At least one of state_keys or camera_keys must be specified.")
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline(
|
||||||
|
steps=[IsaaclabArenaProcessorStep(state_keys=state_keys, camera_keys=camera_keys)]
|
||||||
|
),
|
||||||
|
PolicyProcessorPipeline(steps=[]),
|
||||||
|
)
|
||||||
|
|||||||
+21
-118
@@ -13,96 +13,52 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import importlib
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.envs.registration import registry as gym_registry
|
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.envs.configs import EnvConfig, HubEnvConfig
|
||||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
|
|
||||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
|
||||||
from lerobot.processor import ProcessorStep
|
|
||||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
|
|
||||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
|
||||||
|
|
||||||
|
|
||||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||||
if env_type == "aloha":
|
try:
|
||||||
return AlohaEnv(**kwargs)
|
cls = EnvConfig.get_choice_class(env_type)
|
||||||
elif env_type == "pusht":
|
except KeyError as err:
|
||||||
return PushtEnv(**kwargs)
|
raise ValueError(
|
||||||
elif env_type == "libero":
|
f"Environment type '{env_type}' is not registered. "
|
||||||
return LiberoEnv(**kwargs)
|
f"Available: {list(EnvConfig.get_known_choices().keys())}"
|
||||||
else:
|
) from err
|
||||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def make_env_pre_post_processors(
|
def make_env_pre_post_processors(
|
||||||
env_cfg: EnvConfig,
|
env_cfg: EnvConfig,
|
||||||
policy_cfg: PreTrainedConfig,
|
policy_cfg: Any,
|
||||||
) -> tuple[
|
) -> tuple[Any, Any]:
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
|
||||||
]:
|
|
||||||
"""
|
"""
|
||||||
Create preprocessor and postprocessor pipelines for environment observations.
|
Create preprocessor and postprocessor pipelines for environment observations.
|
||||||
|
|
||||||
This function creates processor pipelines that transform raw environment
|
Returns a tuple of (preprocessor, postprocessor). By default, delegates to
|
||||||
observations and actions. By default, it returns identity processors that do nothing.
|
``env_cfg.get_env_processors()``. The XVLAConfig policy-specific override
|
||||||
For specific environments like LIBERO, it adds environment-specific processing steps.
|
stays here because it depends on the *policy* config, not the env config.
|
||||||
|
|
||||||
Args:
|
|
||||||
env_cfg: The configuration of the environment.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple containing:
|
|
||||||
- preprocessor: Pipeline that processes environment observations
|
|
||||||
- postprocessor: Pipeline that processes environment outputs (currently identity)
|
|
||||||
"""
|
"""
|
||||||
# Preprocessor and Postprocessor steps are Identity for most environments
|
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
preprocessor_steps: list[ProcessorStep] = []
|
|
||||||
postprocessor_steps: list[ProcessorStep] = []
|
|
||||||
if isinstance(policy_cfg, XVLAConfig):
|
if isinstance(policy_cfg, XVLAConfig):
|
||||||
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
|
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
|
||||||
|
|
||||||
return make_xvla_libero_pre_post_processors()
|
return make_xvla_libero_pre_post_processors()
|
||||||
|
|
||||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
return env_cfg.get_env_processors()
|
||||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
|
||||||
preprocessor_steps.append(LiberoProcessorStep())
|
|
||||||
|
|
||||||
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
|
|
||||||
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
|
|
||||||
# Parse comma-separated keys (handle None for state-based policies)
|
|
||||||
if env_cfg.state_keys:
|
|
||||||
state_keys = tuple(k.strip() for k in env_cfg.state_keys.split(",") if k.strip())
|
|
||||||
else:
|
|
||||||
state_keys = ()
|
|
||||||
if env_cfg.camera_keys:
|
|
||||||
camera_keys = tuple(k.strip() for k in env_cfg.camera_keys.split(",") if k.strip())
|
|
||||||
else:
|
|
||||||
camera_keys = ()
|
|
||||||
if not state_keys and not camera_keys:
|
|
||||||
raise ValueError("At least one of state_keys or camera_keys must be specified.")
|
|
||||||
preprocessor_steps.append(
|
|
||||||
IsaaclabArenaProcessorStep(
|
|
||||||
state_keys=state_keys,
|
|
||||||
camera_keys=camera_keys,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
|
|
||||||
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
|
|
||||||
|
|
||||||
return preprocessor, postprocessor
|
|
||||||
|
|
||||||
|
|
||||||
def make_env(
|
def make_env(
|
||||||
cfg: EnvConfig | str,
|
cfg: EnvConfig | str,
|
||||||
n_envs: int = 1,
|
n_envs: int = 1,
|
||||||
use_async_envs: bool = False,
|
use_async_envs: bool = True,
|
||||||
hub_cache_dir: str | None = None,
|
hub_cache_dir: str | None = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||||
@@ -163,57 +119,4 @@ def make_env(
|
|||||||
if n_envs < 1:
|
if n_envs < 1:
|
||||||
raise ValueError("`n_envs` must be at least 1")
|
raise ValueError("`n_envs` must be at least 1")
|
||||||
|
|
||||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
return cfg.create_envs(n_envs=n_envs, use_async_envs=use_async_envs)
|
||||||
|
|
||||||
if "libero" in cfg.type:
|
|
||||||
from lerobot.envs.libero import create_libero_envs
|
|
||||||
|
|
||||||
if cfg.task is None:
|
|
||||||
raise ValueError("LiberoEnv requires a task to be specified")
|
|
||||||
|
|
||||||
return create_libero_envs(
|
|
||||||
task=cfg.task,
|
|
||||||
n_envs=n_envs,
|
|
||||||
camera_name=cfg.camera_name,
|
|
||||||
init_states=cfg.init_states,
|
|
||||||
gym_kwargs=cfg.gym_kwargs,
|
|
||||||
env_cls=env_cls,
|
|
||||||
control_mode=cfg.control_mode,
|
|
||||||
episode_length=cfg.episode_length,
|
|
||||||
)
|
|
||||||
elif "metaworld" in cfg.type:
|
|
||||||
from lerobot.envs.metaworld import create_metaworld_envs
|
|
||||||
|
|
||||||
if cfg.task is None:
|
|
||||||
raise ValueError("MetaWorld requires a task to be specified")
|
|
||||||
|
|
||||||
return create_metaworld_envs(
|
|
||||||
task=cfg.task,
|
|
||||||
n_envs=n_envs,
|
|
||||||
gym_kwargs=cfg.gym_kwargs,
|
|
||||||
env_cls=env_cls,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.gym_id not in gym_registry:
|
|
||||||
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
|
|
||||||
try:
|
|
||||||
importlib.import_module(cfg.package_name)
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
raise ModuleNotFoundError(
|
|
||||||
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
|
|
||||||
f"Please install it or check PYTHONPATH."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if cfg.gym_id not in gym_registry:
|
|
||||||
raise gym.error.NameNotFound(
|
|
||||||
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _make_one():
|
|
||||||
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
|
|
||||||
|
|
||||||
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
|
|
||||||
|
|
||||||
# normalize to {suite: {task_id: vec_env}} for consistency
|
|
||||||
suite_name = cfg.type # e.g., "pusht", "aloha"
|
|
||||||
return {suite_name: {0: vec}}
|
|
||||||
|
|||||||
+35
-17
@@ -150,7 +150,17 @@ class LiberoEnv(gym.Env):
|
|||||||
|
|
||||||
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||||
|
|
||||||
self._env = self._make_envs_task(task_suite, self.task_id)
|
# Extract task metadata without allocating GPU resources (safe before fork).
|
||||||
|
task = task_suite.get_task(task_id)
|
||||||
|
self.task = task.name
|
||||||
|
self.task_description = task.language
|
||||||
|
self._task_bddl_file = os.path.join(
|
||||||
|
get_libero_path("bddl_files"), task.problem_folder, task.bddl_file
|
||||||
|
)
|
||||||
|
self._env: OffScreenRenderEnv | None = (
|
||||||
|
None # deferred — created on first reset() inside the worker subprocess
|
||||||
|
)
|
||||||
|
|
||||||
default_steps = 500
|
default_steps = 500
|
||||||
self._max_episode_steps = (
|
self._max_episode_steps = (
|
||||||
TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||||
@@ -221,28 +231,32 @@ class LiberoEnv(gym.Env):
|
|||||||
low=ACTION_LOW, high=ACTION_HIGH, shape=(ACTION_DIM,), dtype=np.float32
|
low=ACTION_LOW, high=ACTION_HIGH, shape=(ACTION_DIM,), dtype=np.float32
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _ensure_env(self) -> None:
|
||||||
|
"""Create the underlying OffScreenRenderEnv on first use.
|
||||||
|
|
||||||
|
Called inside the worker subprocess after fork(), so each worker gets
|
||||||
|
its own clean EGL context rather than inheriting a stale one from the
|
||||||
|
parent process (which causes EGL_BAD_CONTEXT crashes with AsyncVectorEnv).
|
||||||
|
"""
|
||||||
|
if self._env is not None:
|
||||||
|
return
|
||||||
|
env = OffScreenRenderEnv(
|
||||||
|
bddl_file_name=self._task_bddl_file,
|
||||||
|
camera_heights=self.observation_height,
|
||||||
|
camera_widths=self.observation_width,
|
||||||
|
)
|
||||||
|
env.reset()
|
||||||
|
self._env = env
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
|
self._ensure_env()
|
||||||
raw_obs = self._env.env._get_observations()
|
raw_obs = self._env.env._get_observations()
|
||||||
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
|
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
|
||||||
image = image[::-1, ::-1] # flip both H and W for visualization
|
image = image[::-1, ::-1] # flip both H and W for visualization
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def _make_envs_task(self, task_suite: Any, task_id: int = 0):
|
|
||||||
task = task_suite.get_task(task_id)
|
|
||||||
self.task = task.name
|
|
||||||
self.task_description = task.language
|
|
||||||
task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
|
|
||||||
|
|
||||||
env_args = {
|
|
||||||
"bddl_file_name": task_bddl_file,
|
|
||||||
"camera_heights": self.observation_height,
|
|
||||||
"camera_widths": self.observation_width,
|
|
||||||
}
|
|
||||||
env = OffScreenRenderEnv(**env_args)
|
|
||||||
env.reset()
|
|
||||||
return env
|
|
||||||
|
|
||||||
def _format_raw_obs(self, raw_obs: RobotObservation) -> RobotObservation:
|
def _format_raw_obs(self, raw_obs: RobotObservation) -> RobotObservation:
|
||||||
|
assert self._env is not None, "_format_raw_obs called before _ensure_env()"
|
||||||
images = {}
|
images = {}
|
||||||
for camera_name in self.camera_name:
|
for camera_name in self.camera_name:
|
||||||
image = raw_obs[camera_name]
|
image = raw_obs[camera_name]
|
||||||
@@ -294,6 +308,7 @@ class LiberoEnv(gym.Env):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def reset(self, seed=None, **kwargs):
|
def reset(self, seed=None, **kwargs):
|
||||||
|
self._ensure_env()
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self._env.seed(seed)
|
self._env.seed(seed)
|
||||||
raw_obs = self._env.reset()
|
raw_obs = self._env.reset()
|
||||||
@@ -320,6 +335,8 @@ class LiberoEnv(gym.Env):
|
|||||||
return observation, info
|
return observation, info
|
||||||
|
|
||||||
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
|
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
|
||||||
|
self._ensure_env()
|
||||||
|
assert self._env is not None
|
||||||
if action.ndim != 1:
|
if action.ndim != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expected action to be 1-D (shape (action_dim,)), "
|
f"Expected action to be 1-D (shape (action_dim,)), "
|
||||||
@@ -350,7 +367,8 @@ class LiberoEnv(gym.Env):
|
|||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._env.close()
|
if self._env is not None:
|
||||||
|
self._env.close()
|
||||||
|
|
||||||
|
|
||||||
def _make_env_fns(
|
def _make_env_fns(
|
||||||
|
|||||||
@@ -97,8 +97,9 @@ class MetaworldEnv(gym.Env):
|
|||||||
self.visualization_height = visualization_height
|
self.visualization_height = visualization_height
|
||||||
self.camera_name = camera_name
|
self.camera_name = camera_name
|
||||||
|
|
||||||
self._env = self._make_envs_task(self.task)
|
self._env_name = self.task # already stripped of "metaworld-" prefix above
|
||||||
self._max_episode_steps = self._env.max_path_length
|
self._env = None # deferred — created on first reset() inside the worker subprocess
|
||||||
|
self._max_episode_steps = 500 # MT1 environments always have max_path_length=500
|
||||||
self.task_description = TASK_DESCRIPTIONS[self.task]
|
self.task_description = TASK_DESCRIPTIONS[self.task]
|
||||||
|
|
||||||
self.expert_policy = TASK_POLICY_MAPPING[self.task]()
|
self.expert_policy = TASK_POLICY_MAPPING[self.task]()
|
||||||
@@ -136,6 +137,24 @@ class MetaworldEnv(gym.Env):
|
|||||||
|
|
||||||
self.action_space = spaces.Box(low=-1, high=1, shape=(ACTION_DIM,), dtype=np.float32)
|
self.action_space = spaces.Box(low=-1, high=1, shape=(ACTION_DIM,), dtype=np.float32)
|
||||||
|
|
||||||
|
def _ensure_env(self) -> None:
|
||||||
|
"""Create the underlying MetaWorld env on first use.
|
||||||
|
|
||||||
|
Called inside the worker subprocess after fork(), so each worker gets
|
||||||
|
its own clean rendering context rather than inheriting a stale one from
|
||||||
|
the parent process (which causes crashes with AsyncVectorEnv).
|
||||||
|
"""
|
||||||
|
if self._env is not None:
|
||||||
|
return
|
||||||
|
mt1 = metaworld.MT1(self._env_name, seed=42)
|
||||||
|
env = mt1.train_classes[self._env_name](render_mode="rgb_array", camera_name=self.camera_name)
|
||||||
|
env.set_task(mt1.train_tasks[0])
|
||||||
|
if self.camera_name == "corner2":
|
||||||
|
env.model.cam_pos[2] = [0.75, 0.075, 0.7]
|
||||||
|
env.reset()
|
||||||
|
env._freeze_rand_vec = False # otherwise no randomization
|
||||||
|
self._env = env
|
||||||
|
|
||||||
def render(self) -> np.ndarray:
|
def render(self) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Render the current environment frame.
|
Render the current environment frame.
|
||||||
@@ -143,26 +162,13 @@ class MetaworldEnv(gym.Env):
|
|||||||
Returns:
|
Returns:
|
||||||
np.ndarray: The rendered RGB image from the environment.
|
np.ndarray: The rendered RGB image from the environment.
|
||||||
"""
|
"""
|
||||||
|
self._ensure_env()
|
||||||
image = self._env.render()
|
image = self._env.render()
|
||||||
if self.camera_name == "corner2":
|
if self.camera_name == "corner2":
|
||||||
# Images from this camera are flipped — correct them
|
# Images from this camera are flipped — correct them
|
||||||
image = np.flip(image, (0, 1))
|
image = np.flip(image, (0, 1))
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def _make_envs_task(self, env_name: str):
|
|
||||||
mt1 = metaworld.MT1(env_name, seed=42)
|
|
||||||
env = mt1.train_classes[env_name](render_mode="rgb_array", camera_name=self.camera_name)
|
|
||||||
env.set_task(mt1.train_tasks[0])
|
|
||||||
if self.camera_name == "corner2":
|
|
||||||
env.model.cam_pos[2] = [
|
|
||||||
0.75,
|
|
||||||
0.075,
|
|
||||||
0.7,
|
|
||||||
] # corner2 position, similar to https://arxiv.org/pdf/2206.14244
|
|
||||||
env.reset()
|
|
||||||
env._freeze_rand_vec = False # otherwise no randomization
|
|
||||||
return env
|
|
||||||
|
|
||||||
def _format_raw_obs(self, raw_obs: np.ndarray) -> RobotObservation:
|
def _format_raw_obs(self, raw_obs: np.ndarray) -> RobotObservation:
|
||||||
image = None
|
image = None
|
||||||
if self._env is not None:
|
if self._env is not None:
|
||||||
@@ -209,6 +215,7 @@ class MetaworldEnv(gym.Env):
|
|||||||
observation (RobotObservation): The initial formatted observation.
|
observation (RobotObservation): The initial formatted observation.
|
||||||
info (Dict[str, Any]): Additional info about the reset state.
|
info (Dict[str, Any]): Additional info about the reset state.
|
||||||
"""
|
"""
|
||||||
|
self._ensure_env()
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
|
|
||||||
raw_obs, info = self._env.reset(seed=seed)
|
raw_obs, info = self._env.reset(seed=seed)
|
||||||
@@ -232,6 +239,7 @@ class MetaworldEnv(gym.Env):
|
|||||||
truncated (bool): Whether the episode was truncated due to a time limit.
|
truncated (bool): Whether the episode was truncated due to a time limit.
|
||||||
info (Dict[str, Any]): Additional environment info.
|
info (Dict[str, Any]): Additional environment info.
|
||||||
"""
|
"""
|
||||||
|
self._ensure_env()
|
||||||
if action.ndim != 1:
|
if action.ndim != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expected action to be 1-D (shape (action_dim,)), "
|
f"Expected action to be 1-D (shape (action_dim,)), "
|
||||||
@@ -263,7 +271,8 @@ class MetaworldEnv(gym.Env):
|
|||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._env.close()
|
if self._env is not None:
|
||||||
|
self._env.close()
|
||||||
|
|
||||||
|
|
||||||
# ---- Main API ----------------------------------------------------------------
|
# ---- Main API ----------------------------------------------------------------
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ You can learn about the CLI options for this script in the `EvalPipelineConfig`
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import concurrent.futures as cf
|
import concurrent.futures as cf
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
@@ -56,7 +57,6 @@ from collections.abc import Callable
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict
|
||||||
@@ -73,7 +73,6 @@ from lerobot.configs import parser
|
|||||||
from lerobot.configs.eval import EvalPipelineConfig
|
from lerobot.configs.eval import EvalPipelineConfig
|
||||||
from lerobot.envs.factory import make_env, make_env_pre_post_processors
|
from lerobot.envs.factory import make_env, make_env_pre_post_processors
|
||||||
from lerobot.envs.utils import (
|
from lerobot.envs.utils import (
|
||||||
add_envs_task,
|
|
||||||
check_env_attributes_and_types,
|
check_env_attributes_and_types,
|
||||||
close_envs,
|
close_envs,
|
||||||
preprocess_observation,
|
preprocess_observation,
|
||||||
@@ -166,9 +165,9 @@ def rollout(
|
|||||||
if return_observations:
|
if return_observations:
|
||||||
all_observations.append(deepcopy(observation))
|
all_observations.append(deepcopy(observation))
|
||||||
|
|
||||||
# Infer "task" from attributes of environments.
|
# Infer "task" from sub-environments.
|
||||||
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
# env.call() works with both SyncVectorEnv and AsyncVectorEnv.
|
||||||
observation = add_envs_task(env, observation)
|
observation["task"] = env.call("task")
|
||||||
|
|
||||||
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
|
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
|
||||||
observation = env_preprocessor(observation)
|
observation = env_preprocessor(observation)
|
||||||
@@ -734,34 +733,48 @@ def eval_policy_all(
|
|||||||
group_acc[group]["video_paths"].extend(paths)
|
group_acc[group]["video_paths"].extend(paths)
|
||||||
overall["video_paths"].extend(paths)
|
overall["video_paths"].extend(paths)
|
||||||
|
|
||||||
|
def _make_thread_policy(p: PreTrainedPolicy) -> PreTrainedPolicy:
|
||||||
|
"""Shallow copy sharing weight tensors, with independent per-thread state.
|
||||||
|
|
||||||
|
copy.copy() gives a new Python object whose _parameters dict is a shared
|
||||||
|
reference (same tensor storage, zero extra VRAM). reset() then rebinds
|
||||||
|
mutable state (action queues etc.) to fresh per-thread objects.
|
||||||
|
|
||||||
|
Note: does NOT work for ACT with temporal_ensemble_coeff — that policy's
|
||||||
|
reset() mutates a shared sub-object. Use max_parallel_tasks=1 for that config.
|
||||||
|
"""
|
||||||
|
thread_p = copy.copy(p)
|
||||||
|
thread_p.reset()
|
||||||
|
return thread_p
|
||||||
|
|
||||||
# Choose runner (sequential vs threaded)
|
# Choose runner (sequential vs threaded)
|
||||||
task_runner = partial(
|
_runner_kwargs = {
|
||||||
run_one,
|
"env_preprocessor": env_preprocessor,
|
||||||
policy=policy,
|
"env_postprocessor": env_postprocessor,
|
||||||
env_preprocessor=env_preprocessor,
|
"preprocessor": preprocessor,
|
||||||
env_postprocessor=env_postprocessor,
|
"postprocessor": postprocessor,
|
||||||
preprocessor=preprocessor,
|
"n_episodes": n_episodes,
|
||||||
postprocessor=postprocessor,
|
"max_episodes_rendered": max_episodes_rendered,
|
||||||
n_episodes=n_episodes,
|
"videos_dir": videos_dir,
|
||||||
max_episodes_rendered=max_episodes_rendered,
|
"return_episode_data": return_episode_data,
|
||||||
videos_dir=videos_dir,
|
"start_seed": start_seed,
|
||||||
return_episode_data=return_episode_data,
|
}
|
||||||
start_seed=start_seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
if max_parallel_tasks <= 1:
|
if max_parallel_tasks <= 1:
|
||||||
# sequential path (single accumulator path on the main thread)
|
# sequential path (single accumulator path on the main thread)
|
||||||
# NOTE: keeping a single-threaded accumulator avoids concurrent list appends or locks
|
# NOTE: keeping a single-threaded accumulator avoids concurrent list appends or locks
|
||||||
for task_group, task_id, env in tasks:
|
for task_group, task_id, env in tasks:
|
||||||
tg, tid, metrics = task_runner(task_group, task_id, env)
|
tg, tid, metrics = run_one(task_group, task_id, env, policy=policy, **_runner_kwargs)
|
||||||
_accumulate_to(tg, metrics)
|
_accumulate_to(tg, metrics)
|
||||||
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
|
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
|
||||||
else:
|
else:
|
||||||
# threaded path: submit all tasks, consume completions on main thread and accumulate there
|
# threaded path: each thread gets a shallow policy copy (shared weights, independent state)
|
||||||
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
|
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
|
||||||
fut2meta = {}
|
fut2meta = {}
|
||||||
for task_group, task_id, env in tasks:
|
for task_group, task_id, env in tasks:
|
||||||
fut = executor.submit(task_runner, task_group, task_id, env)
|
fut = executor.submit(
|
||||||
|
run_one, task_group, task_id, env, policy=_make_thread_policy(policy), **_runner_kwargs
|
||||||
|
)
|
||||||
fut2meta[fut] = (task_group, task_id)
|
fut2meta[fut] = (task_group, task_id)
|
||||||
for fut in cf.as_completed(fut2meta):
|
for fut in cf.as_completed(fut2meta):
|
||||||
tg, tid, metrics = fut.result()
|
tg, tid, metrics = fut.result()
|
||||||
|
|||||||
@@ -0,0 +1,143 @@
|
|||||||
|
"""Tests for the benchmark dispatch refactor (create_envs / get_env_processors on EnvConfig)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import pytest
|
||||||
|
from gymnasium.envs.registration import register, registry as gym_registry
|
||||||
|
|
||||||
|
from lerobot.configs.types import PolicyFeature
|
||||||
|
from lerobot.envs.configs import EnvConfig
|
||||||
|
from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_all_types():
|
||||||
|
"""make_env_config should resolve every registered EnvConfig subclass via the registry."""
|
||||||
|
known = list(EnvConfig.get_known_choices().keys())
|
||||||
|
assert len(known) >= 6
|
||||||
|
for t in known:
|
||||||
|
cfg = make_env_config(t)
|
||||||
|
assert cfg.type == t
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_type():
|
||||||
|
with pytest.raises(ValueError, match="not registered"):
|
||||||
|
make_env_config("nonexistent")
|
||||||
|
|
||||||
|
|
||||||
|
def test_identity_processors():
|
||||||
|
"""Base class get_env_processors() returns identity pipelines."""
|
||||||
|
cfg = make_env_config("aloha")
|
||||||
|
pre, post = cfg.get_env_processors()
|
||||||
|
assert len(pre.steps) == 0 and len(post.steps) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_delegation():
|
||||||
|
"""make_env() should call cfg.create_envs(), not use if/elif dispatch."""
|
||||||
|
sentinel = {"delegated": {0: "marker"}}
|
||||||
|
fake = type(
|
||||||
|
"Fake",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"hub_path": None,
|
||||||
|
"create_envs": lambda self, n_envs, use_async_envs=False: sentinel,
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
result = make_env(fake, n_envs=1)
|
||||||
|
assert result is sentinel
|
||||||
|
|
||||||
|
|
||||||
|
def test_processors_delegation():
|
||||||
|
"""make_env_pre_post_processors delegates to cfg.get_env_processors()."""
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
|
||||||
|
cfg = make_env_config("aloha")
|
||||||
|
pre, post = make_env_pre_post_processors(cfg, PreTrainedConfig())
|
||||||
|
assert len(pre.steps) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_create_envs():
|
||||||
|
"""Base class create_envs() should build a single-task VectorEnv via gym.make()."""
|
||||||
|
gym_id = "_dispatch_test/CartPole-v99"
|
||||||
|
if gym_id not in gym_registry:
|
||||||
|
register(id=gym_id, entry_point="gymnasium.envs.classic_control:CartPoleEnv")
|
||||||
|
|
||||||
|
@EnvConfig.register_subclass("_dispatch_base_test")
|
||||||
|
@dataclass
|
||||||
|
class _Env(EnvConfig):
|
||||||
|
task: str = "CartPole-v99"
|
||||||
|
fps: int = 10
|
||||||
|
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def package_name(self):
|
||||||
|
return "_dispatch_test"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gym_id(self):
|
||||||
|
return gym_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gym_kwargs(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
envs = _Env().create_envs(n_envs=2)
|
||||||
|
assert "_dispatch_base_test" in envs
|
||||||
|
env = envs["_dispatch_base_test"][0]
|
||||||
|
assert isinstance(env, gym.vector.SyncVectorEnv)
|
||||||
|
assert env.num_envs == 2
|
||||||
|
env.close()
|
||||||
|
finally:
|
||||||
|
if gym_id in gym_registry:
|
||||||
|
del gym_registry[gym_id]
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_create_envs_override():
|
||||||
|
"""A custom EnvConfig subclass can override create_envs()."""
|
||||||
|
mock_vec = gym.vector.SyncVectorEnv([lambda: gym.make("CartPole-v1")])
|
||||||
|
|
||||||
|
@EnvConfig.register_subclass("_dispatch_custom_test")
|
||||||
|
@dataclass
|
||||||
|
class _Env(EnvConfig):
|
||||||
|
task: str = "x"
|
||||||
|
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gym_kwargs(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def create_envs(self, n_envs, use_async_envs=False):
|
||||||
|
return {"custom_suite": {0: mock_vec}}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = make_env(_Env(), n_envs=1)
|
||||||
|
assert "custom_suite" in result
|
||||||
|
finally:
|
||||||
|
mock_vec.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_get_env_processors_override():
|
||||||
|
"""A custom EnvConfig subclass can override get_env_processors()."""
|
||||||
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
|
|
||||||
|
@EnvConfig.register_subclass("_dispatch_proc_test")
|
||||||
|
@dataclass
|
||||||
|
class _Env(EnvConfig):
|
||||||
|
task: str = "x"
|
||||||
|
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gym_kwargs(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_env_processors(self):
|
||||||
|
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
|
||||||
|
|
||||||
|
pre, post = _Env().get_env_processors()
|
||||||
|
assert isinstance(pre, PolicyProcessorPipeline)
|
||||||
Reference in New Issue
Block a user