Compare commits

...

1 Commits

Author SHA1 Message Date
Steven Palma 93e29b0cfc fix(groot): GPU/tensor N1.7 image preprocessing + resize to trained resolution
GR00T training was dataloader-bound (0->100->0 GPU-utilization sawtooth).
GrootN17VLMEncodeStep ran the Qwen3-VL image processor per frame on PIL images
on the single CPU main-loop thread, and that cost is timed inside dataloading_s
(preprocessor(batch) runs in the main process, not the dataloader workers), so
adding workers cannot hide it.

- Feed the torchvision-backed Qwen3-VL processor (C,H,W) uint8 tensors instead
  of a per-frame Image.fromarray PIL roundtrip, and run resize/normalize/patchify
  on config.device (GPU) when available. Bit-identical on CPU when no resize is
  configured; with a resize only the PIL->torchvision bicubic backend differs
  (<2/255 per pixel). The use_albumentations path stays PIL/cv2; reload on a box
  without the saved device falls back to CPU.

- Default image_target_size/crop to the N1.7 backbone's training geometry
  (256x256 / 230x230) when a checkpoint ships no image sizing (checkpoint_assets
  is None, e.g. finetuning nvidia/GR00T-N1.7-3B via repo-id with a new
  embodiment). Previously image_target_size=None disabled the resize, so
  full-resolution frames were patchified into ~4.7x more vision tokens than the
  model was trained on -- inflating dataloading_s (patchify) and update_s (VLM
  sequence) and skewing the input distribution. Checkpoints that pin their own
  sizing are honored; the default constants are shared with GR00T_N1_7_DEFAULTS.

