Optimize GR00T N1.7 image preprocessing

This commit is contained in:
Andy Wrenn
2026-06-20 06:58:10 -07:00
parent 229299d937
commit ee41109d35
+22 -10
View File
@@ -750,27 +750,26 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
def _transform_n1_7_image_for_vlm_albumentations(
image: Image.Image,
image: np.ndarray | Image.Image,
*,
image_crop_size: list[int] | None,
image_target_size: list[int] | None,
shortest_image_edge: int | None,
crop_fraction: float | None,
) -> Image.Image:
) -> np.ndarray | Image.Image:
"""cv2/INTER_AREA eval transform mirroring Isaac-GR00T's albumentations preprocessing.
Used only for checkpoints saved with ``use_albumentations=True``. cv2 is
CPU/numpy-only so this path cannot run on GPU; the default (non-albumentations)
geometry is handled on-device by :func:`_transform_n1_7_image_for_vlm_torch`. The
cv2/INTER_AREA resize and floored center-crop here intentionally differ from that
torch path and must stay bit-exact to the upstream reference.
torch path and must stay bit-exact to the upstream reference. The hot path accepts
and returns numpy arrays to avoid per-frame PIL round-trips.
"""
if image_target_size is None:
return image
target_h, target_w = image_target_size
if image.mode != "RGB":
image = image.convert("RGB")
try:
import cv2
@@ -779,7 +778,20 @@ def _transform_n1_7_image_for_vlm_albumentations(
"GR00T N1.7 checkpoints with use_albumentations=True require opencv-python-headless."
) from exc
image_np = np.asarray(image)
if isinstance(image, Image.Image):
if image.mode != "RGB":
image = image.convert("RGB")
image_np = np.asarray(image)
else:
image_np = np.asarray(image)
if image_np.ndim == 2:
image_np = np.repeat(image_np[:, :, None], 3, axis=2)
elif image_np.shape[-1] == 4:
image_np = image_np[:, :, :3]
if not image_np.flags.c_contiguous:
image_np = np.ascontiguousarray(image_np)
height, width = image_np.shape[:2]
if height != width:
square_edge = max(height, width)
@@ -811,7 +823,7 @@ def _transform_n1_7_image_for_vlm_albumentations(
if image_np.shape[:2] != (target_h, target_w):
image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA)
return Image.fromarray(image_np)
return image_np
def _transform_n1_7_image_for_vlm_torch(
@@ -1192,7 +1204,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
return self._proc
def _target_device(self) -> torch.device | None:
# The albumentations path is cv2/PIL only, so it cannot run on GPU.
# The albumentations path is cv2/numpy only, so it cannot run on GPU.
if self.device is None or self.use_albumentations:
return None
try:
@@ -1209,7 +1221,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
) -> list[list[Any]]:
"""Return, per batch item, its ordered ``(timestep, view)`` frames.
``use_albumentations`` keeps the legacy per-frame PIL/cv2 transform;
``use_albumentations`` keeps the legacy per-frame cv2/INTER_AREA transform;
otherwise frames are ``(C, H, W)`` uint8 tensors (moved to
``target_device`` when set) for the torchvision-backed Qwen processor.
"""
@@ -1218,7 +1230,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
return [
[
_transform_n1_7_image_for_vlm_albumentations(
Image.fromarray(video_np[batch_idx, timestep, view_idx]),
video_np[batch_idx, timestep, view_idx],
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,