Compare commits

...

24 Commits

Author SHA1 Message Date
Steven Palma 3fa1415057 Merge branch 'main' into feat/implement_evo1 2026-07-02 17:45:31 +02:00
Steven Palma 8dcbfd9d25 feat(policies): implement RTC to EVO1 2026-07-02 17:32:09 +02:00
Steven Palma edc01c3b94 chore: update docs + remove legacy codepaths 2026-07-02 17:21:49 +02:00
Steven Palma f5ac58adb9 refactor(policies): multiple improvements 2026-07-02 17:21:45 +02:00
Steven Palma 2afe2864e9 refactor(policies): use config for evo1 + local imports 2026-07-02 17:21:31 +02:00
Steven Palma d61941fe68 chore(evo1): delete added test + reduce diff 2026-07-02 17:21:09 +02:00
Steven Palma e181f2e383 Merge branch 'main' into feat/implement_evo1 2026-07-02 17:20:53 +02:00
Steven Palma 9f5ddeb761 fix(style): pre-commit
oops
2026-07-02 17:20:40 +02:00
Steven Palma 13adcea522 Merge branch 'main' into feat/implement_evo1 2026-07-02 10:55:31 +02:00
Steven Palma 5b541c042d refactor(policy): evo1 GPU-batched preprocessing + vectorized attention masking + remove dead code 2026-07-02 10:47:50 +02:00
Martino Russi 9423deda02 refactor(evo1): use native HF InternVL3-1B-hf, drop trust_remote_code
- Switch from OpenGVLab/InternVL3-1B (requires trust_remote_code=True)
  to OpenGVLab/InternVL3-1B-hf (native transformers implementation).
- Replace manual _extract_feature + _prepare_and_fuse_embeddings with
  a single model.forward() call — verified bit-for-bit identical output.
- Remove ~170 lines of manual ViT/pixel-shuffle/projection logic.
- Symlink README.md to docs/source/ following repo convention.

Weights are byte-identical between both model variants; only the module
naming differs. All 12 existing unit tests pass. Local training (10 steps)
on maximellerbach/omx_pickandplace confirmed working.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-23 17:17:19 +02:00
javadcc_mac 25556ceefe fix(evo1): move LIBERO padding into policy processors 2026-06-21 15:58:38 +08:00
javadcc_mac 4cfa762da8 Fix eval action conversion for bf16 policies 2026-06-13 10:51:33 +08:00
javadcc_mac fa984990c0 Fix EVO1 LIBERO eval action postprocessing 2026-06-13 10:18:34 +08:00
javadcc_mac f9b8f297b4 Fix EVO1 LIBERO rollout processors 2026-06-09 15:10:10 +08:00
javadcc_mac 95527f6051 Merge remote-tracking branch 'upstream/main' into codex/add-evo1-policy 2026-05-12 17:40:59 +08:00
javadcc_mac 407ee867b9 docs(evo1): format results table 2026-05-12 17:40:18 +08:00
javadcc_mac a5e6409985 fix(evo1): finalize policy guide alignment 2026-05-11 21:51:41 +08:00
javadcc_mac 1c9fbba9a9 chore(evo1): align with policy contribution guide conventions
- Add `src/lerobot/policies/evo1/README.md` symlink into `docs/source/evo1.mdx`
  to match the in-tree README convention (mirroring the EO-1 layout).
- Convert `transformers` import in `internvl3_embedder.py` to the standard
  `TYPE_CHECKING + _transformers_available` two-step gating used by other
  optional-backbone policies (e.g. diffusion). The previous lazy-in-`__init__`
  import was functionally equivalent for runtime gating but didn't expose the
  real symbols to type checkers.
- Add `lerobot[evo1]` to the `all` extra in `pyproject.toml` so
  `pip install 'lerobot[all]'` keeps installing every optional policy.

Per the guidance in https://moon-ci-docs.huggingface.co/docs/lerobot/pr_3534/en/contributing_a_policy.
2026-05-10 23:14:23 +08:00
javadcc_mac 6a1b5ceb9d Merge remote-tracking branch 'upstream/main' into codex/add-evo1-policy
# Conflicts:
#	uv.lock
2026-05-10 22:48:17 +08:00
javadcc_mac daa4c4dd30 chore(lock): regenerate uv.lock for evo1 extra
Adds the `evo1` entry to `[package.metadata.requires-dist]` and the
`provides-extras` list so that `uv sync --locked --extra test` (used by
fast_tests.yml) no longer reports the lockfile as stale.

Generated with `uv 0.8.0` (matching `UV_VERSION` in fast_tests.yml).
The non-evo1 marker tweaks are produced by `uv lock` re-resolving the
existing dep graph and are not introduced by this PR.
2026-05-10 22:43:26 +08:00
Yiming Wang ff992a7a1d Merge branch 'main' into codex/add-evo1-policy 2026-05-10 18:54:35 +08:00
javadcc_mac 48269dddb3 fix(evo1): infer batch size after normalizing image dims
`_collect_image_batches` read `batch_size = batch[camera_keys[0]].shape[0]`
before normalizing per-camera tensors to `(B, C, H, W)`. For an unbatched
`(C, H, W)` input (which the function tries to support via the `image.dim() == 3`
branch), this picked up the channel count `C` instead of the real batch size,
making the subsequent per-sample loop iterate `C` times and indexing go
out of bounds.

Normalize each camera tensor up-front, then read `batch_size` from the
normalized batch dim. Adds `test_collect_image_batches_handles_unbatched_chw`
covering the regression.

