mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-05 17:17:01 +00:00
fix(groot): make N1.7 letterbox opt-in
This commit is contained in:
@@ -143,6 +143,7 @@ class _GrootN17CheckpointProcessorAssets:
|
||||
shortest_image_edge: int | None
|
||||
crop_fraction: float | None
|
||||
use_albumentations: bool
|
||||
letter_box_transform: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -199,6 +200,9 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
|
||||
use_albumentations = processor_kwargs.get("use_albumentations", False)
|
||||
if not isinstance(use_albumentations, bool):
|
||||
use_albumentations = False
|
||||
letter_box_transform = processor_kwargs.get("letter_box_transform", False)
|
||||
if not isinstance(letter_box_transform, bool):
|
||||
letter_box_transform = False
|
||||
|
||||
valid_action_horizon = _load_n1_7_checkpoint_action_horizon(processor_kwargs, config.embodiment_tag)
|
||||
video_horizon = _load_n1_7_checkpoint_video_horizon(processor_kwargs, config.embodiment_tag)
|
||||
@@ -225,6 +229,7 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
|
||||
shortest_image_edge=as_optional_int(processor_kwargs.get("shortest_image_edge")),
|
||||
crop_fraction=as_optional_float(processor_kwargs.get("crop_fraction")),
|
||||
use_albumentations=use_albumentations,
|
||||
letter_box_transform=letter_box_transform,
|
||||
)
|
||||
|
||||
|
||||
@@ -1058,6 +1063,7 @@ def _build_n1_7_relative_action_processor_assets(
|
||||
shortest_image_edge=base_assets.shortest_image_edge if base_assets is not None else None,
|
||||
crop_fraction=base_assets.crop_fraction if base_assets is not None else None,
|
||||
use_albumentations=base_assets.use_albumentations if base_assets is not None else False,
|
||||
letter_box_transform=base_assets.letter_box_transform if base_assets is not None else False,
|
||||
)
|
||||
|
||||
|
||||
@@ -1179,6 +1185,7 @@ def make_groot_pre_post_processors(
|
||||
shortest_image_edge = None
|
||||
crop_fraction = None
|
||||
use_albumentations = checkpoint_assets.use_albumentations if checkpoint_assets is not None else False
|
||||
letter_box_transform = checkpoint_assets.letter_box_transform if checkpoint_assets is not None else False
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
@@ -1191,6 +1198,7 @@ def make_groot_pre_post_processors(
|
||||
shortest_image_edge=shortest_image_edge,
|
||||
crop_fraction=crop_fraction,
|
||||
use_albumentations=use_albumentations,
|
||||
letter_box_transform=letter_box_transform,
|
||||
device=config.device,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
@@ -1315,6 +1323,7 @@ def _transform_n1_7_image_for_vlm_albumentations(
|
||||
image_target_size: list[int] | None,
|
||||
shortest_image_edge: int | None,
|
||||
crop_fraction: float | None,
|
||||
letter_box_transform: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""cv2/INTER_AREA eval transform mirroring Isaac-GR00T's albumentations preprocessing.
|
||||
|
||||
@@ -1339,6 +1348,18 @@ def _transform_n1_7_image_for_vlm_albumentations(
|
||||
if not image_np.flags.c_contiguous:
|
||||
image_np = np.ascontiguousarray(image_np)
|
||||
|
||||
if letter_box_transform:
|
||||
height, width = image_np.shape[:2]
|
||||
if height != width:
|
||||
square_edge = max(height, width)
|
||||
pad_h = square_edge - height
|
||||
pad_w = square_edge - width
|
||||
top = pad_h // 2
|
||||
bottom = pad_h - top
|
||||
left = pad_w // 2
|
||||
right = pad_w - left
|
||||
image_np = cv2.copyMakeBorder(image_np, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
|
||||
|
||||
resize_edge = shortest_image_edge or target_h
|
||||
|
||||
def resize_shortest_edge(frame: np.ndarray) -> np.ndarray:
|
||||
@@ -1377,9 +1398,12 @@ def _transform_n1_7_image_for_vlm_torch(
|
||||
image_target_size: list[int] | None,
|
||||
shortest_image_edge: int | None,
|
||||
crop_fraction: float | None,
|
||||
letter_box_transform: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Default (non-albumentations) N1.7 image transform: pad-to-square, resize to
|
||||
``shortest_image_edge``, center-crop by ``crop_fraction``, resize to ``image_target_size``.
|
||||
"""Default (non-albumentations) N1.7 image transform.
|
||||
|
||||
Optionally pads to square, then resizes to ``shortest_image_edge``, center-crops
|
||||
by ``crop_fraction``, and resizes to ``image_target_size``.
|
||||
|
||||
Operates on a ``(C, H, W)`` uint8 tensor and keeps the result on the input
|
||||
tensor's device so the resize/crop run on GPU when the tensor is. Bicubic
|
||||
@@ -1394,13 +1418,14 @@ def _transform_n1_7_image_for_vlm_torch(
|
||||
target_h, target_w = image_target_size
|
||||
_, height, width = image.shape
|
||||
|
||||
square_edge = max(height, width)
|
||||
if height != width:
|
||||
left = (square_edge - width) // 2
|
||||
top = (square_edge - height) // 2
|
||||
image = tv_functional.pad(
|
||||
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
|
||||
)
|
||||
if letter_box_transform:
|
||||
square_edge = max(height, width)
|
||||
if height != width:
|
||||
left = (square_edge - width) // 2
|
||||
top = (square_edge - height) // 2
|
||||
image = tv_functional.pad(
|
||||
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
|
||||
)
|
||||
|
||||
resize_edge = shortest_image_edge or target_h
|
||||
image = tv_functional.resize(
|
||||
@@ -1945,6 +1970,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
shortest_image_edge: int | None = None
|
||||
crop_fraction: float | None = None
|
||||
use_albumentations: bool = False
|
||||
letter_box_transform: bool = False
|
||||
device: str | None = None
|
||||
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
|
||||
|
||||
@@ -1986,6 +2012,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
image_target_size=self.image_target_size,
|
||||
shortest_image_edge=self.shortest_image_edge,
|
||||
crop_fraction=self.crop_fraction,
|
||||
letter_box_transform=self.letter_box_transform,
|
||||
)
|
||||
for timestep in range(video_np.shape[1])
|
||||
for view_idx in range(video_np.shape[2])
|
||||
@@ -2010,6 +2037,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
image_target_size=self.image_target_size,
|
||||
shortest_image_edge=self.shortest_image_edge,
|
||||
crop_fraction=self.crop_fraction,
|
||||
letter_box_transform=self.letter_box_transform,
|
||||
)
|
||||
for timestep in range(sample.shape[0])
|
||||
for view_idx in range(sample.shape[1])
|
||||
@@ -2083,6 +2111,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
"shortest_image_edge": self.shortest_image_edge,
|
||||
"crop_fraction": self.crop_fraction,
|
||||
"use_albumentations": self.use_albumentations,
|
||||
"letter_box_transform": self.letter_box_transform,
|
||||
"device": self.device,
|
||||
}
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ from lerobot.policies.groot.processor_groot import (
|
||||
N1_7_NATIVE_ACTION_HORIZON,
|
||||
_make_relative_action_training_stats,
|
||||
_transform_n1_7_image_for_vlm_albumentations,
|
||||
_transform_n1_7_image_for_vlm_torch,
|
||||
make_groot_pre_post_processors,
|
||||
)
|
||||
from lerobot.processor import (
|
||||
@@ -245,6 +246,7 @@ def _write_raw_n1_7_libero_checkpoint(path):
|
||||
"shortest_image_edge": 256,
|
||||
"crop_fraction": 0.95,
|
||||
"use_albumentations": True,
|
||||
"letter_box_transform": False,
|
||||
"max_action_horizon": 40,
|
||||
"max_state_dim": 132,
|
||||
"max_action_dim": 132,
|
||||
@@ -609,6 +611,7 @@ def test_raw_n1_7_libero_checkpoint_processors_use_checkpoint_assets(tmp_path):
|
||||
assert vlm_encode.shortest_image_edge == 256
|
||||
assert vlm_encode.crop_fraction == 0.95
|
||||
assert vlm_encode.use_albumentations is True
|
||||
assert vlm_encode.letter_box_transform is False
|
||||
assert decode_actions.raw_stats["action"]["gripper"]["q99"] == [115.0]
|
||||
assert decode_actions.env_action_dim == 7
|
||||
assert decode_actions.use_percentiles is True
|
||||
@@ -682,6 +685,7 @@ def test_groot_n1_7_saved_processors_round_trip_checkpoint_specific_fields(tmp_p
|
||||
config_filename="policy_postprocessor.json",
|
||||
)
|
||||
pack_inputs = next(step for step in loaded_preprocessor.steps if isinstance(step, GrootN17PackInputsStep))
|
||||
vlm_encode = next(step for step in loaded_preprocessor.steps if isinstance(step, GrootN17VLMEncodeStep))
|
||||
decode_actions = next(
|
||||
step for step in loaded_postprocessor.steps if isinstance(step, GrootN17ActionDecodeStep)
|
||||
)
|
||||
@@ -690,6 +694,7 @@ def test_groot_n1_7_saved_processors_round_trip_checkpoint_specific_fields(tmp_p
|
||||
assert pack_inputs.action_horizon == 40
|
||||
assert pack_inputs.video_modality_keys == ["image", "wrist_image"]
|
||||
assert pack_inputs.clip_outliers is True
|
||||
assert vlm_encode.letter_box_transform is False
|
||||
torch.testing.assert_close(
|
||||
pack_inputs.stats[OBS_STATE]["min"],
|
||||
torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
|
||||
@@ -1858,6 +1863,58 @@ def test_groot_n1_7_vlm_image_transform_matches_albumentations_eval_path():
|
||||
np.testing.assert_array_equal(np.asarray(transformed), expected)
|
||||
|
||||
|
||||
def test_groot_n1_7_albumentations_letterbox_is_opt_in():
|
||||
pytest.importorskip("cv2", exc_type=ImportError)
|
||||
|
||||
image = np.full((3, 5, 3), 255, dtype=np.uint8)
|
||||
|
||||
default = _transform_n1_7_image_for_vlm_albumentations(
|
||||
image,
|
||||
image_crop_size=None,
|
||||
image_target_size=[10, 10],
|
||||
shortest_image_edge=10,
|
||||
crop_fraction=None,
|
||||
)
|
||||
letterboxed = _transform_n1_7_image_for_vlm_albumentations(
|
||||
image,
|
||||
image_crop_size=None,
|
||||
image_target_size=[10, 10],
|
||||
shortest_image_edge=10,
|
||||
crop_fraction=None,
|
||||
letter_box_transform=True,
|
||||
)
|
||||
|
||||
assert default.shape == (10, 17, 3)
|
||||
assert default.min() == 255
|
||||
assert letterboxed.shape == (10, 10, 3)
|
||||
assert letterboxed.min() < 255
|
||||
|
||||
|
||||
def test_groot_n1_7_torch_letterbox_is_opt_in():
|
||||
image = torch.full((3, 3, 5), 255, dtype=torch.uint8)
|
||||
|
||||
default = _transform_n1_7_image_for_vlm_torch(
|
||||
image,
|
||||
image_crop_size=None,
|
||||
image_target_size=[10, 10],
|
||||
shortest_image_edge=10,
|
||||
crop_fraction=None,
|
||||
)
|
||||
letterboxed = _transform_n1_7_image_for_vlm_torch(
|
||||
image,
|
||||
image_crop_size=None,
|
||||
image_target_size=[10, 10],
|
||||
shortest_image_edge=10,
|
||||
crop_fraction=None,
|
||||
letter_box_transform=True,
|
||||
)
|
||||
|
||||
assert tuple(default.shape) == (3, 10, 10)
|
||||
assert int(default.min()) == 255
|
||||
assert tuple(letterboxed.shape) == (3, 10, 10)
|
||||
assert int(letterboxed.min()) < 255
|
||||
|
||||
|
||||
def test_groot_n1_7_vlm_encode_transforms_non_square_two_camera_sample_like_core_albumentations():
|
||||
cv2 = pytest.importorskip("cv2", exc_type=ImportError)
|
||||
|
||||
@@ -1928,6 +1985,7 @@ def test_groot_n1_7_vlm_encode_config_round_trips_model_name():
|
||||
shortest_image_edge=256,
|
||||
crop_fraction=0.95,
|
||||
use_albumentations=True,
|
||||
letter_box_transform=True,
|
||||
)
|
||||
|
||||
restored = GrootN17VLMEncodeStep(**step.get_config())
|
||||
@@ -1938,6 +1996,7 @@ def test_groot_n1_7_vlm_encode_config_round_trips_model_name():
|
||||
assert restored.shortest_image_edge == 256
|
||||
assert restored.crop_fraction == 0.95
|
||||
assert restored.use_albumentations is True
|
||||
assert restored.letter_box_transform is True
|
||||
|
||||
|
||||
def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch):
|
||||
|
||||
Reference in New Issue
Block a user