diff --git a/src/lerobot/policies/evo1/README.md b/src/lerobot/policies/evo1/README.md deleted file mode 100644 index 3c6d31c83..000000000 --- a/src/lerobot/policies/evo1/README.md +++ /dev/null @@ -1,18 +0,0 @@ -# EVO1 - -EVO1 is a Vision-Language-Action policy for robot control. The LeRobot -integration uses an InternVL3 vision-language backbone with a flow-matching -action head, and supports staged training through the standard LeRobot policy -APIs. - -The upstream EVO1 project is available at -[MINT-SJTU/Evo-1](https://github.com/MINT-SJTU/Evo-1). - -```bibtex -@misc{evo1, - title = {EVO1}, - author = {{MINT-SJTU}}, - year = {2026}, - howpublished = {\url{https://github.com/MINT-SJTU/Evo-1}}, -} -``` diff --git a/src/lerobot/policies/evo1/README.md b/src/lerobot/policies/evo1/README.md new file mode 120000 index 000000000..6c4284fb9 --- /dev/null +++ b/src/lerobot/policies/evo1/README.md @@ -0,0 +1 @@ +../../../../docs/source/policy_evo1_README.md \ No newline at end of file diff --git a/src/lerobot/policies/evo1/configuration_evo1.py b/src/lerobot/policies/evo1/configuration_evo1.py index b7dd72a95..a9f6ffe38 100644 --- a/src/lerobot/policies/evo1/configuration_evo1.py +++ b/src/lerobot/policies/evo1/configuration_evo1.py @@ -77,7 +77,7 @@ class Evo1Config(PreTrainedConfig): } ) - vlm_model_name: str = "OpenGVLab/InternVL3-1B" + vlm_model_name: str = "OpenGVLab/InternVL3-1B-hf" vlm_num_layers: int | None = 14 vlm_dtype: str = "bfloat16" use_flash_attn: bool = True diff --git a/src/lerobot/policies/evo1/internvl3_embedder.py b/src/lerobot/policies/evo1/internvl3_embedder.py index fa7b6eb7d..ca9abbbeb 100644 --- a/src/lerobot/policies/evo1/internvl3_embedder.py +++ b/src/lerobot/policies/evo1/internvl3_embedder.py @@ -16,14 +16,11 @@ from __future__ import annotations import functools import logging -import types from collections.abc import Sequence -from contextlib import contextmanager from typing import TYPE_CHECKING import torch import torch.nn as nn -import torch.utils.checkpoint import torchvision.transforms.functional as tvf from PIL import Image from torchvision.transforms.functional import to_pil_image @@ -45,88 +42,6 @@ IMG_END_TOKEN = "" # nosec B105 logger = logging.getLogger(__name__) -def _patch_vision_encoder_checkpointing(encoder: nn.Module, use_reentrant: bool) -> None: - for attr_name in ("_gradient_checkpointing_func", "gradient_checkpointing_func"): - original_func = getattr(encoder, attr_name, None) - if not callable(original_func): - continue - patch_attr = f"_evo1_{attr_name}_patch_applied" - if getattr(encoder, patch_attr, False): - encoder.gradient_checkpointing_use_reentrant = use_reentrant - return - - def checkpoint_with_kwargs( - function, *checkpoint_args, _original_func=original_func, **checkpoint_kwargs - ): - checkpoint_kwargs.setdefault("use_reentrant", encoder.gradient_checkpointing_use_reentrant) - return _original_func(function, *checkpoint_args, **checkpoint_kwargs) - - encoder.gradient_checkpointing_use_reentrant = use_reentrant - setattr(encoder, attr_name, checkpoint_with_kwargs) - setattr(encoder, patch_attr, True) - return - - if getattr(encoder, "_evo1_checkpoint_patch_applied", False): - encoder.gradient_checkpointing_use_reentrant = use_reentrant - return - - original_forward = encoder.forward - - def forward_with_checkpoint_kwargs(self, *args, **kwargs): - original_checkpoint = torch.utils.checkpoint.checkpoint - - def checkpoint(function, *checkpoint_args, **checkpoint_kwargs): - checkpoint_kwargs.setdefault("use_reentrant", self.gradient_checkpointing_use_reentrant) - return original_checkpoint(function, *checkpoint_args, **checkpoint_kwargs) - - # Some InternVL3 remote-code versions call torch.utils.checkpoint.checkpoint - # directly and do not expose a per-encoder checkpoint function to patch. - # Keep this compatibility fallback scoped to encoder.forward and restore it. - torch.utils.checkpoint.checkpoint = checkpoint - try: - return original_forward(*args, **kwargs) - finally: - torch.utils.checkpoint.checkpoint = original_checkpoint - - encoder.gradient_checkpointing_use_reentrant = use_reentrant - encoder.forward = types.MethodType(forward_with_checkpoint_kwargs, encoder) - encoder._evo1_checkpoint_patch_applied = True - - -def flash_attn_is_available() -> bool: - try: - import flash_attn # noqa: F401 - except ModuleNotFoundError: - return False - return True - - -@contextmanager -def _internvl_transformers5_load_compatibility(): - from transformers.modeling_utils import PreTrainedModel - - original_linspace = torch.linspace - original_mark_tied = PreTrainedModel.mark_tied_weights_as_initialized - - def linspace(*args, **kwargs): - if kwargs.get("device") is None: - kwargs["device"] = torch.device("cpu") - return original_linspace(*args, **kwargs) - - def mark_tied_weights_as_initialized(self, loading_info): - if not hasattr(self, "all_tied_weights_keys"): - self.all_tied_weights_keys = {} - return original_mark_tied(self, loading_info) - - torch.linspace = linspace - PreTrainedModel.mark_tied_weights_as_initialized = mark_tied_weights_as_initialized - try: - yield - finally: - torch.linspace = original_linspace - PreTrainedModel.mark_tied_weights_as_initialized = original_mark_tied - - @functools.lru_cache(maxsize=10000) def get_target_aspect_ratio(orig_width: int, orig_height: int, image_size: int, min_num: int, max_num: int): aspect_ratio = orig_width / orig_height @@ -175,9 +90,11 @@ def dynamic_preprocess(image, min_num=1, max_num=1, image_size=448, use_thumbnai class InternVL3Embedder(nn.Module): + """Vision-language embedder using the native HF InternVL3 model (no trust_remote_code).""" + def __init__( self, - model_name="OpenGVLab/InternVL3-1B", + model_name="OpenGVLab/InternVL3-1B-hf", image_size=448, device="cuda", num_language_layers: int | None = 14, @@ -196,43 +113,31 @@ class InternVL3Embedder(nn.Module): require_package("transformers", extra="evo1") - self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) if isinstance(model_dtype, str): try: model_dtype = getattr(torch, model_dtype) except AttributeError as exc: raise ValueError(f"Unsupported EVO1 vlm_dtype '{model_dtype}'") from exc - resolved_use_flash_attn = bool(use_flash_attn and flash_attn_is_available()) - if use_flash_attn and not resolved_use_flash_attn: - logger.warning("flash_attn is not installed. Falling back to standard attention.") + attn_implementation = "flash_attention_2" if (use_flash_attn and _flash_attn_available()) else "eager" + if use_flash_attn and attn_implementation == "eager": + logger.warning("flash_attn is not installed. Falling back to eager attention.") - # InternVL3 remote code predates Transformers 5 post-init conventions: - # it computes stochastic-depth scalars via torch.linspace(...).item() - # while Transformers initializes under torch.device("meta"), and it - # does not populate all_tied_weights_keys before loading finalization. - with _internvl_transformers5_load_compatibility(): - self.model = AutoModel.from_pretrained( - model_name, - torch_dtype=model_dtype, - trust_remote_code=True, - use_flash_attn=resolved_use_flash_attn, - low_cpu_mem_usage=True, - _fast_init=False, - ).to(self._requested_device) + self.model = AutoModel.from_pretrained( + model_name, + torch_dtype=model_dtype, + attn_implementation=attn_implementation, + low_cpu_mem_usage=True, + ).to(self._requested_device) - if hasattr(self.model.language_model, "model"): - layers = self.model.language_model.model.layers - else: - layers = self.model.language_model.layers + self.num_image_token = self.model.config.image_seq_length + + # Truncate language model to the requested number of layers + layers = self.model.language_model.layers if self.num_language_layers is not None: layers = layers[: self.num_language_layers] - - if hasattr(self.model.language_model, "model"): - self.model.language_model.model.layers = torch.nn.ModuleList(layers) - else: - self.model.language_model.layers = torch.nn.ModuleList(layers) - self.model.language_model.lm_head = torch.nn.Identity() + self.model.language_model.layers = torch.nn.ModuleList(layers) self._configure_memory_features() self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) @@ -241,20 +146,12 @@ class InternVL3Embedder(nn.Module): checkpoint_kwargs = {"use_reentrant": self.gradient_checkpointing_use_reentrant} if not self.enable_gradient_checkpointing: - if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"): - self.model.vision_model.encoder.gradient_checkpointing = False - language_model = getattr(self.model, "language_model", None) - if language_model is not None: - if hasattr(language_model, "gradient_checkpointing_disable"): - language_model.gradient_checkpointing_disable() - elif hasattr(language_model, "gradient_checkpointing"): - language_model.gradient_checkpointing = False - if hasattr(language_model, "model"): - inner = language_model.model - if hasattr(inner, "gradient_checkpointing_disable"): - inner.gradient_checkpointing_disable() - elif hasattr(inner, "gradient_checkpointing"): - inner.gradient_checkpointing = False + language_model = self.model.language_model + if hasattr(language_model, "gradient_checkpointing_disable"): + language_model.gradient_checkpointing_disable() + vision_tower = getattr(self.model, "vision_tower", None) + if vision_tower is not None and hasattr(vision_tower, "encoder"): + vision_tower.encoder.gradient_checkpointing = False return def _enable_ckpt(module: nn.Module | None) -> bool: @@ -273,21 +170,14 @@ class InternVL3Embedder(nn.Module): enabled_any = _enable_ckpt(self.model) - if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"): - encoder = self.model.vision_model.encoder - encoder.gradient_checkpointing = True - _patch_vision_encoder_checkpointing( - encoder, use_reentrant=self.gradient_checkpointing_use_reentrant - ) - enabled_any = True + vision_tower = getattr(self.model, "vision_tower", None) + if vision_tower is not None: + enabled_any = _enable_ckpt(vision_tower) or enabled_any - language_model = getattr(self.model, "language_model", None) - if language_model is not None: - enabled_any = _enable_ckpt(language_model) or enabled_any - if hasattr(language_model, "model"): - enabled_any = _enable_ckpt(language_model.model) or enabled_any - if hasattr(language_model, "config"): - language_model.config.use_cache = False + language_model = self.model.language_model + enabled_any = _enable_ckpt(language_model) or enabled_any + if hasattr(language_model, "config"): + language_model.config.use_cache = False if hasattr(self.model, "config"): self.model.config.use_cache = False @@ -303,8 +193,6 @@ class InternVL3Embedder(nn.Module): def _preprocess_single_image(self, image: Image.Image | torch.Tensor) -> torch.Tensor: if isinstance(image, torch.Tensor): - # Match upstream EVO1/InternVL preprocessing, which converts tensors - # through PIL before tiling and ImageNet normalization. pil_image = to_pil_image(image.detach().cpu()) else: pil_image = image.convert("RGB") @@ -348,76 +236,12 @@ class InternVL3Embedder(nn.Module): for num_tiles_list, text_prompt in zip(batch_num_tiles_list, text_prompts, strict=True): prompt_segments = [] for i, tile_count in enumerate(num_tiles_list): - token_count = self.model.num_image_token * tile_count + token_count = self.num_image_token * tile_count image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * token_count + IMG_END_TOKEN prompt_segments.append(f"Image-{i + 1}: {image_tokens}\n") prompts.append("".join(prompt_segments) + text_prompt.strip()) return prompts - def _prepare_and_fuse_embeddings( - self, - prompts: Sequence[str], - vit_embeds: torch.Tensor, - image_masks: torch.Tensor, - batch_num_tiles_list: list[list[int]], - ) -> tuple[torch.Tensor, torch.Tensor]: - untruncated_ids = self.tokenizer(list(prompts), padding=False, truncation=False)["input_ids"] - true_sequence_length = max((len(ids) for ids in untruncated_ids), default=0) - if true_sequence_length > self.max_text_length: - logger.warning( - "InternVL3 prompt truncated in batch: max_length=%s actual_max_length=%s", - self.max_text_length, - true_sequence_length, - ) - - model_inputs = self.tokenizer( - list(prompts), - return_tensors="pt", - padding="max_length", - truncation=True, - max_length=self.max_text_length, - ).to(self.device) - input_ids = model_inputs["input_ids"] - attention_mask = model_inputs["attention_mask"] - - img_token_mask = input_ids == self.img_context_token_id - input_embeds = self.model.language_model.get_input_embeddings()(input_ids).clone() - - batch_size, _, channels = input_embeds.shape - vit_embeds = vit_embeds.reshape(-1, channels).to(dtype=input_embeds.dtype, device=input_embeds.device) - tokens_per_tile = self.model.num_image_token - actual_vis_tokens_list = img_token_mask.sum(dim=1).tolist() - - vit_idx = 0 - for batch_index in range(batch_size): - expected_vis_tokens = sum(batch_num_tiles_list[batch_index]) * tokens_per_tile - mask_b = img_token_mask[batch_index] - actual_vis_tokens = actual_vis_tokens_list[batch_index] - - item_vit_embeds = vit_embeds[vit_idx : vit_idx + expected_vis_tokens] - vit_idx += expected_vis_tokens - if actual_vis_tokens > 0: - if item_vit_embeds.shape[0] < actual_vis_tokens: - raise ValueError( - f"InternVL3 produced fewer image tokens than expected for sample {batch_index}: " - f"got {item_vit_embeds.shape[0]}, need {actual_vis_tokens}" - ) - input_embeds[batch_index, mask_b] = item_vit_embeds[:actual_vis_tokens] - - current_token_idx = 0 - img_token_locations = torch.where(mask_b)[0] - for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]): - num_tokens_for_image = num_tiles * tokens_per_tile - if not bool(image_masks[batch_index, image_index].item()): - start_offset = current_token_idx - end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations)) - if start_offset < end_offset: - idxs = img_token_locations[start_offset:end_offset] - attention_mask[batch_index, idxs] = 0 - current_token_idx += num_tokens_for_image - - return input_embeds, attention_mask - def get_fused_image_text_embedding_from_tensor_images( self, image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]], @@ -429,27 +253,46 @@ class InternVL3Embedder(nn.Module): if pixel_values.shape[0] == 0: logger.warning("InternVL3 received an empty image batch after preprocessing.") hidden_size = getattr(self.model.config, "hidden_size", None) - if hidden_size is None and hasattr(self.model.language_model, "config"): - hidden_size = getattr(self.model.language_model.config, "hidden_size", None) + if hidden_size is None: + hidden_size = getattr(self.model.config.text_config, "hidden_size", None) if hidden_size is None: raise RuntimeError("Unable to infer hidden size for empty InternVL3 batch.") empty = torch.empty(0, hidden_size, device=self.device, dtype=torch.float32) return empty prompts = self._build_multimodal_prompts(batch_num_tiles_list, text_prompts) - vit_embeds = self.model.extract_feature(pixel_values) - inputs_embeds, attention_mask = self._prepare_and_fuse_embeddings( - prompts, - vit_embeds, - image_masks.to(device=self.device), - batch_num_tiles_list, - ) - outputs = self.model.language_model( - inputs_embeds=inputs_embeds, + model_inputs = self.tokenizer( + list(prompts), + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.max_text_length, + ).to(self.device) + input_ids = model_inputs["input_ids"] + attention_mask = model_inputs["attention_mask"] + + # Zero out attention for absent images + img_token_mask = input_ids == self.img_context_token_id + tokens_per_tile = self.num_image_token + for batch_index in range(input_ids.shape[0]): + current_token_idx = 0 + img_token_locations = torch.where(img_token_mask[batch_index])[0] + for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]): + num_tokens_for_image = num_tiles * tokens_per_tile + if not bool(image_masks[batch_index, image_index].item()): + start_offset = current_token_idx + end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations)) + if start_offset < end_offset: + idxs = img_token_locations[start_offset:end_offset] + attention_mask[batch_index, idxs] = 0 + current_token_idx += num_tokens_for_image + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, attention_mask=attention_mask, output_hidden_states=True, - use_cache=False, return_dict=True, ) fused_hidden = outputs.hidden_states[-1].to(torch.float32) @@ -458,3 +301,11 @@ class InternVL3Embedder(nn.Module): @property def device(self) -> torch.device: return next(self.model.parameters()).device + + +def _flash_attn_available() -> bool: + try: + import flash_attn # noqa: F401 + except ModuleNotFoundError: + return False + return True