Reported by Copilot review on huggingface/lerobot#3545.
2026-05-10 11:29:23 +08:00
javadcc_mac 8df8d3d866 feat(policies): add EVO1 policy 2026-05-09 21:39:19 +08:00
18 changed files with 3409 additions and 6 deletions
+2
View File
@@ -71,6 +71,8 @@
title: EO-1
- local: fastwam
title: FastWAM
- local: evo1
title: EVO1
- local: groot
title: NVIDIA GR00T N1.5
- local: xvla
+191
View File
@@ -0,0 +1,191 @@
# EVO1
EVO1 is a Vision-Language-Action policy for robot control built around an InternVL3 backbone and a continuous flow-matching action head. This LeRobot integration exposes EVO1 as a standard policy type so it can be trained and evaluated with the usual LeRobot dataset, checkpoint, and processor APIs.
## Model Overview
The policy embeds one or more camera images and the language task prompt with InternVL3, pads robot state/action vectors to fixed maximum dimensions, and predicts future action chunks with a flow-matching action head. During inference, the policy samples an action chunk and returns `n_action_steps` actions from that chunk before sampling again.
### What the LeRobot Integration Covers
- Standard `policy.type=evo1` configuration through LeRobot
- InternVL3 image/text embedding with optional FlashAttention fallback
- Stage-based finetuning controls for action-head-only and VLM finetuning runs
- Continuous flow-matching action prediction
- Checkpoint save/load through LeRobot policy APIs
- Training with `lerobot-train` and evaluation with standard policy inference APIs
The broader EVO1 project may include additional training scripts and dataset tooling. This page focuses on the LeRobot robot-control policy path.
## Installation Requirements
1. Install LeRobot by following the [Installation Guide](./installation).
2. Install EVO1 dependencies:
```bash
pip install -e ".[evo1]"
```
For LIBERO evaluation, install the LIBERO extra as well:
```bash
pip install -e ".[evo1,libero]"
```
3. Install a `flash-attn` wheel only if it is compatible with your Python, PyTorch, CUDA, and GPU stack. EVO1 falls back to standard attention when `flash_attn` is not available.
EVO1 uses the native Hugging Face `transformers` InternVL implementation, so `policy.vlm_model_name` must point to a natively converted checkpoint such as `OpenGVLab/InternVL3-1B-hf` (note the `-hf` suffix). The first run may download the configured VLM checkpoint unless `policy.vlm_model_name` points to a local model directory.
## Data Requirements
EVO1 expects a LeRobot dataset with:
- One to `policy.max_views` visual observations, for example `observation.images.image`
- `observation.state`
- `action`
- A language task instruction in the dataset `task` field, or another field configured with `policy.task_field`
State and action vectors are padded to `policy.max_state_dim` and `policy.max_action_dim`. Predictions are cropped back to the dataset action dimension before being returned.
## Usage
To use EVO1 in a LeRobot configuration, specify:
```python
policy.type=evo1
```
By default, a new EVO1 policy initializes its VLM from:
```python
policy.vlm_model_name=OpenGVLab/InternVL3-1B-hf
```
Once a LeRobot-format EVO1 checkpoint is available, load it with:
```python
policy.path=your-org/your-evo1-checkpoint
```
## Training
### Stage 1
Stage 1 freezes the VLM and trains the action head:
```bash
lerobot-train \
--dataset.repo_id=your_org/your_dataset \
--policy.type=evo1 \
--policy.training_stage=stage1 \
--policy.vlm_model_name=OpenGVLab/InternVL3-1B-hf \
--policy.device=cuda \
--policy.chunk_size=50 \
--policy.n_action_steps=50 \
--policy.max_state_dim=24 \
--policy.max_action_dim=24 \
--policy.optimizer_lr=1e-5 \
--batch_size=4 \
--steps=5000 \
--output_dir=./outputs/evo1_stage1
```
### Stage 2
Stage 2 finetunes the VLM branches and action head. A common workflow starts from a Stage 1 checkpoint:
```bash
lerobot-train \
--dataset.repo_id=your_org/your_dataset \
--policy.path=./outputs/evo1_stage1/checkpoints/005000/pretrained_model \
--policy.training_stage=stage2 \
--policy.vlm_model_name=OpenGVLab/InternVL3-1B-hf \
--policy.device=cuda \
--policy.chunk_size=50 \
--policy.n_action_steps=50 \
--policy.max_state_dim=24 \
--policy.max_action_dim=24 \
--policy.optimizer_lr=1e-5 \
--batch_size=4 \
--steps=80000 \
--output_dir=./outputs/evo1_stage2
```
By default, `policy.training_stage` reapplies the finetuning defaults for that stage. This is important when
starting Stage 2 from a Stage 1 checkpoint, because the Stage 1 checkpoint config stores the VLM finetuning
flags as disabled. These stage defaults take precedence over saved or manually supplied `policy.finetune_*`
flags unless `policy.apply_training_stage_defaults=false`, so set that flag only when manually controlling
every finetuning flag.
### Key Training Parameters
| Parameter | Default | Description |
| --------------------------------------------- | --------------------------- | ----------------------------------------------------------------- |
| `policy.vlm_model_name` | `OpenGVLab/InternVL3-1B-hf` | Natively converted InternVL3 checkpoint or local model directory |
| `policy.training_stage` | `stage1` | `stage1` trains the action head; `stage2` finetunes VLM branches |
| `policy.apply_training_stage_defaults` | `true` | Reapplies stage finetuning defaults after loading a checkpoint |
| `policy.vlm_num_layers` | `14` | Number of InternVL3 language layers kept for the policy |
| `policy.vlm_dtype` | `bfloat16` | Requested VLM dtype |
| `policy.use_flash_attn` | `true` | Requests FlashAttention when installed; otherwise falls back |
| `policy.enable_gradient_checkpointing` | `true` | Enables checkpointing on supported InternVL3 modules |
| `policy.gradient_checkpointing_use_reentrant` | `false` | Reentrant setting passed to gradient checkpointing when supported |
| `policy.chunk_size` | `50` | Number of future actions predicted per chunk |
| `policy.n_action_steps` | `50` | Number of actions consumed from a sampled chunk |
| `policy.max_state_dim` | `24` | State padding dimension |
| `policy.max_action_dim` | `24` | Action padding dimension |
| `policy.postprocess_action_dim` | `null` | Optional action dimension returned after EVO1 postprocessing |
| `policy.binarize_gripper` | `false` | Binarizes the postprocessed gripper channel for LIBERO-style eval |
| `policy.task_field` | `task` | Batch field used as the language prompt |
## Inference
Try it out with a trained EVO1 checkpoint:
```bash
lerobot-rollout \
--policy.path=your-org/your-evo1-checkpoint \
--inference.type=rtc \ # optional
...
```
## Results
### LIBERO Evaluation
> [!NOTE]
> Benchmark results for a `lerobot`-hosted LIBERO checkpoint trained with this implementation
> will be added once training completes.
The official EVO1 LIBERO rollout protocol uses the raw LIBERO camera feature names
(`observation.images.agentview_image` and `observation.images.robot0_eye_in_hand_image`), replans every
14 actions, and binarizes the gripper command before stepping the simulator. The EVO1 policy postprocessor
can crop the padded 24D action back to the 7D LIBERO action space and apply that gripper binarization. To
evaluate a LIBERO checkpoint under the same one-episode-per-task setting, keep the raw camera names instead
of the default `image`/`image2` mapping and set the LIBERO action postprocessing flags:
```bash
lerobot-eval \
--policy.path=your-org/your-evo1-libero-checkpoint \
--policy.vlm_model_name=OpenGVLab/InternVL3-1B-hf \
--policy.device=cuda \
--policy.use_flash_attn=true \
--policy.n_action_steps=14 \
--policy.postprocess_action_dim=7 \
--policy.binarize_gripper=true \
--env.type=libero \
--env.task=libero_object \
--env.camera_name_mapping="{agentview_image: agentview_image, robot0_eye_in_hand_image: robot0_eye_in_hand_image}" \
--env.observation_height=448 \
--env.observation_width=448 \
--eval.batch_size=1 \
--eval.n_episodes=1
```
## References
- [EVO1 repository](https://github.com/MINT-SJTU/Evo-1)
- [InternVL3-1B-hf](https://huggingface.co/OpenGVLab/InternVL3-1B-hf)
## License
This LeRobot integration follows the Apache 2.0 License used by LeRobot. Check the upstream EVO1 and InternVL3 model pages for the licenses of released checkpoints and data.
+18
View File
@@ -0,0 +1,18 @@
# EVO1
EVO1 is a Vision-Language-Action policy for robot control. The LeRobot
integration uses an InternVL3 vision-language backbone with a flow-matching
action head, and supports staged training through the standard LeRobot policy
APIs.
The upstream EVO1 project is available at
[MINT-SJTU/Evo-1](https://github.com/MINT-SJTU/Evo-1).
```bibtex
@misc{evo1,
title = {EVO1},
author = {{MINT-SJTU}},
year = {2025},
howpublished = {\url{https://github.com/MINT-SJTU/Evo-1}},
}
```
+4 -1
View File
@@ -164,6 +164,7 @@ pynput-dep = ["pynput>=1.7.8,<1.9.0"]
pyzmq-dep = ["pyzmq>=26.2.1,<28.0.0"]
motorbridge-dep = ["motorbridge>=0.3.2,<0.4.0"]
motorbridge-smart-servo-dep = ["motorbridge-smart-servo>=0.0.4,<0.1.0"]
timm-dep = ["timm>=1.0.0,<1.1.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
@@ -220,7 +221,7 @@ groot = [
"lerobot[peft-dep]",
"lerobot[diffusers-dep]",
"dm-tree>=0.1.8,<1.0.0",
"timm>=1.0.0,<1.1.0",
"lerobot[timm-dep]",
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
@@ -234,6 +235,7 @@ fastwam = [
"lerobot[transformers-dep]",
"lerobot[diffusers-dep]",
]
evo1 = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
@@ -316,6 +318,7 @@ all = [
"lerobot[fastwam]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
"lerobot[evo1]",
"lerobot[hilserl]",
"lerobot[vla_jepa]",
"lerobot[async]",
+22
View File
@@ -83,6 +83,28 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
return LambdaLR(optimizer, lr_lambda, -1)
@LRSchedulerConfig.register_subclass("cosine_annealing_with_warmup")
@dataclass
class CosineAnnealingWithWarmupSchedulerConfig(LRSchedulerConfig):
"""Linear warmup followed by cosine annealing from the peak LR to zero.
Used by EVO1; the annealing phase always spans the remaining training steps.
"""
num_warmup_steps: int
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
def lr_lambda(current_step: int) -> float:
if current_step < self.num_warmup_steps:
return current_step / max(1, self.num_warmup_steps)
progress = (current_step - self.num_warmup_steps) / max(
1, num_training_steps - self.num_warmup_steps
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return LambdaLR(optimizer, lr_lambda, -1)
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
@dataclass
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
+2
View File
@@ -17,6 +17,7 @@ from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterp
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .eo1.configuration_eo1 import EO1Config as EO1Config
from .evo1.configuration_evo1 import Evo1Config as Evo1Config
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .fastwam.configuration_fastwam import FastWAMConfig as FastWAMConfig
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
@@ -45,6 +46,7 @@ __all__ = [
"EO1Config",
"FastWAMConfig",
"GaussianActorConfig",
"Evo1Config",
"GrootConfig",
"MolmoAct2Config",
"MultiTaskDiTConfig",
+1
View File
@@ -0,0 +1 @@
../../../../docs/source/policy_evo1_README.md
+19
View File
@@ -0,0 +1,19 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_evo1 import Evo1Config
from .modeling_evo1 import Evo1Policy
from .processor_evo1 import make_evo1_pre_post_processors
__all__ = ["Evo1Config", "Evo1Policy", "make_evo1_pre_post_processors"]
@@ -0,0 +1,252 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineAnnealingWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from ..rtc.configuration_rtc import RTCConfig
logger = logging.getLogger(__name__)
@PreTrainedConfig.register_subclass("evo1")
@dataclass
class Evo1Config(PreTrainedConfig):
training_stage: str = "stage1"
# When True and the policy runs on CUDA, EVO1 wraps its own forward passes (training and
# inference) in a bfloat16 autocast block, so its numerics do not depend on the dtype of any
# outer autocast context opened by lerobot-train/lerobot-eval.
use_amp: bool = True
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
max_state_dim: int = 24
max_action_dim: int = 24
max_views: int = 3
image_resolution: tuple[int, int] = (448, 448)
empty_cameras: int = 0
postprocess_action_dim: int | None = None
binarize_gripper: bool = False
gripper_index: int = 6
gripper_threshold: float = 0.5
gripper_below_threshold_value: float = 1.0
gripper_above_threshold_value: float = -1.0
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
vlm_model_name: str = "OpenGVLab/InternVL3-1B-hf"
vlm_num_layers: int | None = 14
vlm_dtype: str = "bfloat16"
# Max token length for tokenizing the (image placeholders + instruction) prompt. Prompts longer
# than this are right-truncated, so raise it for tasks with long language instructions or many views.
max_text_length: int = 1024
use_flash_attn: bool = True
action_head: str = "flowmatching"
embed_dim: int = 896
hidden_dim: int = 1024
state_hidden_dim: int = 1024
num_heads: int = 8
num_layers: int = 8
dropout: float = 0.0
num_inference_timesteps: int = 32
num_categories: int = 1
# When True, the action head is conditioned on a single pooled VL token (the last non-padding
# token of the causal decoder) instead of the full fused token sequence.
return_cls_only: bool = False
enable_gradient_checkpointing: bool = True
gradient_checkpointing_use_reentrant: bool = False
finetune_vlm: bool | None = None
finetune_language_model: bool | None = None
finetune_vision_model: bool | None = None
finetune_action_head: bool | None = None
# Reapply stage defaults after loading checkpoint configs so stage2 cannot
# accidentally inherit the frozen VLM flags stored by a stage1 checkpoint.
apply_training_stage_defaults: bool = True
task_field: str = "task"
embodiment_id_field: str | None = None
default_embodiment_id: int = 0
# Real-Time Chunking guidance for asynchronous inference (lerobot-rollout --inference.type=rtc
# sets this and calls init_rtc_processor()); None disables RTC.
rtc_config: RTCConfig | None = None
optimizer_lr: float = 1e-5
optimizer_betas: tuple[float, float] = (0.9, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-5
optimizer_grad_clip_norm: float = 1.0
scheduler_warmup_steps: int = 300
def __post_init__(self):
super().__post_init__()
if self.training_stage not in {"stage1", "stage2"}:
raise ValueError(
f"Unsupported EVO1 training_stage '{self.training_stage}', expected 'stage1' or 'stage2'"
)
if self.apply_training_stage_defaults:
stage_defaults = {
"stage1": {
"finetune_vlm": False,
"finetune_language_model": False,
"finetune_vision_model": False,
"finetune_action_head": True,
},
"stage2": {
"finetune_vlm": True,
"finetune_language_model": True,
"finetune_vision_model": True,
"finetune_action_head": True,
},
}[self.training_stage]
for flag_name, default_value in stage_defaults.items():
current_value = getattr(self, flag_name)
if current_value is not None and current_value != default_value:
logger.warning(
"EVO1 %s=%s is overridden by training_stage=%s default %s. "
"Set apply_training_stage_defaults=false to keep explicit finetuning flags.",
flag_name,
current_value,
self.training_stage,
default_value,
)
setattr(self, flag_name, default_value)
elif self.training_stage == "stage1":
if self.finetune_vlm is None:
self.finetune_vlm = False
if self.finetune_language_model is None:
self.finetune_language_model = False
if self.finetune_vision_model is None:
self.finetune_vision_model = False
if self.finetune_action_head is None:
self.finetune_action_head = True
elif self.training_stage == "stage2":
has_explicit_branch_flags = any(
flag is not None for flag in (self.finetune_language_model, self.finetune_vision_model)
)
if not has_explicit_branch_flags:
# An explicit finetune_vlm decides both branches; otherwise stage2 defaults to a
# full-VLM finetune.
vlm_finetune = self.finetune_vlm if self.finetune_vlm is not None else True
self.finetune_vlm = vlm_finetune
self.finetune_language_model = vlm_finetune
self.finetune_vision_model = vlm_finetune
elif self.finetune_vlm is None:
self.finetune_vlm = bool(self.finetune_language_model or self.finetune_vision_model)
if self.finetune_action_head is None:
self.finetune_action_head = True
if self.finetune_vlm is None:
self.finetune_vlm = False
if self.finetune_language_model is None:
self.finetune_language_model = False
if self.finetune_vision_model is None:
self.finetune_vision_model = False
if self.finetune_action_head is None:
self.finetune_action_head = False
branch_vlm = self.finetune_language_model or self.finetune_vision_model
if self.finetune_vlm != branch_vlm:
raise ValueError(
"Inconsistent EVO1 finetune config: "
f"finetune_vlm={self.finetune_vlm} but "
f"(finetune_language_model or finetune_vision_model)={branch_vlm}. "
"When branch-level flags are used, finetune_vlm must match their effective union."
)
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) must be <= chunk_size ({self.chunk_size})"
)
if len(self.image_resolution) != 2 or self.image_resolution[0] != self.image_resolution[1]:
raise ValueError(
"EVO1 currently expects a square image_resolution because InternVL3 preprocessing "
f"uses a scalar image_size, got {self.image_resolution}."
)
if not 0 <= self.default_embodiment_id < self.num_categories:
raise ValueError(
f"default_embodiment_id ({self.default_embodiment_id}) must be in "
f"[0, num_categories={self.num_categories})"
)
def validate_features(self) -> None:
if self.input_features is None:
self.input_features = {}
if self.output_features is None:
self.output_features = {}
for i in range(self.empty_cameras):
key = OBS_IMAGES + f".empty_camera_{i}"
if key not in self.input_features:
self.input_features[key] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, *self.image_resolution),
)
if OBS_STATE not in self.input_features:
self.input_features[OBS_STATE] = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,),
)
if ACTION not in self.output_features:
self.output_features[ACTION] = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,),
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineAnnealingWithWarmupSchedulerConfig(
num_warmup_steps=self.scheduler_warmup_steps,
)
@property
def observation_delta_indices(self) -> list[int]:
return [0]
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
+210
View File
@@ -0,0 +1,210 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
import torch.nn as nn
from .configuration_evo1 import Evo1Config
from .flow_matching import FlowmatchingActionHead
from .internvl3_embedder import InternVL3Embedder
class Evo1Model(nn.Module):
def __init__(self, config: Evo1Config, vlm_hub_kwargs: dict | None = None):
super().__init__()
self.config = config
self._device = config.device
self.return_cls_only = config.return_cls_only
# Set by Evo1Policy.init_rtc_processor() when config.rtc_config is provided.
self.rtc_processor = None
# Gradient checkpointing only pays off when the VLM is actually being trained; keep it off
# whenever every VLM branch is frozen so the frozen forward stays cheap.
tracks_vlm_gradients = bool(
config.finetune_vlm or config.finetune_language_model or config.finetune_vision_model
)
enable_gradient_checkpointing = config.enable_gradient_checkpointing and tracks_vlm_gradients
self.embedder = InternVL3Embedder(
model_name=config.vlm_model_name,
image_size=int(config.image_resolution[0]),
device=self._device,
num_language_layers=config.vlm_num_layers,
model_dtype=config.vlm_dtype,
use_flash_attn=config.use_flash_attn,
max_text_length=config.max_text_length,
enable_gradient_checkpointing=enable_gradient_checkpointing,
gradient_checkpointing_use_reentrant=config.gradient_checkpointing_use_reentrant,
hub_kwargs=vlm_hub_kwargs,
)
action_head_type = config.action_head.lower()
if action_head_type != "flowmatching":
raise NotImplementedError(f"Unknown action_head: {action_head_type}")
horizon = config.chunk_size
per_action_dim = config.max_action_dim
action_dim = horizon * per_action_dim
self.horizon = horizon
self.per_action_dim = per_action_dim
self.action_head = FlowmatchingActionHead(
embed_dim=config.embed_dim,
hidden_dim=config.hidden_dim,
action_dim=action_dim,
horizon=horizon,
per_action_dim=per_action_dim,
num_heads=config.num_heads,
num_layers=config.num_layers,
dropout=config.dropout,
num_inference_timesteps=config.num_inference_timesteps,
num_categories=config.num_categories,
state_dim=config.max_state_dim,
state_hidden_dim=config.state_hidden_dim,
).to(self._device)
def get_vl_embeddings(
self,
images: list[torch.Tensor],
image_mask: torch.Tensor,
prompt: str | list[str] | None = None,
return_cls_only: bool | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Fused VL embeddings from per-camera image batches.
Args:
images: list of per-camera tensors, each shaped ``(B, C, H, W)`` with values in ``[0, 1]``.
image_mask: bool tensor ``(B, max_views)`` marking present views.
Returns:
``(embeddings, valid_mask)``: the fused tokens and the bool mask of attendable context
positions (None when a single pooled token is returned).
"""
if return_cls_only is None:
return_cls_only = self.return_cls_only
if not images:
raise ValueError("EVO1 expects at least one image per sample.")
batch_size = images[0].shape[0]
if prompt is None:
prompts = [""] * batch_size
elif isinstance(prompt, str):
prompts = [prompt] * batch_size
else:
prompts = [str(p) for p in prompt]
if len(prompts) != batch_size:
raise ValueError(
f"Prompt batch size {len(prompts)} does not match image batch size {batch_size}"
)
if image_mask.dim() == 1:
image_mask = image_mask.unsqueeze(0)
if image_mask.shape[0] != batch_size:
raise ValueError(
f"image_mask batch size {image_mask.shape[0]} does not match image batch size {batch_size}"
)
return self.embedder.get_fused_image_text_embedding_batched(
camera_images=images,
image_masks=image_mask,
text_prompts=prompts,
return_cls_only=return_cls_only,
)
def predict_action(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor,
actions_gt: torch.Tensor | None = None,
action_mask: torch.Tensor | None = None,
embodiment_ids: torch.Tensor | None = None,
context_mask: torch.Tensor | None = None,
inference_delay: int | None = None,
prev_chunk_left_over: torch.Tensor | None = None,
execution_horizon: int | None = None,
):
if actions_gt is None:
return self.action_head.get_action(
fused_tokens,
state=state,
action_mask=action_mask,
embodiment_id=embodiment_ids,
context_mask=context_mask,
inference_delay=inference_delay,
prev_chunk_left_over=prev_chunk_left_over,
execution_horizon=execution_horizon,
rtc_processor=self.rtc_processor,
)
return self.action_head(
fused_tokens,
state=state,
actions_gt=actions_gt,
action_mask=action_mask,
embodiment_id=embodiment_ids,
context_mask=context_mask,
)
def forward(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor | None = None,
actions_gt: torch.Tensor | None = None,
action_mask: torch.Tensor | None = None,
embodiment_ids: torch.Tensor | None = None,
context_mask: torch.Tensor | None = None,
inference_delay: int | None = None,
prev_chunk_left_over: torch.Tensor | None = None,
execution_horizon: int | None = None,
):
return self.predict_action(
fused_tokens,
state,
actions_gt,
action_mask,
embodiment_ids,
context_mask,
inference_delay,
prev_chunk_left_over,
execution_horizon,
)
def _set_module_trainable(self, module: nn.Module, trainable: bool):
for param in module.parameters():
param.requires_grad = trainable
def _vlm_submodule(self, name: str) -> nn.Module:
module = getattr(self.embedder.model, name, None)
if not isinstance(module, nn.Module):
raise AttributeError(
f"InternVL model {type(self.embedder.model).__name__} has no '{name}' submodule; "
"the native HF InternVL layout (language_model / vision_tower / "
"multi_modal_projector) is required to apply the EVO1 finetune flags."
)
return module
def set_finetune_flags(self):
# __post_init__ resolves every finetune flag to a concrete boolean, so branch-level flags
# are authoritative here. Freeze everything first, then re-enable the requested branches.
self._set_module_trainable(self.embedder, False)
self._set_module_trainable(
self._vlm_submodule("language_model"), bool(self.config.finetune_language_model)
)
finetune_vision = bool(self.config.finetune_vision_model)
self._set_module_trainable(self._vlm_submodule("vision_tower"), finetune_vision)
self._set_module_trainable(self._vlm_submodule("multi_modal_projector"), finetune_vision)
if not self.config.finetune_action_head:
self._set_module_trainable(self.action_head, False)
+483
View File
@@ -0,0 +1,483 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import math
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, dim: int, max_len: int = 1000):
super().__init__()
pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, seq_len: int):
if seq_len > self.pe.size(1):
self._extend_pe(seq_len)
return self.pe[:, :seq_len, :]
def _extend_pe(self, new_max_len):
old_max_len, dim = self.pe.size(1), self.pe.size(2)
if new_max_len <= old_max_len:
return
extra_positions = torch.arange(old_max_len, new_max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
extra_pe = torch.zeros(new_max_len - old_max_len, dim)
extra_pe[:, 0::2] = torch.sin(extra_positions * div_term)
extra_pe[:, 1::2] = torch.cos(extra_positions * div_term)
extra_pe = extra_pe.unsqueeze(0)
new_pe = torch.cat([self.pe, extra_pe.to(self.pe.device)], dim=1)
self.pe = new_pe
class CategorySpecificLinear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, num_categories: int = 1):
super().__init__()
self.num_categories = num_categories
if num_categories <= 1:
self.linear = nn.Linear(in_dim, out_dim)
else:
self.weight = nn.Parameter(torch.empty(num_categories, in_dim, out_dim))
self.bias = nn.Parameter(torch.zeros(num_categories, out_dim))
# Initialize each per-category (in_dim, out_dim) matrix separately: xavier on the full
# 3D tensor would compute fan_in = in_dim * out_dim and badly under-scale the weights.
for category in range(num_categories):
nn.init.xavier_uniform_(self.weight[category])
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
if self.num_categories <= 1:
if x.dtype != self.linear.weight.dtype:
x = x.to(dtype=self.linear.weight.dtype)
return self.linear(x)
if x.dtype != self.weight.dtype:
x = x.to(dtype=self.weight.dtype)
orig_shape = x.shape
x_flat = x.reshape(-1, orig_shape[-1])
if category_id.dim() == 0:
cid = category_id.item()
out = x_flat @ self.weight[cid] + self.bias[cid]
else:
category_id = category_id.reshape(-1)
if category_id.numel() != x_flat.size(0):
raise ValueError(
f"category_id length {category_id.numel()} does not match flattened batch {x_flat.size(0)}"
)
weight_selected = self.weight[category_id]
bias_selected = self.bias[category_id]
out = torch.bmm(x_flat.unsqueeze(1), weight_selected).squeeze(1) + bias_selected
out_shape = orig_shape[:-1] + (out.shape[-1],)
return out.view(out_shape)
class CategorySpecificMLP(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_categories: int = 1):
super().__init__()
self.fc1 = CategorySpecificLinear(input_dim, hidden_dim, num_categories)
self.fc2 = CategorySpecificLinear(hidden_dim, output_dim, num_categories)
self.activation = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
out = self.activation(self.fc1(x, category_id))
out = self.fc2(out, category_id)
return out
class MultiEmbodimentActionEncoder(nn.Module):
def __init__(
self, action_dim: int, embed_dim: int, hidden_dim: int, horizon: int, num_categories: int = 1
):
super().__init__()
self.horizon = horizon
self.embed_dim = embed_dim
self.num_categories = num_categories
self.W1 = CategorySpecificLinear(action_dim, hidden_dim, num_categories)
self.W2 = CategorySpecificLinear(hidden_dim, hidden_dim, num_categories)
self.W3 = CategorySpecificLinear(hidden_dim, embed_dim, num_categories)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_dim, max_len=horizon)
self.activation = nn.ReLU(inplace=True)
def forward(self, action_seq: torch.Tensor, category_id: torch.LongTensor):
batch_size, horizon, action_dim = action_seq.shape
if self.horizon != horizon:
raise ValueError(
f"Action sequence length must match horizon: got {horizon}, expected {self.horizon}."
)
x = action_seq.reshape(batch_size * horizon, action_dim)
if category_id.dim() == 0:
cat_ids = category_id.expand(horizon * batch_size)
else:
cat_ids = category_id.unsqueeze(1).expand(batch_size, horizon).reshape(batch_size * horizon)
out = self.activation(self.W1(x, cat_ids))
pos_enc = self.pos_encoding(horizon).to(device=out.device, dtype=out.dtype)
out = out.view(batch_size, horizon, -1) + pos_enc
out = out.view(batch_size * horizon, -1)
out = self.activation(self.W2(out, cat_ids))
out = self.W3(out, cat_ids)
return out.view(batch_size, horizon, self.embed_dim)
class BasicTransformerBlock(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, hidden_dim: int, dropout: float = 0.0):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.ff = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, embed_dim))
def forward(
self,
action_tokens: torch.Tensor,
context_tokens: torch.Tensor,
time_emb: torch.Tensor,
context_key_padding_mask: torch.Tensor | None = None,
):
x = self.norm1(action_tokens)
attn_out, _ = self.attn(x, context_tokens, context_tokens, key_padding_mask=context_key_padding_mask)
x = action_tokens + attn_out
x2 = self.norm2(x)
if time_emb is not None:
x2 = x2 + time_emb.unsqueeze(1)
ff_out = self.ff(x2)
return x + ff_out
class FlowmatchingActionHead(nn.Module):
def __init__(
self,
embed_dim: int = 896,
hidden_dim: int = 1024,
action_dim: int = 16 * 7,
horizon: int = 16,
per_action_dim: int = 7,
num_heads: int = 8,
num_layers: int = 8,
dropout: float = 0.0,
num_inference_timesteps: int = 20,
num_categories: int = 1,
state_dim: int | None = None,
state_hidden_dim: int | None = None,
):
super().__init__()
logger.info("FlowmatchingActionHead num_inference_timesteps=%s", num_inference_timesteps)
self.embed_dim = embed_dim
self.horizon = horizon
self.per_action_dim = per_action_dim
self.action_dim = action_dim
self.num_inference_timesteps = num_inference_timesteps
self.num_categories = num_categories
self.time_pos_enc = SinusoidalPositionalEncoding(embed_dim, max_len=1000)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
embed_dim=embed_dim,
num_heads=num_heads,
hidden_dim=embed_dim * 4,
dropout=dropout,
)
for _ in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(embed_dim)
self.seq_pool_proj = nn.Linear(self.horizon * self.embed_dim, self.embed_dim)
self.mlp_head = CategorySpecificMLP(
input_dim=embed_dim,
hidden_dim=hidden_dim,
output_dim=action_dim,
num_categories=num_categories,
)
self.state_encoder = None
if state_dim is not None:
state_hidden = state_hidden_dim if state_hidden_dim is not None else embed_dim
self.state_encoder = CategorySpecificMLP(
input_dim=state_dim,
hidden_dim=state_hidden,
output_dim=embed_dim,
num_categories=num_categories,
)
if horizon > 1:
self.action_encoder = MultiEmbodimentActionEncoder(
action_dim=self.per_action_dim,
embed_dim=embed_dim,
hidden_dim=embed_dim,
horizon=horizon,
num_categories=num_categories,
)
self.single_action_proj = None
else:
self.action_encoder = None
self.single_action_proj = nn.Linear(self.per_action_dim, self.embed_dim)
def _project_actions(self, action_seq: torch.Tensor, embodiment_id: torch.LongTensor) -> torch.Tensor:
if self.horizon > 1 and self.action_encoder is not None:
return self.action_encoder(action_seq, embodiment_id)
if self.single_action_proj is None:
raise RuntimeError("single_action_proj is not initialized for horizon <= 1.")
return self.single_action_proj(action_seq)
def _expand_action_mask(
self,
action_mask: torch.Tensor,
batch_size: int,
per_action_dim: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
if action_mask is None:
raise ValueError("action_mask must be provided for flow matching inference.")
if action_mask.dim() == 2:
expected_last_dim = self.horizon * per_action_dim
if action_mask.shape == (batch_size, expected_last_dim):
expanded_mask = action_mask.reshape(batch_size, self.horizon, per_action_dim)
elif action_mask.shape == (batch_size, per_action_dim):
expanded_mask = action_mask.unsqueeze(1).expand(batch_size, self.horizon, per_action_dim)
else:
raise ValueError(
f"Expected action_mask shape {(batch_size, expected_last_dim)} or "
f"{(batch_size, per_action_dim)}, got {tuple(action_mask.shape)}"
)
elif action_mask.dim() == 3:
expected_shape = (batch_size, self.horizon, per_action_dim)
if tuple(action_mask.shape) != expected_shape:
raise ValueError(
f"Expected action_mask shape {expected_shape}, got {tuple(action_mask.shape)}"
)
expanded_mask = action_mask
else:
raise ValueError(f"Unsupported action_mask rank: {action_mask.dim()}")
return expanded_mask.to(device=device, dtype=dtype)
def _prepare_context(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor | None,
embodiment_id: torch.LongTensor | None,
context_mask: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.LongTensor]:
"""Normalize the VL context and embodiment ids shared by training and inference.
Returns the context tokens ``(B, S, E)``, a key_padding_mask for
``nn.MultiheadAttention`` (True = ignore) or None, and the resolved embodiment ids.
"""
batch_size = fused_tokens.size(0)
device = fused_tokens.device
if embodiment_id is None:
embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device)
elif self.num_categories > 1 and (
int(embodiment_id.min()) < 0 or int(embodiment_id.max()) >= self.num_categories
):
raise ValueError(
f"embodiment ids must be in [0, num_categories={self.num_categories}), "
f"got range [{int(embodiment_id.min())}, {int(embodiment_id.max())}]"
)
context_tokens = fused_tokens
if context_tokens.dim() == 2:
# A single pooled VL token (return_cls_only): give it a sequence dim of 1.
context_tokens = context_tokens.unsqueeze(1)
context_mask = None
if state is not None and self.state_encoder is not None:
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
if context_mask is not None:
state_valid = torch.ones(batch_size, 1, dtype=torch.bool, device=context_mask.device)
context_mask = torch.cat([context_mask.to(torch.bool), state_valid], dim=1)
key_padding_mask = None if context_mask is None else ~context_mask.to(torch.bool)
return context_tokens, key_padding_mask, embodiment_id
def forward(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor = None,
actions_gt: torch.Tensor = None,
embodiment_id: torch.LongTensor = None,
action_mask: torch.Tensor = None,
context_mask: torch.Tensor = None,
):
if actions_gt is None:
return self.get_action(
fused_tokens,
state=state,
embodiment_id=embodiment_id,
action_mask=action_mask,
context_mask=context_mask,
)
batch_size = fused_tokens.size(0)
device = fused_tokens.device
context_tokens, key_padding_mask, embodiment_id = self._prepare_context(
fused_tokens, state, embodiment_id, context_mask
)
t = (
torch.distributions.Beta(2, 2)
.sample((batch_size,))
.clamp(0.02, 0.98)
.to(device)
.to(dtype=self.dtype)
)
time_index = (t * 999).long().clamp_(0, 999)
time_emb = self.time_pos_enc(1000)[:, time_index, :].squeeze(0).to(dtype=context_tokens.dtype)
actions_gt_seq = actions_gt
noise = torch.rand_like(actions_gt) * 2 - 1
if action_mask is not None:
action_mask = action_mask.to(dtype=noise.dtype, device=noise.device)
if action_mask.shape != noise.shape:
raise ValueError(f"action_mask shape {action_mask.shape} != noise shape {noise.shape}")
actions_gt_seq = actions_gt_seq * action_mask
noise = noise * action_mask
if self.horizon > 1:
noise_seq = noise.view(batch_size, self.horizon, self.per_action_dim)
else:
noise_seq = noise if noise.dim() == 3 else noise.unsqueeze(1)
t_broadcast = t.view(batch_size, 1, 1)
action_intermediate_seq = (1 - t_broadcast) * noise_seq + t_broadcast * actions_gt_seq
action_tokens = self._project_actions(action_intermediate_seq, embodiment_id)
target_dtype = self.dtype
action_tokens = action_tokens.to(dtype=target_dtype)
context_tokens = context_tokens.to(dtype=target_dtype)
time_emb = time_emb.to(dtype=target_dtype)
x = action_tokens
for block in self.transformer_blocks:
x = block(x, context_tokens, time_emb, key_padding_mask)
x = self.norm_out(x)
if self.horizon > 1:
x_flat = x.reshape(batch_size, -1)
x_pooled = self.seq_pool_proj(x_flat)
else:
x_pooled = x.squeeze(1)
pred_velocity = self.mlp_head(x_pooled, embodiment_id)
return pred_velocity, noise
def get_action(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor = None,
embodiment_id: torch.LongTensor = None,
action_mask: torch.Tensor = None,
context_mask: torch.Tensor = None,
inference_delay: int | None = None,
prev_chunk_left_over: torch.Tensor | None = None,
execution_horizon: int | None = None,
rtc_processor=None,
):
batch_size = fused_tokens.size(0)
device = fused_tokens.device
context_tokens, key_padding_mask, embodiment_id = self._prepare_context(
fused_tokens, state, embodiment_id, context_mask
)
action_dim_total = self.action_dim
per_action_dim = self.per_action_dim
action = torch.rand(batch_size, action_dim_total, device=device, dtype=context_tokens.dtype) * 2 - 1
action_seq = action.view(batch_size, self.horizon, per_action_dim)
action_mask = self._expand_action_mask(
action_mask,
batch_size=batch_size,
per_action_dim=per_action_dim,
device=action_seq.device,
dtype=action_seq.dtype,
)
action_seq = action_seq * action_mask
target_dtype = self.dtype
context_tokens = context_tokens.to(dtype=target_dtype)
num_steps = int(self.num_inference_timesteps)
if num_steps <= 0:
raise ValueError(f"num_inference_timesteps must be positive, got {num_steps}")
dt = 1.0 / num_steps
use_rtc = rtc_processor is not None and (
inference_delay is not None or prev_chunk_left_over is not None
)
def predict_velocity(seq: torch.Tensor, step_time_emb: torch.Tensor) -> torch.Tensor:
"""Predict the masked flow velocity (x1 - x0 convention) for one integration step."""
seq = seq * action_mask
action_tokens = self._project_actions(seq, embodiment_id).to(dtype=target_dtype)
x = action_tokens
for block in self.transformer_blocks:
x = block(x, context_tokens, step_time_emb, key_padding_mask)
x = self.norm_out(x)
x_pooled = self.seq_pool_proj(x.reshape(batch_size, -1)) if self.horizon > 1 else x.squeeze(1)
pred = self.mlp_head(x_pooled, embodiment_id)
return pred.view(batch_size, self.horizon, per_action_dim) * action_mask
for i in range(num_steps):
t = i / num_steps
time_index = min(int(t * 999), 999)
time_emb = self.time_pos_enc(1000)[:, time_index, :].to(device).squeeze(0).to(dtype=target_dtype)
time_emb = time_emb.unsqueeze(0).repeat(batch_size, 1)
if use_rtc:
# RTCProcessor assumes the pi0 flow convention: its `time` runs 1 -> 0 and the
# clean-action estimate is x1 = x_t - time * v. EVO1 integrates t: 0 -> 1 with
# velocity v = x1 - x0 (so x1 = x_t + (1 - t) * v); passing time = 1 - t and
# flipping the velocity sign in both directions maps one convention onto the other.
guided = rtc_processor.denoise_step(
x_t=action_seq,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=1.0 - t,
original_denoise_step_partial=lambda seq, emb=time_emb: -predict_velocity(seq, emb),
execution_horizon=execution_horizon,
)
velocity = -guided
else:
velocity = predict_velocity(action_seq, time_emb)
action_seq = action_seq + dt * velocity
action_seq = action_seq * action_mask
return action_seq.reshape(batch_size, -1)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
@@ -0,0 +1,369 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from collections.abc import Sequence
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
import torchvision.transforms.functional as tvf
from torchvision.transforms.functional import InterpolationMode
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModel, AutoTokenizer
else:
AutoModel = None
AutoTokenizer = None
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>" # nosec B105
IMG_START_TOKEN = "<img>" # nosec B105
IMG_END_TOKEN = "</img>" # nosec B105
logger = logging.getLogger(__name__)
def _batched_resize_01(images: torch.Tensor, image_size: int) -> torch.Tensor:
"""Resize a batch of ``[0, 1]`` images to ``(image_size, image_size)`` on-device.
Numerically mirrors InternVL3's reference PIL preprocessing
(``to_pil_image`` -> ``Image.resize`` -> ``to_tensor``): the float input is quantized to uint8
exactly as ``to_pil_image`` does, then resized with bicubic interpolation and antialiasing,
which matches PIL's default resampler. Matching the reference pixel-for-pixel keeps the policy
interchangeable with checkpoints produced by the upstream EVO1 preprocessing.
Args:
images: float tensor of shape ``(N, C, H, W)`` with values in ``[0, 1]``.
Returns:
float32 tensor of shape ``(N, C, image_size, image_size)`` with values in ``[0, 1]``.
"""
# to_pil_image() quantizes float [0, 1] to uint8 (x * 255, truncated); replicate that so the
# bicubic resample sees the same integer pixels PIL would.
pixels_u8 = (images * 255.0).clamp(0, 255).to(torch.uint8)
resized = tvf.resize(
pixels_u8, [image_size, image_size], interpolation=InterpolationMode.BICUBIC, antialias=True
)
return resized.to(torch.float32) / 255.0
def _batched_pixel_values(
camera_images: Sequence[torch.Tensor],
max_views: int,
image_size: int,
mean: torch.Tensor,
std: torch.Tensor,
dtype: torch.dtype,
device: torch.device | str,
) -> torch.Tensor:
"""Build InternVL3 ``pixel_values`` from per-camera ``[0, 1]`` image batches without leaving the device.
Each image is resized, converted to ``dtype``, and ImageNet-normalized (a single tile per
image), batched across the whole minibatch. Absent views (fewer cameras than ``max_views``)
are filled with zero images; their placeholder tokens are masked out of attention downstream
via ``_mask_absent_image_tokens``.
Returns:
``pixel_values`` of shape ``(B * max_views, C, image_size, image_size)``, ordered row-major
over ``(sample, view)`` to line up with the per-view image placeholders in the prompt.
"""
resized: list[torch.Tensor] = []
for image in camera_images:
resized.append(_batched_resize_01(image.to(device=device), image_size).to(dtype))
batch_size = resized[0].shape[0]
channels = resized[0].shape[1]
while len(resized) < max_views:
resized.append(torch.zeros(batch_size, channels, image_size, image_size, dtype=dtype, device=device))
stacked = torch.stack(resized[:max_views], dim=1) # (B, V, C, H, W)
mean = mean.to(device=device, dtype=dtype).view(1, 1, -1, 1, 1)
std = std.to(device=device, dtype=dtype).view(1, 1, -1, 1, 1)
normalized = (stacked - mean) / std
return normalized.reshape(batch_size * max_views, channels, image_size, image_size)
class InternVL3Embedder(nn.Module):
"""Vision-language embedder using the native HF InternVL3 model (no trust_remote_code)."""
def __init__(
self,
model_name="OpenGVLab/InternVL3-1B-hf",
image_size=448,
device="cuda",
num_language_layers: int | None = 14,
model_dtype: str | torch.dtype = "bfloat16",
use_flash_attn: bool = True,
max_text_length: int = 1024,
enable_gradient_checkpointing: bool = True,
gradient_checkpointing_use_reentrant: bool = False,
hub_kwargs: dict | None = None,
):
super().__init__()
self._requested_device = device
self.image_size = image_size
self.num_language_layers = num_language_layers
self.max_text_length = max_text_length
self.enable_gradient_checkpointing = bool(enable_gradient_checkpointing)
self.gradient_checkpointing_use_reentrant = bool(gradient_checkpointing_use_reentrant)
hub_kwargs = hub_kwargs or {}
require_package("transformers", extra="evo1")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **hub_kwargs)
if isinstance(model_dtype, str):
try:
model_dtype = getattr(torch, model_dtype)
except AttributeError as exc:
raise ValueError(f"Unsupported EVO1 vlm_dtype '{model_dtype}'") from exc
self.model_dtype = model_dtype
attn_implementation = "flash_attention_2" if (use_flash_attn and _flash_attn_available()) else "eager"
if use_flash_attn and attn_implementation == "eager":
logger.warning("flash_attn is not installed. Falling back to eager attention.")
self.model = AutoModel.from_pretrained(
model_name,
torch_dtype=model_dtype,
attn_implementation=attn_implementation,
low_cpu_mem_usage=True,
**hub_kwargs,
).to(self._requested_device)
checkpoint_image_size = getattr(self.model.config.vision_config, "image_size", None)
if isinstance(checkpoint_image_size, (list, tuple)):
checkpoint_image_size = checkpoint_image_size[0]
if checkpoint_image_size is not None and int(checkpoint_image_size) != int(image_size):
raise ValueError(
f"EVO1 image_resolution ({image_size}) must match the InternVL checkpoint's native "
f"image size ({checkpoint_image_size}): the checkpoint's image_seq_length assumes "
"its native resolution, so other sizes would desync the image placeholder tokens "
"from the vision features."
)
self.num_image_token = self.model.config.image_seq_length
# Truncate language model to the requested number of layers
layers = self.model.language_model.layers
if self.num_language_layers is not None:
layers = layers[: self.num_language_layers]
self.model.language_model.layers = torch.nn.ModuleList(layers)
self._configure_memory_features()
self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
def _configure_memory_features(self) -> None:
checkpoint_kwargs = {"use_reentrant": self.gradient_checkpointing_use_reentrant}
if not self.enable_gradient_checkpointing:
language_model = self.model.language_model
if hasattr(language_model, "gradient_checkpointing_disable"):
language_model.gradient_checkpointing_disable()
vision_tower = getattr(self.model, "vision_tower", None)
if vision_tower is not None and hasattr(vision_tower, "encoder"):
vision_tower.encoder.gradient_checkpointing = False
return
def _enable_ckpt(module: nn.Module | None) -> bool:
if module is None:
return False
if hasattr(module, "gradient_checkpointing_enable"):
try:
module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=checkpoint_kwargs)
except TypeError:
module.gradient_checkpointing_enable()
return True
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = True
return True
return False
enabled_any = _enable_ckpt(self.model)
vision_tower = getattr(self.model, "vision_tower", None)
if vision_tower is not None:
enabled_any = _enable_ckpt(vision_tower) or enabled_any
language_model = self.model.language_model
enabled_any = _enable_ckpt(language_model) or enabled_any
if hasattr(language_model, "config"):
language_model.config.use_cache = False
if hasattr(self.model, "config"):
self.model.config.use_cache = False
if hasattr(self.model, "enable_input_require_grads"):
self.model.enable_input_require_grads()
if enabled_any:
logger.info("Gradient checkpointing enabled for InternVL3 embedder.")
else:
logger.warning(
"Requested gradient checkpointing, but model does not expose checkpointing controls."
)
def _build_multimodal_prompts(
self,
batch_num_tiles_list: list[list[int]],
text_prompts: Sequence[str],
) -> list[str]:
prompts = []
for num_tiles_list, text_prompt in zip(batch_num_tiles_list, text_prompts, strict=True):
prompt_segments = []
for i, tile_count in enumerate(num_tiles_list):
token_count = self.num_image_token * tile_count
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * token_count + IMG_END_TOKEN
prompt_segments.append(f"Image-{i + 1}: {image_tokens}\n")
prompts.append("".join(prompt_segments) + text_prompt.strip())
return prompts
def get_fused_image_text_embedding_batched(
self,
camera_images: Sequence[torch.Tensor],
image_masks: torch.Tensor,
text_prompts: Sequence[str],
return_cls_only: bool = True,
):
"""Fused VL embedding from per-camera ``[0, 1]`` image batches (no PIL, no host round-trip).
Args:
camera_images: list of per-camera tensors, each shaped ``(B, C, H, W)`` in ``[0, 1]``.
image_masks: bool tensor ``(B, max_views)`` marking present views.
Returns:
A ``(embeddings, valid_mask)`` tuple. With ``return_cls_only=False``, ``embeddings`` is
``(B, L, H)`` and ``valid_mask`` is a ``(B, L)`` bool tensor marking tokens downstream
attention may attend to (padding and absent-view tokens are False). With
``return_cls_only=True``, ``embeddings`` is the pooled ``(B, H)`` last-valid-token state
and ``valid_mask`` is None.
"""
max_views = int(image_masks.shape[1])
batch_size = int(image_masks.shape[0])
mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=self.model_dtype)
std = torch.tensor(IMAGENET_STD, device=self.device, dtype=self.model_dtype)
pixel_values = _batched_pixel_values(
camera_images, max_views, self.image_size, mean, std, self.model_dtype, self.device
)
# InternVL3 preprocessing uses a single tile per image (max_num=1).
batch_num_tiles_list = [[1] * max_views for _ in range(batch_size)]
return self._forward_vlm(
pixel_values, batch_num_tiles_list, image_masks, text_prompts, return_cls_only
)
def _mask_absent_image_tokens(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
image_masks: torch.Tensor,
batch_num_tiles_list: list[list[int]],
) -> torch.Tensor:
"""Zero attention over the image-context tokens of absent (zero-padded) views.
Fully vectorized: runs without any host<->device synchronization.
"""
# A single tile per image (max_num=1), so every image occupies the same number of
# context tokens.
tiles_per_image = (
batch_num_tiles_list[0][0] if batch_num_tiles_list and batch_num_tiles_list[0] else 1
)
tokens_per_image = self.num_image_token * tiles_per_image
image_masks = image_masks.to(device=input_ids.device).bool()
img_token_mask = input_ids == self.img_context_token_id # (B, L)
# keep[b, k] tells whether the k-th image-context token (ordered view0, view1, ...) survives.
per_token_keep = image_masks.repeat_interleave(tokens_per_image, dim=1) # (B, V * tokens_per_image)
# Rank each context token by its running position among the row's context tokens.
ctx_index = img_token_mask.to(torch.long).cumsum(dim=1) - 1
ctx_index = ctx_index.clamp(min=0, max=per_token_keep.shape[1] - 1)
keep_here = torch.gather(per_token_keep, 1, ctx_index) # (B, L)
drop = img_token_mask & ~keep_here
return attention_mask.masked_fill(drop, 0)
def _forward_vlm(
self,
pixel_values: torch.Tensor,
batch_num_tiles_list: list[list[int]],
image_masks: torch.Tensor,
text_prompts: Sequence[str],
return_cls_only: bool,
):
if pixel_values.shape[0] == 0:
logger.warning("InternVL3 received an empty image batch after preprocessing.")
hidden_size = getattr(self.model.config, "hidden_size", None)
if hidden_size is None:
hidden_size = getattr(self.model.config.text_config, "hidden_size", None)
if hidden_size is None:
raise RuntimeError("Unable to infer hidden size for empty InternVL3 batch.")
return torch.empty(0, hidden_size, device=self.device, dtype=torch.float32), None
prompts = self._build_multimodal_prompts(batch_num_tiles_list, text_prompts)
model_inputs = self.tokenizer(
list(prompts),
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_text_length,
).to(self.device)
input_ids = model_inputs["input_ids"]
if input_ids.shape[1] >= self.max_text_length:
# Truncation cuts from the right, so text is dropped before image placeholders — but a
# large max_views * image_seq_length budget can still eat into them. Fail loudly instead
# of letting the VLM crash on a placeholder/vision-feature count mismatch.
expected_image_tokens = self.num_image_token * sum(batch_num_tiles_list[0])
image_token_counts = (input_ids == self.img_context_token_id).sum(dim=1)
if not bool((image_token_counts == expected_image_tokens).all()):
raise ValueError(
f"Prompt truncation at max_text_length={self.max_text_length} cut into the "
f"image placeholder tokens ({expected_image_tokens} expected per sample). "
"Increase max_text_length or reduce max_views."
)
attention_mask = self._mask_absent_image_tokens(
input_ids, model_inputs["attention_mask"], image_masks, batch_num_tiles_list
)
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
)
fused_hidden = outputs.hidden_states[-1].to(torch.float32)
valid_mask = attention_mask.to(torch.bool)
if return_cls_only:
# Right-padded causal decoder: the last valid token is the only one that has attended
# to the full image + text prompt.
positions = torch.arange(valid_mask.shape[1], device=valid_mask.device)
last_valid = (valid_mask.long() * positions).argmax(dim=1)
batch_index = torch.arange(fused_hidden.shape[0], device=fused_hidden.device)
return fused_hidden[batch_index, last_valid], None
return fused_hidden, valid_mask
@property
def device(self) -> torch.device:
return next(self.model.parameters()).device
def _flash_attn_available() -> bool:
try:
import flash_attn # noqa: F401
except ModuleNotFoundError:
return False
return True
+532
View File
@@ -0,0 +1,532 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import builtins
from collections import deque
from contextlib import nullcontext
from pathlib import Path
from typing import TypedDict, Unpack
import torch
from torch import Tensor
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from ..rtc.modeling_rtc import RTCProcessor
from .configuration_evo1 import Evo1Config
from .evo1_model import Evo1Model
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
class Evo1Policy(PreTrainedPolicy):
config_class = Evo1Config
name = "evo1"
def __init__(self, config: Evo1Config, *, vlm_hub_kwargs: dict | None = None, **kwargs):
super().__init__(config)
config.validate_features()
if len(config.image_features) > config.max_views:
raise ValueError(
f"EVO1 supports at most {config.max_views} camera streams, got {len(config.image_features)}"
)
self.config = config
self.model = Evo1Model(config, vlm_hub_kwargs=vlm_hub_kwargs)
self.model.set_finetune_flags()
self._keep_frozen_embedder_eval()
self.init_rtc_processor()
self.reset()
def init_rtc_processor(self):
"""Create the RTC processor when config.rtc_config is set.
The RTC rollout backend assigns config.rtc_config after loading the policy and re-invokes
this method.
"""
self.rtc_processor = None
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
model = getattr(self, "model", None)
if model is not None:
model.rtc_processor = self.rtc_processor
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool | None = None,
**kwargs,
) -> T:
if strict is None:
strict = True
vlm_hub_kwargs = kwargs.pop("vlm_hub_kwargs", None)
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
if vlm_hub_kwargs is None:
# Forward the hub download options to the base-VLM download as well; `revision` is not
# forwarded because it identifies the policy repo, not the VLM repo.
vlm_hub_kwargs = {
key: value
for key, value in (
("token", token),
("cache_dir", cache_dir),
("local_files_only", local_files_only),
("proxies", proxies),
)
if value not in (None, False)
}
kwargs["vlm_hub_kwargs"] = vlm_hub_kwargs
return super().from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
config=config,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
strict=strict,
**kwargs,
)
@property
def _camera_keys(self) -> list[str]:
return list(self.config.image_features)
@property
def _env_action_dim(self) -> int:
action_feature = self.config.action_feature
if action_feature is None:
return self.config.max_action_dim
return int(action_feature.shape[0])
@property
def _compute_dtype(self) -> torch.dtype:
return next(self.model.action_head.parameters()).dtype
@property
def _device(self) -> torch.device:
# The device the policy actually lives on. Derived from the parameters rather than
# config.device so the policy keeps working after accelerate (or a plain .to()) moves it.
return next(self.model.action_head.parameters()).device
@property
def _amp_enabled(self) -> bool:
return bool(self.config.use_amp) and self._device.type == "cuda"
def _maybe_autocast(self):
# EVO1 manages its own mixed precision: an explicit bf16 autocast that also overrides any
# outer autocast context (e.g. lerobot-eval's fp16 default), keeping train and eval
# numerics identical.
if self._amp_enabled:
return torch.autocast(device_type="cuda", dtype=torch.bfloat16)
return nullcontext()
def get_optim_params(self) -> list[dict]:
decay, no_decay = [], []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
is_bias = name.endswith("bias") or ".bias" in name
is_norm = param.dim() == 1 or "norm" in name.lower()
if is_bias or is_norm:
no_decay.append(param)
else:
decay.append(param)
return [
{"params": decay, "weight_decay": self.config.optimizer_weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]
def reset(self):
self._action_queue = deque([], maxlen=self.config.n_action_steps)
def _normalize_task_batch(self, batch: dict[str, Tensor | list[str] | str]) -> list[str]:
prompts = batch.get(self.config.task_field)
if prompts is None and self.config.task_field != "task":
prompts = batch.get("task")
if prompts is None:
raise ValueError(f"EVO1 expects a '{self.config.task_field}' text field in the batch.")
if isinstance(prompts, str):
return [prompts]
if isinstance(prompts, (list, tuple)):
return [str(prompt) for prompt in prompts]
raise TypeError(f"Unsupported prompt batch type: {type(prompts)}")
def _prepare_state(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
if OBS_STATE not in batch:
raise ValueError(f"EVO1 requires '{OBS_STATE}' in the batch.")
state = batch[OBS_STATE]
if state.dim() == 1:
state = state.unsqueeze(0)
elif state.dim() == 3:
state = state[:, -1]
elif state.dim() != 2:
raise ValueError(f"Unsupported state tensor shape for EVO1: {tuple(state.shape)}")
batch_size, state_dim = state.shape
if state_dim > self.config.max_state_dim:
raise ValueError(
f"State dim {state_dim} exceeds configured max_state_dim {self.config.max_state_dim}"
)
explicit_mask = batch.get("state_mask")
if explicit_mask is not None:
if explicit_mask.dim() == 1:
explicit_mask = explicit_mask.unsqueeze(0)
elif explicit_mask.dim() == 3:
explicit_mask = explicit_mask[:, -1]
elif explicit_mask.dim() != 2:
raise ValueError(
f"Unsupported state_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
)
if explicit_mask.shape != (batch_size, state_dim):
raise ValueError(
f"state_mask shape {tuple(explicit_mask.shape)} does not match state shape {(batch_size, state_dim)}"
)
device = self._device
padded = torch.zeros(
batch_size,
self.config.max_state_dim,
dtype=state.dtype,
device=device,
)
padded[:, :state_dim] = state.to(device=device)
mask = torch.zeros(
batch_size,
self.config.max_state_dim,
dtype=torch.bool,
device=device,
)
if explicit_mask is None:
mask[:, :state_dim] = True
else:
mask[:, :state_dim] = explicit_mask.to(device=device, dtype=torch.bool)
# Zero out masked state dims so an explicit state_mask actually affects the model input
# (the state encoder has no mask argument of its own).
padded = padded * mask.to(dtype=padded.dtype)
return padded.to(dtype=self._compute_dtype), mask
def _prepare_actions(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
if ACTION not in batch:
raise ValueError(f"EVO1 requires '{ACTION}' in the batch for training.")
action = batch[ACTION]
if action.dim() == 2:
action = action.unsqueeze(1)
batch_size, horizon, action_dim = action.shape
if horizon != self.config.chunk_size:
raise ValueError(
f"EVO1 expects chunk_size={self.config.chunk_size}, got action horizon {horizon}"
)
if action_dim > self.config.max_action_dim:
raise ValueError(
f"Action dim {action_dim} exceeds configured max_action_dim {self.config.max_action_dim}"
)
explicit_mask = batch.get("action_mask")
if explicit_mask is not None:
if explicit_mask.dim() == 2:
if horizon == 1:
explicit_mask = explicit_mask.unsqueeze(1)
else:
raise ValueError(
f"2D action_mask is only supported when chunk_size=1, got action horizon {horizon}"
)
elif explicit_mask.dim() != 3:
raise ValueError(
f"Unsupported action_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
)
if explicit_mask.shape != (batch_size, horizon, action_dim):
raise ValueError(
"action_mask shape "
f"{tuple(explicit_mask.shape)} does not match action shape {(batch_size, horizon, action_dim)}"
)
device = self._device
padded = torch.zeros(
batch_size,
horizon,
self.config.max_action_dim,
dtype=action.dtype,
device=device,
)
padded[:, :, :action_dim] = action.to(device=device)
mask = torch.zeros(
batch_size,
horizon,
self.config.max_action_dim,
dtype=torch.bool,
device=device,
)
if explicit_mask is None:
mask[:, :, :action_dim] = True
else:
mask[:, :, :action_dim] = explicit_mask.to(device=device, dtype=torch.bool)
# Timesteps beyond the episode end hold fabricated (repeated) actions; exclude them from
# the loss like the other chunked policies do.
action_is_pad = batch.get("action_is_pad")
if action_is_pad is not None:
if action_is_pad.shape != (batch_size, horizon):
raise ValueError(
f"action_is_pad shape {tuple(action_is_pad.shape)} does not match "
f"(batch_size, chunk_size)={(batch_size, horizon)}"
)
in_episode = ~action_is_pad.to(device=device, dtype=torch.bool)
mask = mask & in_episode.unsqueeze(-1)
return padded.to(dtype=self._compute_dtype), mask
def _prepare_inference_action_mask(self, batch_size: int) -> Tensor:
mask = torch.zeros(
batch_size,
self.config.max_action_dim,
dtype=torch.bool,
device=self._device,
)
mask[:, : self._env_action_dim] = True
return mask
def _get_embodiment_ids(self, batch: dict[str, Tensor], batch_size: int) -> Tensor:
embodiment_ids = batch.get("embodiment_id")
if embodiment_ids is None and self.config.embodiment_id_field:
embodiment_ids = batch.get(self.config.embodiment_id_field)
if embodiment_ids is None:
return torch.full(
(batch_size,),
self.config.default_embodiment_id,
dtype=torch.long,
device=self._device,
)
if embodiment_ids.dim() == 0:
embodiment_ids = embodiment_ids.unsqueeze(0)
elif embodiment_ids.dim() > 1:
embodiment_ids = embodiment_ids[:, -1]
return embodiment_ids.to(device=self._device, dtype=torch.long)
@property
def _tracks_vlm_gradients(self) -> bool:
return bool(
self.config.finetune_vlm
or self.config.finetune_language_model
or self.config.finetune_vision_model
)
def _keep_frozen_embedder_eval(self) -> None:
if self._tracks_vlm_gradients:
return
embedder = getattr(self.model, "embedder", None)
if embedder is not None:
embedder.eval()
def train(self, mode: bool = True):
super().train(mode)
self._keep_frozen_embedder_eval()
return self
def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], Tensor]:
camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}."))
if not camera_keys:
raise ValueError("EVO1 requires at least one visual observation feature.")
camera_keys = list(camera_keys)[: self.config.max_views]
# Configured cameras may be absent from the batch up to the empty_cameras budget (e.g. the
# placeholder features added by validate_features); they become masked-out views that the
# embedder zero-pads. Any other absent camera is an error.
present_keys = [key for key in camera_keys if key in batch]
missing_keys = [key for key in camera_keys if key not in batch]
if len(missing_keys) > self.config.empty_cameras:
raise ValueError(
f"Missing camera features {missing_keys} in batch; at most "
f"empty_cameras={self.config.empty_cameras} may be absent."
)
if not present_keys:
raise ValueError("EVO1 requires at least one visual observation in the batch.")
# Keep each present camera as a batched (B, C, H, W) tensor on its current (GPU) device.
# Resizing/normalization and zero-padding of absent views happen batched inside the
# embedder, so images never leave the device here.
camera_images: list[Tensor] = []
for camera_key in present_keys:
image = batch[camera_key]
if image.dim() == 3:
# Promote an unbatched (C, H, W) frame so batch_size is read from a real batch dim.
image = image.unsqueeze(0)
elif image.dim() == 5:
image = image[:, -1]
elif image.dim() != 4:
raise ValueError(
f"Unsupported image tensor shape for EVO1: key={camera_key} shape={tuple(image.shape)}"
)
camera_images.append(image)
batch_size = camera_images[0].shape[0]
n_present = len(camera_images)
image_masks = torch.zeros(
batch_size, self.config.max_views, dtype=torch.bool, device=camera_images[0].device
)
image_masks[:, :n_present] = True
return camera_images, image_masks
def _compute_fused_tokens(
self,
prompts: list[str],
image_batches: list[Tensor],
image_masks: Tensor,
) -> tuple[Tensor, Tensor | None]:
track_vlm_gradients = self._tracks_vlm_gradients
grad_context = nullcontext() if track_vlm_gradients else torch.no_grad()
with grad_context:
fused_tokens, context_mask = self.model.get_vl_embeddings(
images=image_batches,
image_mask=image_masks,
prompt=prompts,
return_cls_only=self.config.return_cls_only,
)
if not track_vlm_gradients:
fused_tokens = fused_tokens.detach()
fused_tokens = fused_tokens.to(device=self._device, dtype=self._compute_dtype)
if context_mask is not None:
context_mask = context_mask.to(device=self._device)
return fused_tokens, context_mask
def _compute_masked_loss(
self,
pred_velocity: Tensor,
target_velocity: Tensor,
action_mask: Tensor,
reduction: str,
) -> Tensor:
flat_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=pred_velocity.dtype)
sq_error = ((pred_velocity - target_velocity) * flat_mask).pow(2)
active = flat_mask.sum(dim=1).clamp_min(1.0)
per_sample_loss = sq_error.sum(dim=1) / active
if reduction == "none":
return per_sample_loss
if reduction != "mean":
raise ValueError(f"Unsupported reduction '{reduction}'")
return sq_error.sum() / active.sum()
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
prompts = self._normalize_task_batch(batch)
image_batches, image_masks = self._collect_image_batches(batch)
states, _state_mask = self._prepare_state(batch)
actions_gt, action_mask = self._prepare_actions(batch)
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
with self._maybe_autocast():
fused_tokens, context_mask = self._compute_fused_tokens(prompts, image_batches, image_masks)
pred_velocity, noise = self.model(
fused_tokens,
state=states,
actions_gt=actions_gt,
action_mask=action_mask.to(device=self._device, dtype=self._compute_dtype),
embodiment_ids=embodiment_ids,
context_mask=context_mask,
)
# Compute the flow-matching regression loss in fp32, outside the autocast block.
pred_velocity = pred_velocity.float()
noise = noise.float()
flat_action_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=torch.float32)
# Flow-matching velocity target. Padded (masked-out) action dims are already zero on both sides
# here (`actions_gt` is zero-padded in `_prepare_actions`, and `noise` is masked inside the head),
# and the whole difference is multiplied by `flat_action_mask`, so padded dims contribute nothing.
target_velocity = (actions_gt.float() - noise).view(actions_gt.shape[0], -1) * flat_action_mask
loss = self._compute_masked_loss(pred_velocity, target_velocity, action_mask, reduction)
loss_mean = loss.mean().item() if loss.ndim > 0 else loss.item()
return loss, {
"loss": loss_mean,
"active_action_dims": float(action_mask.sum(dim=(1, 2)).float().mean().item()),
}
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
if (inference_delay is not None or prev_chunk_left_over is not None) and not self._rtc_enabled():
raise RuntimeError(
"Received RTC arguments but RTC is not configured for this EVO1 policy: set "
"config.rtc_config and call init_rtc_processor() (lerobot-rollout does this for "
"--inference.type=rtc)."
)
self.eval()
prompts = self._normalize_task_batch(batch)
image_batches, image_masks = self._collect_image_batches(batch)
states, _state_mask = self._prepare_state(batch)
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
action_mask = self._prepare_inference_action_mask(states.shape[0])
if prev_chunk_left_over is not None:
prev_chunk_left_over = prev_chunk_left_over.to(device=self._device)
with self._maybe_autocast():
fused_tokens, context_mask = self._compute_fused_tokens(prompts, image_batches, image_masks)
actions = self.model(
fused_tokens,
state=states,
action_mask=action_mask,
embodiment_ids=embodiment_ids,
context_mask=context_mask,
inference_delay=inference_delay,
prev_chunk_left_over=prev_chunk_left_over,
execution_horizon=execution_horizon,
)
actions = actions.view(states.shape[0], self.config.chunk_size, self.config.max_action_dim)
return actions.to(dtype=torch.float32)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval()
if len(self._action_queue) == 0:
action_chunk = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
self._action_queue.extend(action_chunk.transpose(0, 1))
# Returns one step of shape (B, max_action_dim): actions are emitted at the padded max_action_dim
# width and cropped to the real action dim downstream by the postprocessor (Evo1ActionProcessorStep).
# Callers that bypass the postprocessor receive the padded width.
return self._action_queue.popleft()
+400
View File
@@ -0,0 +1,400 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass
from typing import Any
import torch
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
ObservationProcessorStep,
PolicyAction,
PolicyActionProcessorStep,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import (
batch_to_transition,
create_transition,
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
ACTION,
DONE,
INFO,
OBS_PREFIX,
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
REWARD,
TRUNCATED,
)
from .configuration_evo1 import Evo1Config
def evo1_batch_to_transition(batch: dict[str, Any]):
transition = batch_to_transition(batch)
complementary_data = dict(transition.get("complementary_data") or {})
reserved = {ACTION, REWARD, DONE, TRUNCATED, INFO}
for key, value in batch.items():
if key in reserved or key.startswith(OBS_PREFIX):
continue
complementary_data.setdefault(key, value)
return create_transition(
observation=transition.get("observation"),
action=transition.get("action"),
reward=transition.get("reward", 0.0),
done=transition.get("done", False),
truncated=transition.get("truncated", False),
info=transition.get("info", {}),
complementary_data=complementary_data,
)
@dataclass
@ProcessorStepRegistry.register(name="evo1_pad_state_processor")
class Evo1PadStateProcessorStep(ObservationProcessorStep):
"""Pad policy observations to EVO1's fixed state width before normalization."""
max_state_dim: int = 24
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
if OBS_STATE not in observation:
return observation
state = observation[OBS_STATE]
state_dim = state.shape[-1]
if state_dim > self.max_state_dim:
raise ValueError(
f"EVO1 state has {state_dim} dims, which exceeds max_state_dim={self.max_state_dim}."
)
if state_dim < self.max_state_dim:
observation = observation.copy()
observation[OBS_STATE] = torch.nn.functional.pad(state, (0, self.max_state_dim - state_dim))
return observation
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
new_features = {ft: feats.copy() for ft, feats in features.items()}
obs_feats = new_features.setdefault(PipelineFeatureType.OBSERVATION, {})
if OBS_STATE in obs_feats:
obs_feats[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.max_state_dim,))
return new_features
def get_config(self) -> dict[str, Any]:
return {"max_state_dim": self.max_state_dim}
@dataclass
@ProcessorStepRegistry.register(name="evo1_pad_action_processor")
class Evo1PadActionProcessorStep(ProcessorStep):
"""Pad training actions and preserve the active action dimensions with action_mask."""
max_action_dim: int = 24
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is None:
return transition
if not isinstance(action, PolicyAction):
raise ValueError(f"EVO1 action should be a PolicyAction tensor, but got {type(action)}.")
action_dim = action.shape[-1]
if action_dim > self.max_action_dim:
raise ValueError(
f"EVO1 action has {action_dim} dims, which exceeds max_action_dim={self.max_action_dim}."
)
new_transition = transition.copy()
new_action = action
if action_dim < self.max_action_dim:
new_action = torch.nn.functional.pad(action, (0, self.max_action_dim - action_dim))
complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
action_mask = complementary_data.get("action_mask")
if action_mask is None:
action_mask = torch.ones(action.shape, dtype=torch.bool, device=action.device)
else:
action_mask = torch.as_tensor(action_mask, dtype=torch.bool, device=action.device)
if action_mask.shape != action.shape:
raise ValueError(
f"action_mask shape {tuple(action_mask.shape)} does not match action shape {tuple(action.shape)}."
)
if action_dim < self.max_action_dim:
action_mask = torch.nn.functional.pad(action_mask, (0, self.max_action_dim - action_dim))
complementary_data["action_mask"] = action_mask
new_transition[TransitionKey.ACTION] = new_action
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
new_features = {ft: feats.copy() for ft, feats in features.items()}
action_feats = new_features.setdefault(PipelineFeatureType.ACTION, {})
action_feats[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.max_action_dim,))
return new_features
def get_config(self) -> dict[str, Any]:
return {"max_action_dim": self.max_action_dim}
@dataclass
@ProcessorStepRegistry.register(name="evo1_action_processor")
class Evo1ActionProcessorStep(PolicyActionProcessorStep):
"""Crop padded EVO1 actions and optionally binarize the LIBERO gripper channel."""
action_dim: int
binarize_gripper: bool = False
gripper_index: int = 6
gripper_threshold: float = 0.5
gripper_below_threshold_value: float = 1.0
gripper_above_threshold_value: float = -1.0
def action(self, action: PolicyAction) -> PolicyAction:
if action.shape[-1] < self.action_dim:
raise ValueError(
f"EVO1 action has {action.shape[-1]} dims, which is smaller than action_dim={self.action_dim}."
)
action = action[..., : self.action_dim]
if not self.binarize_gripper:
return action
if not 0 <= self.gripper_index < self.action_dim:
raise ValueError(
f"gripper_index={self.gripper_index} must be within action_dim={self.action_dim}."
)
action = action.clone()
below = torch.as_tensor(
self.gripper_below_threshold_value,
dtype=action.dtype,
device=action.device,
)
above = torch.as_tensor(
self.gripper_above_threshold_value,
dtype=action.dtype,
device=action.device,
)
action[..., self.gripper_index] = torch.where(
action[..., self.gripper_index] > self.gripper_threshold,
above,
below,
)
return action
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
new_features = {ft: feats.copy() for ft, feats in features.items()}
action_feats = new_features.setdefault(PipelineFeatureType.ACTION, {})
action_feats[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))
return new_features
def get_config(self) -> dict[str, Any]:
return {
"action_dim": self.action_dim,
"binarize_gripper": self.binarize_gripper,
"gripper_index": self.gripper_index,
"gripper_threshold": self.gripper_threshold,
"gripper_below_threshold_value": self.gripper_below_threshold_value,
"gripper_above_threshold_value": self.gripper_above_threshold_value,
}
def _evo1_action_dim(config: Evo1Config) -> int:
if config.postprocess_action_dim is not None:
return config.postprocess_action_dim
action_feature = config.action_feature
if action_feature is None:
return config.max_action_dim
return int(action_feature.shape[0])
def _evo1_normalization_features(config: Evo1Config) -> dict[str, PolicyFeature]:
features = {**config.input_features, **config.output_features}
features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(config.max_state_dim,))
features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(config.max_action_dim,))
return features
def _evo1_action_features(config: Evo1Config) -> dict[str, PolicyFeature]:
return {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(config.max_action_dim,))}
_STAT_PAD_VALUES = {
"mean": 0.0,
"std": 1.0,
"min": -1.0,
"max": 1.0,
"q01": -1.0,
"q99": 1.0,
"q10": -1.0,
"q90": 1.0,
}
def _pad_stat_value(value: Any, target_dim: int, stat_name: str) -> torch.Tensor:
tensor = torch.as_tensor(value)
if not tensor.is_floating_point():
tensor = tensor.to(dtype=torch.float32)
if tensor.ndim == 0 or tensor.shape[-1] >= target_dim:
return tensor
pad_shape = (*tensor.shape[:-1], target_dim - tensor.shape[-1])
pad_value = _STAT_PAD_VALUES.get(stat_name, 0.0)
padding = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device)
return torch.cat([tensor, padding], dim=-1)
def _pad_feature_stats(
stats: dict[str, dict[str, Any]],
feature_key: str,
target_dim: int,
) -> None:
if feature_key not in stats:
return
stats[feature_key] = {
stat_name: _pad_stat_value(stat_value, target_dim, stat_name)
for stat_name, stat_value in stats[feature_key].items()
}
def _pad_evo1_stats(
config: Evo1Config,
stats: dict[str, dict[str, Any]] | None,
) -> dict[str, dict[str, Any]] | None:
if stats is None:
return None
padded_stats = deepcopy(stats)
# Added dimensions represent zero-padding inside EVO1. These neutral stats keep
# padded observations at normalized zero and only provide shape compatibility.
_pad_feature_stats(padded_stats, OBS_STATE, config.max_state_dim)
_pad_feature_stats(padded_stats, ACTION, config.max_action_dim)
return padded_stats
def reconcile_evo1_processors(
config: Evo1Config,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""Reconcile checkpoint-loaded pipelines with the current EVO1 config.
Two things cannot be restored from a serialized pipeline alone: the EVO1 batch converter
(converters are plain functions and are never serialized), and eval-time CLI overrides of the
action postprocessing flags (`postprocess_action_dim`, `binarize_gripper`, `gripper_*`). This
restores the converter and rebuilds the action step from the current config so those overrides
take effect.
"""
# Pipelines reloaded from a checkpoint come back with the default batch converter, which drops
# non-observation extras (embodiment_id, state_mask, custom task fields) needed by EVO1.
preprocessor.to_transition = evo1_batch_to_transition
action_step = Evo1ActionProcessorStep(
action_dim=_evo1_action_dim(config),
binarize_gripper=config.binarize_gripper,
gripper_index=config.gripper_index,
gripper_threshold=config.gripper_threshold,
gripper_below_threshold_value=config.gripper_below_threshold_value,
gripper_above_threshold_value=config.gripper_above_threshold_value,
)
steps = list(postprocessor.steps)
action_step_idx = next(
(idx for idx, step in enumerate(steps) if isinstance(step, Evo1ActionProcessorStep)), None
)
if action_step_idx is None:
insert_idx = next(
(idx + 1 for idx, step in enumerate(steps) if isinstance(step, UnnormalizerProcessorStep)),
0,
)
steps.insert(insert_idx, action_step)
else:
steps[action_step_idx] = action_step
postprocessor.steps = steps
return preprocessor, postprocessor
def make_evo1_pre_post_processors(
config: Evo1Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
normalization_features = _evo1_normalization_features(config)
action_features = _evo1_action_features(config)
normalization_stats = _pad_evo1_stats(config, dataset_stats)
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
Evo1PadStateProcessorStep(max_state_dim=config.max_state_dim),
Evo1PadActionProcessorStep(max_action_dim=config.max_action_dim),
NormalizerProcessorStep(
features=normalization_features,
norm_map=config.normalization_mapping,
stats=normalization_stats,
),
DeviceProcessorStep(device=config.device),
]
output_steps = [
UnnormalizerProcessorStep(
features=action_features,
norm_map=config.normalization_mapping,
stats=normalization_stats,
),
Evo1ActionProcessorStep(
action_dim=_evo1_action_dim(config),
binarize_gripper=config.binarize_gripper,
gripper_index=config.gripper_index,
gripper_threshold=config.gripper_threshold,
gripper_below_threshold_value=config.gripper_below_threshold_value,
gripper_above_threshold_value=config.gripper_above_threshold_value,
),
# float32 so downstream numpy conversion works even when the policy computes in bf16.
DeviceProcessorStep(device="cpu", float_dtype="float32"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
to_transition=evo1_batch_to_transition,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
+24 -2
View File
@@ -47,6 +47,7 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
from .act.configuration_act import ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .evo1.configuration_evo1 import Evo1Config
from .fastwam.configuration_fastwam import FastWAMConfig
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from .groot.configuration_groot import GrootConfig
@@ -92,7 +93,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x",
"molmoact2".
"molmoact2", "eo1", "evo1".
Returns:
The policy class corresponding to the given name.
@@ -167,6 +168,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .fastwam.modeling_fastwam import FastWAMPolicy
return FastWAMPolicy
elif name == "evo1":
from .evo1.modeling_evo1 import Evo1Policy
return Evo1Policy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -184,7 +189,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
"smolvla", "wall_x", "molmoact2".
"smolvla", "wall_x", "molmoact2", "eo1", "evo1".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -225,6 +230,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return VLAJEPAConfig(**kwargs)
elif policy_type == "fastwam":
return FastWAMConfig(**kwargs)
elif policy_type == "evo1":
return Evo1Config(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -330,6 +337,14 @@ def make_pre_post_processors(
revision=pretrained_revision,
)
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
if isinstance(policy_cfg, Evo1Config):
from .evo1.processor_evo1 import reconcile_evo1_processors
preprocessor, postprocessor = reconcile_evo1_processors(
policy_cfg,
preprocessor,
postprocessor,
)
return preprocessor, postprocessor
# Create a new processor based on policy type
@@ -440,6 +455,13 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, Evo1Config):
from .evo1.processor_evo1 import make_evo1_pre_post_processors
processors = make_evo1_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, MolmoAct2Config):
from .molmoact2.processor_molmoact2 import make_molmoact2_pre_post_processors
+29 -1
View File
@@ -7,11 +7,14 @@ from dataclasses import dataclass, field
import gymnasium as gym
import pytest
import torch
from gymnasium.envs.registration import register, registry as gym_registry
from lerobot.configs.types import PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.envs.configs import EnvConfig, LiberoEnv
from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors
from lerobot.processor import LiberoProcessorStep
from lerobot.utils.constants import OBS_PREFIX, OBS_STATE
logger = logging.getLogger(__name__)
@@ -61,6 +64,31 @@ def test_processors_delegation():
assert len(pre.steps) == 0
def test_libero_processors_are_policy_agnostic():
cfg = LiberoEnv()
pre, post = make_env_pre_post_processors(cfg, policy_cfg=object())
assert isinstance(pre.steps[0], LiberoProcessorStep)
assert len(post.steps) == 0
def test_libero_processor_flattens_state_to_raw_8_dim():
step = LiberoProcessorStep()
observation = {
OBS_PREFIX + "robot_state": {
"eef": {
"pos": torch.tensor([[1.0, 2.0, 3.0]]),
"quat": torch.tensor([[0.0, 0.0, 0.0, 1.0]]),
},
"gripper": {"qpos": torch.tensor([[4.0, 5.0]])},
}
}
state = step.observation(observation)[OBS_STATE]
assert state.shape == (1, 8)
assert torch.allclose(state, torch.tensor([[1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 4.0, 5.0]]))
def test_base_create_envs():
"""Base class create_envs() should build a single-task VectorEnv via gym.make()."""
gym_id = "_dispatch_test/CartPole-v99"
+840
View File
@@ -0,0 +1,840 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import pytest
import torch
from torch import nn
import lerobot.policies.evo1.evo1_model as evo1_model
import lerobot.policies.evo1.modeling_evo1 as modeling_evo1
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead
from lerobot.policies.evo1.internvl3_embedder import (
IMAGENET_MEAN,
IMAGENET_STD,
_batched_pixel_values,
)
from lerobot.policies.evo1.processor_evo1 import (
Evo1ActionProcessorStep,
Evo1PadActionProcessorStep,
Evo1PadStateProcessorStep,
evo1_batch_to_transition,
make_evo1_pre_post_processors,
reconcile_evo1_processors,
)
from lerobot.policies.factory import get_policy_class, make_policy_config
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.processor import (
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import (
batch_to_transition,
policy_action_to_transition,
transition_to_batch,
transition_to_policy_action,
)
from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
STATE_DIM = 4
ACTION_DIM = 3
MAX_STATE_DIM = 6
MAX_ACTION_DIM = 5
CHUNK_SIZE = 2
EMBED_DIM = 8
class DummyEvo1Model(nn.Module):
def __init__(self, config, vlm_hub_kwargs=None):
super().__init__()
self.config = config
self.embedder = nn.Dropout(p=0.0)
self.action_head = nn.Linear(1, 1)
self.get_vl_embeddings_calls = 0
self.grad_enabled_calls = []
self.embedder_training_calls = []
def set_finetune_flags(self):
return None
def get_vl_embeddings(self, images, image_mask, prompt=None, return_cls_only=False):
self.get_vl_embeddings_calls += 1
self.grad_enabled_calls.append(torch.is_grad_enabled())
self.embedder_training_calls.append(self.embedder.training)
# images is a list of per-camera (B, C, H, W) tensors, so the batch dim is images[0].shape[0].
batch_size = images[0].shape[0]
tokens = torch.ones(batch_size, 4, EMBED_DIM, requires_grad=torch.is_grad_enabled())
valid_mask = torch.ones(batch_size, 4, dtype=torch.bool)
return tokens, valid_mask
def forward(
self,
fused_tokens,
state=None,
actions_gt=None,
action_mask=None,
embodiment_ids=None,
context_mask=None,
**kwargs,
):
batch_size = fused_tokens.shape[0]
if actions_gt is None:
return torch.ones(batch_size, CHUNK_SIZE * MAX_ACTION_DIM)
pred_velocity = torch.zeros(batch_size, CHUNK_SIZE * MAX_ACTION_DIM)
noise = torch.zeros_like(actions_gt)
return pred_velocity, noise
class ChunkCountingDummyModel(DummyEvo1Model):
"""Emits per-step distinguishable actions so queue ordering and re-prediction are observable."""
def __init__(self, config, vlm_hub_kwargs=None):
super().__init__(config, vlm_hub_kwargs)
self.chunks_predicted = 0
def forward(
self,
fused_tokens,
state=None,
actions_gt=None,
action_mask=None,
embodiment_ids=None,
context_mask=None,
**kwargs,
):
if actions_gt is not None:
return super().forward(fused_tokens, state, actions_gt, action_mask, embodiment_ids, context_mask)
self.chunks_predicted += 1
batch_size = fused_tokens.shape[0]
step_values = torch.arange(CHUNK_SIZE, dtype=torch.float32) + 10.0 * self.chunks_predicted
chunk = step_values.repeat_interleave(MAX_ACTION_DIM).unsqueeze(0).repeat(batch_size, 1)
return chunk
def make_config(training_stage="stage1", **kwargs):
config_kwargs = {
"device": "cpu",
"vlm_model_name": "dummy-internvl3",
"training_stage": training_stage,
"chunk_size": CHUNK_SIZE,
"n_action_steps": 1,
"max_state_dim": MAX_STATE_DIM,
"max_action_dim": MAX_ACTION_DIM,
"max_views": 2,
"embed_dim": EMBED_DIM,
"hidden_dim": 16,
"state_hidden_dim": 16,
"num_heads": 2,
"num_layers": 1,
"num_inference_timesteps": 2,
"input_features": {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
f"{OBS_IMAGES}.front": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
},
"output_features": {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
},
}
config_kwargs.update(kwargs)
return Evo1Config(**config_kwargs)
def make_batch(include_action=True):
batch = {
"task": ["pick the block", "place the block"],
OBS_STATE: torch.randn(2, STATE_DIM),
f"{OBS_IMAGES}.front": torch.rand(2, 3, 16, 16),
}
if include_action:
batch[ACTION] = torch.randn(2, CHUNK_SIZE, ACTION_DIM)
return batch
def make_stats(state_dim=STATE_DIM, action_dim=ACTION_DIM):
return {
OBS_STATE: {
"min": torch.full((state_dim,), -2.0),
"max": torch.full((state_dim,), 2.0),
},
ACTION: {
"min": torch.full((action_dim,), -1.0),
"max": torch.full((action_dim,), 1.0),
},
}
def make_flowmatching_head(**overrides):
kwargs = {
"embed_dim": EMBED_DIM,
"hidden_dim": 16,
"action_dim": CHUNK_SIZE * ACTION_DIM,
"horizon": CHUNK_SIZE,
"per_action_dim": ACTION_DIM,
"num_heads": 2,
"num_layers": 1,
"num_inference_timesteps": 2,
"state_dim": STATE_DIM,
"state_hidden_dim": 16,
"num_categories": 1,
}
kwargs.update(overrides)
return FlowmatchingActionHead(**kwargs)
def test_evo1_factory_registration():
cfg = make_policy_config(
"evo1",
device="cpu",
vlm_model_name="dummy-internvl3",
input_features={
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
f"{OBS_IMAGES}.front": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
)
assert isinstance(cfg, Evo1Config)
assert get_policy_class("evo1") is modeling_evo1.Evo1Policy
def test_evo1_stage_defaults_and_consistency():
stage1 = make_config(training_stage="stage1")
assert (stage1.finetune_vlm, stage1.finetune_language_model, stage1.finetune_vision_model) == (
False,
False,
False,
)
assert stage1.finetune_action_head is True
stage2 = make_config(training_stage="stage2")
assert (stage2.finetune_vlm, stage2.finetune_language_model, stage2.finetune_vision_model) == (
True,
True,
True,
)
assert stage2.finetune_action_head is True
stage2_from_stage1_checkpoint_flags = make_config(
training_stage="stage2",
finetune_vlm=False,
finetune_language_model=False,
finetune_vision_model=False,
finetune_action_head=False,
)
assert (
stage2_from_stage1_checkpoint_flags.finetune_vlm,
stage2_from_stage1_checkpoint_flags.finetune_language_model,
stage2_from_stage1_checkpoint_flags.finetune_vision_model,
) == (
True,
True,
True,
)
assert stage2_from_stage1_checkpoint_flags.finetune_action_head is True
explicit_off = make_config(
training_stage="stage2",
apply_training_stage_defaults=False,
finetune_vlm=False,
finetune_language_model=False,
finetune_vision_model=False,
finetune_action_head=False,
)
assert (
explicit_off.finetune_vlm,
explicit_off.finetune_language_model,
explicit_off.finetune_vision_model,
) == (
False,
False,
False,
)
assert explicit_off.finetune_action_head is False
# An explicit finetune_vlm=False without branch-level flags freezes both branches instead of
# raising an inconsistency error.
frozen_vlm = make_config(
training_stage="stage2",
apply_training_stage_defaults=False,
finetune_vlm=False,
)
assert (
frozen_vlm.finetune_vlm,
frozen_vlm.finetune_language_model,
frozen_vlm.finetune_vision_model,
) == (False, False, False)
try:
make_config(
training_stage="stage2",
apply_training_stage_defaults=False,
finetune_vlm=True,
finetune_language_model=False,
)
except ValueError as exc:
assert "Inconsistent EVO1 finetune config" in str(exc)
else:
raise AssertionError("Expected inconsistent finetune config to raise ValueError")
def test_evo1_rejects_non_square_image_resolution():
with pytest.raises(ValueError, match="square image_resolution"):
make_config(image_resolution=(448, 320))
def test_evo1_rejects_out_of_range_default_embodiment_id():
with pytest.raises(ValueError, match="default_embodiment_id"):
make_config(default_embodiment_id=3, num_categories=2)
def test_evo1_model_uses_image_resolution_and_trainable_checkpointing(monkeypatch):
captured: dict = {}
class SpyEmbedder(nn.Module):
def __init__(self, **kwargs):
super().__init__()
captured.clear()
captured.update(kwargs)
monkeypatch.setattr(evo1_model, "InternVL3Embedder", SpyEmbedder)
stage1 = make_config(training_stage="stage1", image_resolution=(224, 224))
evo1_model.Evo1Model(stage1)
assert captured["image_size"] == 224
# VLM is frozen in stage1, so gradient checkpointing is gated off.
assert captured["enable_gradient_checkpointing"] is False
stage2 = make_config(training_stage="stage2", image_resolution=(224, 224))
evo1_model.Evo1Model(stage2)
assert captured["enable_gradient_checkpointing"] is True
class FakeInternVLModel(nn.Module):
"""Minimal stand-in with the native HF InternVL submodule layout."""
def __init__(self):
super().__init__()
self.language_model = nn.Linear(2, 2)
self.vision_tower = nn.Linear(2, 2)
self.multi_modal_projector = nn.Linear(2, 2)
class FakeEmbedder(nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.model = FakeInternVLModel()
def test_set_finetune_flags_targets_native_hf_internvl_submodules(monkeypatch):
monkeypatch.setattr(evo1_model, "InternVL3Embedder", FakeEmbedder)
stage2_model = evo1_model.Evo1Model(make_config(training_stage="stage2"))
stage2_model.set_finetune_flags()
vlm = stage2_model.embedder.model
assert all(p.requires_grad for p in vlm.language_model.parameters())
assert all(p.requires_grad for p in vlm.vision_tower.parameters())
assert all(p.requires_grad for p in vlm.multi_modal_projector.parameters())
assert all(p.requires_grad for p in stage2_model.action_head.parameters())
stage1_model = evo1_model.Evo1Model(make_config(training_stage="stage1"))
stage1_model.set_finetune_flags()
vlm = stage1_model.embedder.model
assert not any(p.requires_grad for p in vlm.parameters())
assert all(p.requires_grad for p in stage1_model.action_head.parameters())
def test_set_finetune_flags_fails_loudly_on_unknown_vlm_layout(monkeypatch):
class LegacyLayoutModel(nn.Module):
def __init__(self):
super().__init__()
self.language_model = nn.Linear(2, 2)
self.vision_model = nn.Linear(2, 2) # trust_remote_code-era attribute name
self.mlp1 = nn.Linear(2, 2)
class FakeEmbedder(nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.model = LegacyLayoutModel()
monkeypatch.setattr(evo1_model, "InternVL3Embedder", FakeEmbedder)
model = evo1_model.Evo1Model(make_config(training_stage="stage2"))
with pytest.raises(AttributeError, match="vision_tower"):
model.set_finetune_flags()
def test_evo1_policy_processors_pad_state_crop_action_and_binarize_gripper():
libero_action_dim = 7
config = make_config(
max_state_dim=MAX_STATE_DIM,
max_action_dim=8,
postprocess_action_dim=libero_action_dim,
binarize_gripper=True,
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(libero_action_dim,))},
)
stats = make_stats(action_dim=libero_action_dim)
preprocessor, postprocessor = make_evo1_pre_post_processors(config, dataset_stats=stats)
assert isinstance(preprocessor.steps[2], Evo1PadStateProcessorStep)
assert isinstance(preprocessor.steps[3], Evo1PadActionProcessorStep)
assert isinstance(preprocessor.steps[4], NormalizerProcessorStep)
assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep)
assert isinstance(postprocessor.steps[1], Evo1ActionProcessorStep)
normalizer = preprocessor.steps[4]
assert normalizer.features[OBS_STATE].shape == (MAX_STATE_DIM,)
assert normalizer.features[ACTION].shape == (8,)
assert normalizer._tensor_stats[OBS_STATE]["min"].shape == (MAX_STATE_DIM,)
assert normalizer._tensor_stats[ACTION]["min"].shape == (8,)
processed_batch = preprocessor(
{
"task": "pick the block",
OBS_STATE: torch.zeros(STATE_DIM),
ACTION: torch.zeros(libero_action_dim),
f"{OBS_IMAGES}.front": torch.rand(3, 16, 16),
}
)
processed_state = processed_batch[OBS_STATE]
assert processed_state.shape == (1, MAX_STATE_DIM)
assert torch.allclose(processed_state, torch.zeros_like(processed_state))
assert processed_batch[ACTION].shape == (1, 8)
assert torch.allclose(processed_batch[ACTION], torch.zeros_like(processed_batch[ACTION]))
assert processed_batch["action_mask"].shape == (1, 8)
assert processed_batch["action_mask"][:, :libero_action_dim].all()
assert not processed_batch["action_mask"][:, libero_action_dim:].any()
action = torch.tensor(
[
[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.7],
[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
],
dtype=torch.float32,
)
processed = postprocessor(action)
assert processed.shape == (2, 7)
assert processed.dtype == torch.float32
assert torch.allclose(processed[:, :6], action[:, :6])
assert torch.equal(processed[:, 6], torch.tensor([1.0, -1.0]))
def test_evo1_postprocessor_returns_float32_for_bf16_actions():
config = make_config()
_preprocessor, postprocessor = make_evo1_pre_post_processors(config, dataset_stats=make_stats())
processed = postprocessor(torch.zeros(2, MAX_ACTION_DIM, dtype=torch.bfloat16))
assert processed.dtype == torch.float32
def test_evo1_processor_save_load_round_trip_applies_config_overrides(tmp_path):
train_config = make_config()
preprocessor, postprocessor = make_evo1_pre_post_processors(train_config, dataset_stats=make_stats())
preprocessor.save_pretrained(tmp_path)
postprocessor.save_pretrained(tmp_path)
loaded_pre = PolicyProcessorPipeline.from_pretrained(
tmp_path,
config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json",
to_transition=batch_to_transition,
to_output=transition_to_batch,
)
loaded_post = PolicyProcessorPipeline.from_pretrained(
tmp_path,
config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json",
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
)
# Simulate eval-time CLI overrides applied on top of the loaded pipelines.
eval_config = make_config(binarize_gripper=True, postprocess_action_dim=ACTION_DIM)
loaded_pre, loaded_post = reconcile_evo1_processors(eval_config, loaded_pre, loaded_post)
assert loaded_pre.to_transition is evo1_batch_to_transition
assert sum(isinstance(step, Evo1ActionProcessorStep) for step in loaded_post.steps) == 1
action_step = next(step for step in loaded_post.steps if isinstance(step, Evo1ActionProcessorStep))
assert action_step.binarize_gripper is True
assert action_step.action_dim == ACTION_DIM
# The float32 output dtype is part of the serialized pipeline itself.
device_step = next(step for step in loaded_post.steps if isinstance(step, DeviceProcessorStep))
assert device_step.float_dtype == "float32"
# Non-observation extras (embodiment_id, ...) must survive the reloaded preprocessor.
processed = loaded_pre(
{
"task": "pick the block",
OBS_STATE: torch.zeros(STATE_DIM),
f"{OBS_IMAGES}.front": torch.rand(3, 16, 16),
"embodiment_id": torch.tensor([0]),
}
)
assert "embodiment_id" in processed
def test_evo1_policy_forward_and_inference_use_batched_embedding(monkeypatch):
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
policy = modeling_evo1.Evo1Policy(make_config())
preprocessor, _postprocessor = make_evo1_pre_post_processors(policy.config, dataset_stats=make_stats())
training_batch = preprocessor(make_batch(include_action=True))
assert training_batch[ACTION].shape == (2, CHUNK_SIZE, MAX_ACTION_DIM)
assert training_batch["action_mask"].shape == (2, CHUNK_SIZE, MAX_ACTION_DIM)
assert training_batch["action_mask"][:, :, :ACTION_DIM].all()
assert not training_batch["action_mask"][:, :, ACTION_DIM:].any()
loss, metrics = policy.forward(training_batch)
assert loss.ndim == 0
assert torch.isfinite(loss)
assert metrics["active_action_dims"] == ACTION_DIM * CHUNK_SIZE
assert policy.model.get_vl_embeddings_calls == 1
action_chunk = policy.predict_action_chunk(make_batch(include_action=False))
assert action_chunk.shape == (2, CHUNK_SIZE, MAX_ACTION_DIM)
assert action_chunk.dtype == torch.float32
policy.reset()
selected = policy.select_action(make_batch(include_action=False))
assert selected.shape == (2, MAX_ACTION_DIM)
def test_evo1_forward_masks_padded_action_timesteps(monkeypatch):
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
policy = modeling_evo1.Evo1Policy(make_config())
batch = make_batch(include_action=True)
batch[ACTION] = torch.ones(2, CHUNK_SIZE, ACTION_DIM)
# Give the padded (past-episode-end) timestep a huge value: if it leaked into the loss, the
# loss would blow up far beyond 1.0.
batch[ACTION][:, -1, :] = 100.0
batch["action_is_pad"] = torch.zeros(2, CHUNK_SIZE, dtype=torch.bool)
batch["action_is_pad"][:, -1] = True
loss, metrics = policy.forward(batch)
# DummyEvo1Model predicts zero velocity and zero noise, so each active element contributes
# (0 - action)^2 = 1.0 for the in-episode ones-valued actions.
assert metrics["active_action_dims"] == ACTION_DIM * (CHUNK_SIZE - 1)
assert torch.isclose(loss, torch.tensor(1.0))
def test_evo1_select_action_queue_orders_steps_and_repredicts(monkeypatch):
monkeypatch.setattr(modeling_evo1, "Evo1Model", ChunkCountingDummyModel)
policy = modeling_evo1.Evo1Policy(make_config(n_action_steps=CHUNK_SIZE))
batch = make_batch(include_action=False)
first = policy.select_action(batch)
second = policy.select_action(batch)
third = policy.select_action(batch)
# First chunk provides steps 10, 11 in order; the third call triggers a fresh prediction (20).
assert torch.all(first == 10.0)
assert torch.all(second == 11.0)
assert torch.all(third == 20.0)
assert policy.model.chunks_predicted == 2
def test_evo1_predict_action_chunk_rejects_rtc_kwargs_without_rtc_config(monkeypatch):
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
policy = modeling_evo1.Evo1Policy(make_config())
with pytest.raises(RuntimeError, match="RTC"):
policy.predict_action_chunk(make_batch(include_action=False), inference_delay=2)
def test_evo1_rtc_processor_wiring(monkeypatch):
monkeypatch.setattr(evo1_model, "InternVL3Embedder", FakeEmbedder)
policy = modeling_evo1.Evo1Policy(make_config())
assert policy.rtc_processor is None
assert policy.model.rtc_processor is None
# The RTC rollout backend assigns rtc_config after loading and re-inits the processor.
policy.config.rtc_config = RTCConfig(execution_horizon=CHUNK_SIZE)
policy.init_rtc_processor()
assert isinstance(policy.rtc_processor, RTCProcessor)
assert policy.model.rtc_processor is policy.rtc_processor
# RTC drives predict_action_chunk directly; the select_action queue path is unsupported.
with pytest.raises(AssertionError, match="select_action"):
policy.select_action(make_batch(include_action=False))
def test_flowmatching_rtc_guidance_pulls_prefix_toward_previous_chunk():
head = make_flowmatching_head(num_inference_timesteps=16)
processor = RTCProcessor(RTCConfig(execution_horizon=CHUNK_SIZE))
fused = torch.randn(2, 4, EMBED_DIM)
state = torch.randn(2, STATE_DIM)
action_mask = torch.ones(2, ACTION_DIM, dtype=torch.bool)
prev_chunk = torch.tensor([0.7, -0.4, 0.2]).expand(2, CHUNK_SIZE, ACTION_DIM).contiguous()
torch.manual_seed(0)
unguided = head.get_action(fused, state=state, action_mask=action_mask)
unguided = unguided.view(2, CHUNK_SIZE, ACTION_DIM)
torch.manual_seed(0)
guided = head.get_action(
fused,
state=state,
action_mask=action_mask,
inference_delay=1,
prev_chunk_left_over=prev_chunk,
rtc_processor=processor,
)
guided = guided.view(2, CHUNK_SIZE, ACTION_DIM)
# The frozen prefix (first inference_delay steps) must land far closer to the previous chunk
# than the unguided sample from the same noise does.
guided_dist = (guided[:, 0] - prev_chunk[:, 0]).abs().mean()
unguided_dist = (unguided[:, 0] - prev_chunk[:, 0]).abs().mean()
assert guided_dist < 0.5 * unguided_dist
assert torch.isfinite(guided).all()
def test_flowmatching_rtc_first_chunk_without_leftover_matches_unguided():
head = make_flowmatching_head(num_inference_timesteps=4)
processor = RTCProcessor(RTCConfig(execution_horizon=CHUNK_SIZE))
fused = torch.randn(2, 4, EMBED_DIM)
state = torch.randn(2, STATE_DIM)
action_mask = torch.ones(2, ACTION_DIM, dtype=torch.bool)
torch.manual_seed(0)
unguided = head.get_action(fused, state=state, action_mask=action_mask)
torch.manual_seed(0)
first_chunk = head.get_action(
fused,
state=state,
action_mask=action_mask,
inference_delay=2,
prev_chunk_left_over=None,
rtc_processor=processor,
)
assert torch.allclose(unguided, first_chunk)
def test_evo1_missing_configured_camera_needs_empty_cameras_budget(monkeypatch):
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
batch = make_batch(include_action=False) # only provides the front camera
two_camera_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
f"{OBS_IMAGES}.front": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
f"{OBS_IMAGES}.wrist": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
}
strict_policy = modeling_evo1.Evo1Policy(make_config(input_features=dict(two_camera_features)))
with pytest.raises(ValueError, match="empty_cameras"):
strict_policy._collect_image_batches(batch)
# empty_cameras adds placeholder camera features that are never present in the batch; they
# become masked-out views instead of crashing with a KeyError.
padded_policy = modeling_evo1.Evo1Policy(make_config(empty_cameras=1))
assert len(padded_policy.config.image_features) == 2
camera_images, image_masks = padded_policy._collect_image_batches(batch)
assert len(camera_images) == 1
assert image_masks.tolist() == [[True, False], [True, False]]
def test_stage1_frozen_vlm_embeddings_do_not_track_gradients(monkeypatch):
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
policy = modeling_evo1.Evo1Policy(make_config(training_stage="stage1"))
policy.train()
image_batches, image_masks = policy._collect_image_batches(make_batch(include_action=False))
fused_tokens, context_mask = policy._compute_fused_tokens(["pick", "place"], image_batches, image_masks)
assert policy.model.grad_enabled_calls == [False]
assert policy.model.embedder_training_calls == [False]
assert not fused_tokens.requires_grad
assert context_mask is not None
assert policy.model.embedder.training is False
def test_stage2_vlm_embeddings_track_gradients(monkeypatch):
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
policy = modeling_evo1.Evo1Policy(make_config(training_stage="stage2"))
policy.train()
image_batches, image_masks = policy._collect_image_batches(make_batch(include_action=False))
fused_tokens, _context_mask = policy._compute_fused_tokens(["pick", "place"], image_batches, image_masks)
assert policy.model.grad_enabled_calls == [True]
assert policy.model.embedder_training_calls == [True]
assert fused_tokens.requires_grad
def test_collect_image_batches_handles_unbatched_chw(monkeypatch):
# Regression for an issue where batch_size was read from shape[0] before normalizing
# per-camera tensor dims, so an unbatched (C, H, W) input was treated as batch_size=C.
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
policy = modeling_evo1.Evo1Policy(make_config())
batch = {
OBS_STATE: torch.randn(1, STATE_DIM),
f"{OBS_IMAGES}.front": torch.rand(3, 16, 16),
}
camera_images, image_masks = policy._collect_image_batches(batch)
# One present camera, returned as a batched (B, C, H, W) tensor with the unbatched CHW frame
# promoted to batch_size=1 (not read as batch_size=C).
assert len(camera_images) == 1
assert camera_images[0].shape == (1, 3, 16, 16)
assert image_masks.tolist() == [[True, False]]
def test_evo1_state_mask_zeroes_masked_dims(monkeypatch):
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
policy = modeling_evo1.Evo1Policy(make_config())
batch = {
OBS_STATE: torch.ones(2, STATE_DIM),
"state_mask": torch.tensor([[True, True, False, False]] * 2),
}
states, mask = policy._prepare_state(batch)
assert torch.all(states[:, :2] == 1.0)
assert torch.all(states[:, 2:] == 0.0)
assert mask[:, :2].all()
assert not mask[:, 2:].any()
def test_evo1_action_mask_accepts_chunk_size_one(monkeypatch):
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
config = make_config(chunk_size=1, n_action_steps=1)
policy = modeling_evo1.Evo1Policy(config)
batch = make_batch(include_action=True)
batch[ACTION] = torch.randn(2, ACTION_DIM)
batch["action_mask"] = torch.ones(2, ACTION_DIM, dtype=torch.bool)
actions, action_mask = policy._prepare_actions(batch)
assert actions.shape == (2, 1, MAX_ACTION_DIM)
assert action_mask.shape == (2, 1, MAX_ACTION_DIM)
assert action_mask[:, :, :ACTION_DIM].all()
assert not action_mask[:, :, ACTION_DIM:].any()
def test_flowmatching_state_encoder_for_horizon_one():
head = make_flowmatching_head(action_dim=ACTION_DIM, horizon=1)
assert head.state_encoder is not None
pred_velocity, noise = head(
torch.randn(2, 4, EMBED_DIM),
state=torch.randn(2, STATE_DIM),
actions_gt=torch.randn(2, 1, ACTION_DIM),
action_mask=torch.ones(2, 1, ACTION_DIM, dtype=torch.bool),
)
assert pred_velocity.shape == (2, ACTION_DIM)
assert noise.shape == (2, 1, ACTION_DIM)
def test_flowmatching_get_action_real_path_respects_action_mask():
torch.manual_seed(0)
head = make_flowmatching_head()
action_mask = torch.zeros(2, ACTION_DIM, dtype=torch.bool)
action_mask[:, :2] = True
actions = head.get_action(
torch.randn(2, 4, EMBED_DIM),
state=torch.randn(2, STATE_DIM),
action_mask=action_mask,
)
assert actions.shape == (2, CHUNK_SIZE * ACTION_DIM)
assert torch.isfinite(actions).all()
action_seq = actions.view(2, CHUNK_SIZE, ACTION_DIM)
assert torch.all(action_seq[..., 2] == 0.0)
def test_flowmatching_context_mask_blocks_masked_context_tokens():
head = make_flowmatching_head()
state = torch.randn(2, STATE_DIM)
action_mask = torch.ones(2, ACTION_DIM, dtype=torch.bool)
fused = torch.randn(2, 4, EMBED_DIM)
context_mask = torch.ones(2, 4, dtype=torch.bool)
context_mask[:, -1] = False
corrupted = fused.clone()
corrupted[:, -1] = 1e4
torch.manual_seed(0)
reference = head.get_action(fused, state=state, action_mask=action_mask, context_mask=context_mask)
torch.manual_seed(0)
with_garbage = head.get_action(corrupted, state=state, action_mask=action_mask, context_mask=context_mask)
assert torch.allclose(reference, with_garbage)
def test_flowmatching_head_accepts_pooled_2d_context():
head = make_flowmatching_head()
pred_velocity, noise = head(
torch.randn(2, EMBED_DIM), # pooled (B, E) context from return_cls_only
state=torch.randn(2, STATE_DIM),
actions_gt=torch.randn(2, CHUNK_SIZE, ACTION_DIM),
action_mask=torch.ones(2, CHUNK_SIZE, ACTION_DIM, dtype=torch.bool),
)
assert pred_velocity.shape == (2, CHUNK_SIZE * ACTION_DIM)
actions = head.get_action(
torch.randn(2, EMBED_DIM),
state=torch.randn(2, STATE_DIM),
action_mask=torch.ones(2, ACTION_DIM, dtype=torch.bool),
)
assert actions.shape == (2, CHUNK_SIZE * ACTION_DIM)
def test_flowmatching_rejects_out_of_range_embodiment_ids():
head = make_flowmatching_head(num_categories=2)
with pytest.raises(ValueError, match="num_categories"):
head.get_action(
torch.randn(2, 4, EMBED_DIM),
state=torch.randn(2, STATE_DIM),
action_mask=torch.ones(2, ACTION_DIM, dtype=torch.bool),
embodiment_id=torch.tensor([0, 5]),
)
def test_evo1_batched_pixel_values_shape_and_zero_padding():
torch.manual_seed(0)
batch_size, image_size, max_views = 2, 448, 3
camera_images = [torch.rand(batch_size, 3, 40, 50)] # a single present camera
mean = torch.tensor(IMAGENET_MEAN)
std = torch.tensor(IMAGENET_STD)
pixel_values = _batched_pixel_values(
camera_images, max_views, image_size, mean, std, torch.float32, torch.device("cpu")
)
assert pixel_values.shape == (batch_size * max_views, 3, image_size, image_size)
grouped = pixel_values.reshape(batch_size, max_views, 3, image_size, image_size)
# Absent views (indices 1, 2) are zero images, normalized to the constant -mean/std.
expected_pad = (-mean / std).view(1, 3, 1, 1)
for view in (1, 2):
assert torch.allclose(
grouped[:, view], expected_pad.expand(batch_size, 3, image_size, image_size), atol=1e-5
)
# The present view is genuinely different from the constant pad value.
assert not torch.allclose(
grouped[:, 0], expected_pad.expand(batch_size, 3, image_size, image_size), atol=1e-3
)
Generated
+11 -2
View File
@@ -2978,6 +2978,9 @@ eo1 = [
evaluation = [
{ name = "av" },
]
evo1 = [
{ name = "transformers" },
]
fastwam = [
{ name = "diffusers" },
{ name = "transformers" },
@@ -3179,6 +3182,9 @@ test = [
{ name = "pytest-cov" },
{ name = "pytest-timeout" },
]
timm-dep = [
{ name = "timm" },
]
topreward = [
{ name = "transformers" },
]
@@ -3296,6 +3302,7 @@ requires-dist = [
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'vla-jepa'" },
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["evo1"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["fastwam"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" },
@@ -3362,10 +3369,12 @@ requires-dist = [
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["smolvla"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["test"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["timm-dep"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["topreward"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["training"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'annotations'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'eo1'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'evo1'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'fastwam'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" },
@@ -3436,7 +3445,7 @@ requires-dist = [
{ name = "setuptools", specifier = ">=71.0.0,<81.0.0" },
{ name = "teleop", marker = "extra == 'phone'", specifier = ">=0.1.0,<0.2.0" },
{ name = "termcolor", specifier = ">=2.4.0,<4.0.0" },
{ name = "timm", marker = "extra == 'groot'", specifier = ">=1.0.0,<1.1.0" },
{ name = "timm", marker = "extra == 'timm-dep'", specifier = ">=1.0.0,<1.1.0" },
{ name = "torch", marker = "sys_platform != 'linux'", specifier = ">=2.7,<2.12.0" },
{ name = "torch", marker = "sys_platform == 'linux'", specifier = ">=2.7,<2.12.0", index = "https://download.pytorch.org/whl/cu128" },
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin' and extra == 'dataset') or (platform_machine == 'AMD64' and sys_platform == 'linux' and extra == 'dataset') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'dataset')", specifier = ">=0.3.0,<0.12.0" },
@@ -3449,7 +3458,7 @@ requires-dist = [
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.28.0" },
]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "accelerate-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "fastwam", "hilserl", "vla-jepa", "async", "peft", "annotations", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "accelerate-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "timm-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "fastwam", "evo1", "hilserl", "vla-jepa", "async", "peft", "annotations", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
[[package]]
name = "librt"