mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-04 16:47:14 +00:00
Optimize GR00T N1.7 image preprocessing
This commit is contained in:
@@ -750,27 +750,26 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
|
||||
|
||||
|
||||
def _transform_n1_7_image_for_vlm_albumentations(
|
||||
image: Image.Image,
|
||||
image: np.ndarray | Image.Image,
|
||||
*,
|
||||
image_crop_size: list[int] | None,
|
||||
image_target_size: list[int] | None,
|
||||
shortest_image_edge: int | None,
|
||||
crop_fraction: float | None,
|
||||
) -> Image.Image:
|
||||
) -> np.ndarray | Image.Image:
|
||||
"""cv2/INTER_AREA eval transform mirroring Isaac-GR00T's albumentations preprocessing.
|
||||
|
||||
Used only for checkpoints saved with ``use_albumentations=True``. cv2 is
|
||||
CPU/numpy-only so this path cannot run on GPU; the default (non-albumentations)
|
||||
geometry is handled on-device by :func:`_transform_n1_7_image_for_vlm_torch`. The
|
||||
cv2/INTER_AREA resize and floored center-crop here intentionally differ from that
|
||||
torch path and must stay bit-exact to the upstream reference.
|
||||
torch path and must stay bit-exact to the upstream reference. The hot path accepts
|
||||
and returns numpy arrays to avoid per-frame PIL round-trips.
|
||||
"""
|
||||
if image_target_size is None:
|
||||
return image
|
||||
|
||||
target_h, target_w = image_target_size
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
try:
|
||||
import cv2
|
||||
@@ -779,7 +778,20 @@ def _transform_n1_7_image_for_vlm_albumentations(
|
||||
"GR00T N1.7 checkpoints with use_albumentations=True require opencv-python-headless."
|
||||
) from exc
|
||||
|
||||
image_np = np.asarray(image)
|
||||
if isinstance(image, Image.Image):
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
image_np = np.asarray(image)
|
||||
else:
|
||||
image_np = np.asarray(image)
|
||||
if image_np.ndim == 2:
|
||||
image_np = np.repeat(image_np[:, :, None], 3, axis=2)
|
||||
elif image_np.shape[-1] == 4:
|
||||
image_np = image_np[:, :, :3]
|
||||
|
||||
if not image_np.flags.c_contiguous:
|
||||
image_np = np.ascontiguousarray(image_np)
|
||||
|
||||
height, width = image_np.shape[:2]
|
||||
if height != width:
|
||||
square_edge = max(height, width)
|
||||
@@ -811,7 +823,7 @@ def _transform_n1_7_image_for_vlm_albumentations(
|
||||
|
||||
if image_np.shape[:2] != (target_h, target_w):
|
||||
image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA)
|
||||
return Image.fromarray(image_np)
|
||||
return image_np
|
||||
|
||||
|
||||
def _transform_n1_7_image_for_vlm_torch(
|
||||
@@ -1192,7 +1204,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
return self._proc
|
||||
|
||||
def _target_device(self) -> torch.device | None:
|
||||
# The albumentations path is cv2/PIL only, so it cannot run on GPU.
|
||||
# The albumentations path is cv2/numpy only, so it cannot run on GPU.
|
||||
if self.device is None or self.use_albumentations:
|
||||
return None
|
||||
try:
|
||||
@@ -1209,7 +1221,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
) -> list[list[Any]]:
|
||||
"""Return, per batch item, its ordered ``(timestep, view)`` frames.
|
||||
|
||||
``use_albumentations`` keeps the legacy per-frame PIL/cv2 transform;
|
||||
``use_albumentations`` keeps the legacy per-frame cv2/INTER_AREA transform;
|
||||
otherwise frames are ``(C, H, W)`` uint8 tensors (moved to
|
||||
``target_device`` when set) for the torchvision-backed Qwen processor.
|
||||
"""
|
||||
@@ -1218,7 +1230,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
return [
|
||||
[
|
||||
_transform_n1_7_image_for_vlm_albumentations(
|
||||
Image.fromarray(video_np[batch_idx, timestep, view_idx]),
|
||||
video_np[batch_idx, timestep, view_idx],
|
||||
image_crop_size=self.image_crop_size,
|
||||
image_target_size=self.image_target_size,
|
||||
shortest_image_edge=self.shortest_image_edge,
|
||||
|
||||
Reference in New Issue
Block a user