Compare commits

..

2 Commits

Author SHA1 Message Date
Steven Palma f81fc79564 minor fixes 2026-06-12 18:45:24 +02:00
Steven Palma 9ce6633518 fix(groot): address review findings for the N1.7 port
N1.5 removal is now explicit and actionable:
- Legacy N1.5 checkpoint configs (tokenizer_assets_repo) parse and fail
  with a single clear error pointing to lerobot==0.5.1 instead of a
  cryptic draccus DecodingError
- Removed N1.5 processor registry names (groot_pack_inputs_v3,
  groot_eagle_encode_v3, groot_eagle_collate_v3) are stubbed to raise the
  same guidance; groot_action_unpack_unnormalize_v1 changed semantics, so
  the step is re-registered as _v2 and _v1 is stubbed
- N1.5 detection also recognizes checkpoint config.json content
  (model_type/architectures/eagle backbone), not just path names; every
  rejection surface includes the migration guidance
- groot.mdx documents the breaking change and migration path

Runtime fixes:
- use_bf16=False no longer crashes (compute_dtype only set when used)
- GrootN17ActionDecodeStep handles the 2-D (B, D) actions delivered by
  sync select_action (relative eef/non-eef decode was broken in
  lerobot-eval/record flows)
- Postprocessor falls back to dataset stats when a raw checkpoint lacks
  the configured embodiment tag instead of silently emitting normalized
  [-1, 1] actions
- Hub-hosted finetuned N1.7 checkpoints load: the processor config is
  resolved via hf_hub_download for non-local paths, with a tolerant
  retry when inspection fails
- Raw-checkpoint processor branch honors caller overrides (device,
  rename_map) instead of dropping them
- Relative-action raw-state cache is per-instance instead of
  process-global (cross-instance contamination)
- Camera/modality-key mismatches warn, including the zero-match
  fallback; checkpoint revision is no longer forwarded into backbone
  loading; deprecated Qwen2VLImageProcessorFast replaced with
  Qwen2VLImageProcessor

Config/UX:
- GrootConfig defaults are the N1.7 values; explicitly passed legacy
  N1.5-era values (chunk_size=50, max_state_dim=64, ...) are remapped
  with a warning instead of silently
- Explicit action_decode_transform='none' wins over the libero_sim
  default (new 'auto' sentinel) and survives save/load round-trips

Tests/CI:
- pytest.importorskip guards so fast_tests tiers pass without
  transformers (was 10 failures, now 0)
- Regression tests for every fix; from_pretrained rejection tests now
  actually exercise from_pretrained
- Parity test reads the artifact seed, fails on shape mismatch instead
  of silently truncating, and a new case runs LeRobot's real Qwen3-VL
  preprocessing on raw observations dumped by the producer
