diff --git a/src/lerobot/policies/groot/configuration_groot.py b/src/lerobot/policies/groot/configuration_groot.py index f108e6f06..9ff6c11e6 100644 --- a/src/lerobot/policies/groot/configuration_groot.py +++ b/src/lerobot/policies/groot/configuration_groot.py @@ -42,6 +42,14 @@ GROOT_N1_5_REMOVAL_GUIDANCE = ( ) GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B" GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B" +# Image preprocessing geometry the GR00T N1.7 backbone was trained on. The processor +# falls back to these when a checkpoint ships no image sizing in its processor_config +# (e.g. fine-tuning the raw nvidia/GR00T-N1.7-3B base with a new embodiment), so frames +# are resized to the expected resolution instead of being patchified at full camera +# resolution (which both slows training and is a train/checkpoint distribution mismatch). +# Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py. +N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256) +N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230) GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero" # Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it # to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an diff --git a/src/lerobot/policies/groot/groot_n1_7.py b/src/lerobot/policies/groot/groot_n1_7.py index e062e2c5c..ecf8fe8c7 100644 --- a/src/lerobot/policies/groot/groot_n1_7.py +++ b/src/lerobot/policies/groot/groot_n1_7.py @@ -32,6 +32,7 @@ from torch.distributions import Beta from lerobot.utils.import_utils import _transformers_available, require_package from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer +from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE if TYPE_CHECKING or _transformers_available: from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel @@ -76,8 +77,8 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = { "use_flash_attention": True, "load_bf16": False, "backbone_trainable_params_fp32": True, - "image_crop_size": (230, 230), - "image_target_size": (256, 256), + "image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE, + "image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE, "shortest_image_edge": None, "crop_fraction": None, "random_rotation_angle": None, diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index ee20d6471..da1ef16ac 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -23,8 +23,10 @@ from typing import TYPE_CHECKING, Any import numpy as np import torch +import torchvision.transforms.v2.functional as tv_functional from einops import rearrange from PIL import Image +from torchvision.transforms import InterpolationMode from lerobot.utils.import_utils import _transformers_available @@ -57,11 +59,14 @@ from lerobot.utils.constants import ( POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME, ) +from lerobot.utils.device_utils import get_safe_torch_device from .configuration_groot import ( GROOT_ACTION_DECODE_TRANSFORM_LIBERO, GROOT_N1_5_REMOVAL_GUIDANCE, GROOT_N1_7_BACKBONE_MODEL, + N1_7_DEFAULT_IMAGE_CROP_SIZE, + N1_7_DEFAULT_IMAGE_TARGET_SIZE, GrootConfig, is_raw_groot_n1_7_checkpoint, ) @@ -729,21 +734,40 @@ def make_groot_pre_post_processors( modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None, ) + # Resolve the image preprocessing geometry. Honor the checkpoint's processor_config + # when it provides an image_target_size; otherwise fall back to the geometry the + # N1.7 backbone was trained on. Without this fallback a raw base checkpoint with no + # processor_config image sizing (e.g. fine-tuning nvidia/GR00T-N1.7-3B with a new + # embodiment, where checkpoint_assets is None) would patchify full-resolution camera + # frames, inflating the VLM token count -- slowing both dataloading_s and update_s -- + # and feeding the model a resolution it was not trained on. + if checkpoint_assets is not None and checkpoint_assets.image_target_size is not None: + image_target_size = checkpoint_assets.image_target_size + image_crop_size = checkpoint_assets.image_crop_size + shortest_image_edge = checkpoint_assets.shortest_image_edge + crop_fraction = checkpoint_assets.crop_fraction + else: + image_target_size = list(N1_7_DEFAULT_IMAGE_TARGET_SIZE) + image_crop_size = list(N1_7_DEFAULT_IMAGE_CROP_SIZE) + shortest_image_edge = None + crop_fraction = None + use_albumentations = checkpoint_assets.use_albumentations if checkpoint_assets is not None else False + input_steps: list[ProcessorStep] = [ RenameObservationsProcessorStep(rename_map={}), AddBatchDimensionProcessorStep(), pack_step, GrootN17VLMEncodeStep( model_name=config.n1_7_backbone_model, - image_crop_size=checkpoint_assets.image_crop_size if checkpoint_assets is not None else None, - image_target_size=checkpoint_assets.image_target_size if checkpoint_assets is not None else None, - shortest_image_edge=checkpoint_assets.shortest_image_edge - if checkpoint_assets is not None - else None, - crop_fraction=checkpoint_assets.crop_fraction if checkpoint_assets is not None else None, - use_albumentations=checkpoint_assets.use_albumentations - if checkpoint_assets is not None - else False, + image_crop_size=image_crop_size, + image_target_size=image_target_size, + shortest_image_edge=shortest_image_edge, + crop_fraction=crop_fraction, + use_albumentations=use_albumentations, + # Run the image resize/normalize/patchify on the training device when + # possible instead of the single CPU main-loop thread (the dominant + # cost folded into dataloading_s). + device=config.device, ), DeviceProcessorStep(device=config.device), ] @@ -982,6 +1006,61 @@ def _transform_n1_7_image_for_vlm( return image +def _transform_n1_7_image_for_vlm_torch( + image: torch.Tensor, + *, + image_crop_size: list[int] | None, + image_target_size: list[int] | None, + shortest_image_edge: int | None, + crop_fraction: float | None, +) -> torch.Tensor: + """Torch/torchvision port of the non-albumentations branch of + :func:`_transform_n1_7_image_for_vlm`. + + Operates on a ``(C, H, W)`` uint8 tensor and keeps the result on the input + tensor's device so the resize/crop run on GPU when the tensor is. Bicubic + interpolation with antialiasing matches PIL's ``Image.Resampling.BICUBIC`` + closely (sub-``2/255`` per-pixel on worst-case inputs). The ``use_albumentations`` + cv2/INTER_AREA path has no torch equivalent and stays on the PIL helper. + """ + if image_target_size is None: + return image + + target_h, target_w = image_target_size + _, height, width = image.shape + + square_edge = max(height, width) + if height != width: + left = (square_edge - width) // 2 + top = (square_edge - height) // 2 + image = tv_functional.pad( + image, [left, top, square_edge - width - left, square_edge - height - top], fill=0 + ) + + resize_edge = shortest_image_edge or target_h + image = tv_functional.resize( + image, [resize_edge, resize_edge], interpolation=InterpolationMode.BICUBIC, antialias=True + ) + + if crop_fraction is None and image_crop_size is not None: + crop_fraction = image_crop_size[0] / float(target_h) + if crop_fraction is not None and 0.0 < crop_fraction < 1.0: + # Match the PIL helper's center crop exactly: round() the crop size but + # floor() the offset (torchvision.center_crop rounds the offset, which + # shifts the region by 1px when (edge - crop) is odd). + crop_h = max(1, int(round(image.shape[-2] * crop_fraction))) + crop_w = max(1, int(round(image.shape[-1] * crop_fraction))) + top = max(0, (image.shape[-2] - crop_h) // 2) + left = max(0, (image.shape[-1] - crop_w) // 2) + image = image[..., top : top + crop_h, left : left + crop_w] + + if tuple(image.shape[-2:]) != (target_h, target_w): + image = tv_functional.resize( + image, [target_h, target_w], interpolation=InterpolationMode.BICUBIC, antialias=True + ) + return image + + @dataclass @ProcessorStepRegistry.register(name="groot_n1_7_pack_inputs_v1") class GrootN17PackInputsStep(ProcessorStep): @@ -1280,6 +1359,12 @@ class GrootN17VLMEncodeStep(ProcessorStep): The packed video has shape ``(B, T, V, H, W, C)``. Each frame/view becomes an image item in the same chat message so the resulting image tokens match the temporal VLM packing used by Isaac-GR00T. + + Images are handed to the torchvision-backed Qwen3-VL processor as ``(C, H, W)`` + uint8 tensors (no per-frame PIL roundtrip), and, when ``device`` resolves to a + CUDA device, the resize/rescale/normalize/patchify run there instead of on the + single CPU main-loop thread. This keeps the output bit-identical on CPU and + moves the dominant preprocessing cost off the critical path on GPU. """ model_name: str = GROOT_N1_7_BACKBONE_MODEL @@ -1288,6 +1373,7 @@ class GrootN17VLMEncodeStep(ProcessorStep): shortest_image_edge: int | None = None crop_fraction: float | None = None use_albumentations: bool = False + device: str | None = None _proc: ProcessorMixin | None = field(default=None, init=False, repr=False) @property @@ -1296,6 +1382,70 @@ class GrootN17VLMEncodeStep(ProcessorStep): self._proc = _build_n1_7_processor(self.model_name) return self._proc + def _target_device(self) -> torch.device | None: + # The albumentations path is cv2/PIL only, so it cannot run on GPU. + if self.device is None or self.use_albumentations: + return None + try: + return get_safe_torch_device(self.device) + except (AssertionError, RuntimeError): + # A device serialized at train time (e.g. "cuda") may be unavailable + # when the processor is reloaded elsewhere (e.g. CPU-only eval), and + # this step is not in the standard device-override set. Fall back to + # the CPU path, which is bit-identical, instead of crashing. + return None + + def _build_sample_images( + self, video: Any, batch_size: int, target_device: torch.device | None + ) -> list[list[Any]]: + """Return, per batch item, its ordered ``(timestep, view)`` frames. + + ``use_albumentations`` keeps the legacy per-frame PIL/cv2 transform; + otherwise frames are ``(C, H, W)`` uint8 tensors (moved to + ``target_device`` when set) for the torchvision-backed Qwen processor. + """ + if self.use_albumentations: + video_np = np.asarray(video) + return [ + [ + _transform_n1_7_image_for_vlm( + Image.fromarray(video_np[batch_idx, timestep, view_idx]), + image_crop_size=self.image_crop_size, + image_target_size=self.image_target_size, + shortest_image_edge=self.shortest_image_edge, + crop_fraction=self.crop_fraction, + use_albumentations=True, + ) + for timestep in range(video_np.shape[1]) + for view_idx in range(video_np.shape[2]) + ] + for batch_idx in range(batch_size) + ] + + video_t = video if torch.is_tensor(video) else torch.from_numpy(np.ascontiguousarray(video)) + # (B, T, V, H, W, C) uint8 -> (B, T, V, C, H, W) + video_t = video_t.permute(0, 1, 2, 5, 3, 4).contiguous() + if target_device is not None and video_t.device != target_device: + video_t = video_t.to(target_device, non_blocking=(target_device.type == "cuda")) + + frames_per_sample: list[list[Any]] = [] + for batch_idx in range(batch_size): + sample = video_t[batch_idx] # (T, V, C, H, W) + frames_per_sample.append( + [ + _transform_n1_7_image_for_vlm_torch( + sample[timestep, view_idx], + image_crop_size=self.image_crop_size, + image_target_size=self.image_target_size, + shortest_image_edge=self.shortest_image_edge, + crop_fraction=self.crop_fraction, + ) + for timestep in range(sample.shape[0]) + for view_idx in range(sample.shape[1]) + ] + ) + return frames_per_sample + def __call__(self, transition: EnvTransition) -> EnvTransition: obs = transition.get(TransitionKey.OBSERVATION, {}) or {} comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {} @@ -1303,33 +1453,25 @@ class GrootN17VLMEncodeStep(ProcessorStep): if video is None: return transition + batch_size = int(video.shape[0]) languages = _prepare_n1_7_language_batch( comp.get("language"), - video.shape[0], + batch_size, formalize_language=False, ) + target_device = self._target_device() + sample_images = self._build_sample_images(video, batch_size, target_device) + texts: list[str] = [] - images: list[Image.Image] = [] - for batch_idx in range(video.shape[0]): - sample = video[batch_idx] # (T, V, H, W, C) - sample_images = [ - _transform_n1_7_image_for_vlm( - Image.fromarray(sample[timestep, view_idx]), - image_crop_size=self.image_crop_size, - image_target_size=self.image_target_size, - shortest_image_edge=self.shortest_image_edge, - crop_fraction=self.crop_fraction, - use_albumentations=self.use_albumentations, - ) - for timestep in range(sample.shape[0]) - for view_idx in range(sample.shape[1]) - ] + images: list[Any] = [] + for batch_idx in range(batch_size): + frames = sample_images[batch_idx] conversation = [ { "role": "user", "content": [ - *[{"type": "image", "image": image} for image in sample_images], + *[{"type": "image", "image": image} for image in frames], {"type": "text", "text": languages[batch_idx]}, ], } @@ -1341,9 +1483,17 @@ class GrootN17VLMEncodeStep(ProcessorStep): add_generation_prompt=False, ) ) - images.extend(sample_images) + images.extend(frames) - encoded = self.proc(text=texts, images=images, return_tensors="pt", padding=True) + proc_kwargs: dict[str, Any] = { + "text": texts, + "images": images, + "return_tensors": "pt", + "padding": True, + } + if target_device is not None: + proc_kwargs["device"] = str(target_device) + encoded = self.proc(**proc_kwargs) for key, value in encoded.items(): comp[key] = value obs.pop("video", None) @@ -1362,6 +1512,7 @@ class GrootN17VLMEncodeStep(ProcessorStep): "shortest_image_edge": self.shortest_image_edge, "crop_fraction": self.crop_fraction, "use_albumentations": self.use_albumentations, + "device": self.device, }