Match GR00T N1.7 OSS preprocessing and relative actions

This commit is contained in:
Andy Wrenn
2026-06-28 12:43:52 -07:00
parent 6126a85d60
commit 1fcc100790
4 changed files with 360 additions and 50 deletions
+5
View File
@@ -255,6 +255,11 @@ class Qwen3Backbone(nn.Module):
load_pretrained_weights: bool = True,
):
require_package("transformers", extra="groot")
if Qwen3VLForConditionalGeneration is None:
raise ImportError(
"Qwen3VLForConditionalGeneration is required for GR00T N1.7. "
"Install a transformers version with Qwen3-VL support."
)
super().__init__()
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
+20 -21
View File
@@ -552,6 +552,7 @@ def _reconnect_groot_n1_7_pack_decode_steps(
step.pack_step = pack_step
def _resolve_feature_names_from_dataset_meta(dataset_meta: Any | None, feature_key: str) -> list[str] | None:
features = getattr(dataset_meta, "features", {}) or {}
feature = features.get(feature_key) if isinstance(features, dict) else None
@@ -634,9 +635,10 @@ def _relative_action_chunks_by_horizon(
if pad_mask is not None:
mask = torch.as_tensor(pad_mask, dtype=torch.bool).cpu()
if mask.ndim == 1 and batch_size == 1 and mask.numel() == horizon:
keep[0] = ~mask
keep[0, :] = not bool(mask.any())
elif mask.ndim == 2 and tuple(mask.shape) == (batch_size, horizon):
keep = ~mask
complete_chunks = ~mask.any(dim=1)
keep = complete_chunks[:, None].expand(batch_size, horizon).clone()
chunks: list[list[np.ndarray]] = [[] for _ in range(horizon)]
relative_np = relative_action.detach().cpu().numpy()
@@ -1308,24 +1310,23 @@ def _transform_n1_7_image_for_vlm_albumentations(
if not image_np.flags.c_contiguous:
image_np = np.ascontiguousarray(image_np)
height, width = image_np.shape[:2]
if height != width:
square_edge = max(height, width)
pad_h = square_edge - height
pad_w = square_edge - width
image_np = cv2.copyMakeBorder(
image_np,
pad_h // 2,
pad_h - pad_h // 2,
pad_w // 2,
pad_w - pad_w // 2,
cv2.BORDER_CONSTANT,
value=(0, 0, 0),
resize_edge = shortest_image_edge or target_h
def resize_shortest_edge(frame: np.ndarray) -> np.ndarray:
height, width = frame.shape[:2]
shortest_edge = min(height, width)
if shortest_edge == resize_edge:
return frame
scale = resize_edge / float(shortest_edge)
resized_height = max(1, int(round(height * scale)))
resized_width = max(1, int(round(width * scale)))
return cv2.resize(
frame,
(resized_width, resized_height),
interpolation=cv2.INTER_AREA,
)
resize_edge = shortest_image_edge or target_h
if image_np.shape[:2] != (resize_edge, resize_edge):
image_np = cv2.resize(image_np, (resize_edge, resize_edge), interpolation=cv2.INTER_AREA)
image_np = resize_shortest_edge(image_np)
if crop_fraction is None and image_crop_size is not None:
crop_fraction = image_crop_size[0] / float(target_h)
@@ -1337,9 +1338,7 @@ def _transform_n1_7_image_for_vlm_albumentations(
left = max(0, (width - crop_w) // 2)
image_np = image_np[top : top + crop_h, left : left + crop_w]
if image_np.shape[:2] != (target_h, target_w):
image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA)
return image_np
return resize_shortest_edge(image_np)
def _transform_n1_7_image_for_vlm_torch(
+81 -29
View File
@@ -301,28 +301,22 @@ def _stats(values):
def _expected_albumentations_eval_image(image_np, cv2, *, target_size, shortest_edge, crop_fraction):
height, width = image_np.shape[:2]
if height != width:
square_edge = max(height, width)
pad_h = square_edge - height
pad_w = square_edge - width
image_np = cv2.copyMakeBorder(
image_np,
pad_h // 2,
pad_h - pad_h // 2,
pad_w // 2,
pad_w - pad_w // 2,
cv2.BORDER_CONSTANT,
value=(0, 0, 0),
)
del target_size
image_np = cv2.resize(image_np, (shortest_edge, shortest_edge), interpolation=cv2.INTER_AREA)
crop_h = max(1, int(shortest_edge * crop_fraction))
crop_w = max(1, int(shortest_edge * crop_fraction))
top = (shortest_edge - crop_h) // 2
left = (shortest_edge - crop_w) // 2
image_np = image_np[top : top + crop_h, left : left + crop_w]
return cv2.resize(image_np, (target_size[1], target_size[0]), interpolation=cv2.INTER_AREA)
def resize_shortest_edge(frame):
height, width = frame.shape[:2]
scale = shortest_edge / float(min(height, width))
resized_height = max(1, int(round(height * scale)))
resized_width = max(1, int(round(width * scale)))
return cv2.resize(frame, (resized_width, resized_height), interpolation=cv2.INTER_AREA)
image_np = resize_shortest_edge(image_np)
height, width = image_np.shape[:2]
crop_h = max(1, int(height * crop_fraction))
crop_w = max(1, int(width * crop_fraction))
top = (height - crop_h) // 2
left = (width - crop_w) // 2
return resize_shortest_edge(image_np[top : top + crop_h, left : left + crop_w])
class _DummyGrootModel(nn.Module):
@@ -1588,7 +1582,9 @@ def test_groot_n1_7_vlm_encode_uses_per_sample_language():
self.encoded_texts = None
def apply_chat_template(self, conversation, tokenize, add_generation_prompt):
text = conversation[0]["content"][-1]["text"]
content = conversation[0]["content"]
assert [item["type"] for item in content] == ["image", "text"]
text = content[-1]["text"]
self.rendered_texts.append(text)
return f"rendered:{text}"
@@ -1626,6 +1622,7 @@ def test_groot_n1_7_vlm_encode_packs_images_time_major_then_camera_order():
class FakeProcessor:
def __init__(self):
self.add_generation_prompts = []
self.conversation_content_types = []
self.conversation_image_values = []
self.conversation_texts = []
self.encoded_texts = None
@@ -1635,6 +1632,7 @@ def test_groot_n1_7_vlm_encode_packs_images_time_major_then_camera_order():
assert tokenize is False
self.add_generation_prompts.append(add_generation_prompt)
content = conversation[0]["content"]
self.conversation_content_types.append([item["type"] for item in content])
self.conversation_image_values.append(
[int(np.asarray(item["image"])[0, 0, 0]) for item in content if item["type"] == "image"]
)
@@ -1672,6 +1670,10 @@ def test_groot_n1_7_vlm_encode_packs_images_time_major_then_camera_order():
output = step(transition)
assert fake_proc.conversation_image_values == [[1, 2, 3, 4], [5, 6, 7, 8]]
assert fake_proc.conversation_content_types == [
["image", "image", "image", "image", "text"],
["image", "image", "image", "image", "text"],
]
assert fake_proc.encoded_image_values == [1, 2, 3, 4, 5, 6, 7, 8]
assert fake_proc.conversation_texts == ["task a", "task b"]
assert fake_proc.encoded_texts == ["rendered:task a", "rendered:task b"]
@@ -1705,7 +1707,7 @@ def test_groot_n1_7_vlm_image_transform_matches_albumentations_eval_path():
expected = expected[crop_start : crop_start + crop_edge, crop_start : crop_start + crop_edge]
expected = cv2.resize(expected, (256, 256), interpolation=cv2.INTER_AREA)
assert transformed.size == (256, 256)
assert transformed.shape == (256, 256, 3)
np.testing.assert_array_equal(np.asarray(transformed), expected)
@@ -1717,7 +1719,9 @@ def test_groot_n1_7_vlm_encode_transforms_non_square_two_camera_sample_like_core
self.images = None
def apply_chat_template(self, conversation, tokenize, add_generation_prompt):
return conversation[0]["content"][-1]["text"]
content = conversation[0]["content"]
assert [item["type"] for item in content] == ["image", "image", "text"]
return content[-1]["text"]
def __call__(self, text, images, return_tensors, padding):
self.images = images
@@ -1792,7 +1796,6 @@ def test_groot_n1_7_vlm_encode_config_round_trips_model_name():
def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch):
pytest.importorskip("transformers")
import transformers
from lerobot.policies.groot import processor_groot
@@ -1833,10 +1836,10 @@ def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch):
cls.from_pretrained_called = True
raise AssertionError("Cosmos does not publish processor_config.json")
monkeypatch.setattr(transformers, "AutoTokenizer", FakeTokenizer)
monkeypatch.setattr(transformers, "Qwen2VLImageProcessorFast", FakeImageProcessor)
monkeypatch.setattr(transformers, "Qwen3VLVideoProcessor", FakeVideoProcessor)
monkeypatch.setattr(transformers, "Qwen3VLProcessor", FakeProcessor)
monkeypatch.setattr(processor_groot, "AutoTokenizer", FakeTokenizer)
monkeypatch.setattr(processor_groot, "Qwen2VLImageProcessor", FakeImageProcessor)
monkeypatch.setattr(processor_groot, "Qwen3VLVideoProcessor", FakeVideoProcessor)
monkeypatch.setattr(processor_groot, "Qwen3VLProcessor", FakeProcessor)
processor = processor_groot._build_n1_7_processor("nvidia/Cosmos-Reason2-2B")
@@ -2306,6 +2309,55 @@ def test_groot_n1_7_generated_relative_stats_match_oss_gr00t_reference_numbers()
torch.testing.assert_close(decoded[TransitionKey.ACTION], action_a.unsqueeze(0), atol=1e-5, rtol=1e-5)
def test_groot_n1_7_relative_action_stats_skip_padded_tail_chunks():
samples = [
{
OBS_STATE: torch.tensor([10.0, 100.0]),
ACTION: torch.tensor([[11.0, 101.0], [12.0, 102.0], [13.0, 103.0]]),
f"{ACTION}_is_pad": torch.tensor([False, False, False]),
},
{
OBS_STATE: torch.tensor([20.0, 200.0]),
ACTION: torch.tensor([[18.0, 198.0], [16.0, 196.0], [14.0, 194.0]]),
f"{ACTION}_is_pad": torch.tensor([False, False, False]),
},
{
OBS_STATE: torch.tensor([0.0, 0.0]),
ACTION: torch.tensor([[999.0, 999.0], [888.0, 888.0], [777.0, 777.0]]),
f"{ACTION}_is_pad": torch.tensor([False, False, True]),
},
]
class _Dataset:
meta = SimpleNamespace(stats={})
def __len__(self):
return len(samples)
def __getitem__(self, idx):
return samples[idx]
relative_dataset_stats = _make_relative_action_training_stats(
_Dataset(),
exclude_joints=[],
action_names=None,
preserve_action_horizon=True,
)
torch.testing.assert_close(
torch.as_tensor(relative_dataset_stats[ACTION]["count"]),
torch.tensor([2, 2, 2]),
)
torch.testing.assert_close(
torch.as_tensor(relative_dataset_stats[ACTION]["min"]),
torch.tensor([[-2.0, -2.0], [-4.0, -4.0], [-6.0, -6.0]]),
)
torch.testing.assert_close(
torch.as_tensor(relative_dataset_stats[ACTION]["max"]),
torch.tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]),
)
def test_groot_policy_selects_n1_7_model_class(monkeypatch):
from lerobot.policies.groot.groot_n1_7 import GR00TN17
@@ -0,0 +1,254 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import os
from pathlib import Path
import numpy as np
import pytest
import torch
from transformers.feature_extraction_utils import BatchFeature
from lerobot.policies.groot.action_head.cross_attention_dit import AlternateVLDiT
from lerobot.policies.groot.groot_n1_7 import GR00TN17
from lerobot.policies.groot.processor_groot import (
GrootN17ActionDecodeStep,
GrootN17PackInputsStep,
GrootN17VLMEncodeStep,
_transform_n1_7_image_for_vlm_albumentations,
)
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_STATE
OSS_REFERENCE_COMMIT = "ab88b50c718f6528e1df9dcbaf75865d1b604760"
def _fixture_path(filename: str) -> Path:
fixture_dir = os.environ.get("GROOT_N17_OSS_PARITY_FIXTURE_DIR")
if fixture_dir is None:
pytest.skip("Set GROOT_N17_OSS_PARITY_FIXTURE_DIR to run external OSS parity fixtures.")
path = Path(fixture_dir) / filename
if not path.is_file():
pytest.skip(f"External OSS parity fixture not found: {path}")
return path
def test_groot_n1_7_eval_image_transform_matches_oss_reference():
"""Match the native N1.7 eval transform for a non-square SO-101 frame."""
y, x = np.indices((480, 640), dtype=np.uint16)
image = np.stack(
((x + 3 * y) % 256, (2 * x + y) % 256, (x + 5 * y) % 256),
axis=-1,
).astype(np.uint8)
actual = _transform_n1_7_image_for_vlm_albumentations(
image,
image_crop_size=[230, 230],
image_target_size=[256, 256],
shortest_image_edge=256,
crop_fraction=0.95,
)
assert actual.shape == (256, 340, 3)
assert hashlib.sha256(actual.tobytes()).hexdigest() == (
"c17e47af68a812aa79db3bb7b64b549ddf10148ac1b204a9686095018561ae9e"
)
def test_groot_n1_7_vlm_chat_content_order_matches_oss_reference():
"""Native OSS places all image items before the language item."""
class RecordingProcessor:
def __init__(self):
self.content_types = None
def apply_chat_template(self, conversation, tokenize, add_generation_prompt):
assert tokenize is False
assert add_generation_prompt is False
self.content_types = [item["type"] for item in conversation[0]["content"]]
return "rendered"
def __call__(self, **kwargs):
return {}
processor = RecordingProcessor()
step = GrootN17VLMEncodeStep(
image_crop_size=[230, 230],
image_target_size=[256, 256],
shortest_image_edge=256,
crop_fraction=0.95,
use_albumentations=True,
device="cpu",
)
step._proc = processor
transition = {
TransitionKey.OBSERVATION: {
"video": np.zeros((1, 1, 2, 480, 640, 3), dtype=np.uint8),
},
TransitionKey.COMPLEMENTARY_DATA: {"language": ["pick up the vial"]},
}
step(transition)
assert processor.content_types == ["image", "image", "text"]
def test_groot_n1_7_alternate_vl_dit_matches_oss_reference():
"""Run the LeRobot DiT with native OSS weights and identical inputs."""
fixture = torch.load(_fixture_path("alternate_vl_dit_small.pt"), map_location="cpu", weights_only=True)
model = AlternateVLDiT(
output_dim=8,
num_attention_heads=2,
attention_head_dim=4,
num_layers=4,
dropout=0.0,
final_dropout=False,
max_num_positional_embeddings=16,
compute_dtype=torch.float32,
interleave_self_attention=True,
cross_attention_dim=6,
).eval()
model.load_state_dict(fixture["state_dict"], strict=True)
actual = model(
hidden_states=fixture["hidden_states"],
encoder_hidden_states=fixture["encoder_hidden_states"],
timestep=fixture["timestep"],
image_mask=fixture["image_mask"],
backbone_attention_mask=fixture["backbone_attention_mask"],
)
torch.testing.assert_close(actual, fixture["output"], atol=1e-6, rtol=1e-6)
def _state_decode_reference():
fixture = np.load(_fixture_path("state_and_action_decode.npz"))
raw_stats = {
"state": {
"single_arm": {"q01": fixture["state_single_arm_q01"], "q99": fixture["state_single_arm_q99"]},
"gripper": {"q01": fixture["state_gripper_q01"], "q99": fixture["state_gripper_q99"]},
},
"action": {
"single_arm": {"q01": fixture["action_single_arm_q01"], "q99": fixture["action_single_arm_q99"]},
"gripper": {"q01": fixture["action_gripper_q01"], "q99": fixture["action_gripper_q99"]},
},
"relative_action": {
"single_arm": {
"min": fixture["relative_single_arm_min"],
"max": fixture["relative_single_arm_max"],
},
},
}
for modality_stats in raw_stats.values():
for entry in modality_stats.values():
for key, value in entry.items():
if isinstance(value, np.ndarray):
entry[key] = value.tolist()
modality_config = {
"state": {"modality_keys": ["single_arm", "gripper"]},
"action": {
"delta_indices": list(range(16)),
"modality_keys": ["single_arm", "gripper"],
"action_configs": [
{"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None},
{"rep": "ABSOLUTE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None},
],
},
}
state_min = np.concatenate((fixture["state_single_arm_q01"], fixture["state_gripper_q01"]))
state_max = np.concatenate((fixture["state_single_arm_q99"], fixture["state_gripper_q99"]))
pack_step = GrootN17PackInputsStep(
normalize_min_max=True,
stats={OBS_STATE: {"min": state_min, "max": state_max}},
raw_stats=raw_stats,
modality_config=modality_config,
use_percentiles=True,
)
raw_state = np.concatenate((fixture["state_single_arm"], fixture["state_gripper"]), axis=-1)
transition = {
TransitionKey.OBSERVATION: {OBS_STATE: torch.from_numpy(raw_state)},
TransitionKey.COMPLEMENTARY_DATA: {},
}
packed = pack_step(transition)
return fixture, raw_stats, modality_config, pack_step, packed
def test_groot_n1_7_state_normalization_matches_oss_checkpoint_reference():
fixture, _raw_stats, _modality_config, _pack_step, packed = _state_decode_reference()
expected = np.concatenate(
(fixture["normalized_state_single_arm"], fixture["normalized_state_gripper"]), axis=-1
)
actual = packed[TransitionKey.OBSERVATION]["state"][:, 0, :6]
torch.testing.assert_close(actual, torch.from_numpy(expected), atol=1e-6, rtol=1e-6)
def test_groot_n1_7_relative_action_decode_matches_oss_checkpoint_reference():
fixture, raw_stats, modality_config, pack_step, _packed = _state_decode_reference()
decode_step = GrootN17ActionDecodeStep(
env_action_dim=6,
raw_stats=raw_stats,
modality_config=modality_config,
use_percentiles=True,
use_relative_action=True,
pack_step=pack_step,
)
decoded = decode_step({TransitionKey.ACTION: torch.from_numpy(fixture["normalized_action"])})[
TransitionKey.ACTION
]
expected = np.concatenate((fixture["decoded_single_arm"], fixture["decoded_gripper"]), axis=-1).astype(
np.float32
)
torch.testing.assert_close(decoded, torch.from_numpy(expected), atol=1e-5, rtol=1e-5)
def test_groot_n1_7_qwen_backbone_matches_oss_checkpoint_reference():
"""Compare the actual 3B checkpoint backbone when explicitly enabled."""
checkpoint = os.environ.get("GROOT_N17_PARITY_CHECKPOINT")
if checkpoint is None:
pytest.skip("Set GROOT_N17_PARITY_CHECKPOINT to run the 3B OSS Qwen parity test.")
if not torch.cuda.is_available():
pytest.skip("The 3B OSS Qwen parity test requires CUDA.")
fixture = torch.load(_fixture_path("qwen_backbone_so101.pt"), map_location="cpu", weights_only=True)
model = GR00TN17.from_pretrained(checkpoint).to(device="cuda", dtype=torch.bfloat16).eval()
backbone_input = BatchFeature(
data={
key.removeprefix("input."): value.to("cuda")
for key, value in fixture.items()
if key.startswith("input.")
}
)
with torch.inference_mode():
actual = model.backbone(backbone_input)
feature_error = (
actual.backbone_features.cpu().float() - fixture["output.backbone_features"].float()
).abs()
# Native OSS and LeRobot use different Torch/Transformers/Flash-Attention releases.
# Require the measured BF16 accumulation envelope while rejecting structural drift.
assert feature_error.mean().item() <= 0.04
assert feature_error.max().item() <= 2.0
torch.testing.assert_close(
actual.backbone_attention_mask.cpu(), fixture["output.backbone_attention_mask"]
)
torch.testing.assert_close(actual.image_mask.cpu(), fixture["output.image_mask"])