Compare commits

..

1 Commits

Author SHA1 Message Date
Steven Palma 2e2c27da34 fix(ci): use GITHUB_TOKEN for automated PR 2026-04-06 21:09:21 +02:00
25 changed files with 255 additions and 679 deletions
-86
View File
@@ -1,86 +0,0 @@
# LeRobot — Claude Code Instructions
You are a senior robotics ML engineer reviewing code for **LeRobot**, a PyTorch framework for real-world robot learning.
Apply these principles to every PR review, fix, or task.
---
## Core Abstractions
These are the load-bearing types. Handle them with care — breaking changes here affect every user.
| Type | Location | Role |
| ---------------- | ---------------------------- | ------------------------------------------------------------ |
| `LeRobotDataset` | `src/lerobot/datasets/` | Streaming replay buffer; HF Hub integration |
| `Policy` | `src/lerobot/policies/` | Base class for all learning agents (ACT, Diffusion, SARM, …) |
| `Robot` | `src/lerobot/robots/` | Hardware abstraction; carries `_output_pipeline` |
| `Teleoperator` | `src/lerobot/teleoperators/` | Leader-side hardware abstraction; carries `_output_pipeline` |
| `Env` | `src/lerobot/envs/` | Gym-like robotics environments |
| `Processor` | `src/lerobot/processor/` | Data transformation pipelines attached to robots/teleops |
**Never break their public APIs without a migration note and explicit user approval.**
---
## Engineering Principles
### Code quality
- Explicit over magic — no hidden control flow, no implicit state.
- No deep inheritance trees. Prefer composition.
- No decorative comment separators (`===`, `---`, etc.).
- Add comments only where the logic is non-obvious.
- No over-engineering. YAGNI applies strictly.
### Type safety
- All new and modified Python code must be fully typed (PEP 484).
- `mypy --strict` must pass on changed files.
- Do not widen or weaken existing type signatures.
### Backwards compatibility
- Public API changes require migration notes.
- Additive changes are preferred over modifications.
- `so100_follower` / `so101_follower` are aliases — never bleed changes there unintentionally.
### HF ecosystem
- Use `push_to_hub()`, HF Hub dataset streaming, and `evaluate` scripts.
- Dataset changes must preserve streaming compatibility.
- Prefer reusing HF primitives over rolling custom solutions.
---
## PR Review Checklist
Before approving or marking P1 issues resolved, verify:
- [ ] `pre-commit run -a` would pass (ruff, mypy, typos, zizmor, bandit)
- [ ] All new/modified code is typed and passes `mypy --strict`
- [ ] New features have unit tests; no silent behavioral changes
- [ ] Public APIs of `LeRobotDataset`, `Policy`, `Robot`, `Teleoperator`, `Env` are unchanged (or migration note present)
- [ ] HF Hub streaming still works for dataset changes
- [ ] No unnecessary abstractions introduced
- [ ] No breaking changes to training scripts (`lerobot-train`, `lerobot-eval`, `lerobot-record`)
---
## ML-Specific Checks
Flag these as **P1** if found:
- **Data leakage**: train and val/test splits must be constructed before any normalization or augmentation that uses train statistics.
- **Loss function errors**: verify reduction mode (`mean` vs `sum`), correct masking, correct shape alignment.
- **Gradient flow**: new modules must have gradients flowing (check `requires_grad`, no detached tensors in the loss path by accident).
- **Distributed training**: operations on tensors must be DDP-safe; no in-place ops on parameters; batch norm needs `SyncBatchNorm` if used.
- **Memory leaks**: no accumulation of tensors outside the training loop; `optimizer.zero_grad()` called correctly.
---
## What to Skip
- Don't flag style nitpicks on unchanged surrounding code.
- Don't propose refactors outside the PR's scope.
- Don't add docstrings or comments to code the PR didn't touch.
- Don't suggest speculative future features (YAGNI).
-49
View File
@@ -1,49 +0,0 @@
name: Claude Code Review
on:
pull_request:
types: [opened, synchronize, ready_for_review, reopened]
jobs:
claude-review:
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
issues: read
id-token: write
actions: read
env:
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 1
persist-credentials: false
- name: Run Claude Code Review
id: claude-review
uses: anthropics/claude-code-action@26ddc358fe3befff50c5ec2f80304c90c763f6f8 # v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
use_sticky_comment: true
prompt: |
Read `.github/CLAUDE.md` for lerobot-specific conventions, then review this PR.
Provide structured, actionable feedback.
Focus areas (in priority order):
1. **Correctness**: Logic errors, off-by-ones, wrong tensor shapes, incorrect loss functions
2. **Type safety**: All new/modified Python code must pass `mypy --strict`; check for missing annotations
3. **Backwards compatibility**: Does this break `LeRobotDataset`, `Policy`, `Robot`, `Teleoperator`, `Env`, or `Processor` public APIs?
4. **Tests**: New features must have tests; no silent behavioral changes
5. **Code style**: Explicit over magic, no unnecessary abstractions, no decorative comments
6. **HF integration**: Dataset streaming, `push_to_hub`, HF Hub compatibility preserved?
7. **pre-commit**: Would `pre-commit run -a` pass? (ruff, mypy, typos, zizmor)
Format findings as P1 (must fix) / P2 (should fix) / P3 (nice to have).
Skip P3 if the PR is already high quality.
claude_args: '--model claude-opus-4-6'
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
# or https://code.claude.com/docs/en/cli-reference for available options
-58
View File
@@ -1,58 +0,0 @@
name: Claude Code
on:
issue_comment:
types: [created]
pull_request_review_comment:
types: [created]
issues:
types: [opened, assigned]
pull_request_review:
types: [submitted]
jobs:
claude:
if: |
(github.event_name == 'issue_comment' &&
contains(github.event.comment.body, '@claude') &&
(github.event.comment.author_association == 'OWNER' || github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'COLLABORATOR')) ||
(github.event_name == 'pull_request_review_comment' &&
contains(github.event.comment.body, '@claude') &&
(github.event.comment.author_association == 'OWNER' || github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'COLLABORATOR')) ||
(github.event_name == 'pull_request_review' &&
contains(github.event.review.body, '@claude') &&
(github.event.review.author_association == 'OWNER' || github.event.review.author_association == 'MEMBER' || github.event.review.author_association == 'COLLABORATOR')) ||
(github.event_name == 'issues' &&
(contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')) &&
(github.event.issue.author_association == 'OWNER' || github.event.issue.author_association == 'MEMBER' || github.event.issue.author_association == 'COLLABORATOR'))
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
issues: write
id-token: write
actions: read
env:
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 1
persist-credentials: false
- name: Run Claude Code
id: claude
uses: anthropics/claude-code-action@26ddc358fe3befff50c5ec2f80304c90c763f6f8 # v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
use_sticky_comment: true
# This is an optional setting that allows Claude to read CI results on PRs
additional_permissions: |
actions: read
claude_args: '--system-prompt "Read .github/CLAUDE.md for lerobot-specific conventions before responding."'
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
# or https://code.claude.com/docs/en/cli-reference for available options
@@ -33,7 +33,7 @@ jobs:
github.event.workflow_run.event == 'pull_request' && github.event.workflow_run.event == 'pull_request' &&
github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.conclusion == 'success' &&
github.repository == 'huggingface/lerobot' github.repository == 'huggingface/lerobot'
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
with: with:
package_name: lerobot package_name: lerobot
secrets: secrets:
+2 -2
View File
@@ -55,7 +55,7 @@ jobs:
github.repository == 'huggingface/lerobot' github.repository == 'huggingface/lerobot'
permissions: permissions:
contents: read contents: read
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
with: with:
commit_sha: ${{ github.sha }} commit_sha: ${{ github.sha }}
package: lerobot package: lerobot
@@ -78,7 +78,7 @@ jobs:
permissions: permissions:
contents: read contents: read
pull-requests: write pull-requests: write
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
with: with:
commit_sha: ${{ github.event.pull_request.head.sha }} commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }} pr_number: ${{ github.event.number }}
+2 -2
View File
@@ -65,7 +65,7 @@ jobs:
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }} HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps: steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - uses: actions/checkout@v6
with: with:
persist-credentials: false persist-credentials: false
lfs: true lfs: true
@@ -83,7 +83,7 @@ jobs:
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev
- name: Setup uv and Python - name: Setup uv and Python
uses: astral-sh/setup-uv@d0cc045d04ccac9d8b7881df0226f9e82c39688e # v6 uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
with: with:
enable-cache: true enable-cache: true
version: ${{ env.UV_VERSION }} version: ${{ env.UV_VERSION }}
+6 -6
View File
@@ -63,7 +63,7 @@ jobs:
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }} HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps: steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - uses: actions/checkout@v6
with: with:
lfs: true lfs: true
persist-credentials: false persist-credentials: false
@@ -80,7 +80,7 @@ jobs:
speech-dispatcher libgeos-dev portaudio19-dev speech-dispatcher libgeos-dev portaudio19-dev
- name: Setup uv and Python - name: Setup uv and Python
uses: astral-sh/setup-uv@d0cc045d04ccac9d8b7881df0226f9e82c39688e # v6 uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
with: with:
enable-cache: true enable-cache: true
version: ${{ env.UV_VERSION }} version: ${{ env.UV_VERSION }}
@@ -137,21 +137,21 @@ jobs:
sudo apt-get update sudo apt-get update
sudo apt-get install git-lfs sudo apt-get install git-lfs
git lfs install git lfs install
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - uses: actions/checkout@v6
with: with:
lfs: true lfs: true
persist-credentials: false persist-credentials: false
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3 uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
with: with:
cache-binary: false cache-binary: false
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3 uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
with: with:
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }} username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }} password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
- name: Build and push Docker image - name: Build and push Docker image
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6 uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
with: with:
context: . context: .
file: ./docker/Dockerfile.internal file: ./docker/Dockerfile.internal
+1 -1
View File
@@ -227,7 +227,7 @@ jobs:
contents: write contents: write
pull-requests: write pull-requests: write
env: env:
GH_TOKEN: ${{ secrets.UPDATE_LOCK_TOKEN }} GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
steps: steps:
- uses: actions/checkout@v6 - uses: actions/checkout@v6
with: with:
+3 -3
View File
@@ -43,16 +43,16 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@v6
with: with:
persist-credentials: false persist-credentials: false
- name: Set up Python - name: Set up Python
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6 uses: actions/setup-python@v6
with: with:
python-version: '3.12' python-version: '3.12'
- name: Run pre-commit hooks - name: Run pre-commit hooks
uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 uses: pre-commit/action@v3.0.1 # zizmor: ignore[unpinned-uses]
with: with:
extra_args: --all-files --show-diff-on-failure --color=always extra_args: --all-files --show-diff-on-failure --color=always
+6 -6
View File
@@ -38,12 +38,12 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@v6
with: with:
persist-credentials: false persist-credentials: false
- name: Set up Python - name: Set up Python
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6 uses: actions/setup-python@v6
with: with:
python-version: '3.12' python-version: '3.12'
@@ -104,7 +104,7 @@ jobs:
- name: Publish to TestPyPI for pre-releases - name: Publish to TestPyPI for pre-releases
# True for tags like 'v0.2.0-rc1' # True for tags like 'v0.2.0-rc1'
if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-') if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-')
uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0 uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
with: with:
repository-url: https://test.pypi.org/legacy/ repository-url: https://test.pypi.org/legacy/
verbose: true verbose: true
@@ -112,7 +112,7 @@ jobs:
- name: Publish to PyPI - name: Publish to PyPI
if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-') if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0 uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
with: with:
verbose: true verbose: true
print-hash: true print-hash: true
@@ -127,7 +127,7 @@ jobs:
env: env:
MUJOCO_GL: egl MUJOCO_GL: egl
steps: steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - uses: actions/checkout@v6
with: with:
lfs: true lfs: true
persist-credentials: false persist-credentials: false
@@ -137,7 +137,7 @@ jobs:
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \ git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev speech-dispatcher libgeos-dev portaudio19-dev
- name: Setup uv and Python - name: Setup uv and Python
uses: astral-sh/setup-uv@d0cc045d04ccac9d8b7881df0226f9e82c39688e # v6 uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
with: with:
enable-cache: true # zizmor: ignore[cache-poisoning] enable-cache: true # zizmor: ignore[cache-poisoning]
version: ${{ env.UV_VERSION }} version: ${{ env.UV_VERSION }}
+2 -2
View File
@@ -43,12 +43,12 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@v6 # zizmor: ignore[unpinned-uses]
with: with:
fetch-depth: 0 fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Secret Scanning - name: Secret Scanning
uses: trufflesecurity/trufflehog@eafb8c5f6a06175141c27f17bcc17941853d0047 # v3.90.0 uses: trufflesecurity/trufflehog@v3.90.0 # zizmor: ignore[unpinned-uses]
with: with:
extra_args: --only-verified extra_args: --only-verified
-1
View File
@@ -4,7 +4,6 @@
<div align="center"> <div align="center">
[![Tests](https://github.com/huggingface/lerobot/actions/workflows/latest_deps_tests.yml/badge.svg?branch=main)](https://github.com/huggingface/lerobot/actions/workflows/latest_deps_tests.yml?query=branch%3Amain)
[![Tests](https://github.com/huggingface/lerobot/actions/workflows/docker_publish.yml/badge.svg?branch=main)](https://github.com/huggingface/lerobot/actions/workflows/docker_publish.yml?query=branch%3Amain) [![Tests](https://github.com/huggingface/lerobot/actions/workflows/docker_publish.yml/badge.svg?branch=main)](https://github.com/huggingface/lerobot/actions/workflows/docker_publish.yml?query=branch%3Amain)
[![Python versions](https://img.shields.io/pypi/pyversions/lerobot)](https://www.python.org/downloads/) [![Python versions](https://img.shields.io/pypi/pyversions/lerobot)](https://www.python.org/downloads/)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/huggingface/lerobot/blob/main/LICENSE) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/huggingface/lerobot/blob/main/LICENSE)
+45 -36
View File
@@ -115,22 +115,23 @@ Each `EnvConfig` subclass declares two dicts that tell the policy what to expect
## Step by step ## Step by step
<Tip> <Tip>
At minimum, you need two files: a **gym.Env wrapper** and an **EnvConfig At minimum, you need three files: a **gym.Env wrapper**, an **EnvConfig
subclass** with a `create_envs()` override. Everything else is optional or subclass**, and a **factory dispatch branch**. Everything else is optional or
documentation. No changes to `factory.py` are needed. documentation.
</Tip> </Tip>
### Checklist ### Checklist
| File | Required | Why | | File | Required | Why |
| ---------------------------------------- | -------- | ------------------------------------------------------------ | | ---------------------------------------- | -------- | ----------------------------------------- |
| `src/lerobot/envs/<benchmark>.py` | Yes | Wraps the simulator as a standard gym.Env | | `src/lerobot/envs/<benchmark>.py` | Yes | Wraps the simulator as a standard gym.Env |
| `src/lerobot/envs/configs.py` | Yes | Registers your benchmark and its `create_envs()` for the CLI | | `src/lerobot/envs/configs.py` | Yes | Registers your benchmark for the CLI |
| `src/lerobot/processor/env_processor.py` | Optional | Custom observation/action transforms | | `src/lerobot/envs/factory.py` | Yes | Tells `make_env()` how to build your envs |
| `src/lerobot/envs/utils.py` | Optional | Only if you need new raw observation keys | | `src/lerobot/processor/env_processor.py` | Optional | Custom observation/action transforms |
| `pyproject.toml` | Yes | Declares benchmark-specific dependencies | | `src/lerobot/envs/utils.py` | Optional | Only if you need new raw observation keys |
| `docs/source/<benchmark>.mdx` | Yes | User-facing documentation page | | `pyproject.toml` | Yes | Declares benchmark-specific dependencies |
| `docs/source/_toctree.yml` | Yes | Adds your page to the docs sidebar | | `docs/source/<benchmark>.mdx` | Yes | User-facing documentation page |
| `docs/source/_toctree.yml` | Yes | Adds your page to the docs sidebar |
### 1. The gym.Env wrapper (`src/lerobot/envs/<benchmark>.py`) ### 1. The gym.Env wrapper (`src/lerobot/envs/<benchmark>.py`)
@@ -178,10 +179,7 @@ See `create_libero_envs()` (multi-suite, multi-task) and `create_metaworld_envs(
### 2. The config (`src/lerobot/envs/configs.py`) ### 2. The config (`src/lerobot/envs/configs.py`)
Register a config dataclass so users can select your benchmark with `--env.type=<name>`. Each config owns its environment creation and processor logic via two methods: Register a config dataclass so users can select your benchmark with `--env.type=<name>`:
- **`create_envs(n_envs, use_async_envs)`** — Returns `{suite: {task_id: VectorEnv}}`. The base class default uses `gym.make()` for single-task envs. Multi-task benchmarks override this.
- **`get_env_processors()`** — Returns `(preprocessor, postprocessor)`. The base class default returns identity (no-op) pipelines. Override if your benchmark needs observation/action transforms.
```python ```python
@EnvConfig.register_subclass("<benchmark_name>") @EnvConfig.register_subclass("<benchmark_name>")
@@ -206,20 +204,6 @@ class MyBenchmarkEnvConfig(EnvConfig):
@property @property
def gym_kwargs(self) -> dict: def gym_kwargs(self) -> dict:
return {"obs_type": self.obs_type, "render_mode": self.render_mode} return {"obs_type": self.obs_type, "render_mode": self.render_mode}
def create_envs(self, n_envs: int, use_async_envs: bool = False):
"""Override for multi-task benchmarks or custom env creation."""
from lerobot.envs.<benchmark> import create_<benchmark>_envs
return create_<benchmark>_envs(task=self.task, n_envs=n_envs, ...)
def get_env_processors(self):
"""Override if your benchmark needs observation/action transforms."""
from lerobot.processor.pipeline import PolicyProcessorPipeline
from lerobot.processor.env_processor import MyBenchmarkProcessorStep
return (
PolicyProcessorPipeline(steps=[MyBenchmarkProcessorStep()]),
PolicyProcessorPipeline(steps=[]),
)
``` ```
Key points: Key points:
@@ -227,11 +211,36 @@ Key points:
- The `register_subclass` name is what users pass on the CLI (`--env.type=<name>`). - The `register_subclass` name is what users pass on the CLI (`--env.type=<name>`).
- `features` tells the policy what the environment produces. - `features` tells the policy what the environment produces.
- `features_map` maps raw observation keys to LeRobot convention keys. - `features_map` maps raw observation keys to LeRobot convention keys.
- **No changes to `factory.py` needed** — the factory delegates to `cfg.create_envs()` and `cfg.get_env_processors()` automatically.
### 3. Env processor (optional — `src/lerobot/processor/env_processor.py`) ### 3. The factory dispatch (`src/lerobot/envs/factory.py`)
Only needed if your benchmark requires observation transforms beyond what `preprocess_observation()` handles (e.g. image flipping, coordinate conversion). Define the processor step here and return it from `get_env_processors()` in your config (see step 2): Add a branch in `make_env()` to call your factory function:
```python
elif "<benchmark_name>" in cfg.type:
from lerobot.envs.<benchmark> import create_<benchmark>_envs
if cfg.task is None:
raise ValueError("<BenchmarkName> requires a task to be specified")
return create_<benchmark>_envs(
task=cfg.task,
n_envs=n_envs,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
)
```
If your benchmark needs an env processor, add it in `make_env_pre_post_processors()`:
```python
if isinstance(env_cfg, MyBenchmarkEnvConfig) or "<benchmark_name>" in env_cfg.type:
preprocessor_steps.append(MyBenchmarkProcessorStep())
```
### 4. Env processor (optional — `src/lerobot/processor/env_processor.py`)
Only needed if your benchmark requires observation transforms beyond what `preprocess_observation()` handles (e.g. image flipping, coordinate conversion):
```python ```python
@dataclass @dataclass
@@ -251,7 +260,7 @@ class MyBenchmarkProcessorStep(ObservationProcessorStep):
See `LiberoProcessorStep` for a full example (image rotation, quaternion-to-axis-angle conversion). See `LiberoProcessorStep` for a full example (image rotation, quaternion-to-axis-angle conversion).
### 4. Dependencies (`pyproject.toml`) ### 5. Dependencies (`pyproject.toml`)
Add a new optional-dependency group: Add a new optional-dependency group:
@@ -272,11 +281,11 @@ Users install with:
pip install -e ".[mybenchmark]" pip install -e ".[mybenchmark]"
``` ```
### 5. Documentation (`docs/source/<benchmark>.mdx`) ### 6. Documentation (`docs/source/<benchmark>.mdx`)
Write a user-facing page following the template in the next section. See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for full examples. Write a user-facing page following the template in the next section. See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for full examples.
### 6. Table of contents (`docs/source/_toctree.yml`) ### 7. Table of contents (`docs/source/_toctree.yml`)
Add your benchmark to the "Benchmarks" section: Add your benchmark to the "Benchmarks" section:
+48 -42
View File
@@ -90,17 +90,11 @@ The same policy can work with different environment processors, and the same env
```python ```python
# Use SmolVLA policy with LIBERO environment # Use SmolVLA policy with LIBERO environment
# Use SmolVLA policy with LIBERO environment libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(
env_cfg=libero_cfg,
policy_cfg=smolvla_cfg,
)
smolvla_preprocessor, smolvla_postprocessor = make_pre_post_processors(smolvla_cfg) smolvla_preprocessor, smolvla_postprocessor = make_pre_post_processors(smolvla_cfg)
# Or use ACT policy with the same LIBERO environment # Or use ACT policy with the same LIBERO environment
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors( libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
env_cfg=libero_cfg,
policy_cfg=act_cfg,
)
act_preprocessor, act_postprocessor = make_pre_post_processors(act_cfg) act_preprocessor, act_postprocessor = make_pre_post_processors(act_cfg)
``` ```
@@ -157,7 +151,7 @@ observation = {
### Factory Function ### Factory Function
The `make_env_pre_post_processors` function delegates to `env_cfg.get_env_processors()`: The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
```python ```python
from lerobot.envs.factory import make_env_pre_post_processors from lerobot.envs.factory import make_env_pre_post_processors
@@ -165,30 +159,46 @@ from lerobot.envs.configs import LiberoEnv, PushtEnv
# For LIBERO: Returns LiberoProcessorStep in preprocessor # For LIBERO: Returns LiberoProcessorStep in preprocessor
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"]) libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg, policy_cfg) env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
# For other environments: Returns identity processors (no-op) # For other environments: Returns identity processors (no-op)
pusht_cfg = PushtEnv() pusht_cfg = PushtEnv()
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg, policy_cfg) env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
``` ```
### How It Works ### Implementation in `envs/factory.py`
Each `EnvConfig` subclass can override `get_env_processors()` to return benchmark-specific
processor pipelines. The base class returns identity (no-op) processors by default.
```python ```python
# In your EnvConfig subclass: def make_env_pre_post_processors(
def get_env_processors(self): env_cfg: EnvConfig,
from lerobot.processor.pipeline import PolicyProcessorPipeline ) -> tuple[
return ( PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline(steps=[MyProcessorStep()]), PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline(steps=[]), ]:
) """
``` Create preprocessor and postprocessor pipelines for environment observations.
The factory function `make_env_pre_post_processors` simply delegates to this method, Args:
with a special case for `XVLAConfig` policies which override the env processors entirely. env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs
"""
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
else:
# For all other environments, return an identity preprocessor
preprocessor = PolicyProcessorPipeline(steps=[])
# Postprocessor is currently identity for all environments
# Future: Could add environment-specific action transformations
postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
```
### Integration in Evaluation ### Integration in Evaluation
@@ -209,10 +219,7 @@ def eval_main(cfg: EvalPipelineConfig):
) )
# Create environment processors (NEW!) # Create environment processors (NEW!)
env_preprocessor, env_postprocessor = make_env_pre_post_processors( env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
env_cfg=cfg.env,
policy_cfg=cfg.policy,
)
# Run evaluation with both processor types # Run evaluation with both processor types
eval_policy_all( eval_policy_all(
@@ -316,22 +323,21 @@ class MyEnvProcessorStep(ObservationProcessorStep):
return processed return processed
``` ```
### 2. Update Your `EnvConfig` Subclass ### 2. Update the Factory
```python ```python
# In src/lerobot/envs/configs.py # In src/lerobot/envs/factory.py
@EnvConfig.register_subclass("myenv")
@dataclass
class MyEnvConfig(EnvConfig):
# ... task/features/gym kwargs ...
def get_env_processors(self): def make_env_pre_post_processors(env_cfg: EnvConfig):
from lerobot.processor.pipeline import PolicyProcessorPipeline if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
elif isinstance(env_cfg, MyEnvConfig) or "myenv" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[MyEnvProcessorStep()])
else:
preprocessor = PolicyProcessorPipeline(steps=[])
return ( postprocessor = PolicyProcessorPipeline(steps=[])
PolicyProcessorPipeline(steps=[MyEnvProcessorStep()]), return preprocessor, postprocessor
PolicyProcessorPipeline(steps=[]),
)
``` ```
### 3. Use in Evaluation ### 3. Use in Evaluation
+1 -1
View File
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project] [project]
name = "lerobot" name = "lerobot"
version = "0.5.2" version = "0.5.1"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
dynamic = ["readme"] dynamic = ["readme"]
license = { text = "Apache-2.0" } license = { text = "Apache-2.0" }
+4 -19
View File
@@ -151,11 +151,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
``$HF_LEROBOT_HOME/hub``. ``$HF_LEROBOT_HOME/hub``.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None. their episode_index in this list. Defaults to None.
image_transforms (Callable | None, optional): image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
conversion. This works for both image-backed and video-backed observations and can later be from videos or images). Defaults to None.
updated with `set_image_transforms()` or cleared with `clear_image_transforms()`.
Defaults to None.
delta_timestamps (dict[list[float]] | None, optional): _description_. Defaults to None. delta_timestamps (dict[list[float]] | None, optional): _description_. Defaults to None.
tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in
sync with the fps value. It is used at the init of the dataset to make sure that each sync with the fps value. It is used at the init of the dataset to make sure that each
@@ -194,8 +192,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
self._requested_root = Path(root) if root else None self._requested_root = Path(root) if root else None
self.reader = None self.image_transforms = image_transforms
self.set_image_transforms(image_transforms)
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
self.episodes = episodes self.episodes = episodes
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
@@ -478,18 +475,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
f"}})" f"}})"
) )
def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform applied to visual observations."""
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self.image_transforms = image_transforms
if self.reader is not None:
self.reader._image_transforms = image_transforms
def clear_image_transforms(self) -> None:
"""Remove the transform applied to visual observations."""
self.set_image_transforms(None)
# ── Hub methods (stay on facade) ────────────────────────────────── # ── Hub methods (stay on facade) ──────────────────────────────────
def push_to_hub( def push_to_hub(
+1 -13
View File
@@ -89,24 +89,12 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
) )
self.disabled_features.update(extra_keys) self.disabled_features.update(extra_keys)
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets # TODO(rcadene, aliberts): We should not perform this aggregation for datasets
# with multiple robots of different ranges. Instead we should have one normalization # with multiple robots of different ranges. Instead we should have one normalization
# per robot. # per robot.
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets]) self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
self.set_image_transforms(image_transforms)
def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform for this dataset and its children."""
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self.image_transforms = image_transforms
for dataset in getattr(self, "_datasets", []):
dataset.set_image_transforms(self.image_transforms)
def clear_image_transforms(self) -> None:
"""Remove the transform from this dataset and its children."""
self.set_image_transforms(None)
@property @property
def repo_id_to_index(self): def repo_id_to_index(self):
-109
View File
@@ -12,16 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
import abc import abc
import importlib
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import Any from typing import Any
import draccus import draccus
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.robots import RobotConfig from lerobot.robots import RobotConfig
@@ -72,49 +67,6 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
def gym_kwargs(self) -> dict: def gym_kwargs(self) -> dict:
raise NotImplementedError() raise NotImplementedError()
def create_envs(
self,
n_envs: int,
use_async_envs: bool = False,
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""Create {suite: {task_id: VectorEnv}}.
Default: single-task env via gym.make(). Multi-task benchmarks override.
"""
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
if self.gym_id not in gym_registry:
print(f"gym id '{self.gym_id}' not found, attempting to import '{self.package_name}'...")
try:
importlib.import_module(self.package_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Package '{self.package_name}' required for env '{self.type}' not found. "
f"Please install it or check PYTHONPATH."
) from e
if self.gym_id not in gym_registry:
raise gym.error.NameNotFound(
f"Environment '{self.gym_id}' not registered even after importing '{self.package_name}'."
)
def _make_one():
return gym.make(self.gym_id, disable_env_checker=self.disable_env_checker, **self.gym_kwargs)
try:
from gymnasium.vector import AutoresetMode
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=AutoresetMode.SAME_STEP)
except ImportError:
vec = env_cls([_make_one for _ in range(n_envs)])
return {self.type: {0: vec}}
def get_env_processors(self):
"""Return (preprocessor, postprocessor) for this env. Default: identity."""
from lerobot.processor.pipeline import PolicyProcessorPipeline
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
@dataclass @dataclass
class HubEnvConfig(EnvConfig): class HubEnvConfig(EnvConfig):
@@ -386,12 +338,6 @@ class LiberoEnv(EnvConfig):
else: else:
raise ValueError(f"Unsupported obs_type: {self.obs_type}") raise ValueError(f"Unsupported obs_type: {self.obs_type}")
if self.camera_name_mapping is not None:
mapped_agentview = self.camera_name_mapping.get("agentview_image", "image")
mapped_eye_in_hand = self.camera_name_mapping.get("robot0_eye_in_hand_image", "image2")
self.features_map[LIBERO_KEY_PIXELS_AGENTVIEW] = f"{OBS_IMAGES}.{mapped_agentview}"
self.features_map[LIBERO_KEY_PIXELS_EYE_IN_HAND] = f"{OBS_IMAGES}.{mapped_eye_in_hand}"
@property @property
def gym_kwargs(self) -> dict: def gym_kwargs(self) -> dict:
kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode} kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
@@ -399,33 +345,6 @@ class LiberoEnv(EnvConfig):
kwargs["task_ids"] = self.task_ids kwargs["task_ids"] = self.task_ids
return kwargs return kwargs
def create_envs(self, n_envs: int, use_async_envs: bool = False):
from lerobot.envs.libero import create_libero_envs
if self.task is None:
raise ValueError("LiberoEnv requires a task to be specified")
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
return create_libero_envs(
task=self.task,
n_envs=n_envs,
camera_name=self.camera_name,
init_states=self.init_states,
gym_kwargs=self.gym_kwargs,
env_cls=env_cls,
control_mode=self.control_mode,
episode_length=self.episode_length,
camera_name_mapping=self.camera_name_mapping,
)
def get_env_processors(self):
from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
return (
PolicyProcessorPipeline(steps=[LiberoProcessorStep()]),
PolicyProcessorPipeline(steps=[]),
)
@EnvConfig.register_subclass("metaworld") @EnvConfig.register_subclass("metaworld")
@dataclass @dataclass
@@ -468,19 +387,6 @@ class MetaworldEnv(EnvConfig):
"render_mode": self.render_mode, "render_mode": self.render_mode,
} }
def create_envs(self, n_envs: int, use_async_envs: bool = False):
from lerobot.envs.metaworld import create_metaworld_envs
if self.task is None:
raise ValueError("MetaWorld requires a task to be specified")
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
return create_metaworld_envs(
task=self.task,
n_envs=n_envs,
gym_kwargs=self.gym_kwargs,
env_cls=env_cls,
)
@EnvConfig.register_subclass("isaaclab_arena") @EnvConfig.register_subclass("isaaclab_arena")
@dataclass @dataclass
@@ -548,18 +454,3 @@ class IsaaclabArenaEnv(HubEnvConfig):
@property @property
def gym_kwargs(self) -> dict: def gym_kwargs(self) -> dict:
return {} return {}
def get_env_processors(self):
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip())
camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip())
if not state_keys and not camera_keys:
raise ValueError("At least one of state_keys or camera_keys must be specified.")
return (
PolicyProcessorPipeline(
steps=[IsaaclabArenaProcessorStep(state_keys=state_keys, camera_keys=camera_keys)]
),
PolicyProcessorPipeline(steps=[]),
)
+117 -20
View File
@@ -13,46 +13,90 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations import importlib
from typing import Any from typing import Any
import gymnasium as gym import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.envs.configs import EnvConfig, HubEnvConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import ProcessorStep
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
def make_env_config(env_type: str, **kwargs) -> EnvConfig: def make_env_config(env_type: str, **kwargs) -> EnvConfig:
try: if env_type == "aloha":
cls = EnvConfig.get_choice_class(env_type) return AlohaEnv(**kwargs)
except KeyError as err: elif env_type == "pusht":
raise ValueError( return PushtEnv(**kwargs)
f"Environment type '{env_type}' is not registered. " elif env_type == "libero":
f"Available: {list(EnvConfig.get_known_choices().keys())}" return LiberoEnv(**kwargs)
) from err else:
return cls(**kwargs) raise ValueError(f"Policy type '{env_type}' is not available.")
def make_env_pre_post_processors( def make_env_pre_post_processors(
env_cfg: EnvConfig, env_cfg: EnvConfig,
policy_cfg: Any, policy_cfg: PreTrainedConfig,
) -> tuple[Any, Any]: ) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
""" """
Create preprocessor and postprocessor pipelines for environment observations. Create preprocessor and postprocessor pipelines for environment observations.
Returns a tuple of (preprocessor, postprocessor). By default, delegates to This function creates processor pipelines that transform raw environment
``env_cfg.get_env_processors()``. The XVLAConfig policy-specific override observations and actions. By default, it returns identity processors that do nothing.
stays here because it depends on the *policy* config, not the env config. For specific environments like LIBERO, it adds environment-specific processing steps.
"""
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
Args:
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs (currently identity)
"""
# Preprocessor and Postprocessor steps are Identity for most environments
preprocessor_steps: list[ProcessorStep] = []
postprocessor_steps: list[ProcessorStep] = []
if isinstance(policy_cfg, XVLAConfig): if isinstance(policy_cfg, XVLAConfig):
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
return make_xvla_libero_pre_post_processors() return make_xvla_libero_pre_post_processors()
return env_cfg.get_env_processors() # For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor_steps.append(LiberoProcessorStep())
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
# Parse comma-separated keys (handle None for state-based policies)
if env_cfg.state_keys:
state_keys = tuple(k.strip() for k in env_cfg.state_keys.split(",") if k.strip())
else:
state_keys = ()
if env_cfg.camera_keys:
camera_keys = tuple(k.strip() for k in env_cfg.camera_keys.split(",") if k.strip())
else:
camera_keys = ()
if not state_keys and not camera_keys:
raise ValueError("At least one of state_keys or camera_keys must be specified.")
preprocessor_steps.append(
IsaaclabArenaProcessorStep(
state_keys=state_keys,
camera_keys=camera_keys,
)
)
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
return preprocessor, postprocessor
def make_env( def make_env(
@@ -119,4 +163,57 @@ def make_env(
if n_envs < 1: if n_envs < 1:
raise ValueError("`n_envs` must be at least 1") raise ValueError("`n_envs` must be at least 1")
return cfg.create_envs(n_envs=n_envs, use_async_envs=use_async_envs) env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
if "libero" in cfg.type:
from lerobot.envs.libero import create_libero_envs
if cfg.task is None:
raise ValueError("LiberoEnv requires a task to be specified")
return create_libero_envs(
task=cfg.task,
n_envs=n_envs,
camera_name=cfg.camera_name,
init_states=cfg.init_states,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
control_mode=cfg.control_mode,
episode_length=cfg.episode_length,
)
elif "metaworld" in cfg.type:
from lerobot.envs.metaworld import create_metaworld_envs
if cfg.task is None:
raise ValueError("MetaWorld requires a task to be specified")
return create_metaworld_envs(
task=cfg.task,
n_envs=n_envs,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
)
if cfg.gym_id not in gym_registry:
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
try:
importlib.import_module(cfg.package_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
f"Please install it or check PYTHONPATH."
) from e
if cfg.gym_id not in gym_registry:
raise gym.error.NameNotFound(
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
)
def _make_one():
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
# normalize to {suite: {task_id: vec_env}} for consistency
suite_name = cfg.type # e.g., "pusht", "aloha"
return {suite_name: {0: vec}}
+7 -6
View File
@@ -223,8 +223,7 @@ class LiberoEnv(gym.Env):
def render(self): def render(self):
raw_obs = self._env.env._get_observations() raw_obs = self._env.env._get_observations()
pixels = self._format_raw_obs(raw_obs)["pixels"] image = self._format_raw_obs(raw_obs)["pixels"]["image"]
image = next(iter(pixels.values()))
image = image[::-1, ::-1] # flip both H and W for visualization image = image[::-1, ::-1] # flip both H and W for visualization
return image return image
@@ -340,6 +339,12 @@ class LiberoEnv(gym.Env):
) )
observation = self._format_raw_obs(raw_obs) observation = self._format_raw_obs(raw_obs)
if terminated: if terminated:
info["final_info"] = {
"task": self.task,
"task_id": self.task_id,
"done": bool(done),
"is_success": bool(is_success),
}
self.reset() self.reset()
truncated = False truncated = False
return observation, reward, terminated, truncated, info return observation, reward, terminated, truncated, info
@@ -359,7 +364,6 @@ def _make_env_fns(
init_states: bool, init_states: bool,
gym_kwargs: Mapping[str, Any], gym_kwargs: Mapping[str, Any],
control_mode: str, control_mode: str,
camera_name_mapping: dict[str, str] | None = None,
) -> list[Callable[[], LiberoEnv]]: ) -> list[Callable[[], LiberoEnv]]:
"""Build n_envs factory callables for a single (suite, task_id).""" """Build n_envs factory callables for a single (suite, task_id)."""
@@ -375,7 +379,6 @@ def _make_env_fns(
episode_index=episode_index, episode_index=episode_index,
n_envs=n_envs, n_envs=n_envs,
control_mode=control_mode, control_mode=control_mode,
camera_name_mapping=camera_name_mapping,
**local_kwargs, **local_kwargs,
) )
@@ -397,7 +400,6 @@ def create_libero_envs(
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None, env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
control_mode: str = "relative", control_mode: str = "relative",
episode_length: int | None = None, episode_length: int | None = None,
camera_name_mapping: dict[str, str] | None = None,
) -> dict[str, dict[int, Any]]: ) -> dict[str, dict[int, Any]]:
""" """
Create vectorized LIBERO environments with a consistent return shape. Create vectorized LIBERO environments with a consistent return shape.
@@ -447,7 +449,6 @@ def create_libero_envs(
init_states=init_states, init_states=init_states,
gym_kwargs=gym_kwargs, gym_kwargs=gym_kwargs,
control_mode=control_mode, control_mode=control_mode,
camera_name_mapping=camera_name_mapping,
) )
out[suite_name][tid] = env_cls(fns) out[suite_name][tid] = env_cls(fns)
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}") print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
-5
View File
@@ -201,11 +201,6 @@ def rollout(
"You're likely using an older version of gymnasium (< 1.0). Please upgrade." "You're likely using an older version of gymnasium (< 1.0). Please upgrade."
) )
successes = final_info["is_success"].tolist() successes = final_info["is_success"].tolist()
elif "is_success" in info:
is_success = info["is_success"]
successes = (
is_success.tolist() if hasattr(is_success, "tolist") else [bool(is_success)] * env.num_envs
)
else: else:
successes = [False] * env.num_envs successes = [False] * env.num_envs
-1
View File
@@ -421,7 +421,6 @@ def record_loop(
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features) act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
# Applies a pipeline to the action, default is IdentityProcessor # Applies a pipeline to the action, default is IdentityProcessor
robot_action_to_send = robot_action_processor((act_processed_policy, obs)) robot_action_to_send = robot_action_processor((act_processed_policy, obs))
action_values = robot_action_to_send
elif policy is None and isinstance(teleop, Teleoperator): elif policy is None and isinstance(teleop, Teleoperator):
act = teleop.get_action() act = teleop.get_action()
-58
View File
@@ -24,7 +24,6 @@ import torch
from huggingface_hub import HfApi from huggingface_hub import HfApi
from PIL import Image from PIL import Image
from safetensors.torch import load_file from safetensors.torch import load_file
from torchvision.transforms import v2
import lerobot import lerobot
from lerobot.configs.default import DatasetConfig from lerobot.configs.default import DatasetConfig
@@ -35,7 +34,6 @@ from lerobot.datasets.image_writer import image_array_to_pil_image
from lerobot.datasets.io_utils import hf_transform_to_torch from lerobot.datasets.io_utils import hf_transform_to_torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.multi_dataset import MultiLeRobotDataset from lerobot.datasets.multi_dataset import MultiLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig
from lerobot.datasets.utils import ( from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -357,62 +355,6 @@ def test_add_frame_image_pil(image_dataset):
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_set_image_transforms_applies_transparently(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
dataset.set_image_transforms(v2.Resize((224, 224)))
assert dataset[0]["image"].shape == torch.Size((3, 224, 224))
dataset.set_image_transforms(v2.Resize((128, 128)))
assert dataset[0]["image"].shape == torch.Size((3, 128, 128))
dataset.clear_image_transforms()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_set_image_transforms_supports_lerobot_image_transforms(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
image_transforms = ImageTransforms(ImageTransformsConfig(enable=False))
dataset.set_image_transforms(image_transforms)
assert dataset.image_transforms is image_transforms
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_set_image_transforms_supports_loaded_dataset(tmp_path, lerobot_dataset_factory):
dataset = lerobot_dataset_factory(root=tmp_path / "test", use_videos=False)
dataset.set_image_transforms(v2.Compose([v2.Resize((224, 224)), v2.Resize((112, 112))]))
camera_key = dataset.meta.camera_keys[0]
assert dataset[0][camera_key].shape == torch.Size((3, 112, 112))
def test_multilerobot_dataset_set_image_transforms_propagates(tmp_path, lerobot_dataset_factory):
root = tmp_path / "multi"
repo_ids = ["lerobot/test_multi_a", "lerobot/test_multi_b"]
for repo_id in repo_ids:
lerobot_dataset_factory(root=root / repo_id, repo_id=repo_id, use_videos=False)
dataset = MultiLeRobotDataset(repo_ids, root=root, download_videos=False)
dataset.set_image_transforms(v2.Resize((96, 96)))
camera_key = dataset.camera_keys[0]
assert dataset[0][camera_key].shape == torch.Size((3, 96, 96))
assert all(child.image_transforms is dataset.image_transforms for child in dataset._datasets)
dataset.clear_image_transforms()
assert dataset.image_transforms is None
assert all(child.image_transforms is None for child in dataset._datasets)
def test_image_array_to_pil_image_wrong_range_float_0_255(): def test_image_array_to_pil_image_wrong_range_float_0_255():
image = np.random.rand(*DUMMY_HWC) * 255 image = np.random.rand(*DUMMY_HWC) * 255
with pytest.raises(ValueError): with pytest.raises(ValueError):
-143
View File
@@ -1,143 +0,0 @@
"""Tests for the benchmark dispatch refactor (create_envs / get_env_processors on EnvConfig)."""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
import gymnasium as gym
import pytest
from gymnasium.envs.registration import register, registry as gym_registry
from lerobot.configs.types import PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors
logger = logging.getLogger(__name__)
def test_registry_all_types():
"""make_env_config should resolve every registered EnvConfig subclass via the registry."""
known = list(EnvConfig.get_known_choices().keys())
assert len(known) >= 6
for t in known:
cfg = make_env_config(t)
if not isinstance(cfg, EnvConfig):
continue
assert cfg.type == t
def test_unknown_type():
with pytest.raises(ValueError, match="not registered"):
make_env_config("nonexistent")
def test_identity_processors():
"""Base class get_env_processors() returns identity pipelines."""
cfg = make_env_config("aloha")
pre, post = cfg.get_env_processors()
assert len(pre.steps) == 0 and len(post.steps) == 0
def test_delegation():
"""make_env() should call cfg.create_envs(), not use if/elif dispatch."""
sentinel = {"delegated": {0: "marker"}}
fake = type(
"Fake",
(),
{
"hub_path": None,
"create_envs": lambda self, n_envs, use_async_envs=False: sentinel,
},
)()
result = make_env(fake, n_envs=1)
assert result is sentinel
def test_processors_delegation():
"""make_env_pre_post_processors delegates to cfg.get_env_processors()."""
cfg = make_env_config("aloha")
pre, post = make_env_pre_post_processors(cfg, policy_cfg=None)
assert len(pre.steps) == 0
def test_base_create_envs():
"""Base class create_envs() should build a single-task VectorEnv via gym.make()."""
gym_id = "_dispatch_test/CartPole-v99"
if gym_id not in gym_registry:
register(id=gym_id, entry_point="gymnasium.envs.classic_control:CartPoleEnv")
@EnvConfig.register_subclass("_dispatch_base_test")
@dataclass
class _Env(EnvConfig):
task: str = "CartPole-v99"
fps: int = 10
features: dict[str, PolicyFeature] = field(default_factory=dict)
@property
def package_name(self):
return "_dispatch_test"
@property
def gym_id(self):
return gym_id
@property
def gym_kwargs(self):
return {}
try:
envs = _Env().create_envs(n_envs=2)
assert "_dispatch_base_test" in envs
env = envs["_dispatch_base_test"][0]
assert isinstance(env, gym.vector.SyncVectorEnv)
assert env.num_envs == 2
env.close()
finally:
if gym_id in gym_registry:
del gym_registry[gym_id]
def test_custom_create_envs_override():
"""A custom EnvConfig subclass can override create_envs()."""
mock_vec = gym.vector.SyncVectorEnv([lambda: gym.make("CartPole-v1")])
@EnvConfig.register_subclass("_dispatch_custom_test")
@dataclass
class _Env(EnvConfig):
task: str = "x"
features: dict[str, PolicyFeature] = field(default_factory=dict)
@property
def gym_kwargs(self):
return {}
def create_envs(self, n_envs, use_async_envs=False):
return {"custom_suite": {0: mock_vec}}
try:
result = make_env(_Env(), n_envs=1)
assert "custom_suite" in result
finally:
mock_vec.close()
def test_custom_get_env_processors_override():
"""A custom EnvConfig subclass can override get_env_processors()."""
from lerobot.processor.pipeline import DataProcessorPipeline
@EnvConfig.register_subclass("_dispatch_proc_test")
@dataclass
class _Env(EnvConfig):
task: str = "x"
features: dict[str, PolicyFeature] = field(default_factory=dict)
@property
def gym_kwargs(self):
return {}
def get_env_processors(self):
return DataProcessorPipeline(steps=[]), DataProcessorPipeline(steps=[])
pre, post = _Env().get_env_processors()
assert isinstance(pre, DataProcessorPipeline)
Generated
+9 -9
View File
@@ -832,10 +832,10 @@ wheels = [
[[package]] [[package]]
name = "cuda-pathfinder" name = "cuda-pathfinder"
version = "1.5.2" version = "1.5.1"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/f2/f9/1b9b60a30fc463c14cdea7a77228131a0ccc89572e8df9cb86c9648271ab/cuda_pathfinder-1.5.2-py3-none-any.whl", hash = "sha256:0c5f160a7756c5b072723cbbd6d861e38917ef956c68150b02f0b6e9271c71fa", size = 49988, upload-time = "2026-04-06T23:01:05.17Z" }, { url = "https://files.pythonhosted.org/packages/c4/74/8c66861b873d8eed51fde56d3091baa4906a56f0d4390cae991f2d41dda5/cuda_pathfinder-1.5.1-py3-none-any.whl", hash = "sha256:b3718097fb57cf9e8a904dd072d806f2c9a27627e35c020b06ab9454bcec08c0", size = 49861, upload-time = "2026-04-03T16:41:22.203Z" },
] ]
[[package]] [[package]]
@@ -1828,7 +1828,7 @@ wheels = [
[[package]] [[package]]
name = "huggingface-hub" name = "huggingface-hub"
version = "1.9.1" version = "1.9.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "filelock" }, { name = "filelock" },
@@ -1841,9 +1841,9 @@ dependencies = [
{ name = "typer" }, { name = "typer" },
{ name = "typing-extensions" }, { name = "typing-extensions" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/44/40/68d9b286b125d9318ae95c8f8b206e8672e7244b0eea61ebb4a88037638c/huggingface_hub-1.9.1.tar.gz", hash = "sha256:442af372207cc24dcb089caf507fcd7dbc1217c11d6059a06f6b90afe64e8bd2", size = 750355, upload-time = "2026-04-07T13:47:59.167Z" } sdist = { url = "https://files.pythonhosted.org/packages/88/bb/62c7aa86f63a05e2f9b96642fdef9b94526a23979820b09f5455deff4983/huggingface_hub-1.9.0.tar.gz", hash = "sha256:0ea5be7a56135c91797cae6ad726e38eaeb6eb4b77cefff5c9d38ba0ecf874f7", size = 750326, upload-time = "2026-04-03T08:35:55.888Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/3d/af/10a89c54937dccf6c10792770f362d96dd67aedfde108e6e1fd7a0836789/huggingface_hub-1.9.1-py3-none-any.whl", hash = "sha256:8dae771b969b318203727a6c6c5209d25e661f6f0dd010fc09cc4a12cf81c657", size = 637356, upload-time = "2026-04-07T13:47:57.239Z" }, { url = "https://files.pythonhosted.org/packages/73/37/0d15d16150e1829f3e90962c99f28257f6de9e526a680b4c6f5acdb54fd2/huggingface_hub-1.9.0-py3-none-any.whl", hash = "sha256:2999328c058d39fd19ab748dd09bd4da2fbaa4f4c1ddea823eab103051e14a1f", size = 637355, upload-time = "2026-04-03T08:35:53.897Z" },
] ]
[[package]] [[package]]
@@ -2184,7 +2184,7 @@ wheels = [
[[package]] [[package]]
name = "lerobot" name = "lerobot"
version = "0.5.2" version = "0.5.1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "accelerate" }, { name = "accelerate" },
@@ -5561,15 +5561,15 @@ wheels = [
[[package]] [[package]]
name = "uvicorn" name = "uvicorn"
version = "0.44.0" version = "0.43.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "click" }, { name = "click" },
{ name = "h11" }, { name = "h11" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/5e/da/6eee1ff8b6cbeed47eeb5229749168e81eb4b7b999a1a15a7176e51410c9/uvicorn-0.44.0.tar.gz", hash = "sha256:6c942071b68f07e178264b9152f1f16dfac5da85880c4ce06366a96d70d4f31e", size = 86947, upload-time = "2026-04-06T09:23:22.826Z" } sdist = { url = "https://files.pythonhosted.org/packages/62/f2/368268300fb8af33743508d738ef7bb4d56afdb46c6d9c0fa3dd515df171/uvicorn-0.43.0.tar.gz", hash = "sha256:ab1652d2fb23abf124f36ccc399828558880def222c3cb3d98d24021520dc6e8", size = 85686, upload-time = "2026-04-03T18:37:48.984Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/23/a5bbd9600dd607411fa644c06ff4951bec3a4d82c4b852374024359c19c0/uvicorn-0.44.0-py3-none-any.whl", hash = "sha256:ce937c99a2cc70279556967274414c087888e8cec9f9c94644dfca11bd3ced89", size = 69425, upload-time = "2026-04-06T09:23:21.524Z" }, { url = "https://files.pythonhosted.org/packages/55/df/0cf5b0c451602748fdc7a702d4667f6e209bf96aa6e3160d754234445f2a/uvicorn-0.43.0-py3-none-any.whl", hash = "sha256:46fac64f487fd968cd999e5e49efbbe64bd231b5bd8b4a0b482a23ebce499620", size = 68591, upload-time = "2026-04-03T18:37:47.64Z" },
] ]
[package.optional-dependencies] [package.optional-dependencies]