diff --git a/src/lerobot/policies/evo1/internvl3_embedder.py b/src/lerobot/policies/evo1/internvl3_embedder.py index ca9abbbeb..a10797eea 100644 --- a/src/lerobot/policies/evo1/internvl3_embedder.py +++ b/src/lerobot/policies/evo1/internvl3_embedder.py @@ -21,9 +21,9 @@ from typing import TYPE_CHECKING import torch import torch.nn as nn +import torch.nn.functional as F import torchvision.transforms.functional as tvf from PIL import Image -from torchvision.transforms.functional import to_pil_image from lerobot.utils.import_utils import _transformers_available, require_package @@ -191,40 +191,65 @@ class InternVL3Embedder(nn.Module): "Requested gradient checkpointing, but model does not expose checkpointing controls." ) - def _preprocess_single_image(self, image: Image.Image | torch.Tensor) -> torch.Tensor: - if isinstance(image, torch.Tensor): - pil_image = to_pil_image(image.detach().cpu()) - else: - pil_image = image.convert("RGB") - tiles = dynamic_preprocess(pil_image, image_size=self.image_size) - tile_tensors = torch.stack([tvf.to_tensor(tile) for tile in tiles]).to( - device=self.device, dtype=torch.bfloat16 - ) - mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1) - std = torch.tensor(IMAGENET_STD, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1) - return (tile_tensors - mean) / std + def _to_chw_float01(self, image: Image.Image | torch.Tensor) -> torch.Tensor: + """Return a (3, H, W) float tensor in [0, 1], staying on the source device.""" + if not isinstance(image, torch.Tensor): + # PIL only reaches this path on unusual callers; convert once and continue as tensor. + image = tvf.to_tensor(image.convert("RGB")) + image = image.detach() + if image.dim() == 2: + image = image.unsqueeze(0) + if image.shape[0] == 1: + image = image.expand(3, *image.shape[1:]) + if torch.is_floating_point(image): + return image.float() + # Integer tensors (e.g. uint8 in [0, 255]) are scaled to [0, 1] to match to_tensor(). + return image.float() / 255.0 def _preprocess_images( self, image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]], ) -> tuple[torch.Tensor, list[list[int]]]: - pixel_values_list = [] + # Every image is a single tile, so the per-image tile count is always 1. + flat_images: list[torch.Tensor] = [] batch_num_tiles_list: list[list[int]] = [] - for image_tensors in image_tensors_batch: - num_tiles_list: list[int] = [] + batch_num_tiles_list.append([1] * len(image_tensors)) for image in image_tensors: - tiles = self._preprocess_single_image(image) - pixel_values_list.append(tiles) - num_tiles_list.append(int(tiles.shape[0])) - batch_num_tiles_list.append(num_tiles_list) + flat_images.append(self._to_chw_float01(image)) - if pixel_values_list: - pixel_values = torch.cat(pixel_values_list, dim=0) - else: + if not flat_images: pixel_values = torch.empty( 0, 3, self.image_size, self.image_size, dtype=torch.bfloat16, device=self.device ) + return pixel_values, batch_num_tiles_list + + size = (self.image_size, self.image_size) + # Resize on the GPU in a single batched kernel instead of converting each image to PIL on + # the CPU. Cameras with matching resolutions stack into one interpolate call; differing + # resolutions fall back to a per-image interpolate that still runs on the GPU. + if len({tuple(img.shape) for img in flat_images}) == 1: + images = torch.stack(flat_images, dim=0).to(self.device, non_blocking=True) + resized = F.interpolate(images, size=size, mode="bicubic", antialias=True) + else: + resized = torch.cat( + [ + F.interpolate( + img.unsqueeze(0).to(self.device, non_blocking=True), + size=size, + mode="bicubic", + antialias=True, + ) + for img in flat_images + ], + dim=0, + ) + + # bicubic can overshoot [0, 1]; clamp to keep the input domain consistent before scaling. + resized = resized.clamp_(0.0, 1.0).to(dtype=torch.bfloat16) + mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1) + std = torch.tensor(IMAGENET_STD, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1) + pixel_values = (resized - mean) / std return pixel_values, batch_num_tiles_list def _build_multimodal_prompts( diff --git a/src/lerobot/policies/evo1/modeling_evo1.py b/src/lerobot/policies/evo1/modeling_evo1.py index 1d0de98de..c79d10113 100644 --- a/src/lerobot/policies/evo1/modeling_evo1.py +++ b/src/lerobot/policies/evo1/modeling_evo1.py @@ -342,15 +342,18 @@ class EVO1Policy(PreTrainedPolicy): image_batches: list[list[Tensor]] = [] image_masks = torch.zeros(batch_size, self.config.max_views, dtype=torch.bool) - cpu_images: dict[str, Tensor] = { - camera_key: normalized[camera_key].detach().cpu() + # Detach each camera tensor once for the whole batch and keep it on-device. The embedder + # resizes on the GPU, so there is no host round-trip; indexing per sample below is then a + # cheap view with no copy and no device sync. + detached_images: dict[str, Tensor] = { + camera_key: normalized[camera_key].detach() for camera_key in camera_keys[: self.config.max_views] } for batch_index in range(batch_size): sample_images: list[Tensor] = [] for camera_key in camera_keys[: self.config.max_views]: - sample_images.append(cpu_images[camera_key][batch_index]) + sample_images.append(detached_images[camera_key][batch_index]) if not sample_images: raise ValueError("EVO1 received a batch without any image tensor.") while len(sample_images) < self.config.max_views: