Remove PIL fallback from GR00T preprocessing

This commit is contained in:
Andy Wrenn
2026-06-20 07:02:00 -07:00
parent ee41109d35
commit 5a83db89de
+7 -13
View File
@@ -24,7 +24,6 @@ import numpy as np
import torch
import torchvision.transforms.v2.functional as tv_functional
from einops import rearrange
from PIL import Image
from torchvision.transforms import InterpolationMode
from lerobot.utils.import_utils import _transformers_available, require_package
@@ -750,13 +749,13 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
def _transform_n1_7_image_for_vlm_albumentations(
image: np.ndarray | Image.Image,
image: np.ndarray,
*,
image_crop_size: list[int] | None,
image_target_size: list[int] | None,
shortest_image_edge: int | None,
crop_fraction: float | None,
) -> np.ndarray | Image.Image:
) -> np.ndarray:
"""cv2/INTER_AREA eval transform mirroring Isaac-GR00T's albumentations preprocessing.
Used only for checkpoints saved with ``use_albumentations=True``. cv2 is
@@ -778,16 +777,11 @@ def _transform_n1_7_image_for_vlm_albumentations(
"GR00T N1.7 checkpoints with use_albumentations=True require opencv-python-headless."
) from exc
if isinstance(image, Image.Image):
if image.mode != "RGB":
image = image.convert("RGB")
image_np = np.asarray(image)
else:
image_np = np.asarray(image)
if image_np.ndim == 2:
image_np = np.repeat(image_np[:, :, None], 3, axis=2)
elif image_np.shape[-1] == 4:
image_np = image_np[:, :, :3]
image_np = np.asarray(image)
if image_np.ndim == 2:
image_np = np.repeat(image_np[:, :, None], 3, axis=2)
elif image_np.ndim == 3 and image_np.shape[-1] == 4:
image_np = image_np[:, :, :3]
if not image_np.flags.c_contiguous:
image_np = np.ascontiguousarray(image_np)