From 1fcc10079082753404c4810d7d4b5cb3e68d77fb Mon Sep 17 00:00:00 2001 From: Andy Wrenn Date: Sun, 28 Jun 2026 12:43:52 -0700 Subject: [PATCH] Match GR00T N1.7 OSS preprocessing and relative actions --- src/lerobot/policies/groot/groot_n1_7.py | 5 + src/lerobot/policies/groot/processor_groot.py | 41 ++- tests/policies/groot/test_groot_n1_7.py | 110 ++++++-- .../groot/test_groot_n1_7_oss_parity.py | 254 ++++++++++++++++++ 4 files changed, 360 insertions(+), 50 deletions(-) create mode 100644 tests/policies/groot/test_groot_n1_7_oss_parity.py diff --git a/src/lerobot/policies/groot/groot_n1_7.py b/src/lerobot/policies/groot/groot_n1_7.py index 517512237..5a49ceed2 100644 --- a/src/lerobot/policies/groot/groot_n1_7.py +++ b/src/lerobot/policies/groot/groot_n1_7.py @@ -255,6 +255,11 @@ class Qwen3Backbone(nn.Module): load_pretrained_weights: bool = True, ): require_package("transformers", extra="groot") + if Qwen3VLForConditionalGeneration is None: + raise ImportError( + "Qwen3VLForConditionalGeneration is required for GR00T N1.7. " + "Install a transformers version with Qwen3-VL support." + ) super().__init__() transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True} diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index 6c68aa693..62ad34094 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -552,6 +552,7 @@ def _reconnect_groot_n1_7_pack_decode_steps( step.pack_step = pack_step + def _resolve_feature_names_from_dataset_meta(dataset_meta: Any | None, feature_key: str) -> list[str] | None: features = getattr(dataset_meta, "features", {}) or {} feature = features.get(feature_key) if isinstance(features, dict) else None @@ -634,9 +635,10 @@ def _relative_action_chunks_by_horizon( if pad_mask is not None: mask = torch.as_tensor(pad_mask, dtype=torch.bool).cpu() if mask.ndim == 1 and batch_size == 1 and mask.numel() == horizon: - keep[0] = ~mask + keep[0, :] = not bool(mask.any()) elif mask.ndim == 2 and tuple(mask.shape) == (batch_size, horizon): - keep = ~mask + complete_chunks = ~mask.any(dim=1) + keep = complete_chunks[:, None].expand(batch_size, horizon).clone() chunks: list[list[np.ndarray]] = [[] for _ in range(horizon)] relative_np = relative_action.detach().cpu().numpy() @@ -1308,24 +1310,23 @@ def _transform_n1_7_image_for_vlm_albumentations( if not image_np.flags.c_contiguous: image_np = np.ascontiguousarray(image_np) - height, width = image_np.shape[:2] - if height != width: - square_edge = max(height, width) - pad_h = square_edge - height - pad_w = square_edge - width - image_np = cv2.copyMakeBorder( - image_np, - pad_h // 2, - pad_h - pad_h // 2, - pad_w // 2, - pad_w - pad_w // 2, - cv2.BORDER_CONSTANT, - value=(0, 0, 0), + resize_edge = shortest_image_edge or target_h + + def resize_shortest_edge(frame: np.ndarray) -> np.ndarray: + height, width = frame.shape[:2] + shortest_edge = min(height, width) + if shortest_edge == resize_edge: + return frame + scale = resize_edge / float(shortest_edge) + resized_height = max(1, int(round(height * scale))) + resized_width = max(1, int(round(width * scale))) + return cv2.resize( + frame, + (resized_width, resized_height), + interpolation=cv2.INTER_AREA, ) - resize_edge = shortest_image_edge or target_h - if image_np.shape[:2] != (resize_edge, resize_edge): - image_np = cv2.resize(image_np, (resize_edge, resize_edge), interpolation=cv2.INTER_AREA) + image_np = resize_shortest_edge(image_np) if crop_fraction is None and image_crop_size is not None: crop_fraction = image_crop_size[0] / float(target_h) @@ -1337,9 +1338,7 @@ def _transform_n1_7_image_for_vlm_albumentations( left = max(0, (width - crop_w) // 2) image_np = image_np[top : top + crop_h, left : left + crop_w] - if image_np.shape[:2] != (target_h, target_w): - image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA) - return image_np + return resize_shortest_edge(image_np) def _transform_n1_7_image_for_vlm_torch( diff --git a/tests/policies/groot/test_groot_n1_7.py b/tests/policies/groot/test_groot_n1_7.py index 4a45393f8..6f31b5695 100644 --- a/tests/policies/groot/test_groot_n1_7.py +++ b/tests/policies/groot/test_groot_n1_7.py @@ -301,28 +301,22 @@ def _stats(values): def _expected_albumentations_eval_image(image_np, cv2, *, target_size, shortest_edge, crop_fraction): - height, width = image_np.shape[:2] - if height != width: - square_edge = max(height, width) - pad_h = square_edge - height - pad_w = square_edge - width - image_np = cv2.copyMakeBorder( - image_np, - pad_h // 2, - pad_h - pad_h // 2, - pad_w // 2, - pad_w - pad_w // 2, - cv2.BORDER_CONSTANT, - value=(0, 0, 0), - ) + del target_size - image_np = cv2.resize(image_np, (shortest_edge, shortest_edge), interpolation=cv2.INTER_AREA) - crop_h = max(1, int(shortest_edge * crop_fraction)) - crop_w = max(1, int(shortest_edge * crop_fraction)) - top = (shortest_edge - crop_h) // 2 - left = (shortest_edge - crop_w) // 2 - image_np = image_np[top : top + crop_h, left : left + crop_w] - return cv2.resize(image_np, (target_size[1], target_size[0]), interpolation=cv2.INTER_AREA) + def resize_shortest_edge(frame): + height, width = frame.shape[:2] + scale = shortest_edge / float(min(height, width)) + resized_height = max(1, int(round(height * scale))) + resized_width = max(1, int(round(width * scale))) + return cv2.resize(frame, (resized_width, resized_height), interpolation=cv2.INTER_AREA) + + image_np = resize_shortest_edge(image_np) + height, width = image_np.shape[:2] + crop_h = max(1, int(height * crop_fraction)) + crop_w = max(1, int(width * crop_fraction)) + top = (height - crop_h) // 2 + left = (width - crop_w) // 2 + return resize_shortest_edge(image_np[top : top + crop_h, left : left + crop_w]) class _DummyGrootModel(nn.Module): @@ -1588,7 +1582,9 @@ def test_groot_n1_7_vlm_encode_uses_per_sample_language(): self.encoded_texts = None def apply_chat_template(self, conversation, tokenize, add_generation_prompt): - text = conversation[0]["content"][-1]["text"] + content = conversation[0]["content"] + assert [item["type"] for item in content] == ["image", "text"] + text = content[-1]["text"] self.rendered_texts.append(text) return f"rendered:{text}" @@ -1626,6 +1622,7 @@ def test_groot_n1_7_vlm_encode_packs_images_time_major_then_camera_order(): class FakeProcessor: def __init__(self): self.add_generation_prompts = [] + self.conversation_content_types = [] self.conversation_image_values = [] self.conversation_texts = [] self.encoded_texts = None @@ -1635,6 +1632,7 @@ def test_groot_n1_7_vlm_encode_packs_images_time_major_then_camera_order(): assert tokenize is False self.add_generation_prompts.append(add_generation_prompt) content = conversation[0]["content"] + self.conversation_content_types.append([item["type"] for item in content]) self.conversation_image_values.append( [int(np.asarray(item["image"])[0, 0, 0]) for item in content if item["type"] == "image"] ) @@ -1672,6 +1670,10 @@ def test_groot_n1_7_vlm_encode_packs_images_time_major_then_camera_order(): output = step(transition) assert fake_proc.conversation_image_values == [[1, 2, 3, 4], [5, 6, 7, 8]] + assert fake_proc.conversation_content_types == [ + ["image", "image", "image", "image", "text"], + ["image", "image", "image", "image", "text"], + ] assert fake_proc.encoded_image_values == [1, 2, 3, 4, 5, 6, 7, 8] assert fake_proc.conversation_texts == ["task a", "task b"] assert fake_proc.encoded_texts == ["rendered:task a", "rendered:task b"] @@ -1705,7 +1707,7 @@ def test_groot_n1_7_vlm_image_transform_matches_albumentations_eval_path(): expected = expected[crop_start : crop_start + crop_edge, crop_start : crop_start + crop_edge] expected = cv2.resize(expected, (256, 256), interpolation=cv2.INTER_AREA) - assert transformed.size == (256, 256) + assert transformed.shape == (256, 256, 3) np.testing.assert_array_equal(np.asarray(transformed), expected) @@ -1717,7 +1719,9 @@ def test_groot_n1_7_vlm_encode_transforms_non_square_two_camera_sample_like_core self.images = None def apply_chat_template(self, conversation, tokenize, add_generation_prompt): - return conversation[0]["content"][-1]["text"] + content = conversation[0]["content"] + assert [item["type"] for item in content] == ["image", "image", "text"] + return content[-1]["text"] def __call__(self, text, images, return_tensors, padding): self.images = images @@ -1792,7 +1796,6 @@ def test_groot_n1_7_vlm_encode_config_round_trips_model_name(): def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch): pytest.importorskip("transformers") - import transformers from lerobot.policies.groot import processor_groot @@ -1833,10 +1836,10 @@ def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch): cls.from_pretrained_called = True raise AssertionError("Cosmos does not publish processor_config.json") - monkeypatch.setattr(transformers, "AutoTokenizer", FakeTokenizer) - monkeypatch.setattr(transformers, "Qwen2VLImageProcessorFast", FakeImageProcessor) - monkeypatch.setattr(transformers, "Qwen3VLVideoProcessor", FakeVideoProcessor) - monkeypatch.setattr(transformers, "Qwen3VLProcessor", FakeProcessor) + monkeypatch.setattr(processor_groot, "AutoTokenizer", FakeTokenizer) + monkeypatch.setattr(processor_groot, "Qwen2VLImageProcessor", FakeImageProcessor) + monkeypatch.setattr(processor_groot, "Qwen3VLVideoProcessor", FakeVideoProcessor) + monkeypatch.setattr(processor_groot, "Qwen3VLProcessor", FakeProcessor) processor = processor_groot._build_n1_7_processor("nvidia/Cosmos-Reason2-2B") @@ -2306,6 +2309,55 @@ def test_groot_n1_7_generated_relative_stats_match_oss_gr00t_reference_numbers() torch.testing.assert_close(decoded[TransitionKey.ACTION], action_a.unsqueeze(0), atol=1e-5, rtol=1e-5) +def test_groot_n1_7_relative_action_stats_skip_padded_tail_chunks(): + samples = [ + { + OBS_STATE: torch.tensor([10.0, 100.0]), + ACTION: torch.tensor([[11.0, 101.0], [12.0, 102.0], [13.0, 103.0]]), + f"{ACTION}_is_pad": torch.tensor([False, False, False]), + }, + { + OBS_STATE: torch.tensor([20.0, 200.0]), + ACTION: torch.tensor([[18.0, 198.0], [16.0, 196.0], [14.0, 194.0]]), + f"{ACTION}_is_pad": torch.tensor([False, False, False]), + }, + { + OBS_STATE: torch.tensor([0.0, 0.0]), + ACTION: torch.tensor([[999.0, 999.0], [888.0, 888.0], [777.0, 777.0]]), + f"{ACTION}_is_pad": torch.tensor([False, False, True]), + }, + ] + + class _Dataset: + meta = SimpleNamespace(stats={}) + + def __len__(self): + return len(samples) + + def __getitem__(self, idx): + return samples[idx] + + relative_dataset_stats = _make_relative_action_training_stats( + _Dataset(), + exclude_joints=[], + action_names=None, + preserve_action_horizon=True, + ) + + torch.testing.assert_close( + torch.as_tensor(relative_dataset_stats[ACTION]["count"]), + torch.tensor([2, 2, 2]), + ) + torch.testing.assert_close( + torch.as_tensor(relative_dataset_stats[ACTION]["min"]), + torch.tensor([[-2.0, -2.0], [-4.0, -4.0], [-6.0, -6.0]]), + ) + torch.testing.assert_close( + torch.as_tensor(relative_dataset_stats[ACTION]["max"]), + torch.tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]), + ) + + def test_groot_policy_selects_n1_7_model_class(monkeypatch): from lerobot.policies.groot.groot_n1_7 import GR00TN17 diff --git a/tests/policies/groot/test_groot_n1_7_oss_parity.py b/tests/policies/groot/test_groot_n1_7_oss_parity.py new file mode 100644 index 000000000..c0d5beddc --- /dev/null +++ b/tests/policies/groot/test_groot_n1_7_oss_parity.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import os +from pathlib import Path + +import numpy as np +import pytest +import torch +from transformers.feature_extraction_utils import BatchFeature + +from lerobot.policies.groot.action_head.cross_attention_dit import AlternateVLDiT +from lerobot.policies.groot.groot_n1_7 import GR00TN17 +from lerobot.policies.groot.processor_groot import ( + GrootN17ActionDecodeStep, + GrootN17PackInputsStep, + GrootN17VLMEncodeStep, + _transform_n1_7_image_for_vlm_albumentations, +) +from lerobot.types import TransitionKey +from lerobot.utils.constants import OBS_STATE + +OSS_REFERENCE_COMMIT = "ab88b50c718f6528e1df9dcbaf75865d1b604760" + + +def _fixture_path(filename: str) -> Path: + fixture_dir = os.environ.get("GROOT_N17_OSS_PARITY_FIXTURE_DIR") + if fixture_dir is None: + pytest.skip("Set GROOT_N17_OSS_PARITY_FIXTURE_DIR to run external OSS parity fixtures.") + path = Path(fixture_dir) / filename + if not path.is_file(): + pytest.skip(f"External OSS parity fixture not found: {path}") + return path + + +def test_groot_n1_7_eval_image_transform_matches_oss_reference(): + """Match the native N1.7 eval transform for a non-square SO-101 frame.""" + + y, x = np.indices((480, 640), dtype=np.uint16) + image = np.stack( + ((x + 3 * y) % 256, (2 * x + y) % 256, (x + 5 * y) % 256), + axis=-1, + ).astype(np.uint8) + actual = _transform_n1_7_image_for_vlm_albumentations( + image, + image_crop_size=[230, 230], + image_target_size=[256, 256], + shortest_image_edge=256, + crop_fraction=0.95, + ) + + assert actual.shape == (256, 340, 3) + assert hashlib.sha256(actual.tobytes()).hexdigest() == ( + "c17e47af68a812aa79db3bb7b64b549ddf10148ac1b204a9686095018561ae9e" + ) + + +def test_groot_n1_7_vlm_chat_content_order_matches_oss_reference(): + """Native OSS places all image items before the language item.""" + + class RecordingProcessor: + def __init__(self): + self.content_types = None + + def apply_chat_template(self, conversation, tokenize, add_generation_prompt): + assert tokenize is False + assert add_generation_prompt is False + self.content_types = [item["type"] for item in conversation[0]["content"]] + return "rendered" + + def __call__(self, **kwargs): + return {} + + processor = RecordingProcessor() + step = GrootN17VLMEncodeStep( + image_crop_size=[230, 230], + image_target_size=[256, 256], + shortest_image_edge=256, + crop_fraction=0.95, + use_albumentations=True, + device="cpu", + ) + step._proc = processor + transition = { + TransitionKey.OBSERVATION: { + "video": np.zeros((1, 1, 2, 480, 640, 3), dtype=np.uint8), + }, + TransitionKey.COMPLEMENTARY_DATA: {"language": ["pick up the vial"]}, + } + + step(transition) + + assert processor.content_types == ["image", "image", "text"] + + +def test_groot_n1_7_alternate_vl_dit_matches_oss_reference(): + """Run the LeRobot DiT with native OSS weights and identical inputs.""" + + fixture = torch.load(_fixture_path("alternate_vl_dit_small.pt"), map_location="cpu", weights_only=True) + model = AlternateVLDiT( + output_dim=8, + num_attention_heads=2, + attention_head_dim=4, + num_layers=4, + dropout=0.0, + final_dropout=False, + max_num_positional_embeddings=16, + compute_dtype=torch.float32, + interleave_self_attention=True, + cross_attention_dim=6, + ).eval() + model.load_state_dict(fixture["state_dict"], strict=True) + + actual = model( + hidden_states=fixture["hidden_states"], + encoder_hidden_states=fixture["encoder_hidden_states"], + timestep=fixture["timestep"], + image_mask=fixture["image_mask"], + backbone_attention_mask=fixture["backbone_attention_mask"], + ) + + torch.testing.assert_close(actual, fixture["output"], atol=1e-6, rtol=1e-6) + + +def _state_decode_reference(): + fixture = np.load(_fixture_path("state_and_action_decode.npz")) + raw_stats = { + "state": { + "single_arm": {"q01": fixture["state_single_arm_q01"], "q99": fixture["state_single_arm_q99"]}, + "gripper": {"q01": fixture["state_gripper_q01"], "q99": fixture["state_gripper_q99"]}, + }, + "action": { + "single_arm": {"q01": fixture["action_single_arm_q01"], "q99": fixture["action_single_arm_q99"]}, + "gripper": {"q01": fixture["action_gripper_q01"], "q99": fixture["action_gripper_q99"]}, + }, + "relative_action": { + "single_arm": { + "min": fixture["relative_single_arm_min"], + "max": fixture["relative_single_arm_max"], + }, + }, + } + for modality_stats in raw_stats.values(): + for entry in modality_stats.values(): + for key, value in entry.items(): + if isinstance(value, np.ndarray): + entry[key] = value.tolist() + modality_config = { + "state": {"modality_keys": ["single_arm", "gripper"]}, + "action": { + "delta_indices": list(range(16)), + "modality_keys": ["single_arm", "gripper"], + "action_configs": [ + {"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None}, + {"rep": "ABSOLUTE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None}, + ], + }, + } + state_min = np.concatenate((fixture["state_single_arm_q01"], fixture["state_gripper_q01"])) + state_max = np.concatenate((fixture["state_single_arm_q99"], fixture["state_gripper_q99"])) + pack_step = GrootN17PackInputsStep( + normalize_min_max=True, + stats={OBS_STATE: {"min": state_min, "max": state_max}}, + raw_stats=raw_stats, + modality_config=modality_config, + use_percentiles=True, + ) + raw_state = np.concatenate((fixture["state_single_arm"], fixture["state_gripper"]), axis=-1) + transition = { + TransitionKey.OBSERVATION: {OBS_STATE: torch.from_numpy(raw_state)}, + TransitionKey.COMPLEMENTARY_DATA: {}, + } + packed = pack_step(transition) + return fixture, raw_stats, modality_config, pack_step, packed + + +def test_groot_n1_7_state_normalization_matches_oss_checkpoint_reference(): + fixture, _raw_stats, _modality_config, _pack_step, packed = _state_decode_reference() + expected = np.concatenate( + (fixture["normalized_state_single_arm"], fixture["normalized_state_gripper"]), axis=-1 + ) + + actual = packed[TransitionKey.OBSERVATION]["state"][:, 0, :6] + + torch.testing.assert_close(actual, torch.from_numpy(expected), atol=1e-6, rtol=1e-6) + + +def test_groot_n1_7_relative_action_decode_matches_oss_checkpoint_reference(): + fixture, raw_stats, modality_config, pack_step, _packed = _state_decode_reference() + decode_step = GrootN17ActionDecodeStep( + env_action_dim=6, + raw_stats=raw_stats, + modality_config=modality_config, + use_percentiles=True, + use_relative_action=True, + pack_step=pack_step, + ) + decoded = decode_step({TransitionKey.ACTION: torch.from_numpy(fixture["normalized_action"])})[ + TransitionKey.ACTION + ] + expected = np.concatenate((fixture["decoded_single_arm"], fixture["decoded_gripper"]), axis=-1).astype( + np.float32 + ) + + torch.testing.assert_close(decoded, torch.from_numpy(expected), atol=1e-5, rtol=1e-5) + + +def test_groot_n1_7_qwen_backbone_matches_oss_checkpoint_reference(): + """Compare the actual 3B checkpoint backbone when explicitly enabled.""" + + checkpoint = os.environ.get("GROOT_N17_PARITY_CHECKPOINT") + if checkpoint is None: + pytest.skip("Set GROOT_N17_PARITY_CHECKPOINT to run the 3B OSS Qwen parity test.") + if not torch.cuda.is_available(): + pytest.skip("The 3B OSS Qwen parity test requires CUDA.") + + fixture = torch.load(_fixture_path("qwen_backbone_so101.pt"), map_location="cpu", weights_only=True) + model = GR00TN17.from_pretrained(checkpoint).to(device="cuda", dtype=torch.bfloat16).eval() + backbone_input = BatchFeature( + data={ + key.removeprefix("input."): value.to("cuda") + for key, value in fixture.items() + if key.startswith("input.") + } + ) + + with torch.inference_mode(): + actual = model.backbone(backbone_input) + + feature_error = ( + actual.backbone_features.cpu().float() - fixture["output.backbone_features"].float() + ).abs() + # Native OSS and LeRobot use different Torch/Transformers/Flash-Attention releases. + # Require the measured BF16 accumulation envelope while rejecting structural drift. + assert feature_error.mean().item() <= 0.04 + assert feature_error.max().item() <= 2.0 + torch.testing.assert_close( + actual.backbone_attention_mask.cpu(), fixture["output.backbone_attention_mask"] + ) + torch.testing.assert_close(actual.image_mask.cpu(), fixture["output.image_mask"])