feat(policies): add EVO1 policy

This commit is contained in:
javadcc_mac
2026-05-09 21:39:19 +08:00
parent 0e6114ac36
commit 8df8d3d866
13 changed files with 2190 additions and 2 deletions
+2
View File
@@ -49,6 +49,8 @@
title: π₀.₅ (Pi05)
- local: eo1
title: EO-1
- local: evo1
title: EVO1
- local: groot
title: NVIDIA GR00T N1.5
- local: xvla
+132
View File
@@ -0,0 +1,132 @@
# EVO1
EVO1 is a Vision-Language-Action policy for robot control built around an InternVL3 backbone and a continuous flow-matching action head. This LeRobot integration exposes EVO1 as a standard policy type so it can be trained and evaluated with the usual LeRobot dataset, checkpoint, and processor APIs.
## Model Overview
The policy embeds one or more camera images and the language task prompt with InternVL3, pads robot state/action vectors to fixed maximum dimensions, and predicts future action chunks with a flow-matching action head. During inference, the policy samples an action chunk and returns `n_action_steps` actions from that chunk before sampling again.
### What the LeRobot Integration Covers
- Standard `policy.type=evo1` configuration through LeRobot
- InternVL3 image/text embedding with optional FlashAttention fallback
- Stage-based finetuning controls for action-head-only and VLM finetuning runs
- Continuous flow-matching action prediction
- Checkpoint save/load through LeRobot policy APIs
- Training with `lerobot-train` and evaluation with standard policy inference APIs
The broader EVO1 project may include additional training scripts and dataset tooling. This page focuses on the LeRobot robot-control policy path.
## Installation Requirements
1. Install LeRobot by following the [Installation Guide](./installation).
2. Install EVO1 dependencies:
```bash
pip install -e ".[evo1]"
```
3. Install a `flash-attn` wheel only if it is compatible with your Python, PyTorch, CUDA, and GPU stack. EVO1 falls back to standard attention when `flash_attn` is not available.
EVO1 uses InternVL3 through the Hugging Face `transformers` remote-code path, so the first run may download the configured VLM checkpoint unless `policy.vlm_model_name` points to a local model directory.
## Data Requirements
EVO1 expects a LeRobot dataset with:
- One to `policy.max_views` visual observations, for example `observation.images.image`
- `observation.state`
- `action`
- A language task instruction in the dataset `task` field, or another field configured with `policy.task_field`
State and action vectors are padded to `policy.max_state_dim` and `policy.max_action_dim`. Predictions are cropped back to the dataset action dimension before being returned.
## Usage
To use EVO1 in a LeRobot configuration, specify:
```python
policy.type=evo1
```
By default, a new EVO1 policy initializes its VLM from:
```python
policy.vlm_model_name=OpenGVLab/InternVL3-1B
```
Once a LeRobot-format EVO1 checkpoint is available, load it with:
```python
policy.path=your-org/your-evo1-checkpoint
```
## Training
### Stage 1
Stage 1 freezes the VLM and trains the action head:
```bash
lerobot-train \
--dataset.repo_id=your_org/your_dataset \
--policy.type=evo1 \
--policy.training_stage=stage1 \
--policy.vlm_model_name=OpenGVLab/InternVL3-1B \
--policy.device=cuda \
--policy.chunk_size=50 \
--policy.n_action_steps=50 \
--policy.max_state_dim=24 \
--policy.max_action_dim=24 \
--policy.optimizer_lr=1e-5 \
--batch_size=4 \
--steps=5000 \
--output_dir=./outputs/evo1_stage1
```
### Stage 2
Stage 2 finetunes the VLM branches and action head. A common workflow starts from a Stage 1 checkpoint:
```bash
lerobot-train \
--dataset.repo_id=your_org/your_dataset \
--policy.path=./outputs/evo1_stage1/checkpoints/005000/pretrained_model \
--policy.training_stage=stage2 \
--policy.vlm_model_name=OpenGVLab/InternVL3-1B \
--policy.device=cuda \
--policy.chunk_size=50 \
--policy.n_action_steps=50 \
--policy.max_state_dim=24 \
--policy.max_action_dim=24 \
--policy.optimizer_lr=1e-5 \
--batch_size=4 \
--steps=80000 \
--output_dir=./outputs/evo1_stage2
```
### Key Training Parameters
| Parameter | Default | Description |
| --------------------------------------------- | ------------------------ | ----------------------------------------------------------------- |
| `policy.vlm_model_name` | `OpenGVLab/InternVL3-1B` | InternVL3 checkpoint or local model directory |
| `policy.training_stage` | `stage1` | `stage1` trains the action head; `stage2` finetunes VLM branches |
| `policy.vlm_num_layers` | `14` | Number of InternVL3 language layers kept for the policy |
| `policy.vlm_dtype` | `bfloat16` | Requested VLM dtype |
| `policy.use_flash_attn` | `true` | Requests FlashAttention when installed; otherwise falls back |
| `policy.enable_gradient_checkpointing` | `true` | Enables checkpointing on supported InternVL3 modules |
| `policy.gradient_checkpointing_use_reentrant` | `false` | Reentrant setting passed to gradient checkpointing when supported |
| `policy.chunk_size` | `50` | Number of future actions predicted per chunk |
| `policy.n_action_steps` | `50` | Number of actions consumed from a sampled chunk |
| `policy.max_state_dim` | `24` | State padding dimension |
| `policy.max_action_dim` | `24` | Action padding dimension |
| `policy.task_field` | `task` | Batch field used as the language prompt |
## References
- [EVO1 repository](https://github.com/MINT-SJTU/Evo-1)
- [InternVL3-1B](https://huggingface.co/OpenGVLab/InternVL3-1B)
## License
This LeRobot integration follows the Apache 2.0 License used by LeRobot. Check the upstream EVO1 and InternVL3 model pages for the licenses of released checkpoints and data.
+2
View File
@@ -195,6 +195,7 @@ groot = [
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
evo1 = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
@@ -334,6 +335,7 @@ ignore = [
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
"src/lerobot/policies/evo1/**" = ["N801", "N812"]
[tool.ruff.lint.isort]
combine-as-imports = true
+2
View File
@@ -17,6 +17,7 @@ from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterp
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .eo1.configuration_eo1 import EO1Config as EO1Config
from .evo1.configuration_evo1 import Evo1Config as Evo1Config
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
@@ -40,6 +41,7 @@ __all__ = [
# Configuration classes
"ACTConfig",
"DiffusionConfig",
"Evo1Config",
"GrootConfig",
"MultiTaskDiTConfig",
"EO1Config",
+19
View File
@@ -0,0 +1,19 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_evo1 import Evo1Config
from .modeling_evo1 import EVO1Policy
from .processor_evo1 import make_evo1_pre_post_processors
__all__ = ["Evo1Config", "EVO1Policy", "make_evo1_pre_post_processors"]
@@ -0,0 +1,211 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from dataclasses import dataclass, field
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
@LRSchedulerConfig.register_subclass("evo1_exact")
@dataclass
class Evo1SchedulerConfig(LRSchedulerConfig):
num_warmup_steps: int
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
def lr_lambda(current_step: int) -> float:
if current_step < self.num_warmup_steps:
return current_step / max(1, self.num_warmup_steps)
progress = (current_step - self.num_warmup_steps) / max(
1, num_training_steps - self.num_warmup_steps
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return LambdaLR(optimizer, lr_lambda, -1)
@PreTrainedConfig.register_subclass("evo1")
@dataclass
class Evo1Config(PreTrainedConfig):
training_stage: str = "stage1"
use_amp: bool = True
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
max_state_dim: int = 24
max_action_dim: int = 24
max_views: int = 3
image_resolution: tuple[int, int] = (448, 448)
empty_cameras: int = 0
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
vlm_model_name: str = "OpenGVLab/InternVL3-1B"
vlm_num_layers: int | None = 14
vlm_dtype: str = "bfloat16"
use_flash_attn: bool = True
action_head: str = "flowmatching"
embed_dim: int = 896
hidden_dim: int = 1024
state_hidden_dim: int = 1024
num_heads: int = 8
num_layers: int = 8
dropout: float = 0.0
num_inference_timesteps: int = 32
num_categories: int = 1
return_cls_only: bool = False
enable_gradient_checkpointing: bool = True
gradient_checkpointing_use_reentrant: bool = False
finetune_vlm: bool | None = None
finetune_language_model: bool | None = None
finetune_vision_model: bool | None = None
finetune_action_head: bool | None = None
task_field: str = "task"
embodiment_id_field: str | None = None
default_embodiment_id: int = 0
optimizer_lr: float = 1e-5
optimizer_betas: tuple[float, float] = (0.9, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-5
optimizer_grad_clip_norm: float = 1.0
scheduler_warmup_steps: int = 300
drop_last: bool = True
def __post_init__(self):
super().__post_init__()
if self.training_stage not in {"stage1", "stage2"}:
raise ValueError(
f"Unsupported EVO1 training_stage '{self.training_stage}', expected 'stage1' or 'stage2'"
)
if self.training_stage == "stage1":
if self.finetune_vlm is None:
self.finetune_vlm = False
if self.finetune_language_model is None:
self.finetune_language_model = False
if self.finetune_vision_model is None:
self.finetune_vision_model = False
if self.finetune_action_head is None:
self.finetune_action_head = True
elif self.training_stage == "stage2":
has_explicit_branch_flags = any(
flag is not None for flag in (self.finetune_language_model, self.finetune_vision_model)
)
if not has_explicit_branch_flags:
if self.finetune_vlm is None:
self.finetune_vlm = True
if self.finetune_language_model is None:
self.finetune_language_model = True
if self.finetune_vision_model is None:
self.finetune_vision_model = True
elif self.finetune_vlm is None:
self.finetune_vlm = bool(self.finetune_language_model or self.finetune_vision_model)
if self.finetune_action_head is None:
self.finetune_action_head = True
if self.finetune_vlm is None:
self.finetune_vlm = False
if self.finetune_language_model is None:
self.finetune_language_model = False
if self.finetune_vision_model is None:
self.finetune_vision_model = False
if self.finetune_action_head is None:
self.finetune_action_head = False
branch_vlm = self.finetune_language_model or self.finetune_vision_model
if self.finetune_vlm != branch_vlm:
raise ValueError(
"Inconsistent EVO1 finetune config: "
f"finetune_vlm={self.finetune_vlm} but "
f"(finetune_language_model or finetune_vision_model)={branch_vlm}. "
"When branch-level flags are used, finetune_vlm must match their effective union."
)
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) must be <= chunk_size ({self.chunk_size})"
)
def validate_features(self) -> None:
if self.input_features is None:
self.input_features = {}
if self.output_features is None:
self.output_features = {}
for i in range(self.empty_cameras):
key = OBS_IMAGES + f".empty_camera_{i}"
if key not in self.input_features:
self.input_features[key] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, *self.image_resolution),
)
if OBS_STATE not in self.input_features:
self.input_features[OBS_STATE] = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,),
)
if ACTION not in self.output_features:
self.output_features[ACTION] = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,),
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return Evo1SchedulerConfig(
num_warmup_steps=self.scheduler_warmup_steps,
)
@property
def observation_delta_indices(self) -> list[int]:
return [0]
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
+234
View File
@@ -0,0 +1,234 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
import torch
import torch.nn as nn
from PIL import Image
from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead
from lerobot.policies.evo1.internvl3_embedder import InternVL3Embedder
def _cfgget(config: Any, key: str, default=None):
if isinstance(config, dict):
return config.get(key, default)
return getattr(config, key, default)
class EVO1(nn.Module):
def __init__(self, config: dict):
super().__init__()
self.config = config
self._device = _cfgget(config, "device", "cuda")
self.return_cls_only = _cfgget(config, "return_cls_only", False)
vlm_name = _cfgget(config, "vlm_name", "OpenGVLab/InternVL3-1B")
image_size = _cfgget(config, "image_size", 448)
if image_size is None:
image_resolution = _cfgget(config, "image_resolution", (448, 448))
image_size = int(image_resolution[0])
self.embedder = InternVL3Embedder(
model_name=vlm_name,
image_size=image_size,
device=self._device,
num_language_layers=_cfgget(config, "vlm_num_layers", 14),
model_dtype=_cfgget(config, "vlm_dtype", "bfloat16"),
use_flash_attn=_cfgget(config, "use_flash_attn", True),
enable_gradient_checkpointing=_cfgget(config, "enable_gradient_checkpointing", True),
gradient_checkpointing_use_reentrant=_cfgget(
config, "gradient_checkpointing_use_reentrant", False
),
)
action_head_type = _cfgget(config, "action_head", "flowmatching").lower()
if action_head_type != "flowmatching":
raise NotImplementedError(f"Unknown action_head: {action_head_type}")
horizon = _cfgget(config, "action_horizon", _cfgget(config, "horizon", 16))
per_action_dim = _cfgget(config, "per_action_dim", 7)
action_dim = horizon * per_action_dim
if isinstance(config, dict):
config["horizon"] = horizon
config["per_action_dim"] = per_action_dim
config["action_dim"] = action_dim
self.horizon = horizon
self.per_action_dim = per_action_dim
self.action_head = FlowmatchingActionHead(config=config).to(self._device)
def _normalize_image_batches(
self,
images: Sequence[Image.Image | torch.Tensor] | Sequence[Sequence[Image.Image | torch.Tensor]],
prompt: str | list[str] | None,
image_mask: torch.Tensor,
) -> tuple[list[list[Image.Image | torch.Tensor]], list[str], torch.Tensor]:
if not images:
raise ValueError("EVO1 expects at least one image per sample.")
first = images[0]
if isinstance(first, (Image.Image, torch.Tensor)):
image_batches = [list(images)] # type: ignore[arg-type]
else:
image_batches = [list(sample) for sample in images] # type: ignore[arg-type]
batch_size = len(image_batches)
if prompt is None:
prompts = [""] * batch_size
elif isinstance(prompt, str):
prompts = [prompt] * batch_size
else:
prompts = [str(p) for p in prompt]
if len(prompts) != batch_size:
raise ValueError(
f"Prompt batch size {len(prompts)} does not match image batch size {batch_size}"
)
if image_mask.dim() == 1:
image_mask = image_mask.unsqueeze(0)
if image_mask.shape[0] != batch_size:
raise ValueError(
f"image_mask batch size {image_mask.shape[0]} does not match image batch size {batch_size}"
)
return image_batches, prompts, image_mask
def get_vl_embeddings(
self,
images: list[Image.Image | torch.Tensor] | list[list[Image.Image | torch.Tensor]],
image_mask: torch.Tensor,
prompt: str | list[str] | None = None,
return_cls_only: bool | None = None,
) -> torch.Tensor:
if return_cls_only is None:
return_cls_only = self.return_cls_only
image_batches, prompts, image_mask = self._normalize_image_batches(images, prompt, image_mask)
return self.embedder.get_fused_image_text_embedding_from_tensor_images(
image_tensors_batch=image_batches,
image_masks=image_mask,
text_prompts=prompts,
return_cls_only=return_cls_only,
)
def prepare_state(self, state_input: list | torch.Tensor) -> torch.Tensor:
if isinstance(state_input, list):
state_tensor = torch.tensor(state_input)
elif isinstance(state_input, torch.Tensor):
state_tensor = state_input
else:
raise TypeError(f"Unsupported state input type: {type(state_input)}")
if state_tensor.ndim == 1:
state_tensor = state_tensor.unsqueeze(0)
return state_tensor.to(self._device)
def predict_action(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor,
actions_gt: torch.Tensor | None = None,
action_mask: torch.Tensor | None = None,
embodiment_ids: torch.Tensor | None = None,
):
if actions_gt is None:
return self.action_head.get_action(
fused_tokens,
state=state,
action_mask=action_mask,
embodiment_id=embodiment_ids,
)
return self.action_head(
fused_tokens,
state=state,
actions_gt=actions_gt,
action_mask=action_mask,
embodiment_id=embodiment_ids,
)
@torch.no_grad()
def run_inference(
self,
images: list[Image.Image | torch.Tensor],
image_mask: torch.Tensor,
prompt: str,
state_input: list | torch.Tensor,
return_cls_only: bool | None = None,
action_mask: torch.Tensor | None = None,
embodiment_ids: torch.Tensor | None = None,
) -> torch.Tensor:
if image_mask.dim() == 1:
image_mask = image_mask.unsqueeze(0)
fused_tokens = self.get_vl_embeddings(
images=[images],
image_mask=image_mask,
prompt=[prompt],
return_cls_only=return_cls_only,
)
state_tensor = self.prepare_state(state_input)
action = self.predict_action(
fused_tokens,
state_tensor,
action_mask=action_mask,
embodiment_ids=embodiment_ids,
)
if isinstance(action, torch.Tensor) and action.dtype == torch.bfloat16:
action = action.to(torch.float32)
return action
def forward(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor | None = None,
actions_gt: torch.Tensor | None = None,
action_mask: torch.Tensor | None = None,
embodiment_ids: torch.Tensor | None = None,
):
return self.predict_action(fused_tokens, state, actions_gt, action_mask, embodiment_ids)
def _set_module_trainable(self, module: nn.Module, trainable: bool):
for param in module.parameters():
param.requires_grad = trainable
def set_finetune_flags(self):
finetune_vlm = _cfgget(self.config, "finetune_vlm", False)
finetune_language_model = _cfgget(self.config, "finetune_language_model", False)
finetune_vision_model = _cfgget(self.config, "finetune_vision_model", False)
has_explicit_branch_flags = any(
flag is not None for flag in (finetune_language_model, finetune_vision_model)
)
finetune_language_model = bool(finetune_language_model)
finetune_vision_model = bool(finetune_vision_model)
finetune_vlm = bool(finetune_vlm)
if has_explicit_branch_flags:
self._set_module_trainable(self.embedder, False)
if hasattr(self.embedder.model, "language_model"):
self._set_module_trainable(self.embedder.model.language_model, finetune_language_model)
if hasattr(self.embedder.model, "vision_model"):
self._set_module_trainable(self.embedder.model.vision_model, finetune_vision_model)
if hasattr(self.embedder.model, "mlp1"):
self._set_module_trainable(self.embedder.model.mlp1, finetune_vision_model)
elif not finetune_vlm:
self._set_module_trainable(self.embedder, False)
if not _cfgget(self.config, "finetune_action_head", False):
self._set_module_trainable(self.action_head, False)
+456
View File
@@ -0,0 +1,456 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import math
from types import SimpleNamespace
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
def _cfgget(config, key: str, default=None):
if isinstance(config, dict):
return config.get(key, default)
return getattr(config, key, default)
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, dim: int, max_len: int = 1000):
super().__init__()
pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, seq_len: int):
if seq_len > self.pe.size(1):
self._extend_pe(seq_len)
return self.pe[:, :seq_len, :]
def _extend_pe(self, new_max_len):
old_max_len, dim = self.pe.size(1), self.pe.size(2)
if new_max_len <= old_max_len:
return
extra_positions = torch.arange(old_max_len, new_max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
extra_pe = torch.zeros(new_max_len - old_max_len, dim)
extra_pe[:, 0::2] = torch.sin(extra_positions * div_term)
extra_pe[:, 1::2] = torch.cos(extra_positions * div_term)
extra_pe = extra_pe.unsqueeze(0)
new_pe = torch.cat([self.pe, extra_pe.to(self.pe.device)], dim=1)
self.pe = new_pe
class CategorySpecificLinear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, num_categories: int = 1):
super().__init__()
self.num_categories = num_categories
if num_categories <= 1:
self.linear = nn.Linear(in_dim, out_dim)
else:
self.weight = nn.Parameter(torch.empty(num_categories, in_dim, out_dim))
self.bias = nn.Parameter(torch.zeros(num_categories, out_dim))
nn.init.xavier_uniform_(self.weight)
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
if self.num_categories <= 1:
if x.dtype != self.linear.weight.dtype:
x = x.to(dtype=self.linear.weight.dtype)
return self.linear(x)
if x.dtype != self.weight.dtype:
x = x.to(dtype=self.weight.dtype)
orig_shape = x.shape
x_flat = x.reshape(-1, orig_shape[-1])
if category_id.dim() == 0:
cid = category_id.item()
out = x_flat @ self.weight[cid] + self.bias[cid]
else:
category_id = category_id.reshape(-1)
if category_id.numel() != x_flat.size(0):
raise ValueError(
f"category_id length {category_id.numel()} does not match flattened batch {x_flat.size(0)}"
)
weight_selected = self.weight[category_id]
bias_selected = self.bias[category_id]
out = torch.bmm(x_flat.unsqueeze(1), weight_selected).squeeze(1) + bias_selected
out_shape = orig_shape[:-1] + (out.shape[-1],)
return out.view(out_shape)
class CategorySpecificMLP(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_categories: int = 1):
super().__init__()
self.fc1 = CategorySpecificLinear(input_dim, hidden_dim, num_categories)
self.fc2 = CategorySpecificLinear(hidden_dim, output_dim, num_categories)
self.activation = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
out = self.activation(self.fc1(x, category_id))
out = self.fc2(out, category_id)
return out
class MultiEmbodimentActionEncoder(nn.Module):
def __init__(
self, action_dim: int, embed_dim: int, hidden_dim: int, horizon: int, num_categories: int = 1
):
super().__init__()
self.horizon = horizon
self.embed_dim = embed_dim
self.num_categories = num_categories
self.W1 = CategorySpecificLinear(action_dim, hidden_dim, num_categories)
self.W2 = CategorySpecificLinear(hidden_dim, hidden_dim, num_categories)
self.W3 = CategorySpecificLinear(hidden_dim, embed_dim, num_categories)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_dim, max_len=horizon)
self.activation = nn.ReLU(inplace=True)
def forward(self, action_seq: torch.Tensor, category_id: torch.LongTensor):
batch_size, horizon, action_dim = action_seq.shape
assert self.horizon == horizon, "Action sequence length must match horizon"
x = action_seq.reshape(batch_size * horizon, action_dim)
if category_id.dim() == 0:
cat_ids = category_id.expand(horizon * batch_size)
else:
cat_ids = category_id.unsqueeze(1).expand(batch_size, horizon).reshape(batch_size * horizon)
out = self.activation(self.W1(x, cat_ids))
pos_enc = self.pos_encoding(horizon).to(device=out.device, dtype=out.dtype)
out = out.view(batch_size, horizon, -1) + pos_enc
out = out.view(batch_size * horizon, -1)
out = self.activation(self.W2(out, cat_ids))
out = self.W3(out, cat_ids)
return out.view(batch_size, horizon, self.embed_dim)
class BasicTransformerBlock(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, hidden_dim: int, dropout: float = 0.0):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.ff = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, embed_dim))
def forward(self, action_tokens: torch.Tensor, context_tokens: torch.Tensor, time_emb: torch.Tensor):
x = self.norm1(action_tokens)
attn_out, _ = self.attn(x, context_tokens, context_tokens)
x = action_tokens + attn_out
x2 = self.norm2(x)
if time_emb is not None:
x2 = x2 + time_emb.unsqueeze(1)
ff_out = self.ff(x2)
return x + ff_out
class FlowmatchingActionHead(nn.Module):
def __init__(
self,
config=None,
embed_dim: int = 896,
hidden_dim: int = 1024,
action_dim: int = 16 * 7,
horizon: int = 16,
per_action_dim: int = 7,
num_heads: int = 8,
num_layers: int = 8,
dropout: float = 0.0,
num_inference_timesteps: int = 20,
num_categories: int = 1,
):
super().__init__()
if config is not None:
embed_dim = _cfgget(config, "embed_dim", embed_dim)
hidden_dim = _cfgget(config, "hidden_dim", hidden_dim)
action_dim = _cfgget(config, "action_dim", action_dim)
horizon = _cfgget(config, "horizon", horizon)
per_action_dim = _cfgget(config, "per_action_dim", per_action_dim)
num_heads = _cfgget(config, "num_heads", num_heads)
num_layers = _cfgget(config, "num_layers", num_layers)
dropout = _cfgget(config, "dropout", dropout)
num_inference_timesteps = _cfgget(config, "num_inference_timesteps", num_inference_timesteps)
num_categories = _cfgget(config, "num_categories", num_categories)
self.config = config
else:
self.config = SimpleNamespace(
embed_dim=embed_dim,
hidden_dim=hidden_dim,
action_dim=action_dim,
horizon=horizon,
per_action_dim=per_action_dim,
num_heads=num_heads,
num_layers=num_layers,
dropout=dropout,
num_inference_timesteps=num_inference_timesteps,
num_categories=num_categories,
)
logger.info("FlowmatchingActionHead num_inference_timesteps=%s", num_inference_timesteps)
self.embed_dim = embed_dim
self.horizon = horizon
self.per_action_dim = _cfgget(self.config, "per_action_dim", per_action_dim)
self.action_dim = _cfgget(self.config, "action_dim", action_dim)
self.time_pos_enc = SinusoidalPositionalEncoding(embed_dim, max_len=1000)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
embed_dim=embed_dim,
num_heads=num_heads,
hidden_dim=embed_dim * 4,
dropout=dropout,
)
for _ in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(embed_dim)
self.seq_pool_proj = nn.Linear(self.horizon * self.embed_dim, self.embed_dim)
self.mlp_head = CategorySpecificMLP(
input_dim=embed_dim,
hidden_dim=hidden_dim,
output_dim=action_dim,
num_categories=num_categories,
)
self.state_encoder = None
state_dim = _cfgget(self.config, "state_dim")
if state_dim is not None:
state_hidden = _cfgget(self.config, "state_hidden_dim", embed_dim)
self.state_encoder = CategorySpecificMLP(
input_dim=state_dim,
hidden_dim=state_hidden,
output_dim=embed_dim,
num_categories=num_categories,
)
if horizon > 1:
self.action_encoder = MultiEmbodimentActionEncoder(
action_dim=self.per_action_dim,
embed_dim=embed_dim,
hidden_dim=embed_dim,
horizon=horizon,
num_categories=num_categories,
)
self.single_action_proj = None
else:
self.action_encoder = None
self.single_action_proj = nn.Linear(self.per_action_dim, self.embed_dim)
def _project_actions(self, action_seq: torch.Tensor, embodiment_id: torch.LongTensor) -> torch.Tensor:
if self.horizon > 1 and self.action_encoder is not None:
return self.action_encoder(action_seq, embodiment_id)
if self.single_action_proj is None:
raise RuntimeError("single_action_proj is not initialized for horizon <= 1.")
return self.single_action_proj(action_seq)
def _expand_action_mask(
self,
action_mask: torch.Tensor,
batch_size: int,
per_action_dim: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
if action_mask is None:
raise ValueError("action_mask must be provided for flow matching inference.")
if action_mask.dim() == 2:
expected_last_dim = self.horizon * per_action_dim
if action_mask.shape == (batch_size, expected_last_dim):
expanded_mask = action_mask.reshape(batch_size, self.horizon, per_action_dim)
elif action_mask.shape == (batch_size, per_action_dim):
expanded_mask = action_mask.unsqueeze(1).expand(batch_size, self.horizon, per_action_dim)
else:
raise ValueError(
f"Expected action_mask shape {(batch_size, expected_last_dim)} or "
f"{(batch_size, per_action_dim)}, got {tuple(action_mask.shape)}"
)
elif action_mask.dim() == 3:
expected_shape = (batch_size, self.horizon, per_action_dim)
if tuple(action_mask.shape) != expected_shape:
raise ValueError(
f"Expected action_mask shape {expected_shape}, got {tuple(action_mask.shape)}"
)
expanded_mask = action_mask
else:
raise ValueError(f"Unsupported action_mask rank: {action_mask.dim()}")
return expanded_mask.to(device=device, dtype=dtype)
def forward(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor = None,
actions_gt: torch.Tensor = None,
embodiment_id: torch.LongTensor = None,
state_mask: torch.Tensor = None,
action_mask: torch.Tensor = None,
):
if actions_gt is None:
return self.get_action(
fused_tokens, state=state, embodiment_id=embodiment_id, action_mask=action_mask
)
batch_size = fused_tokens.size(0)
device = fused_tokens.device
if embodiment_id is None:
embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device)
context_tokens = fused_tokens
if state is not None and self.state_encoder is not None:
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
t = (
torch.distributions.Beta(2, 2)
.sample((batch_size,))
.clamp(0.02, 0.98)
.to(device)
.to(dtype=self.dtype)
)
time_index = (t * 999).long().clamp_(0, 999)
time_emb = self.time_pos_enc(1000)[:, time_index, :].squeeze(0).to(dtype=context_tokens.dtype)
actions_gt_seq = actions_gt
noise = torch.rand_like(actions_gt) * 2 - 1
if action_mask is not None:
action_mask = action_mask.to(dtype=noise.dtype, device=noise.device)
if action_mask.shape != noise.shape:
raise ValueError(f"action_mask shape {action_mask.shape} != noise shape {noise.shape}")
actions_gt_seq = actions_gt_seq * action_mask
noise = noise * action_mask
if self.horizon > 1:
noise_seq = noise.view(batch_size, self.horizon, self.per_action_dim)
else:
noise_seq = noise if noise.dim() == 3 else noise.unsqueeze(1)
t_broadcast = t.view(batch_size, 1, 1)
action_intermediate_seq = (1 - t_broadcast) * noise_seq + t_broadcast * actions_gt_seq
action_tokens = self._project_actions(action_intermediate_seq, embodiment_id)
target_dtype = self.dtype
action_tokens = action_tokens.to(dtype=target_dtype)
context_tokens = context_tokens.to(dtype=target_dtype)
time_emb = time_emb.to(dtype=target_dtype)
x = action_tokens
for block in self.transformer_blocks:
x = block(x, context_tokens, time_emb)
x = self.norm_out(x)
if self.horizon > 1:
x_flat = x.reshape(batch_size, -1)
x_pooled = self.seq_pool_proj(x_flat)
else:
x_pooled = x.squeeze(1)
pred_velocity = self.mlp_head(x_pooled, embodiment_id)
return pred_velocity, noise
def get_action(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor = None,
embodiment_id: torch.LongTensor = None,
action_mask: torch.Tensor = None,
):
batch_size = fused_tokens.size(0)
device = fused_tokens.device
if embodiment_id is None:
embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device)
context_tokens = fused_tokens
if state is not None and self.state_encoder is not None:
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
action_dim_total = _cfgget(self.config, "action_dim", self.action_dim)
per_action_dim = _cfgget(self.config, "per_action_dim", action_dim_total // max(self.horizon, 1))
action = torch.rand(batch_size, action_dim_total, device=device, dtype=context_tokens.dtype) * 2 - 1
action_seq = (
action.view(batch_size, self.horizon, per_action_dim)
if self.horizon > 1
else action.view(batch_size, 1, per_action_dim)
)
action_mask = self._expand_action_mask(
action_mask,
batch_size=batch_size,
per_action_dim=per_action_dim,
device=action_seq.device,
dtype=action_seq.dtype,
)
action_seq = action_seq * action_mask
target_dtype = self.dtype
context_tokens = context_tokens.to(dtype=target_dtype)
num_steps = int(_cfgget(self.config, "num_inference_timesteps", 32))
if num_steps <= 0:
raise ValueError(f"num_inference_timesteps must be positive, got {num_steps}")
dt = 1.0 / num_steps
for i in range(num_steps):
t = i / num_steps
time_index = min(int(t * 999), 999)
time_emb = (
self.time_pos_enc(1000)[:, time_index, :].to(device).squeeze(0).to(dtype=context_tokens.dtype)
)
time_emb = time_emb.unsqueeze(0).repeat(batch_size, 1)
action_seq = action_seq * action_mask
action_tokens = self._project_actions(action_seq, embodiment_id).to(dtype=target_dtype)
time_emb = time_emb.to(dtype=target_dtype)
x = action_tokens
for block in self.transformer_blocks:
x = block(x, context_tokens, time_emb)
x = self.norm_out(x)
if self.horizon > 1:
x_flat = x.reshape(batch_size, -1)
x_pooled = self.seq_pool_proj(x_flat)
else:
x_pooled = x.squeeze(1)
pred = self.mlp_head(x_pooled, embodiment_id)
action = action + dt * pred
action_seq = (
action.view(batch_size, self.horizon, per_action_dim)
if self.horizon > 1
else action.view(batch_size, 1, per_action_dim)
)
action_seq = action_seq * action_mask
return action_seq.reshape(batch_size, -1)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
@@ -0,0 +1,366 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import functools
import logging
from collections.abc import Sequence
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
from PIL import Image
from torchvision.transforms.functional import to_pil_image
from lerobot.utils.import_utils import require_package
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>" # nosec B105
IMG_START_TOKEN = "<img>" # nosec B105
IMG_END_TOKEN = "</img>" # nosec B105
logger = logging.getLogger(__name__)
def flash_attn_is_available() -> bool:
try:
import flash_attn # noqa: F401
except ModuleNotFoundError:
return False
return True
@functools.lru_cache(maxsize=10000)
def get_target_aspect_ratio(orig_width: int, orig_height: int, image_size: int, min_num: int, max_num: int):
aspect_ratio = orig_width / orig_height
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = orig_width * orig_height
for ratio in target_ratios:
target_ar = ratio[0] / ratio[1]
diff = abs(aspect_ratio - target_ar)
if diff < best_ratio_diff:
best_ratio_diff = diff
best_ratio = ratio
elif diff == best_ratio_diff and area > 0.5 * image_size**2 * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=1, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
ratio_w, ratio_h = get_target_aspect_ratio(orig_width, orig_height, image_size, min_num, max_num)
target_width = image_size * ratio_w
target_height = image_size * ratio_h
blocks = ratio_w * ratio_h
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
processed_images.append(resized_img.crop(box))
if use_thumbnail and len(processed_images) != 1:
processed_images.append(image.resize((image_size, image_size)))
return processed_images
class InternVL3Embedder(nn.Module):
def __init__(
self,
model_name="OpenGVLab/InternVL3-1B",
image_size=448,
device="cuda",
num_language_layers: int | None = 14,
model_dtype: str | torch.dtype = "bfloat16",
use_flash_attn: bool = True,
enable_gradient_checkpointing: bool = True,
gradient_checkpointing_use_reentrant: bool = False,
):
super().__init__()
self._requested_device = device
self.image_size = image_size
self.num_language_layers = num_language_layers
self.max_text_length = 1024
self.enable_gradient_checkpointing = bool(enable_gradient_checkpointing)
self.gradient_checkpointing_use_reentrant = bool(gradient_checkpointing_use_reentrant)
require_package("transformers", extra="evo1")
from transformers import AutoModel, AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
if isinstance(model_dtype, str):
try:
model_dtype = getattr(torch, model_dtype)
except AttributeError as exc:
raise ValueError(f"Unsupported EVO1 vlm_dtype '{model_dtype}'") from exc
resolved_use_flash_attn = bool(use_flash_attn and flash_attn_is_available())
if use_flash_attn and not resolved_use_flash_attn:
logger.warning("flash_attn is not installed. Falling back to standard attention.")
self.model = AutoModel.from_pretrained(
model_name,
torch_dtype=model_dtype,
trust_remote_code=True,
use_flash_attn=resolved_use_flash_attn,
low_cpu_mem_usage=True,
_fast_init=False,
).to(self._requested_device)
if hasattr(self.model.language_model, "model"):
layers = self.model.language_model.model.layers
else:
layers = self.model.language_model.layers
if self.num_language_layers is not None:
layers = layers[: self.num_language_layers]
if hasattr(self.model.language_model, "model"):
self.model.language_model.model.layers = torch.nn.ModuleList(layers)
else:
self.model.language_model.layers = torch.nn.ModuleList(layers)
self.model.language_model.lm_head = torch.nn.Identity()
self._configure_memory_features()
self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
def _configure_memory_features(self) -> None:
checkpoint_kwargs = {"use_reentrant": self.gradient_checkpointing_use_reentrant}
if not self.enable_gradient_checkpointing:
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
self.model.vision_model.encoder.gradient_checkpointing = False
language_model = getattr(self.model, "language_model", None)
if language_model is not None:
if hasattr(language_model, "gradient_checkpointing_disable"):
language_model.gradient_checkpointing_disable()
elif hasattr(language_model, "gradient_checkpointing"):
language_model.gradient_checkpointing = False
if hasattr(language_model, "model"):
inner = language_model.model
if hasattr(inner, "gradient_checkpointing_disable"):
inner.gradient_checkpointing_disable()
elif hasattr(inner, "gradient_checkpointing"):
inner.gradient_checkpointing = False
return
def _enable_ckpt(module: nn.Module | None) -> bool:
if module is None:
return False
if hasattr(module, "gradient_checkpointing_enable"):
try:
module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=checkpoint_kwargs)
except TypeError:
module.gradient_checkpointing_enable()
return True
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = True
return True
return False
enabled_any = _enable_ckpt(self.model)
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
self.model.vision_model.encoder.gradient_checkpointing = True
enabled_any = True
language_model = getattr(self.model, "language_model", None)
if language_model is not None:
enabled_any = _enable_ckpt(language_model) or enabled_any
if hasattr(language_model, "model"):
enabled_any = _enable_ckpt(language_model.model) or enabled_any
if hasattr(language_model, "config"):
language_model.config.use_cache = False
if hasattr(self.model, "config"):
self.model.config.use_cache = False
if hasattr(self.model, "enable_input_require_grads"):
self.model.enable_input_require_grads()
if enabled_any:
logger.info("Gradient checkpointing enabled for InternVL3 embedder.")
else:
logger.warning(
"Requested gradient checkpointing, but model does not expose checkpointing controls."
)
def _preprocess_single_image(self, image: Image.Image | torch.Tensor) -> torch.Tensor:
if isinstance(image, torch.Tensor):
pil_image = to_pil_image(image.detach().cpu())
else:
pil_image = image.convert("RGB")
tiles = dynamic_preprocess(pil_image, image_size=self.image_size)
tile_tensors = torch.stack([TF.to_tensor(tile) for tile in tiles]).to(
device=self.device, dtype=torch.bfloat16
)
mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
std = torch.tensor(IMAGENET_STD, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
return (tile_tensors - mean) / std
def _preprocess_images(
self,
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
) -> tuple[torch.Tensor, list[list[int]]]:
pixel_values_list = []
batch_num_tiles_list: list[list[int]] = []
for image_tensors in image_tensors_batch:
num_tiles_list: list[int] = []
for image in image_tensors:
tiles = self._preprocess_single_image(image)
pixel_values_list.append(tiles)
num_tiles_list.append(int(tiles.shape[0]))
batch_num_tiles_list.append(num_tiles_list)
if pixel_values_list:
pixel_values = torch.cat(pixel_values_list, dim=0)
else:
pixel_values = torch.empty(
0, 3, self.image_size, self.image_size, dtype=torch.bfloat16, device=self.device
)
return pixel_values, batch_num_tiles_list
def _build_multimodal_prompts(
self,
batch_num_tiles_list: list[list[int]],
text_prompts: Sequence[str],
) -> list[str]:
prompts = []
for num_tiles_list, text_prompt in zip(batch_num_tiles_list, text_prompts, strict=True):
prompt_segments = []
for i, tile_count in enumerate(num_tiles_list):
token_count = self.model.num_image_token * tile_count
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * token_count + IMG_END_TOKEN
prompt_segments.append(f"Image-{i + 1}: {image_tokens}\n")
prompts.append("".join(prompt_segments) + text_prompt.strip())
return prompts
def _prepare_and_fuse_embeddings(
self,
prompts: Sequence[str],
vit_embeds: torch.Tensor,
image_masks: torch.Tensor,
batch_num_tiles_list: list[list[int]],
) -> tuple[torch.Tensor, torch.Tensor]:
untruncated_ids = self.tokenizer(list(prompts), padding=False, truncation=False)["input_ids"]
true_sequence_length = max((len(ids) for ids in untruncated_ids), default=0)
if true_sequence_length > self.max_text_length:
logger.warning(
"InternVL3 prompt truncated in batch: max_length=%s actual_max_length=%s",
self.max_text_length,
true_sequence_length,
)
model_inputs = self.tokenizer(
list(prompts),
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_text_length,
).to(self.device)
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]
img_token_mask = input_ids == self.img_context_token_id
input_embeds = self.model.language_model.get_input_embeddings()(input_ids).clone()
batch_size, _, channels = input_embeds.shape
vit_embeds = vit_embeds.reshape(-1, channels).to(dtype=input_embeds.dtype, device=input_embeds.device)
tokens_per_tile = self.model.num_image_token
actual_vis_tokens_list = img_token_mask.sum(dim=1).tolist()
vit_idx = 0
for batch_index in range(batch_size):
expected_vis_tokens = sum(batch_num_tiles_list[batch_index]) * tokens_per_tile
mask_b = img_token_mask[batch_index]
actual_vis_tokens = actual_vis_tokens_list[batch_index]
item_vit_embeds = vit_embeds[vit_idx : vit_idx + expected_vis_tokens]
vit_idx += expected_vis_tokens
if actual_vis_tokens > 0:
if item_vit_embeds.shape[0] < actual_vis_tokens:
raise ValueError(
f"InternVL3 produced fewer image tokens than expected for sample {batch_index}: "
f"got {item_vit_embeds.shape[0]}, need {actual_vis_tokens}"
)
input_embeds[batch_index, mask_b] = item_vit_embeds[:actual_vis_tokens]
current_token_idx = 0
img_token_locations = torch.where(mask_b)[0]
for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]):
num_tokens_for_image = num_tiles * tokens_per_tile
if not bool(image_masks[batch_index, image_index].item()):
start_offset = current_token_idx
end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations))
if start_offset < end_offset:
idxs = img_token_locations[start_offset:end_offset]
attention_mask[batch_index, idxs] = 0
current_token_idx += num_tokens_for_image
return input_embeds, attention_mask
def get_fused_image_text_embedding_from_tensor_images(
self,
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
image_masks: torch.Tensor,
text_prompts: Sequence[str],
return_cls_only: bool = True,
):
pixel_values, batch_num_tiles_list = self._preprocess_images(image_tensors_batch)
if pixel_values.shape[0] == 0:
logger.warning("InternVL3 received an empty image batch after preprocessing.")
hidden_size = getattr(self.model.config, "hidden_size", None)
if hidden_size is None and hasattr(self.model.language_model, "config"):
hidden_size = getattr(self.model.language_model.config, "hidden_size", None)
if hidden_size is None:
raise RuntimeError("Unable to infer hidden size for empty InternVL3 batch.")
empty = torch.empty(0, hidden_size, device=self.device, dtype=torch.float32)
return empty
prompts = self._build_multimodal_prompts(batch_num_tiles_list, text_prompts)
vit_embeds = self.model.extract_feature(pixel_values)
inputs_embeds, attention_mask = self._prepare_and_fuse_embeddings(
prompts,
vit_embeds,
image_masks.to(device=self.device),
batch_num_tiles_list,
)
outputs = self.model.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
return_dict=True,
)
fused_hidden = outputs.hidden_states[-1].to(torch.float32)
return fused_hidden[:, 0, :] if return_cls_only else fused_hidden
@property
def device(self) -> torch.device:
return next(self.model.parameters()).device
+419
View File
@@ -0,0 +1,419 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import builtins
from collections import deque
from contextlib import nullcontext
from pathlib import Path
import torch
from torch import Tensor
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
from lerobot.policies.evo1.evo1_model import EVO1
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
class EVO1Policy(PreTrainedPolicy):
config_class = Evo1Config
name = "evo1"
def __init__(self, config: Evo1Config, **kwargs):
super().__init__(config)
config.validate_features()
if len(config.image_features) > config.max_views:
raise ValueError(
f"EVO1 supports at most {config.max_views} camera streams, got {len(config.image_features)}"
)
self.config = config
self.model = EVO1(self._build_model_config(config))
self.model.set_finetune_flags()
self.reset()
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool | None = None,
**kwargs,
) -> T:
if strict is None:
strict = not (config is not None and getattr(config, "training_stage", None) == "stage2")
return super().from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
config=config,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
strict=strict,
**kwargs,
)
@staticmethod
def _build_model_config(config: Evo1Config) -> dict:
return {
"device": config.device,
"return_cls_only": config.return_cls_only,
"vlm_name": config.vlm_model_name,
"vlm_num_layers": config.vlm_num_layers,
"vlm_dtype": config.vlm_dtype,
"use_flash_attn": config.use_flash_attn,
"action_head": config.action_head,
"action_horizon": config.chunk_size,
"per_action_dim": config.max_action_dim,
"state_dim": config.max_state_dim,
"embed_dim": config.embed_dim,
"hidden_dim": config.hidden_dim,
"state_hidden_dim": config.state_hidden_dim,
"num_heads": config.num_heads,
"num_layers": config.num_layers,
"dropout": config.dropout,
"num_inference_timesteps": config.num_inference_timesteps,
"num_categories": config.num_categories,
"enable_gradient_checkpointing": config.enable_gradient_checkpointing,
"gradient_checkpointing_use_reentrant": config.gradient_checkpointing_use_reentrant,
"finetune_vlm": config.finetune_vlm,
"finetune_language_model": config.finetune_language_model,
"finetune_vision_model": config.finetune_vision_model,
"finetune_action_head": config.finetune_action_head,
}
@property
def _camera_keys(self) -> list[str]:
return list(self.config.image_features)
@property
def _env_action_dim(self) -> int:
action_feature = self.config.action_feature
if action_feature is None:
return self.config.max_action_dim
return int(action_feature.shape[0])
@property
def _compute_dtype(self) -> torch.dtype:
return next(self.model.action_head.parameters()).dtype
@property
def _training_compute_dtype(self) -> torch.dtype:
if str(self.config.device).startswith("cuda"):
return torch.bfloat16
return self._compute_dtype
@property
def _inference_compute_dtype(self) -> torch.dtype:
if str(self.config.device).startswith("cuda") and self.config.use_amp:
return torch.bfloat16
return self._compute_dtype
def get_optim_params(self) -> list[dict]:
decay, no_decay = [], []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
is_bias = name.endswith("bias") or ".bias" in name
is_norm = param.dim() == 1 or "norm" in name.lower()
if is_bias or is_norm:
no_decay.append(param)
else:
decay.append(param)
return [
{"params": decay, "weight_decay": self.config.optimizer_weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]
def reset(self):
self._action_queue = deque([], maxlen=self.config.n_action_steps)
def _normalize_task_batch(self, batch: dict[str, Tensor | list[str] | str]) -> list[str]:
prompts = batch.get(self.config.task_field)
if prompts is None and self.config.task_field != "task":
prompts = batch.get("task")
if prompts is None:
raise ValueError(f"EVO1 expects a '{self.config.task_field}' text field in the batch.")
if isinstance(prompts, str):
return [prompts]
if isinstance(prompts, (list, tuple)):
return [str(prompt) for prompt in prompts]
raise TypeError(f"Unsupported prompt batch type: {type(prompts)}")
def _prepare_state(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
if OBS_STATE not in batch:
raise ValueError(f"EVO1 requires '{OBS_STATE}' in the batch.")
state = batch[OBS_STATE]
if state.dim() == 1:
state = state.unsqueeze(0)
elif state.dim() == 3:
state = state[:, -1]
elif state.dim() != 2:
raise ValueError(f"Unsupported state tensor shape for EVO1: {tuple(state.shape)}")
batch_size, state_dim = state.shape
if state_dim > self.config.max_state_dim:
raise ValueError(
f"State dim {state_dim} exceeds configured max_state_dim {self.config.max_state_dim}"
)
explicit_mask = batch.get("state_mask")
if explicit_mask is not None:
if explicit_mask.dim() == 1:
explicit_mask = explicit_mask.unsqueeze(0)
elif explicit_mask.dim() == 3:
explicit_mask = explicit_mask[:, -1]
elif explicit_mask.dim() != 2:
raise ValueError(
f"Unsupported state_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
)
if explicit_mask.shape != (batch_size, state_dim):
raise ValueError(
f"state_mask shape {tuple(explicit_mask.shape)} does not match state shape {(batch_size, state_dim)}"
)
padded = torch.zeros(
batch_size,
self.config.max_state_dim,
dtype=state.dtype,
device=self.config.device,
)
padded[:, :state_dim] = state.to(device=self.config.device)
mask = torch.zeros(
batch_size,
self.config.max_state_dim,
dtype=torch.bool,
device=self.config.device,
)
if explicit_mask is None:
mask[:, :state_dim] = True
else:
mask[:, :state_dim] = explicit_mask.to(device=self.config.device, dtype=torch.bool)
return padded.to(dtype=self._compute_dtype), mask
def _prepare_actions(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
if ACTION not in batch:
raise ValueError(f"EVO1 requires '{ACTION}' in the batch for training.")
action = batch[ACTION]
if action.dim() == 2:
action = action.unsqueeze(1)
batch_size, horizon, action_dim = action.shape
if horizon != self.config.chunk_size:
raise ValueError(
f"EVO1 expects chunk_size={self.config.chunk_size}, got action horizon {horizon}"
)
if action_dim > self.config.max_action_dim:
raise ValueError(
f"Action dim {action_dim} exceeds configured max_action_dim {self.config.max_action_dim}"
)
explicit_mask = batch.get("action_mask")
if explicit_mask is not None:
if explicit_mask.dim() == 2:
if horizon == 1:
explicit_mask = explicit_mask.unsqueeze(1)
else:
raise ValueError(
f"2D action_mask is only supported when chunk_size=1, got action horizon {horizon}"
)
elif explicit_mask.dim() != 3:
raise ValueError(
f"Unsupported action_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
)
if explicit_mask.shape != (batch_size, horizon, action_dim):
raise ValueError(
"action_mask shape "
f"{tuple(explicit_mask.shape)} does not match action shape {(batch_size, horizon, action_dim)}"
)
padded = torch.zeros(
batch_size,
horizon,
self.config.max_action_dim,
dtype=action.dtype,
device=self.config.device,
)
padded[:, :, :action_dim] = action.to(device=self.config.device)
mask = torch.zeros(
batch_size,
horizon,
self.config.max_action_dim,
dtype=torch.bool,
device=self.config.device,
)
if explicit_mask is None:
mask[:, :, :action_dim] = True
else:
mask[:, :, :action_dim] = explicit_mask.to(device=self.config.device, dtype=torch.bool)
return padded.to(dtype=self._compute_dtype), mask
def _prepare_inference_action_mask(self, batch_size: int) -> Tensor:
mask = torch.zeros(
batch_size,
self.config.max_action_dim,
dtype=torch.bool,
device=self.config.device,
)
mask[:, : self._env_action_dim] = True
return mask
def _get_embodiment_ids(self, batch: dict[str, Tensor], batch_size: int) -> Tensor:
embodiment_ids = batch.get("embodiment_id")
if embodiment_ids is None and self.config.embodiment_id_field:
embodiment_ids = batch.get(self.config.embodiment_id_field)
if embodiment_ids is None:
return torch.full(
(batch_size,),
self.config.default_embodiment_id,
dtype=torch.long,
device=self.config.device,
)
if embodiment_ids.dim() == 0:
embodiment_ids = embodiment_ids.unsqueeze(0)
elif embodiment_ids.dim() > 1:
embodiment_ids = embodiment_ids[:, -1]
return embodiment_ids.to(device=self.config.device, dtype=torch.long)
def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[list[Tensor]], Tensor]:
camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}."))
if not camera_keys:
raise ValueError("EVO1 requires at least one visual observation feature.")
batch_size = batch[camera_keys[0]].shape[0]
image_batches: list[list[Tensor]] = []
image_masks = torch.zeros(batch_size, self.config.max_views, dtype=torch.bool)
for batch_index in range(batch_size):
sample_images: list[Tensor] = []
for camera_key in camera_keys[: self.config.max_views]:
image = batch[camera_key]
if image.dim() == 3:
image = image.unsqueeze(0)
elif image.dim() == 5:
image = image[:, -1]
elif image.dim() != 4:
raise ValueError(
f"Unsupported image tensor shape for EVO1: key={camera_key} shape={tuple(image.shape)}"
)
sample_images.append(image[batch_index].detach().cpu())
if not sample_images:
raise ValueError("EVO1 received a batch without any image tensor.")
while len(sample_images) < self.config.max_views:
sample_images.append(torch.zeros_like(sample_images[0]))
image_batches.append(sample_images[: self.config.max_views])
image_masks[batch_index, : min(len(camera_keys), self.config.max_views)] = True
return image_batches, image_masks
def _compute_fused_tokens(
self,
prompts: list[str],
image_batches: list[list[Tensor]],
image_masks: Tensor,
) -> Tensor:
fused_tokens = self.model.get_vl_embeddings(
images=image_batches,
image_mask=image_masks,
prompt=prompts,
return_cls_only=self.config.return_cls_only,
)
return fused_tokens.to(device=self.config.device, dtype=self._compute_dtype)
def _compute_masked_loss(
self,
pred_velocity: Tensor,
target_velocity: Tensor,
action_mask: Tensor,
reduction: str,
) -> Tensor:
flat_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=pred_velocity.dtype)
sq_error = ((pred_velocity - target_velocity) * flat_mask).pow(2)
active = flat_mask.sum(dim=1).clamp_min(1.0)
per_sample_loss = sq_error.sum(dim=1) / active
if reduction == "none":
return per_sample_loss
if reduction != "mean":
raise ValueError(f"Unsupported reduction '{reduction}'")
return sq_error.sum() / active.sum()
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
prompts = self._normalize_task_batch(batch)
image_batches, image_masks = self._collect_image_batches(batch)
states, _state_mask = self._prepare_state(batch)
actions_gt, action_mask = self._prepare_actions(batch)
fused_tokens = self._compute_fused_tokens(prompts, image_batches, image_masks)
states = states.to(dtype=self._training_compute_dtype)
actions_gt = actions_gt.to(dtype=self._training_compute_dtype)
fused_tokens = fused_tokens.to(dtype=self._training_compute_dtype)
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
pred_velocity, noise = self.model(
fused_tokens,
state=states,
actions_gt=actions_gt,
action_mask=action_mask.to(device=self.config.device, dtype=self._compute_dtype),
embodiment_ids=embodiment_ids,
)
flat_action_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=actions_gt.dtype)
target_velocity = (actions_gt - noise).view(actions_gt.shape[0], -1) * flat_action_mask
loss = self._compute_masked_loss(pred_velocity, target_velocity, action_mask, reduction)
loss_mean = loss.mean().item() if loss.ndim > 0 else loss.item()
return loss, {
"loss": loss_mean,
"active_action_dims": float(action_mask.sum(dim=(1, 2)).float().mean().item()),
}
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
self.eval()
prompts = self._normalize_task_batch(batch)
image_batches, image_masks = self._collect_image_batches(batch)
states, _state_mask = self._prepare_state(batch)
fused_tokens = self._compute_fused_tokens(prompts, image_batches, image_masks)
states = states.to(dtype=self._inference_compute_dtype)
fused_tokens = fused_tokens.to(dtype=self._inference_compute_dtype)
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
action_mask = self._prepare_inference_action_mask(states.shape[0])
with (
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
if self.config.use_amp and str(self.config.device).startswith("cuda")
else nullcontext()
):
actions = self.model(
fused_tokens,
state=states,
action_mask=action_mask,
embodiment_ids=embodiment_ids,
)
actions = actions.view(states.shape[0], self.config.chunk_size, self.config.max_action_dim)
return actions[:, :, : self._env_action_dim]
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
self.eval()
if len(self._action_queue) == 0:
action_chunk = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
self._action_queue.extend(action_chunk.transpose(0, 1))
return self._action_queue.popleft()
+106
View File
@@ -0,0 +1,106 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
import torch
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import (
batch_to_transition,
create_transition,
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.utils.constants import (
ACTION,
DONE,
INFO,
OBS_PREFIX,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
REWARD,
TRUNCATED,
)
def evo1_batch_to_transition(batch: dict[str, Any]):
transition = batch_to_transition(batch)
complementary_data = dict(transition.get("complementary_data") or {})
reserved = {ACTION, REWARD, DONE, TRUNCATED, INFO}
for key, value in batch.items():
if key in reserved or key.startswith(OBS_PREFIX):
continue
complementary_data.setdefault(key, value)
return create_transition(
observation=transition.get("observation"),
action=transition.get("action"),
reward=transition.get("reward", 0.0),
done=transition.get("done", False),
truncated=transition.get("truncated", False),
info=transition.get("info", {}),
complementary_data=complementary_data,
)
def make_evo1_pre_post_processors(
config: Evo1Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device=config.device),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
to_transition=evo1_batch_to_transition,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
+16 -2
View File
@@ -47,6 +47,7 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
from .act.configuration_act import ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .evo1.configuration_evo1 import Evo1Config
from .groot.configuration_groot import GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config
@@ -88,7 +89,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x".
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x", "eo1", "evo1".
Returns:
The policy class corresponding to the given name.
@@ -151,6 +152,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .eo1.modeling_eo1 import EO1Policy
return EO1Policy
elif name == "evo1":
from .evo1.modeling_evo1 import EVO1Policy
return EVO1Policy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -168,7 +173,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"smolvla", "wall_x".
"smolvla", "wall_x", "eo1", "evo1".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -203,6 +208,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return WallXConfig(**kwargs)
elif policy_type == "eo1":
return EO1Config(**kwargs)
elif policy_type == "evo1":
return Evo1Config(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -413,6 +420,13 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, Evo1Config):
from .evo1.processor_evo1 import make_evo1_pre_post_processors
processors = make_evo1_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
+225
View File
@@ -0,0 +1,225 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from torch import nn
import lerobot.policies.evo1.modeling_evo1 as modeling_evo1
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead
from lerobot.policies.factory import get_policy_class, make_policy_config
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
STATE_DIM = 4
ACTION_DIM = 3
MAX_STATE_DIM = 6
MAX_ACTION_DIM = 5
CHUNK_SIZE = 2
EMBED_DIM = 8
class DummyEVO1(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.action_head = nn.Linear(1, 1)
self.get_vl_embeddings_calls = 0
def set_finetune_flags(self):
return None
def get_vl_embeddings(self, images, image_mask, prompt=None, return_cls_only=False):
self.get_vl_embeddings_calls += 1
return torch.ones(len(images), 4, EMBED_DIM)
def forward(
self,
fused_tokens,
state=None,
actions_gt=None,
action_mask=None,
embodiment_ids=None,
):
batch_size = fused_tokens.shape[0]
if actions_gt is None:
return torch.ones(batch_size, CHUNK_SIZE * MAX_ACTION_DIM)
pred_velocity = torch.zeros(batch_size, CHUNK_SIZE * MAX_ACTION_DIM)
noise = torch.zeros_like(actions_gt)
return pred_velocity, noise
def make_config(training_stage="stage1", **kwargs):
config_kwargs = {
"device": "cpu",
"vlm_model_name": "dummy-internvl3",
"training_stage": training_stage,
"chunk_size": CHUNK_SIZE,
"n_action_steps": 1,
"max_state_dim": MAX_STATE_DIM,
"max_action_dim": MAX_ACTION_DIM,
"max_views": 2,
"embed_dim": EMBED_DIM,
"hidden_dim": 16,
"state_hidden_dim": 16,
"num_heads": 2,
"num_layers": 1,
"num_inference_timesteps": 2,
"input_features": {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
f"{OBS_IMAGES}.front": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
},
"output_features": {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
},
}
config_kwargs.update(kwargs)
return Evo1Config(**config_kwargs)
def make_batch(include_action=True):
batch = {
"task": ["pick the block", "place the block"],
OBS_STATE: torch.randn(2, STATE_DIM),
f"{OBS_IMAGES}.front": torch.rand(2, 3, 16, 16),
}
if include_action:
batch[ACTION] = torch.randn(2, CHUNK_SIZE, ACTION_DIM)
return batch
def test_evo1_factory_registration():
cfg = make_policy_config(
"evo1",
device="cpu",
vlm_model_name="dummy-internvl3",
input_features={
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
f"{OBS_IMAGES}.front": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
)
assert isinstance(cfg, Evo1Config)
assert get_policy_class("evo1") is modeling_evo1.EVO1Policy
def test_evo1_stage_defaults_and_consistency():
stage1 = make_config(training_stage="stage1")
assert (stage1.finetune_vlm, stage1.finetune_language_model, stage1.finetune_vision_model) == (
False,
False,
False,
)
assert stage1.finetune_action_head is True
stage2 = make_config(training_stage="stage2")
assert (stage2.finetune_vlm, stage2.finetune_language_model, stage2.finetune_vision_model) == (
True,
True,
True,
)
assert stage2.finetune_action_head is True
explicit_off = make_config(
training_stage="stage2",
finetune_vlm=False,
finetune_language_model=False,
finetune_vision_model=False,
finetune_action_head=False,
)
assert (
explicit_off.finetune_vlm,
explicit_off.finetune_language_model,
explicit_off.finetune_vision_model,
) == (
False,
False,
False,
)
assert explicit_off.finetune_action_head is False
try:
make_config(training_stage="stage2", finetune_vlm=True, finetune_language_model=False)
except ValueError as exc:
assert "Inconsistent EVO1 finetune config" in str(exc)
else:
raise AssertionError("Expected inconsistent finetune config to raise ValueError")
def test_evo1_policy_forward_and_inference_use_batched_embedding(monkeypatch):
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
policy = modeling_evo1.EVO1Policy(make_config())
loss, metrics = policy.forward(make_batch(include_action=True))
assert loss.ndim == 0
assert torch.isfinite(loss)
assert metrics["active_action_dims"] == ACTION_DIM * CHUNK_SIZE
assert policy.model.get_vl_embeddings_calls == 1
action_chunk = policy.predict_action_chunk(make_batch(include_action=False))
assert action_chunk.shape == (2, CHUNK_SIZE, ACTION_DIM)
policy.reset()
selected = policy.select_action(make_batch(include_action=False))
assert selected.shape == (2, ACTION_DIM)
def test_evo1_action_mask_accepts_chunk_size_one(monkeypatch):
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
config = make_config(chunk_size=1, n_action_steps=1)
policy = modeling_evo1.EVO1Policy(config)
batch = make_batch(include_action=True)
batch[ACTION] = torch.randn(2, ACTION_DIM)
batch["action_mask"] = torch.ones(2, ACTION_DIM, dtype=torch.bool)
actions, action_mask = policy._prepare_actions(batch)
assert actions.shape == (2, 1, MAX_ACTION_DIM)
assert action_mask.shape == (2, 1, MAX_ACTION_DIM)
assert action_mask[:, :, :ACTION_DIM].all()
assert not action_mask[:, :, ACTION_DIM:].any()
def test_flowmatching_dict_config_enables_state_encoder_for_horizon_one():
head = FlowmatchingActionHead(
config={
"embed_dim": EMBED_DIM,
"hidden_dim": 16,
"action_dim": ACTION_DIM,
"horizon": 1,
"per_action_dim": ACTION_DIM,
"num_heads": 2,
"num_layers": 1,
"num_inference_timesteps": 2,
"state_dim": STATE_DIM,
"state_hidden_dim": 16,
"num_categories": 1,
}
)
assert head.state_encoder is not None
pred_velocity, noise = head(
torch.randn(2, 4, EMBED_DIM),
state=torch.randn(2, STATE_DIM),
actions_gt=torch.randn(2, 1, ACTION_DIM),
action_mask=torch.ones(2, 1, ACTION_DIM, dtype=torch.bool),
)
assert pred_velocity.shape == (2, ACTION_DIM)
assert noise.shape == (2, 1, ACTION_DIM)