Compare commits

..

2 Commits

Author SHA1 Message Date
Steven Palma f81fc79564 minor fixes 2026-06-12 18:45:24 +02:00
Steven Palma 9ce6633518 fix(groot): address review findings for the N1.7 port
N1.5 removal is now explicit and actionable:
- Legacy N1.5 checkpoint configs (tokenizer_assets_repo) parse and fail
  with a single clear error pointing to lerobot==0.5.1 instead of a
  cryptic draccus DecodingError
- Removed N1.5 processor registry names (groot_pack_inputs_v3,
  groot_eagle_encode_v3, groot_eagle_collate_v3) are stubbed to raise the
  same guidance; groot_action_unpack_unnormalize_v1 changed semantics, so
  the step is re-registered as _v2 and _v1 is stubbed
- N1.5 detection also recognizes checkpoint config.json content
  (model_type/architectures/eagle backbone), not just path names; every
  rejection surface includes the migration guidance
- groot.mdx documents the breaking change and migration path

Runtime fixes:
- use_bf16=False no longer crashes (compute_dtype only set when used)
- GrootN17ActionDecodeStep handles the 2-D (B, D) actions delivered by
  sync select_action (relative eef/non-eef decode was broken in
  lerobot-eval/record flows)
- Postprocessor falls back to dataset stats when a raw checkpoint lacks
  the configured embodiment tag instead of silently emitting normalized
  [-1, 1] actions
- Hub-hosted finetuned N1.7 checkpoints load: the processor config is
  resolved via hf_hub_download for non-local paths, with a tolerant
  retry when inspection fails
- Raw-checkpoint processor branch honors caller overrides (device,
  rename_map) instead of dropping them
- Relative-action raw-state cache is per-instance instead of
  process-global (cross-instance contamination)
- Camera/modality-key mismatches warn, including the zero-match
  fallback; checkpoint revision is no longer forwarded into backbone
  loading; deprecated Qwen2VLImageProcessorFast replaced with
  Qwen2VLImageProcessor

Config/UX:
- GrootConfig defaults are the N1.7 values; explicitly passed legacy
  N1.5-era values (chunk_size=50, max_state_dim=64, ...) are remapped
  with a warning instead of silently
- Explicit action_decode_transform='none' wins over the libero_sim
  default (new 'auto' sentinel) and survives save/load round-trips

Tests/CI:
- pytest.importorskip guards so fast_tests tiers pass without
  transformers (was 10 failures, now 0)
- Regression tests for every fix; from_pretrained rejection tests now
  actually exercise from_pretrained
- Parity test reads the artifact seed, fails on shape mismatch instead
  of silently truncating, and a new case runs LeRobot's real Qwen3-VL
  preprocessing on raw observations dumped by the producer
