mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2c78c46fcd |
@@ -1,86 +0,0 @@
|
|||||||
# LeRobot — Claude Code Instructions
|
|
||||||
|
|
||||||
You are a senior robotics ML engineer reviewing code for **LeRobot**, a PyTorch framework for real-world robot learning.
|
|
||||||
Apply these principles to every PR review, fix, or task.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Core Abstractions
|
|
||||||
|
|
||||||
These are the load-bearing types. Handle them with care — breaking changes here affect every user.
|
|
||||||
|
|
||||||
| Type | Location | Role |
|
|
||||||
| ---------------- | ---------------------------- | ------------------------------------------------------------ |
|
|
||||||
| `LeRobotDataset` | `src/lerobot/datasets/` | Streaming replay buffer; HF Hub integration |
|
|
||||||
| `Policy` | `src/lerobot/policies/` | Base class for all learning agents (ACT, Diffusion, SARM, …) |
|
|
||||||
| `Robot` | `src/lerobot/robots/` | Hardware abstraction; carries `_output_pipeline` |
|
|
||||||
| `Teleoperator` | `src/lerobot/teleoperators/` | Leader-side hardware abstraction; carries `_output_pipeline` |
|
|
||||||
| `Env` | `src/lerobot/envs/` | Gym-like robotics environments |
|
|
||||||
| `Processor` | `src/lerobot/processor/` | Data transformation pipelines attached to robots/teleops |
|
|
||||||
|
|
||||||
**Never break their public APIs without a migration note and explicit user approval.**
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Engineering Principles
|
|
||||||
|
|
||||||
### Code quality
|
|
||||||
|
|
||||||
- Explicit over magic — no hidden control flow, no implicit state.
|
|
||||||
- No deep inheritance trees. Prefer composition.
|
|
||||||
- No decorative comment separators (`===`, `---`, etc.).
|
|
||||||
- Add comments only where the logic is non-obvious.
|
|
||||||
- No over-engineering. YAGNI applies strictly.
|
|
||||||
|
|
||||||
### Type safety
|
|
||||||
|
|
||||||
- All new and modified Python code must be fully typed (PEP 484).
|
|
||||||
- `mypy --strict` must pass on changed files.
|
|
||||||
- Do not widen or weaken existing type signatures.
|
|
||||||
|
|
||||||
### Backwards compatibility
|
|
||||||
|
|
||||||
- Public API changes require migration notes.
|
|
||||||
- Additive changes are preferred over modifications.
|
|
||||||
- `so100_follower` / `so101_follower` are aliases — never bleed changes there unintentionally.
|
|
||||||
|
|
||||||
### HF ecosystem
|
|
||||||
|
|
||||||
- Use `push_to_hub()`, HF Hub dataset streaming, and `evaluate` scripts.
|
|
||||||
- Dataset changes must preserve streaming compatibility.
|
|
||||||
- Prefer reusing HF primitives over rolling custom solutions.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## PR Review Checklist
|
|
||||||
|
|
||||||
Before approving or marking P1 issues resolved, verify:
|
|
||||||
|
|
||||||
- [ ] `pre-commit run -a` would pass (ruff, mypy, typos, zizmor, bandit)
|
|
||||||
- [ ] All new/modified code is typed and passes `mypy --strict`
|
|
||||||
- [ ] New features have unit tests; no silent behavioral changes
|
|
||||||
- [ ] Public APIs of `LeRobotDataset`, `Policy`, `Robot`, `Teleoperator`, `Env` are unchanged (or migration note present)
|
|
||||||
- [ ] HF Hub streaming still works for dataset changes
|
|
||||||
- [ ] No unnecessary abstractions introduced
|
|
||||||
- [ ] No breaking changes to training scripts (`lerobot-train`, `lerobot-eval`, `lerobot-record`)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ML-Specific Checks
|
|
||||||
|
|
||||||
Flag these as **P1** if found:
|
|
||||||
|
|
||||||
- **Data leakage**: train and val/test splits must be constructed before any normalization or augmentation that uses train statistics.
|
|
||||||
- **Loss function errors**: verify reduction mode (`mean` vs `sum`), correct masking, correct shape alignment.
|
|
||||||
- **Gradient flow**: new modules must have gradients flowing (check `requires_grad`, no detached tensors in the loss path by accident).
|
|
||||||
- **Distributed training**: operations on tensors must be DDP-safe; no in-place ops on parameters; batch norm needs `SyncBatchNorm` if used.
|
|
||||||
- **Memory leaks**: no accumulation of tensors outside the training loop; `optimizer.zero_grad()` called correctly.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## What to Skip
|
|
||||||
|
|
||||||
- Don't flag style nitpicks on unchanged surrounding code.
|
|
||||||
- Don't propose refactors outside the PR's scope.
|
|
||||||
- Don't add docstrings or comments to code the PR didn't touch.
|
|
||||||
- Don't suggest speculative future features (YAGNI).
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
name: Claude Code Review
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
types: [opened, synchronize, ready_for_review, reopened]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
claude-review:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
pull-requests: write
|
|
||||||
issues: read
|
|
||||||
id-token: write
|
|
||||||
actions: read
|
|
||||||
env:
|
|
||||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Run Claude Code Review
|
|
||||||
id: claude-review
|
|
||||||
uses: anthropics/claude-code-action@26ddc358fe3befff50c5ec2f80304c90c763f6f8 # v1
|
|
||||||
with:
|
|
||||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
|
||||||
use_sticky_comment: true
|
|
||||||
prompt: |
|
|
||||||
Read `.github/CLAUDE.md` for lerobot-specific conventions, then review this PR.
|
|
||||||
Provide structured, actionable feedback.
|
|
||||||
|
|
||||||
Focus areas (in priority order):
|
|
||||||
1. **Correctness**: Logic errors, off-by-ones, wrong tensor shapes, incorrect loss functions
|
|
||||||
2. **Type safety**: All new/modified Python code must pass `mypy --strict`; check for missing annotations
|
|
||||||
3. **Backwards compatibility**: Does this break `LeRobotDataset`, `Policy`, `Robot`, `Teleoperator`, `Env`, or `Processor` public APIs?
|
|
||||||
4. **Tests**: New features must have tests; no silent behavioral changes
|
|
||||||
5. **Code style**: Explicit over magic, no unnecessary abstractions, no decorative comments
|
|
||||||
6. **HF integration**: Dataset streaming, `push_to_hub`, HF Hub compatibility preserved?
|
|
||||||
7. **pre-commit**: Would `pre-commit run -a` pass? (ruff, mypy, typos, zizmor)
|
|
||||||
|
|
||||||
Format findings as P1 (must fix) / P2 (should fix) / P3 (nice to have).
|
|
||||||
Skip P3 if the PR is already high quality.
|
|
||||||
claude_args: '--model claude-opus-4-6'
|
|
||||||
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
|
|
||||||
# or https://code.claude.com/docs/en/cli-reference for available options
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
name: Claude Code
|
|
||||||
|
|
||||||
on:
|
|
||||||
issue_comment:
|
|
||||||
types: [created]
|
|
||||||
pull_request_review_comment:
|
|
||||||
types: [created]
|
|
||||||
issues:
|
|
||||||
types: [opened, assigned]
|
|
||||||
pull_request_review:
|
|
||||||
types: [submitted]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
claude:
|
|
||||||
if: |
|
|
||||||
(github.event_name == 'issue_comment' &&
|
|
||||||
contains(github.event.comment.body, '@claude') &&
|
|
||||||
(github.event.comment.author_association == 'OWNER' || github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'COLLABORATOR')) ||
|
|
||||||
(github.event_name == 'pull_request_review_comment' &&
|
|
||||||
contains(github.event.comment.body, '@claude') &&
|
|
||||||
(github.event.comment.author_association == 'OWNER' || github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'COLLABORATOR')) ||
|
|
||||||
(github.event_name == 'pull_request_review' &&
|
|
||||||
contains(github.event.review.body, '@claude') &&
|
|
||||||
(github.event.review.author_association == 'OWNER' || github.event.review.author_association == 'MEMBER' || github.event.review.author_association == 'COLLABORATOR')) ||
|
|
||||||
(github.event_name == 'issues' &&
|
|
||||||
(contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')) &&
|
|
||||||
(github.event.issue.author_association == 'OWNER' || github.event.issue.author_association == 'MEMBER' || github.event.issue.author_association == 'COLLABORATOR'))
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
pull-requests: write
|
|
||||||
issues: write
|
|
||||||
id-token: write
|
|
||||||
actions: read
|
|
||||||
env:
|
|
||||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 1
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Run Claude Code
|
|
||||||
id: claude
|
|
||||||
uses: anthropics/claude-code-action@26ddc358fe3befff50c5ec2f80304c90c763f6f8 # v1
|
|
||||||
with:
|
|
||||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
|
||||||
use_sticky_comment: true
|
|
||||||
|
|
||||||
# This is an optional setting that allows Claude to read CI results on PRs
|
|
||||||
additional_permissions: |
|
|
||||||
actions: read
|
|
||||||
|
|
||||||
claude_args: '--system-prompt "Read .github/CLAUDE.md for lerobot-specific conventions before responding."'
|
|
||||||
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
|
|
||||||
# or https://code.claude.com/docs/en/cli-reference for available options
|
|
||||||
@@ -115,22 +115,23 @@ Each `EnvConfig` subclass declares two dicts that tell the policy what to expect
|
|||||||
## Step by step
|
## Step by step
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
At minimum, you need two files: a **gym.Env wrapper** and an **EnvConfig
|
At minimum, you need three files: a **gym.Env wrapper**, an **EnvConfig
|
||||||
subclass** with a `create_envs()` override. Everything else is optional or
|
subclass**, and a **factory dispatch branch**. Everything else is optional or
|
||||||
documentation. No changes to `factory.py` are needed.
|
documentation.
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
### Checklist
|
### Checklist
|
||||||
|
|
||||||
| File | Required | Why |
|
| File | Required | Why |
|
||||||
| ---------------------------------------- | -------- | ------------------------------------------------------------ |
|
| ---------------------------------------- | -------- | ----------------------------------------- |
|
||||||
| `src/lerobot/envs/<benchmark>.py` | Yes | Wraps the simulator as a standard gym.Env |
|
| `src/lerobot/envs/<benchmark>.py` | Yes | Wraps the simulator as a standard gym.Env |
|
||||||
| `src/lerobot/envs/configs.py` | Yes | Registers your benchmark and its `create_envs()` for the CLI |
|
| `src/lerobot/envs/configs.py` | Yes | Registers your benchmark for the CLI |
|
||||||
| `src/lerobot/processor/env_processor.py` | Optional | Custom observation/action transforms |
|
| `src/lerobot/envs/factory.py` | Yes | Tells `make_env()` how to build your envs |
|
||||||
| `src/lerobot/envs/utils.py` | Optional | Only if you need new raw observation keys |
|
| `src/lerobot/processor/env_processor.py` | Optional | Custom observation/action transforms |
|
||||||
| `pyproject.toml` | Yes | Declares benchmark-specific dependencies |
|
| `src/lerobot/envs/utils.py` | Optional | Only if you need new raw observation keys |
|
||||||
| `docs/source/<benchmark>.mdx` | Yes | User-facing documentation page |
|
| `pyproject.toml` | Yes | Declares benchmark-specific dependencies |
|
||||||
| `docs/source/_toctree.yml` | Yes | Adds your page to the docs sidebar |
|
| `docs/source/<benchmark>.mdx` | Yes | User-facing documentation page |
|
||||||
|
| `docs/source/_toctree.yml` | Yes | Adds your page to the docs sidebar |
|
||||||
|
|
||||||
### 1. The gym.Env wrapper (`src/lerobot/envs/<benchmark>.py`)
|
### 1. The gym.Env wrapper (`src/lerobot/envs/<benchmark>.py`)
|
||||||
|
|
||||||
@@ -178,10 +179,7 @@ See `create_libero_envs()` (multi-suite, multi-task) and `create_metaworld_envs(
|
|||||||
|
|
||||||
### 2. The config (`src/lerobot/envs/configs.py`)
|
### 2. The config (`src/lerobot/envs/configs.py`)
|
||||||
|
|
||||||
Register a config dataclass so users can select your benchmark with `--env.type=<name>`. Each config owns its environment creation and processor logic via two methods:
|
Register a config dataclass so users can select your benchmark with `--env.type=<name>`:
|
||||||
|
|
||||||
- **`create_envs(n_envs, use_async_envs)`** — Returns `{suite: {task_id: VectorEnv}}`. The base class default uses `gym.make()` for single-task envs. Multi-task benchmarks override this.
|
|
||||||
- **`get_env_processors()`** — Returns `(preprocessor, postprocessor)`. The base class default returns identity (no-op) pipelines. Override if your benchmark needs observation/action transforms.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@EnvConfig.register_subclass("<benchmark_name>")
|
@EnvConfig.register_subclass("<benchmark_name>")
|
||||||
@@ -206,20 +204,6 @@ class MyBenchmarkEnvConfig(EnvConfig):
|
|||||||
@property
|
@property
|
||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
return {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
return {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
||||||
|
|
||||||
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
|
||||||
"""Override for multi-task benchmarks or custom env creation."""
|
|
||||||
from lerobot.envs.<benchmark> import create_<benchmark>_envs
|
|
||||||
return create_<benchmark>_envs(task=self.task, n_envs=n_envs, ...)
|
|
||||||
|
|
||||||
def get_env_processors(self):
|
|
||||||
"""Override if your benchmark needs observation/action transforms."""
|
|
||||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
|
||||||
from lerobot.processor.env_processor import MyBenchmarkProcessorStep
|
|
||||||
return (
|
|
||||||
PolicyProcessorPipeline(steps=[MyBenchmarkProcessorStep()]),
|
|
||||||
PolicyProcessorPipeline(steps=[]),
|
|
||||||
)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Key points:
|
Key points:
|
||||||
@@ -227,11 +211,36 @@ Key points:
|
|||||||
- The `register_subclass` name is what users pass on the CLI (`--env.type=<name>`).
|
- The `register_subclass` name is what users pass on the CLI (`--env.type=<name>`).
|
||||||
- `features` tells the policy what the environment produces.
|
- `features` tells the policy what the environment produces.
|
||||||
- `features_map` maps raw observation keys to LeRobot convention keys.
|
- `features_map` maps raw observation keys to LeRobot convention keys.
|
||||||
- **No changes to `factory.py` needed** — the factory delegates to `cfg.create_envs()` and `cfg.get_env_processors()` automatically.
|
|
||||||
|
|
||||||
### 3. Env processor (optional — `src/lerobot/processor/env_processor.py`)
|
### 3. The factory dispatch (`src/lerobot/envs/factory.py`)
|
||||||
|
|
||||||
Only needed if your benchmark requires observation transforms beyond what `preprocess_observation()` handles (e.g. image flipping, coordinate conversion). Define the processor step here and return it from `get_env_processors()` in your config (see step 2):
|
Add a branch in `make_env()` to call your factory function:
|
||||||
|
|
||||||
|
```python
|
||||||
|
elif "<benchmark_name>" in cfg.type:
|
||||||
|
from lerobot.envs.<benchmark> import create_<benchmark>_envs
|
||||||
|
|
||||||
|
if cfg.task is None:
|
||||||
|
raise ValueError("<BenchmarkName> requires a task to be specified")
|
||||||
|
|
||||||
|
return create_<benchmark>_envs(
|
||||||
|
task=cfg.task,
|
||||||
|
n_envs=n_envs,
|
||||||
|
gym_kwargs=cfg.gym_kwargs,
|
||||||
|
env_cls=env_cls,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
If your benchmark needs an env processor, add it in `make_env_pre_post_processors()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
if isinstance(env_cfg, MyBenchmarkEnvConfig) or "<benchmark_name>" in env_cfg.type:
|
||||||
|
preprocessor_steps.append(MyBenchmarkProcessorStep())
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Env processor (optional — `src/lerobot/processor/env_processor.py`)
|
||||||
|
|
||||||
|
Only needed if your benchmark requires observation transforms beyond what `preprocess_observation()` handles (e.g. image flipping, coordinate conversion):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -251,7 +260,7 @@ class MyBenchmarkProcessorStep(ObservationProcessorStep):
|
|||||||
|
|
||||||
See `LiberoProcessorStep` for a full example (image rotation, quaternion-to-axis-angle conversion).
|
See `LiberoProcessorStep` for a full example (image rotation, quaternion-to-axis-angle conversion).
|
||||||
|
|
||||||
### 4. Dependencies (`pyproject.toml`)
|
### 5. Dependencies (`pyproject.toml`)
|
||||||
|
|
||||||
Add a new optional-dependency group:
|
Add a new optional-dependency group:
|
||||||
|
|
||||||
@@ -272,11 +281,11 @@ Users install with:
|
|||||||
pip install -e ".[mybenchmark]"
|
pip install -e ".[mybenchmark]"
|
||||||
```
|
```
|
||||||
|
|
||||||
### 5. Documentation (`docs/source/<benchmark>.mdx`)
|
### 6. Documentation (`docs/source/<benchmark>.mdx`)
|
||||||
|
|
||||||
Write a user-facing page following the template in the next section. See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for full examples.
|
Write a user-facing page following the template in the next section. See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for full examples.
|
||||||
|
|
||||||
### 6. Table of contents (`docs/source/_toctree.yml`)
|
### 7. Table of contents (`docs/source/_toctree.yml`)
|
||||||
|
|
||||||
Add your benchmark to the "Benchmarks" section:
|
Add your benchmark to the "Benchmarks" section:
|
||||||
|
|
||||||
|
|||||||
@@ -90,17 +90,11 @@ The same policy can work with different environment processors, and the same env
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
# Use SmolVLA policy with LIBERO environment
|
# Use SmolVLA policy with LIBERO environment
|
||||||
# Use SmolVLA policy with LIBERO environment
|
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
|
||||||
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(
|
|
||||||
env_cfg=libero_cfg,
|
|
||||||
policy_cfg=smolvla_cfg,
|
|
||||||
)
|
|
||||||
smolvla_preprocessor, smolvla_postprocessor = make_pre_post_processors(smolvla_cfg)
|
smolvla_preprocessor, smolvla_postprocessor = make_pre_post_processors(smolvla_cfg)
|
||||||
|
|
||||||
# Or use ACT policy with the same LIBERO environment
|
# Or use ACT policy with the same LIBERO environment
|
||||||
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(
|
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
|
||||||
env_cfg=libero_cfg,
|
|
||||||
policy_cfg=act_cfg,
|
|
||||||
)
|
|
||||||
act_preprocessor, act_postprocessor = make_pre_post_processors(act_cfg)
|
act_preprocessor, act_postprocessor = make_pre_post_processors(act_cfg)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -157,7 +151,7 @@ observation = {
|
|||||||
|
|
||||||
### Factory Function
|
### Factory Function
|
||||||
|
|
||||||
The `make_env_pre_post_processors` function delegates to `env_cfg.get_env_processors()`:
|
The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from lerobot.envs.factory import make_env_pre_post_processors
|
from lerobot.envs.factory import make_env_pre_post_processors
|
||||||
@@ -165,30 +159,46 @@ from lerobot.envs.configs import LiberoEnv, PushtEnv
|
|||||||
|
|
||||||
# For LIBERO: Returns LiberoProcessorStep in preprocessor
|
# For LIBERO: Returns LiberoProcessorStep in preprocessor
|
||||||
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
|
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
|
||||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg, policy_cfg)
|
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
|
||||||
|
|
||||||
# For other environments: Returns identity processors (no-op)
|
# For other environments: Returns identity processors (no-op)
|
||||||
pusht_cfg = PushtEnv()
|
pusht_cfg = PushtEnv()
|
||||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg, policy_cfg)
|
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
|
||||||
```
|
```
|
||||||
|
|
||||||
### How It Works
|
### Implementation in `envs/factory.py`
|
||||||
|
|
||||||
Each `EnvConfig` subclass can override `get_env_processors()` to return benchmark-specific
|
|
||||||
processor pipelines. The base class returns identity (no-op) processors by default.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# In your EnvConfig subclass:
|
def make_env_pre_post_processors(
|
||||||
def get_env_processors(self):
|
env_cfg: EnvConfig,
|
||||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
) -> tuple[
|
||||||
return (
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
PolicyProcessorPipeline(steps=[MyProcessorStep()]),
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
PolicyProcessorPipeline(steps=[]),
|
]:
|
||||||
)
|
"""
|
||||||
```
|
Create preprocessor and postprocessor pipelines for environment observations.
|
||||||
|
|
||||||
The factory function `make_env_pre_post_processors` simply delegates to this method,
|
Args:
|
||||||
with a special case for `XVLAConfig` policies which override the env processors entirely.
|
env_cfg: The configuration of the environment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing:
|
||||||
|
- preprocessor: Pipeline that processes environment observations
|
||||||
|
- postprocessor: Pipeline that processes environment outputs
|
||||||
|
"""
|
||||||
|
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||||
|
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||||
|
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
|
||||||
|
else:
|
||||||
|
# For all other environments, return an identity preprocessor
|
||||||
|
preprocessor = PolicyProcessorPipeline(steps=[])
|
||||||
|
|
||||||
|
# Postprocessor is currently identity for all environments
|
||||||
|
# Future: Could add environment-specific action transformations
|
||||||
|
postprocessor = PolicyProcessorPipeline(steps=[])
|
||||||
|
|
||||||
|
return preprocessor, postprocessor
|
||||||
|
```
|
||||||
|
|
||||||
### Integration in Evaluation
|
### Integration in Evaluation
|
||||||
|
|
||||||
@@ -209,10 +219,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create environment processors (NEW!)
|
# Create environment processors (NEW!)
|
||||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(
|
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
||||||
env_cfg=cfg.env,
|
|
||||||
policy_cfg=cfg.policy,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run evaluation with both processor types
|
# Run evaluation with both processor types
|
||||||
eval_policy_all(
|
eval_policy_all(
|
||||||
@@ -316,22 +323,21 @@ class MyEnvProcessorStep(ObservationProcessorStep):
|
|||||||
return processed
|
return processed
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Update Your `EnvConfig` Subclass
|
### 2. Update the Factory
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# In src/lerobot/envs/configs.py
|
# In src/lerobot/envs/factory.py
|
||||||
@EnvConfig.register_subclass("myenv")
|
|
||||||
@dataclass
|
|
||||||
class MyEnvConfig(EnvConfig):
|
|
||||||
# ... task/features/gym kwargs ...
|
|
||||||
|
|
||||||
def get_env_processors(self):
|
def make_env_pre_post_processors(env_cfg: EnvConfig):
|
||||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||||
|
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
|
||||||
|
elif isinstance(env_cfg, MyEnvConfig) or "myenv" in env_cfg.type:
|
||||||
|
preprocessor = PolicyProcessorPipeline(steps=[MyEnvProcessorStep()])
|
||||||
|
else:
|
||||||
|
preprocessor = PolicyProcessorPipeline(steps=[])
|
||||||
|
|
||||||
return (
|
postprocessor = PolicyProcessorPipeline(steps=[])
|
||||||
PolicyProcessorPipeline(steps=[MyEnvProcessorStep()]),
|
return preprocessor, postprocessor
|
||||||
PolicyProcessorPipeline(steps=[]),
|
|
||||||
)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. Use in Evaluation
|
### 3. Use in Evaluation
|
||||||
|
|||||||
@@ -72,6 +72,8 @@ class DatasetReader:
|
|||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self._tolerance_s = tolerance_s
|
self._tolerance_s = tolerance_s
|
||||||
self._video_backend = video_backend
|
self._video_backend = video_backend
|
||||||
|
if image_transforms is not None and not callable(image_transforms):
|
||||||
|
raise TypeError("image_transforms must be callable or None.")
|
||||||
self._image_transforms = image_transforms
|
self._image_transforms = image_transforms
|
||||||
|
|
||||||
self.hf_dataset: datasets.Dataset | None = None
|
self.hf_dataset: datasets.Dataset | None = None
|
||||||
@@ -83,6 +85,16 @@ class DatasetReader:
|
|||||||
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
||||||
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
||||||
|
|
||||||
|
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||||
|
"""Replace the transform applied to visual observations."""
|
||||||
|
if image_transforms is not None and not callable(image_transforms):
|
||||||
|
raise TypeError("image_transforms must be callable or None.")
|
||||||
|
self._image_transforms = image_transforms
|
||||||
|
|
||||||
|
def clear_image_transforms(self) -> None:
|
||||||
|
"""Remove the transform applied to visual observations."""
|
||||||
|
self._image_transforms = None
|
||||||
|
|
||||||
def try_load(self) -> bool:
|
def try_load(self) -> bool:
|
||||||
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -194,8 +194,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self._requested_root = Path(root) if root else None
|
self._requested_root = Path(root) if root else None
|
||||||
self.reader = None
|
|
||||||
self.set_image_transforms(image_transforms)
|
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
@@ -225,6 +223,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
delta_timestamps=delta_timestamps,
|
delta_timestamps=delta_timestamps,
|
||||||
image_transforms=image_transforms,
|
image_transforms=image_transforms,
|
||||||
)
|
)
|
||||||
|
self.image_transforms = image_transforms
|
||||||
|
|
||||||
# Load actual data
|
# Load actual data
|
||||||
if force_cache_sync or not self.reader.try_load():
|
if force_cache_sync or not self.reader.try_load():
|
||||||
@@ -480,15 +479,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||||
"""Replace the transform applied to visual observations."""
|
"""Replace the transform applied to visual observations."""
|
||||||
if image_transforms is not None and not callable(image_transforms):
|
self._ensure_reader().set_image_transforms(image_transforms)
|
||||||
raise TypeError("image_transforms must be callable or None.")
|
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
if self.reader is not None:
|
|
||||||
self.reader._image_transforms = image_transforms
|
|
||||||
|
|
||||||
def clear_image_transforms(self) -> None:
|
def clear_image_transforms(self) -> None:
|
||||||
"""Remove the transform applied to visual observations."""
|
"""Remove the transform applied to visual observations."""
|
||||||
self.set_image_transforms(None)
|
if self.reader is not None:
|
||||||
|
self.reader.set_image_transforms(None)
|
||||||
|
self.image_transforms = None
|
||||||
|
|
||||||
# ── Hub methods (stay on facade) ──────────────────────────────────
|
# ── Hub methods (stay on facade) ──────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
@@ -12,16 +12,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import importlib
|
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
import gymnasium as gym
|
|
||||||
from gymnasium.envs.registration import registry as gym_registry
|
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
from lerobot.robots import RobotConfig
|
from lerobot.robots import RobotConfig
|
||||||
@@ -72,49 +67,6 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def create_envs(
|
|
||||||
self,
|
|
||||||
n_envs: int,
|
|
||||||
use_async_envs: bool = False,
|
|
||||||
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
|
||||||
"""Create {suite: {task_id: VectorEnv}}.
|
|
||||||
|
|
||||||
Default: single-task env via gym.make(). Multi-task benchmarks override.
|
|
||||||
"""
|
|
||||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
|
||||||
|
|
||||||
if self.gym_id not in gym_registry:
|
|
||||||
print(f"gym id '{self.gym_id}' not found, attempting to import '{self.package_name}'...")
|
|
||||||
try:
|
|
||||||
importlib.import_module(self.package_name)
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
raise ModuleNotFoundError(
|
|
||||||
f"Package '{self.package_name}' required for env '{self.type}' not found. "
|
|
||||||
f"Please install it or check PYTHONPATH."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if self.gym_id not in gym_registry:
|
|
||||||
raise gym.error.NameNotFound(
|
|
||||||
f"Environment '{self.gym_id}' not registered even after importing '{self.package_name}'."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _make_one():
|
|
||||||
return gym.make(self.gym_id, disable_env_checker=self.disable_env_checker, **self.gym_kwargs)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from gymnasium.vector import AutoresetMode
|
|
||||||
|
|
||||||
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=AutoresetMode.SAME_STEP)
|
|
||||||
except ImportError:
|
|
||||||
vec = env_cls([_make_one for _ in range(n_envs)])
|
|
||||||
return {self.type: {0: vec}}
|
|
||||||
|
|
||||||
def get_env_processors(self):
|
|
||||||
"""Return (preprocessor, postprocessor) for this env. Default: identity."""
|
|
||||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
|
||||||
|
|
||||||
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HubEnvConfig(EnvConfig):
|
class HubEnvConfig(EnvConfig):
|
||||||
@@ -386,12 +338,6 @@ class LiberoEnv(EnvConfig):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
||||||
|
|
||||||
if self.camera_name_mapping is not None:
|
|
||||||
mapped_agentview = self.camera_name_mapping.get("agentview_image", "image")
|
|
||||||
mapped_eye_in_hand = self.camera_name_mapping.get("robot0_eye_in_hand_image", "image2")
|
|
||||||
self.features_map[LIBERO_KEY_PIXELS_AGENTVIEW] = f"{OBS_IMAGES}.{mapped_agentview}"
|
|
||||||
self.features_map[LIBERO_KEY_PIXELS_EYE_IN_HAND] = f"{OBS_IMAGES}.{mapped_eye_in_hand}"
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
||||||
@@ -399,33 +345,6 @@ class LiberoEnv(EnvConfig):
|
|||||||
kwargs["task_ids"] = self.task_ids
|
kwargs["task_ids"] = self.task_ids
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
|
||||||
from lerobot.envs.libero import create_libero_envs
|
|
||||||
|
|
||||||
if self.task is None:
|
|
||||||
raise ValueError("LiberoEnv requires a task to be specified")
|
|
||||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
|
||||||
return create_libero_envs(
|
|
||||||
task=self.task,
|
|
||||||
n_envs=n_envs,
|
|
||||||
camera_name=self.camera_name,
|
|
||||||
init_states=self.init_states,
|
|
||||||
gym_kwargs=self.gym_kwargs,
|
|
||||||
env_cls=env_cls,
|
|
||||||
control_mode=self.control_mode,
|
|
||||||
episode_length=self.episode_length,
|
|
||||||
camera_name_mapping=self.camera_name_mapping,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_env_processors(self):
|
|
||||||
from lerobot.processor.env_processor import LiberoProcessorStep
|
|
||||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
|
||||||
|
|
||||||
return (
|
|
||||||
PolicyProcessorPipeline(steps=[LiberoProcessorStep()]),
|
|
||||||
PolicyProcessorPipeline(steps=[]),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass("metaworld")
|
@EnvConfig.register_subclass("metaworld")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -468,19 +387,6 @@ class MetaworldEnv(EnvConfig):
|
|||||||
"render_mode": self.render_mode,
|
"render_mode": self.render_mode,
|
||||||
}
|
}
|
||||||
|
|
||||||
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
|
||||||
from lerobot.envs.metaworld import create_metaworld_envs
|
|
||||||
|
|
||||||
if self.task is None:
|
|
||||||
raise ValueError("MetaWorld requires a task to be specified")
|
|
||||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
|
||||||
return create_metaworld_envs(
|
|
||||||
task=self.task,
|
|
||||||
n_envs=n_envs,
|
|
||||||
gym_kwargs=self.gym_kwargs,
|
|
||||||
env_cls=env_cls,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass("isaaclab_arena")
|
@EnvConfig.register_subclass("isaaclab_arena")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -548,18 +454,3 @@ class IsaaclabArenaEnv(HubEnvConfig):
|
|||||||
@property
|
@property
|
||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def get_env_processors(self):
|
|
||||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep
|
|
||||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
|
||||||
|
|
||||||
state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip())
|
|
||||||
camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip())
|
|
||||||
if not state_keys and not camera_keys:
|
|
||||||
raise ValueError("At least one of state_keys or camera_keys must be specified.")
|
|
||||||
return (
|
|
||||||
PolicyProcessorPipeline(
|
|
||||||
steps=[IsaaclabArenaProcessorStep(state_keys=state_keys, camera_keys=camera_keys)]
|
|
||||||
),
|
|
||||||
PolicyProcessorPipeline(steps=[]),
|
|
||||||
)
|
|
||||||
|
|||||||
+117
-20
@@ -13,46 +13,90 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import annotations
|
import importlib
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
from gymnasium.envs.registration import registry as gym_registry
|
||||||
|
|
||||||
from lerobot.envs.configs import EnvConfig, HubEnvConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
|
||||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||||
|
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
|
from lerobot.processor import ProcessorStep
|
||||||
|
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
|
||||||
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
|
|
||||||
|
|
||||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||||
try:
|
if env_type == "aloha":
|
||||||
cls = EnvConfig.get_choice_class(env_type)
|
return AlohaEnv(**kwargs)
|
||||||
except KeyError as err:
|
elif env_type == "pusht":
|
||||||
raise ValueError(
|
return PushtEnv(**kwargs)
|
||||||
f"Environment type '{env_type}' is not registered. "
|
elif env_type == "libero":
|
||||||
f"Available: {list(EnvConfig.get_known_choices().keys())}"
|
return LiberoEnv(**kwargs)
|
||||||
) from err
|
else:
|
||||||
return cls(**kwargs)
|
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||||
|
|
||||||
|
|
||||||
def make_env_pre_post_processors(
|
def make_env_pre_post_processors(
|
||||||
env_cfg: EnvConfig,
|
env_cfg: EnvConfig,
|
||||||
policy_cfg: Any,
|
policy_cfg: PreTrainedConfig,
|
||||||
) -> tuple[Any, Any]:
|
) -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Create preprocessor and postprocessor pipelines for environment observations.
|
Create preprocessor and postprocessor pipelines for environment observations.
|
||||||
|
|
||||||
Returns a tuple of (preprocessor, postprocessor). By default, delegates to
|
This function creates processor pipelines that transform raw environment
|
||||||
``env_cfg.get_env_processors()``. The XVLAConfig policy-specific override
|
observations and actions. By default, it returns identity processors that do nothing.
|
||||||
stays here because it depends on the *policy* config, not the env config.
|
For specific environments like LIBERO, it adds environment-specific processing steps.
|
||||||
"""
|
|
||||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_cfg: The configuration of the environment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing:
|
||||||
|
- preprocessor: Pipeline that processes environment observations
|
||||||
|
- postprocessor: Pipeline that processes environment outputs (currently identity)
|
||||||
|
"""
|
||||||
|
# Preprocessor and Postprocessor steps are Identity for most environments
|
||||||
|
preprocessor_steps: list[ProcessorStep] = []
|
||||||
|
postprocessor_steps: list[ProcessorStep] = []
|
||||||
if isinstance(policy_cfg, XVLAConfig):
|
if isinstance(policy_cfg, XVLAConfig):
|
||||||
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
|
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
|
||||||
|
|
||||||
return make_xvla_libero_pre_post_processors()
|
return make_xvla_libero_pre_post_processors()
|
||||||
|
|
||||||
return env_cfg.get_env_processors()
|
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||||
|
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||||
|
preprocessor_steps.append(LiberoProcessorStep())
|
||||||
|
|
||||||
|
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
|
||||||
|
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
|
||||||
|
# Parse comma-separated keys (handle None for state-based policies)
|
||||||
|
if env_cfg.state_keys:
|
||||||
|
state_keys = tuple(k.strip() for k in env_cfg.state_keys.split(",") if k.strip())
|
||||||
|
else:
|
||||||
|
state_keys = ()
|
||||||
|
if env_cfg.camera_keys:
|
||||||
|
camera_keys = tuple(k.strip() for k in env_cfg.camera_keys.split(",") if k.strip())
|
||||||
|
else:
|
||||||
|
camera_keys = ()
|
||||||
|
if not state_keys and not camera_keys:
|
||||||
|
raise ValueError("At least one of state_keys or camera_keys must be specified.")
|
||||||
|
preprocessor_steps.append(
|
||||||
|
IsaaclabArenaProcessorStep(
|
||||||
|
state_keys=state_keys,
|
||||||
|
camera_keys=camera_keys,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
|
||||||
|
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
|
||||||
|
|
||||||
|
return preprocessor, postprocessor
|
||||||
|
|
||||||
|
|
||||||
def make_env(
|
def make_env(
|
||||||
@@ -119,4 +163,57 @@ def make_env(
|
|||||||
if n_envs < 1:
|
if n_envs < 1:
|
||||||
raise ValueError("`n_envs` must be at least 1")
|
raise ValueError("`n_envs` must be at least 1")
|
||||||
|
|
||||||
return cfg.create_envs(n_envs=n_envs, use_async_envs=use_async_envs)
|
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||||
|
|
||||||
|
if "libero" in cfg.type:
|
||||||
|
from lerobot.envs.libero import create_libero_envs
|
||||||
|
|
||||||
|
if cfg.task is None:
|
||||||
|
raise ValueError("LiberoEnv requires a task to be specified")
|
||||||
|
|
||||||
|
return create_libero_envs(
|
||||||
|
task=cfg.task,
|
||||||
|
n_envs=n_envs,
|
||||||
|
camera_name=cfg.camera_name,
|
||||||
|
init_states=cfg.init_states,
|
||||||
|
gym_kwargs=cfg.gym_kwargs,
|
||||||
|
env_cls=env_cls,
|
||||||
|
control_mode=cfg.control_mode,
|
||||||
|
episode_length=cfg.episode_length,
|
||||||
|
)
|
||||||
|
elif "metaworld" in cfg.type:
|
||||||
|
from lerobot.envs.metaworld import create_metaworld_envs
|
||||||
|
|
||||||
|
if cfg.task is None:
|
||||||
|
raise ValueError("MetaWorld requires a task to be specified")
|
||||||
|
|
||||||
|
return create_metaworld_envs(
|
||||||
|
task=cfg.task,
|
||||||
|
n_envs=n_envs,
|
||||||
|
gym_kwargs=cfg.gym_kwargs,
|
||||||
|
env_cls=env_cls,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.gym_id not in gym_registry:
|
||||||
|
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
|
||||||
|
try:
|
||||||
|
importlib.import_module(cfg.package_name)
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
|
||||||
|
f"Please install it or check PYTHONPATH."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if cfg.gym_id not in gym_registry:
|
||||||
|
raise gym.error.NameNotFound(
|
||||||
|
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_one():
|
||||||
|
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
|
||||||
|
|
||||||
|
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
|
||||||
|
|
||||||
|
# normalize to {suite: {task_id: vec_env}} for consistency
|
||||||
|
suite_name = cfg.type # e.g., "pusht", "aloha"
|
||||||
|
return {suite_name: {0: vec}}
|
||||||
|
|||||||
@@ -223,8 +223,7 @@ class LiberoEnv(gym.Env):
|
|||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
raw_obs = self._env.env._get_observations()
|
raw_obs = self._env.env._get_observations()
|
||||||
pixels = self._format_raw_obs(raw_obs)["pixels"]
|
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
|
||||||
image = next(iter(pixels.values()))
|
|
||||||
image = image[::-1, ::-1] # flip both H and W for visualization
|
image = image[::-1, ::-1] # flip both H and W for visualization
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@@ -340,6 +339,12 @@ class LiberoEnv(gym.Env):
|
|||||||
)
|
)
|
||||||
observation = self._format_raw_obs(raw_obs)
|
observation = self._format_raw_obs(raw_obs)
|
||||||
if terminated:
|
if terminated:
|
||||||
|
info["final_info"] = {
|
||||||
|
"task": self.task,
|
||||||
|
"task_id": self.task_id,
|
||||||
|
"done": bool(done),
|
||||||
|
"is_success": bool(is_success),
|
||||||
|
}
|
||||||
self.reset()
|
self.reset()
|
||||||
truncated = False
|
truncated = False
|
||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
@@ -359,7 +364,6 @@ def _make_env_fns(
|
|||||||
init_states: bool,
|
init_states: bool,
|
||||||
gym_kwargs: Mapping[str, Any],
|
gym_kwargs: Mapping[str, Any],
|
||||||
control_mode: str,
|
control_mode: str,
|
||||||
camera_name_mapping: dict[str, str] | None = None,
|
|
||||||
) -> list[Callable[[], LiberoEnv]]:
|
) -> list[Callable[[], LiberoEnv]]:
|
||||||
"""Build n_envs factory callables for a single (suite, task_id)."""
|
"""Build n_envs factory callables for a single (suite, task_id)."""
|
||||||
|
|
||||||
@@ -375,7 +379,6 @@ def _make_env_fns(
|
|||||||
episode_index=episode_index,
|
episode_index=episode_index,
|
||||||
n_envs=n_envs,
|
n_envs=n_envs,
|
||||||
control_mode=control_mode,
|
control_mode=control_mode,
|
||||||
camera_name_mapping=camera_name_mapping,
|
|
||||||
**local_kwargs,
|
**local_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -397,7 +400,6 @@ def create_libero_envs(
|
|||||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||||
control_mode: str = "relative",
|
control_mode: str = "relative",
|
||||||
episode_length: int | None = None,
|
episode_length: int | None = None,
|
||||||
camera_name_mapping: dict[str, str] | None = None,
|
|
||||||
) -> dict[str, dict[int, Any]]:
|
) -> dict[str, dict[int, Any]]:
|
||||||
"""
|
"""
|
||||||
Create vectorized LIBERO environments with a consistent return shape.
|
Create vectorized LIBERO environments with a consistent return shape.
|
||||||
@@ -447,7 +449,6 @@ def create_libero_envs(
|
|||||||
init_states=init_states,
|
init_states=init_states,
|
||||||
gym_kwargs=gym_kwargs,
|
gym_kwargs=gym_kwargs,
|
||||||
control_mode=control_mode,
|
control_mode=control_mode,
|
||||||
camera_name_mapping=camera_name_mapping,
|
|
||||||
)
|
)
|
||||||
out[suite_name][tid] = env_cls(fns)
|
out[suite_name][tid] = env_cls(fns)
|
||||||
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||||
|
|||||||
@@ -201,11 +201,6 @@ def rollout(
|
|||||||
"You're likely using an older version of gymnasium (< 1.0). Please upgrade."
|
"You're likely using an older version of gymnasium (< 1.0). Please upgrade."
|
||||||
)
|
)
|
||||||
successes = final_info["is_success"].tolist()
|
successes = final_info["is_success"].tolist()
|
||||||
elif "is_success" in info:
|
|
||||||
is_success = info["is_success"]
|
|
||||||
successes = (
|
|
||||||
is_success.tolist() if hasattr(is_success, "tolist") else [bool(is_success)] * env.num_envs
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
successes = [False] * env.num_envs
|
successes = [False] * env.num_envs
|
||||||
|
|
||||||
|
|||||||
@@ -1,143 +0,0 @@
|
|||||||
"""Tests for the benchmark dispatch refactor (create_envs / get_env_processors on EnvConfig)."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
import pytest
|
|
||||||
from gymnasium.envs.registration import register, registry as gym_registry
|
|
||||||
|
|
||||||
from lerobot.configs.types import PolicyFeature
|
|
||||||
from lerobot.envs.configs import EnvConfig
|
|
||||||
from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def test_registry_all_types():
|
|
||||||
"""make_env_config should resolve every registered EnvConfig subclass via the registry."""
|
|
||||||
known = list(EnvConfig.get_known_choices().keys())
|
|
||||||
assert len(known) >= 6
|
|
||||||
for t in known:
|
|
||||||
cfg = make_env_config(t)
|
|
||||||
if not isinstance(cfg, EnvConfig):
|
|
||||||
continue
|
|
||||||
assert cfg.type == t
|
|
||||||
|
|
||||||
|
|
||||||
def test_unknown_type():
|
|
||||||
with pytest.raises(ValueError, match="not registered"):
|
|
||||||
make_env_config("nonexistent")
|
|
||||||
|
|
||||||
|
|
||||||
def test_identity_processors():
|
|
||||||
"""Base class get_env_processors() returns identity pipelines."""
|
|
||||||
cfg = make_env_config("aloha")
|
|
||||||
pre, post = cfg.get_env_processors()
|
|
||||||
assert len(pre.steps) == 0 and len(post.steps) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_delegation():
|
|
||||||
"""make_env() should call cfg.create_envs(), not use if/elif dispatch."""
|
|
||||||
sentinel = {"delegated": {0: "marker"}}
|
|
||||||
fake = type(
|
|
||||||
"Fake",
|
|
||||||
(),
|
|
||||||
{
|
|
||||||
"hub_path": None,
|
|
||||||
"create_envs": lambda self, n_envs, use_async_envs=False: sentinel,
|
|
||||||
},
|
|
||||||
)()
|
|
||||||
result = make_env(fake, n_envs=1)
|
|
||||||
assert result is sentinel
|
|
||||||
|
|
||||||
|
|
||||||
def test_processors_delegation():
|
|
||||||
"""make_env_pre_post_processors delegates to cfg.get_env_processors()."""
|
|
||||||
cfg = make_env_config("aloha")
|
|
||||||
pre, post = make_env_pre_post_processors(cfg, policy_cfg=None)
|
|
||||||
assert len(pre.steps) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_base_create_envs():
|
|
||||||
"""Base class create_envs() should build a single-task VectorEnv via gym.make()."""
|
|
||||||
gym_id = "_dispatch_test/CartPole-v99"
|
|
||||||
if gym_id not in gym_registry:
|
|
||||||
register(id=gym_id, entry_point="gymnasium.envs.classic_control:CartPoleEnv")
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass("_dispatch_base_test")
|
|
||||||
@dataclass
|
|
||||||
class _Env(EnvConfig):
|
|
||||||
task: str = "CartPole-v99"
|
|
||||||
fps: int = 10
|
|
||||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def package_name(self):
|
|
||||||
return "_dispatch_test"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def gym_id(self):
|
|
||||||
return gym_id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def gym_kwargs(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
envs = _Env().create_envs(n_envs=2)
|
|
||||||
assert "_dispatch_base_test" in envs
|
|
||||||
env = envs["_dispatch_base_test"][0]
|
|
||||||
assert isinstance(env, gym.vector.SyncVectorEnv)
|
|
||||||
assert env.num_envs == 2
|
|
||||||
env.close()
|
|
||||||
finally:
|
|
||||||
if gym_id in gym_registry:
|
|
||||||
del gym_registry[gym_id]
|
|
||||||
|
|
||||||
|
|
||||||
def test_custom_create_envs_override():
|
|
||||||
"""A custom EnvConfig subclass can override create_envs()."""
|
|
||||||
mock_vec = gym.vector.SyncVectorEnv([lambda: gym.make("CartPole-v1")])
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass("_dispatch_custom_test")
|
|
||||||
@dataclass
|
|
||||||
class _Env(EnvConfig):
|
|
||||||
task: str = "x"
|
|
||||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def gym_kwargs(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def create_envs(self, n_envs, use_async_envs=False):
|
|
||||||
return {"custom_suite": {0: mock_vec}}
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = make_env(_Env(), n_envs=1)
|
|
||||||
assert "custom_suite" in result
|
|
||||||
finally:
|
|
||||||
mock_vec.close()
|
|
||||||
|
|
||||||
|
|
||||||
def test_custom_get_env_processors_override():
|
|
||||||
"""A custom EnvConfig subclass can override get_env_processors()."""
|
|
||||||
from lerobot.processor.pipeline import DataProcessorPipeline
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass("_dispatch_proc_test")
|
|
||||||
@dataclass
|
|
||||||
class _Env(EnvConfig):
|
|
||||||
task: str = "x"
|
|
||||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def gym_kwargs(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def get_env_processors(self):
|
|
||||||
return DataProcessorPipeline(steps=[]), DataProcessorPipeline(steps=[])
|
|
||||||
|
|
||||||
pre, post = _Env().get_env_processors()
|
|
||||||
assert isinstance(pre, DataProcessorPipeline)
|
|
||||||
Reference in New Issue
Block a user