- docs: dead huggingface-cli download replaced with hf download

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 16:51:14 +02:00
8 changed files with 1315 additions and 694 deletions
@@ -42,14 +42,6 @@ GROOT_N1_5_REMOVAL_GUIDANCE = (
) )
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B" GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B" GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
# Image preprocessing geometry the GR00T N1.7 backbone was trained on. The processor
# falls back to these when a checkpoint ships no image sizing in its processor_config
# (e.g. fine-tuning the raw nvidia/GR00T-N1.7-3B base with a new embodiment), so frames
# are resized to the expected resolution instead of being patchified at full camera
# resolution (which both slows training and is a train/checkpoint distribution mismatch).
# Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero" GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it # Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an # to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
@@ -329,6 +321,9 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
normalized = candidate.lower().replace("-", "_") normalized = candidate.lower().replace("-", "_")
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}: if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
return GROOT_N1_7 return GROOT_N1_7
# nvidia/GR00T-N1.5-3B ships model_type 'gr00t_n1_5' and architectures ['GR00T_N1_5'].
# Recognise them so N1.5 checkpoints at generic local paths are rejected loudly
# instead of being silently treated as N1.7 (see infer_groot_model_version).
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}: if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
return GROOT_N1_5 return GROOT_N1_5
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL: if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
@@ -370,7 +365,11 @@ class GrootConfig(PreTrainedConfig):
} }
) )
# Groot-specific model parameters # Deprecated and unused: image sizing is handled by the backbone's image processor.
# Kept only so config.json files saved with earlier versions still parse.
image_size: tuple[int, int] = (256, 256)
# Groot-specific model parameters (from groot_finetune_script.py)
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only. # Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
model_version: str = GROOT_N1_7 model_version: str = GROOT_N1_7
@@ -386,43 +385,14 @@ class GrootConfig(PreTrainedConfig):
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'. # transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
# Deprecated, GR00T N1.5 only — do not set. Kept so config.json files saved by lerobot<=0.5.1
# still parse (draccus rejects unknown fields) and can be rejected in __post_init__ with a
# clear error pointing at GROOT_N1_5_REMOVAL_GUIDANCE instead of a cryptic DecodingError.
tokenizer_assets_repo: str | None = None
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1') # Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment" embodiment_tag: str = "new_embodiment"
# Inference-only override for the number of flow-matching denoising steps used to decode an
# action chunk. None = use the model checkpoint default (currently 4). Higher values trade
# inference speed for action quality; applied at base-model load via _create_groot_model.
num_inference_timesteps: int | None = None
# If set, caps the number of open-loop actions executed before replanning (inference cadence).
# Overrides the value inferred from the checkpoint/embodiment in _resolve_action_queue_steps.
execution_horizon: int | None = None
# Opt-in. Copy a pretrained embodiment category slot's action-head weights into the target
# embodiment slot at base-model build (in _create_groot_model), to warm-start a cold
# 'new_embodiment' slot. Accepts an embodiment name (e.g.
# 'oxe_droid_relative_eef_relative_joint') or an int embodiment id. Runs on every fresh
# base-model build (so it applies during lerobot-train, which uses __init__ not
# from_pretrained); on a fine-tuned checkpoint reload it is harmlessly overwritten.
warm_start_embodiment_slot: int | str | None = None
# Opt-in relative-action support for the 'new_embodiment' slot (sync-safe, GR00T-native).
# When True, GR00T converts absolute->relative inside its own pack step (training) and
# reconstructs absolute inside its own flat decode step (inference), using a cached
# reference state. The dataset stays absolute; compute relative ACTION stats with
# `lerobot-edit-dataset --operation.relative_action true --operation.relative_exclude_joints
# "['gripper']"` (this only rewrites stats, not actions).
use_relative_actions: bool = False
# Joint names kept absolute (not converted to relative) when use_relative_actions is True.
# Case-insensitive token match against action_feature_names.
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Action dimension names from dataset metadata; auto-populated by the factory from dataset
# meta (see factory.py:528). Used to build the relative-action mask so the gripper can be
# identified and kept absolute. When None, the gripper cannot be identified.
action_feature_names: list[str] | None = None
# Fine-tuning control arguments # Fine-tuning control arguments
# Whether to fine-tune the llm backbone # Whether to fine-tune the llm backbone
@@ -458,13 +428,10 @@ class GrootConfig(PreTrainedConfig):
warmup_ratio: float = 0.05 warmup_ratio: float = 0.05
use_bf16: bool = True use_bf16: bool = True
# TODO(Steven): Remove these deprecated fields in a future release. # Deprecated Isaac-GR00T runner fields below — unused by the LeRobot N1.7 implementation
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by # (nothing in src/lerobot reads them). They are kept only so config.json files saved by
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them # earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
# would break every previously saved groot checkpoint at config-load time. # would break every previously saved groot checkpoint at config-load time.
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
tokenizer_assets_repo: str | None = None
video_backend: str = "decord" video_backend: str = "decord"
balance_dataset_weights: bool = True balance_dataset_weights: bool = True
balance_trajectory_weights: bool = True balance_trajectory_weights: bool = True
@@ -478,6 +445,9 @@ class GrootConfig(PreTrainedConfig):
resume: bool = False resume: bool = False
def __post_init__(self): def __post_init__(self):
# 'tokenizer_assets_repo' only ever existed for GR00T N1.5 (lerobot<=0.5.1) and was
# serialized into every groot checkpoint config.json, so a value here means a legacy
# N1.5 checkpoint or config is being loaded.
if self.tokenizer_assets_repo is not None: if self.tokenizer_assets_repo is not None:
raise ValueError( raise ValueError(
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks " "Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
@@ -612,11 +582,22 @@ class GrootConfig(PreTrainedConfig):
@property @property
def action_delta_indices(self) -> list[int]: def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions.""" """Return indices for delta actions.
The model action horizon is read from the checkpoint's processor_config.json
when available; the result is cached (keyed on the inputs that determine it) so
repeated access during dataset/training setup does not re-read from disk.
"""
cache_key = (self.base_model_path, self.embodiment_tag, self.chunk_size)
cached = getattr(self, "_action_delta_indices_cache", None)
if cached is not None and cached[0] == cache_key:
return cached[1]
model_action_horizon = ( model_action_horizon = (
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40 infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
) )
return list(range(min(self.chunk_size, model_action_horizon))) indices = list(range(min(self.chunk_size, model_action_horizon)))
object.__setattr__(self, "_action_delta_indices_cache", (cache_key, indices))
return indices
@property @property
def reward_delta_indices(self) -> None: def reward_delta_indices(self) -> None:
+9 -4
View File
@@ -32,7 +32,6 @@ from torch.distributions import Beta
from lerobot.utils.import_utils import _transformers_available, require_package from lerobot.utils.import_utils import _transformers_available, require_package
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
if TYPE_CHECKING or _transformers_available: if TYPE_CHECKING or _transformers_available:
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
@@ -72,13 +71,13 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"backbone_embedding_dim": 2048, "backbone_embedding_dim": 2048,
"tune_llm": False, "tune_llm": False,
"tune_visual": False, "tune_visual": False,
"select_layer": 16, "select_layer": 16, # N1.7-3B checkpoint value; real checkpoint loads override this from config.json
"reproject_vision": False, "reproject_vision": False,
"use_flash_attention": True, "use_flash_attention": True,
"load_bf16": False, "load_bf16": False,
"backbone_trainable_params_fp32": True, "backbone_trainable_params_fp32": True,
"image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE, "image_crop_size": (230, 230),
"image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE, "image_target_size": (256, 256),
"shortest_image_edge": None, "shortest_image_edge": None,
"crop_fraction": None, "crop_fraction": None,
"random_rotation_angle": None, "random_rotation_angle": None,
@@ -823,6 +822,8 @@ def get_backbone_cls(config: GR00TN17Config):
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name: if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
return Qwen3Backbone return Qwen3Backbone
if config.backbone_model_type == "qwen": if config.backbone_model_type == "qwen":
# Local backbone checkpoints (e.g. hub-cache snapshot paths) contain neither hub
# marker, so trust the explicit backbone type but surface what is being assumed.
logger.warning( logger.warning(
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible " "Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
"backbone because backbone_model_type='qwen'.", "backbone because backbone_model_type='qwen'.",
@@ -913,6 +914,10 @@ class GR00TN17(PreTrainedModel):
"trust_remote_code": True "trust_remote_code": True
} }
load_backbone_weights = kwargs.pop("load_backbone_weights", False) load_backbone_weights = kwargs.pop("load_backbone_weights", False)
# Only repo-agnostic hub kwargs are forwarded to the backbone loading kwargs:
# ``revision`` pins the GR00T checkpoint repo (see snapshot_download below) and would
# be invalid for the unrelated backbone repo (``config.model_name``). Pin the backbone
# itself by passing ``revision`` inside ``transformers_loading_kwargs``.
for key in ("cache_dir", "local_files_only", "token"): for key in ("cache_dir", "local_files_only", "token"):
if key in kwargs: if key in kwargs:
transformers_loading_kwargs.setdefault(key, kwargs[key]) transformers_loading_kwargs.setdefault(key, kwargs[key])
+6 -125
View File
@@ -54,98 +54,6 @@ logger = logging.getLogger(__name__)
T = TypeVar("T", bound="GrootPolicy") T = TypeVar("T", bound="GrootPolicy")
def _resolve_embodiment_id(value: int | str) -> int:
"""Resolve an embodiment id from an int or an N1.7 embodiment name.
Names are looked up in N1_7_EMBODIMENT_MAPPING (e.g. 'new_embodiment' -> 10).
Raises ValueError listing the known keys if the name is unknown.
"""
from .processor_groot import N1_7_EMBODIMENT_MAPPING
if isinstance(value, bool): # bool is a subclass of int; reject it explicitly.
raise ValueError(f"Embodiment id must be an int or embodiment name, got bool {value!r}.")
if isinstance(value, int):
return value
if value in N1_7_EMBODIMENT_MAPPING:
return N1_7_EMBODIMENT_MAPPING[value]
raise ValueError(
f"Unknown GR00T N1.7 embodiment name '{value}'. Known names: "
f"{sorted(N1_7_EMBODIMENT_MAPPING.keys())}."
)
def _warm_start_embodiment_slot(model, source_id: int, target_id: int) -> None:
"""Copy category-specific action-head weights from one embodiment slot to another.
Used at base-model load (training only) to warm-start a cold target embodiment slot
(e.g. 'new_embodiment') from a pretrained slot. Copies the per-category ``W``/``b``
parameters across every CategorySpecificLinear in the action head's state encoder,
action encoder, and action decoder. No-ops (with a logged warning) if the ids are out
of range or identical.
"""
if source_id == target_id:
logger.warning(
"GR00T warm_start_embodiment_slot: source and target embodiment id are both %d; "
"skipping (nothing to copy).",
source_id,
)
return
action_head = getattr(model, "action_head", None)
if action_head is None:
logger.warning("GR00T warm_start_embodiment_slot: model has no action_head; skipping.")
return
# Each entry is (submodule, [CategorySpecificLinear attribute names]).
linear_groups = [
(getattr(action_head, "state_encoder", None), ["layer1", "layer2"]),
(getattr(action_head, "action_encoder", None), ["W1", "W2", "W3"]),
(getattr(action_head, "action_decoder", None), ["layer1", "layer2"]),
]
copied: list[str] = []
with torch.no_grad():
for submodule, attr_names in linear_groups:
if submodule is None:
continue
submodule_name = type(submodule).__name__
for attr_name in attr_names:
lin = getattr(submodule, attr_name, None)
if lin is None or not hasattr(lin, "W") or not hasattr(lin, "b"):
continue
num_categories = lin.W.shape[0]
if not (0 <= source_id < num_categories and 0 <= target_id < num_categories):
logger.warning(
"GR00T warm_start_embodiment_slot: source_id=%d/target_id=%d out of range "
"for %s.%s (num_categories=%d); skipping this layer.",
source_id,
target_id,
submodule_name,
attr_name,
num_categories,
)
continue
lin.W.data[target_id] = lin.W.data[source_id].clone()
lin.b.data[target_id] = lin.b.data[source_id].clone()
copied.append(f"{submodule_name}.{attr_name}")
if copied:
logger.info(
"GR00T warm_start_embodiment_slot: copied action-head weights from embodiment slot %d "
"to slot %d for: %s.",
source_id,
target_id,
", ".join(copied),
)
else:
logger.warning(
"GR00T warm_start_embodiment_slot: no action-head weights were copied "
"(source_id=%d, target_id=%d).",
source_id,
target_id,
)
class GrootPolicy(PreTrainedPolicy): class GrootPolicy(PreTrainedPolicy):
"""Wrapper around external Groot model for LeRobot integration.""" """Wrapper around external Groot model for LeRobot integration."""
@@ -185,24 +93,11 @@ class GrootPolicy(PreTrainedPolicy):
transformers_loading_kwargs={"trust_remote_code": True}, transformers_loading_kwargs={"trust_remote_code": True},
) )
# Inference-only override for the number of flow-matching denoising steps. The action # GR00TN17 defines no compute_dtype attribute, so only record the
# head reads self.num_inference_timesteps in get_action_with_features; dt (1/n) and the # bf16 preference when it is enabled instead of reading a default back.
# t schedule adapt automatically. if self.config.use_bf16:
if self.config.num_inference_timesteps is not None: model.compute_dtype = "bfloat16"
n = int(self.config.num_inference_timesteps) model.config.compute_dtype = "bfloat16"
model.config.num_inference_timesteps = n
model.action_head.num_inference_timesteps = n
# Opt-in: warm-start a cold embodiment slot (e.g. 'new_embodiment') from a pretrained
# slot's action-head weights. Done here (not in from_pretrained) so it applies on every
# fresh base-model build -- training via make_policy instantiates GrootPolicy(config)
# directly (factory uses __init__ when cfg.pretrained_path is unset), it does NOT go
# through from_pretrained. On a fine-tuned checkpoint reload this also runs but is
# immediately overwritten by the loaded state_dict, so it is a harmless no-op there.
if self.config.warm_start_embodiment_slot is not None:
source_id = _resolve_embodiment_id(self.config.warm_start_embodiment_slot)
target_id = _resolve_embodiment_id(self.config.embodiment_tag)
_warm_start_embodiment_slot(model, source_id, target_id)
return model return model
@@ -371,11 +266,7 @@ class GrootPolicy(PreTrainedPolicy):
horizons.append(checkpoint_action_horizon) horizons.append(checkpoint_action_horizon)
if execution_horizon is not None: if execution_horizon is not None:
horizons.append(execution_horizon) horizons.append(execution_horizon)
# An explicit config override caps the open-loop horizon (inference cadence), overriding return min(horizons)
# the value inferred from the checkpoint/embodiment.
if self.config.execution_horizon is not None:
horizons.append(max(1, int(self.config.execution_horizon)))
return max(1, min(horizons))
def _resolve_prediction_horizon(self, actions: Tensor) -> int: def _resolve_prediction_horizon(self, actions: Tensor) -> int:
"""Return the policy-facing action horizon for a native GR00T prediction.""" """Return the policy-facing action horizon for a native GR00T prediction."""
@@ -543,16 +434,6 @@ class GrootPolicy(PreTrainedPolicy):
""" """
self.eval() self.eval()
# Freeze the relative-action reference at the exact chunk-prediction event so every popped
# delta of this chunk is reconstructed (in the postprocessor) against this S_T, not the
# per-tick latest state. Driven by the predict event, so it is correct under any runtime
# n_action_steps/execution_horizon. No-op for non-relative checkpoints (holder absent/unused).
from .processor_groot import _GROOT_REF_HOLDER_KEY
holder = batch.get(_GROOT_REF_HOLDER_KEY)
if holder is not None:
holder.freeze()
# Preprocessing is handled by the processor pipeline, so we just filter the batch. # Preprocessing is handled by the processor pipeline, so we just filter the batch.
# During inference, we do not pass action because it is predicted. # During inference, we do not pass action because it is predicted.
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor. # N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
+133 -440
View File
@@ -23,10 +23,9 @@ from typing import TYPE_CHECKING, Any
import numpy as np import numpy as np
import torch import torch
import torchvision.transforms.v2.functional as tv_functional
from einops import rearrange from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image from PIL import Image
from torchvision.transforms import InterpolationMode
from lerobot.utils.import_utils import _transformers_available from lerobot.utils.import_utils import _transformers_available
@@ -47,8 +46,6 @@ from lerobot.processor import (
RenameObservationsProcessorStep, RenameObservationsProcessorStep,
batch_to_transition, batch_to_transition,
policy_action_to_transition, policy_action_to_transition,
to_absolute_actions,
to_relative_actions,
transition_to_batch, transition_to_batch,
transition_to_policy_action, transition_to_policy_action,
) )
@@ -61,14 +58,11 @@ from lerobot.utils.constants import (
POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME,
) )
from lerobot.utils.device_utils import get_safe_torch_device
from .configuration_groot import ( from .configuration_groot import (
GROOT_ACTION_DECODE_TRANSFORM_LIBERO, GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
GROOT_N1_5_REMOVAL_GUIDANCE, GROOT_N1_5_REMOVAL_GUIDANCE,
GROOT_N1_7_BACKBONE_MODEL, GROOT_N1_7_BACKBONE_MODEL,
N1_7_DEFAULT_IMAGE_CROP_SIZE,
N1_7_DEFAULT_IMAGE_TARGET_SIZE,
GrootConfig, GrootConfig,
is_raw_groot_n1_7_checkpoint, is_raw_groot_n1_7_checkpoint,
) )
@@ -90,30 +84,6 @@ N1_7_EMBODIMENT_MAPPING = {
} }
_GROOT_REF_HOLDER_KEY = "_groot_relative_ref_holder" # private; dropped by _filter_groot_inputs, never reaches the model
class _GrootRelativeRefHolder:
"""Runtime-only carrier shared (by object identity) between the pack step (owner/writer of the
live reference), GrootPolicy.predict_action_chunk (freezes it at a real predict event), and the
decode step (reads the frozen reference). Not serialized. One instance per pack step."""
__slots__ = ("reference_state", "raw_state", "frozen_reference", "frozen_raw")
def __init__(self):
self.reference_state = None
self.raw_state = None
self.frozen_reference = None
self.frozen_raw = None
def freeze(self) -> None:
self.frozen_reference = self.reference_state
self.frozen_raw = self.raw_state
def clear(self) -> None:
self.reference_state = self.raw_state = self.frozen_reference = self.frozen_raw = None
@dataclass @dataclass
class _GrootN17CheckpointProcessorAssets: class _GrootN17CheckpointProcessorAssets:
"""Processor metadata loaded from a raw Isaac-GR00T N1.7 checkpoint. """Processor metadata loaded from a raw Isaac-GR00T N1.7 checkpoint.
@@ -143,39 +113,6 @@ class _GrootN17CheckpointProcessorAssets:
use_albumentations: bool use_albumentations: bool
def _resolve_base_model_local_dir(base_model_path: str | None) -> str | None:
"""Resolve a base model path to a local snapshot dir holding its sidecar JSONs.
``is_raw_groot_n1_7_checkpoint`` needs a local directory (or config.json) to inspect, so a
bare HF repo-id (e.g. ``nvidia/GR00T-N1.7-3B``) would never be recognised as a raw N1.7
checkpoint and the processor would fall back to LeRobot default image geometry instead of the
checkpoint's processor_config.json geometry. When the path is not already a local dir, this
downloads just the JSON sidecars and returns the local snapshot dir. Offline-safe: any failure
returns the original string unchanged. Only used on the fresh-build (training) path; inference
loads the serialized processor, so no per-inference network call is added.
"""
if base_model_path is None:
return None
if Path(base_model_path).expanduser().is_dir():
return base_model_path
try:
from huggingface_hub import snapshot_download
local_dir = snapshot_download(
base_model_path,
repo_type="model",
allow_patterns=["*.json"],
)
logging.debug(
"Resolved GR00T base model '%s' to local snapshot '%s' for processor asset loading.",
base_model_path,
local_dir,
)
return local_dir
except Exception: # noqa: BLE001 (offline-safe: fall back to the original path on any failure)
return base_model_path
def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17CheckpointProcessorAssets | None: def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17CheckpointProcessorAssets | None:
"""Load N1.7 processor settings from checkpoint sidecar JSON files. """Load N1.7 processor settings from checkpoint sidecar JSON files.
@@ -183,11 +120,10 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
can keep using caller-provided dataset stats and config values. can keep using caller-provided dataset stats and config values.
""" """
resolved_base_model_path = _resolve_base_model_local_dir(config.base_model_path) if not is_raw_groot_n1_7_checkpoint(config.base_model_path):
if not is_raw_groot_n1_7_checkpoint(resolved_base_model_path):
return None return None
checkpoint_path = Path(resolved_base_model_path).expanduser() checkpoint_path = Path(config.base_model_path).expanduser()
processor_config = _read_json(checkpoint_path / "processor_config.json") processor_config = _read_json(checkpoint_path / "processor_config.json")
processor_kwargs = processor_config.get("processor_kwargs", {}) processor_kwargs = processor_config.get("processor_kwargs", {})
if not isinstance(processor_kwargs, dict): if not isinstance(processor_kwargs, dict):
@@ -512,74 +448,60 @@ def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
return any(bool(modality_stats) for modality_stats in stats.values()) return any(bool(modality_stats) for modality_stats in stats.values())
def _build_relative_action_mask( def _legacy_groot_processor_overrides(
action_dim: int, config: GrootConfig,
exclude_joints: list[str] | None, dataset_stats: dict[str, dict[str, torch.Tensor]] | None,
action_names: list[str] | None, preprocessor_overrides: dict[str, Any] | None = None,
) -> list[bool]: postprocessor_overrides: dict[str, Any] | None = None,
"""Build the per-dim relative-action mask (True = convert to relative, False = keep absolute). ) -> tuple[dict[str, Any], dict[str, Any]]:
"""Patch older serialized Groot processors with fields current processors expect."""
Replicates ``RelativeActionsProcessorStep._build_mask`` semantics: dims are excluded preprocessor_overrides = dict(preprocessor_overrides or {})
(kept absolute) by case-insensitive token match against ``action_names``. postprocessor_overrides = dict(postprocessor_overrides or {})
pack_inputs_key = "groot_n1_7_pack_inputs_v1"
When ``action_names`` is None we cannot identify the gripper, so this returns all-True pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {}))
(every dim treated as relative). The user should ensure ``config.action_feature_names`` is pack_input_overrides["normalize_min_max"] = True
populated (the factory does this from dataset meta) so the gripper can be kept absolute; preprocessor_overrides[pack_inputs_key] = pack_input_overrides
arm-relative still works either way, but a missing-name gripper would be treated as relative.
try:
env_action_dim = int(config.output_features[ACTION].shape[0])
except Exception:
env_action_dim = 0
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v2", {}))
action_unpack_overrides["normalize_min_max"] = True
action_unpack_overrides["env_action_dim"] = env_action_dim
postprocessor_overrides["groot_action_unpack_unnormalize_v2"] = action_unpack_overrides
return preprocessor_overrides, postprocessor_overrides
def _pretrained_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
"""Check whether a serialized processor pipeline contains a registry step.
Resolves the processor config from a local directory or, for Hub repo ids,
via ``hf_hub_download`` (which serves the cached copy when offline). Returns
False when the config cannot be resolved; loading then proceeds with the
legacy overrides and `make_groot_pre_post_processors_from_pretrained` retries
without them if they do not match the serialized pipeline.
""" """
if not exclude_joints or action_names is None: path = Path(pretrained_path).expanduser()
return [True] * action_dim if path.is_dir():
config = _read_json(path / config_filename)
exclude_tokens = [str(name).lower() for name in exclude_joints if name] elif path.exists():
if not exclude_tokens: return False
return [True] * action_dim else:
try:
mask: list[bool] = [] config_path = hf_hub_download(
for name in action_names[:action_dim]: repo_id=str(pretrained_path), filename=config_filename, repo_type="model"
action_name = str(name).lower()
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
mask.append(not is_excluded)
if len(mask) < action_dim:
mask.extend([True] * (action_dim - len(mask)))
return mask
# GR00T normalizes state/action inside its own processor steps and so deliberately has no
# NormalizerProcessorStep/UnnormalizerProcessorStep (see GrootConfig.normalization_mapping, which is
# IDENTITY for every feature). lerobot-train nonetheless emits these standard override keys
# unconditionally, so for a GR00T pipeline they legitimately match no step. They are dropped up front
# by _drop_groot_absent_standard_overrides so they neither break loading nor mask genuine typos.
_GROOT_ABSENT_STANDARD_OVERRIDE_KEYS = frozenset({"normalizer_processor", "unnormalizer_processor"})
def _drop_groot_absent_standard_overrides(overrides: dict[str, Any] | None) -> dict[str, Any] | None:
"""Strip standard normalization override keys that a GR00T pipeline has no step for.
``lerobot-train`` emits ``normalizer_processor``/``unnormalizer_processor`` overrides
unconditionally, but GR00T normalizes inside its own steps and has no such step (see
``GrootConfig.normalization_mapping``). Both override-application paths reject keys that match no
step — ``_apply_groot_step_overrides`` raises for the freshly built raw-checkpoint pipeline, and
``PolicyProcessorPipeline.from_pretrained`` raises via its used-override validation for the
serialized pipeline — so these keys are removed before either path runs. Any other unknown key
(e.g. a typo) is left in place and still raises.
"""
if not overrides:
return overrides
filtered: dict[str, Any] = {}
for key, value in overrides.items():
if key in _GROOT_ABSENT_STANDARD_OVERRIDE_KEYS:
logging.debug(
"Ignoring override key '%s': GR00T normalizes inside its own processor steps and has "
"no matching step (see GrootConfig.normalization_mapping).",
key,
) )
continue except Exception:
filtered[key] = value return False
return filtered config = _read_json(Path(config_path))
steps = config.get("steps", [])
if not isinstance(steps, list):
return False
return any(isinstance(step, dict) and step.get("registry_name") == step_name for step in steps)
def _apply_groot_step_overrides( def _apply_groot_step_overrides(
@@ -595,8 +517,7 @@ def _apply_groot_step_overrides(
steps by registry name only — prefer registry names so overrides keep steps by registry name only — prefer registry names so overrides keep
working after the checkpoint is converted and reloaded from a serialized working after the checkpoint is converted and reloaded from a serialized
pipeline). Keys or fields that match nothing raise instead of being dropped pipeline). Keys or fields that match nothing raise instead of being dropped
silently (standard normalization keys GR00T has no step for are removed silently.
beforehand by ``_drop_groot_absent_standard_overrides``).
""" """
if not overrides: if not overrides:
@@ -652,13 +573,7 @@ def make_groot_pre_post_processors_from_pretrained(
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction], PolicyProcessorPipeline[PolicyAction, PolicyAction],
]: ]:
"""Load Groot processors for a raw N1.7 checkpoint or a serialized LeRobot pipeline.""" """Load Groot processors while preserving compatibility with older serialized configs."""
# Drop the standard normalizer/unnormalizer override keys lerobot-train emits unconditionally:
# GR00T has no such steps, so they would make both the raw-checkpoint and serialized override
# paths raise. This must happen before either branch below.
preprocessor_overrides = _drop_groot_absent_standard_overrides(preprocessor_overrides)
postprocessor_overrides = _drop_groot_absent_standard_overrides(postprocessor_overrides)
if is_raw_groot_n1_7_checkpoint(pretrained_path): if is_raw_groot_n1_7_checkpoint(pretrained_path):
processor_cfg = copy(config) processor_cfg = copy(config)
@@ -674,13 +589,49 @@ def make_groot_pre_post_processors_from_pretrained(
_apply_groot_step_overrides(postprocessor, postprocessor_overrides) _apply_groot_step_overrides(postprocessor, postprocessor_overrides)
return preprocessor, postprocessor return preprocessor, postprocessor
preprocessor, postprocessor = _load_groot_processor_pipelines( caller_preprocessor_overrides = dict(preprocessor_overrides or {})
caller_postprocessor_overrides = dict(postprocessor_overrides or {})
if _pretrained_processor_config_has_step(
pretrained_path, pretrained_path,
preprocessor_overrides=preprocessor_overrides, postprocessor_config_filename,
postprocessor_overrides=postprocessor_overrides, "groot_n1_7_action_decode_v1",
preprocessor_config_filename=preprocessor_config_filename, ):
postprocessor_config_filename=postprocessor_config_filename, # Converted raw N1.7 checkpoints already carry the checkpoint-specific
) # action decoder. Adding the legacy action-unpack override would target
# a step that is not present and break loading.
applied_legacy_overrides = False
preprocessor_overrides = caller_preprocessor_overrides
postprocessor_overrides = caller_postprocessor_overrides
else:
applied_legacy_overrides = True
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
config=config,
dataset_stats=dataset_stats,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
)
try:
preprocessor, postprocessor = _load_groot_processor_pipelines(
pretrained_path,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
preprocessor_config_filename=preprocessor_config_filename,
postprocessor_config_filename=postprocessor_config_filename,
)
except KeyError:
if not applied_legacy_overrides:
raise
# The legacy overrides target steps that are absent from the serialized
# pipelines (e.g. a converted raw N1.7 checkpoint whose postprocessor
# config could not be inspected before loading); retry with the caller
# overrides only.
preprocessor, postprocessor = _load_groot_processor_pipelines(
pretrained_path,
preprocessor_overrides=caller_preprocessor_overrides,
postprocessor_overrides=caller_postprocessor_overrides,
preprocessor_config_filename=preprocessor_config_filename,
postprocessor_config_filename=postprocessor_config_filename,
)
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor) _reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
_reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor) _reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor)
return preprocessor, postprocessor return preprocessor, postprocessor
@@ -747,15 +698,8 @@ def _reconnect_groot_n1_7_pack_decode_steps(
if pack_step is None: if pack_step is None:
return return
# Both decode steps read the pack step's cached state via a non-serialized ``pack_step`` link:
# GrootN17ActionDecodeStep reads the per-modality raw state; the relative-action path
# (GrootActionUnpackUnnormalizeStep) reads the cached reference state. Restore both links after
# deserialization.
for step in postprocessor.steps: for step in postprocessor.steps:
if ( if isinstance(step, GrootN17ActionDecodeStep) and step.pack_step is None:
isinstance(step, (GrootN17ActionDecodeStep, GrootActionUnpackUnnormalizeStep))
and step.pack_step is None
):
step.pack_step = pack_step step.pack_step = pack_step
@@ -833,45 +777,23 @@ def make_groot_pre_post_processors(
video_modality_keys=video_modality_keys, video_modality_keys=video_modality_keys,
raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None, raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None,
modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None, modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None,
use_relative_actions=config.use_relative_actions,
relative_exclude_joints=config.relative_exclude_joints,
action_feature_names=config.action_feature_names,
) )
# Resolve the image preprocessing geometry. Honor the checkpoint's processor_config
# when it provides an image_target_size; otherwise fall back to the geometry the
# N1.7 backbone was trained on. Without this fallback a raw base checkpoint with no
# processor_config image sizing (e.g. fine-tuning nvidia/GR00T-N1.7-3B with a new
# embodiment, where checkpoint_assets is None) would patchify full-resolution camera
# frames, inflating the VLM token count -- slowing both dataloading_s and update_s --
# and feeding the model a resolution it was not trained on.
if checkpoint_assets is not None and checkpoint_assets.image_target_size is not None:
image_target_size = checkpoint_assets.image_target_size
image_crop_size = checkpoint_assets.image_crop_size
shortest_image_edge = checkpoint_assets.shortest_image_edge
crop_fraction = checkpoint_assets.crop_fraction
else:
image_target_size = list(N1_7_DEFAULT_IMAGE_TARGET_SIZE)
image_crop_size = list(N1_7_DEFAULT_IMAGE_CROP_SIZE)
shortest_image_edge = None
crop_fraction = None
use_albumentations = checkpoint_assets.use_albumentations if checkpoint_assets is not None else False
input_steps: list[ProcessorStep] = [ input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(), AddBatchDimensionProcessorStep(),
pack_step, pack_step,
GrootN17VLMEncodeStep( GrootN17VLMEncodeStep(
model_name=config.n1_7_backbone_model, model_name=config.n1_7_backbone_model,
image_crop_size=image_crop_size, image_crop_size=checkpoint_assets.image_crop_size if checkpoint_assets is not None else None,
image_target_size=image_target_size, image_target_size=checkpoint_assets.image_target_size if checkpoint_assets is not None else None,
shortest_image_edge=shortest_image_edge, shortest_image_edge=checkpoint_assets.shortest_image_edge
crop_fraction=crop_fraction, if checkpoint_assets is not None
use_albumentations=use_albumentations, else None,
# Run the image resize/normalize/patchify on the training device when crop_fraction=checkpoint_assets.crop_fraction if checkpoint_assets is not None else None,
# possible instead of the single CPU main-loop thread (the dominant use_albumentations=checkpoint_assets.use_albumentations
# cost folded into dataloading_s). if checkpoint_assets is not None
device=config.device, else False,
), ),
DeviceProcessorStep(device=config.device), DeviceProcessorStep(device=config.device),
] ]
@@ -895,10 +817,6 @@ def make_groot_pre_post_processors(
stats=padded_stats, stats=padded_stats,
normalize_min_max=True, normalize_min_max=True,
clip_normalized_action=True, clip_normalized_action=True,
use_relative_actions=config.use_relative_actions,
relative_exclude_joints=config.relative_exclude_joints,
action_feature_names=config.action_feature_names,
pack_step=pack_step,
) )
else: else:
action_decode_step = GrootN17ActionDecodeStep( action_decode_step = GrootN17ActionDecodeStep(
@@ -1114,61 +1032,6 @@ def _transform_n1_7_image_for_vlm(
return image return image
def _transform_n1_7_image_for_vlm_torch(
image: torch.Tensor,
*,
image_crop_size: list[int] | None,
image_target_size: list[int] | None,
shortest_image_edge: int | None,
crop_fraction: float | None,
) -> torch.Tensor:
"""Torch/torchvision port of the non-albumentations branch of
:func:`_transform_n1_7_image_for_vlm`.
Operates on a ``(C, H, W)`` uint8 tensor and keeps the result on the input
tensor's device so the resize/crop run on GPU when the tensor is. Bicubic
interpolation with antialiasing matches PIL's ``Image.Resampling.BICUBIC``
closely (sub-``2/255`` per-pixel on worst-case inputs). The ``use_albumentations``
cv2/INTER_AREA path has no torch equivalent and stays on the PIL helper.
"""
if image_target_size is None:
return image
target_h, target_w = image_target_size
_, height, width = image.shape
square_edge = max(height, width)
if height != width:
left = (square_edge - width) // 2
top = (square_edge - height) // 2
image = tv_functional.pad(
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
)
resize_edge = shortest_image_edge or target_h
image = tv_functional.resize(
image, [resize_edge, resize_edge], interpolation=InterpolationMode.BICUBIC, antialias=True
)
if crop_fraction is None and image_crop_size is not None:
crop_fraction = image_crop_size[0] / float(target_h)
if crop_fraction is not None and 0.0 < crop_fraction < 1.0:
# Match the PIL helper's center crop exactly: round() the crop size but
# floor() the offset (torchvision.center_crop rounds the offset, which
# shifts the region by 1px when (edge - crop) is odd).
crop_h = max(1, int(round(image.shape[-2] * crop_fraction)))
crop_w = max(1, int(round(image.shape[-1] * crop_fraction)))
top = max(0, (image.shape[-2] - crop_h) // 2)
left = max(0, (image.shape[-1] - crop_w) // 2)
image = image[..., top : top + crop_h, left : left + crop_w]
if tuple(image.shape[-2:]) != (target_h, target_w):
image = tv_functional.resize(
image, [target_h, target_w], interpolation=InterpolationMode.BICUBIC, antialias=True
)
return image
@dataclass @dataclass
@ProcessorStepRegistry.register(name="groot_n1_7_pack_inputs_v1") @ProcessorStepRegistry.register(name="groot_n1_7_pack_inputs_v1")
class GrootN17PackInputsStep(ProcessorStep): class GrootN17PackInputsStep(ProcessorStep):
@@ -1195,18 +1058,11 @@ class GrootN17PackInputsStep(ProcessorStep):
video_modality_keys: list[str] | None = None video_modality_keys: list[str] | None = None
raw_stats: dict[str, Any] | None = None raw_stats: dict[str, Any] | None = None
modality_config: dict[str, Any] | None = None modality_config: dict[str, Any] | None = None
# Opt-in relative-action support: convert absolute->relative actions inside this pack step # Unused: kept so serialized configs that include it still load. The raw
# (training) using the cached raw reference state, keeping excluded joints (e.g. gripper) # state cache is per instance (_last_raw_state), never process-global.
# absolute. The paired GrootActionUnpackUnnormalizeStep reconstructs absolute on decode. state_cache_key: str = ""
use_relative_actions: bool = False
relative_exclude_joints: list[str] = field(default_factory=list)
action_feature_names: list[str] | None = None
_last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False) _last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False)
_last_reference_state: torch.Tensor | None = field(default=None, init=False, repr=False)
_warned_image_keys: bool = field(default=False, init=False, repr=False) _warned_image_keys: bool = field(default=False, init=False, repr=False)
_ref_holder: "_GrootRelativeRefHolder" = field(
default_factory=_GrootRelativeRefHolder, init=False, repr=False
)
def _ordered_image_keys(self, obs: dict[str, Any]) -> list[str]: def _ordered_image_keys(self, obs: dict[str, Any]) -> list[str]:
available = {key for key in obs if key.startswith(OBS_IMAGES)} available = {key for key in obs if key.startswith(OBS_IMAGES)}
@@ -1328,7 +1184,6 @@ class GrootN17PackInputsStep(ProcessorStep):
start_idx += dim start_idx += dim
if grouped: if grouped:
self._last_raw_state = grouped self._last_raw_state = grouped
self._ref_holder.raw_state = grouped
img_keys = self._ordered_image_keys(obs) img_keys = self._ordered_image_keys(obs)
if img_keys: if img_keys:
@@ -1348,9 +1203,6 @@ class GrootN17PackInputsStep(ProcessorStep):
formalize_language=self.formalize_language, formalize_language=self.formalize_language,
) )
# Reference state for relative-action conversion (RAW, pre-normalization, (B, D)). Cached
# regardless of whether an action is present so inference caches it too for decode.
relative_reference_state: torch.Tensor | None = None
if OBS_STATE in obs: if OBS_STATE in obs:
state = obs[OBS_STATE] state = obs[OBS_STATE]
if state.dim() != 2: if state.dim() != 2:
@@ -1359,10 +1211,6 @@ class GrootN17PackInputsStep(ProcessorStep):
if dim > self.max_state_dim: if dim > self.max_state_dim:
raise ValueError(f"State dimension {dim} exceeds max_state_dim {self.max_state_dim}.") raise ValueError(f"State dimension {dim} exceeds max_state_dim {self.max_state_dim}.")
_cache_raw_state(state) _cache_raw_state(state)
if self.use_relative_actions:
relative_reference_state = state.detach().clone()
self._last_reference_state = relative_reference_state
self._ref_holder.reference_state = relative_reference_state
if self.normalize_min_max: if self.normalize_min_max:
state = _min_max_norm(state, OBS_STATE) state = _min_max_norm(state, OBS_STATE)
state = state.unsqueeze(1) state = state.unsqueeze(1)
@@ -1385,19 +1233,6 @@ class GrootN17PackInputsStep(ProcessorStep):
raise ValueError(f"Action horizon {horizon} exceeds action_horizon {self.action_horizon}.") raise ValueError(f"Action horizon {horizon} exceeds action_horizon {self.action_horizon}.")
if dim > self.max_action_dim: if dim > self.max_action_dim:
raise ValueError(f"Action dimension {dim} exceeds max_action_dim {self.max_action_dim}.") raise ValueError(f"Action dimension {dim} exceeds max_action_dim {self.max_action_dim}.")
# Convert absolute->relative BEFORE normalization. The mask keeps excluded joints (e.g.
# gripper) absolute; to_relative_actions broadcasts the (B, D) reference state over T.
if self.use_relative_actions:
if relative_reference_state is None:
raise RuntimeError(
"GrootN17PackInputsStep.use_relative_actions requires observation.state "
"(OBS_STATE) to be present alongside the action to build the relative "
"reference, but no state was found in this transition."
)
mask = _build_relative_action_mask(
action.shape[-1], self.relative_exclude_joints, self.action_feature_names
)
action = to_relative_actions(action, relative_reference_state, mask)
if self.normalize_min_max: if self.normalize_min_max:
flat = _min_max_norm(action.reshape(bsz * horizon, dim), ACTION) flat = _min_max_norm(action.reshape(bsz * horizon, dim), ACTION)
action = flat.view(bsz, horizon, dim) action = flat.view(bsz, horizon, dim)
@@ -1437,12 +1272,6 @@ class GrootN17PackInputsStep(ProcessorStep):
comp["action_mask"] = action_mask comp["action_mask"] = action_mask
comp["embodiment_id"] = torch.full((bsz,), emb_id, dtype=torch.int32, device=device) comp["embodiment_id"] = torch.full((bsz,), emb_id, dtype=torch.int32, device=device)
# Publish the runtime-only reference holder so the policy can freeze it at the predict
# event and the decode step can read the frozen reference. It rides in COMPLEMENTARY_DATA,
# survives the VLM-encode step and DeviceProcessorStep as a non-tensor, and reaches the
# policy via the batch (by object identity) through the pipeline's shallow copies.
comp[_GROOT_REF_HOLDER_KEY] = self._ref_holder
transition[TransitionKey.OBSERVATION] = obs transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = comp transition[TransitionKey.COMPLEMENTARY_DATA] = comp
return transition return transition
@@ -1467,9 +1296,6 @@ class GrootN17PackInputsStep(ProcessorStep):
"video_modality_keys": self.video_modality_keys, "video_modality_keys": self.video_modality_keys,
"raw_stats": self.raw_stats, "raw_stats": self.raw_stats,
"modality_config": self.modality_config, "modality_config": self.modality_config,
"use_relative_actions": self.use_relative_actions,
"relative_exclude_joints": self.relative_exclude_joints,
"action_feature_names": self.action_feature_names,
} }
def get_cached_raw_state(self) -> dict[str, np.ndarray] | None: def get_cached_raw_state(self) -> dict[str, np.ndarray] | None:
@@ -1477,23 +1303,6 @@ class GrootN17PackInputsStep(ProcessorStep):
return self._last_raw_state return self._last_raw_state
def get_cached_reference_state(self) -> torch.Tensor | None:
"""Return the latest RAW (pre-normalization) (B, D) state used for relative-action conversion."""
return self._last_reference_state
def get_reference_holder(self) -> "_GrootRelativeRefHolder":
"""Return the runtime-only holder shared with the policy (writer) and decode step (reader)."""
return self._ref_holder
def reset(self) -> None:
"""Clear cached per-episode relative-action references (sync engine resets on episode boundaries)."""
self._last_reference_state = None
self._last_raw_state = None
self._ref_holder.clear()
def state_dict(self) -> dict[str, torch.Tensor]: def state_dict(self) -> dict[str, torch.Tensor]:
if not self.stats: if not self.stats:
return {} return {}
@@ -1524,12 +1333,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
The packed video has shape ``(B, T, V, H, W, C)``. Each frame/view becomes The packed video has shape ``(B, T, V, H, W, C)``. Each frame/view becomes
an image item in the same chat message so the resulting image tokens match an image item in the same chat message so the resulting image tokens match
the temporal VLM packing used by Isaac-GR00T. the temporal VLM packing used by Isaac-GR00T.
Images are handed to the torchvision-backed Qwen3-VL processor as ``(C, H, W)``
uint8 tensors (no per-frame PIL roundtrip), and, when ``device`` resolves to a
CUDA device, the resize/rescale/normalize/patchify run there instead of on the
single CPU main-loop thread. This keeps the output bit-identical on CPU and
moves the dominant preprocessing cost off the critical path on GPU.
""" """
model_name: str = GROOT_N1_7_BACKBONE_MODEL model_name: str = GROOT_N1_7_BACKBONE_MODEL
@@ -1538,7 +1341,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
shortest_image_edge: int | None = None shortest_image_edge: int | None = None
crop_fraction: float | None = None crop_fraction: float | None = None
use_albumentations: bool = False use_albumentations: bool = False
device: str | None = None
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False) _proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
@property @property
@@ -1547,70 +1349,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
self._proc = _build_n1_7_processor(self.model_name) self._proc = _build_n1_7_processor(self.model_name)
return self._proc return self._proc
def _target_device(self) -> torch.device | None:
# The albumentations path is cv2/PIL only, so it cannot run on GPU.
if self.device is None or self.use_albumentations:
return None
try:
return get_safe_torch_device(self.device)
except (AssertionError, RuntimeError):
# A device serialized at train time (e.g. "cuda") may be unavailable
# when the processor is reloaded elsewhere (e.g. CPU-only eval), and
# this step is not in the standard device-override set. Fall back to
# the CPU path, which is bit-identical, instead of crashing.
return None
def _build_sample_images(
self, video: Any, batch_size: int, target_device: torch.device | None
) -> list[list[Any]]:
"""Return, per batch item, its ordered ``(timestep, view)`` frames.
``use_albumentations`` keeps the legacy per-frame PIL/cv2 transform;
otherwise frames are ``(C, H, W)`` uint8 tensors (moved to
``target_device`` when set) for the torchvision-backed Qwen processor.
"""
if self.use_albumentations:
video_np = np.asarray(video)
return [
[
_transform_n1_7_image_for_vlm(
Image.fromarray(video_np[batch_idx, timestep, view_idx]),
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
use_albumentations=True,
)
for timestep in range(video_np.shape[1])
for view_idx in range(video_np.shape[2])
]
for batch_idx in range(batch_size)
]
video_t = video if torch.is_tensor(video) else torch.from_numpy(np.ascontiguousarray(video))
# (B, T, V, H, W, C) uint8 -> (B, T, V, C, H, W)
video_t = video_t.permute(0, 1, 2, 5, 3, 4).contiguous()
if target_device is not None and video_t.device != target_device:
video_t = video_t.to(target_device, non_blocking=(target_device.type == "cuda"))
frames_per_sample: list[list[Any]] = []
for batch_idx in range(batch_size):
sample = video_t[batch_idx] # (T, V, C, H, W)
frames_per_sample.append(
[
_transform_n1_7_image_for_vlm_torch(
sample[timestep, view_idx],
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
)
for timestep in range(sample.shape[0])
for view_idx in range(sample.shape[1])
]
)
return frames_per_sample
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION, {}) or {} obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {} comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
@@ -1618,25 +1356,33 @@ class GrootN17VLMEncodeStep(ProcessorStep):
if video is None: if video is None:
return transition return transition
batch_size = int(video.shape[0])
languages = _prepare_n1_7_language_batch( languages = _prepare_n1_7_language_batch(
comp.get("language"), comp.get("language"),
batch_size, video.shape[0],
formalize_language=False, formalize_language=False,
) )
target_device = self._target_device()
sample_images = self._build_sample_images(video, batch_size, target_device)
texts: list[str] = [] texts: list[str] = []
images: list[Any] = [] images: list[Image.Image] = []
for batch_idx in range(batch_size): for batch_idx in range(video.shape[0]):
frames = sample_images[batch_idx] sample = video[batch_idx] # (T, V, H, W, C)
sample_images = [
_transform_n1_7_image_for_vlm(
Image.fromarray(sample[timestep, view_idx]),
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
use_albumentations=self.use_albumentations,
)
for timestep in range(sample.shape[0])
for view_idx in range(sample.shape[1])
]
conversation = [ conversation = [
{ {
"role": "user", "role": "user",
"content": [ "content": [
*[{"type": "image", "image": image} for image in frames], *[{"type": "image", "image": image} for image in sample_images],
{"type": "text", "text": languages[batch_idx]}, {"type": "text", "text": languages[batch_idx]},
], ],
} }
@@ -1648,17 +1394,9 @@ class GrootN17VLMEncodeStep(ProcessorStep):
add_generation_prompt=False, add_generation_prompt=False,
) )
) )
images.extend(frames) images.extend(sample_images)
proc_kwargs: dict[str, Any] = { encoded = self.proc(text=texts, images=images, return_tensors="pt", padding=True)
"text": texts,
"images": images,
"return_tensors": "pt",
"padding": True,
}
if target_device is not None:
proc_kwargs["device"] = str(target_device)
encoded = self.proc(**proc_kwargs)
for key, value in encoded.items(): for key, value in encoded.items():
comp[key] = value comp[key] = value
obs.pop("video", None) obs.pop("video", None)
@@ -1677,7 +1415,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
"shortest_image_edge": self.shortest_image_edge, "shortest_image_edge": self.shortest_image_edge,
"crop_fraction": self.crop_fraction, "crop_fraction": self.crop_fraction,
"use_albumentations": self.use_albumentations, "use_albumentations": self.use_albumentations,
"device": self.device,
} }
@@ -1828,6 +1565,8 @@ class GrootN17ActionDecodeStep(ProcessorStep):
modality_config: dict[str, Any] | None = None modality_config: dict[str, Any] | None = None
use_percentiles: bool = False use_percentiles: bool = False
use_relative_action: bool = False use_relative_action: bool = False
# Unused: kept so serialized configs that include it still load.
state_cache_key: str = ""
action_decode_transform: str | None = None action_decode_transform: str | None = None
pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False) pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False)
@@ -1890,14 +1629,7 @@ class GrootN17ActionDecodeStep(ProcessorStep):
start_idx += dim start_idx += dim
if self.use_relative_action: if self.use_relative_action:
# Prefer the raw state frozen at the chunk-prediction event (see the relative-action raw_state = self.pack_step.get_cached_raw_state() if self.pack_step is not None else None
# branch of GrootActionUnpackUnnormalizeStep). Falls back to the live cached raw state.
holder = self.pack_step.get_reference_holder() if self.pack_step is not None else None
raw_state = None
if holder is not None:
raw_state = holder.frozen_raw if holder.frozen_raw is not None else holder.raw_state
if raw_state is None and self.pack_step is not None:
raw_state = self.pack_step.get_cached_raw_state()
if raw_state is None: if raw_state is None:
raise RuntimeError( raise RuntimeError(
"GrootN17ActionDecodeStep requires the raw state cached by its connected " "GrootN17ActionDecodeStep requires the raw state cached by its connected "
@@ -1962,10 +1694,10 @@ class GrootN17ActionDecodeStep(ProcessorStep):
} }
@dataclass
# v2: unlike the N1.5-era v1 step, this step no longer collapses (B, T, D) # v2: unlike the N1.5-era v1 step, this step no longer collapses (B, T, D)
# action chunks to the last timestep, so old serialized v1 pipelines must not # action chunks to the last timestep, so old serialized v1 pipelines must not
# silently load into it (v1 is stubbed below with the removal guidance). # silently load into it (v1 is stubbed below with the removal guidance).
@dataclass
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v2") @ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v2")
class GrootActionUnpackUnnormalizeStep(ProcessorStep): class GrootActionUnpackUnnormalizeStep(ProcessorStep):
env_action_dim: int = 0 env_action_dim: int = 0
@@ -1975,13 +1707,6 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
clip_normalized_action: bool = False clip_normalized_action: bool = False
libero_gripper_action: bool = False libero_gripper_action: bool = False
libero_gripper_binarize: bool = True libero_gripper_binarize: bool = True
# Opt-in relative-action reconstruction (paired with GrootN17PackInputsStep). After the
# min-max inverse, relative deltas (arm) + absolute gripper are converted back to absolute
# using the reference state cached by the linked pack_step (re-linked on reload).
use_relative_actions: bool = False
relative_exclude_joints: list[str] = field(default_factory=list)
action_feature_names: list[str] | None = None
pack_step: "GrootN17PackInputsStep | None" = field(default=None, repr=False)
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
# Expect model outputs to be in TransitionKey.ACTION as (B, T, D_model) # Expect model outputs to be in TransitionKey.ACTION as (B, T, D_model)
@@ -2021,35 +1746,6 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
inv = (action + 1.0) * 0.5 * safe_denom + min_v inv = (action + 1.0) * 0.5 * safe_denom + min_v
action = torch.where(mask, inv, min_v) action = torch.where(mask, inv, min_v)
# Reconstruct absolute actions from relative deltas (arm) + absolute gripper, using the
# reference state cached by the linked pack step. The link is restored on reload by
# _reconnect_groot_n1_7_pack_decode_steps.
if self.use_relative_actions:
if self.pack_step is None:
raise RuntimeError(
"GrootActionUnpackUnnormalizeStep.use_relative_actions requires a linked "
"GrootN17PackInputsStep to read the cached reference state, but pack_step is None. "
"Build both pipelines through make_groot_pre_post_processors (or load them together "
"via make_groot_pre_post_processors_from_pretrained)."
)
# Prefer the reference frozen at the chunk-prediction event (set by
# GrootPolicy.predict_action_chunk via the shared holder) so every popped delta of a
# chunk reconstructs against that chunk's start state S_T, not the per-tick latest
# state. Falls back to the live reference when nothing was frozen (e.g. decode without
# a preceding predict event, or RTC/async where frozen == live).
holder = self.pack_step.get_reference_holder()
ref = holder.frozen_reference if holder.frozen_reference is not None else holder.reference_state
if ref is None:
raise RuntimeError(
"GrootActionUnpackUnnormalizeStep.use_relative_actions requires the reference state "
"cached by its connected GrootN17PackInputsStep to convert relative actions back to "
"absolute. Run the preprocessor on an observation before decoding actions."
)
relative_mask = _build_relative_action_mask(
action.shape[-1], self.relative_exclude_joints, self.action_feature_names
)
action = to_absolute_actions(action, ref, relative_mask)
if self.libero_gripper_action and action.shape[-1] >= 7: if self.libero_gripper_action and action.shape[-1] >= 7:
gripper = action[..., -1] gripper = action[..., -1]
if self.libero_gripper_binarize: if self.libero_gripper_binarize:
@@ -2077,9 +1773,6 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
"clip_normalized_action": self.clip_normalized_action, "clip_normalized_action": self.clip_normalized_action,
"libero_gripper_action": self.libero_gripper_action, "libero_gripper_action": self.libero_gripper_action,
"libero_gripper_binarize": self.libero_gripper_binarize, "libero_gripper_binarize": self.libero_gripper_binarize,
"use_relative_actions": self.use_relative_actions,
"relative_exclude_joints": self.relative_exclude_joints,
"action_feature_names": self.action_feature_names,
} }
def state_dict(self) -> dict[str, torch.Tensor]: def state_dict(self) -> dict[str, torch.Tensor]:
@@ -207,6 +207,11 @@ def test_lerobot_groot_forward_pass():
with torch.no_grad(): with torch.no_grad():
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed) lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
assert isinstance(lerobot_loss, torch.Tensor)
assert torch.isfinite(lerobot_loss).all()
assert "loss" in lerobot_metrics
assert np.isfinite(lerobot_metrics["loss"])
print("\nForward pass successful.") print("\nForward pass successful.")
print(f" - Loss: {lerobot_loss.item():.6f}") print(f" - Loss: {lerobot_loss.item():.6f}")
print(f" - Metrics: {lerobot_metrics}") print(f" - Metrics: {lerobot_metrics}")
File diff suppressed because it is too large Load Diff
+175 -26
View File
@@ -14,31 +14,36 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Parity test: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot. """Parity tests: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
Verifies that the self-contained LeRobot reimplementation of the GR00T N1.7 action Two comparisons run per embodiment tag, against per-tag ``.npz`` artifacts produced
head + Qwen3-VL backbone produces the SAME raw model output (``action_pred``, the once in the original ``gr00t`` env by the companion script
normalized flow-matching prediction before any action decoding) as NVIDIA's original ``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file):
``gr00t`` package, given byte-identical pre-processed inputs and the same
flow-matching seed. The comparison is parametrized over every embodiment tag present
in the checkpoint.
To keep the comparison fair, the original outputs + the exact collated inputs are 1. **Model parity** -- the self-contained LeRobot reimplementation of the GR00T N1.7
produced once per embodiment in the original ``gr00t`` env via the companion script action head + Qwen3-VL backbone must produce the SAME raw model output
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file) and saved (``action_pred``, the normalized flow-matching prediction before any action
to per-tag ``.npz`` files. decoding) as NVIDIA's original ``gr00t`` package, given byte-identical
This test discovers those artifacts, replays the identical inputs through the LeRobot pre-processed inputs and the flow-matching seed recorded in the artifact.
model, and compares. 2. **Preprocessor parity** -- LeRobot's own preprocessor pipeline (real Qwen3-VL chat
template / tokenizer / image packing + state normalization, no mocks) must produce
the SAME collated model inputs (``input_ids``, ``pixel_values``, ``state``, ...)
as the original package's processor, given the identical raw observations
(images, state, language) recorded in the artifact. Artifacts written by older
versions of the dump script carry no raw observations; this case then SKIPS with
a regeneration hint.
This test is LOCAL-only and skips on CI, when ``gr00t``-side prerequisites are not These tests are LOCAL-only and skip on CI, when ``gr00t``-side prerequisites are not
present, or when no artifact has been generated. By default it looks for artifacts in present, or when no artifact has been generated. By default they look for artifacts in
``<this dir>/artifacts/``; override with ``GROOT_N1_7_PARITY_DIR``. See the ``<this dir>/artifacts/``; override with ``GROOT_N1_7_PARITY_DIR``. See the
"Original-vs-LeRobot parity test" section of ``src/lerobot/policies/groot/README.md`` "Original-vs-LeRobot parity test" section of ``src/lerobot/policies/groot/README.md``
for the full run procedure. for the full run procedure.
""" """
import os import os
import warnings
from pathlib import Path from pathlib import Path
from typing import Any
import numpy as np import numpy as np
import pytest import pytest
@@ -50,7 +55,9 @@ pytestmark = pytest.mark.skipif(
) )
from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401 from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
# Fallback flow-matching seed for artifacts predating the recorded ``seed`` field.
SEED = 42 SEED = 42
DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu") DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3")) ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3"))
@@ -60,6 +67,11 @@ RTOL = float(os.environ.get("GROOT_PARITY_RTOL", "1e-3"))
_ARTIFACT_PREFIX = "original_n1_7_" _ARTIFACT_PREFIX = "original_n1_7_"
_ARTIFACT_SUFFIX = ".npz" _ARTIFACT_SUFFIX = ".npz"
# Collated keys compared by the preprocessor parity case: integer/id tensors must
# match exactly; float tensors within ATOL/RTOL.
_COLLATED_EXACT_KEYS = ("input_ids", "attention_mask", "image_grid_thw", "embodiment_id")
_COLLATED_CLOSE_KEYS = ("pixel_values", "state")
def _artifact_dir() -> Path: def _artifact_dir() -> Path:
"""Directory holding the per-embodiment .npz artifacts. """Directory holding the per-embodiment .npz artifacts.
@@ -109,9 +121,20 @@ def _resolve_checkpoint() -> str:
return str(ckpt) return str(ckpt)
def _load_artifact(path: Path): def _load_artifact(path: Path) -> tuple[torch.Tensor, dict[str, torch.Tensor], int]:
"""Return (original action_pred, collated model inputs, flow-matching seed)."""
data = np.load(path, allow_pickle=True) data = np.load(path, allow_pickle=True)
original_action = torch.from_numpy(data["action_pred"]).float() original_action = torch.from_numpy(data["action_pred"]).float()
if "seed" in data.files:
seed = int(data["seed"])
else:
warnings.warn(
f"Artifact '{path.name}' does not record the producer seed (it predates the current "
f"dump_original_n1_7.py); falling back to seed={SEED}. If the parity comparison fails, "
"regenerate the artifact with the current dump script.",
stacklevel=2,
)
seed = SEED
dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False)) dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False))
inputs = {} inputs = {}
for key in data.files: for key in data.files:
@@ -124,7 +147,45 @@ def _load_artifact(path: Path):
if "int" in declared or "long" in declared: if "int" in declared or "long" in declared:
t = t.long() t = t.long()
inputs[name] = t inputs[name] = t
return original_action, inputs return original_action, inputs, seed
def _load_raw_observation(path: Path) -> dict[str, Any] | None:
"""Return the raw observation recorded in the artifact, or None for old artifacts.
Artifacts produced by the current ``dump_original_n1_7.py`` additionally store the
exact raw observation the producer fed to the original processor: per-camera uint8
frames (``raw::video.<key>``, (B, T, H, W, C)), per-key state vectors
(``raw::state.<key>``, (B, T, dim)) and the language instruction
(``raw::language``, one string per batch element). ``raw_video_keys`` /
``raw_state_keys`` record the checkpoint modality-key order.
"""
data = np.load(path, allow_pickle=True)
markers = ("raw_video_keys", "raw_state_keys", "raw::language")
if any(marker not in data.files for marker in markers):
return None
video_keys = [str(k) for k in data["raw_video_keys"].tolist()]
state_keys = [str(k) for k in data["raw_state_keys"].tolist()]
return {
"video": {k: data[f"raw::video.{k}"] for k in video_keys},
"state": {k: data[f"raw::state.{k}"] for k in state_keys},
"language": [str(t) for t in data["raw::language"].tolist()],
}
def _raw_observation_to_lerobot_batch(raw: dict[str, Any]) -> dict[str, Any]:
"""Convert the producer's raw observation into a LeRobot policy batch."""
batch: dict[str, Any] = {}
for key, frames in raw["video"].items():
# (B, T, H, W, C) uint8 -> (B, T, C, H, W); the pack step converts back losslessly.
batch[f"{OBS_IMAGES}.{key}"] = torch.from_numpy(frames).permute(0, 1, 4, 2, 3).contiguous()
# observation.state is the per-key state vectors (latest frame) concatenated in
# checkpoint modality-key order -- the layout the LeRobot pack step and the
# flattened checkpoint statistics expect.
state_parts = [torch.from_numpy(np.asarray(arr)[:, -1, :]).float() for arr in raw["state"].values()]
batch[OBS_STATE] = torch.cat(state_parts, dim=-1)
batch["task"] = list(raw["language"])
return batch
def _unflatten(inputs: dict[str, torch.Tensor]) -> dict: def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
@@ -139,6 +200,36 @@ def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
return nested.get("inputs", nested) return nested.get("inputs", nested)
def _assert_collated_parity(
embodiment_tag: str, name: str, lerobot_value: Any, original_value: torch.Tensor, *, exact: bool
) -> None:
"""Compare one collated tensor produced by LeRobot against the original's."""
assert isinstance(lerobot_value, torch.Tensor), (
f"[{embodiment_tag}] LeRobot preprocessor output '{name}' is "
f"{type(lerobot_value).__name__}, expected a tensor."
)
lerobot_t = lerobot_value.detach().cpu()
original_t = original_value.detach().cpu()
assert lerobot_t.shape == original_t.shape, (
f"[{embodiment_tag}] collated '{name}' shape mismatch: lerobot={tuple(lerobot_t.shape)} vs "
f"original={tuple(original_t.shape)}."
)
if exact:
mismatched = int((lerobot_t.long() != original_t.long()).sum())
assert mismatched == 0, (
f"[{embodiment_tag}] collated '{name}' differs from the original processor output: "
f"{mismatched}/{original_t.numel()} elements mismatch."
)
else:
lerobot_f, original_f = lerobot_t.float(), original_t.float()
max_diff = (lerobot_f - original_f).abs().max().item()
print(f"[{embodiment_tag}] {name}: shape {tuple(lerobot_t.shape)} max|diff|={max_diff:.6e}")
assert torch.allclose(lerobot_f, original_f, atol=ATOL, rtol=RTOL), (
f"[{embodiment_tag}] collated '{name}' differs from the original processor output beyond "
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}."
)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def lerobot_model(): def lerobot_model():
"""Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags.""" """Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags."""
@@ -165,8 +256,7 @@ def lerobot_model():
_ARTIFACTS = _discover_artifacts() _ARTIFACTS = _discover_artifacts()
_requires_artifacts = pytest.mark.skipif(
@pytest.mark.skipif(
not _ARTIFACTS, not _ARTIFACTS,
reason=( reason=(
"No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t " "No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t "
@@ -174,24 +264,30 @@ _ARTIFACTS = _discover_artifacts()
"--ckpt <ckpt> --out-dir tests/policies/groot/artifacts --device cuda" "--ckpt <ckpt> --out-dir tests/policies/groot/artifacts --device cuda"
), ),
) )
@_requires_artifacts
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS]) @pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model): def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
"""Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot.""" """Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot."""
original_action, flat_inputs = _load_artifact(artifact) original_action, flat_inputs, seed = _load_artifact(artifact)
model_inputs = _unflatten(flat_inputs) model_inputs = _unflatten(flat_inputs)
# Align the flow-matching RNG exactly as the producer did (seed right before sampling). # Align the flow-matching RNG exactly as the producer did (seed right before sampling).
torch.manual_seed(SEED) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED) torch.cuda.manual_seed_all(seed)
with torch.inference_mode(): with torch.inference_mode():
out = lerobot_model.get_action(model_inputs) out = lerobot_model.get_action(model_inputs)
lerobot_action = out["action_pred"].float().cpu() lerobot_action = out["action_pred"].float().cpu()
t = min(original_action.shape[1], lerobot_action.shape[1]) assert lerobot_action.shape == original_action.shape, (
d = min(original_action.shape[2], lerobot_action.shape[2]) f"GR00T N1.7 action_pred shape mismatch for embodiment '{embodiment_tag}': "
original_action = original_action[:, :t, :d] f"lerobot={tuple(lerobot_action.shape)} vs original={tuple(original_action.shape)}. "
lerobot_action = lerobot_action[:, :t, :d] "The same checkpoint and inputs must produce identical shapes; this indicates an "
"action-horizon or action-dim regression (or a stale artifact -- regenerate it with "
"utils/dump_original_n1_7.py)."
)
diff = torch.abs(lerobot_action - original_action) diff = torch.abs(lerobot_action - original_action)
max_diff = diff.max().item() max_diff = diff.max().item()
@@ -205,3 +301,56 @@ def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond " f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond "
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}" f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}"
) )
@_requires_artifacts
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
def test_groot_preprocessor_parity(embodiment_tag, artifact):
"""LeRobot's real preprocessor vs the original's collated tensors, from identical raw obs.
Runs LeRobot's full preprocessor pipeline -- including the real Qwen3-VL chat
template, tokenizer and image packing plus the checkpoint-driven state
normalization (no mocks) -- on the raw observations recorded in the artifact, and
compares every collated model input against the ones the original ``gr00t``
processor produced from the same raw observations.
"""
raw = _load_raw_observation(artifact)
if raw is None:
pytest.skip(
f"Artifact '{artifact.name}' was produced by an older dump_original_n1_7.py that does "
"not record raw observations; regenerate it with the current dump script to run the "
"preprocessor parity case."
)
_, flat_inputs, _ = _load_artifact(artifact)
original_inputs = _unflatten(flat_inputs)
ckpt = _resolve_checkpoint()
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
# CPU keeps this case runnable without a GPU; the preprocessor is deterministic.
config = GrootConfig(base_model_path=ckpt, embodiment_tag=embodiment_tag, device="cpu")
preprocessor, _ = make_groot_pre_post_processors(config)
processed = preprocessor(_raw_observation_to_lerobot_batch(raw))
compared_keys = (*_COLLATED_EXACT_KEYS, *_COLLATED_CLOSE_KEYS)
missing_original = [k for k in compared_keys if k not in original_inputs]
missing_lerobot = [k for k in compared_keys if k not in processed]
assert not missing_original, (
f"[{embodiment_tag}] artifact collated inputs miss {missing_original} "
f"(available: {sorted(original_inputs)}); regenerate the artifact with the current dump script."
)
assert not missing_lerobot, (
f"[{embodiment_tag}] LeRobot preprocessor output misses {missing_lerobot} (tensor keys "
f"available: {sorted(k for k, v in processed.items() if isinstance(v, torch.Tensor))})."
)
for name in compared_keys:
_assert_collated_parity(
embodiment_tag,
name,
processed[name],
original_inputs[name],
exact=name in _COLLATED_EXACT_KEYS,
)
@@ -9,6 +9,9 @@ LeRobot GR00T N1.7 integration requires. The two implementations therefore canno
imported in the same Python process. To keep the parity comparison FAIR, we run the imported in the same Python process. To keep the parity comparison FAIR, we run the
original model in its native env here and serialize, PER EMBODIMENT TAG: original model in its native env here and serialize, PER EMBODIMENT TAG:
* the RAW observation fed to the original processor (per-camera uint8 frames,
per-key state vectors, the language instruction), so the LeRobot side can also
run its OWN preprocessor on identical raw inputs and compare collated tensors,
* the exact pre-processed/collated model inputs (so the LeRobot side consumes the * the exact pre-processed/collated model inputs (so the LeRobot side consumes the
byte-identical tensors -- same image preprocessing, tokenization, normalization), byte-identical tensors -- same image preprocessing, tokenization, normalization),
* the random seed used right before the flow-matching sampler, * the random seed used right before the flow-matching sampler,
@@ -21,8 +24,10 @@ processor's per-embodiment modality configs. This lets us test many embodiment t
from the SAME checkpoint and confirm the LeRobot integration is not overfit to from the SAME checkpoint and confirm the LeRobot integration is not overfit to
``libero_sim``. ``libero_sim``.
The companion pytest (run in the LeRobot env) loads each .npz, replays the identical The companion pytest (run in the LeRobot env) loads each .npz and asserts parity
inputs + seed through the LeRobot GR00T N1.7 model, and asserts the outputs match. twice: the collated inputs + seed are replayed through the LeRobot GR00T N1.7 model
(model parity), and the raw observation is replayed through LeRobot's own
preprocessor pipeline and compared against the collated inputs (preprocessor parity).
Usage: Usage:
.venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py \ .venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py \
@@ -62,10 +67,7 @@ def make_observation(seed: int, video_keys, lang_key, state_spec):
# One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics. # One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics.
# Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as # Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as
# present-but-empty so the processor's state transform finds every expected key. # present-but-empty so the processor's state transform finds every expected key.
state = { state = {k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32) for k, dim in state_spec}
k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32)
for k, dim in state_spec
}
language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]} language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]}
return {"video": video, "state": state, "language": language} return {"video": video, "state": state, "language": language}
@@ -77,6 +79,25 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa
lang_key = modality_cfg["language"].modality_keys[0] lang_key = modality_cfg["language"].modality_keys[0]
observation = make_observation(args.seed, video_keys, lang_key, state_spec) observation = make_observation(args.seed, video_keys, lang_key, state_spec)
# Snapshot the RAW observation exactly as fed to the original processor below. The
# consumer's preprocessor-parity case replays it through LeRobot's own preprocessor
# and compares the resulting collated tensors against the "in::" ones saved further
# down. raw_state_keys records the checkpoint modality-key order, which is the
# concatenation order of the flat LeRobot ``observation.state`` vector.
spec_keys = [key for key, _ in state_spec]
state_modality = modality_cfg.get("state")
state_keys = [key for key in state_modality.modality_keys if key in spec_keys] if state_modality else []
state_keys += [key for key in spec_keys if key not in state_keys]
raw_language = [
str(item[0]) if isinstance(item, (list, tuple)) else str(item)
for item in observation["language"][lang_key]
]
raw_flat = {f"raw::video.{key}": arr.copy() for key, arr in observation["video"].items()}
raw_flat.update({f"raw::state.{key}": arr.copy() for key, arr in observation["state"].items()})
raw_flat["raw::language"] = np.array(raw_language, dtype=object)
raw_flat["raw_video_keys"] = np.array([str(key) for key in video_keys], dtype=object)
raw_flat["raw_state_keys"] = np.array([str(key) for key in state_keys], dtype=object)
# Point the policy preprocessing at this embodiment (mirrors Gr00tPolicy.__init__). # Point the policy preprocessing at this embodiment (mirrors Gr00tPolicy.__init__).
policy.embodiment_tag = type(policy.embodiment_tag)(tag) policy.embodiment_tag = type(policy.embodiment_tag)(tag)
policy.modality_configs = { policy.modality_configs = {
@@ -136,6 +157,7 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa
embodiment_tag=np.array(tag), embodiment_tag=np.array(tag),
meta_keys=np.array(list(meta.keys()), dtype=object), meta_keys=np.array(list(meta.keys()), dtype=object),
meta_dtypes=np.array(list(meta.values()), dtype=object), meta_dtypes=np.array(list(meta.values()), dtype=object),
**raw_flat,
**flat, **flat,
) )
print(f"[{tag}] action_pred {action_pred.shape} -> {out_path.name} ({os.path.getsize(out_path)} B)") print(f"[{tag}] action_pred {action_pred.shape} -> {out_path.name} ({os.path.getsize(out_path)} B)")
@@ -181,7 +203,12 @@ def main():
state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()] state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()]
try: try:
dump_one_tag( dump_one_tag(
policy, fair_model, tag, all_modality[tag], state_spec, args, policy,
fair_model,
tag,
all_modality[tag],
state_spec,
args,
out_dir / f"original_n1_7_{tag}.npz", out_dir / f"original_n1_7_{tag}.npz",
) )
done.append(tag) done.append(tag)