fix(groot): make N1.7 letterbox opt-in

This commit is contained in:
Andy Wrenn
2026-06-30 12:25:28 -07:00
parent c74eb20abd
commit da9ce79678
2 changed files with 97 additions and 9 deletions
+38 -9
View File
@@ -143,6 +143,7 @@ class _GrootN17CheckpointProcessorAssets:
shortest_image_edge: int | None
crop_fraction: float | None
use_albumentations: bool
letter_box_transform: bool
@dataclass(frozen=True)
@@ -199,6 +200,9 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
use_albumentations = processor_kwargs.get("use_albumentations", False)
if not isinstance(use_albumentations, bool):
use_albumentations = False
letter_box_transform = processor_kwargs.get("letter_box_transform", False)
if not isinstance(letter_box_transform, bool):
letter_box_transform = False
valid_action_horizon = _load_n1_7_checkpoint_action_horizon(processor_kwargs, config.embodiment_tag)
video_horizon = _load_n1_7_checkpoint_video_horizon(processor_kwargs, config.embodiment_tag)
@@ -225,6 +229,7 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
shortest_image_edge=as_optional_int(processor_kwargs.get("shortest_image_edge")),
crop_fraction=as_optional_float(processor_kwargs.get("crop_fraction")),
use_albumentations=use_albumentations,
letter_box_transform=letter_box_transform,
)
@@ -1058,6 +1063,7 @@ def _build_n1_7_relative_action_processor_assets(
shortest_image_edge=base_assets.shortest_image_edge if base_assets is not None else None,
crop_fraction=base_assets.crop_fraction if base_assets is not None else None,
use_albumentations=base_assets.use_albumentations if base_assets is not None else False,
letter_box_transform=base_assets.letter_box_transform if base_assets is not None else False,
)
@@ -1179,6 +1185,7 @@ def make_groot_pre_post_processors(
shortest_image_edge = None
crop_fraction = None
use_albumentations = checkpoint_assets.use_albumentations if checkpoint_assets is not None else False
letter_box_transform = checkpoint_assets.letter_box_transform if checkpoint_assets is not None else False
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}),
@@ -1191,6 +1198,7 @@ def make_groot_pre_post_processors(
shortest_image_edge=shortest_image_edge,
crop_fraction=crop_fraction,
use_albumentations=use_albumentations,
letter_box_transform=letter_box_transform,
device=config.device,
),
DeviceProcessorStep(device=config.device),
@@ -1315,6 +1323,7 @@ def _transform_n1_7_image_for_vlm_albumentations(
image_target_size: list[int] | None,
shortest_image_edge: int | None,
crop_fraction: float | None,
letter_box_transform: bool = False,
) -> np.ndarray:
"""cv2/INTER_AREA eval transform mirroring Isaac-GR00T's albumentations preprocessing.
@@ -1339,6 +1348,18 @@ def _transform_n1_7_image_for_vlm_albumentations(
if not image_np.flags.c_contiguous:
image_np = np.ascontiguousarray(image_np)
if letter_box_transform:
height, width = image_np.shape[:2]
if height != width:
square_edge = max(height, width)
pad_h = square_edge - height
pad_w = square_edge - width
top = pad_h // 2
bottom = pad_h - top
left = pad_w // 2
right = pad_w - left
image_np = cv2.copyMakeBorder(image_np, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
resize_edge = shortest_image_edge or target_h
def resize_shortest_edge(frame: np.ndarray) -> np.ndarray:
@@ -1377,9 +1398,12 @@ def _transform_n1_7_image_for_vlm_torch(
image_target_size: list[int] | None,
shortest_image_edge: int | None,
crop_fraction: float | None,
letter_box_transform: bool = False,
) -> torch.Tensor:
"""Default (non-albumentations) N1.7 image transform: pad-to-square, resize to
``shortest_image_edge``, center-crop by ``crop_fraction``, resize to ``image_target_size``.
"""Default (non-albumentations) N1.7 image transform.
Optionally pads to square, then resizes to ``shortest_image_edge``, center-crops
by ``crop_fraction``, and resizes to ``image_target_size``.
Operates on a ``(C, H, W)`` uint8 tensor and keeps the result on the input
tensor's device so the resize/crop run on GPU when the tensor is. Bicubic
@@ -1394,13 +1418,14 @@ def _transform_n1_7_image_for_vlm_torch(
target_h, target_w = image_target_size
_, height, width = image.shape
square_edge = max(height, width)
if height != width:
left = (square_edge - width) // 2
top = (square_edge - height) // 2
image = tv_functional.pad(
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
)
if letter_box_transform:
square_edge = max(height, width)
if height != width:
left = (square_edge - width) // 2
top = (square_edge - height) // 2
image = tv_functional.pad(
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
)
resize_edge = shortest_image_edge or target_h
image = tv_functional.resize(
@@ -1945,6 +1970,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
shortest_image_edge: int | None = None
crop_fraction: float | None = None
use_albumentations: bool = False
letter_box_transform: bool = False
device: str | None = None
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
@@ -1986,6 +2012,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
letter_box_transform=self.letter_box_transform,
)
for timestep in range(video_np.shape[1])
for view_idx in range(video_np.shape[2])
@@ -2010,6 +2037,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
letter_box_transform=self.letter_box_transform,
)
for timestep in range(sample.shape[0])
for view_idx in range(sample.shape[1])
@@ -2083,6 +2111,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
"shortest_image_edge": self.shortest_image_edge,
"crop_fraction": self.crop_fraction,
"use_albumentations": self.use_albumentations,
"letter_box_transform": self.letter_box_transform,
"device": self.device,
}
+59
View File
@@ -46,6 +46,7 @@ from lerobot.policies.groot.processor_groot import (
N1_7_NATIVE_ACTION_HORIZON,
_make_relative_action_training_stats,
_transform_n1_7_image_for_vlm_albumentations,
_transform_n1_7_image_for_vlm_torch,
make_groot_pre_post_processors,
)
from lerobot.processor import (
@@ -245,6 +246,7 @@ def _write_raw_n1_7_libero_checkpoint(path):
"shortest_image_edge": 256,
"crop_fraction": 0.95,
"use_albumentations": True,
"letter_box_transform": False,
"max_action_horizon": 40,
"max_state_dim": 132,
"max_action_dim": 132,
@@ -609,6 +611,7 @@ def test_raw_n1_7_libero_checkpoint_processors_use_checkpoint_assets(tmp_path):
assert vlm_encode.shortest_image_edge == 256
assert vlm_encode.crop_fraction == 0.95
assert vlm_encode.use_albumentations is True
assert vlm_encode.letter_box_transform is False
assert decode_actions.raw_stats["action"]["gripper"]["q99"] == [115.0]
assert decode_actions.env_action_dim == 7
assert decode_actions.use_percentiles is True
@@ -682,6 +685,7 @@ def test_groot_n1_7_saved_processors_round_trip_checkpoint_specific_fields(tmp_p
config_filename="policy_postprocessor.json",
)
pack_inputs = next(step for step in loaded_preprocessor.steps if isinstance(step, GrootN17PackInputsStep))
vlm_encode = next(step for step in loaded_preprocessor.steps if isinstance(step, GrootN17VLMEncodeStep))
decode_actions = next(
step for step in loaded_postprocessor.steps if isinstance(step, GrootN17ActionDecodeStep)
)
@@ -690,6 +694,7 @@ def test_groot_n1_7_saved_processors_round_trip_checkpoint_specific_fields(tmp_p
assert pack_inputs.action_horizon == 40
assert pack_inputs.video_modality_keys == ["image", "wrist_image"]
assert pack_inputs.clip_outliers is True
assert vlm_encode.letter_box_transform is False
torch.testing.assert_close(
pack_inputs.stats[OBS_STATE]["min"],
torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
@@ -1858,6 +1863,58 @@ def test_groot_n1_7_vlm_image_transform_matches_albumentations_eval_path():
np.testing.assert_array_equal(np.asarray(transformed), expected)
def test_groot_n1_7_albumentations_letterbox_is_opt_in():
pytest.importorskip("cv2", exc_type=ImportError)
image = np.full((3, 5, 3), 255, dtype=np.uint8)
default = _transform_n1_7_image_for_vlm_albumentations(
image,
image_crop_size=None,
image_target_size=[10, 10],
shortest_image_edge=10,
crop_fraction=None,
)
letterboxed = _transform_n1_7_image_for_vlm_albumentations(
image,
image_crop_size=None,
image_target_size=[10, 10],
shortest_image_edge=10,
crop_fraction=None,
letter_box_transform=True,
)
assert default.shape == (10, 17, 3)
assert default.min() == 255
assert letterboxed.shape == (10, 10, 3)
assert letterboxed.min() < 255
def test_groot_n1_7_torch_letterbox_is_opt_in():
image = torch.full((3, 3, 5), 255, dtype=torch.uint8)
default = _transform_n1_7_image_for_vlm_torch(
image,
image_crop_size=None,
image_target_size=[10, 10],
shortest_image_edge=10,
crop_fraction=None,
)
letterboxed = _transform_n1_7_image_for_vlm_torch(
image,
image_crop_size=None,
image_target_size=[10, 10],
shortest_image_edge=10,
crop_fraction=None,
letter_box_transform=True,
)
assert tuple(default.shape) == (3, 10, 10)
assert int(default.min()) == 255
assert tuple(letterboxed.shape) == (3, 10, 10)
assert int(letterboxed.min()) < 255
def test_groot_n1_7_vlm_encode_transforms_non_square_two_camera_sample_like_core_albumentations():
cv2 = pytest.importorskip("cv2", exc_type=ImportError)
@@ -1928,6 +1985,7 @@ def test_groot_n1_7_vlm_encode_config_round_trips_model_name():
shortest_image_edge=256,
crop_fraction=0.95,
use_albumentations=True,
letter_box_transform=True,
)
restored = GrootN17VLMEncodeStep(**step.get_config())
@@ -1938,6 +1996,7 @@ def test_groot_n1_7_vlm_encode_config_round_trips_model_name():
assert restored.shortest_image_edge == 256
assert restored.crop_fraction == 0.95
assert restored.use_albumentations is True
assert restored.letter_box_transform is True
def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch):