refactor(evo1): use native HF InternVL3-1B-hf, drop trust_remote_code

- Switch from OpenGVLab/InternVL3-1B (requires trust_remote_code=True)
  to OpenGVLab/InternVL3-1B-hf (native transformers implementation).
- Replace manual _extract_feature + _prepare_and_fuse_embeddings with
  a single model.forward() call — verified bit-for-bit identical output.
- Remove ~170 lines of manual ViT/pixel-shuffle/projection logic.
- Symlink README.md to docs/source/ following repo convention.

Weights are byte-identical between both model variants; only the module
naming differs. All 12 existing unit tests pass. Local training (10 steps)
on maximellerbach/omx_pickandplace confirmed working.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Martino Russi
2026-06-23 17:17:19 +02:00
parent 25556ceefe
commit 9423deda02
3 changed files with 73 additions and 239 deletions
-18
View File
@@ -1,18 +0,0 @@
# EVO1
EVO1 is a Vision-Language-Action policy for robot control. The LeRobot
integration uses an InternVL3 vision-language backbone with a flow-matching
action head, and supports staged training through the standard LeRobot policy
APIs.
The upstream EVO1 project is available at
[MINT-SJTU/Evo-1](https://github.com/MINT-SJTU/Evo-1).
```bibtex
@misc{evo1,
title = {EVO1},
author = {{MINT-SJTU}},
year = {2026},
howpublished = {\url{https://github.com/MINT-SJTU/Evo-1}},
}
```
+1
View File
@@ -0,0 +1 @@
../../../../docs/source/policy_evo1_README.md
@@ -77,7 +77,7 @@ class Evo1Config(PreTrainedConfig):
}
)
vlm_model_name: str = "OpenGVLab/InternVL3-1B"
vlm_model_name: str = "OpenGVLab/InternVL3-1B-hf"
vlm_num_layers: int | None = 14
vlm_dtype: str = "bfloat16"
use_flash_attn: bool = True
+71 -220
View File
@@ -16,14 +16,11 @@ from __future__ import annotations
import functools
import logging
import types
from collections.abc import Sequence
from contextlib import contextmanager
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
import torch.utils.checkpoint
import torchvision.transforms.functional as tvf
from PIL import Image
from torchvision.transforms.functional import to_pil_image
@@ -45,88 +42,6 @@ IMG_END_TOKEN = "</img>" # nosec B105
logger = logging.getLogger(__name__)
def _patch_vision_encoder_checkpointing(encoder: nn.Module, use_reentrant: bool) -> None:
for attr_name in ("_gradient_checkpointing_func", "gradient_checkpointing_func"):
original_func = getattr(encoder, attr_name, None)
if not callable(original_func):
continue
patch_attr = f"_evo1_{attr_name}_patch_applied"
if getattr(encoder, patch_attr, False):
encoder.gradient_checkpointing_use_reentrant = use_reentrant
return
def checkpoint_with_kwargs(
function, *checkpoint_args, _original_func=original_func, **checkpoint_kwargs
):
checkpoint_kwargs.setdefault("use_reentrant", encoder.gradient_checkpointing_use_reentrant)
return _original_func(function, *checkpoint_args, **checkpoint_kwargs)
encoder.gradient_checkpointing_use_reentrant = use_reentrant
setattr(encoder, attr_name, checkpoint_with_kwargs)
setattr(encoder, patch_attr, True)
return
if getattr(encoder, "_evo1_checkpoint_patch_applied", False):
encoder.gradient_checkpointing_use_reentrant = use_reentrant
return
original_forward = encoder.forward
def forward_with_checkpoint_kwargs(self, *args, **kwargs):
original_checkpoint = torch.utils.checkpoint.checkpoint
def checkpoint(function, *checkpoint_args, **checkpoint_kwargs):
checkpoint_kwargs.setdefault("use_reentrant", self.gradient_checkpointing_use_reentrant)
return original_checkpoint(function, *checkpoint_args, **checkpoint_kwargs)
# Some InternVL3 remote-code versions call torch.utils.checkpoint.checkpoint
# directly and do not expose a per-encoder checkpoint function to patch.
# Keep this compatibility fallback scoped to encoder.forward and restore it.
torch.utils.checkpoint.checkpoint = checkpoint
try:
return original_forward(*args, **kwargs)
finally:
torch.utils.checkpoint.checkpoint = original_checkpoint
encoder.gradient_checkpointing_use_reentrant = use_reentrant
encoder.forward = types.MethodType(forward_with_checkpoint_kwargs, encoder)
encoder._evo1_checkpoint_patch_applied = True
def flash_attn_is_available() -> bool:
try:
import flash_attn # noqa: F401
except ModuleNotFoundError:
return False
return True
@contextmanager
def _internvl_transformers5_load_compatibility():
from transformers.modeling_utils import PreTrainedModel
original_linspace = torch.linspace
original_mark_tied = PreTrainedModel.mark_tied_weights_as_initialized
def linspace(*args, **kwargs):
if kwargs.get("device") is None:
kwargs["device"] = torch.device("cpu")
return original_linspace(*args, **kwargs)
def mark_tied_weights_as_initialized(self, loading_info):
if not hasattr(self, "all_tied_weights_keys"):
self.all_tied_weights_keys = {}
return original_mark_tied(self, loading_info)
torch.linspace = linspace
PreTrainedModel.mark_tied_weights_as_initialized = mark_tied_weights_as_initialized
try:
yield
finally:
torch.linspace = original_linspace
PreTrainedModel.mark_tied_weights_as_initialized = original_mark_tied
@functools.lru_cache(maxsize=10000)
def get_target_aspect_ratio(orig_width: int, orig_height: int, image_size: int, min_num: int, max_num: int):
aspect_ratio = orig_width / orig_height
@@ -175,9 +90,11 @@ def dynamic_preprocess(image, min_num=1, max_num=1, image_size=448, use_thumbnai
class InternVL3Embedder(nn.Module):
"""Vision-language embedder using the native HF InternVL3 model (no trust_remote_code)."""
def __init__(
self,
model_name="OpenGVLab/InternVL3-1B",
model_name="OpenGVLab/InternVL3-1B-hf",
image_size=448,
device="cuda",
num_language_layers: int | None = 14,
@@ -196,43 +113,31 @@ class InternVL3Embedder(nn.Module):
require_package("transformers", extra="evo1")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if isinstance(model_dtype, str):
try:
model_dtype = getattr(torch, model_dtype)
except AttributeError as exc:
raise ValueError(f"Unsupported EVO1 vlm_dtype '{model_dtype}'") from exc
resolved_use_flash_attn = bool(use_flash_attn and flash_attn_is_available())
if use_flash_attn and not resolved_use_flash_attn:
logger.warning("flash_attn is not installed. Falling back to standard attention.")
attn_implementation = "flash_attention_2" if (use_flash_attn and _flash_attn_available()) else "eager"
if use_flash_attn and attn_implementation == "eager":
logger.warning("flash_attn is not installed. Falling back to eager attention.")
# InternVL3 remote code predates Transformers 5 post-init conventions:
# it computes stochastic-depth scalars via torch.linspace(...).item()
# while Transformers initializes under torch.device("meta"), and it
# does not populate all_tied_weights_keys before loading finalization.
with _internvl_transformers5_load_compatibility():
self.model = AutoModel.from_pretrained(
model_name,
torch_dtype=model_dtype,
trust_remote_code=True,
use_flash_attn=resolved_use_flash_attn,
low_cpu_mem_usage=True,
_fast_init=False,
).to(self._requested_device)
self.model = AutoModel.from_pretrained(
model_name,
torch_dtype=model_dtype,
attn_implementation=attn_implementation,
low_cpu_mem_usage=True,
).to(self._requested_device)
if hasattr(self.model.language_model, "model"):
layers = self.model.language_model.model.layers
else:
layers = self.model.language_model.layers
self.num_image_token = self.model.config.image_seq_length
# Truncate language model to the requested number of layers
layers = self.model.language_model.layers
if self.num_language_layers is not None:
layers = layers[: self.num_language_layers]
if hasattr(self.model.language_model, "model"):
self.model.language_model.model.layers = torch.nn.ModuleList(layers)
else:
self.model.language_model.layers = torch.nn.ModuleList(layers)
self.model.language_model.lm_head = torch.nn.Identity()
self.model.language_model.layers = torch.nn.ModuleList(layers)
self._configure_memory_features()
self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
@@ -241,20 +146,12 @@ class InternVL3Embedder(nn.Module):
checkpoint_kwargs = {"use_reentrant": self.gradient_checkpointing_use_reentrant}
if not self.enable_gradient_checkpointing:
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
self.model.vision_model.encoder.gradient_checkpointing = False
language_model = getattr(self.model, "language_model", None)
if language_model is not None:
if hasattr(language_model, "gradient_checkpointing_disable"):
language_model.gradient_checkpointing_disable()
elif hasattr(language_model, "gradient_checkpointing"):
language_model.gradient_checkpointing = False
if hasattr(language_model, "model"):
inner = language_model.model
if hasattr(inner, "gradient_checkpointing_disable"):
inner.gradient_checkpointing_disable()
elif hasattr(inner, "gradient_checkpointing"):
inner.gradient_checkpointing = False
language_model = self.model.language_model
if hasattr(language_model, "gradient_checkpointing_disable"):
language_model.gradient_checkpointing_disable()
vision_tower = getattr(self.model, "vision_tower", None)
if vision_tower is not None and hasattr(vision_tower, "encoder"):
vision_tower.encoder.gradient_checkpointing = False
return
def _enable_ckpt(module: nn.Module | None) -> bool:
@@ -273,21 +170,14 @@ class InternVL3Embedder(nn.Module):
enabled_any = _enable_ckpt(self.model)
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
encoder = self.model.vision_model.encoder
encoder.gradient_checkpointing = True
_patch_vision_encoder_checkpointing(
encoder, use_reentrant=self.gradient_checkpointing_use_reentrant
)
enabled_any = True
vision_tower = getattr(self.model, "vision_tower", None)
if vision_tower is not None:
enabled_any = _enable_ckpt(vision_tower) or enabled_any
language_model = getattr(self.model, "language_model", None)
if language_model is not None:
enabled_any = _enable_ckpt(language_model) or enabled_any
if hasattr(language_model, "model"):
enabled_any = _enable_ckpt(language_model.model) or enabled_any
if hasattr(language_model, "config"):
language_model.config.use_cache = False
language_model = self.model.language_model
enabled_any = _enable_ckpt(language_model) or enabled_any
if hasattr(language_model, "config"):
language_model.config.use_cache = False
if hasattr(self.model, "config"):
self.model.config.use_cache = False
@@ -303,8 +193,6 @@ class InternVL3Embedder(nn.Module):
def _preprocess_single_image(self, image: Image.Image | torch.Tensor) -> torch.Tensor:
if isinstance(image, torch.Tensor):
# Match upstream EVO1/InternVL preprocessing, which converts tensors
# through PIL before tiling and ImageNet normalization.
pil_image = to_pil_image(image.detach().cpu())
else:
pil_image = image.convert("RGB")
@@ -348,76 +236,12 @@ class InternVL3Embedder(nn.Module):
for num_tiles_list, text_prompt in zip(batch_num_tiles_list, text_prompts, strict=True):
prompt_segments = []
for i, tile_count in enumerate(num_tiles_list):
token_count = self.model.num_image_token * tile_count
token_count = self.num_image_token * tile_count
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * token_count + IMG_END_TOKEN
prompt_segments.append(f"Image-{i + 1}: {image_tokens}\n")
prompts.append("".join(prompt_segments) + text_prompt.strip())
return prompts
def _prepare_and_fuse_embeddings(
self,
prompts: Sequence[str],
vit_embeds: torch.Tensor,
image_masks: torch.Tensor,
batch_num_tiles_list: list[list[int]],
) -> tuple[torch.Tensor, torch.Tensor]:
untruncated_ids = self.tokenizer(list(prompts), padding=False, truncation=False)["input_ids"]
true_sequence_length = max((len(ids) for ids in untruncated_ids), default=0)
if true_sequence_length > self.max_text_length:
logger.warning(
"InternVL3 prompt truncated in batch: max_length=%s actual_max_length=%s",
self.max_text_length,
true_sequence_length,
)
model_inputs = self.tokenizer(
list(prompts),
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_text_length,
).to(self.device)
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]
img_token_mask = input_ids == self.img_context_token_id
input_embeds = self.model.language_model.get_input_embeddings()(input_ids).clone()
batch_size, _, channels = input_embeds.shape
vit_embeds = vit_embeds.reshape(-1, channels).to(dtype=input_embeds.dtype, device=input_embeds.device)
tokens_per_tile = self.model.num_image_token
actual_vis_tokens_list = img_token_mask.sum(dim=1).tolist()
vit_idx = 0
for batch_index in range(batch_size):
expected_vis_tokens = sum(batch_num_tiles_list[batch_index]) * tokens_per_tile
mask_b = img_token_mask[batch_index]
actual_vis_tokens = actual_vis_tokens_list[batch_index]
item_vit_embeds = vit_embeds[vit_idx : vit_idx + expected_vis_tokens]
vit_idx += expected_vis_tokens
if actual_vis_tokens > 0:
if item_vit_embeds.shape[0] < actual_vis_tokens:
raise ValueError(
f"InternVL3 produced fewer image tokens than expected for sample {batch_index}: "
f"got {item_vit_embeds.shape[0]}, need {actual_vis_tokens}"
)
input_embeds[batch_index, mask_b] = item_vit_embeds[:actual_vis_tokens]
current_token_idx = 0
img_token_locations = torch.where(mask_b)[0]
for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]):
num_tokens_for_image = num_tiles * tokens_per_tile
if not bool(image_masks[batch_index, image_index].item()):
start_offset = current_token_idx
end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations))
if start_offset < end_offset:
idxs = img_token_locations[start_offset:end_offset]
attention_mask[batch_index, idxs] = 0
current_token_idx += num_tokens_for_image
return input_embeds, attention_mask
def get_fused_image_text_embedding_from_tensor_images(
self,
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
@@ -429,27 +253,46 @@ class InternVL3Embedder(nn.Module):
if pixel_values.shape[0] == 0:
logger.warning("InternVL3 received an empty image batch after preprocessing.")
hidden_size = getattr(self.model.config, "hidden_size", None)
if hidden_size is None and hasattr(self.model.language_model, "config"):
hidden_size = getattr(self.model.language_model.config, "hidden_size", None)
if hidden_size is None:
hidden_size = getattr(self.model.config.text_config, "hidden_size", None)
if hidden_size is None:
raise RuntimeError("Unable to infer hidden size for empty InternVL3 batch.")
empty = torch.empty(0, hidden_size, device=self.device, dtype=torch.float32)
return empty
prompts = self._build_multimodal_prompts(batch_num_tiles_list, text_prompts)
vit_embeds = self.model.extract_feature(pixel_values)
inputs_embeds, attention_mask = self._prepare_and_fuse_embeddings(
prompts,
vit_embeds,
image_masks.to(device=self.device),
batch_num_tiles_list,
)
outputs = self.model.language_model(
inputs_embeds=inputs_embeds,
model_inputs = self.tokenizer(
list(prompts),
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_text_length,
).to(self.device)
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]
# Zero out attention for absent images
img_token_mask = input_ids == self.img_context_token_id
tokens_per_tile = self.num_image_token
for batch_index in range(input_ids.shape[0]):
current_token_idx = 0
img_token_locations = torch.where(img_token_mask[batch_index])[0]
for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]):
num_tokens_for_image = num_tiles * tokens_per_tile
if not bool(image_masks[batch_index, image_index].item()):
start_offset = current_token_idx
end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations))
if start_offset < end_offset:
idxs = img_token_locations[start_offset:end_offset]
attention_mask[batch_index, idxs] = 0
current_token_idx += num_tokens_for_image
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
return_dict=True,
)
fused_hidden = outputs.hidden_states[-1].to(torch.float32)
@@ -458,3 +301,11 @@ class InternVL3Embedder(nn.Module):
@property
def device(self) -> torch.device:
return next(self.model.parameters()).device
def _flash_attn_available() -> bool:
try:
import flash_attn # noqa: F401
except ModuleNotFoundError:
return False
return True