Compare commits

...

2 Commits

Author SHA1 Message Date
Martino Russi 67213e91d5 remove PIL from forward 2026-07-01 18:06:07 +00:00
Martino Russi 1df19ae468 only call .detach().cpu() once per caemra instead of once per image 2026-07-01 17:45:16 +00:00
2 changed files with 57 additions and 24 deletions
+48 -23
View File
@@ -21,9 +21,9 @@ from typing import TYPE_CHECKING
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as tvf
from PIL import Image
from torchvision.transforms.functional import to_pil_image
from lerobot.utils.import_utils import _transformers_available, require_package
@@ -191,40 +191,65 @@ class InternVL3Embedder(nn.Module):
"Requested gradient checkpointing, but model does not expose checkpointing controls."
)
def _preprocess_single_image(self, image: Image.Image | torch.Tensor) -> torch.Tensor:
if isinstance(image, torch.Tensor):
pil_image = to_pil_image(image.detach().cpu())
else:
pil_image = image.convert("RGB")
tiles = dynamic_preprocess(pil_image, image_size=self.image_size)
tile_tensors = torch.stack([tvf.to_tensor(tile) for tile in tiles]).to(
device=self.device, dtype=torch.bfloat16
)
mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
std = torch.tensor(IMAGENET_STD, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
return (tile_tensors - mean) / std
def _to_chw_float01(self, image: Image.Image | torch.Tensor) -> torch.Tensor:
"""Return a (3, H, W) float tensor in [0, 1], staying on the source device."""
if not isinstance(image, torch.Tensor):
# PIL only reaches this path on unusual callers; convert once and continue as tensor.
image = tvf.to_tensor(image.convert("RGB"))
image = image.detach()
if image.dim() == 2:
image = image.unsqueeze(0)
if image.shape[0] == 1:
image = image.expand(3, *image.shape[1:])
if torch.is_floating_point(image):
return image.float()
# Integer tensors (e.g. uint8 in [0, 255]) are scaled to [0, 1] to match to_tensor().
return image.float() / 255.0
def _preprocess_images(
self,
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
) -> tuple[torch.Tensor, list[list[int]]]:
pixel_values_list = []
# Every image is a single tile, so the per-image tile count is always 1.
flat_images: list[torch.Tensor] = []
batch_num_tiles_list: list[list[int]] = []
for image_tensors in image_tensors_batch:
num_tiles_list: list[int] = []
batch_num_tiles_list.append([1] * len(image_tensors))
for image in image_tensors:
tiles = self._preprocess_single_image(image)
pixel_values_list.append(tiles)
num_tiles_list.append(int(tiles.shape[0]))
batch_num_tiles_list.append(num_tiles_list)
flat_images.append(self._to_chw_float01(image))
if pixel_values_list:
pixel_values = torch.cat(pixel_values_list, dim=0)
else:
if not flat_images:
pixel_values = torch.empty(
0, 3, self.image_size, self.image_size, dtype=torch.bfloat16, device=self.device
)
return pixel_values, batch_num_tiles_list
size = (self.image_size, self.image_size)
# Resize on the GPU in a single batched kernel instead of converting each image to PIL on
# the CPU. Cameras with matching resolutions stack into one interpolate call; differing
# resolutions fall back to a per-image interpolate that still runs on the GPU.
if len({tuple(img.shape) for img in flat_images}) == 1:
images = torch.stack(flat_images, dim=0).to(self.device, non_blocking=True)
resized = F.interpolate(images, size=size, mode="bicubic", antialias=True)
else:
resized = torch.cat(
[
F.interpolate(
img.unsqueeze(0).to(self.device, non_blocking=True),
size=size,
mode="bicubic",
antialias=True,
)
for img in flat_images
],
dim=0,
)
# bicubic can overshoot [0, 1]; clamp to keep the input domain consistent before scaling.
resized = resized.clamp_(0.0, 1.0).to(dtype=torch.bfloat16)
mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
std = torch.tensor(IMAGENET_STD, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
pixel_values = (resized - mean) / std
return pixel_values, batch_num_tiles_list
def _build_multimodal_prompts(
+9 -1
View File
@@ -342,10 +342,18 @@ class EVO1Policy(PreTrainedPolicy):
image_batches: list[list[Tensor]] = []
image_masks = torch.zeros(batch_size, self.config.max_views, dtype=torch.bool)
# Detach each camera tensor once for the whole batch and keep it on-device. The embedder
# resizes on the GPU, so there is no host round-trip; indexing per sample below is then a
# cheap view with no copy and no device sync.
detached_images: dict[str, Tensor] = {
camera_key: normalized[camera_key].detach()
for camera_key in camera_keys[: self.config.max_views]
}
for batch_index in range(batch_size):
sample_images: list[Tensor] = []
for camera_key in camera_keys[: self.config.max_views]:
sample_images.append(normalized[camera_key][batch_index].detach().cpu())
sample_images.append(detached_images[camera_key][batch_index])
if not sample_images:
raise ValueError("EVO1 received a batch without any image tensor.")
while len(sample_images) < self.config.max_views: