From 5a83db89de75675a4c17ec65013c0f1e2ff5b843 Mon Sep 17 00:00:00 2001 From: Andy Wrenn Date: Sat, 20 Jun 2026 07:02:00 -0700 Subject: [PATCH] Remove PIL fallback from GR00T preprocessing --- src/lerobot/policies/groot/processor_groot.py | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index 2ea10b288..bd711c00b 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -24,7 +24,6 @@ import numpy as np import torch import torchvision.transforms.v2.functional as tv_functional from einops import rearrange -from PIL import Image from torchvision.transforms import InterpolationMode from lerobot.utils.import_utils import _transformers_available, require_package @@ -750,13 +749,13 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces def _transform_n1_7_image_for_vlm_albumentations( - image: np.ndarray | Image.Image, + image: np.ndarray, *, image_crop_size: list[int] | None, image_target_size: list[int] | None, shortest_image_edge: int | None, crop_fraction: float | None, -) -> np.ndarray | Image.Image: +) -> np.ndarray: """cv2/INTER_AREA eval transform mirroring Isaac-GR00T's albumentations preprocessing. Used only for checkpoints saved with ``use_albumentations=True``. cv2 is @@ -778,16 +777,11 @@ def _transform_n1_7_image_for_vlm_albumentations( "GR00T N1.7 checkpoints with use_albumentations=True require opencv-python-headless." ) from exc - if isinstance(image, Image.Image): - if image.mode != "RGB": - image = image.convert("RGB") - image_np = np.asarray(image) - else: - image_np = np.asarray(image) - if image_np.ndim == 2: - image_np = np.repeat(image_np[:, :, None], 3, axis=2) - elif image_np.shape[-1] == 4: - image_np = image_np[:, :, :3] + image_np = np.asarray(image) + if image_np.ndim == 2: + image_np = np.repeat(image_np[:, :, None], 3, axis=2) + elif image_np.ndim == 3 and image_np.shape[-1] == 4: + image_np = image_np[:, :, :3] if not image_np.flags.c_contiguous: image_np = np.ascontiguousarray(image_np)