From 93e29b0cfc70a4b71f4a4fc02204b6adf92a32ee Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 15 Jun 2026 11:11:34 +0200 Subject: [PATCH] fix(groot): GPU/tensor N1.7 image preprocessing + resize to trained resolution GR00T training was dataloader-bound (0->100->0 GPU-utilization sawtooth). GrootN17VLMEncodeStep ran the Qwen3-VL image processor per frame on PIL images on the single CPU main-loop thread, and that cost is timed inside dataloading_s (preprocessor(batch) runs in the main process, not the dataloader workers), so adding workers cannot hide it. - Feed the torchvision-backed Qwen3-VL processor (C,H,W) uint8 tensors instead of a per-frame Image.fromarray PIL roundtrip, and run resize/normalize/patchify on config.device (GPU) when available. Bit-identical on CPU when no resize is configured; with a resize only the PIL->torchvision bicubic backend differs (<2/255 per pixel). The use_albumentations path stays PIL/cv2; reload on a box without the saved device falls back to CPU. - Default image_target_size/crop to the N1.7 backbone's training geometry (256x256 / 230x230) when a checkpoint ships no image sizing (checkpoint_assets is None, e.g. finetuning nvidia/GR00T-N1.7-3B via repo-id with a new embodiment). Previously image_target_size=None disabled the resize, so full-resolution frames were patchified into ~4.7x more vision tokens than the model was trained on -- inflating dataloading_s (patchify) and update_s (VLM sequence) and skewing the input distribution. Checkpoints that pin their own sizing are honored; the default constants are shared with GR00T_N1_7_DEFAULTS. Net: preprocessing leaves the CPU critical path and the VLM sees the resolution it was trained on -- faster training/inference and a correct train/serve distribution. Affects inference too (shared preprocessor); existing checkpoints still load (backward compatible) but must be retrained to gain the benefits. --- .../policies/groot/configuration_groot.py | 4 + src/lerobot/policies/groot/groot_n1_7.py | 5 +- src/lerobot/policies/groot/processor_groot.py | 301 +++++++++++++----- tests/policies/groot/test_groot_n1_7.py | 5 +- 4 files changed, 224 insertions(+), 91 deletions(-) diff --git a/src/lerobot/policies/groot/configuration_groot.py b/src/lerobot/policies/groot/configuration_groot.py index f108e6f06..9a645af1f 100644 --- a/src/lerobot/policies/groot/configuration_groot.py +++ b/src/lerobot/policies/groot/configuration_groot.py @@ -42,6 +42,10 @@ GROOT_N1_5_REMOVAL_GUIDANCE = ( ) GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B" GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B" +# Default GR00T N1.7 training resolution. Fallback if processor_config lacks sizing. Prevents mismatched +# full-res patchification by forcing a resize. Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py. +N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256) +N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230) GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero" # Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it # to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an diff --git a/src/lerobot/policies/groot/groot_n1_7.py b/src/lerobot/policies/groot/groot_n1_7.py index e062e2c5c..ecf8fe8c7 100644 --- a/src/lerobot/policies/groot/groot_n1_7.py +++ b/src/lerobot/policies/groot/groot_n1_7.py @@ -32,6 +32,7 @@ from torch.distributions import Beta from lerobot.utils.import_utils import _transformers_available, require_package from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer +from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE if TYPE_CHECKING or _transformers_available: from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel @@ -76,8 +77,8 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = { "use_flash_attention": True, "load_bf16": False, "backbone_trainable_params_fp32": True, - "image_crop_size": (230, 230), - "image_target_size": (256, 256), + "image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE, + "image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE, "shortest_image_edge": None, "crop_fraction": None, "random_rotation_angle": None, diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index ee20d6471..d75621d72 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -23,8 +23,10 @@ from typing import TYPE_CHECKING, Any import numpy as np import torch +import torchvision.transforms.v2.functional as tv_functional from einops import rearrange from PIL import Image +from torchvision.transforms import InterpolationMode from lerobot.utils.import_utils import _transformers_available @@ -57,11 +59,14 @@ from lerobot.utils.constants import ( POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME, ) +from lerobot.utils.device_utils import get_safe_torch_device from .configuration_groot import ( GROOT_ACTION_DECODE_TRANSFORM_LIBERO, GROOT_N1_5_REMOVAL_GUIDANCE, GROOT_N1_7_BACKBONE_MODEL, + N1_7_DEFAULT_IMAGE_CROP_SIZE, + N1_7_DEFAULT_IMAGE_TARGET_SIZE, GrootConfig, is_raw_groot_n1_7_checkpoint, ) @@ -729,21 +734,36 @@ def make_groot_pre_post_processors( modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None, ) + # Resolve the image preprocessing geometry. Honor the checkpoint's processor_config + # when it provides an image_target_size; otherwise fall back to the geometry the + # N1.7 backbone was trained on. Without this fallback a raw base checkpoint with no + # processor_config image sizing (e.g. fine-tuning nvidia/GR00T-N1.7-3B with a new + # embodiment, where checkpoint_assets is None) would patchify full-resolution camera + # frames, inflating the VLM token count and feeding the model a resolution it was not trained on. + if checkpoint_assets is not None and checkpoint_assets.image_target_size is not None: + image_target_size = checkpoint_assets.image_target_size + image_crop_size = checkpoint_assets.image_crop_size + shortest_image_edge = checkpoint_assets.shortest_image_edge + crop_fraction = checkpoint_assets.crop_fraction + else: + image_target_size = list(N1_7_DEFAULT_IMAGE_TARGET_SIZE) + image_crop_size = list(N1_7_DEFAULT_IMAGE_CROP_SIZE) + shortest_image_edge = None + crop_fraction = None + use_albumentations = checkpoint_assets.use_albumentations if checkpoint_assets is not None else False + input_steps: list[ProcessorStep] = [ RenameObservationsProcessorStep(rename_map={}), AddBatchDimensionProcessorStep(), pack_step, GrootN17VLMEncodeStep( model_name=config.n1_7_backbone_model, - image_crop_size=checkpoint_assets.image_crop_size if checkpoint_assets is not None else None, - image_target_size=checkpoint_assets.image_target_size if checkpoint_assets is not None else None, - shortest_image_edge=checkpoint_assets.shortest_image_edge - if checkpoint_assets is not None - else None, - crop_fraction=checkpoint_assets.crop_fraction if checkpoint_assets is not None else None, - use_albumentations=checkpoint_assets.use_albumentations - if checkpoint_assets is not None - else False, + image_crop_size=image_crop_size, + image_target_size=image_target_size, + shortest_image_edge=shortest_image_edge, + crop_fraction=crop_fraction, + use_albumentations=use_albumentations, + device=config.device, ), DeviceProcessorStep(device=config.device), ] @@ -899,15 +919,22 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces return proc -def _transform_n1_7_image_for_vlm( +def _transform_n1_7_image_for_vlm_albumentations( image: Image.Image, *, image_crop_size: list[int] | None, image_target_size: list[int] | None, shortest_image_edge: int | None, crop_fraction: float | None, - use_albumentations: bool = False, ) -> Image.Image: + """cv2/INTER_AREA eval transform mirroring Isaac-GR00T's albumentations preprocessing. + + Used only for checkpoints saved with ``use_albumentations=True``. cv2 is + CPU/numpy-only so this path cannot run on GPU; the default (non-albumentations) + geometry is handled on-device by :func:`_transform_n1_7_image_for_vlm_torch`. The + cv2/INTER_AREA resize and floored center-crop here intentionally differ from that + torch path and must stay bit-exact to the upstream reference. + """ if image_target_size is None: return image @@ -915,70 +942,101 @@ def _transform_n1_7_image_for_vlm( if image.mode != "RGB": image = image.convert("RGB") - if use_albumentations: - try: - import cv2 - except ImportError as exc: - raise ImportError( - "GR00T N1.7 checkpoints with use_albumentations=True require opencv-python-headless." - ) from exc + try: + import cv2 + except ImportError as exc: + raise ImportError( + "GR00T N1.7 checkpoints with use_albumentations=True require opencv-python-headless." + ) from exc - image_np = np.asarray(image) - height, width = image_np.shape[:2] - if height != width: - square_edge = max(height, width) - pad_h = square_edge - height - pad_w = square_edge - width - image_np = cv2.copyMakeBorder( - image_np, - pad_h // 2, - pad_h - pad_h // 2, - pad_w // 2, - pad_w - pad_w // 2, - cv2.BORDER_CONSTANT, - value=(0, 0, 0), - ) - - resize_edge = shortest_image_edge or target_h - if image_np.shape[:2] != (resize_edge, resize_edge): - image_np = cv2.resize(image_np, (resize_edge, resize_edge), interpolation=cv2.INTER_AREA) - - if crop_fraction is None and image_crop_size is not None: - crop_fraction = image_crop_size[0] / float(target_h) - if crop_fraction is not None and 0.0 < crop_fraction < 1.0: - height, width = image_np.shape[:2] - crop_h = max(1, int(height * crop_fraction)) - crop_w = max(1, int(width * crop_fraction)) - top = max(0, (height - crop_h) // 2) - left = max(0, (width - crop_w) // 2) - image_np = image_np[top : top + crop_h, left : left + crop_w] - - if image_np.shape[:2] != (target_h, target_w): - image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA) - return Image.fromarray(image_np) - - square_edge = max(image.width, image.height) - if image.width != image.height: - padded = Image.new("RGB", (square_edge, square_edge)) - left = (square_edge - image.width) // 2 - top = (square_edge - image.height) // 2 - padded.paste(image, (left, top)) - image = padded + image_np = np.asarray(image) + height, width = image_np.shape[:2] + if height != width: + square_edge = max(height, width) + pad_h = square_edge - height + pad_w = square_edge - width + image_np = cv2.copyMakeBorder( + image_np, + pad_h // 2, + pad_h - pad_h // 2, + pad_w // 2, + pad_w - pad_w // 2, + cv2.BORDER_CONSTANT, + value=(0, 0, 0), + ) resize_edge = shortest_image_edge or target_h - image = image.resize((resize_edge, resize_edge), Image.Resampling.BICUBIC) + if image_np.shape[:2] != (resize_edge, resize_edge): + image_np = cv2.resize(image_np, (resize_edge, resize_edge), interpolation=cv2.INTER_AREA) if crop_fraction is None and image_crop_size is not None: crop_fraction = image_crop_size[0] / float(target_h) if crop_fraction is not None and 0.0 < crop_fraction < 1.0: - crop_w = max(1, int(round(image.width * crop_fraction))) - crop_h = max(1, int(round(image.height * crop_fraction))) - left = max(0, (image.width - crop_w) // 2) - top = max(0, (image.height - crop_h) // 2) - image = image.crop((left, top, left + crop_w, top + crop_h)) + height, width = image_np.shape[:2] + crop_h = max(1, int(height * crop_fraction)) + crop_w = max(1, int(width * crop_fraction)) + top = max(0, (height - crop_h) // 2) + left = max(0, (width - crop_w) // 2) + image_np = image_np[top : top + crop_h, left : left + crop_w] - if image.size != (target_w, target_h): - image = image.resize((target_w, target_h), Image.Resampling.BICUBIC) + if image_np.shape[:2] != (target_h, target_w): + image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA) + return Image.fromarray(image_np) + + +def _transform_n1_7_image_for_vlm_torch( + image: torch.Tensor, + *, + image_crop_size: list[int] | None, + image_target_size: list[int] | None, + shortest_image_edge: int | None, + crop_fraction: float | None, +) -> torch.Tensor: + """Default (non-albumentations) N1.7 image transform: pad-to-square, resize to + ``shortest_image_edge``, center-crop by ``crop_fraction``, resize to ``image_target_size``. + + Operates on a ``(C, H, W)`` uint8 tensor and keeps the result on the input + tensor's device so the resize/crop run on GPU when the tensor is. Bicubic + interpolation with antialiasing matches PIL's ``Image.Resampling.BICUBIC`` + closely (sub-``2/255`` per-pixel on worst-case inputs). The ``use_albumentations`` + cv2/INTER_AREA path has no torch equivalent and stays on + :func:`_transform_n1_7_image_for_vlm_albumentations`. + """ + if image_target_size is None: + return image + + target_h, target_w = image_target_size + _, height, width = image.shape + + square_edge = max(height, width) + if height != width: + left = (square_edge - width) // 2 + top = (square_edge - height) // 2 + image = tv_functional.pad( + image, [left, top, square_edge - width - left, square_edge - height - top], fill=0 + ) + + resize_edge = shortest_image_edge or target_h + image = tv_functional.resize( + image, [resize_edge, resize_edge], interpolation=InterpolationMode.BICUBIC, antialias=True + ) + + if crop_fraction is None and image_crop_size is not None: + crop_fraction = image_crop_size[0] / float(target_h) + if crop_fraction is not None and 0.0 < crop_fraction < 1.0: + # Match the PIL helper's center crop exactly: round() the crop size but + # floor() the offset (torchvision.center_crop rounds the offset, which + # shifts the region by 1px when (edge - crop) is odd). + crop_h = max(1, int(round(image.shape[-2] * crop_fraction))) + crop_w = max(1, int(round(image.shape[-1] * crop_fraction))) + top = max(0, (image.shape[-2] - crop_h) // 2) + left = max(0, (image.shape[-1] - crop_w) // 2) + image = image[..., top : top + crop_h, left : left + crop_w] + + if tuple(image.shape[-2:]) != (target_h, target_w): + image = tv_functional.resize( + image, [target_h, target_w], interpolation=InterpolationMode.BICUBIC, antialias=True + ) return image @@ -1280,6 +1338,12 @@ class GrootN17VLMEncodeStep(ProcessorStep): The packed video has shape ``(B, T, V, H, W, C)``. Each frame/view becomes an image item in the same chat message so the resulting image tokens match the temporal VLM packing used by Isaac-GR00T. + + Images are handed to the torchvision-backed Qwen3-VL processor as ``(C, H, W)`` + uint8 tensors (no per-frame PIL roundtrip), and, when ``device`` resolves to a + CUDA device, the resize/rescale/normalize/patchify run there. This keeps the + output bit-identical on CPU and moves the dominant preprocessing cost off + the critical path on GPU. """ model_name: str = GROOT_N1_7_BACKBONE_MODEL @@ -1288,6 +1352,7 @@ class GrootN17VLMEncodeStep(ProcessorStep): shortest_image_edge: int | None = None crop_fraction: float | None = None use_albumentations: bool = False + device: str | None = None _proc: ProcessorMixin | None = field(default=None, init=False, repr=False) @property @@ -1296,6 +1361,69 @@ class GrootN17VLMEncodeStep(ProcessorStep): self._proc = _build_n1_7_processor(self.model_name) return self._proc + def _target_device(self) -> torch.device | None: + # The albumentations path is cv2/PIL only, so it cannot run on GPU. + if self.device is None or self.use_albumentations: + return None + try: + return get_safe_torch_device(self.device) + except (AssertionError, RuntimeError): + # A device serialized at train time (e.g. "cuda") may be unavailable + # when the processor is reloaded elsewhere (e.g. CPU-only eval), and + # this step is not in the standard device-override set. Fall back to + # the CPU path, which is bit-identical, instead of crashing. + return None + + def _build_sample_images( + self, video: Any, batch_size: int, target_device: torch.device | None + ) -> list[list[Any]]: + """Return, per batch item, its ordered ``(timestep, view)`` frames. + + ``use_albumentations`` keeps the legacy per-frame PIL/cv2 transform; + otherwise frames are ``(C, H, W)`` uint8 tensors (moved to + ``target_device`` when set) for the torchvision-backed Qwen processor. + """ + if self.use_albumentations: + video_np = np.asarray(video) + return [ + [ + _transform_n1_7_image_for_vlm_albumentations( + Image.fromarray(video_np[batch_idx, timestep, view_idx]), + image_crop_size=self.image_crop_size, + image_target_size=self.image_target_size, + shortest_image_edge=self.shortest_image_edge, + crop_fraction=self.crop_fraction, + ) + for timestep in range(video_np.shape[1]) + for view_idx in range(video_np.shape[2]) + ] + for batch_idx in range(batch_size) + ] + + video_t = video if torch.is_tensor(video) else torch.from_numpy(np.ascontiguousarray(video)) + # (B, T, V, H, W, C) uint8 -> (B, T, V, C, H, W) + video_t = video_t.permute(0, 1, 2, 5, 3, 4).contiguous() + if target_device is not None and video_t.device != target_device: + video_t = video_t.to(target_device, non_blocking=(target_device.type == "cuda")) + + frames_per_sample: list[list[Any]] = [] + for batch_idx in range(batch_size): + sample = video_t[batch_idx] # (T, V, C, H, W) + frames_per_sample.append( + [ + _transform_n1_7_image_for_vlm_torch( + sample[timestep, view_idx], + image_crop_size=self.image_crop_size, + image_target_size=self.image_target_size, + shortest_image_edge=self.shortest_image_edge, + crop_fraction=self.crop_fraction, + ) + for timestep in range(sample.shape[0]) + for view_idx in range(sample.shape[1]) + ] + ) + return frames_per_sample + def __call__(self, transition: EnvTransition) -> EnvTransition: obs = transition.get(TransitionKey.OBSERVATION, {}) or {} comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {} @@ -1303,33 +1431,25 @@ class GrootN17VLMEncodeStep(ProcessorStep): if video is None: return transition + batch_size = int(video.shape[0]) languages = _prepare_n1_7_language_batch( comp.get("language"), - video.shape[0], + batch_size, formalize_language=False, ) + target_device = self._target_device() + sample_images = self._build_sample_images(video, batch_size, target_device) + texts: list[str] = [] - images: list[Image.Image] = [] - for batch_idx in range(video.shape[0]): - sample = video[batch_idx] # (T, V, H, W, C) - sample_images = [ - _transform_n1_7_image_for_vlm( - Image.fromarray(sample[timestep, view_idx]), - image_crop_size=self.image_crop_size, - image_target_size=self.image_target_size, - shortest_image_edge=self.shortest_image_edge, - crop_fraction=self.crop_fraction, - use_albumentations=self.use_albumentations, - ) - for timestep in range(sample.shape[0]) - for view_idx in range(sample.shape[1]) - ] + images: list[Any] = [] + for batch_idx in range(batch_size): + frames = sample_images[batch_idx] conversation = [ { "role": "user", "content": [ - *[{"type": "image", "image": image} for image in sample_images], + *[{"type": "image", "image": image} for image in frames], {"type": "text", "text": languages[batch_idx]}, ], } @@ -1341,9 +1461,17 @@ class GrootN17VLMEncodeStep(ProcessorStep): add_generation_prompt=False, ) ) - images.extend(sample_images) + images.extend(frames) - encoded = self.proc(text=texts, images=images, return_tensors="pt", padding=True) + proc_kwargs: dict[str, Any] = { + "text": texts, + "images": images, + "return_tensors": "pt", + "padding": True, + } + if target_device is not None: + proc_kwargs["device"] = str(target_device) + encoded = self.proc(**proc_kwargs) for key, value in encoded.items(): comp[key] = value obs.pop("video", None) @@ -1362,6 +1490,7 @@ class GrootN17VLMEncodeStep(ProcessorStep): "shortest_image_edge": self.shortest_image_edge, "crop_fraction": self.crop_fraction, "use_albumentations": self.use_albumentations, + "device": self.device, } diff --git a/tests/policies/groot/test_groot_n1_7.py b/tests/policies/groot/test_groot_n1_7.py index 5f06bb73e..dd6e8eb30 100644 --- a/tests/policies/groot/test_groot_n1_7.py +++ b/tests/policies/groot/test_groot_n1_7.py @@ -41,7 +41,7 @@ from lerobot.policies.groot.processor_groot import ( GrootN17ActionDecodeStep, GrootN17PackInputsStep, GrootN17VLMEncodeStep, - _transform_n1_7_image_for_vlm, + _transform_n1_7_image_for_vlm_albumentations, make_groot_pre_post_processors, ) from lerobot.processor import ( @@ -1529,13 +1529,12 @@ def test_groot_n1_7_vlm_image_transform_matches_albumentations_eval_path(): image_np = (np.arange(360 * 360 * 3, dtype=np.uint32) % 251).astype(np.uint8).reshape(360, 360, 3) - transformed = _transform_n1_7_image_for_vlm( + transformed = _transform_n1_7_image_for_vlm_albumentations( Image.fromarray(image_np), image_crop_size=[230, 230], image_target_size=[256, 256], shortest_image_edge=256, crop_fraction=0.95, - use_albumentations=True, ) expected = cv2.resize(image_np, (256, 256), interpolation=cv2.INTER_AREA)