diff --git a/src/lerobot/policies/evo1/evo1_model.py b/src/lerobot/policies/evo1/evo1_model.py index a9637eda0..c4f75af86 100644 --- a/src/lerobot/policies/evo1/evo1_model.py +++ b/src/lerobot/policies/evo1/evo1_model.py @@ -14,12 +14,10 @@ from __future__ import annotations -from collections.abc import Sequence from typing import Any import torch import torch.nn as nn -from PIL import Image from .flow_matching import FlowmatchingActionHead from .internvl3_embedder import InternVL3Embedder @@ -73,22 +71,25 @@ class EVO1(nn.Module): self.per_action_dim = per_action_dim self.action_head = FlowmatchingActionHead(config=config).to(self._device) - def _normalize_image_batches( + def get_vl_embeddings( self, - images: Sequence[Image.Image | torch.Tensor] | Sequence[Sequence[Image.Image | torch.Tensor]], - prompt: str | list[str] | None, + images: list[torch.Tensor], image_mask: torch.Tensor, - ) -> tuple[list[list[Image.Image | torch.Tensor]], list[str], torch.Tensor]: + prompt: str | list[str] | None = None, + return_cls_only: bool | None = None, + ) -> torch.Tensor: + """Fused VL embeddings from per-camera image batches. + + Args: + images: list of per-camera tensors, each shaped ``(B, C, H, W)`` with values in ``[0, 1]``. + image_mask: bool tensor ``(B, max_views)`` marking present views. + """ + if return_cls_only is None: + return_cls_only = self.return_cls_only if not images: raise ValueError("EVO1 expects at least one image per sample.") - first = images[0] - if isinstance(first, (Image.Image, torch.Tensor)): - image_batches = [list(images)] # type: ignore[arg-type] - else: - image_batches = [list(sample) for sample in images] # type: ignore[arg-type] - - batch_size = len(image_batches) + batch_size = images[0].shape[0] if prompt is None: prompts = [""] * batch_size elif isinstance(prompt, str): @@ -107,21 +108,8 @@ class EVO1(nn.Module): f"image_mask batch size {image_mask.shape[0]} does not match image batch size {batch_size}" ) - return image_batches, prompts, image_mask - - def get_vl_embeddings( - self, - images: list[Image.Image | torch.Tensor] | list[list[Image.Image | torch.Tensor]], - image_mask: torch.Tensor, - prompt: str | list[str] | None = None, - return_cls_only: bool | None = None, - ) -> torch.Tensor: - if return_cls_only is None: - return_cls_only = self.return_cls_only - - image_batches, prompts, image_mask = self._normalize_image_batches(images, prompt, image_mask) - return self.embedder.get_fused_image_text_embedding_from_tensor_images( - image_tensors_batch=image_batches, + return self.embedder.get_fused_image_text_embedding_batched( + camera_images=images, image_masks=image_mask, text_prompts=prompts, return_cls_only=return_cls_only, diff --git a/src/lerobot/policies/evo1/internvl3_embedder.py b/src/lerobot/policies/evo1/internvl3_embedder.py index ca9abbbeb..6040cf759 100644 --- a/src/lerobot/policies/evo1/internvl3_embedder.py +++ b/src/lerobot/policies/evo1/internvl3_embedder.py @@ -14,7 +14,6 @@ from __future__ import annotations -import functools import logging from collections.abc import Sequence from typing import TYPE_CHECKING @@ -22,8 +21,7 @@ from typing import TYPE_CHECKING import torch import torch.nn as nn import torchvision.transforms.functional as tvf -from PIL import Image -from torchvision.transforms.functional import to_pil_image +from torchvision.transforms.functional import InterpolationMode from lerobot.utils.import_utils import _transformers_available, require_package @@ -42,51 +40,64 @@ IMG_END_TOKEN = "" # nosec B105 logger = logging.getLogger(__name__) -@functools.lru_cache(maxsize=10000) -def get_target_aspect_ratio(orig_width: int, orig_height: int, image_size: int, min_num: int, max_num: int): - aspect_ratio = orig_width / orig_height - target_ratios = { - (i, j) - for n in range(min_num, max_num + 1) - for i in range(1, n + 1) - for j in range(1, n + 1) - if i * j <= max_num and i * j >= min_num - } - target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) +def _batched_resize_01(images: torch.Tensor, image_size: int) -> torch.Tensor: + """Resize a batch of ``[0, 1]`` images to ``(image_size, image_size)`` on-device. - best_ratio_diff = float("inf") - best_ratio = (1, 1) - area = orig_width * orig_height - for ratio in target_ratios: - target_ar = ratio[0] / ratio[1] - diff = abs(aspect_ratio - target_ar) - if diff < best_ratio_diff: - best_ratio_diff = diff - best_ratio = ratio - elif diff == best_ratio_diff and area > 0.5 * image_size**2 * ratio[0] * ratio[1]: - best_ratio = ratio - return best_ratio + Numerically mirrors InternVL3's per-image PIL preprocessing + (``to_pil_image`` -> ``Image.resize`` -> ``to_tensor``): the float input is quantized to uint8 + exactly as ``to_pil_image`` does, then resized with bicubic interpolation and antialiasing, + which matches PIL's default resampler. This runs as a single batched op instead of a per-image + Python loop with a GPU->CPU->PIL->GPU round-trip. + + Args: + images: float tensor of shape ``(N, C, H, W)`` with values in ``[0, 1]``. + + Returns: + float32 tensor of shape ``(N, C, image_size, image_size)`` with values in ``[0, 1]``. + """ + # to_pil_image() quantizes float [0, 1] to uint8 (x * 255, truncated); replicate that so the + # bicubic resample sees the same integer pixels PIL would. + pixels_u8 = (images * 255.0).clamp(0, 255).to(torch.uint8) + resized = tvf.resize( + pixels_u8, [image_size, image_size], interpolation=InterpolationMode.BICUBIC, antialias=True + ) + return resized.to(torch.float32) / 255.0 -def dynamic_preprocess(image, min_num=1, max_num=1, image_size=448, use_thumbnail=False): - orig_width, orig_height = image.size - ratio_w, ratio_h = get_target_aspect_ratio(orig_width, orig_height, image_size, min_num, max_num) - target_width = image_size * ratio_w - target_height = image_size * ratio_h - blocks = ratio_w * ratio_h - resized_img = image.resize((target_width, target_height)) - processed_images = [] - for i in range(blocks): - box = ( - (i % (target_width // image_size)) * image_size, - (i // (target_width // image_size)) * image_size, - ((i % (target_width // image_size)) + 1) * image_size, - ((i // (target_width // image_size)) + 1) * image_size, - ) - processed_images.append(resized_img.crop(box)) - if use_thumbnail and len(processed_images) != 1: - processed_images.append(image.resize((image_size, image_size))) - return processed_images +def _batched_pixel_values( + camera_images: Sequence[torch.Tensor], + max_views: int, + image_size: int, + mean: torch.Tensor, + std: torch.Tensor, + dtype: torch.dtype, + device: torch.device | str, +) -> torch.Tensor: + """Build InternVL3 ``pixel_values`` from per-camera ``[0, 1]`` image batches without leaving the device. + + Equivalent to running the old per-sample/per-image PIL path (resize -> to_tensor -> ImageNet + normalize, a single tile per image) but batched across the whole minibatch. Absent views (fewer + cameras than ``max_views``) are zero-padded to reproduce the previous ``torch.zeros_like`` + padding; those views are masked out downstream via the attention mask. + + Returns: + ``pixel_values`` of shape ``(B * max_views, C, image_size, image_size)``, ordered row-major + over ``(sample, view)`` to match the old preprocessing. + """ + resized: list[torch.Tensor] = [] + for image in camera_images: + resized.append(_batched_resize_01(image.to(device=device), image_size).to(dtype)) + + batch_size = resized[0].shape[0] + channels = resized[0].shape[1] + while len(resized) < max_views: + resized.append(torch.zeros(batch_size, channels, image_size, image_size, dtype=dtype, device=device)) + + stacked = torch.stack(resized[:max_views], dim=1) # (B, V, C, H, W) + mean = mean.to(device=device, dtype=dtype).view(1, 1, -1, 1, 1) + std = std.to(device=device, dtype=dtype).view(1, 1, -1, 1, 1) + normalized = (stacked - mean) / std + return normalized.reshape(batch_size * max_views, channels, image_size, image_size) class InternVL3Embedder(nn.Module): @@ -191,42 +202,6 @@ class InternVL3Embedder(nn.Module): "Requested gradient checkpointing, but model does not expose checkpointing controls." ) - def _preprocess_single_image(self, image: Image.Image | torch.Tensor) -> torch.Tensor: - if isinstance(image, torch.Tensor): - pil_image = to_pil_image(image.detach().cpu()) - else: - pil_image = image.convert("RGB") - tiles = dynamic_preprocess(pil_image, image_size=self.image_size) - tile_tensors = torch.stack([tvf.to_tensor(tile) for tile in tiles]).to( - device=self.device, dtype=torch.bfloat16 - ) - mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1) - std = torch.tensor(IMAGENET_STD, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1) - return (tile_tensors - mean) / std - - def _preprocess_images( - self, - image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]], - ) -> tuple[torch.Tensor, list[list[int]]]: - pixel_values_list = [] - batch_num_tiles_list: list[list[int]] = [] - - for image_tensors in image_tensors_batch: - num_tiles_list: list[int] = [] - for image in image_tensors: - tiles = self._preprocess_single_image(image) - pixel_values_list.append(tiles) - num_tiles_list.append(int(tiles.shape[0])) - batch_num_tiles_list.append(num_tiles_list) - - if pixel_values_list: - pixel_values = torch.cat(pixel_values_list, dim=0) - else: - pixel_values = torch.empty( - 0, 3, self.image_size, self.image_size, dtype=torch.bfloat16, device=self.device - ) - return pixel_values, batch_num_tiles_list - def _build_multimodal_prompts( self, batch_num_tiles_list: list[list[int]], @@ -242,14 +217,70 @@ class InternVL3Embedder(nn.Module): prompts.append("".join(prompt_segments) + text_prompt.strip()) return prompts - def get_fused_image_text_embedding_from_tensor_images( + def get_fused_image_text_embedding_batched( self, - image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]], + camera_images: Sequence[torch.Tensor], image_masks: torch.Tensor, text_prompts: Sequence[str], return_cls_only: bool = True, ): - pixel_values, batch_num_tiles_list = self._preprocess_images(image_tensors_batch) + """Fused VL embedding from per-camera ``[0, 1]`` image batches (no PIL, no host round-trip). + + Args: + camera_images: list of per-camera tensors, each shaped ``(B, C, H, W)`` in ``[0, 1]``. + image_masks: bool tensor ``(B, max_views)`` marking present views. + """ + max_views = int(image_masks.shape[1]) + batch_size = int(image_masks.shape[0]) + mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16) + std = torch.tensor(IMAGENET_STD, device=self.device, dtype=torch.bfloat16) + pixel_values = _batched_pixel_values( + camera_images, max_views, self.image_size, mean, std, torch.bfloat16, self.device + ) + # InternVL3 preprocessing uses a single tile per image (max_num=1). + batch_num_tiles_list = [[1] * max_views for _ in range(batch_size)] + return self._forward_vlm( + pixel_values, batch_num_tiles_list, image_masks, text_prompts, return_cls_only + ) + + def _mask_absent_image_tokens( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + image_masks: torch.Tensor, + batch_num_tiles_list: list[list[int]], + ) -> torch.Tensor: + """Zero attention over the image-context tokens of absent views, fully vectorized. + + Reproduces the previous per-sample/per-image Python loop, which called ``.item()`` once per + image and forced a device->host sync each time, without any host<->device synchronization. + """ + # A single tile per image (max_num=1), so every image occupies the same number of + # context tokens. + tiles_per_image = ( + batch_num_tiles_list[0][0] if batch_num_tiles_list and batch_num_tiles_list[0] else 1 + ) + tokens_per_image = self.num_image_token * tiles_per_image + + image_masks = image_masks.to(device=input_ids.device).bool() + img_token_mask = input_ids == self.img_context_token_id # (B, L) + # keep[b, k] tells whether the k-th image-context token (ordered view0, view1, ...) survives. + per_token_keep = image_masks.repeat_interleave(tokens_per_image, dim=1) # (B, V * tokens_per_image) + # Rank each context token by its running position among the row's context tokens. + ctx_index = img_token_mask.to(torch.long).cumsum(dim=1) - 1 + ctx_index = ctx_index.clamp(min=0, max=per_token_keep.shape[1] - 1) + keep_here = torch.gather(per_token_keep, 1, ctx_index) # (B, L) + drop = img_token_mask & ~keep_here + return attention_mask.masked_fill(drop, 0) + + def _forward_vlm( + self, + pixel_values: torch.Tensor, + batch_num_tiles_list: list[list[int]], + image_masks: torch.Tensor, + text_prompts: Sequence[str], + return_cls_only: bool, + ): if pixel_values.shape[0] == 0: logger.warning("InternVL3 received an empty image batch after preprocessing.") hidden_size = getattr(self.model.config, "hidden_size", None) @@ -257,8 +288,7 @@ class InternVL3Embedder(nn.Module): hidden_size = getattr(self.model.config.text_config, "hidden_size", None) if hidden_size is None: raise RuntimeError("Unable to infer hidden size for empty InternVL3 batch.") - empty = torch.empty(0, hidden_size, device=self.device, dtype=torch.float32) - return empty + return torch.empty(0, hidden_size, device=self.device, dtype=torch.float32) prompts = self._build_multimodal_prompts(batch_num_tiles_list, text_prompts) @@ -270,23 +300,9 @@ class InternVL3Embedder(nn.Module): max_length=self.max_text_length, ).to(self.device) input_ids = model_inputs["input_ids"] - attention_mask = model_inputs["attention_mask"] - - # Zero out attention for absent images - img_token_mask = input_ids == self.img_context_token_id - tokens_per_tile = self.num_image_token - for batch_index in range(input_ids.shape[0]): - current_token_idx = 0 - img_token_locations = torch.where(img_token_mask[batch_index])[0] - for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]): - num_tokens_for_image = num_tiles * tokens_per_tile - if not bool(image_masks[batch_index, image_index].item()): - start_offset = current_token_idx - end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations)) - if start_offset < end_offset: - idxs = img_token_locations[start_offset:end_offset] - attention_mask[batch_index, idxs] = 0 - current_token_idx += num_tokens_for_image + attention_mask = self._mask_absent_image_tokens( + input_ids, model_inputs["attention_mask"], image_masks, batch_num_tiles_list + ) outputs = self.model( input_ids=input_ids, diff --git a/src/lerobot/policies/evo1/modeling_evo1.py b/src/lerobot/policies/evo1/modeling_evo1.py index 3606d63ca..ade94bc53 100644 --- a/src/lerobot/policies/evo1/modeling_evo1.py +++ b/src/lerobot/policies/evo1/modeling_evo1.py @@ -318,17 +318,20 @@ class EVO1Policy(PreTrainedPolicy): self._keep_frozen_embedder_eval() return self - def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[list[Tensor]], Tensor]: + def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], Tensor]: camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}.")) if not camera_keys: raise ValueError("EVO1 requires at least one visual observation feature.") + camera_keys = list(camera_keys)[: self.config.max_views] - # Normalize each camera tensor to (B, C, H, W) up-front so that batch_size is read - # from a real batch dim and not from C in the unbatched (C, H, W) case. - normalized: dict[str, Tensor] = {} - for camera_key in camera_keys[: self.config.max_views]: + # Keep each present camera as a batched (B, C, H, W) tensor on its current (GPU) device. + # Resizing/normalization and zero-padding of absent views happen batched inside the + # embedder, so images never leave the device here (no per-sample .cpu() round-trip). + camera_images: list[Tensor] = [] + for camera_key in camera_keys: image = batch[camera_key] if image.dim() == 3: + # Promote an unbatched (C, H, W) frame so batch_size is read from a real batch dim. image = image.unsqueeze(0) elif image.dim() == 5: image = image[:, -1] @@ -336,24 +339,16 @@ class EVO1Policy(PreTrainedPolicy): raise ValueError( f"Unsupported image tensor shape for EVO1: key={camera_key} shape={tuple(image.shape)}" ) - normalized[camera_key] = image + camera_images.append(image) - batch_size = normalized[camera_keys[0]].shape[0] - image_batches: list[list[Tensor]] = [] - image_masks = torch.zeros(batch_size, self.config.max_views, dtype=torch.bool) + batch_size = camera_images[0].shape[0] + n_present = len(camera_images) + image_masks = torch.zeros( + batch_size, self.config.max_views, dtype=torch.bool, device=camera_images[0].device + ) + image_masks[:, :n_present] = True - for batch_index in range(batch_size): - sample_images: list[Tensor] = [] - for camera_key in camera_keys[: self.config.max_views]: - sample_images.append(normalized[camera_key][batch_index].detach().cpu()) - if not sample_images: - raise ValueError("EVO1 received a batch without any image tensor.") - while len(sample_images) < self.config.max_views: - sample_images.append(torch.zeros_like(sample_images[0])) - image_batches.append(sample_images[: self.config.max_views]) - image_masks[batch_index, : min(len(camera_keys), self.config.max_views)] = True - - return image_batches, image_masks + return camera_images, image_masks def _compute_fused_tokens( self, diff --git a/tests/policies/evo1/test_evo1.py b/tests/policies/evo1/test_evo1.py index 9ab531f63..81070f92a 100644 --- a/tests/policies/evo1/test_evo1.py +++ b/tests/policies/evo1/test_evo1.py @@ -24,6 +24,11 @@ import lerobot.policies.evo1.modeling_evo1 as modeling_evo1 from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.evo1.configuration_evo1 import Evo1Config from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead +from lerobot.policies.evo1.internvl3_embedder import ( + IMAGENET_MEAN, + IMAGENET_STD, + _batched_pixel_values, +) from lerobot.policies.evo1.processor_evo1 import ( Evo1ActionProcessorStep, Evo1PadActionProcessorStep, @@ -60,7 +65,9 @@ class DummyEVO1(nn.Module): self.get_vl_embeddings_calls += 1 self.grad_enabled_calls.append(torch.is_grad_enabled()) self.embedder_training_calls.append(self.embedder.training) - return torch.ones(len(images), 4, EMBED_DIM, requires_grad=torch.is_grad_enabled()) + # images is a list of per-camera (B, C, H, W) tensors, so the batch dim is images[0].shape[0]. + batch_size = images[0].shape[0] + return torch.ones(batch_size, 4, EMBED_DIM, requires_grad=torch.is_grad_enabled()) def forward( self, @@ -397,10 +404,12 @@ def test_collect_image_batches_handles_unbatched_chw(monkeypatch): f"{OBS_IMAGES}.front": torch.rand(3, 16, 16), } - image_batches, image_masks = policy._collect_image_batches(batch) + camera_images, image_masks = policy._collect_image_batches(batch) - assert len(image_batches) == 1 - assert len(image_batches[0]) == policy.config.max_views + # One present camera, returned as a batched (B, C, H, W) tensor with the unbatched CHW frame + # promoted to batch_size=1 (not read as batch_size=C). + assert len(camera_images) == 1 + assert camera_images[0].shape == (1, 3, 16, 16) assert image_masks.tolist() == [[True, False]] @@ -447,3 +456,28 @@ def test_flowmatching_dict_config_enables_state_encoder_for_horizon_one(): assert pred_velocity.shape == (2, ACTION_DIM) assert noise.shape == (2, 1, ACTION_DIM) + + +def test_evo1_batched_pixel_values_shape_and_zero_padding(): + torch.manual_seed(0) + batch_size, image_size, max_views = 2, 448, 3 + camera_images = [torch.rand(batch_size, 3, 40, 50)] # a single present camera + mean = torch.tensor(IMAGENET_MEAN) + std = torch.tensor(IMAGENET_STD) + + pixel_values = _batched_pixel_values( + camera_images, max_views, image_size, mean, std, torch.float32, torch.device("cpu") + ) + + assert pixel_values.shape == (batch_size * max_views, 3, image_size, image_size) + grouped = pixel_values.reshape(batch_size, max_views, 3, image_size, image_size) + # Absent views (indices 1, 2) are zero images normalized to -mean/std, matching the old padding. + expected_pad = (-mean / std).view(1, 3, 1, 1) + for view in (1, 2): + assert torch.allclose( + grouped[:, view], expected_pad.expand(batch_size, 3, image_size, image_size), atol=1e-5 + ) + # The present view is genuinely different from the constant pad value. + assert not torch.allclose( + grouped[:, 0], expected_pad.expand(batch_size, 3, image_size, image_size), atol=1e-3 + )