mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-02 23:57:24 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 67213e91d5 | |||
| 1df19ae468 |
@@ -21,9 +21,9 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as tvf
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import to_pil_image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
@@ -191,40 +191,65 @@ class InternVL3Embedder(nn.Module):
|
||||
"Requested gradient checkpointing, but model does not expose checkpointing controls."
|
||||
)
|
||||
|
||||
def _preprocess_single_image(self, image: Image.Image | torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(image, torch.Tensor):
|
||||
pil_image = to_pil_image(image.detach().cpu())
|
||||
else:
|
||||
pil_image = image.convert("RGB")
|
||||
tiles = dynamic_preprocess(pil_image, image_size=self.image_size)
|
||||
tile_tensors = torch.stack([tvf.to_tensor(tile) for tile in tiles]).to(
|
||||
device=self.device, dtype=torch.bfloat16
|
||||
)
|
||||
mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
|
||||
std = torch.tensor(IMAGENET_STD, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
|
||||
return (tile_tensors - mean) / std
|
||||
def _to_chw_float01(self, image: Image.Image | torch.Tensor) -> torch.Tensor:
|
||||
"""Return a (3, H, W) float tensor in [0, 1], staying on the source device."""
|
||||
if not isinstance(image, torch.Tensor):
|
||||
# PIL only reaches this path on unusual callers; convert once and continue as tensor.
|
||||
image = tvf.to_tensor(image.convert("RGB"))
|
||||
image = image.detach()
|
||||
if image.dim() == 2:
|
||||
image = image.unsqueeze(0)
|
||||
if image.shape[0] == 1:
|
||||
image = image.expand(3, *image.shape[1:])
|
||||
if torch.is_floating_point(image):
|
||||
return image.float()
|
||||
# Integer tensors (e.g. uint8 in [0, 255]) are scaled to [0, 1] to match to_tensor().
|
||||
return image.float() / 255.0
|
||||
|
||||
def _preprocess_images(
|
||||
self,
|
||||
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
|
||||
) -> tuple[torch.Tensor, list[list[int]]]:
|
||||
pixel_values_list = []
|
||||
# Every image is a single tile, so the per-image tile count is always 1.
|
||||
flat_images: list[torch.Tensor] = []
|
||||
batch_num_tiles_list: list[list[int]] = []
|
||||
|
||||
for image_tensors in image_tensors_batch:
|
||||
num_tiles_list: list[int] = []
|
||||
batch_num_tiles_list.append([1] * len(image_tensors))
|
||||
for image in image_tensors:
|
||||
tiles = self._preprocess_single_image(image)
|
||||
pixel_values_list.append(tiles)
|
||||
num_tiles_list.append(int(tiles.shape[0]))
|
||||
batch_num_tiles_list.append(num_tiles_list)
|
||||
flat_images.append(self._to_chw_float01(image))
|
||||
|
||||
if pixel_values_list:
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
else:
|
||||
if not flat_images:
|
||||
pixel_values = torch.empty(
|
||||
0, 3, self.image_size, self.image_size, dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
return pixel_values, batch_num_tiles_list
|
||||
|
||||
size = (self.image_size, self.image_size)
|
||||
# Resize on the GPU in a single batched kernel instead of converting each image to PIL on
|
||||
# the CPU. Cameras with matching resolutions stack into one interpolate call; differing
|
||||
# resolutions fall back to a per-image interpolate that still runs on the GPU.
|
||||
if len({tuple(img.shape) for img in flat_images}) == 1:
|
||||
images = torch.stack(flat_images, dim=0).to(self.device, non_blocking=True)
|
||||
resized = F.interpolate(images, size=size, mode="bicubic", antialias=True)
|
||||
else:
|
||||
resized = torch.cat(
|
||||
[
|
||||
F.interpolate(
|
||||
img.unsqueeze(0).to(self.device, non_blocking=True),
|
||||
size=size,
|
||||
mode="bicubic",
|
||||
antialias=True,
|
||||
)
|
||||
for img in flat_images
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# bicubic can overshoot [0, 1]; clamp to keep the input domain consistent before scaling.
|
||||
resized = resized.clamp_(0.0, 1.0).to(dtype=torch.bfloat16)
|
||||
mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
|
||||
std = torch.tensor(IMAGENET_STD, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
|
||||
pixel_values = (resized - mean) / std
|
||||
return pixel_values, batch_num_tiles_list
|
||||
|
||||
def _build_multimodal_prompts(
|
||||
|
||||
@@ -342,10 +342,18 @@ class EVO1Policy(PreTrainedPolicy):
|
||||
image_batches: list[list[Tensor]] = []
|
||||
image_masks = torch.zeros(batch_size, self.config.max_views, dtype=torch.bool)
|
||||
|
||||
# Detach each camera tensor once for the whole batch and keep it on-device. The embedder
|
||||
# resizes on the GPU, so there is no host round-trip; indexing per sample below is then a
|
||||
# cheap view with no copy and no device sync.
|
||||
detached_images: dict[str, Tensor] = {
|
||||
camera_key: normalized[camera_key].detach()
|
||||
for camera_key in camera_keys[: self.config.max_views]
|
||||
}
|
||||
|
||||
for batch_index in range(batch_size):
|
||||
sample_images: list[Tensor] = []
|
||||
for camera_key in camera_keys[: self.config.max_views]:
|
||||
sample_images.append(normalized[camera_key][batch_index].detach().cpu())
|
||||
sample_images.append(detached_images[camera_key][batch_index])
|
||||
if not sample_images:
|
||||
raise ValueError("EVO1 received a batch without any image tensor.")
|
||||
while len(sample_images) < self.config.max_views:
|
||||
|
||||
Reference in New Issue
Block a user