- docs: dead huggingface-cli download replaced with hf download

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 16:51:14 +02:00
8 changed files with 1299 additions and 282 deletions
@@ -321,6 +321,9 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
normalized = candidate.lower().replace("-", "_")
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
return GROOT_N1_7
# nvidia/GR00T-N1.5-3B ships model_type 'gr00t_n1_5' and architectures ['GR00T_N1_5'].
# Recognise them so N1.5 checkpoints at generic local paths are rejected loudly
# instead of being silently treated as N1.7 (see infer_groot_model_version).
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
return GROOT_N1_5
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
@@ -362,7 +365,11 @@ class GrootConfig(PreTrainedConfig):
}
)
# Groot-specific model parameters
# Deprecated and unused: image sizing is handled by the backbone's image processor.
# Kept only so config.json files saved with earlier versions still parse.
image_size: tuple[int, int] = (256, 256)
# Groot-specific model parameters (from groot_finetune_script.py)
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
model_version: str = GROOT_N1_7
@@ -378,6 +385,11 @@ class GrootConfig(PreTrainedConfig):
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
# Deprecated, GR00T N1.5 only — do not set. Kept so config.json files saved by lerobot<=0.5.1
# still parse (draccus rejects unknown fields) and can be rejected in __post_init__ with a
# clear error pointing at GROOT_N1_5_REMOVAL_GUIDANCE instead of a cryptic DecodingError.
tokenizer_assets_repo: str | None = None
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment"
@@ -416,13 +428,10 @@ class GrootConfig(PreTrainedConfig):
warmup_ratio: float = 0.05
use_bf16: bool = True
# TODO(Steven): Remove these deprecated fields in a future release.
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
# Deprecated Isaac-GR00T runner fields below — unused by the LeRobot N1.7 implementation
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
# would break every previously saved groot checkpoint at config-load time.
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
tokenizer_assets_repo: str | None = None
video_backend: str = "decord"
balance_dataset_weights: bool = True
balance_trajectory_weights: bool = True
@@ -436,6 +445,9 @@ class GrootConfig(PreTrainedConfig):
resume: bool = False
def __post_init__(self):
# 'tokenizer_assets_repo' only ever existed for GR00T N1.5 (lerobot<=0.5.1) and was
# serialized into every groot checkpoint config.json, so a value here means a legacy
# N1.5 checkpoint or config is being loaded.
if self.tokenizer_assets_repo is not None:
raise ValueError(
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
@@ -570,11 +582,22 @@ class GrootConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions."""
"""Return indices for delta actions.
The model action horizon is read from the checkpoint's processor_config.json
when available; the result is cached (keyed on the inputs that determine it) so
repeated access during dataset/training setup does not re-read from disk.
"""
cache_key = (self.base_model_path, self.embodiment_tag, self.chunk_size)
cached = getattr(self, "_action_delta_indices_cache", None)
if cached is not None and cached[0] == cache_key:
return cached[1]
model_action_horizon = (
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
)
return list(range(min(self.chunk_size, model_action_horizon)))
indices = list(range(min(self.chunk_size, model_action_horizon)))
object.__setattr__(self, "_action_delta_indices_cache", (cache_key, indices))
return indices
@property
def reward_delta_indices(self) -> None:
+7 -1
View File
@@ -71,7 +71,7 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"backbone_embedding_dim": 2048,
"tune_llm": False,
"tune_visual": False,
"select_layer": 16,
"select_layer": 16, # N1.7-3B checkpoint value; real checkpoint loads override this from config.json
"reproject_vision": False,
"use_flash_attention": True,
"load_bf16": False,
@@ -822,6 +822,8 @@ def get_backbone_cls(config: GR00TN17Config):
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
return Qwen3Backbone
if config.backbone_model_type == "qwen":
# Local backbone checkpoints (e.g. hub-cache snapshot paths) contain neither hub
# marker, so trust the explicit backbone type but surface what is being assumed.
logger.warning(
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
"backbone because backbone_model_type='qwen'.",
@@ -912,6 +914,10 @@ class GR00TN17(PreTrainedModel):
"trust_remote_code": True
}
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
# Only repo-agnostic hub kwargs are forwarded to the backbone loading kwargs:
# ``revision`` pins the GR00T checkpoint repo (see snapshot_download below) and would
# be invalid for the unrelated backbone repo (``config.model_name``). Pin the backbone
# itself by passing ``revision`` inside ``transformers_loading_kwargs``.
for key in ("cache_dir", "local_files_only", "token"):
if key in kwargs:
transformers_loading_kwargs.setdefault(key, kwargs[key])
@@ -93,6 +93,12 @@ class GrootPolicy(PreTrainedPolicy):
transformers_loading_kwargs={"trust_remote_code": True},
)
# GR00TN17 defines no compute_dtype attribute, so only record the
# bf16 preference when it is enabled instead of reading a default back.
if self.config.use_bf16:
model.compute_dtype = "bfloat16"
model.config.compute_dtype = "bfloat16"
return model
def reset(self):
+119 -198
View File
@@ -23,10 +23,9 @@ from typing import TYPE_CHECKING, Any
import numpy as np
import torch
import torchvision.transforms.v2.functional as tv_functional
from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image
from torchvision.transforms import InterpolationMode
from lerobot.utils.import_utils import _transformers_available
@@ -59,7 +58,6 @@ from lerobot.utils.constants import (
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from lerobot.utils.device_utils import get_safe_torch_device
from .configuration_groot import (
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
@@ -450,40 +448,60 @@ def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
return any(bool(modality_stats) for modality_stats in stats.values())
# GR00T normalizes state/action inside its own processor steps and so deliberately has no
# NormalizerProcessorStep/UnnormalizerProcessorStep (see GrootConfig.normalization_mapping, which is
# IDENTITY for every feature). lerobot-train nonetheless emits these standard override keys
# unconditionally, so for a GR00T pipeline they legitimately match no step. They are dropped up front
# by _drop_groot_absent_standard_overrides so they neither break loading nor mask genuine typos.
_GROOT_ABSENT_STANDARD_OVERRIDE_KEYS = frozenset({"normalizer_processor", "unnormalizer_processor"})
def _legacy_groot_processor_overrides(
config: GrootConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None,
preprocessor_overrides: dict[str, Any] | None = None,
postprocessor_overrides: dict[str, Any] | None = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Patch older serialized Groot processors with fields current processors expect."""
preprocessor_overrides = dict(preprocessor_overrides or {})
postprocessor_overrides = dict(postprocessor_overrides or {})
pack_inputs_key = "groot_n1_7_pack_inputs_v1"
pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {}))
pack_input_overrides["normalize_min_max"] = True
preprocessor_overrides[pack_inputs_key] = pack_input_overrides
try:
env_action_dim = int(config.output_features[ACTION].shape[0])
except Exception:
env_action_dim = 0
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v2", {}))
action_unpack_overrides["normalize_min_max"] = True
action_unpack_overrides["env_action_dim"] = env_action_dim
postprocessor_overrides["groot_action_unpack_unnormalize_v2"] = action_unpack_overrides
return preprocessor_overrides, postprocessor_overrides
def _drop_groot_absent_standard_overrides(overrides: dict[str, Any] | None) -> dict[str, Any] | None:
"""Strip standard normalization override keys that a GR00T pipeline has no step for.
def _pretrained_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
"""Check whether a serialized processor pipeline contains a registry step.
``lerobot-train`` emits ``normalizer_processor``/``unnormalizer_processor`` overrides
unconditionally, but GR00T normalizes inside its own steps and has no such step (see
``GrootConfig.normalization_mapping``). Both override-application paths reject keys that match no
step ``_apply_groot_step_overrides`` raises for the freshly built raw-checkpoint pipeline, and
``PolicyProcessorPipeline.from_pretrained`` raises via its used-override validation for the
serialized pipeline so these keys are removed before either path runs. Any other unknown key
(e.g. a typo) is left in place and still raises.
Resolves the processor config from a local directory or, for Hub repo ids,
via ``hf_hub_download`` (which serves the cached copy when offline). Returns
False when the config cannot be resolved; loading then proceeds with the
legacy overrides and `make_groot_pre_post_processors_from_pretrained` retries
without them if they do not match the serialized pipeline.
"""
if not overrides:
return overrides
filtered: dict[str, Any] = {}
for key, value in overrides.items():
if key in _GROOT_ABSENT_STANDARD_OVERRIDE_KEYS:
logging.debug(
"Ignoring override key '%s': GR00T normalizes inside its own processor steps and has "
"no matching step (see GrootConfig.normalization_mapping).",
key,
path = Path(pretrained_path).expanduser()
if path.is_dir():
config = _read_json(path / config_filename)
elif path.exists():
return False
else:
try:
config_path = hf_hub_download(
repo_id=str(pretrained_path), filename=config_filename, repo_type="model"
)
continue
filtered[key] = value
return filtered
except Exception:
return False
config = _read_json(Path(config_path))
steps = config.get("steps", [])
if not isinstance(steps, list):
return False
return any(isinstance(step, dict) and step.get("registry_name") == step_name for step in steps)
def _apply_groot_step_overrides(
@@ -499,8 +517,7 @@ def _apply_groot_step_overrides(
steps by registry name only prefer registry names so overrides keep
working after the checkpoint is converted and reloaded from a serialized
pipeline). Keys or fields that match nothing raise instead of being dropped
silently (standard normalization keys GR00T has no step for are removed
beforehand by ``_drop_groot_absent_standard_overrides``).
silently.
"""
if not overrides:
@@ -556,13 +573,7 @@ def make_groot_pre_post_processors_from_pretrained(
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Load Groot processors for a raw N1.7 checkpoint or a serialized LeRobot pipeline."""
# Drop the standard normalizer/unnormalizer override keys lerobot-train emits unconditionally:
# GR00T has no such steps, so they would make both the raw-checkpoint and serialized override
# paths raise. This must happen before either branch below.
preprocessor_overrides = _drop_groot_absent_standard_overrides(preprocessor_overrides)
postprocessor_overrides = _drop_groot_absent_standard_overrides(postprocessor_overrides)
"""Load Groot processors while preserving compatibility with older serialized configs."""
if is_raw_groot_n1_7_checkpoint(pretrained_path):
processor_cfg = copy(config)
@@ -578,13 +589,49 @@ def make_groot_pre_post_processors_from_pretrained(
_apply_groot_step_overrides(postprocessor, postprocessor_overrides)
return preprocessor, postprocessor
preprocessor, postprocessor = _load_groot_processor_pipelines(
caller_preprocessor_overrides = dict(preprocessor_overrides or {})
caller_postprocessor_overrides = dict(postprocessor_overrides or {})
if _pretrained_processor_config_has_step(
pretrained_path,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
preprocessor_config_filename=preprocessor_config_filename,
postprocessor_config_filename=postprocessor_config_filename,
)
postprocessor_config_filename,
"groot_n1_7_action_decode_v1",
):
# Converted raw N1.7 checkpoints already carry the checkpoint-specific
# action decoder. Adding the legacy action-unpack override would target
# a step that is not present and break loading.
applied_legacy_overrides = False
preprocessor_overrides = caller_preprocessor_overrides
postprocessor_overrides = caller_postprocessor_overrides
else:
applied_legacy_overrides = True
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
config=config,
dataset_stats=dataset_stats,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
)
try:
preprocessor, postprocessor = _load_groot_processor_pipelines(
pretrained_path,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
preprocessor_config_filename=preprocessor_config_filename,
postprocessor_config_filename=postprocessor_config_filename,
)
except KeyError:
if not applied_legacy_overrides:
raise
# The legacy overrides target steps that are absent from the serialized
# pipelines (e.g. a converted raw N1.7 checkpoint whose postprocessor
# config could not be inspected before loading); retry with the caller
# overrides only.
preprocessor, postprocessor = _load_groot_processor_pipelines(
pretrained_path,
preprocessor_overrides=caller_preprocessor_overrides,
postprocessor_overrides=caller_postprocessor_overrides,
preprocessor_config_filename=preprocessor_config_filename,
postprocessor_config_filename=postprocessor_config_filename,
)
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
_reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor)
return preprocessor, postprocessor
@@ -747,10 +794,6 @@ def make_groot_pre_post_processors(
use_albumentations=checkpoint_assets.use_albumentations
if checkpoint_assets is not None
else False,
# Run the image resize/normalize/patchify on the training device when
# possible instead of the single CPU main-loop thread (the dominant
# cost folded into dataloading_s).
device=config.device,
),
DeviceProcessorStep(device=config.device),
]
@@ -989,61 +1032,6 @@ def _transform_n1_7_image_for_vlm(
return image
def _transform_n1_7_image_for_vlm_torch(
image: torch.Tensor,
*,
image_crop_size: list[int] | None,
image_target_size: list[int] | None,
shortest_image_edge: int | None,
crop_fraction: float | None,
) -> torch.Tensor:
"""Torch/torchvision port of the non-albumentations branch of
:func:`_transform_n1_7_image_for_vlm`.
Operates on a ``(C, H, W)`` uint8 tensor and keeps the result on the input
tensor's device so the resize/crop run on GPU when the tensor is. Bicubic
interpolation with antialiasing matches PIL's ``Image.Resampling.BICUBIC``
closely (sub-``2/255`` per-pixel on worst-case inputs). The ``use_albumentations``
cv2/INTER_AREA path has no torch equivalent and stays on the PIL helper.
"""
if image_target_size is None:
return image
target_h, target_w = image_target_size
_, height, width = image.shape
square_edge = max(height, width)
if height != width:
left = (square_edge - width) // 2
top = (square_edge - height) // 2
image = tv_functional.pad(
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
)
resize_edge = shortest_image_edge or target_h
image = tv_functional.resize(
image, [resize_edge, resize_edge], interpolation=InterpolationMode.BICUBIC, antialias=True
)
if crop_fraction is None and image_crop_size is not None:
crop_fraction = image_crop_size[0] / float(target_h)
if crop_fraction is not None and 0.0 < crop_fraction < 1.0:
# Match the PIL helper's center crop exactly: round() the crop size but
# floor() the offset (torchvision.center_crop rounds the offset, which
# shifts the region by 1px when (edge - crop) is odd).
crop_h = max(1, int(round(image.shape[-2] * crop_fraction)))
crop_w = max(1, int(round(image.shape[-1] * crop_fraction)))
top = max(0, (image.shape[-2] - crop_h) // 2)
left = max(0, (image.shape[-1] - crop_w) // 2)
image = image[..., top : top + crop_h, left : left + crop_w]
if tuple(image.shape[-2:]) != (target_h, target_w):
image = tv_functional.resize(
image, [target_h, target_w], interpolation=InterpolationMode.BICUBIC, antialias=True
)
return image
@dataclass
@ProcessorStepRegistry.register(name="groot_n1_7_pack_inputs_v1")
class GrootN17PackInputsStep(ProcessorStep):
@@ -1070,6 +1058,9 @@ class GrootN17PackInputsStep(ProcessorStep):
video_modality_keys: list[str] | None = None
raw_stats: dict[str, Any] | None = None
modality_config: dict[str, Any] | None = None
# Unused: kept so serialized configs that include it still load. The raw
# state cache is per instance (_last_raw_state), never process-global.
state_cache_key: str = ""
_last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False)
_warned_image_keys: bool = field(default=False, init=False, repr=False)
@@ -1342,12 +1333,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
The packed video has shape ``(B, T, V, H, W, C)``. Each frame/view becomes
an image item in the same chat message so the resulting image tokens match
the temporal VLM packing used by Isaac-GR00T.
Images are handed to the torchvision-backed Qwen3-VL processor as ``(C, H, W)``
uint8 tensors (no per-frame PIL roundtrip), and, when ``device`` resolves to a
CUDA device, the resize/rescale/normalize/patchify run there instead of on the
single CPU main-loop thread. This keeps the output bit-identical on CPU and
moves the dominant preprocessing cost off the critical path on GPU.
"""
model_name: str = GROOT_N1_7_BACKBONE_MODEL
@@ -1356,7 +1341,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
shortest_image_edge: int | None = None
crop_fraction: float | None = None
use_albumentations: bool = False
device: str | None = None
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
@property
@@ -1365,70 +1349,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
self._proc = _build_n1_7_processor(self.model_name)
return self._proc
def _target_device(self) -> torch.device | None:
# The albumentations path is cv2/PIL only, so it cannot run on GPU.
if self.device is None or self.use_albumentations:
return None
try:
return get_safe_torch_device(self.device)
except (AssertionError, RuntimeError):
# A device serialized at train time (e.g. "cuda") may be unavailable
# when the processor is reloaded elsewhere (e.g. CPU-only eval), and
# this step is not in the standard device-override set. Fall back to
# the CPU path, which is bit-identical, instead of crashing.
return None
def _build_sample_images(
self, video: Any, batch_size: int, target_device: torch.device | None
) -> list[list[Any]]:
"""Return, per batch item, its ordered ``(timestep, view)`` frames.
``use_albumentations`` keeps the legacy per-frame PIL/cv2 transform;
otherwise frames are ``(C, H, W)`` uint8 tensors (moved to
``target_device`` when set) for the torchvision-backed Qwen processor.
"""
if self.use_albumentations:
video_np = np.asarray(video)
return [
[
_transform_n1_7_image_for_vlm(
Image.fromarray(video_np[batch_idx, timestep, view_idx]),
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
use_albumentations=True,
)
for timestep in range(video_np.shape[1])
for view_idx in range(video_np.shape[2])
]
for batch_idx in range(batch_size)
]
video_t = video if torch.is_tensor(video) else torch.from_numpy(np.ascontiguousarray(video))
# (B, T, V, H, W, C) uint8 -> (B, T, V, C, H, W)
video_t = video_t.permute(0, 1, 2, 5, 3, 4).contiguous()
if target_device is not None and video_t.device != target_device:
video_t = video_t.to(target_device, non_blocking=(target_device.type == "cuda"))
frames_per_sample: list[list[Any]] = []
for batch_idx in range(batch_size):
sample = video_t[batch_idx] # (T, V, C, H, W)
frames_per_sample.append(
[
_transform_n1_7_image_for_vlm_torch(
sample[timestep, view_idx],
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
)
for timestep in range(sample.shape[0])
for view_idx in range(sample.shape[1])
]
)
return frames_per_sample
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
@@ -1436,25 +1356,33 @@ class GrootN17VLMEncodeStep(ProcessorStep):
if video is None:
return transition
batch_size = int(video.shape[0])
languages = _prepare_n1_7_language_batch(
comp.get("language"),
batch_size,
video.shape[0],
formalize_language=False,
)
target_device = self._target_device()
sample_images = self._build_sample_images(video, batch_size, target_device)
texts: list[str] = []
images: list[Any] = []
for batch_idx in range(batch_size):
frames = sample_images[batch_idx]
images: list[Image.Image] = []
for batch_idx in range(video.shape[0]):
sample = video[batch_idx] # (T, V, H, W, C)
sample_images = [
_transform_n1_7_image_for_vlm(
Image.fromarray(sample[timestep, view_idx]),
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
use_albumentations=self.use_albumentations,
)
for timestep in range(sample.shape[0])
for view_idx in range(sample.shape[1])
]
conversation = [
{
"role": "user",
"content": [
*[{"type": "image", "image": image} for image in frames],
*[{"type": "image", "image": image} for image in sample_images],
{"type": "text", "text": languages[batch_idx]},
],
}
@@ -1466,17 +1394,9 @@ class GrootN17VLMEncodeStep(ProcessorStep):
add_generation_prompt=False,
)
)
images.extend(frames)
images.extend(sample_images)
proc_kwargs: dict[str, Any] = {
"text": texts,
"images": images,
"return_tensors": "pt",
"padding": True,
}
if target_device is not None:
proc_kwargs["device"] = str(target_device)
encoded = self.proc(**proc_kwargs)
encoded = self.proc(text=texts, images=images, return_tensors="pt", padding=True)
for key, value in encoded.items():
comp[key] = value
obs.pop("video", None)
@@ -1495,7 +1415,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
"shortest_image_edge": self.shortest_image_edge,
"crop_fraction": self.crop_fraction,
"use_albumentations": self.use_albumentations,
"device": self.device,
}
@@ -1646,6 +1565,8 @@ class GrootN17ActionDecodeStep(ProcessorStep):
modality_config: dict[str, Any] | None = None
use_percentiles: bool = False
use_relative_action: bool = False
# Unused: kept so serialized configs that include it still load.
state_cache_key: str = ""
action_decode_transform: str | None = None
pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False)
@@ -1773,10 +1694,10 @@ class GrootN17ActionDecodeStep(ProcessorStep):
}
@dataclass
# v2: unlike the N1.5-era v1 step, this step no longer collapses (B, T, D)
# action chunks to the last timestep, so old serialized v1 pipelines must not
# silently load into it (v1 is stubbed below with the removal guidance).
@dataclass
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v2")
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
env_action_dim: int = 0
@@ -207,6 +207,11 @@ def test_lerobot_groot_forward_pass():
with torch.no_grad():
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
assert isinstance(lerobot_loss, torch.Tensor)
assert torch.isfinite(lerobot_loss).all()
assert "loss" in lerobot_metrics
assert np.isfinite(lerobot_metrics["loss"])
print("\nForward pass successful.")
print(f" - Loss: {lerobot_loss.item():.6f}")
print(f" - Metrics: {lerobot_metrics}")
File diff suppressed because it is too large Load Diff
+175 -26
View File
@@ -14,31 +14,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parity test: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
"""Parity tests: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
Verifies that the self-contained LeRobot reimplementation of the GR00T N1.7 action
head + Qwen3-VL backbone produces the SAME raw model output (``action_pred``, the
normalized flow-matching prediction before any action decoding) as NVIDIA's original
``gr00t`` package, given byte-identical pre-processed inputs and the same
flow-matching seed. The comparison is parametrized over every embodiment tag present
in the checkpoint.
Two comparisons run per embodiment tag, against per-tag ``.npz`` artifacts produced
once in the original ``gr00t`` env by the companion script
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file):
To keep the comparison fair, the original outputs + the exact collated inputs are
produced once per embodiment in the original ``gr00t`` env via the companion script
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file) and saved
to per-tag ``.npz`` files.
This test discovers those artifacts, replays the identical inputs through the LeRobot
model, and compares.
1. **Model parity** -- the self-contained LeRobot reimplementation of the GR00T N1.7
action head + Qwen3-VL backbone must produce the SAME raw model output
(``action_pred``, the normalized flow-matching prediction before any action
decoding) as NVIDIA's original ``gr00t`` package, given byte-identical
pre-processed inputs and the flow-matching seed recorded in the artifact.
2. **Preprocessor parity** -- LeRobot's own preprocessor pipeline (real Qwen3-VL chat
template / tokenizer / image packing + state normalization, no mocks) must produce
the SAME collated model inputs (``input_ids``, ``pixel_values``, ``state``, ...)
as the original package's processor, given the identical raw observations
(images, state, language) recorded in the artifact. Artifacts written by older
versions of the dump script carry no raw observations; this case then SKIPS with
a regeneration hint.
This test is LOCAL-only and skips on CI, when ``gr00t``-side prerequisites are not
present, or when no artifact has been generated. By default it looks for artifacts in
These tests are LOCAL-only and skip on CI, when ``gr00t``-side prerequisites are not
present, or when no artifact has been generated. By default they look for artifacts in
``<this dir>/artifacts/``; override with ``GROOT_N1_7_PARITY_DIR``. See the
"Original-vs-LeRobot parity test" section of ``src/lerobot/policies/groot/README.md``
for the full run procedure.
"""
import os
import warnings
from pathlib import Path
from typing import Any
import numpy as np
import pytest
@@ -50,7 +55,9 @@ pytestmark = pytest.mark.skipif(
)
from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
# Fallback flow-matching seed for artifacts predating the recorded ``seed`` field.
SEED = 42
DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3"))
@@ -60,6 +67,11 @@ RTOL = float(os.environ.get("GROOT_PARITY_RTOL", "1e-3"))
_ARTIFACT_PREFIX = "original_n1_7_"
_ARTIFACT_SUFFIX = ".npz"
# Collated keys compared by the preprocessor parity case: integer/id tensors must
# match exactly; float tensors within ATOL/RTOL.
_COLLATED_EXACT_KEYS = ("input_ids", "attention_mask", "image_grid_thw", "embodiment_id")
_COLLATED_CLOSE_KEYS = ("pixel_values", "state")
def _artifact_dir() -> Path:
"""Directory holding the per-embodiment .npz artifacts.
@@ -109,9 +121,20 @@ def _resolve_checkpoint() -> str:
return str(ckpt)
def _load_artifact(path: Path):
def _load_artifact(path: Path) -> tuple[torch.Tensor, dict[str, torch.Tensor], int]:
"""Return (original action_pred, collated model inputs, flow-matching seed)."""
data = np.load(path, allow_pickle=True)
original_action = torch.from_numpy(data["action_pred"]).float()
if "seed" in data.files:
seed = int(data["seed"])
else:
warnings.warn(
f"Artifact '{path.name}' does not record the producer seed (it predates the current "
f"dump_original_n1_7.py); falling back to seed={SEED}. If the parity comparison fails, "
"regenerate the artifact with the current dump script.",
stacklevel=2,
)
seed = SEED
dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False))
inputs = {}
for key in data.files:
@@ -124,7 +147,45 @@ def _load_artifact(path: Path):
if "int" in declared or "long" in declared:
t = t.long()
inputs[name] = t
return original_action, inputs
return original_action, inputs, seed
def _load_raw_observation(path: Path) -> dict[str, Any] | None:
"""Return the raw observation recorded in the artifact, or None for old artifacts.
Artifacts produced by the current ``dump_original_n1_7.py`` additionally store the
exact raw observation the producer fed to the original processor: per-camera uint8
frames (``raw::video.<key>``, (B, T, H, W, C)), per-key state vectors
(``raw::state.<key>``, (B, T, dim)) and the language instruction
(``raw::language``, one string per batch element). ``raw_video_keys`` /
``raw_state_keys`` record the checkpoint modality-key order.
"""
data = np.load(path, allow_pickle=True)
markers = ("raw_video_keys", "raw_state_keys", "raw::language")
if any(marker not in data.files for marker in markers):
return None
video_keys = [str(k) for k in data["raw_video_keys"].tolist()]
state_keys = [str(k) for k in data["raw_state_keys"].tolist()]
return {
"video": {k: data[f"raw::video.{k}"] for k in video_keys},
"state": {k: data[f"raw::state.{k}"] for k in state_keys},
"language": [str(t) for t in data["raw::language"].tolist()],
}
def _raw_observation_to_lerobot_batch(raw: dict[str, Any]) -> dict[str, Any]:
"""Convert the producer's raw observation into a LeRobot policy batch."""
batch: dict[str, Any] = {}
for key, frames in raw["video"].items():
# (B, T, H, W, C) uint8 -> (B, T, C, H, W); the pack step converts back losslessly.
batch[f"{OBS_IMAGES}.{key}"] = torch.from_numpy(frames).permute(0, 1, 4, 2, 3).contiguous()
# observation.state is the per-key state vectors (latest frame) concatenated in
# checkpoint modality-key order -- the layout the LeRobot pack step and the
# flattened checkpoint statistics expect.
state_parts = [torch.from_numpy(np.asarray(arr)[:, -1, :]).float() for arr in raw["state"].values()]
batch[OBS_STATE] = torch.cat(state_parts, dim=-1)
batch["task"] = list(raw["language"])
return batch
def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
@@ -139,6 +200,36 @@ def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
return nested.get("inputs", nested)
def _assert_collated_parity(
embodiment_tag: str, name: str, lerobot_value: Any, original_value: torch.Tensor, *, exact: bool
) -> None:
"""Compare one collated tensor produced by LeRobot against the original's."""
assert isinstance(lerobot_value, torch.Tensor), (
f"[{embodiment_tag}] LeRobot preprocessor output '{name}' is "
f"{type(lerobot_value).__name__}, expected a tensor."
)
lerobot_t = lerobot_value.detach().cpu()
original_t = original_value.detach().cpu()
assert lerobot_t.shape == original_t.shape, (
f"[{embodiment_tag}] collated '{name}' shape mismatch: lerobot={tuple(lerobot_t.shape)} vs "
f"original={tuple(original_t.shape)}."
)
if exact:
mismatched = int((lerobot_t.long() != original_t.long()).sum())
assert mismatched == 0, (
f"[{embodiment_tag}] collated '{name}' differs from the original processor output: "
f"{mismatched}/{original_t.numel()} elements mismatch."
)
else:
lerobot_f, original_f = lerobot_t.float(), original_t.float()
max_diff = (lerobot_f - original_f).abs().max().item()
print(f"[{embodiment_tag}] {name}: shape {tuple(lerobot_t.shape)} max|diff|={max_diff:.6e}")
assert torch.allclose(lerobot_f, original_f, atol=ATOL, rtol=RTOL), (
f"[{embodiment_tag}] collated '{name}' differs from the original processor output beyond "
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}."
)
@pytest.fixture(scope="module")
def lerobot_model():
"""Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags."""
@@ -165,8 +256,7 @@ def lerobot_model():
_ARTIFACTS = _discover_artifacts()
@pytest.mark.skipif(
_requires_artifacts = pytest.mark.skipif(
not _ARTIFACTS,
reason=(
"No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t "
@@ -174,24 +264,30 @@ _ARTIFACTS = _discover_artifacts()
"--ckpt <ckpt> --out-dir tests/policies/groot/artifacts --device cuda"
),
)
@_requires_artifacts
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
"""Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot."""
original_action, flat_inputs = _load_artifact(artifact)
original_action, flat_inputs, seed = _load_artifact(artifact)
model_inputs = _unflatten(flat_inputs)
# Align the flow-matching RNG exactly as the producer did (seed right before sampling).
torch.manual_seed(SEED)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
torch.cuda.manual_seed_all(seed)
with torch.inference_mode():
out = lerobot_model.get_action(model_inputs)
lerobot_action = out["action_pred"].float().cpu()
t = min(original_action.shape[1], lerobot_action.shape[1])
d = min(original_action.shape[2], lerobot_action.shape[2])
original_action = original_action[:, :t, :d]
lerobot_action = lerobot_action[:, :t, :d]
assert lerobot_action.shape == original_action.shape, (
f"GR00T N1.7 action_pred shape mismatch for embodiment '{embodiment_tag}': "
f"lerobot={tuple(lerobot_action.shape)} vs original={tuple(original_action.shape)}. "
"The same checkpoint and inputs must produce identical shapes; this indicates an "
"action-horizon or action-dim regression (or a stale artifact -- regenerate it with "
"utils/dump_original_n1_7.py)."
)
diff = torch.abs(lerobot_action - original_action)
max_diff = diff.max().item()
@@ -205,3 +301,56 @@ def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond "
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}"
)
@_requires_artifacts
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
def test_groot_preprocessor_parity(embodiment_tag, artifact):
"""LeRobot's real preprocessor vs the original's collated tensors, from identical raw obs.
Runs LeRobot's full preprocessor pipeline -- including the real Qwen3-VL chat
template, tokenizer and image packing plus the checkpoint-driven state
normalization (no mocks) -- on the raw observations recorded in the artifact, and
compares every collated model input against the ones the original ``gr00t``
processor produced from the same raw observations.
"""
raw = _load_raw_observation(artifact)
if raw is None:
pytest.skip(
f"Artifact '{artifact.name}' was produced by an older dump_original_n1_7.py that does "
"not record raw observations; regenerate it with the current dump script to run the "
"preprocessor parity case."
)
_, flat_inputs, _ = _load_artifact(artifact)
original_inputs = _unflatten(flat_inputs)
ckpt = _resolve_checkpoint()
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
# CPU keeps this case runnable without a GPU; the preprocessor is deterministic.
config = GrootConfig(base_model_path=ckpt, embodiment_tag=embodiment_tag, device="cpu")
preprocessor, _ = make_groot_pre_post_processors(config)
processed = preprocessor(_raw_observation_to_lerobot_batch(raw))
compared_keys = (*_COLLATED_EXACT_KEYS, *_COLLATED_CLOSE_KEYS)
missing_original = [k for k in compared_keys if k not in original_inputs]
missing_lerobot = [k for k in compared_keys if k not in processed]
assert not missing_original, (
f"[{embodiment_tag}] artifact collated inputs miss {missing_original} "
f"(available: {sorted(original_inputs)}); regenerate the artifact with the current dump script."
)
assert not missing_lerobot, (
f"[{embodiment_tag}] LeRobot preprocessor output misses {missing_lerobot} (tensor keys "
f"available: {sorted(k for k, v in processed.items() if isinstance(v, torch.Tensor))})."
)
for name in compared_keys:
_assert_collated_parity(
embodiment_tag,
name,
processed[name],
original_inputs[name],
exact=name in _COLLATED_EXACT_KEYS,
)
@@ -9,6 +9,9 @@ LeRobot GR00T N1.7 integration requires. The two implementations therefore canno
imported in the same Python process. To keep the parity comparison FAIR, we run the
original model in its native env here and serialize, PER EMBODIMENT TAG:
* the RAW observation fed to the original processor (per-camera uint8 frames,
per-key state vectors, the language instruction), so the LeRobot side can also
run its OWN preprocessor on identical raw inputs and compare collated tensors,
* the exact pre-processed/collated model inputs (so the LeRobot side consumes the
byte-identical tensors -- same image preprocessing, tokenization, normalization),
* the random seed used right before the flow-matching sampler,
@@ -21,8 +24,10 @@ processor's per-embodiment modality configs. This lets us test many embodiment t
from the SAME checkpoint and confirm the LeRobot integration is not overfit to
``libero_sim``.
The companion pytest (run in the LeRobot env) loads each .npz, replays the identical
inputs + seed through the LeRobot GR00T N1.7 model, and asserts the outputs match.
The companion pytest (run in the LeRobot env) loads each .npz and asserts parity
twice: the collated inputs + seed are replayed through the LeRobot GR00T N1.7 model
(model parity), and the raw observation is replayed through LeRobot's own
preprocessor pipeline and compared against the collated inputs (preprocessor parity).
Usage:
.venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py \
@@ -62,10 +67,7 @@ def make_observation(seed: int, video_keys, lang_key, state_spec):
# One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics.
# Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as
# present-but-empty so the processor's state transform finds every expected key.
state = {
k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32)
for k, dim in state_spec
}
state = {k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32) for k, dim in state_spec}
language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]}
return {"video": video, "state": state, "language": language}
@@ -77,6 +79,25 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa
lang_key = modality_cfg["language"].modality_keys[0]
observation = make_observation(args.seed, video_keys, lang_key, state_spec)
# Snapshot the RAW observation exactly as fed to the original processor below. The
# consumer's preprocessor-parity case replays it through LeRobot's own preprocessor
# and compares the resulting collated tensors against the "in::" ones saved further
# down. raw_state_keys records the checkpoint modality-key order, which is the
# concatenation order of the flat LeRobot ``observation.state`` vector.
spec_keys = [key for key, _ in state_spec]
state_modality = modality_cfg.get("state")
state_keys = [key for key in state_modality.modality_keys if key in spec_keys] if state_modality else []
state_keys += [key for key in spec_keys if key not in state_keys]
raw_language = [
str(item[0]) if isinstance(item, (list, tuple)) else str(item)
for item in observation["language"][lang_key]
]
raw_flat = {f"raw::video.{key}": arr.copy() for key, arr in observation["video"].items()}
raw_flat.update({f"raw::state.{key}": arr.copy() for key, arr in observation["state"].items()})
raw_flat["raw::language"] = np.array(raw_language, dtype=object)
raw_flat["raw_video_keys"] = np.array([str(key) for key in video_keys], dtype=object)
raw_flat["raw_state_keys"] = np.array([str(key) for key in state_keys], dtype=object)
# Point the policy preprocessing at this embodiment (mirrors Gr00tPolicy.__init__).
policy.embodiment_tag = type(policy.embodiment_tag)(tag)
policy.modality_configs = {
@@ -136,6 +157,7 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa
embodiment_tag=np.array(tag),
meta_keys=np.array(list(meta.keys()), dtype=object),
meta_dtypes=np.array(list(meta.values()), dtype=object),
**raw_flat,
**flat,
)
print(f"[{tag}] action_pred {action_pred.shape} -> {out_path.name} ({os.path.getsize(out_path)} B)")
@@ -181,7 +203,12 @@ def main():
state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()]
try:
dump_one_tag(
policy, fair_model, tag, all_modality[tag], state_spec, args,
policy,
fair_model,
tag,
all_modality[tag],
state_spec,
args,
out_dir / f"original_n1_7_{tag}.npz",
)
done.append(tag)