mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-04 00:27:15 +00:00
Match GR00T N1.7 OSS preprocessing and relative actions
This commit is contained in:
@@ -255,6 +255,11 @@ class Qwen3Backbone(nn.Module):
|
||||
load_pretrained_weights: bool = True,
|
||||
):
|
||||
require_package("transformers", extra="groot")
|
||||
if Qwen3VLForConditionalGeneration is None:
|
||||
raise ImportError(
|
||||
"Qwen3VLForConditionalGeneration is required for GR00T N1.7. "
|
||||
"Install a transformers version with Qwen3-VL support."
|
||||
)
|
||||
super().__init__()
|
||||
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
|
||||
|
||||
|
||||
@@ -552,6 +552,7 @@ def _reconnect_groot_n1_7_pack_decode_steps(
|
||||
step.pack_step = pack_step
|
||||
|
||||
|
||||
|
||||
def _resolve_feature_names_from_dataset_meta(dataset_meta: Any | None, feature_key: str) -> list[str] | None:
|
||||
features = getattr(dataset_meta, "features", {}) or {}
|
||||
feature = features.get(feature_key) if isinstance(features, dict) else None
|
||||
@@ -634,9 +635,10 @@ def _relative_action_chunks_by_horizon(
|
||||
if pad_mask is not None:
|
||||
mask = torch.as_tensor(pad_mask, dtype=torch.bool).cpu()
|
||||
if mask.ndim == 1 and batch_size == 1 and mask.numel() == horizon:
|
||||
keep[0] = ~mask
|
||||
keep[0, :] = not bool(mask.any())
|
||||
elif mask.ndim == 2 and tuple(mask.shape) == (batch_size, horizon):
|
||||
keep = ~mask
|
||||
complete_chunks = ~mask.any(dim=1)
|
||||
keep = complete_chunks[:, None].expand(batch_size, horizon).clone()
|
||||
|
||||
chunks: list[list[np.ndarray]] = [[] for _ in range(horizon)]
|
||||
relative_np = relative_action.detach().cpu().numpy()
|
||||
@@ -1308,24 +1310,23 @@ def _transform_n1_7_image_for_vlm_albumentations(
|
||||
if not image_np.flags.c_contiguous:
|
||||
image_np = np.ascontiguousarray(image_np)
|
||||
|
||||
height, width = image_np.shape[:2]
|
||||
if height != width:
|
||||
square_edge = max(height, width)
|
||||
pad_h = square_edge - height
|
||||
pad_w = square_edge - width
|
||||
image_np = cv2.copyMakeBorder(
|
||||
image_np,
|
||||
pad_h // 2,
|
||||
pad_h - pad_h // 2,
|
||||
pad_w // 2,
|
||||
pad_w - pad_w // 2,
|
||||
cv2.BORDER_CONSTANT,
|
||||
value=(0, 0, 0),
|
||||
resize_edge = shortest_image_edge or target_h
|
||||
|
||||
def resize_shortest_edge(frame: np.ndarray) -> np.ndarray:
|
||||
height, width = frame.shape[:2]
|
||||
shortest_edge = min(height, width)
|
||||
if shortest_edge == resize_edge:
|
||||
return frame
|
||||
scale = resize_edge / float(shortest_edge)
|
||||
resized_height = max(1, int(round(height * scale)))
|
||||
resized_width = max(1, int(round(width * scale)))
|
||||
return cv2.resize(
|
||||
frame,
|
||||
(resized_width, resized_height),
|
||||
interpolation=cv2.INTER_AREA,
|
||||
)
|
||||
|
||||
resize_edge = shortest_image_edge or target_h
|
||||
if image_np.shape[:2] != (resize_edge, resize_edge):
|
||||
image_np = cv2.resize(image_np, (resize_edge, resize_edge), interpolation=cv2.INTER_AREA)
|
||||
image_np = resize_shortest_edge(image_np)
|
||||
|
||||
if crop_fraction is None and image_crop_size is not None:
|
||||
crop_fraction = image_crop_size[0] / float(target_h)
|
||||
@@ -1337,9 +1338,7 @@ def _transform_n1_7_image_for_vlm_albumentations(
|
||||
left = max(0, (width - crop_w) // 2)
|
||||
image_np = image_np[top : top + crop_h, left : left + crop_w]
|
||||
|
||||
if image_np.shape[:2] != (target_h, target_w):
|
||||
image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA)
|
||||
return image_np
|
||||
return resize_shortest_edge(image_np)
|
||||
|
||||
|
||||
def _transform_n1_7_image_for_vlm_torch(
|
||||
|
||||
@@ -301,28 +301,22 @@ def _stats(values):
|
||||
|
||||
|
||||
def _expected_albumentations_eval_image(image_np, cv2, *, target_size, shortest_edge, crop_fraction):
|
||||
height, width = image_np.shape[:2]
|
||||
if height != width:
|
||||
square_edge = max(height, width)
|
||||
pad_h = square_edge - height
|
||||
pad_w = square_edge - width
|
||||
image_np = cv2.copyMakeBorder(
|
||||
image_np,
|
||||
pad_h // 2,
|
||||
pad_h - pad_h // 2,
|
||||
pad_w // 2,
|
||||
pad_w - pad_w // 2,
|
||||
cv2.BORDER_CONSTANT,
|
||||
value=(0, 0, 0),
|
||||
)
|
||||
del target_size
|
||||
|
||||
image_np = cv2.resize(image_np, (shortest_edge, shortest_edge), interpolation=cv2.INTER_AREA)
|
||||
crop_h = max(1, int(shortest_edge * crop_fraction))
|
||||
crop_w = max(1, int(shortest_edge * crop_fraction))
|
||||
top = (shortest_edge - crop_h) // 2
|
||||
left = (shortest_edge - crop_w) // 2
|
||||
image_np = image_np[top : top + crop_h, left : left + crop_w]
|
||||
return cv2.resize(image_np, (target_size[1], target_size[0]), interpolation=cv2.INTER_AREA)
|
||||
def resize_shortest_edge(frame):
|
||||
height, width = frame.shape[:2]
|
||||
scale = shortest_edge / float(min(height, width))
|
||||
resized_height = max(1, int(round(height * scale)))
|
||||
resized_width = max(1, int(round(width * scale)))
|
||||
return cv2.resize(frame, (resized_width, resized_height), interpolation=cv2.INTER_AREA)
|
||||
|
||||
image_np = resize_shortest_edge(image_np)
|
||||
height, width = image_np.shape[:2]
|
||||
crop_h = max(1, int(height * crop_fraction))
|
||||
crop_w = max(1, int(width * crop_fraction))
|
||||
top = (height - crop_h) // 2
|
||||
left = (width - crop_w) // 2
|
||||
return resize_shortest_edge(image_np[top : top + crop_h, left : left + crop_w])
|
||||
|
||||
|
||||
class _DummyGrootModel(nn.Module):
|
||||
@@ -1588,7 +1582,9 @@ def test_groot_n1_7_vlm_encode_uses_per_sample_language():
|
||||
self.encoded_texts = None
|
||||
|
||||
def apply_chat_template(self, conversation, tokenize, add_generation_prompt):
|
||||
text = conversation[0]["content"][-1]["text"]
|
||||
content = conversation[0]["content"]
|
||||
assert [item["type"] for item in content] == ["image", "text"]
|
||||
text = content[-1]["text"]
|
||||
self.rendered_texts.append(text)
|
||||
return f"rendered:{text}"
|
||||
|
||||
@@ -1626,6 +1622,7 @@ def test_groot_n1_7_vlm_encode_packs_images_time_major_then_camera_order():
|
||||
class FakeProcessor:
|
||||
def __init__(self):
|
||||
self.add_generation_prompts = []
|
||||
self.conversation_content_types = []
|
||||
self.conversation_image_values = []
|
||||
self.conversation_texts = []
|
||||
self.encoded_texts = None
|
||||
@@ -1635,6 +1632,7 @@ def test_groot_n1_7_vlm_encode_packs_images_time_major_then_camera_order():
|
||||
assert tokenize is False
|
||||
self.add_generation_prompts.append(add_generation_prompt)
|
||||
content = conversation[0]["content"]
|
||||
self.conversation_content_types.append([item["type"] for item in content])
|
||||
self.conversation_image_values.append(
|
||||
[int(np.asarray(item["image"])[0, 0, 0]) for item in content if item["type"] == "image"]
|
||||
)
|
||||
@@ -1672,6 +1670,10 @@ def test_groot_n1_7_vlm_encode_packs_images_time_major_then_camera_order():
|
||||
output = step(transition)
|
||||
|
||||
assert fake_proc.conversation_image_values == [[1, 2, 3, 4], [5, 6, 7, 8]]
|
||||
assert fake_proc.conversation_content_types == [
|
||||
["image", "image", "image", "image", "text"],
|
||||
["image", "image", "image", "image", "text"],
|
||||
]
|
||||
assert fake_proc.encoded_image_values == [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
assert fake_proc.conversation_texts == ["task a", "task b"]
|
||||
assert fake_proc.encoded_texts == ["rendered:task a", "rendered:task b"]
|
||||
@@ -1705,7 +1707,7 @@ def test_groot_n1_7_vlm_image_transform_matches_albumentations_eval_path():
|
||||
expected = expected[crop_start : crop_start + crop_edge, crop_start : crop_start + crop_edge]
|
||||
expected = cv2.resize(expected, (256, 256), interpolation=cv2.INTER_AREA)
|
||||
|
||||
assert transformed.size == (256, 256)
|
||||
assert transformed.shape == (256, 256, 3)
|
||||
np.testing.assert_array_equal(np.asarray(transformed), expected)
|
||||
|
||||
|
||||
@@ -1717,7 +1719,9 @@ def test_groot_n1_7_vlm_encode_transforms_non_square_two_camera_sample_like_core
|
||||
self.images = None
|
||||
|
||||
def apply_chat_template(self, conversation, tokenize, add_generation_prompt):
|
||||
return conversation[0]["content"][-1]["text"]
|
||||
content = conversation[0]["content"]
|
||||
assert [item["type"] for item in content] == ["image", "image", "text"]
|
||||
return content[-1]["text"]
|
||||
|
||||
def __call__(self, text, images, return_tensors, padding):
|
||||
self.images = images
|
||||
@@ -1792,7 +1796,6 @@ def test_groot_n1_7_vlm_encode_config_round_trips_model_name():
|
||||
def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
import transformers
|
||||
|
||||
from lerobot.policies.groot import processor_groot
|
||||
|
||||
@@ -1833,10 +1836,10 @@ def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch):
|
||||
cls.from_pretrained_called = True
|
||||
raise AssertionError("Cosmos does not publish processor_config.json")
|
||||
|
||||
monkeypatch.setattr(transformers, "AutoTokenizer", FakeTokenizer)
|
||||
monkeypatch.setattr(transformers, "Qwen2VLImageProcessorFast", FakeImageProcessor)
|
||||
monkeypatch.setattr(transformers, "Qwen3VLVideoProcessor", FakeVideoProcessor)
|
||||
monkeypatch.setattr(transformers, "Qwen3VLProcessor", FakeProcessor)
|
||||
monkeypatch.setattr(processor_groot, "AutoTokenizer", FakeTokenizer)
|
||||
monkeypatch.setattr(processor_groot, "Qwen2VLImageProcessor", FakeImageProcessor)
|
||||
monkeypatch.setattr(processor_groot, "Qwen3VLVideoProcessor", FakeVideoProcessor)
|
||||
monkeypatch.setattr(processor_groot, "Qwen3VLProcessor", FakeProcessor)
|
||||
|
||||
processor = processor_groot._build_n1_7_processor("nvidia/Cosmos-Reason2-2B")
|
||||
|
||||
@@ -2306,6 +2309,55 @@ def test_groot_n1_7_generated_relative_stats_match_oss_gr00t_reference_numbers()
|
||||
torch.testing.assert_close(decoded[TransitionKey.ACTION], action_a.unsqueeze(0), atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_groot_n1_7_relative_action_stats_skip_padded_tail_chunks():
|
||||
samples = [
|
||||
{
|
||||
OBS_STATE: torch.tensor([10.0, 100.0]),
|
||||
ACTION: torch.tensor([[11.0, 101.0], [12.0, 102.0], [13.0, 103.0]]),
|
||||
f"{ACTION}_is_pad": torch.tensor([False, False, False]),
|
||||
},
|
||||
{
|
||||
OBS_STATE: torch.tensor([20.0, 200.0]),
|
||||
ACTION: torch.tensor([[18.0, 198.0], [16.0, 196.0], [14.0, 194.0]]),
|
||||
f"{ACTION}_is_pad": torch.tensor([False, False, False]),
|
||||
},
|
||||
{
|
||||
OBS_STATE: torch.tensor([0.0, 0.0]),
|
||||
ACTION: torch.tensor([[999.0, 999.0], [888.0, 888.0], [777.0, 777.0]]),
|
||||
f"{ACTION}_is_pad": torch.tensor([False, False, True]),
|
||||
},
|
||||
]
|
||||
|
||||
class _Dataset:
|
||||
meta = SimpleNamespace(stats={})
|
||||
|
||||
def __len__(self):
|
||||
return len(samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return samples[idx]
|
||||
|
||||
relative_dataset_stats = _make_relative_action_training_stats(
|
||||
_Dataset(),
|
||||
exclude_joints=[],
|
||||
action_names=None,
|
||||
preserve_action_horizon=True,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
torch.as_tensor(relative_dataset_stats[ACTION]["count"]),
|
||||
torch.tensor([2, 2, 2]),
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
torch.as_tensor(relative_dataset_stats[ACTION]["min"]),
|
||||
torch.tensor([[-2.0, -2.0], [-4.0, -4.0], [-6.0, -6.0]]),
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
torch.as_tensor(relative_dataset_stats[ACTION]["max"]),
|
||||
torch.tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]),
|
||||
)
|
||||
|
||||
|
||||
def test_groot_policy_selects_n1_7_model_class(monkeypatch):
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
|
||||
@@ -0,0 +1,254 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
from lerobot.policies.groot.action_head.cross_attention_dit import AlternateVLDiT
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
from lerobot.policies.groot.processor_groot import (
|
||||
GrootN17ActionDecodeStep,
|
||||
GrootN17PackInputsStep,
|
||||
GrootN17VLMEncodeStep,
|
||||
_transform_n1_7_image_for_vlm_albumentations,
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
OSS_REFERENCE_COMMIT = "ab88b50c718f6528e1df9dcbaf75865d1b604760"
|
||||
|
||||
|
||||
def _fixture_path(filename: str) -> Path:
|
||||
fixture_dir = os.environ.get("GROOT_N17_OSS_PARITY_FIXTURE_DIR")
|
||||
if fixture_dir is None:
|
||||
pytest.skip("Set GROOT_N17_OSS_PARITY_FIXTURE_DIR to run external OSS parity fixtures.")
|
||||
path = Path(fixture_dir) / filename
|
||||
if not path.is_file():
|
||||
pytest.skip(f"External OSS parity fixture not found: {path}")
|
||||
return path
|
||||
|
||||
|
||||
def test_groot_n1_7_eval_image_transform_matches_oss_reference():
|
||||
"""Match the native N1.7 eval transform for a non-square SO-101 frame."""
|
||||
|
||||
y, x = np.indices((480, 640), dtype=np.uint16)
|
||||
image = np.stack(
|
||||
((x + 3 * y) % 256, (2 * x + y) % 256, (x + 5 * y) % 256),
|
||||
axis=-1,
|
||||
).astype(np.uint8)
|
||||
actual = _transform_n1_7_image_for_vlm_albumentations(
|
||||
image,
|
||||
image_crop_size=[230, 230],
|
||||
image_target_size=[256, 256],
|
||||
shortest_image_edge=256,
|
||||
crop_fraction=0.95,
|
||||
)
|
||||
|
||||
assert actual.shape == (256, 340, 3)
|
||||
assert hashlib.sha256(actual.tobytes()).hexdigest() == (
|
||||
"c17e47af68a812aa79db3bb7b64b549ddf10148ac1b204a9686095018561ae9e"
|
||||
)
|
||||
|
||||
|
||||
def test_groot_n1_7_vlm_chat_content_order_matches_oss_reference():
|
||||
"""Native OSS places all image items before the language item."""
|
||||
|
||||
class RecordingProcessor:
|
||||
def __init__(self):
|
||||
self.content_types = None
|
||||
|
||||
def apply_chat_template(self, conversation, tokenize, add_generation_prompt):
|
||||
assert tokenize is False
|
||||
assert add_generation_prompt is False
|
||||
self.content_types = [item["type"] for item in conversation[0]["content"]]
|
||||
return "rendered"
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
return {}
|
||||
|
||||
processor = RecordingProcessor()
|
||||
step = GrootN17VLMEncodeStep(
|
||||
image_crop_size=[230, 230],
|
||||
image_target_size=[256, 256],
|
||||
shortest_image_edge=256,
|
||||
crop_fraction=0.95,
|
||||
use_albumentations=True,
|
||||
device="cpu",
|
||||
)
|
||||
step._proc = processor
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
"video": np.zeros((1, 1, 2, 480, 640, 3), dtype=np.uint8),
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"language": ["pick up the vial"]},
|
||||
}
|
||||
|
||||
step(transition)
|
||||
|
||||
assert processor.content_types == ["image", "image", "text"]
|
||||
|
||||
|
||||
def test_groot_n1_7_alternate_vl_dit_matches_oss_reference():
|
||||
"""Run the LeRobot DiT with native OSS weights and identical inputs."""
|
||||
|
||||
fixture = torch.load(_fixture_path("alternate_vl_dit_small.pt"), map_location="cpu", weights_only=True)
|
||||
model = AlternateVLDiT(
|
||||
output_dim=8,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=4,
|
||||
num_layers=4,
|
||||
dropout=0.0,
|
||||
final_dropout=False,
|
||||
max_num_positional_embeddings=16,
|
||||
compute_dtype=torch.float32,
|
||||
interleave_self_attention=True,
|
||||
cross_attention_dim=6,
|
||||
).eval()
|
||||
model.load_state_dict(fixture["state_dict"], strict=True)
|
||||
|
||||
actual = model(
|
||||
hidden_states=fixture["hidden_states"],
|
||||
encoder_hidden_states=fixture["encoder_hidden_states"],
|
||||
timestep=fixture["timestep"],
|
||||
image_mask=fixture["image_mask"],
|
||||
backbone_attention_mask=fixture["backbone_attention_mask"],
|
||||
)
|
||||
|
||||
torch.testing.assert_close(actual, fixture["output"], atol=1e-6, rtol=1e-6)
|
||||
|
||||
|
||||
def _state_decode_reference():
|
||||
fixture = np.load(_fixture_path("state_and_action_decode.npz"))
|
||||
raw_stats = {
|
||||
"state": {
|
||||
"single_arm": {"q01": fixture["state_single_arm_q01"], "q99": fixture["state_single_arm_q99"]},
|
||||
"gripper": {"q01": fixture["state_gripper_q01"], "q99": fixture["state_gripper_q99"]},
|
||||
},
|
||||
"action": {
|
||||
"single_arm": {"q01": fixture["action_single_arm_q01"], "q99": fixture["action_single_arm_q99"]},
|
||||
"gripper": {"q01": fixture["action_gripper_q01"], "q99": fixture["action_gripper_q99"]},
|
||||
},
|
||||
"relative_action": {
|
||||
"single_arm": {
|
||||
"min": fixture["relative_single_arm_min"],
|
||||
"max": fixture["relative_single_arm_max"],
|
||||
},
|
||||
},
|
||||
}
|
||||
for modality_stats in raw_stats.values():
|
||||
for entry in modality_stats.values():
|
||||
for key, value in entry.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
entry[key] = value.tolist()
|
||||
modality_config = {
|
||||
"state": {"modality_keys": ["single_arm", "gripper"]},
|
||||
"action": {
|
||||
"delta_indices": list(range(16)),
|
||||
"modality_keys": ["single_arm", "gripper"],
|
||||
"action_configs": [
|
||||
{"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None},
|
||||
{"rep": "ABSOLUTE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None},
|
||||
],
|
||||
},
|
||||
}
|
||||
state_min = np.concatenate((fixture["state_single_arm_q01"], fixture["state_gripper_q01"]))
|
||||
state_max = np.concatenate((fixture["state_single_arm_q99"], fixture["state_gripper_q99"]))
|
||||
pack_step = GrootN17PackInputsStep(
|
||||
normalize_min_max=True,
|
||||
stats={OBS_STATE: {"min": state_min, "max": state_max}},
|
||||
raw_stats=raw_stats,
|
||||
modality_config=modality_config,
|
||||
use_percentiles=True,
|
||||
)
|
||||
raw_state = np.concatenate((fixture["state_single_arm"], fixture["state_gripper"]), axis=-1)
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {OBS_STATE: torch.from_numpy(raw_state)},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
packed = pack_step(transition)
|
||||
return fixture, raw_stats, modality_config, pack_step, packed
|
||||
|
||||
|
||||
def test_groot_n1_7_state_normalization_matches_oss_checkpoint_reference():
|
||||
fixture, _raw_stats, _modality_config, _pack_step, packed = _state_decode_reference()
|
||||
expected = np.concatenate(
|
||||
(fixture["normalized_state_single_arm"], fixture["normalized_state_gripper"]), axis=-1
|
||||
)
|
||||
|
||||
actual = packed[TransitionKey.OBSERVATION]["state"][:, 0, :6]
|
||||
|
||||
torch.testing.assert_close(actual, torch.from_numpy(expected), atol=1e-6, rtol=1e-6)
|
||||
|
||||
|
||||
def test_groot_n1_7_relative_action_decode_matches_oss_checkpoint_reference():
|
||||
fixture, raw_stats, modality_config, pack_step, _packed = _state_decode_reference()
|
||||
decode_step = GrootN17ActionDecodeStep(
|
||||
env_action_dim=6,
|
||||
raw_stats=raw_stats,
|
||||
modality_config=modality_config,
|
||||
use_percentiles=True,
|
||||
use_relative_action=True,
|
||||
pack_step=pack_step,
|
||||
)
|
||||
decoded = decode_step({TransitionKey.ACTION: torch.from_numpy(fixture["normalized_action"])})[
|
||||
TransitionKey.ACTION
|
||||
]
|
||||
expected = np.concatenate((fixture["decoded_single_arm"], fixture["decoded_gripper"]), axis=-1).astype(
|
||||
np.float32
|
||||
)
|
||||
|
||||
torch.testing.assert_close(decoded, torch.from_numpy(expected), atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_groot_n1_7_qwen_backbone_matches_oss_checkpoint_reference():
|
||||
"""Compare the actual 3B checkpoint backbone when explicitly enabled."""
|
||||
|
||||
checkpoint = os.environ.get("GROOT_N17_PARITY_CHECKPOINT")
|
||||
if checkpoint is None:
|
||||
pytest.skip("Set GROOT_N17_PARITY_CHECKPOINT to run the 3B OSS Qwen parity test.")
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("The 3B OSS Qwen parity test requires CUDA.")
|
||||
|
||||
fixture = torch.load(_fixture_path("qwen_backbone_so101.pt"), map_location="cpu", weights_only=True)
|
||||
model = GR00TN17.from_pretrained(checkpoint).to(device="cuda", dtype=torch.bfloat16).eval()
|
||||
backbone_input = BatchFeature(
|
||||
data={
|
||||
key.removeprefix("input."): value.to("cuda")
|
||||
for key, value in fixture.items()
|
||||
if key.startswith("input.")
|
||||
}
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
actual = model.backbone(backbone_input)
|
||||
|
||||
feature_error = (
|
||||
actual.backbone_features.cpu().float() - fixture["output.backbone_features"].float()
|
||||
).abs()
|
||||
# Native OSS and LeRobot use different Torch/Transformers/Flash-Attention releases.
|
||||
# Require the measured BF16 accumulation envelope while rejecting structural drift.
|
||||
assert feature_error.mean().item() <= 0.04
|
||||
assert feature_error.max().item() <= 2.0
|
||||
torch.testing.assert_close(
|
||||
actual.backbone_attention_mask.cpu(), fixture["output.backbone_attention_mask"]
|
||||
)
|
||||
torch.testing.assert_close(actual.image_mask.cpu(), fixture["output.image_mask"])
|
||||
Reference in New Issue
Block a user