Net: preprocessing leaves the CPU critical path and the VLM sees the resolution
it was trained on -- faster training/inference and a correct train/serve
distribution. Affects inference too (shared preprocessor); existing checkpoints
still load (backward compatible) but must be retrained to gain the benefits.
2026-06-15 18:14:09 +02:00
4 changed files with 224 additions and 91 deletions
@@ -42,6 +42,10 @@ GROOT_N1_5_REMOVAL_GUIDANCE = (
)
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
# Default GR00T N1.7 training resolution. Fallback if processor_config lacks sizing. Prevents mismatched
# full-res patchification by forcing a resize. Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
+3 -2
View File
@@ -32,6 +32,7 @@ from torch.distributions import Beta
from lerobot.utils.import_utils import _transformers_available, require_package
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
if TYPE_CHECKING or _transformers_available:
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
@@ -76,8 +77,8 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"use_flash_attention": True,
"load_bf16": False,
"backbone_trainable_params_fp32": True,
"image_crop_size": (230, 230),
"image_target_size": (256, 256),
"image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE,
"image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE,
"shortest_image_edge": None,
"crop_fraction": None,
"random_rotation_angle": None,
+215 -86
View File
@@ -23,8 +23,10 @@ from typing import TYPE_CHECKING, Any
import numpy as np
import torch
import torchvision.transforms.v2.functional as tv_functional
from einops import rearrange
from PIL import Image
from torchvision.transforms import InterpolationMode
from lerobot.utils.import_utils import _transformers_available
@@ -57,11 +59,14 @@ from lerobot.utils.constants import (
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from lerobot.utils.device_utils import get_safe_torch_device
from .configuration_groot import (
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
GROOT_N1_5_REMOVAL_GUIDANCE,
GROOT_N1_7_BACKBONE_MODEL,
N1_7_DEFAULT_IMAGE_CROP_SIZE,
N1_7_DEFAULT_IMAGE_TARGET_SIZE,
GrootConfig,
is_raw_groot_n1_7_checkpoint,
)
@@ -729,21 +734,36 @@ def make_groot_pre_post_processors(
modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None,
)
# Resolve the image preprocessing geometry. Honor the checkpoint's processor_config
# when it provides an image_target_size; otherwise fall back to the geometry the
# N1.7 backbone was trained on. Without this fallback a raw base checkpoint with no
# processor_config image sizing (e.g. fine-tuning nvidia/GR00T-N1.7-3B with a new
# embodiment, where checkpoint_assets is None) would patchify full-resolution camera
# frames, inflating the VLM token count and feeding the model a resolution it was not trained on.
if checkpoint_assets is not None and checkpoint_assets.image_target_size is not None:
image_target_size = checkpoint_assets.image_target_size
image_crop_size = checkpoint_assets.image_crop_size
shortest_image_edge = checkpoint_assets.shortest_image_edge
crop_fraction = checkpoint_assets.crop_fraction
else:
image_target_size = list(N1_7_DEFAULT_IMAGE_TARGET_SIZE)
image_crop_size = list(N1_7_DEFAULT_IMAGE_CROP_SIZE)
shortest_image_edge = None
crop_fraction = None
use_albumentations = checkpoint_assets.use_albumentations if checkpoint_assets is not None else False
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
pack_step,
GrootN17VLMEncodeStep(
model_name=config.n1_7_backbone_model,
image_crop_size=checkpoint_assets.image_crop_size if checkpoint_assets is not None else None,
image_target_size=checkpoint_assets.image_target_size if checkpoint_assets is not None else None,
shortest_image_edge=checkpoint_assets.shortest_image_edge
if checkpoint_assets is not None
else None,
crop_fraction=checkpoint_assets.crop_fraction if checkpoint_assets is not None else None,
use_albumentations=checkpoint_assets.use_albumentations
if checkpoint_assets is not None
else False,
image_crop_size=image_crop_size,
image_target_size=image_target_size,
shortest_image_edge=shortest_image_edge,
crop_fraction=crop_fraction,
use_albumentations=use_albumentations,
device=config.device,
),
DeviceProcessorStep(device=config.device),
]
@@ -899,15 +919,22 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
return proc
def _transform_n1_7_image_for_vlm(
def _transform_n1_7_image_for_vlm_albumentations(
image: Image.Image,
*,
image_crop_size: list[int] | None,
image_target_size: list[int] | None,
shortest_image_edge: int | None,
crop_fraction: float | None,
use_albumentations: bool = False,
) -> Image.Image:
"""cv2/INTER_AREA eval transform mirroring Isaac-GR00T's albumentations preprocessing.
Used only for checkpoints saved with ``use_albumentations=True``. cv2 is
CPU/numpy-only so this path cannot run on GPU; the default (non-albumentations)
geometry is handled on-device by :func:`_transform_n1_7_image_for_vlm_torch`. The
cv2/INTER_AREA resize and floored center-crop here intentionally differ from that
torch path and must stay bit-exact to the upstream reference.
"""
if image_target_size is None:
return image
@@ -915,70 +942,101 @@ def _transform_n1_7_image_for_vlm(
if image.mode != "RGB":
image = image.convert("RGB")
if use_albumentations:
try:
import cv2
except ImportError as exc:
raise ImportError(
"GR00T N1.7 checkpoints with use_albumentations=True require opencv-python-headless."
) from exc
try:
import cv2
except ImportError as exc:
raise ImportError(
"GR00T N1.7 checkpoints with use_albumentations=True require opencv-python-headless."
) from exc
image_np = np.asarray(image)
height, width = image_np.shape[:2]
if height != width:
square_edge = max(height, width)
pad_h = square_edge - height
pad_w = square_edge - width
image_np = cv2.copyMakeBorder(
image_np,
pad_h // 2,
pad_h - pad_h // 2,
pad_w // 2,
pad_w - pad_w // 2,
cv2.BORDER_CONSTANT,
value=(0, 0, 0),
)
resize_edge = shortest_image_edge or target_h
if image_np.shape[:2] != (resize_edge, resize_edge):
image_np = cv2.resize(image_np, (resize_edge, resize_edge), interpolation=cv2.INTER_AREA)
if crop_fraction is None and image_crop_size is not None:
crop_fraction = image_crop_size[0] / float(target_h)
if crop_fraction is not None and 0.0 < crop_fraction < 1.0:
height, width = image_np.shape[:2]
crop_h = max(1, int(height * crop_fraction))
crop_w = max(1, int(width * crop_fraction))
top = max(0, (height - crop_h) // 2)
left = max(0, (width - crop_w) // 2)
image_np = image_np[top : top + crop_h, left : left + crop_w]
if image_np.shape[:2] != (target_h, target_w):
image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA)
return Image.fromarray(image_np)
square_edge = max(image.width, image.height)
if image.width != image.height:
padded = Image.new("RGB", (square_edge, square_edge))
left = (square_edge - image.width) // 2
top = (square_edge - image.height) // 2
padded.paste(image, (left, top))
image = padded
image_np = np.asarray(image)
height, width = image_np.shape[:2]
if height != width:
square_edge = max(height, width)
pad_h = square_edge - height
pad_w = square_edge - width
image_np = cv2.copyMakeBorder(
image_np,
pad_h // 2,
pad_h - pad_h // 2,
pad_w // 2,
pad_w - pad_w // 2,
cv2.BORDER_CONSTANT,
value=(0, 0, 0),
)
resize_edge = shortest_image_edge or target_h
image = image.resize((resize_edge, resize_edge), Image.Resampling.BICUBIC)
if image_np.shape[:2] != (resize_edge, resize_edge):
image_np = cv2.resize(image_np, (resize_edge, resize_edge), interpolation=cv2.INTER_AREA)
if crop_fraction is None and image_crop_size is not None:
crop_fraction = image_crop_size[0] / float(target_h)
if crop_fraction is not None and 0.0 < crop_fraction < 1.0:
crop_w = max(1, int(round(image.width * crop_fraction)))
crop_h = max(1, int(round(image.height * crop_fraction)))
left = max(0, (image.width - crop_w) // 2)
top = max(0, (image.height - crop_h) // 2)
image = image.crop((left, top, left + crop_w, top + crop_h))
height, width = image_np.shape[:2]
crop_h = max(1, int(height * crop_fraction))
crop_w = max(1, int(width * crop_fraction))
top = max(0, (height - crop_h) // 2)
left = max(0, (width - crop_w) // 2)
image_np = image_np[top : top + crop_h, left : left + crop_w]
if image.size != (target_w, target_h):
image = image.resize((target_w, target_h), Image.Resampling.BICUBIC)
if image_np.shape[:2] != (target_h, target_w):
image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA)
return Image.fromarray(image_np)
def _transform_n1_7_image_for_vlm_torch(
image: torch.Tensor,
*,
image_crop_size: list[int] | None,
image_target_size: list[int] | None,
shortest_image_edge: int | None,
crop_fraction: float | None,
) -> torch.Tensor:
"""Default (non-albumentations) N1.7 image transform: pad-to-square, resize to
``shortest_image_edge``, center-crop by ``crop_fraction``, resize to ``image_target_size``.
Operates on a ``(C, H, W)`` uint8 tensor and keeps the result on the input
tensor's device so the resize/crop run on GPU when the tensor is. Bicubic
interpolation with antialiasing matches PIL's ``Image.Resampling.BICUBIC``
closely (sub-``2/255`` per-pixel on worst-case inputs). The ``use_albumentations``
cv2/INTER_AREA path has no torch equivalent and stays on
:func:`_transform_n1_7_image_for_vlm_albumentations`.
"""
if image_target_size is None:
return image
target_h, target_w = image_target_size
_, height, width = image.shape
square_edge = max(height, width)
if height != width:
left = (square_edge - width) // 2
top = (square_edge - height) // 2
image = tv_functional.pad(
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
)
resize_edge = shortest_image_edge or target_h
image = tv_functional.resize(
image, [resize_edge, resize_edge], interpolation=InterpolationMode.BICUBIC, antialias=True
)
if crop_fraction is None and image_crop_size is not None:
crop_fraction = image_crop_size[0] / float(target_h)
if crop_fraction is not None and 0.0 < crop_fraction < 1.0:
# Match the PIL helper's center crop exactly: round() the crop size but
# floor() the offset (torchvision.center_crop rounds the offset, which
# shifts the region by 1px when (edge - crop) is odd).
crop_h = max(1, int(round(image.shape[-2] * crop_fraction)))
crop_w = max(1, int(round(image.shape[-1] * crop_fraction)))
top = max(0, (image.shape[-2] - crop_h) // 2)
left = max(0, (image.shape[-1] - crop_w) // 2)
image = image[..., top : top + crop_h, left : left + crop_w]
if tuple(image.shape[-2:]) != (target_h, target_w):
image = tv_functional.resize(
image, [target_h, target_w], interpolation=InterpolationMode.BICUBIC, antialias=True
)
return image
@@ -1280,6 +1338,12 @@ class GrootN17VLMEncodeStep(ProcessorStep):
The packed video has shape ``(B, T, V, H, W, C)``. Each frame/view becomes
an image item in the same chat message so the resulting image tokens match
the temporal VLM packing used by Isaac-GR00T.
Images are handed to the torchvision-backed Qwen3-VL processor as ``(C, H, W)``
uint8 tensors (no per-frame PIL roundtrip), and, when ``device`` resolves to a
CUDA device, the resize/rescale/normalize/patchify run there. This keeps the
output bit-identical on CPU and moves the dominant preprocessing cost off
the critical path on GPU.
"""
model_name: str = GROOT_N1_7_BACKBONE_MODEL
@@ -1288,6 +1352,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
shortest_image_edge: int | None = None
crop_fraction: float | None = None
use_albumentations: bool = False
device: str | None = None
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
@property
@@ -1296,6 +1361,69 @@ class GrootN17VLMEncodeStep(ProcessorStep):
self._proc = _build_n1_7_processor(self.model_name)
return self._proc
def _target_device(self) -> torch.device | None:
# The albumentations path is cv2/PIL only, so it cannot run on GPU.
if self.device is None or self.use_albumentations:
return None
try:
return get_safe_torch_device(self.device)
except (AssertionError, RuntimeError):
# A device serialized at train time (e.g. "cuda") may be unavailable
# when the processor is reloaded elsewhere (e.g. CPU-only eval), and
# this step is not in the standard device-override set. Fall back to
# the CPU path, which is bit-identical, instead of crashing.
return None
def _build_sample_images(
self, video: Any, batch_size: int, target_device: torch.device | None
) -> list[list[Any]]:
"""Return, per batch item, its ordered ``(timestep, view)`` frames.
``use_albumentations`` keeps the legacy per-frame PIL/cv2 transform;
otherwise frames are ``(C, H, W)`` uint8 tensors (moved to
``target_device`` when set) for the torchvision-backed Qwen processor.
"""
if self.use_albumentations:
video_np = np.asarray(video)
return [
[
_transform_n1_7_image_for_vlm_albumentations(
Image.fromarray(video_np[batch_idx, timestep, view_idx]),
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
)
for timestep in range(video_np.shape[1])
for view_idx in range(video_np.shape[2])
]
for batch_idx in range(batch_size)
]
video_t = video if torch.is_tensor(video) else torch.from_numpy(np.ascontiguousarray(video))
# (B, T, V, H, W, C) uint8 -> (B, T, V, C, H, W)
video_t = video_t.permute(0, 1, 2, 5, 3, 4).contiguous()
if target_device is not None and video_t.device != target_device:
video_t = video_t.to(target_device, non_blocking=(target_device.type == "cuda"))
frames_per_sample: list[list[Any]] = []
for batch_idx in range(batch_size):
sample = video_t[batch_idx] # (T, V, C, H, W)
frames_per_sample.append(
[
_transform_n1_7_image_for_vlm_torch(
sample[timestep, view_idx],
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
)
for timestep in range(sample.shape[0])
for view_idx in range(sample.shape[1])
]
)
return frames_per_sample
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
@@ -1303,33 +1431,25 @@ class GrootN17VLMEncodeStep(ProcessorStep):
if video is None:
return transition
batch_size = int(video.shape[0])
languages = _prepare_n1_7_language_batch(
comp.get("language"),
video.shape[0],
batch_size,
formalize_language=False,
)
target_device = self._target_device()
sample_images = self._build_sample_images(video, batch_size, target_device)
texts: list[str] = []
images: list[Image.Image] = []
for batch_idx in range(video.shape[0]):
sample = video[batch_idx] # (T, V, H, W, C)
sample_images = [
_transform_n1_7_image_for_vlm(
Image.fromarray(sample[timestep, view_idx]),
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
use_albumentations=self.use_albumentations,
)
for timestep in range(sample.shape[0])
for view_idx in range(sample.shape[1])
]
images: list[Any] = []
for batch_idx in range(batch_size):
frames = sample_images[batch_idx]
conversation = [
{
"role": "user",
"content": [
*[{"type": "image", "image": image} for image in sample_images],
*[{"type": "image", "image": image} for image in frames],
{"type": "text", "text": languages[batch_idx]},
],
}
@@ -1341,9 +1461,17 @@ class GrootN17VLMEncodeStep(ProcessorStep):
add_generation_prompt=False,
)
)
images.extend(sample_images)
images.extend(frames)
encoded = self.proc(text=texts, images=images, return_tensors="pt", padding=True)
proc_kwargs: dict[str, Any] = {
"text": texts,
"images": images,
"return_tensors": "pt",
"padding": True,
}
if target_device is not None:
proc_kwargs["device"] = str(target_device)
encoded = self.proc(**proc_kwargs)
for key, value in encoded.items():
comp[key] = value
obs.pop("video", None)
@@ -1362,6 +1490,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
"shortest_image_edge": self.shortest_image_edge,
"crop_fraction": self.crop_fraction,
"use_albumentations": self.use_albumentations,
"device": self.device,
}
+2 -3
View File
@@ -41,7 +41,7 @@ from lerobot.policies.groot.processor_groot import (
GrootN17ActionDecodeStep,
GrootN17PackInputsStep,
GrootN17VLMEncodeStep,
_transform_n1_7_image_for_vlm,
_transform_n1_7_image_for_vlm_albumentations,
make_groot_pre_post_processors,
)
from lerobot.processor import (
@@ -1529,13 +1529,12 @@ def test_groot_n1_7_vlm_image_transform_matches_albumentations_eval_path():
image_np = (np.arange(360 * 360 * 3, dtype=np.uint32) % 251).astype(np.uint8).reshape(360, 360, 3)
transformed = _transform_n1_7_image_for_vlm(
transformed = _transform_n1_7_image_for_vlm_albumentations(
Image.fromarray(image_np),
image_crop_size=[230, 230],
image_target_size=[256, 256],
shortest_image_edge=256,
crop_fraction=0.95,
use_albumentations=True,
)
expected = cv2.resize(image_np, (256, 256), interpolation=cv2.INTER_AREA)