diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index 65ee46954..2ea10b288 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -750,27 +750,26 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces def _transform_n1_7_image_for_vlm_albumentations( - image: Image.Image, + image: np.ndarray | Image.Image, *, image_crop_size: list[int] | None, image_target_size: list[int] | None, shortest_image_edge: int | None, crop_fraction: float | None, -) -> Image.Image: +) -> np.ndarray | Image.Image: """cv2/INTER_AREA eval transform mirroring Isaac-GR00T's albumentations preprocessing. Used only for checkpoints saved with ``use_albumentations=True``. cv2 is CPU/numpy-only so this path cannot run on GPU; the default (non-albumentations) geometry is handled on-device by :func:`_transform_n1_7_image_for_vlm_torch`. The cv2/INTER_AREA resize and floored center-crop here intentionally differ from that - torch path and must stay bit-exact to the upstream reference. + torch path and must stay bit-exact to the upstream reference. The hot path accepts + and returns numpy arrays to avoid per-frame PIL round-trips. """ if image_target_size is None: return image target_h, target_w = image_target_size - if image.mode != "RGB": - image = image.convert("RGB") try: import cv2 @@ -779,7 +778,20 @@ def _transform_n1_7_image_for_vlm_albumentations( "GR00T N1.7 checkpoints with use_albumentations=True require opencv-python-headless." ) from exc - image_np = np.asarray(image) + if isinstance(image, Image.Image): + if image.mode != "RGB": + image = image.convert("RGB") + image_np = np.asarray(image) + else: + image_np = np.asarray(image) + if image_np.ndim == 2: + image_np = np.repeat(image_np[:, :, None], 3, axis=2) + elif image_np.shape[-1] == 4: + image_np = image_np[:, :, :3] + + if not image_np.flags.c_contiguous: + image_np = np.ascontiguousarray(image_np) + height, width = image_np.shape[:2] if height != width: square_edge = max(height, width) @@ -811,7 +823,7 @@ def _transform_n1_7_image_for_vlm_albumentations( if image_np.shape[:2] != (target_h, target_w): image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA) - return Image.fromarray(image_np) + return image_np def _transform_n1_7_image_for_vlm_torch( @@ -1192,7 +1204,7 @@ class GrootN17VLMEncodeStep(ProcessorStep): return self._proc def _target_device(self) -> torch.device | None: - # The albumentations path is cv2/PIL only, so it cannot run on GPU. + # The albumentations path is cv2/numpy only, so it cannot run on GPU. if self.device is None or self.use_albumentations: return None try: @@ -1209,7 +1221,7 @@ class GrootN17VLMEncodeStep(ProcessorStep): ) -> list[list[Any]]: """Return, per batch item, its ordered ``(timestep, view)`` frames. - ``use_albumentations`` keeps the legacy per-frame PIL/cv2 transform; + ``use_albumentations`` keeps the legacy per-frame cv2/INTER_AREA transform; otherwise frames are ``(C, H, W)`` uint8 tensors (moved to ``target_device`` when set) for the torchvision-backed Qwen processor. """ @@ -1218,7 +1230,7 @@ class GrootN17VLMEncodeStep(ProcessorStep): return [ [ _transform_n1_7_image_for_vlm_albumentations( - Image.fromarray(video_np[batch_idx, timestep, view_idx]), + video_np[batch_idx, timestep, view_idx], image_crop_size=self.image_crop_size, image_target_size=self.image_target_size, shortest_image_edge=self.shortest_image_edge,