Compare commits

...

10 Commits

Author SHA1 Message Date
javadcc_mac f9b8f297b4 Fix EVO1 LIBERO rollout processors 2026-06-09 15:10:10 +08:00
javadcc_mac 95527f6051 Merge remote-tracking branch 'upstream/main' into codex/add-evo1-policy 2026-05-12 17:40:59 +08:00
javadcc_mac 407ee867b9 docs(evo1): format results table 2026-05-12 17:40:18 +08:00
javadcc_mac a5e6409985 fix(evo1): finalize policy guide alignment 2026-05-11 21:51:41 +08:00
javadcc_mac 1c9fbba9a9 chore(evo1): align with policy contribution guide conventions
- Add `src/lerobot/policies/evo1/README.md` symlink into `docs/source/evo1.mdx`
  to match the in-tree README convention (mirroring the EO-1 layout).
- Convert `transformers` import in `internvl3_embedder.py` to the standard
  `TYPE_CHECKING + _transformers_available` two-step gating used by other
  optional-backbone policies (e.g. diffusion). The previous lazy-in-`__init__`
  import was functionally equivalent for runtime gating but didn't expose the
  real symbols to type checkers.
- Add `lerobot[evo1]` to the `all` extra in `pyproject.toml` so
  `pip install 'lerobot[all]'` keeps installing every optional policy.

Per the guidance in https://moon-ci-docs.huggingface.co/docs/lerobot/pr_3534/en/contributing_a_policy.
2026-05-10 23:14:23 +08:00
javadcc_mac 6a1b5ceb9d Merge remote-tracking branch 'upstream/main' into codex/add-evo1-policy
# Conflicts:
#	uv.lock
2026-05-10 22:48:17 +08:00
javadcc_mac daa4c4dd30 chore(lock): regenerate uv.lock for evo1 extra
Adds the `evo1` entry to `[package.metadata.requires-dist]` and the
`provides-extras` list so that `uv sync --locked --extra test` (used by
fast_tests.yml) no longer reports the lockfile as stale.

Generated with `uv 0.8.0` (matching `UV_VERSION` in fast_tests.yml).
The non-evo1 marker tweaks are produced by `uv lock` re-resolving the
existing dep graph and are not introduced by this PR.
2026-05-10 22:43:26 +08:00
Yiming Wang ff992a7a1d Merge branch 'main' into codex/add-evo1-policy 2026-05-10 18:54:35 +08:00
javadcc_mac 48269dddb3 fix(evo1): infer batch size after normalizing image dims
`_collect_image_batches` read `batch_size = batch[camera_keys[0]].shape[0]`
before normalizing per-camera tensors to `(B, C, H, W)`. For an unbatched
`(C, H, W)` input (which the function tries to support via the `image.dim() == 3`
branch), this picked up the channel count `C` instead of the real batch size,
making the subsequent per-sample loop iterate `C` times and indexing go
out of bounds.

Normalize each camera tensor up-front, then read `batch_size` from the
normalized batch dim. Adds `test_collect_image_batches_handles_unbatched_chw`
covering the regression.

