feat(groot): train-time random crop for N1.7 (eval keeps center crop)

Isaac-GR00T crops a random crop_fraction window during training and the
deterministic center window at eval, replaying the sampled window across all
camera views of a sample. This contract is unchanged since the N1.5 release
(gr00t/data/transform/video.py: "If mode is 'train', return a random crop
transform. If mode is 'eval', return a center crop transform.") and mirrors
LeRobot's own Diffusion/VQBeT crop_is_random pattern. The LeRobot N1.7 port
used the eval center crop for training too, so the fine-tuned projector/DiT
never sees frame borders and trains on a single fixed appearance point.

Scope: crop geometry ONLY - no color jitter, no new dependencies. The random
window is plain numpy slicing inside the existing cv2 eval transform:

- _transform_n1_7_image_for_vlm_albumentations gains crop_position=(y, x)
  fractions; None keeps the center crop byte-identical to before (verified
  by test)
- GrootN17VLMEncodeStep gains a runtime-only 'training' flag (never
  serialized; reloaded pipelines default to eval); training samples ONE
  window per sample and reuses it across (timestep, view) frames - Isaac's
  cross-view consistency
- gated on torch.is_grad_enabled() so no_grad validation and frozen-eval
  paths are unaffected
- wired via dataset_meta is not None in make_groot_pre_post_processors and
  the existing _set_groot_preprocessor_training on serialized reloads

Verification: tests/policies/groot/test_groot_train_random_crop.py (8 passed:
center-crop bit-exactness with crop_position=None, corner/center windows,
cross-view replay, train!=eval, no_grad gating, seed reproducibility,
serialization contract) + groot suite 23 passed / 5 skipped on RTX PRO 6000 /
CUDA 13.3.
This commit is contained in:
johnnynunez
2026-07-02 03:17:47 +02:00
parent 459d416bbf
commit f53490c15e
2 changed files with 199 additions and 17 deletions
+43 -17
View File
@@ -1225,6 +1225,7 @@ def make_groot_pre_post_processors(
crop_fraction=crop_fraction,
use_albumentations=use_albumentations,
letter_box_transform=letter_box_transform,
training=dataset_meta is not None,
device=config.device,
),
DeviceProcessorStep(device=config.device),
@@ -1350,6 +1351,7 @@ def _transform_n1_7_image_for_vlm_albumentations(
shortest_image_edge: int | None,
crop_fraction: float | None,
letter_box_transform: bool = False,
crop_position: tuple[float, float] | None = None,
) -> np.ndarray:
"""cv2/INTER_AREA eval transform mirroring Isaac-GR00T's albumentations preprocessing.
@@ -1359,6 +1361,12 @@ def _transform_n1_7_image_for_vlm_albumentations(
cv2/INTER_AREA resize and floored center-crop here intentionally differ from that
torch path and must stay bit-exact to the upstream reference. The hot path accepts
and returns numpy arrays to avoid per-frame PIL round-trips.
``crop_position`` selects where the ``crop_fraction`` window sits: ``None``
keeps the deterministic center crop (eval contract), while ``(y, x)``
fractions in [0, 1] place the window for Isaac's train-time random crop
(0.5, 0.5 == center). Training samples one position per sample and reuses
it across camera views.
"""
if image_target_size is None:
return image
@@ -1410,8 +1418,13 @@ def _transform_n1_7_image_for_vlm_albumentations(
height, width = image_np.shape[:2]
crop_h = max(1, int(height * crop_fraction))
crop_w = max(1, int(width * crop_fraction))
top = max(0, (height - crop_h) // 2)
left = max(0, (width - crop_w) // 2)
if crop_position is None:
top = max(0, (height - crop_h) // 2)
left = max(0, (width - crop_w) // 2)
else:
pos_y, pos_x = crop_position
top = int(round((height - crop_h) * min(max(pos_y, 0.0), 1.0)))
left = int(round((width - crop_w) * min(max(pos_x, 0.0), 1.0)))
image_np = image_np[top : top + crop_h, left : left + crop_w]
return resize_shortest_edge(image_np)
@@ -2007,6 +2020,11 @@ class GrootN17VLMEncodeStep(ProcessorStep):
crop_fraction: float | None = None
use_albumentations: bool = False
letter_box_transform: bool = False
# Runtime-only train/eval mode: True enables Isaac's train-time random crop
# (one window per sample, replayed across views); False keeps the
# deterministic center crop. Never serialized - reloaded pipelines default
# to eval and are re-enabled only when processors are built with dataset_meta.
training: bool = False
device: str | None = None
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
@@ -2040,21 +2058,29 @@ class GrootN17VLMEncodeStep(ProcessorStep):
"""
if self.use_albumentations:
video_np = np.asarray(video)
return [
[
_transform_n1_7_image_for_vlm_albumentations(
video_np[batch_idx, timestep, view_idx],
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
letter_box_transform=self.letter_box_transform,
)
for timestep in range(video_np.shape[1])
for view_idx in range(video_np.shape[2])
]
for batch_idx in range(batch_size)
]
train_crop = self.training and torch.is_grad_enabled()
sample_images: list[list[Any]] = []
for batch_idx in range(batch_size):
# Isaac-GR00T samples ONE crop window per sample and replays it
# across every (timestep, view) frame of that sample, keeping
# cross-view geometry consistent. Eval keeps the center crop.
crop_position = (random.random(), random.random()) if train_crop else None
sample_images.append(
[
_transform_n1_7_image_for_vlm_albumentations(
video_np[batch_idx, timestep, view_idx],
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
letter_box_transform=self.letter_box_transform,
crop_position=crop_position,
)
for timestep in range(video_np.shape[1])
for view_idx in range(video_np.shape[2])
]
)
return sample_images
video_t = video if torch.is_tensor(video) else torch.from_numpy(np.ascontiguousarray(video))
# (B, T, V, H, W, C) uint8 -> (B, T, V, C, H, W)
@@ -0,0 +1,156 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Isaac-GR00T N1.7 train-time random crop contract (crop geometry only).
Isaac-GR00T crops a random ``crop_fraction`` window during training and the
deterministic center window at eval, replaying the sampled window across all
camera views of a sample (gr00t/data/transform/video.py, n1.5-release onward:
"If mode is 'train', return a random crop transform. If mode is 'eval', return
a center crop transform."). This mirrors LeRobot's own Diffusion/VQBeT
``crop_is_random`` pattern. Color jitter is intentionally out of scope here.
"""
import random
import numpy as np
import torch
from lerobot.policies.groot.processor_groot import (
GrootN17VLMEncodeStep,
_transform_n1_7_image_for_vlm_albumentations,
)
def _structured_image(h=480, w=640):
yy, xx = np.mgrid[0:h, 0:w]
return np.stack(
[(xx * 255 / w), (yy * 255 / h), ((xx + yy) * 255 / (h + w))], axis=-1
).astype(np.uint8)
def test_crop_position_none_is_bitexact_center_crop():
"""crop_position=None must remain byte-identical to the pre-change eval path."""
img = _structured_image()
ref = _transform_n1_7_image_for_vlm_albumentations(
img, image_crop_size=None, image_target_size=[256, 256],
shortest_image_edge=256, crop_fraction=0.95,
)
out = _transform_n1_7_image_for_vlm_albumentations(
img, image_crop_size=None, image_target_size=[256, 256],
shortest_image_edge=256, crop_fraction=0.95, crop_position=None,
)
np.testing.assert_array_equal(ref, out)
def test_crop_position_center_matches_center_crop():
img = _structured_image()
center = _transform_n1_7_image_for_vlm_albumentations(
img, image_crop_size=None, image_target_size=[256, 256],
shortest_image_edge=256, crop_fraction=0.95, crop_position=None,
)
explicit = _transform_n1_7_image_for_vlm_albumentations(
img, image_crop_size=None, image_target_size=[256, 256],
shortest_image_edge=256, crop_fraction=0.95, crop_position=(0.5, 0.5),
)
# int-floor center vs rounded positional center may differ by <=1 px of grid
assert center.shape == explicit.shape
diff = np.abs(center.astype(np.int16) - explicit.astype(np.int16))
assert diff.mean() < 3.0
def test_crop_position_corners_differ_from_center():
img = _structured_image()
def crop_at(position):
return _transform_n1_7_image_for_vlm_albumentations(
img,
image_crop_size=None,
image_target_size=[256, 256],
shortest_image_edge=256,
crop_fraction=0.95,
crop_position=position,
)
center = crop_at(None)
tl = crop_at((0.0, 0.0))
br = crop_at((1.0, 1.0))
assert not np.array_equal(center, tl)
assert not np.array_equal(tl, br)
def _video(img, views=2):
return np.stack([img] * views, axis=0).reshape(1, 1, views, *img.shape)
def _step(training):
return GrootN17VLMEncodeStep(
image_target_size=[256, 256],
shortest_image_edge=256,
crop_fraction=0.95,
use_albumentations=True,
training=training,
)
def test_training_crop_replays_one_window_across_views():
video = _video(_structured_image())
frames = _step(training=True)._build_sample_images(video, batch_size=1, target_device=None)[0]
np.testing.assert_array_equal(np.asarray(frames[0]), np.asarray(frames[1]))
def test_training_crop_differs_from_eval_center_crop():
video = _video(_structured_image())
random.seed(3) # a draw that is not the exact center
train_frame = np.asarray(
_step(training=True)._build_sample_images(video, batch_size=1, target_device=None)[0][0]
)
eval_frame = np.asarray(
_step(training=False)._build_sample_images(video, batch_size=1, target_device=None)[0][0]
)
assert not np.array_equal(train_frame, eval_frame)
def test_training_crop_is_disabled_under_no_grad():
video = _video(_structured_image())
with torch.no_grad():
no_grad_frame = np.asarray(
_step(training=True)._build_sample_images(video, batch_size=1, target_device=None)[0][0]
)
eval_frame = np.asarray(
_step(training=False)._build_sample_images(video, batch_size=1, target_device=None)[0][0]
)
np.testing.assert_array_equal(no_grad_frame, eval_frame)
def test_training_mode_is_not_serialized():
step = _step(training=True)
serialized = step.get_config()
assert "training" not in serialized
restored = GrootN17VLMEncodeStep(**serialized)
assert restored.training is False
def test_training_crop_respects_global_seed():
video = _video(_structured_image())
def draw():
random.seed(11)
return np.asarray(
_step(training=True)._build_sample_images(video, batch_size=1, target_device=None)[0][0]
)
np.testing.assert_array_equal(draw(), draw())