Reported by Copilot review on huggingface/lerobot#3545.
2026-05-10 11:29:23 +08:00
javadcc_mac 8df8d3d866 feat(policies): add EVO1 policy 2026-05-09 21:39:19 +08:00
21 changed files with 2607 additions and 16 deletions
+2
View File
@@ -55,6 +55,8 @@
title: π₀.₅ (Pi05)
- local: eo1
title: EO-1
- local: evo1
title: EVO1
- local: groot
title: NVIDIA GR00T N1.5
- local: xvla
+186
View File
@@ -0,0 +1,186 @@
# EVO1
EVO1 is a Vision-Language-Action policy for robot control built around an InternVL3 backbone and a continuous flow-matching action head. This LeRobot integration exposes EVO1 as a standard policy type so it can be trained and evaluated with the usual LeRobot dataset, checkpoint, and processor APIs.
## Model Overview
The policy embeds one or more camera images and the language task prompt with InternVL3, pads robot state/action vectors to fixed maximum dimensions, and predicts future action chunks with a flow-matching action head. During inference, the policy samples an action chunk and returns `n_action_steps` actions from that chunk before sampling again.
### What the LeRobot Integration Covers
- Standard `policy.type=evo1` configuration through LeRobot
- InternVL3 image/text embedding with optional FlashAttention fallback
- Stage-based finetuning controls for action-head-only and VLM finetuning runs
- Continuous flow-matching action prediction
- Checkpoint save/load through LeRobot policy APIs
- Training with `lerobot-train` and evaluation with standard policy inference APIs
The broader EVO1 project may include additional training scripts and dataset tooling. This page focuses on the LeRobot robot-control policy path.
## Installation Requirements
1. Install LeRobot by following the [Installation Guide](./installation).
2. Install EVO1 dependencies:
```bash
pip install -e ".[evo1]"
```
For LIBERO evaluation, install the LIBERO extra as well:
```bash
pip install -e ".[evo1,libero]"
```
3. Install a `flash-attn` wheel only if it is compatible with your Python, PyTorch, CUDA, and GPU stack. EVO1 falls back to standard attention when `flash_attn` is not available, but reproducing the official LIBERO checkpoint conversion result below requires the same FlashAttention path used by the original EVO1 checkpoint.
EVO1 uses InternVL3 through the Hugging Face `transformers` remote-code path, so the first run may download the configured VLM checkpoint unless `policy.vlm_model_name` points to a local model directory.
## Data Requirements
EVO1 expects a LeRobot dataset with:
- One to `policy.max_views` visual observations, for example `observation.images.image`
- `observation.state`
- `action`
- A language task instruction in the dataset `task` field, or another field configured with `policy.task_field`
State and action vectors are padded to `policy.max_state_dim` and `policy.max_action_dim`. Predictions are cropped back to the dataset action dimension before being returned.
## Usage
To use EVO1 in a LeRobot configuration, specify:
```python
policy.type=evo1
```
By default, a new EVO1 policy initializes its VLM from:
```python
policy.vlm_model_name=OpenGVLab/InternVL3-1B
```
Once a LeRobot-format EVO1 checkpoint is available, load it with:
```python
policy.path=your-org/your-evo1-checkpoint
```
The converted LIBERO checkpoint used for this PR is available at:
```python
policy.path=javadcc/evo1-libero-lerobot
```
## Training
### Stage 1
Stage 1 freezes the VLM and trains the action head:
```bash
lerobot-train \
--dataset.repo_id=your_org/your_dataset \
--policy.type=evo1 \
--policy.training_stage=stage1 \
--policy.vlm_model_name=OpenGVLab/InternVL3-1B \
--policy.device=cuda \
--policy.chunk_size=50 \
--policy.n_action_steps=50 \
--policy.max_state_dim=24 \
--policy.max_action_dim=24 \
--policy.optimizer_lr=1e-5 \
--batch_size=4 \
--steps=5000 \
--output_dir=./outputs/evo1_stage1
```
### Stage 2
Stage 2 finetunes the VLM branches and action head. A common workflow starts from a Stage 1 checkpoint:
```bash
lerobot-train \
--dataset.repo_id=your_org/your_dataset \
--policy.path=./outputs/evo1_stage1/checkpoints/005000/pretrained_model \
--policy.training_stage=stage2 \
--policy.vlm_model_name=OpenGVLab/InternVL3-1B \
--policy.device=cuda \
--policy.chunk_size=50 \
--policy.n_action_steps=50 \
--policy.max_state_dim=24 \
--policy.max_action_dim=24 \
--policy.optimizer_lr=1e-5 \
--batch_size=4 \
--steps=80000 \
--output_dir=./outputs/evo1_stage2
```
By default, `policy.training_stage` reapplies the finetuning defaults for that stage. This is important when
starting Stage 2 from a Stage 1 checkpoint, because the Stage 1 checkpoint config stores the VLM finetuning
flags as disabled. These stage defaults take precedence over saved or manually supplied `policy.finetune_*`
flags unless `policy.apply_training_stage_defaults=false`, so set that flag only when manually controlling
every finetuning flag.
### Key Training Parameters
| Parameter | Default | Description |
| --------------------------------------------- | ------------------------ | ----------------------------------------------------------------- |
| `policy.vlm_model_name` | `OpenGVLab/InternVL3-1B` | InternVL3 checkpoint or local model directory |
| `policy.training_stage` | `stage1` | `stage1` trains the action head; `stage2` finetunes VLM branches |
| `policy.apply_training_stage_defaults` | `true` | Reapplies stage finetuning defaults after loading a checkpoint |
| `policy.vlm_num_layers` | `14` | Number of InternVL3 language layers kept for the policy |
| `policy.vlm_dtype` | `bfloat16` | Requested VLM dtype |
| `policy.use_flash_attn` | `true` | Requests FlashAttention when installed; otherwise falls back |
| `policy.enable_gradient_checkpointing` | `true` | Enables checkpointing on supported InternVL3 modules |
| `policy.gradient_checkpointing_use_reentrant` | `false` | Reentrant setting passed to gradient checkpointing when supported |
| `policy.chunk_size` | `50` | Number of future actions predicted per chunk |
| `policy.n_action_steps` | `50` | Number of actions consumed from a sampled chunk |
| `policy.max_state_dim` | `24` | State padding dimension |
| `policy.max_action_dim` | `24` | Action padding dimension |
| `policy.task_field` | `task` | Batch field used as the language prompt |
## Results
### LIBERO Object Checkpoint Conversion
The checkpoint [javadcc/evo1-libero-lerobot](https://huggingface.co/javadcc/evo1-libero-lerobot)
is the LeRobot-format conversion of the official EVO1 LIBERO checkpoint. The conversion was checked against
the official EVO1 checkpoint with the same LIBERO Object initial states and action postprocessing.
| Checkpoint | Suite | Episodes | Success Rate |
| ---------------------------- | --------------- | ---------------- | ------------ |
| Official EVO1 checkpoint | `libero_object` | 10, one per task | 100% |
| LeRobot converted checkpoint | `libero_object` | 10, one per task | 100% |
For a fixed `libero_object` rollout, the official checkpoint and LeRobot checkpoint produced identical
pixel embeddings, VLM fused tokens, normalized actions, and denormalized actions for the checked action step
(`max_abs_diff=0.0`).
The published checkpoint expects the raw LIBERO camera feature names
`observation.images.agentview_image` and `observation.images.robot0_eye_in_hand_image`. To run the converted
checkpoint with LeRobot LIBERO evaluation for the same one-episode-per-task setting, keep those camera names
instead of the default `image`/`image2` mapping:
```bash
lerobot-eval \
--policy.path=javadcc/evo1-libero-lerobot \
--policy.device=cuda \
--env.type=libero \
--env.task=libero_object \
--env.camera_name_mapping="{agentview_image: agentview_image, robot0_eye_in_hand_image: robot0_eye_in_hand_image}" \
--env.observation_height=448 \
--env.observation_width=448 \
--eval.batch_size=1 \
--eval.n_episodes=1
```
## References
- [EVO1 repository](https://github.com/MINT-SJTU/Evo-1)
- [InternVL3-1B](https://huggingface.co/OpenGVLab/InternVL3-1B)
## License
This LeRobot integration follows the Apache 2.0 License used by LeRobot. Check the upstream EVO1 and InternVL3 model pages for the licenses of released checkpoints and data.
+18
View File
@@ -0,0 +1,18 @@
# EVO1
EVO1 is a Vision-Language-Action policy for robot control. The LeRobot
integration uses an InternVL3 vision-language backbone with a flow-matching
action head, and supports staged training through the standard LeRobot policy
APIs.
The upstream EVO1 project is available at
[MINT-SJTU/Evo-1](https://github.com/MINT-SJTU/Evo-1).
```bibtex
@misc{evo1,
title = {EVO1},
author = {{MINT-SJTU}},
year = {2026},
howpublished = {\url{https://github.com/MINT-SJTU/Evo-1}},
}
```
+3
View File
@@ -195,6 +195,7 @@ groot = [
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
evo1 = ["lerobot[transformers-dep]", "timm>=1.0.0,<1.1.0"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
@@ -258,6 +259,7 @@ all = [
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
"lerobot[evo1]",
"lerobot[hilserl]",
"lerobot[async]",
"lerobot[dev]",
@@ -348,6 +350,7 @@ ignore = [
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
"src/lerobot/policies/evo1/**" = ["N801", "N812"]
[tool.ruff.lint.isort]
combine-as-imports = true
+14 -6
View File
@@ -24,7 +24,12 @@ import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.processor import IsaaclabArenaProcessorStep, LiberoProcessorStep, PolicyProcessorPipeline
from lerobot.processor import (
IsaaclabArenaProcessorStep,
LiberoActionProcessorStep,
LiberoProcessorStep,
PolicyProcessorPipeline,
)
from lerobot.robots import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig
from lerobot.utils.constants import (
@@ -123,7 +128,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
vec = env_cls([_make_one for _ in range(n_envs)], **extra_kwargs)
return {self.type: {0: vec}}
def get_env_processors(self):
def get_env_processors(self, policy_cfg: Any | None = None):
"""Return (preprocessor, postprocessor) for this env. Default: identity."""
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
@@ -436,10 +441,13 @@ class LiberoEnv(EnvConfig):
is_libero_plus=self.is_libero_plus,
)
def get_env_processors(self):
def get_env_processors(self, policy_cfg: Any | None = None):
max_state_dim = getattr(policy_cfg, "max_state_dim", None) if getattr(policy_cfg, "type", None) == "evo1" else None
action_feature = self.features.get(ACTION)
action_dim = int(action_feature.shape[0]) if action_feature is not None else 7
return (
PolicyProcessorPipeline(steps=[LiberoProcessorStep()]),
PolicyProcessorPipeline(steps=[]),
PolicyProcessorPipeline(steps=[LiberoProcessorStep(max_state_dim=max_state_dim)]),
PolicyProcessorPipeline(steps=[LiberoActionProcessorStep(action_dim=action_dim)]),
)
@@ -705,7 +713,7 @@ class IsaaclabArenaEnv(HubEnvConfig):
def gym_kwargs(self) -> dict:
return {}
def get_env_processors(self):
def get_env_processors(self, policy_cfg: Any | None = None):
state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip())
camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip())
if not state_keys and not camera_keys:
+9 -1
View File
@@ -15,6 +15,7 @@
# limitations under the License.
from __future__ import annotations
import inspect
from typing import Any
import gymnasium as gym
@@ -52,7 +53,14 @@ def make_env_pre_post_processors(
return make_xvla_libero_pre_post_processors()
return env_cfg.get_env_processors()
get_processors = env_cfg.get_env_processors
signature = inspect.signature(get_processors)
supports_policy_cfg = "policy_cfg" in signature.parameters or any(
param.kind is inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()
)
if supports_policy_cfg:
return get_processors(policy_cfg=policy_cfg)
return get_processors()
def make_env(
+2
View File
@@ -17,6 +17,7 @@ from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterp
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .eo1.configuration_eo1 import EO1Config as EO1Config
from .evo1.configuration_evo1 import Evo1Config as Evo1Config
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
@@ -40,6 +41,7 @@ __all__ = [
# Configuration classes
"ACTConfig",
"DiffusionConfig",
"Evo1Config",
"GrootConfig",
"MultiTaskDiTConfig",
"EO1Config",
+1
View File
@@ -0,0 +1 @@
../../../../docs/source/policy_evo1_README.md
+19
View File
@@ -0,0 +1,19 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_evo1 import Evo1Config
from .modeling_evo1 import EVO1Policy
from .processor_evo1 import make_evo1_pre_post_processors
__all__ = ["Evo1Config", "EVO1Policy", "make_evo1_pre_post_processors"]
@@ -0,0 +1,225 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from dataclasses import dataclass, field
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
@LRSchedulerConfig.register_subclass("evo1_exact")
@dataclass
class Evo1SchedulerConfig(LRSchedulerConfig):
num_warmup_steps: int
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
def lr_lambda(current_step: int) -> float:
if current_step < self.num_warmup_steps:
return current_step / max(1, self.num_warmup_steps)
progress = (current_step - self.num_warmup_steps) / max(
1, num_training_steps - self.num_warmup_steps
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return LambdaLR(optimizer, lr_lambda, -1)
@PreTrainedConfig.register_subclass("evo1")
@dataclass
class Evo1Config(PreTrainedConfig):
training_stage: str = "stage1"
use_amp: bool = True
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
max_state_dim: int = 24
max_action_dim: int = 24
max_views: int = 3
image_resolution: tuple[int, int] = (448, 448)
empty_cameras: int = 0
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
vlm_model_name: str = "OpenGVLab/InternVL3-1B"
vlm_num_layers: int | None = 14
vlm_dtype: str = "bfloat16"
use_flash_attn: bool = True
action_head: str = "flowmatching"
embed_dim: int = 896
hidden_dim: int = 1024
state_hidden_dim: int = 1024
num_heads: int = 8
num_layers: int = 8
dropout: float = 0.0
num_inference_timesteps: int = 32
num_categories: int = 1
return_cls_only: bool = False
enable_gradient_checkpointing: bool = True
gradient_checkpointing_use_reentrant: bool = False
finetune_vlm: bool | None = None
finetune_language_model: bool | None = None
finetune_vision_model: bool | None = None
finetune_action_head: bool | None = None
# Reapply stage defaults after loading checkpoint configs so stage2 cannot
# accidentally inherit the frozen VLM flags stored by a stage1 checkpoint.
apply_training_stage_defaults: bool = True
task_field: str = "task"
embodiment_id_field: str | None = None
default_embodiment_id: int = 0
optimizer_lr: float = 1e-5
optimizer_betas: tuple[float, float] = (0.9, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-5
optimizer_grad_clip_norm: float = 1.0
scheduler_warmup_steps: int = 300
drop_last: bool = True
def __post_init__(self):
super().__post_init__()
if self.training_stage not in {"stage1", "stage2"}:
raise ValueError(
f"Unsupported EVO1 training_stage '{self.training_stage}', expected 'stage1' or 'stage2'"
)
if self.apply_training_stage_defaults:
if self.training_stage == "stage1":
self.finetune_vlm = False
self.finetune_language_model = False
self.finetune_vision_model = False
self.finetune_action_head = True
elif self.training_stage == "stage2":
self.finetune_vlm = True
self.finetune_language_model = True
self.finetune_vision_model = True
self.finetune_action_head = True
elif self.training_stage == "stage1":
if self.finetune_vlm is None:
self.finetune_vlm = False
if self.finetune_language_model is None:
self.finetune_language_model = False
if self.finetune_vision_model is None:
self.finetune_vision_model = False
if self.finetune_action_head is None:
self.finetune_action_head = True
elif self.training_stage == "stage2":
has_explicit_branch_flags = any(
flag is not None for flag in (self.finetune_language_model, self.finetune_vision_model)
)
if not has_explicit_branch_flags:
if self.finetune_vlm is None:
self.finetune_vlm = True
if self.finetune_language_model is None:
self.finetune_language_model = True
if self.finetune_vision_model is None:
self.finetune_vision_model = True
elif self.finetune_vlm is None:
self.finetune_vlm = bool(self.finetune_language_model or self.finetune_vision_model)
if self.finetune_action_head is None:
self.finetune_action_head = True
if self.finetune_vlm is None:
self.finetune_vlm = False
if self.finetune_language_model is None:
self.finetune_language_model = False
if self.finetune_vision_model is None:
self.finetune_vision_model = False
if self.finetune_action_head is None:
self.finetune_action_head = False
branch_vlm = self.finetune_language_model or self.finetune_vision_model
if self.finetune_vlm != branch_vlm:
raise ValueError(
"Inconsistent EVO1 finetune config: "
f"finetune_vlm={self.finetune_vlm} but "
f"(finetune_language_model or finetune_vision_model)={branch_vlm}. "
"When branch-level flags are used, finetune_vlm must match their effective union."
)
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) must be <= chunk_size ({self.chunk_size})"
)
def validate_features(self) -> None:
if self.input_features is None:
self.input_features = {}
if self.output_features is None:
self.output_features = {}
for i in range(self.empty_cameras):
key = OBS_IMAGES + f".empty_camera_{i}"
if key not in self.input_features:
self.input_features[key] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, *self.image_resolution),
)
if OBS_STATE not in self.input_features:
self.input_features[OBS_STATE] = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,),
)
if ACTION not in self.output_features:
self.output_features[ACTION] = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,),
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return Evo1SchedulerConfig(
num_warmup_steps=self.scheduler_warmup_steps,
)
@property
def observation_delta_indices(self) -> list[int]:
return [0]
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
+234
View File
@@ -0,0 +1,234 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
import torch
import torch.nn as nn
from PIL import Image
from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead
from lerobot.policies.evo1.internvl3_embedder import InternVL3Embedder
def _cfgget(config: Any, key: str, default=None):
if isinstance(config, dict):
return config.get(key, default)
return getattr(config, key, default)
class EVO1(nn.Module):
def __init__(self, config: dict):
super().__init__()
self.config = config
self._device = _cfgget(config, "device", "cuda")
self.return_cls_only = _cfgget(config, "return_cls_only", False)
vlm_name = _cfgget(config, "vlm_name", "OpenGVLab/InternVL3-1B")
image_size = _cfgget(config, "image_size", 448)
if image_size is None:
image_resolution = _cfgget(config, "image_resolution", (448, 448))
image_size = int(image_resolution[0])
self.embedder = InternVL3Embedder(
model_name=vlm_name,
image_size=image_size,
device=self._device,
num_language_layers=_cfgget(config, "vlm_num_layers", 14),
model_dtype=_cfgget(config, "vlm_dtype", "bfloat16"),
use_flash_attn=_cfgget(config, "use_flash_attn", True),
enable_gradient_checkpointing=_cfgget(config, "enable_gradient_checkpointing", True),
gradient_checkpointing_use_reentrant=_cfgget(
config, "gradient_checkpointing_use_reentrant", False
),
)
action_head_type = _cfgget(config, "action_head", "flowmatching").lower()
if action_head_type != "flowmatching":
raise NotImplementedError(f"Unknown action_head: {action_head_type}")
horizon = _cfgget(config, "action_horizon", _cfgget(config, "horizon", 16))
per_action_dim = _cfgget(config, "per_action_dim", 7)
action_dim = horizon * per_action_dim
if isinstance(config, dict):
config["horizon"] = horizon
config["per_action_dim"] = per_action_dim
config["action_dim"] = action_dim
self.horizon = horizon
self.per_action_dim = per_action_dim
self.action_head = FlowmatchingActionHead(config=config).to(self._device)
def _normalize_image_batches(
self,
images: Sequence[Image.Image | torch.Tensor] | Sequence[Sequence[Image.Image | torch.Tensor]],
prompt: str | list[str] | None,
image_mask: torch.Tensor,
) -> tuple[list[list[Image.Image | torch.Tensor]], list[str], torch.Tensor]:
if not images:
raise ValueError("EVO1 expects at least one image per sample.")
first = images[0]
if isinstance(first, (Image.Image, torch.Tensor)):
image_batches = [list(images)] # type: ignore[arg-type]
else:
image_batches = [list(sample) for sample in images] # type: ignore[arg-type]
batch_size = len(image_batches)
if prompt is None:
prompts = [""] * batch_size
elif isinstance(prompt, str):
prompts = [prompt] * batch_size
else:
prompts = [str(p) for p in prompt]
if len(prompts) != batch_size:
raise ValueError(
f"Prompt batch size {len(prompts)} does not match image batch size {batch_size}"
)
if image_mask.dim() == 1:
image_mask = image_mask.unsqueeze(0)
if image_mask.shape[0] != batch_size:
raise ValueError(
f"image_mask batch size {image_mask.shape[0]} does not match image batch size {batch_size}"
)
return image_batches, prompts, image_mask
def get_vl_embeddings(
self,
images: list[Image.Image | torch.Tensor] | list[list[Image.Image | torch.Tensor]],
image_mask: torch.Tensor,
prompt: str | list[str] | None = None,
return_cls_only: bool | None = None,
) -> torch.Tensor:
if return_cls_only is None:
return_cls_only = self.return_cls_only
image_batches, prompts, image_mask = self._normalize_image_batches(images, prompt, image_mask)
return self.embedder.get_fused_image_text_embedding_from_tensor_images(
image_tensors_batch=image_batches,
image_masks=image_mask,
text_prompts=prompts,
return_cls_only=return_cls_only,
)
def prepare_state(self, state_input: list | torch.Tensor) -> torch.Tensor:
if isinstance(state_input, list):
state_tensor = torch.tensor(state_input)
elif isinstance(state_input, torch.Tensor):
state_tensor = state_input
else:
raise TypeError(f"Unsupported state input type: {type(state_input)}")
if state_tensor.ndim == 1:
state_tensor = state_tensor.unsqueeze(0)
return state_tensor.to(self._device)
def predict_action(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor,
actions_gt: torch.Tensor | None = None,
action_mask: torch.Tensor | None = None,
embodiment_ids: torch.Tensor | None = None,
):
if actions_gt is None:
return self.action_head.get_action(
fused_tokens,
state=state,
action_mask=action_mask,
embodiment_id=embodiment_ids,
)
return self.action_head(
fused_tokens,
state=state,
actions_gt=actions_gt,
action_mask=action_mask,
embodiment_id=embodiment_ids,
)
@torch.no_grad()
def run_inference(
self,
images: list[Image.Image | torch.Tensor],
image_mask: torch.Tensor,
prompt: str,
state_input: list | torch.Tensor,
return_cls_only: bool | None = None,
action_mask: torch.Tensor | None = None,
embodiment_ids: torch.Tensor | None = None,
) -> torch.Tensor:
if image_mask.dim() == 1:
image_mask = image_mask.unsqueeze(0)
fused_tokens = self.get_vl_embeddings(
images=[images],
image_mask=image_mask,
prompt=[prompt],
return_cls_only=return_cls_only,
)
state_tensor = self.prepare_state(state_input)
action = self.predict_action(
fused_tokens,
state_tensor,
action_mask=action_mask,
embodiment_ids=embodiment_ids,
)
if isinstance(action, torch.Tensor) and action.dtype == torch.bfloat16:
action = action.to(torch.float32)
return action
def forward(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor | None = None,
actions_gt: torch.Tensor | None = None,
action_mask: torch.Tensor | None = None,
embodiment_ids: torch.Tensor | None = None,
):
return self.predict_action(fused_tokens, state, actions_gt, action_mask, embodiment_ids)
def _set_module_trainable(self, module: nn.Module, trainable: bool):
for param in module.parameters():
param.requires_grad = trainable
def set_finetune_flags(self):
finetune_vlm = _cfgget(self.config, "finetune_vlm", False)
finetune_language_model = _cfgget(self.config, "finetune_language_model", False)
finetune_vision_model = _cfgget(self.config, "finetune_vision_model", False)
has_explicit_branch_flags = any(
flag is not None for flag in (finetune_language_model, finetune_vision_model)
)
finetune_language_model = bool(finetune_language_model)
finetune_vision_model = bool(finetune_vision_model)
finetune_vlm = bool(finetune_vlm)
if has_explicit_branch_flags:
self._set_module_trainable(self.embedder, False)
if hasattr(self.embedder.model, "language_model"):
self._set_module_trainable(self.embedder.model.language_model, finetune_language_model)
if hasattr(self.embedder.model, "vision_model"):
self._set_module_trainable(self.embedder.model.vision_model, finetune_vision_model)
if hasattr(self.embedder.model, "mlp1"):
self._set_module_trainable(self.embedder.model.mlp1, finetune_vision_model)
elif not finetune_vlm:
self._set_module_trainable(self.embedder, False)
if not _cfgget(self.config, "finetune_action_head", False):
self._set_module_trainable(self.action_head, False)
+456
View File
@@ -0,0 +1,456 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import math
from types import SimpleNamespace
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
def _cfgget(config, key: str, default=None):
if isinstance(config, dict):
return config.get(key, default)
return getattr(config, key, default)
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, dim: int, max_len: int = 1000):
super().__init__()
pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, seq_len: int):
if seq_len > self.pe.size(1):
self._extend_pe(seq_len)
return self.pe[:, :seq_len, :]
def _extend_pe(self, new_max_len):
old_max_len, dim = self.pe.size(1), self.pe.size(2)
if new_max_len <= old_max_len:
return
extra_positions = torch.arange(old_max_len, new_max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
extra_pe = torch.zeros(new_max_len - old_max_len, dim)
extra_pe[:, 0::2] = torch.sin(extra_positions * div_term)
extra_pe[:, 1::2] = torch.cos(extra_positions * div_term)
extra_pe = extra_pe.unsqueeze(0)
new_pe = torch.cat([self.pe, extra_pe.to(self.pe.device)], dim=1)
self.pe = new_pe
class CategorySpecificLinear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, num_categories: int = 1):
super().__init__()
self.num_categories = num_categories
if num_categories <= 1:
self.linear = nn.Linear(in_dim, out_dim)
else:
self.weight = nn.Parameter(torch.empty(num_categories, in_dim, out_dim))
self.bias = nn.Parameter(torch.zeros(num_categories, out_dim))
nn.init.xavier_uniform_(self.weight)
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
if self.num_categories <= 1:
if x.dtype != self.linear.weight.dtype:
x = x.to(dtype=self.linear.weight.dtype)
return self.linear(x)
if x.dtype != self.weight.dtype:
x = x.to(dtype=self.weight.dtype)
orig_shape = x.shape
x_flat = x.reshape(-1, orig_shape[-1])
if category_id.dim() == 0:
cid = category_id.item()
out = x_flat @ self.weight[cid] + self.bias[cid]
else:
category_id = category_id.reshape(-1)
if category_id.numel() != x_flat.size(0):
raise ValueError(
f"category_id length {category_id.numel()} does not match flattened batch {x_flat.size(0)}"
)
weight_selected = self.weight[category_id]
bias_selected = self.bias[category_id]
out = torch.bmm(x_flat.unsqueeze(1), weight_selected).squeeze(1) + bias_selected
out_shape = orig_shape[:-1] + (out.shape[-1],)
return out.view(out_shape)
class CategorySpecificMLP(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_categories: int = 1):
super().__init__()
self.fc1 = CategorySpecificLinear(input_dim, hidden_dim, num_categories)
self.fc2 = CategorySpecificLinear(hidden_dim, output_dim, num_categories)
self.activation = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
out = self.activation(self.fc1(x, category_id))
out = self.fc2(out, category_id)
return out
class MultiEmbodimentActionEncoder(nn.Module):
def __init__(
self, action_dim: int, embed_dim: int, hidden_dim: int, horizon: int, num_categories: int = 1
):
super().__init__()
self.horizon = horizon
self.embed_dim = embed_dim
self.num_categories = num_categories
self.W1 = CategorySpecificLinear(action_dim, hidden_dim, num_categories)
self.W2 = CategorySpecificLinear(hidden_dim, hidden_dim, num_categories)
self.W3 = CategorySpecificLinear(hidden_dim, embed_dim, num_categories)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_dim, max_len=horizon)
self.activation = nn.ReLU(inplace=True)
def forward(self, action_seq: torch.Tensor, category_id: torch.LongTensor):
batch_size, horizon, action_dim = action_seq.shape
assert self.horizon == horizon, "Action sequence length must match horizon"
x = action_seq.reshape(batch_size * horizon, action_dim)
if category_id.dim() == 0:
cat_ids = category_id.expand(horizon * batch_size)
else:
cat_ids = category_id.unsqueeze(1).expand(batch_size, horizon).reshape(batch_size * horizon)
out = self.activation(self.W1(x, cat_ids))
pos_enc = self.pos_encoding(horizon).to(device=out.device, dtype=out.dtype)
out = out.view(batch_size, horizon, -1) + pos_enc
out = out.view(batch_size * horizon, -1)
out = self.activation(self.W2(out, cat_ids))
out = self.W3(out, cat_ids)
return out.view(batch_size, horizon, self.embed_dim)
class BasicTransformerBlock(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, hidden_dim: int, dropout: float = 0.0):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.ff = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, embed_dim))
def forward(self, action_tokens: torch.Tensor, context_tokens: torch.Tensor, time_emb: torch.Tensor):
x = self.norm1(action_tokens)
attn_out, _ = self.attn(x, context_tokens, context_tokens)
x = action_tokens + attn_out
x2 = self.norm2(x)
if time_emb is not None:
x2 = x2 + time_emb.unsqueeze(1)
ff_out = self.ff(x2)
return x + ff_out
class FlowmatchingActionHead(nn.Module):
def __init__(
self,
config=None,
embed_dim: int = 896,
hidden_dim: int = 1024,
action_dim: int = 16 * 7,
horizon: int = 16,
per_action_dim: int = 7,
num_heads: int = 8,
num_layers: int = 8,
dropout: float = 0.0,
num_inference_timesteps: int = 20,
num_categories: int = 1,
):
super().__init__()
if config is not None:
embed_dim = _cfgget(config, "embed_dim", embed_dim)
hidden_dim = _cfgget(config, "hidden_dim", hidden_dim)
action_dim = _cfgget(config, "action_dim", action_dim)
horizon = _cfgget(config, "horizon", horizon)
per_action_dim = _cfgget(config, "per_action_dim", per_action_dim)
num_heads = _cfgget(config, "num_heads", num_heads)
num_layers = _cfgget(config, "num_layers", num_layers)
dropout = _cfgget(config, "dropout", dropout)
num_inference_timesteps = _cfgget(config, "num_inference_timesteps", num_inference_timesteps)
num_categories = _cfgget(config, "num_categories", num_categories)
self.config = config
else:
self.config = SimpleNamespace(
embed_dim=embed_dim,
hidden_dim=hidden_dim,
action_dim=action_dim,
horizon=horizon,
per_action_dim=per_action_dim,
num_heads=num_heads,
num_layers=num_layers,
dropout=dropout,
num_inference_timesteps=num_inference_timesteps,
num_categories=num_categories,
)
logger.info("FlowmatchingActionHead num_inference_timesteps=%s", num_inference_timesteps)
self.embed_dim = embed_dim
self.horizon = horizon
self.per_action_dim = _cfgget(self.config, "per_action_dim", per_action_dim)
self.action_dim = _cfgget(self.config, "action_dim", action_dim)
self.time_pos_enc = SinusoidalPositionalEncoding(embed_dim, max_len=1000)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
embed_dim=embed_dim,
num_heads=num_heads,
hidden_dim=embed_dim * 4,
dropout=dropout,
)
for _ in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(embed_dim)
self.seq_pool_proj = nn.Linear(self.horizon * self.embed_dim, self.embed_dim)
self.mlp_head = CategorySpecificMLP(
input_dim=embed_dim,
hidden_dim=hidden_dim,
output_dim=action_dim,
num_categories=num_categories,
)
self.state_encoder = None
state_dim = _cfgget(self.config, "state_dim")
if state_dim is not None:
state_hidden = _cfgget(self.config, "state_hidden_dim", embed_dim)
self.state_encoder = CategorySpecificMLP(
input_dim=state_dim,
hidden_dim=state_hidden,
output_dim=embed_dim,
num_categories=num_categories,
)
if horizon > 1:
self.action_encoder = MultiEmbodimentActionEncoder(
action_dim=self.per_action_dim,
embed_dim=embed_dim,
hidden_dim=embed_dim,
horizon=horizon,
num_categories=num_categories,
)
self.single_action_proj = None
else:
self.action_encoder = None
self.single_action_proj = nn.Linear(self.per_action_dim, self.embed_dim)
def _project_actions(self, action_seq: torch.Tensor, embodiment_id: torch.LongTensor) -> torch.Tensor:
if self.horizon > 1 and self.action_encoder is not None:
return self.action_encoder(action_seq, embodiment_id)
if self.single_action_proj is None:
raise RuntimeError("single_action_proj is not initialized for horizon <= 1.")
return self.single_action_proj(action_seq)
def _expand_action_mask(
self,
action_mask: torch.Tensor,
batch_size: int,
per_action_dim: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
if action_mask is None:
raise ValueError("action_mask must be provided for flow matching inference.")
if action_mask.dim() == 2:
expected_last_dim = self.horizon * per_action_dim
if action_mask.shape == (batch_size, expected_last_dim):
expanded_mask = action_mask.reshape(batch_size, self.horizon, per_action_dim)
elif action_mask.shape == (batch_size, per_action_dim):
expanded_mask = action_mask.unsqueeze(1).expand(batch_size, self.horizon, per_action_dim)
else:
raise ValueError(
f"Expected action_mask shape {(batch_size, expected_last_dim)} or "
f"{(batch_size, per_action_dim)}, got {tuple(action_mask.shape)}"
)
elif action_mask.dim() == 3:
expected_shape = (batch_size, self.horizon, per_action_dim)
if tuple(action_mask.shape) != expected_shape:
raise ValueError(
f"Expected action_mask shape {expected_shape}, got {tuple(action_mask.shape)}"
)
expanded_mask = action_mask
else:
raise ValueError(f"Unsupported action_mask rank: {action_mask.dim()}")
return expanded_mask.to(device=device, dtype=dtype)
def forward(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor = None,
actions_gt: torch.Tensor = None,
embodiment_id: torch.LongTensor = None,
state_mask: torch.Tensor = None,
action_mask: torch.Tensor = None,
):
if actions_gt is None:
return self.get_action(
fused_tokens, state=state, embodiment_id=embodiment_id, action_mask=action_mask
)
batch_size = fused_tokens.size(0)
device = fused_tokens.device
if embodiment_id is None:
embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device)
context_tokens = fused_tokens
if state is not None and self.state_encoder is not None:
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
t = (
torch.distributions.Beta(2, 2)
.sample((batch_size,))
.clamp(0.02, 0.98)
.to(device)
.to(dtype=self.dtype)
)
time_index = (t * 999).long().clamp_(0, 999)
time_emb = self.time_pos_enc(1000)[:, time_index, :].squeeze(0).to(dtype=context_tokens.dtype)
actions_gt_seq = actions_gt
noise = torch.rand_like(actions_gt) * 2 - 1
if action_mask is not None:
action_mask = action_mask.to(dtype=noise.dtype, device=noise.device)
if action_mask.shape != noise.shape:
raise ValueError(f"action_mask shape {action_mask.shape} != noise shape {noise.shape}")
actions_gt_seq = actions_gt_seq * action_mask
noise = noise * action_mask
if self.horizon > 1:
noise_seq = noise.view(batch_size, self.horizon, self.per_action_dim)
else:
noise_seq = noise if noise.dim() == 3 else noise.unsqueeze(1)
t_broadcast = t.view(batch_size, 1, 1)
action_intermediate_seq = (1 - t_broadcast) * noise_seq + t_broadcast * actions_gt_seq
action_tokens = self._project_actions(action_intermediate_seq, embodiment_id)
target_dtype = self.dtype
action_tokens = action_tokens.to(dtype=target_dtype)
context_tokens = context_tokens.to(dtype=target_dtype)
time_emb = time_emb.to(dtype=target_dtype)
x = action_tokens
for block in self.transformer_blocks:
x = block(x, context_tokens, time_emb)
x = self.norm_out(x)
if self.horizon > 1:
x_flat = x.reshape(batch_size, -1)
x_pooled = self.seq_pool_proj(x_flat)
else:
x_pooled = x.squeeze(1)
pred_velocity = self.mlp_head(x_pooled, embodiment_id)
return pred_velocity, noise
def get_action(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor = None,
embodiment_id: torch.LongTensor = None,
action_mask: torch.Tensor = None,
):
batch_size = fused_tokens.size(0)
device = fused_tokens.device
if embodiment_id is None:
embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device)
context_tokens = fused_tokens
if state is not None and self.state_encoder is not None:
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
action_dim_total = _cfgget(self.config, "action_dim", self.action_dim)
per_action_dim = _cfgget(self.config, "per_action_dim", action_dim_total // max(self.horizon, 1))
action = torch.rand(batch_size, action_dim_total, device=device, dtype=context_tokens.dtype) * 2 - 1
action_seq = (
action.view(batch_size, self.horizon, per_action_dim)
if self.horizon > 1
else action.view(batch_size, 1, per_action_dim)
)
action_mask = self._expand_action_mask(
action_mask,
batch_size=batch_size,
per_action_dim=per_action_dim,
device=action_seq.device,
dtype=action_seq.dtype,
)
action_seq = action_seq * action_mask
target_dtype = self.dtype
context_tokens = context_tokens.to(dtype=target_dtype)
num_steps = int(_cfgget(self.config, "num_inference_timesteps", 32))
if num_steps <= 0:
raise ValueError(f"num_inference_timesteps must be positive, got {num_steps}")
dt = 1.0 / num_steps
for i in range(num_steps):
t = i / num_steps
time_index = min(int(t * 999), 999)
time_emb = (
self.time_pos_enc(1000)[:, time_index, :].to(device).squeeze(0).to(dtype=context_tokens.dtype)
)
time_emb = time_emb.unsqueeze(0).repeat(batch_size, 1)
action_seq = action_seq * action_mask
action_tokens = self._project_actions(action_seq, embodiment_id).to(dtype=target_dtype)
time_emb = time_emb.to(dtype=target_dtype)
x = action_tokens
for block in self.transformer_blocks:
x = block(x, context_tokens, time_emb)
x = self.norm_out(x)
if self.horizon > 1:
x_flat = x.reshape(batch_size, -1)
x_pooled = self.seq_pool_proj(x_flat)
else:
x_pooled = x.squeeze(1)
pred = self.mlp_head(x_pooled, embodiment_id)
action = action + dt * pred
action_seq = (
action.view(batch_size, self.horizon, per_action_dim)
if self.horizon > 1
else action.view(batch_size, 1, per_action_dim)
)
action_seq = action_seq * action_mask
return action_seq.reshape(batch_size, -1)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
@@ -0,0 +1,435 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import functools
import logging
import types
from collections.abc import Sequence
from contextlib import contextmanager
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
import torch.utils.checkpoint
import torchvision.transforms.functional as TF
from PIL import Image
from torchvision.transforms.functional import to_pil_image
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModel, AutoTokenizer
else:
AutoModel = None
AutoTokenizer = None
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>" # nosec B105
IMG_START_TOKEN = "<img>" # nosec B105
IMG_END_TOKEN = "</img>" # nosec B105
logger = logging.getLogger(__name__)
def _patch_vision_encoder_checkpointing(encoder: nn.Module, use_reentrant: bool) -> None:
if getattr(encoder, "_evo1_checkpoint_patch_applied", False):
encoder.gradient_checkpointing_use_reentrant = use_reentrant
return
original_forward = encoder.forward
def forward_with_checkpoint_kwargs(self, *args, **kwargs):
original_checkpoint = torch.utils.checkpoint.checkpoint
def checkpoint(function, *checkpoint_args, **checkpoint_kwargs):
checkpoint_kwargs.setdefault("use_reentrant", self.gradient_checkpointing_use_reentrant)
return original_checkpoint(function, *checkpoint_args, **checkpoint_kwargs)
torch.utils.checkpoint.checkpoint = checkpoint
try:
return original_forward(*args, **kwargs)
finally:
torch.utils.checkpoint.checkpoint = original_checkpoint
encoder.gradient_checkpointing_use_reentrant = use_reentrant
encoder.forward = types.MethodType(forward_with_checkpoint_kwargs, encoder)
encoder._evo1_checkpoint_patch_applied = True
def flash_attn_is_available() -> bool:
try:
import flash_attn # noqa: F401
except ModuleNotFoundError:
return False
return True
@contextmanager
def _internvl_transformers5_load_compatibility():
from transformers.modeling_utils import PreTrainedModel
original_linspace = torch.linspace
original_mark_tied = PreTrainedModel.mark_tied_weights_as_initialized
def linspace(*args, **kwargs):
if kwargs.get("device") is None:
kwargs["device"] = torch.device("cpu")
return original_linspace(*args, **kwargs)
def mark_tied_weights_as_initialized(self, loading_info):
if not hasattr(self, "all_tied_weights_keys"):
self.all_tied_weights_keys = {}
return original_mark_tied(self, loading_info)
torch.linspace = linspace
PreTrainedModel.mark_tied_weights_as_initialized = mark_tied_weights_as_initialized
try:
yield
finally:
torch.linspace = original_linspace
PreTrainedModel.mark_tied_weights_as_initialized = original_mark_tied
@functools.lru_cache(maxsize=10000)
def get_target_aspect_ratio(orig_width: int, orig_height: int, image_size: int, min_num: int, max_num: int):
aspect_ratio = orig_width / orig_height
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = orig_width * orig_height
for ratio in target_ratios:
target_ar = ratio[0] / ratio[1]
diff = abs(aspect_ratio - target_ar)
if diff < best_ratio_diff:
best_ratio_diff = diff
best_ratio = ratio
elif diff == best_ratio_diff and area > 0.5 * image_size**2 * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=1, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
ratio_w, ratio_h = get_target_aspect_ratio(orig_width, orig_height, image_size, min_num, max_num)
target_width = image_size * ratio_w
target_height = image_size * ratio_h
blocks = ratio_w * ratio_h
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
processed_images.append(resized_img.crop(box))
if use_thumbnail and len(processed_images) != 1:
processed_images.append(image.resize((image_size, image_size)))
return processed_images
class InternVL3Embedder(nn.Module):
def __init__(
self,
model_name="OpenGVLab/InternVL3-1B",
image_size=448,
device="cuda",
num_language_layers: int | None = 14,
model_dtype: str | torch.dtype = "bfloat16",
use_flash_attn: bool = True,
enable_gradient_checkpointing: bool = True,
gradient_checkpointing_use_reentrant: bool = False,
):
super().__init__()
self._requested_device = device
self.image_size = image_size
self.num_language_layers = num_language_layers
self.max_text_length = 1024
self.enable_gradient_checkpointing = bool(enable_gradient_checkpointing)
self.gradient_checkpointing_use_reentrant = bool(gradient_checkpointing_use_reentrant)
require_package("transformers", extra="evo1")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
if isinstance(model_dtype, str):
try:
model_dtype = getattr(torch, model_dtype)
except AttributeError as exc:
raise ValueError(f"Unsupported EVO1 vlm_dtype '{model_dtype}'") from exc
resolved_use_flash_attn = bool(use_flash_attn and flash_attn_is_available())
if use_flash_attn and not resolved_use_flash_attn:
logger.warning("flash_attn is not installed. Falling back to standard attention.")
# InternVL3 remote code predates Transformers 5 post-init conventions:
# it computes stochastic-depth scalars via torch.linspace(...).item()
# while Transformers initializes under torch.device("meta"), and it
# does not populate all_tied_weights_keys before loading finalization.
with _internvl_transformers5_load_compatibility():
self.model = AutoModel.from_pretrained(
model_name,
torch_dtype=model_dtype,
trust_remote_code=True,
use_flash_attn=resolved_use_flash_attn,
low_cpu_mem_usage=True,
_fast_init=False,
).to(self._requested_device)
if hasattr(self.model.language_model, "model"):
layers = self.model.language_model.model.layers
else:
layers = self.model.language_model.layers
if self.num_language_layers is not None:
layers = layers[: self.num_language_layers]
if hasattr(self.model.language_model, "model"):
self.model.language_model.model.layers = torch.nn.ModuleList(layers)
else:
self.model.language_model.layers = torch.nn.ModuleList(layers)
self.model.language_model.lm_head = torch.nn.Identity()
self._configure_memory_features()
self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
def _configure_memory_features(self) -> None:
checkpoint_kwargs = {"use_reentrant": self.gradient_checkpointing_use_reentrant}
if not self.enable_gradient_checkpointing:
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
self.model.vision_model.encoder.gradient_checkpointing = False
language_model = getattr(self.model, "language_model", None)
if language_model is not None:
if hasattr(language_model, "gradient_checkpointing_disable"):
language_model.gradient_checkpointing_disable()
elif hasattr(language_model, "gradient_checkpointing"):
language_model.gradient_checkpointing = False
if hasattr(language_model, "model"):
inner = language_model.model
if hasattr(inner, "gradient_checkpointing_disable"):
inner.gradient_checkpointing_disable()
elif hasattr(inner, "gradient_checkpointing"):
inner.gradient_checkpointing = False
return
def _enable_ckpt(module: nn.Module | None) -> bool:
if module is None:
return False
if hasattr(module, "gradient_checkpointing_enable"):
try:
module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=checkpoint_kwargs)
except TypeError:
module.gradient_checkpointing_enable()
return True
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = True
return True
return False
enabled_any = _enable_ckpt(self.model)
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
encoder = self.model.vision_model.encoder
encoder.gradient_checkpointing = True
_patch_vision_encoder_checkpointing(
encoder, use_reentrant=self.gradient_checkpointing_use_reentrant
)
enabled_any = True
language_model = getattr(self.model, "language_model", None)
if language_model is not None:
enabled_any = _enable_ckpt(language_model) or enabled_any
if hasattr(language_model, "model"):
enabled_any = _enable_ckpt(language_model.model) or enabled_any
if hasattr(language_model, "config"):
language_model.config.use_cache = False
if hasattr(self.model, "config"):
self.model.config.use_cache = False
if hasattr(self.model, "enable_input_require_grads"):
self.model.enable_input_require_grads()
if enabled_any:
logger.info("Gradient checkpointing enabled for InternVL3 embedder.")
else:
logger.warning(
"Requested gradient checkpointing, but model does not expose checkpointing controls."
)
def _preprocess_single_image(self, image: Image.Image | torch.Tensor) -> torch.Tensor:
if isinstance(image, torch.Tensor):
pil_image = to_pil_image(image.detach().cpu())
else:
pil_image = image.convert("RGB")
tiles = dynamic_preprocess(pil_image, image_size=self.image_size)
tile_tensors = torch.stack([TF.to_tensor(tile) for tile in tiles]).to(
device=self.device, dtype=torch.bfloat16
)
mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
std = torch.tensor(IMAGENET_STD, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
return (tile_tensors - mean) / std
def _preprocess_images(
self,
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
) -> tuple[torch.Tensor, list[list[int]]]:
pixel_values_list = []
batch_num_tiles_list: list[list[int]] = []
for image_tensors in image_tensors_batch:
num_tiles_list: list[int] = []
for image in image_tensors:
tiles = self._preprocess_single_image(image)
pixel_values_list.append(tiles)
num_tiles_list.append(int(tiles.shape[0]))
batch_num_tiles_list.append(num_tiles_list)
if pixel_values_list:
pixel_values = torch.cat(pixel_values_list, dim=0)
else:
pixel_values = torch.empty(
0, 3, self.image_size, self.image_size, dtype=torch.bfloat16, device=self.device
)
return pixel_values, batch_num_tiles_list
def _build_multimodal_prompts(
self,
batch_num_tiles_list: list[list[int]],
text_prompts: Sequence[str],
) -> list[str]:
prompts = []
for num_tiles_list, text_prompt in zip(batch_num_tiles_list, text_prompts, strict=True):
prompt_segments = []
for i, tile_count in enumerate(num_tiles_list):
token_count = self.model.num_image_token * tile_count
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * token_count + IMG_END_TOKEN
prompt_segments.append(f"Image-{i + 1}: {image_tokens}\n")
prompts.append("".join(prompt_segments) + text_prompt.strip())
return prompts
def _prepare_and_fuse_embeddings(
self,
prompts: Sequence[str],
vit_embeds: torch.Tensor,
image_masks: torch.Tensor,
batch_num_tiles_list: list[list[int]],
) -> tuple[torch.Tensor, torch.Tensor]:
untruncated_ids = self.tokenizer(list(prompts), padding=False, truncation=False)["input_ids"]
true_sequence_length = max((len(ids) for ids in untruncated_ids), default=0)
if true_sequence_length > self.max_text_length:
logger.warning(
"InternVL3 prompt truncated in batch: max_length=%s actual_max_length=%s",
self.max_text_length,
true_sequence_length,
)
model_inputs = self.tokenizer(
list(prompts),
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_text_length,
).to(self.device)
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]
img_token_mask = input_ids == self.img_context_token_id
input_embeds = self.model.language_model.get_input_embeddings()(input_ids).clone()
batch_size, _, channels = input_embeds.shape
vit_embeds = vit_embeds.reshape(-1, channels).to(dtype=input_embeds.dtype, device=input_embeds.device)
tokens_per_tile = self.model.num_image_token
actual_vis_tokens_list = img_token_mask.sum(dim=1).tolist()
vit_idx = 0
for batch_index in range(batch_size):
expected_vis_tokens = sum(batch_num_tiles_list[batch_index]) * tokens_per_tile
mask_b = img_token_mask[batch_index]
actual_vis_tokens = actual_vis_tokens_list[batch_index]
item_vit_embeds = vit_embeds[vit_idx : vit_idx + expected_vis_tokens]
vit_idx += expected_vis_tokens
if actual_vis_tokens > 0:
if item_vit_embeds.shape[0] < actual_vis_tokens:
raise ValueError(
f"InternVL3 produced fewer image tokens than expected for sample {batch_index}: "
f"got {item_vit_embeds.shape[0]}, need {actual_vis_tokens}"
)
input_embeds[batch_index, mask_b] = item_vit_embeds[:actual_vis_tokens]
current_token_idx = 0
img_token_locations = torch.where(mask_b)[0]
for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]):
num_tokens_for_image = num_tiles * tokens_per_tile
if not bool(image_masks[batch_index, image_index].item()):
start_offset = current_token_idx
end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations))
if start_offset < end_offset:
idxs = img_token_locations[start_offset:end_offset]
attention_mask[batch_index, idxs] = 0
current_token_idx += num_tokens_for_image
return input_embeds, attention_mask
def get_fused_image_text_embedding_from_tensor_images(
self,
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
image_masks: torch.Tensor,
text_prompts: Sequence[str],
return_cls_only: bool = True,
):
pixel_values, batch_num_tiles_list = self._preprocess_images(image_tensors_batch)
if pixel_values.shape[0] == 0:
logger.warning("InternVL3 received an empty image batch after preprocessing.")
hidden_size = getattr(self.model.config, "hidden_size", None)
if hidden_size is None and hasattr(self.model.language_model, "config"):
hidden_size = getattr(self.model.language_model.config, "hidden_size", None)
if hidden_size is None:
raise RuntimeError("Unable to infer hidden size for empty InternVL3 batch.")
empty = torch.empty(0, hidden_size, device=self.device, dtype=torch.float32)
return empty
prompts = self._build_multimodal_prompts(batch_num_tiles_list, text_prompts)
vit_embeds = self.model.extract_feature(pixel_values)
inputs_embeds, attention_mask = self._prepare_and_fuse_embeddings(
prompts,
vit_embeds,
image_masks.to(device=self.device),
batch_num_tiles_list,
)
outputs = self.model.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
return_dict=True,
)
fused_hidden = outputs.hidden_states[-1].to(torch.float32)
return fused_hidden[:, 0, :] if return_cls_only else fused_hidden
@property
def device(self) -> torch.device:
return next(self.model.parameters()).device
+450
View File
@@ -0,0 +1,450 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import builtins
from collections import deque
from contextlib import nullcontext
from pathlib import Path
import torch
from torch import Tensor
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
from lerobot.policies.evo1.evo1_model import EVO1
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
class EVO1Policy(PreTrainedPolicy):
config_class = Evo1Config
name = "evo1"
def __init__(self, config: Evo1Config, **kwargs):
super().__init__(config)
config.validate_features()
if len(config.image_features) > config.max_views:
raise ValueError(
f"EVO1 supports at most {config.max_views} camera streams, got {len(config.image_features)}"
)
self.config = config
self.model = EVO1(self._build_model_config(config))
self.model.set_finetune_flags()
self.reset()
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool | None = None,
**kwargs,
) -> T:
if strict is None:
strict = not (config is not None and getattr(config, "training_stage", None) == "stage2")
return super().from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
config=config,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
strict=strict,
**kwargs,
)
@staticmethod
def _build_model_config(config: Evo1Config) -> dict:
return {
"device": config.device,
"return_cls_only": config.return_cls_only,
"vlm_name": config.vlm_model_name,
"vlm_num_layers": config.vlm_num_layers,
"vlm_dtype": config.vlm_dtype,
"use_flash_attn": config.use_flash_attn,
"action_head": config.action_head,
"action_horizon": config.chunk_size,
"per_action_dim": config.max_action_dim,
"state_dim": config.max_state_dim,
"embed_dim": config.embed_dim,
"hidden_dim": config.hidden_dim,
"state_hidden_dim": config.state_hidden_dim,
"num_heads": config.num_heads,
"num_layers": config.num_layers,
"dropout": config.dropout,
"num_inference_timesteps": config.num_inference_timesteps,
"num_categories": config.num_categories,
"enable_gradient_checkpointing": config.enable_gradient_checkpointing,
"gradient_checkpointing_use_reentrant": config.gradient_checkpointing_use_reentrant,
"finetune_vlm": config.finetune_vlm,
"finetune_language_model": config.finetune_language_model,
"finetune_vision_model": config.finetune_vision_model,
"finetune_action_head": config.finetune_action_head,
}
@property
def _camera_keys(self) -> list[str]:
return list(self.config.image_features)
@property
def _env_action_dim(self) -> int:
action_feature = self.config.action_feature
if action_feature is None:
return self.config.max_action_dim
return int(action_feature.shape[0])
@property
def _compute_dtype(self) -> torch.dtype:
return next(self.model.action_head.parameters()).dtype
@property
def _training_compute_dtype(self) -> torch.dtype:
if str(self.config.device).startswith("cuda"):
return torch.bfloat16
return self._compute_dtype
@property
def _inference_compute_dtype(self) -> torch.dtype:
if str(self.config.device).startswith("cuda") and self.config.use_amp:
return torch.bfloat16
return self._compute_dtype
def get_optim_params(self) -> list[dict]:
decay, no_decay = [], []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
is_bias = name.endswith("bias") or ".bias" in name
is_norm = param.dim() == 1 or "norm" in name.lower()
if is_bias or is_norm:
no_decay.append(param)
else:
decay.append(param)
return [
{"params": decay, "weight_decay": self.config.optimizer_weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]
def reset(self):
self._action_queue = deque([], maxlen=self.config.n_action_steps)
def _normalize_task_batch(self, batch: dict[str, Tensor | list[str] | str]) -> list[str]:
prompts = batch.get(self.config.task_field)
if prompts is None and self.config.task_field != "task":
prompts = batch.get("task")
if prompts is None:
raise ValueError(f"EVO1 expects a '{self.config.task_field}' text field in the batch.")
if isinstance(prompts, str):
return [prompts]
if isinstance(prompts, (list, tuple)):
return [str(prompt) for prompt in prompts]
raise TypeError(f"Unsupported prompt batch type: {type(prompts)}")
def _prepare_state(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
if OBS_STATE not in batch:
raise ValueError(f"EVO1 requires '{OBS_STATE}' in the batch.")
state = batch[OBS_STATE]
if state.dim() == 1:
state = state.unsqueeze(0)
elif state.dim() == 3:
state = state[:, -1]
elif state.dim() != 2:
raise ValueError(f"Unsupported state tensor shape for EVO1: {tuple(state.shape)}")
batch_size, state_dim = state.shape
if state_dim > self.config.max_state_dim:
raise ValueError(
f"State dim {state_dim} exceeds configured max_state_dim {self.config.max_state_dim}"
)
explicit_mask = batch.get("state_mask")
if explicit_mask is not None:
if explicit_mask.dim() == 1:
explicit_mask = explicit_mask.unsqueeze(0)
elif explicit_mask.dim() == 3:
explicit_mask = explicit_mask[:, -1]
elif explicit_mask.dim() != 2:
raise ValueError(
f"Unsupported state_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
)
if explicit_mask.shape != (batch_size, state_dim):
raise ValueError(
f"state_mask shape {tuple(explicit_mask.shape)} does not match state shape {(batch_size, state_dim)}"
)
padded = torch.zeros(
batch_size,
self.config.max_state_dim,
dtype=state.dtype,
device=self.config.device,
)
padded[:, :state_dim] = state.to(device=self.config.device)
mask = torch.zeros(
batch_size,
self.config.max_state_dim,
dtype=torch.bool,
device=self.config.device,
)
if explicit_mask is None:
mask[:, :state_dim] = True
else:
mask[:, :state_dim] = explicit_mask.to(device=self.config.device, dtype=torch.bool)
return padded.to(dtype=self._compute_dtype), mask
def _prepare_actions(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
if ACTION not in batch:
raise ValueError(f"EVO1 requires '{ACTION}' in the batch for training.")
action = batch[ACTION]
if action.dim() == 2:
action = action.unsqueeze(1)
batch_size, horizon, action_dim = action.shape
if horizon != self.config.chunk_size:
raise ValueError(
f"EVO1 expects chunk_size={self.config.chunk_size}, got action horizon {horizon}"
)
if action_dim > self.config.max_action_dim:
raise ValueError(
f"Action dim {action_dim} exceeds configured max_action_dim {self.config.max_action_dim}"
)
explicit_mask = batch.get("action_mask")
if explicit_mask is not None:
if explicit_mask.dim() == 2:
if horizon == 1:
explicit_mask = explicit_mask.unsqueeze(1)
else:
raise ValueError(
f"2D action_mask is only supported when chunk_size=1, got action horizon {horizon}"
)
elif explicit_mask.dim() != 3:
raise ValueError(
f"Unsupported action_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
)
if explicit_mask.shape != (batch_size, horizon, action_dim):
raise ValueError(
"action_mask shape "
f"{tuple(explicit_mask.shape)} does not match action shape {(batch_size, horizon, action_dim)}"
)
padded = torch.zeros(
batch_size,
horizon,
self.config.max_action_dim,
dtype=action.dtype,
device=self.config.device,
)
padded[:, :, :action_dim] = action.to(device=self.config.device)
mask = torch.zeros(
batch_size,
horizon,
self.config.max_action_dim,
dtype=torch.bool,
device=self.config.device,
)
if explicit_mask is None:
mask[:, :, :action_dim] = True
else:
mask[:, :, :action_dim] = explicit_mask.to(device=self.config.device, dtype=torch.bool)
return padded.to(dtype=self._compute_dtype), mask
def _prepare_inference_action_mask(self, batch_size: int) -> Tensor:
mask = torch.zeros(
batch_size,
self.config.max_action_dim,
dtype=torch.bool,
device=self.config.device,
)
mask[:, : self._env_action_dim] = True
return mask
def _get_embodiment_ids(self, batch: dict[str, Tensor], batch_size: int) -> Tensor:
embodiment_ids = batch.get("embodiment_id")
if embodiment_ids is None and self.config.embodiment_id_field:
embodiment_ids = batch.get(self.config.embodiment_id_field)
if embodiment_ids is None:
return torch.full(
(batch_size,),
self.config.default_embodiment_id,
dtype=torch.long,
device=self.config.device,
)
if embodiment_ids.dim() == 0:
embodiment_ids = embodiment_ids.unsqueeze(0)
elif embodiment_ids.dim() > 1:
embodiment_ids = embodiment_ids[:, -1]
return embodiment_ids.to(device=self.config.device, dtype=torch.long)
@property
def _tracks_vlm_gradients(self) -> bool:
return bool(
self.config.finetune_vlm
or self.config.finetune_language_model
or self.config.finetune_vision_model
)
def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[list[Tensor]], Tensor]:
camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}."))
if not camera_keys:
raise ValueError("EVO1 requires at least one visual observation feature.")
# Normalize each camera tensor to (B, C, H, W) up-front so that batch_size is read
# from a real batch dim and not from C in the unbatched (C, H, W) case.
normalized: dict[str, Tensor] = {}
for camera_key in camera_keys[: self.config.max_views]:
image = batch[camera_key]
if image.dim() == 3:
image = image.unsqueeze(0)
elif image.dim() == 5:
image = image[:, -1]
elif image.dim() != 4:
raise ValueError(
f"Unsupported image tensor shape for EVO1: key={camera_key} shape={tuple(image.shape)}"
)
normalized[camera_key] = image
batch_size = normalized[camera_keys[0]].shape[0]
image_batches: list[list[Tensor]] = []
image_masks = torch.zeros(batch_size, self.config.max_views, dtype=torch.bool)
for batch_index in range(batch_size):
sample_images: list[Tensor] = []
for camera_key in camera_keys[: self.config.max_views]:
sample_images.append(normalized[camera_key][batch_index].detach().cpu())
if not sample_images:
raise ValueError("EVO1 received a batch without any image tensor.")
while len(sample_images) < self.config.max_views:
sample_images.append(torch.zeros_like(sample_images[0]))
image_batches.append(sample_images[: self.config.max_views])
image_masks[batch_index, : min(len(camera_keys), self.config.max_views)] = True
return image_batches, image_masks
def _compute_fused_tokens(
self,
prompts: list[str],
image_batches: list[list[Tensor]],
image_masks: Tensor,
) -> Tensor:
track_vlm_gradients = self._tracks_vlm_gradients
grad_context = nullcontext() if track_vlm_gradients else torch.no_grad()
embedder = getattr(self.model, "embedder", None)
embedder_was_training = embedder.training if embedder is not None else None
if not track_vlm_gradients and embedder is not None:
embedder.eval()
try:
with grad_context:
fused_tokens = self.model.get_vl_embeddings(
images=image_batches,
image_mask=image_masks,
prompt=prompts,
return_cls_only=self.config.return_cls_only,
)
finally:
if not track_vlm_gradients and embedder is not None and embedder_was_training is not None:
embedder.train(embedder_was_training)
if not track_vlm_gradients:
fused_tokens = fused_tokens.detach()
return fused_tokens.to(device=self.config.device, dtype=self._compute_dtype)
def _compute_masked_loss(
self,
pred_velocity: Tensor,
target_velocity: Tensor,
action_mask: Tensor,
reduction: str,
) -> Tensor:
flat_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=pred_velocity.dtype)
sq_error = ((pred_velocity - target_velocity) * flat_mask).pow(2)
active = flat_mask.sum(dim=1).clamp_min(1.0)
per_sample_loss = sq_error.sum(dim=1) / active
if reduction == "none":
return per_sample_loss
if reduction != "mean":
raise ValueError(f"Unsupported reduction '{reduction}'")
return sq_error.sum() / active.sum()
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
prompts = self._normalize_task_batch(batch)
image_batches, image_masks = self._collect_image_batches(batch)
states, _state_mask = self._prepare_state(batch)
actions_gt, action_mask = self._prepare_actions(batch)
fused_tokens = self._compute_fused_tokens(prompts, image_batches, image_masks)
states = states.to(dtype=self._training_compute_dtype)
actions_gt = actions_gt.to(dtype=self._training_compute_dtype)
fused_tokens = fused_tokens.to(dtype=self._training_compute_dtype)
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
pred_velocity, noise = self.model(
fused_tokens,
state=states,
actions_gt=actions_gt,
action_mask=action_mask.to(device=self.config.device, dtype=self._compute_dtype),
embodiment_ids=embodiment_ids,
)
flat_action_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=actions_gt.dtype)
target_velocity = (actions_gt - noise).view(actions_gt.shape[0], -1) * flat_action_mask
loss = self._compute_masked_loss(pred_velocity, target_velocity, action_mask, reduction)
loss_mean = loss.mean().item() if loss.ndim > 0 else loss.item()
return loss, {
"loss": loss_mean,
"active_action_dims": float(action_mask.sum(dim=(1, 2)).float().mean().item()),
}
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
self.eval()
prompts = self._normalize_task_batch(batch)
image_batches, image_masks = self._collect_image_batches(batch)
states, _state_mask = self._prepare_state(batch)
fused_tokens = self._compute_fused_tokens(prompts, image_batches, image_masks)
states = states.to(dtype=self._inference_compute_dtype)
fused_tokens = fused_tokens.to(dtype=self._inference_compute_dtype)
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
action_mask = self._prepare_inference_action_mask(states.shape[0])
with (
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
if self.config.use_amp and str(self.config.device).startswith("cuda")
else nullcontext()
):
actions = self.model(
fused_tokens,
state=states,
action_mask=action_mask,
embodiment_ids=embodiment_ids,
)
actions = actions.view(states.shape[0], self.config.chunk_size, self.config.max_action_dim)
return actions[:, :, : self._env_action_dim]
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
self.eval()
if len(self._action_queue) == 0:
action_chunk = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
self._action_queue.extend(action_chunk.transpose(0, 1))
return self._action_queue.popleft()
+106
View File
@@ -0,0 +1,106 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
import torch
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import (
batch_to_transition,
create_transition,
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.utils.constants import (
ACTION,
DONE,
INFO,
OBS_PREFIX,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
REWARD,
TRUNCATED,
)
def evo1_batch_to_transition(batch: dict[str, Any]):
transition = batch_to_transition(batch)
complementary_data = dict(transition.get("complementary_data") or {})
reserved = {ACTION, REWARD, DONE, TRUNCATED, INFO}
for key, value in batch.items():
if key in reserved or key.startswith(OBS_PREFIX):
continue
complementary_data.setdefault(key, value)
return create_transition(
observation=transition.get("observation"),
action=transition.get("action"),
reward=transition.get("reward", 0.0),
done=transition.get("done", False),
truncated=transition.get("truncated", False),
info=transition.get("info", {}),
complementary_data=complementary_data,
)
def make_evo1_pre_post_processors(
config: Evo1Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device=config.device),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
to_transition=evo1_batch_to_transition,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
+16 -2
View File
@@ -47,6 +47,7 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
from .act.configuration_act import ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .evo1.configuration_evo1 import Evo1Config
from .groot.configuration_groot import GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config
@@ -88,7 +89,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x".
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x", "eo1", "evo1".
Returns:
The policy class corresponding to the given name.
@@ -151,6 +152,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .eo1.modeling_eo1 import EO1Policy
return EO1Policy
elif name == "evo1":
from .evo1.modeling_evo1 import EVO1Policy
return EVO1Policy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -168,7 +173,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"smolvla", "wall_x".
"smolvla", "wall_x", "eo1", "evo1".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -203,6 +208,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return WallXConfig(**kwargs)
elif policy_type == "eo1":
return EO1Config(**kwargs)
elif policy_type == "evo1":
return Evo1Config(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -413,6 +420,13 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, Evo1Config):
from .evo1.processor_evo1 import make_evo1_pre_post_processors
processors = make_evo1_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
+2 -1
View File
@@ -40,7 +40,7 @@ from .converters import (
)
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
from .device_processor import DeviceProcessorStep
from .env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
from .env_processor import IsaaclabArenaProcessorStep, LiberoActionProcessorStep, LiberoProcessorStep
from .factory import (
make_default_processors,
make_default_robot_action_processor,
@@ -149,6 +149,7 @@ __all__ = [
"RewardProcessorStep",
"DataProcessorPipeline",
"IsaaclabArenaProcessorStep",
"LiberoActionProcessorStep",
"LiberoProcessorStep",
"TimeLimitProcessorStep",
"AddBatchDimensionProcessorStep",
+43 -3
View File
@@ -18,9 +18,9 @@ from dataclasses import dataclass
import torch
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
@dataclass
@@ -46,6 +46,8 @@ class LiberoProcessorStep(ObservationProcessorStep):
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
"""
max_state_dim: int | None = None
def _process_observation(self, observation):
"""
Processes both image and robot_state observations from LIBERO.
@@ -78,6 +80,15 @@ class LiberoProcessorStep(ObservationProcessorStep):
state = state.float()
if state.dim() == 1:
state = state.unsqueeze(0)
if self.max_state_dim is not None:
if state.shape[-1] > self.max_state_dim:
raise ValueError(
f"LIBERO state has {state.shape[-1]} dims, which is larger than "
f"configured max_state_dim={self.max_state_dim}."
)
if state.shape[-1] < self.max_state_dim:
pad_width = self.max_state_dim - state.shape[-1]
state = torch.nn.functional.pad(state, (0, pad_width))
processed_obs[OBS_STATE] = state
return processed_obs
@@ -101,7 +112,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
# add our new flattened state
state_feats[OBS_STATE] = PolicyFeature(
type=FeatureType.STATE,
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
shape=(self.max_state_dim or 8,), # [eef_pos(3), axis_angle(3), gripper(2)] plus padding
)
new_features[FeatureType.STATE] = state_feats
@@ -111,6 +122,9 @@ class LiberoProcessorStep(ObservationProcessorStep):
def observation(self, observation):
return self._process_observation(observation)
def get_config(self) -> dict:
return {"max_state_dim": self.max_state_dim}
def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor:
"""
Convert batched quaternions to axis-angle format.
@@ -153,6 +167,32 @@ class LiberoProcessorStep(ObservationProcessorStep):
return result
@dataclass
@ProcessorStepRegistry.register(name="libero_action_processor")
class LiberoActionProcessorStep(ActionProcessorStep):
"""Slices padded policy actions back to the executable LIBERO action space."""
action_dim: int = 7
def action(self, action):
if action.shape[-1] < self.action_dim:
raise ValueError(
f"LIBERO action has {action.shape[-1]} dims, which is smaller than action_dim={self.action_dim}."
)
return action[..., : self.action_dim]
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
new_features = {ft: feats.copy() for ft, feats in features.items()}
action_feats = new_features.setdefault(FeatureType.ACTION, {})
action_feats[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))
return new_features
def get_config(self) -> dict:
return {"action_dim": self.action_dim}
@dataclass
@ProcessorStepRegistry.register(name="isaaclab_arena_processor")
class IsaaclabArenaProcessorStep(ObservationProcessorStep):
+79 -2
View File
@@ -7,11 +7,14 @@ from dataclasses import dataclass, field
import gymnasium as gym
import pytest
import torch
from gymnasium.envs.registration import register, registry as gym_registry
from lerobot.configs.types import PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.envs.configs import EnvConfig, LiberoEnv
from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors
from lerobot.processor import LiberoActionProcessorStep, LiberoProcessorStep
from lerobot.utils.constants import OBS_PREFIX, OBS_STATE
logger = logging.getLogger(__name__)
@@ -61,6 +64,80 @@ def test_processors_delegation():
assert len(pre.steps) == 0
def test_processors_delegation_supports_legacy_override_signature():
"""External EnvConfig subclasses with the old get_env_processors() signature keep working."""
from lerobot.processor.pipeline import DataProcessorPipeline
@EnvConfig.register_subclass("_dispatch_legacy_proc_test")
@dataclass
class _Env(EnvConfig):
task: str = "x"
features: dict[str, PolicyFeature] = field(default_factory=dict)
@property
def gym_kwargs(self):
return {}
def get_env_processors(self):
return DataProcessorPipeline(steps=[]), DataProcessorPipeline(steps=[])
pre, post = make_env_pre_post_processors(_Env(), policy_cfg=object())
assert isinstance(pre, DataProcessorPipeline)
assert isinstance(post, DataProcessorPipeline)
def test_libero_evo1_processors_use_padded_state_and_env_action_dim():
"""EVO1 uses padded LIBERO state features while env actions stay executable."""
class _Evo1Config:
type = "evo1"
max_state_dim = 24
cfg = LiberoEnv()
pre, post = make_env_pre_post_processors(cfg, policy_cfg=_Evo1Config())
assert isinstance(pre.steps[0], LiberoProcessorStep)
assert pre.steps[0].max_state_dim == 24
assert isinstance(post.steps[0], LiberoActionProcessorStep)
assert post.steps[0].action_dim == cfg.features["action"].shape[0] == 7
class _OtherConfig:
type = "other"
pre_other, _ = make_env_pre_post_processors(cfg, policy_cfg=_OtherConfig())
assert pre_other.steps[0].max_state_dim is None
def test_libero_processor_pads_state_to_max_dim():
step = LiberoProcessorStep(max_state_dim=24)
observation = {
OBS_PREFIX
+ "robot_state": {
"eef": {
"pos": torch.tensor([[1.0, 2.0, 3.0]]),
"quat": torch.tensor([[0.0, 0.0, 0.0, 1.0]]),
},
"gripper": {"qpos": torch.tensor([[4.0, 5.0]])},
}
}
state = step.observation(observation)[OBS_STATE]
assert state.shape == (1, 24)
assert torch.allclose(state[:, :8], torch.tensor([[1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 4.0, 5.0]]))
assert torch.count_nonzero(state[:, 8:]).item() == 0
def test_libero_action_processor_slices_padded_action():
step = LiberoActionProcessorStep(action_dim=7)
action = torch.arange(2 * 3 * 24, dtype=torch.float32).reshape(2, 3, 24)
sliced = step.action(action)
assert sliced.shape == (2, 3, 7)
assert torch.equal(sliced, action[..., :7])
with pytest.raises(ValueError, match="smaller than action_dim=7"):
step.action(torch.zeros(2, 6))
def test_base_create_envs():
"""Base class create_envs() should build a single-task VectorEnv via gym.make()."""
gym_id = "_dispatch_test/CartPole-v99"
@@ -136,7 +213,7 @@ def test_custom_get_env_processors_override():
def gym_kwargs(self):
return {}
def get_env_processors(self):
def get_env_processors(self, policy_cfg=None):
return DataProcessorPipeline(steps=[]), DataProcessorPipeline(steps=[])
pre, post = _Env().get_env_processors()
+298
View File
@@ -0,0 +1,298 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from torch import nn
import lerobot.policies.evo1.modeling_evo1 as modeling_evo1
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead
from lerobot.policies.factory import get_policy_class, make_policy_config
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
STATE_DIM = 4
ACTION_DIM = 3
MAX_STATE_DIM = 6
MAX_ACTION_DIM = 5
CHUNK_SIZE = 2
EMBED_DIM = 8
class DummyEVO1(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embedder = nn.Dropout(p=0.0)
self.action_head = nn.Linear(1, 1)
self.get_vl_embeddings_calls = 0
self.grad_enabled_calls = []
self.embedder_training_calls = []
def set_finetune_flags(self):
return None
def get_vl_embeddings(self, images, image_mask, prompt=None, return_cls_only=False):
self.get_vl_embeddings_calls += 1
self.grad_enabled_calls.append(torch.is_grad_enabled())
self.embedder_training_calls.append(self.embedder.training)
return torch.ones(len(images), 4, EMBED_DIM, requires_grad=torch.is_grad_enabled())
def forward(
self,
fused_tokens,
state=None,
actions_gt=None,
action_mask=None,
embodiment_ids=None,
):
batch_size = fused_tokens.shape[0]
if actions_gt is None:
return torch.ones(batch_size, CHUNK_SIZE * MAX_ACTION_DIM)
pred_velocity = torch.zeros(batch_size, CHUNK_SIZE * MAX_ACTION_DIM)
noise = torch.zeros_like(actions_gt)
return pred_velocity, noise
def make_config(training_stage="stage1", **kwargs):
config_kwargs = {
"device": "cpu",
"vlm_model_name": "dummy-internvl3",
"training_stage": training_stage,
"chunk_size": CHUNK_SIZE,
"n_action_steps": 1,
"max_state_dim": MAX_STATE_DIM,
"max_action_dim": MAX_ACTION_DIM,
"max_views": 2,
"embed_dim": EMBED_DIM,
"hidden_dim": 16,
"state_hidden_dim": 16,
"num_heads": 2,
"num_layers": 1,
"num_inference_timesteps": 2,
"input_features": {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
f"{OBS_IMAGES}.front": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
},
"output_features": {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
},
}
config_kwargs.update(kwargs)
return Evo1Config(**config_kwargs)
def make_batch(include_action=True):
batch = {
"task": ["pick the block", "place the block"],
OBS_STATE: torch.randn(2, STATE_DIM),
f"{OBS_IMAGES}.front": torch.rand(2, 3, 16, 16),
}
if include_action:
batch[ACTION] = torch.randn(2, CHUNK_SIZE, ACTION_DIM)
return batch
def test_evo1_factory_registration():
cfg = make_policy_config(
"evo1",
device="cpu",
vlm_model_name="dummy-internvl3",
input_features={
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
f"{OBS_IMAGES}.front": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
)
assert isinstance(cfg, Evo1Config)
assert get_policy_class("evo1") is modeling_evo1.EVO1Policy
def test_evo1_stage_defaults_and_consistency():
stage1 = make_config(training_stage="stage1")
assert (stage1.finetune_vlm, stage1.finetune_language_model, stage1.finetune_vision_model) == (
False,
False,
False,
)
assert stage1.finetune_action_head is True
stage2 = make_config(training_stage="stage2")
assert (stage2.finetune_vlm, stage2.finetune_language_model, stage2.finetune_vision_model) == (
True,
True,
True,
)
assert stage2.finetune_action_head is True
stage2_from_stage1_checkpoint_flags = make_config(
training_stage="stage2",
finetune_vlm=False,
finetune_language_model=False,
finetune_vision_model=False,
finetune_action_head=False,
)
assert (
stage2_from_stage1_checkpoint_flags.finetune_vlm,
stage2_from_stage1_checkpoint_flags.finetune_language_model,
stage2_from_stage1_checkpoint_flags.finetune_vision_model,
) == (
True,
True,
True,
)
assert stage2_from_stage1_checkpoint_flags.finetune_action_head is True
explicit_off = make_config(
training_stage="stage2",
apply_training_stage_defaults=False,
finetune_vlm=False,
finetune_language_model=False,
finetune_vision_model=False,
finetune_action_head=False,
)
assert (
explicit_off.finetune_vlm,
explicit_off.finetune_language_model,
explicit_off.finetune_vision_model,
) == (
False,
False,
False,
)
assert explicit_off.finetune_action_head is False
try:
make_config(
training_stage="stage2",
apply_training_stage_defaults=False,
finetune_vlm=True,
finetune_language_model=False,
)
except ValueError as exc:
assert "Inconsistent EVO1 finetune config" in str(exc)
else:
raise AssertionError("Expected inconsistent finetune config to raise ValueError")
def test_evo1_policy_forward_and_inference_use_batched_embedding(monkeypatch):
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
policy = modeling_evo1.EVO1Policy(make_config())
loss, metrics = policy.forward(make_batch(include_action=True))
assert loss.ndim == 0
assert torch.isfinite(loss)
assert metrics["active_action_dims"] == ACTION_DIM * CHUNK_SIZE
assert policy.model.get_vl_embeddings_calls == 1
action_chunk = policy.predict_action_chunk(make_batch(include_action=False))
assert action_chunk.shape == (2, CHUNK_SIZE, ACTION_DIM)
policy.reset()
selected = policy.select_action(make_batch(include_action=False))
assert selected.shape == (2, ACTION_DIM)
def test_stage1_frozen_vlm_embeddings_do_not_track_gradients(monkeypatch):
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
policy = modeling_evo1.EVO1Policy(make_config(training_stage="stage1"))
policy.train()
image_batches, image_masks = policy._collect_image_batches(make_batch(include_action=False))
fused_tokens = policy._compute_fused_tokens(["pick", "place"], image_batches, image_masks)
assert policy.model.grad_enabled_calls == [False]
assert policy.model.embedder_training_calls == [False]
assert not fused_tokens.requires_grad
assert policy.model.embedder.training is True
def test_stage2_vlm_embeddings_track_gradients(monkeypatch):
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
policy = modeling_evo1.EVO1Policy(make_config(training_stage="stage2"))
policy.train()
image_batches, image_masks = policy._collect_image_batches(make_batch(include_action=False))
fused_tokens = policy._compute_fused_tokens(["pick", "place"], image_batches, image_masks)
assert policy.model.grad_enabled_calls == [True]
assert policy.model.embedder_training_calls == [True]
assert fused_tokens.requires_grad
def test_collect_image_batches_handles_unbatched_chw(monkeypatch):
# Regression for an issue where batch_size was read from shape[0] before normalizing
# per-camera tensor dims, so an unbatched (C, H, W) input was treated as batch_size=C.
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
policy = modeling_evo1.EVO1Policy(make_config())
batch = {
OBS_STATE: torch.randn(1, STATE_DIM),
f"{OBS_IMAGES}.front": torch.rand(3, 16, 16),
}
image_batches, image_masks = policy._collect_image_batches(batch)
assert len(image_batches) == 1
assert len(image_batches[0]) == policy.config.max_views
assert image_masks.tolist() == [[True, False]]
def test_evo1_action_mask_accepts_chunk_size_one(monkeypatch):
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
config = make_config(chunk_size=1, n_action_steps=1)
policy = modeling_evo1.EVO1Policy(config)
batch = make_batch(include_action=True)
batch[ACTION] = torch.randn(2, ACTION_DIM)
batch["action_mask"] = torch.ones(2, ACTION_DIM, dtype=torch.bool)
actions, action_mask = policy._prepare_actions(batch)
assert actions.shape == (2, 1, MAX_ACTION_DIM)
assert action_mask.shape == (2, 1, MAX_ACTION_DIM)
assert action_mask[:, :, :ACTION_DIM].all()
assert not action_mask[:, :, ACTION_DIM:].any()
def test_flowmatching_dict_config_enables_state_encoder_for_horizon_one():
head = FlowmatchingActionHead(
config={
"embed_dim": EMBED_DIM,
"hidden_dim": 16,
"action_dim": ACTION_DIM,
"horizon": 1,
"per_action_dim": ACTION_DIM,
"num_heads": 2,
"num_layers": 1,
"num_inference_timesteps": 2,
"state_dim": STATE_DIM,
"state_hidden_dim": 16,
"num_categories": 1,
}
)
assert head.state_encoder is not None
pred_velocity, noise = head(
torch.randn(2, 4, EMBED_DIM),
state=torch.randn(2, STATE_DIM),
actions_gt=torch.randn(2, 1, ACTION_DIM),
action_mask=torch.ones(2, 1, ACTION_DIM, dtype=torch.bool),
)
assert pred_velocity.shape == (2, ACTION_DIM)
assert noise.shape == (2, 1, ACTION_DIM)
Generated
+9 -1
View File
@@ -2729,6 +2729,7 @@ all = [
{ name = "scikit-image" },
{ name = "scipy" },
{ name = "teleop" },
{ name = "timm" },
{ name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
{ name = "torchdiffeq" },
{ name = "transformers" },
@@ -2820,6 +2821,10 @@ eo1 = [
evaluation = [
{ name = "av" },
]
evo1 = [
{ name = "timm" },
{ name = "transformers" },
]
feetech = [
{ name = "deepdiff" },
{ name = "feetech-servo-sdk" },
@@ -3082,6 +3087,7 @@ requires-dist = [
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" },
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["evo1"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'lekiwi'" },
@@ -3139,6 +3145,7 @@ requires-dist = [
{ name = "lerobot", extras = ["test"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["training"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'eo1'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'evo1'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'libero'" },
@@ -3199,6 +3206,7 @@ requires-dist = [
{ name = "setuptools", specifier = ">=71.0.0,<81.0.0" },
{ name = "teleop", marker = "extra == 'phone'", specifier = ">=0.1.0,<0.2.0" },
{ name = "termcolor", specifier = ">=2.4.0,<4.0.0" },
{ name = "timm", marker = "extra == 'evo1'", specifier = ">=1.0.0,<1.1.0" },
{ name = "timm", marker = "extra == 'groot'", specifier = ">=1.0.0,<1.1.0" },
{ name = "torch", marker = "sys_platform != 'linux'", specifier = ">=2.7,<2.12.0" },
{ name = "torch", marker = "sys_platform == 'linux'", specifier = ">=2.7,<2.12.0", index = "https://download.pytorch.org/whl/cu128" },
@@ -3210,7 +3218,7 @@ requires-dist = [
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "evo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
[[package]]
name = "librt"