Compare commits

...

1 Commits

Author SHA1 Message Date
Steven Palma ffdfb3d25f fix(groot): N1.7 config defaults, N1.5 rejection, and processor/model runtime fixes
Covers the GR00T N1.7 source trio (configuration, processor, model wrapper).
These three files are grouped together because processor_groot and
modeling_groot import GROOT_N1_5_REMOVAL_GUIDANCE defined in
configuration_groot.

Config:
- GrootConfig defaults are the N1.7 values; explicitly passed legacy N1.5-era
  values (chunk_size=50, max_state_dim=64, ...) are remapped with a warning
  instead of silently.
- action_decode_transform gains an 'auto' sentinel so an explicit 'none'
  opt-out wins over the libero_sim default and survives save/load round-trips.
- action_delta_indices is cached on the inputs that determine it.
- Legacy N1.5 checkpoints/configs (tokenizer_assets_repo, model_type/
  architectures/eagle backbone markers) are rejected with a single clear
  error pointing to lerobot==0.5.1.

Processor:
- GrootN17ActionDecodeStep handles the 2-D (B, D) actions delivered by sync
  select_action (relative eef/non-eef decode in eval/record flows).
- Postprocessor falls back to dataset stats when a raw checkpoint lacks the
  configured embodiment tag; raw-state cache is per-instance, not
  process-global; caller overrides (device, rename_map) are honored on the
  raw-checkpoint branch.
- Hub-hosted finetuned checkpoints resolve the processor config via
  hf_hub_download, with a tolerant retry when inspection fails.
- Camera/modality-key mismatches warn (including the zero-match fallback);
  deprecated Qwen2VLImageProcessorFast replaced with Qwen2VLImageProcessor;
  removed N1.5 processor steps are stubbed to raise the removal guidance and
  the action-unpack step is re-registered as _v2.

Model:
- use_bf16=False no longer crashes (compute_dtype only set when used).
- Flash-attention probe is diagnostic-only; forward raises on a missing loss;
  print() replaced with logging; N1.5 base-path mismatch includes the
  removal guidance.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-12 23:37:44 +02:00
3 changed files with 464 additions and 124 deletions
+121 -37
View File
@@ -15,6 +15,7 @@
# limitations under the License.
import json
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
@@ -23,15 +24,29 @@ from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTr
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
logger = logging.getLogger(__name__)
GROOT_N1_7 = "n1.7"
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
# still rejects it). It is retained only so that infer_groot_model_version can recognise
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
GROOT_N1_5 = "n1.5"
# Canonical guidance appended to every error raised when an N1.5 checkpoint, config,
# or processor pipeline is detected. Keep this message in sync with docs/source/groot.mdx.
GROOT_N1_5_REMOVAL_GUIDANCE = (
"GR00T N1.5 support was removed from LeRobot. "
"To keep using an N1.5 checkpoint, pin the last release that supports it: "
"`pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 "
"(model_version='n1.7', base model nvidia/GR00T-N1.7-3B)."
)
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
# explicit 'none' (resolved to None) so an opt-out survives a draccus save/load round-trip.
GROOT_ACTION_DECODE_TRANSFORM_AUTO = "auto"
_GROOT_MODEL_VERSION_ALIASES = {
"n1.7": GROOT_N1_7,
@@ -41,7 +56,12 @@ _GROOT_MODEL_VERSION_ALIASES = {
"1.7": GROOT_N1_7,
}
# Legacy N1.5 spellings, kept ONLY so they can be detected and rejected with
# GROOT_N1_5_REMOVAL_GUIDANCE (see GROOT_N1_5 above). Never map these to a supported version.
_GROOT_N1_5_VERSION_ALIASES = {"n1.5", "n1_5", "n1d5", "n15", "1.5"}
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
GROOT_ACTION_DECODE_TRANSFORM_AUTO: GROOT_ACTION_DECODE_TRANSFORM_AUTO,
"none": None,
"": None,
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
@@ -52,9 +72,10 @@ def normalize_groot_model_version(model_version: str) -> str:
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
if normalized is None:
supported = GROOT_N1_7
raise ValueError(
f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
)
message = f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
return normalized
@@ -286,6 +307,8 @@ def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
def _infer_groot_model_version_from_config(config: dict) -> str | None:
model_version = config.get("model_version")
if isinstance(model_version, str):
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
return GROOT_N1_5
try:
return normalize_groot_model_version(model_version)
except ValueError:
@@ -298,8 +321,17 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
normalized = candidate.lower().replace("-", "_")
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
return GROOT_N1_7
# nvidia/GR00T-N1.5-3B ships model_type 'gr00t_n1_5' and architectures ['GR00T_N1_5'].
# Recognise them so N1.5 checkpoints at generic local paths are rejected loudly
# instead of being silently treated as N1.7 (see infer_groot_model_version).
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
return GROOT_N1_5
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
return GROOT_N1_7
# The Eagle VLM backbone is specific to pre-N1.7 GR00T checkpoints (N1.7 uses Cosmos/Qwen3-VL).
backbone_cfg = config.get("backbone_cfg")
if isinstance(backbone_cfg, dict) and "eagle_path" in backbone_cfg:
return GROOT_N1_5
return None
@@ -310,27 +342,32 @@ class GrootConfig(PreTrainedConfig):
# Basic policy settings
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
chunk_size: int = 40
n_action_steps: int = 40
# Dimension settings (must match pretrained GR00T model expectations)
# Maximum state dimension. Shorter states will be zero-padded.
max_state_dim: int = 64
max_state_dim: int = 132
# Maximum action dimension. Shorter actions will be zero-padded.
max_action_dim: int = 32
max_action_dim: int = 132
# Normalization (start with identity, adjust as needed)
# GR00T normalizes state/action internally in its processor steps (min/max with
# q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor
# handles image normalization. The policy therefore does NOT use LeRobot's
# NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally
# IDENTITY for every feature and is not consulted by make_groot_pre_post_processors.
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
)
# Image preprocessing (adjust to match Groot's expected input)
image_size: tuple[int, int] = (224, 224)
# Deprecated and unused: image sizing is handled by the backbone's image processor.
# Kept only so config.json files saved with earlier versions still parse.
image_size: tuple[int, int] = (256, 256)
# Groot-specific model parameters (from groot_finetune_script.py)
@@ -344,7 +381,14 @@ class GrootConfig(PreTrainedConfig):
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
action_decode_transform: str | None = None
# 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
# Deprecated, GR00T N1.5 only — do not set. Kept so config.json files saved by lerobot<=0.5.1
# still parse (draccus rejects unknown fields) and can be rejected in __post_init__ with a
# clear error pointing at GROOT_N1_5_REMOVAL_GUIDANCE instead of a cryptic DecodingError.
tokenizer_assets_repo: str | None = None
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment"
@@ -384,17 +428,13 @@ class GrootConfig(PreTrainedConfig):
warmup_ratio: float = 0.05
use_bf16: bool = True
# Dataset parameters
# Video backend to use for training ('decord' or 'torchvision_av')
# Deprecated Isaac-GR00T runner fields below — unused by the LeRobot N1.7 implementation
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
# would break every previously saved groot checkpoint at config-load time.
video_backend: str = "decord"
# Whether to balance dataset weights in mixture datasets
balance_dataset_weights: bool = True
# Whether to sample trajectories weighted by their length
balance_trajectory_weights: bool = True
# Optional dataset paths for delegating training to Isaac-GR00T runner
dataset_paths: list[str] | None = None
output_dir: str = "./tmp/gr00t"
save_steps: int = 1000
@@ -405,6 +445,15 @@ class GrootConfig(PreTrainedConfig):
resume: bool = False
def __post_init__(self):
# 'tokenizer_assets_repo' only ever existed for GR00T N1.5 (lerobot<=0.5.1) and was
# serialized into every groot checkpoint config.json, so a value here means a legacy
# N1.5 checkpoint or config is being loaded.
if self.tokenizer_assets_repo is not None:
raise ValueError(
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
)
self.model_version = normalize_groot_model_version(self.model_version)
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
if self.base_model_path is None:
@@ -416,26 +465,48 @@ class GrootConfig(PreTrainedConfig):
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
# This matches the embodiment-specific handling already done for the
# action execution horizon (see infer_groot_n1_7_action_execution_horizon).
if self.action_decode_transform is None and self.embodiment_tag == "libero_sim":
self.action_decode_transform = GROOT_ACTION_DECODE_TRANSFORM_LIBERO
# Only the 'auto' sentinel resolves to the embodiment default; an explicit
# 'none' (normalized to None above) keeps the transform disabled.
if self.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_AUTO:
self.action_decode_transform = (
GROOT_ACTION_DECODE_TRANSFORM_LIBERO if self.embodiment_tag == "libero_sim" else None
)
if self.max_state_dim == 64:
self.max_state_dim = 132
if self.max_action_dim == 32:
self.max_action_dim = 132
if self.chunk_size == 50:
self.chunk_size = 40
if self.n_action_steps == 50:
self.n_action_steps = 40
if tuple(self.image_size) == (224, 224):
self.image_size = (256, 256)
# GR00T N1.5-era default values (e.g. --policy.chunk_size=50 from old commands or
# stale configs) are migrated to the values the N1.7 checkpoints expect, with a
# warning. The dataclass defaults are already the N1.7 values, so a plain
# GrootConfig() never triggers this.
legacy_default_remaps = (
("max_state_dim", 64, 132),
("max_action_dim", 32, 132),
("chunk_size", 50, 40),
("n_action_steps", 50, 40),
("image_size", (224, 224), (256, 256)),
)
for field_name, legacy_value, n1_7_value in legacy_default_remaps:
current_value = getattr(self, field_name)
if isinstance(legacy_value, tuple):
current_value = tuple(current_value)
if current_value == legacy_value:
logger.warning(
"GrootConfig.%s=%s matches a legacy GR00T N1.5-era default; remapping it to %s, "
"the value expected by GR00T N1.7 checkpoints. Set a different value explicitly "
"if this is not what you want.",
field_name,
legacy_value,
n1_7_value,
)
setattr(self, field_name, n1_7_value)
inferred_version = infer_groot_model_version(self.base_model_path)
if inferred_version is not None and inferred_version != self.model_version:
raise ValueError(
message = (
f"GR00T model_version '{self.model_version}' does not match base_model_path "
f"'{self.base_model_path}', which looks like '{inferred_version}'."
)
if inferred_version == GROOT_N1_5:
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
super().__post_init__()
@@ -511,9 +582,22 @@ class GrootConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions."""
model_action_horizon = infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
return list(range(min(self.chunk_size, model_action_horizon)))
"""Return indices for delta actions.
The model action horizon is read from the checkpoint's processor_config.json
when available; the result is cached (keyed on the inputs that determine it) so
repeated access during dataset/training setup does not re-read from disk.
"""
cache_key = (self.base_model_path, self.embodiment_tag, self.chunk_size)
cached = getattr(self, "_action_delta_indices_cache", None)
if cached is not None and cached[0] == cache_key:
return cached[1]
model_action_horizon = (
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
)
indices = list(range(min(self.chunk_size, model_action_horizon)))
object.__setattr__(self, "_action_delta_indices_cache", (cache_key, indices))
return indices
@property
def reward_delta_indices(self) -> None:
+46 -49
View File
@@ -18,15 +18,12 @@
Groot Policy Wrapper for LeRobot Integration
Minimal integration that delegates to Isaac-GR00T N1.7 components where
possible without porting their code.
Notes:
- Dataset loading and full training orchestration is handled by Isaac-GR00T
TrainRunner in their codebase. If you want to invoke that flow end-to-end
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
possible without porting their code. Dataset loading and training
orchestration are handled by LeRobot's standard training stack.
"""
import builtins
import logging
import os
from collections import deque
from pathlib import Path
@@ -42,6 +39,8 @@ from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy
from ..utils import get_device_from_parameters
from .configuration_groot import (
GROOT_N1_5,
GROOT_N1_5_REMOVAL_GUIDANCE,
GROOT_N1_7,
GrootConfig,
infer_groot_model_version,
@@ -50,6 +49,8 @@ from .configuration_groot import (
normalize_groot_model_version,
)
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="GrootPolicy")
@@ -92,8 +93,11 @@ class GrootPolicy(PreTrainedPolicy):
transformers_loading_kwargs={"trust_remote_code": True},
)
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
model.config.compute_dtype = model.compute_dtype
# GR00TN17 defines no compute_dtype attribute, so only record the
# bf16 preference when it is enabled instead of reading a default back.
if self.config.use_bf16:
model.compute_dtype = "bfloat16"
model.config.compute_dtype = "bfloat16"
return model
@@ -148,9 +152,10 @@ class GrootPolicy(PreTrainedPolicy):
if config is not None
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
)
print(
f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n"
f"Loading pretrained model from: {pretrained_name_or_path}"
logger.info(
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
requested_version,
pretrained_name_or_path,
)
model_id = str(pretrained_name_or_path)
@@ -181,7 +186,7 @@ class GrootPolicy(PreTrainedPolicy):
if is_finetuned_checkpoint:
# This is a fine-tuned LeRobot checkpoint - use parent class loading
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
return super().from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
config=config,
@@ -197,7 +202,7 @@ class GrootPolicy(PreTrainedPolicy):
)
# This is a base GR00T model - load it fresh
print("Detected base GR00T model, loading from HuggingFace...")
logger.info("Detected base GR00T model, loading from HuggingFace...")
if config is None:
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
@@ -229,10 +234,13 @@ class GrootPolicy(PreTrainedPolicy):
config.model_version = normalize_groot_model_version(config.model_version)
inferred_version = infer_groot_model_version(config.base_model_path)
if inferred_version is not None and inferred_version != config.model_version:
raise ValueError(
message = (
f"GR00T model_version '{config.model_version}' does not match base_model_path "
f"'{config.base_model_path}', which looks like '{inferred_version}'."
)
if inferred_version == GROOT_N1_5:
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
# Create a fresh policy instance - this will automatically load the GR00T model
# in __init__ via _create_groot_model()
policy = cls(config)
@@ -297,9 +305,7 @@ class GrootPolicy(PreTrainedPolicy):
allowed_base.add("action_mask")
return {
k: v
for k, v in batch.items()
if k in allowed_base and not (k.startswith("next.") or k == "info")
k: v for k, v in batch.items() if k in allowed_base and not (k.startswith("next.") or k == "info")
}
def _prepare_n1_7_rtc_inputs(
@@ -320,9 +326,7 @@ class GrootPolicy(PreTrainedPolicy):
if prev_actions.ndim == 2:
prev_actions = prev_actions.unsqueeze(0)
elif prev_actions.ndim != 3:
raise ValueError(
"prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC."
)
raise ValueError("prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC.")
state = inputs.get("state")
if state is None:
@@ -331,9 +335,7 @@ class GrootPolicy(PreTrainedPolicy):
if prev_actions.shape[0] == 1 and batch_size > 1:
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
elif prev_actions.shape[0] != batch_size:
raise ValueError(
"prev_chunk_left_over batch size must match the current GR00T N1.7 batch size."
)
raise ValueError("prev_chunk_left_over batch size must match the current GR00T N1.7 batch size.")
# The generic LeRobot RTC engine pads short leftovers with exact zero
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
@@ -346,7 +348,9 @@ class GrootPolicy(PreTrainedPolicy):
else:
return inputs, None
model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size))
model_action_horizon = int(
getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)
)
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
if prev_actions.shape[1] > model_action_horizon:
prev_actions = prev_actions[:, -model_action_horizon:, :]
@@ -409,6 +413,11 @@ class GrootPolicy(PreTrainedPolicy):
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
loss = outputs.get("loss")
if loss is None:
raise RuntimeError(
"GR00T model.forward did not return a 'loss'. Training batches must include "
"'action' and 'action_mask'; check the preprocessor output."
)
loss_dict = {"loss": loss.item()}
@@ -471,33 +480,21 @@ class GrootPolicy(PreTrainedPolicy):
# Internal helpers
# -------------------------
def _handle_flash_attention_compatibility(self) -> None:
"""Handle Flash Attention compatibility issues by setting environment variables.
"""Log Flash Attention availability (diagnostic only).
This addresses the common 'undefined symbol' error that occurs when Flash Attention
is compiled against a different PyTorch version than what's currently installed.
The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is
unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not
change behaviour or mutate global state.
"""
# Set environment variables to handle Flash Attention compatibility
# These help with symbol resolution issues
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
# Try to import flash_attn and handle failures gracefully
try:
import flash_attn
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
except ImportError as e:
print(f"[GROOT] Flash Attention not available: {e}")
print("[GROOT] Will use fallback attention mechanism")
except Exception as e:
if "undefined symbol" in str(e):
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
print(" pip uninstall flash-attn")
print(" pip install --no-build-isolation flash-attn==2.6.3")
print("[GROOT] Continuing with fallback attention mechanism")
else:
print(f"[GROOT] Flash Attention error: {e}")
print("[GROOT] Continuing with fallback attention mechanism")
logger.debug("Flash Attention %s is available.", flash_attn.__version__)
except ImportError:
logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.")
except Exception as e: # noqa: BLE001
logger.warning(
"Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is "
"an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.",
e,
)
+297 -38
View File
@@ -15,14 +15,16 @@
# limitations under the License.
import json
import logging
from copy import copy
from dataclasses import dataclass, field
from dataclasses import dataclass, field, fields, is_dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image
from lerobot.utils.import_utils import _transformers_available
@@ -59,6 +61,7 @@ from lerobot.utils.constants import (
from .configuration_groot import (
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
GROOT_N1_5_REMOVAL_GUIDANCE,
GROOT_N1_7_BACKBONE_MODEL,
GrootConfig,
is_raw_groot_n1_7_checkpoint,
@@ -80,12 +83,6 @@ N1_7_EMBODIMENT_MAPPING = {
"new_embodiment": 10,
}
_N1_7_RAW_STATE_CACHE: dict[str, dict[str, np.ndarray]] = {}
def _n1_7_state_cache_key(value: str | None) -> str:
return value or "groot_n1_7_default"
@dataclass
class _GrootN17CheckpointProcessorAssets:
@@ -471,25 +468,98 @@ def _legacy_groot_processor_overrides(
env_action_dim = int(config.output_features[ACTION].shape[0])
except Exception:
env_action_dim = 0
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v1", {}))
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v2", {}))
action_unpack_overrides["normalize_min_max"] = True
action_unpack_overrides["env_action_dim"] = env_action_dim
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = action_unpack_overrides
postprocessor_overrides["groot_action_unpack_unnormalize_v2"] = action_unpack_overrides
return preprocessor_overrides, postprocessor_overrides
def _local_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
def _pretrained_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
"""Check whether a serialized processor pipeline contains a registry step.
Resolves the processor config from a local directory or, for Hub repo ids,
via ``hf_hub_download`` (which serves the cached copy when offline). Returns
False when the config cannot be resolved; loading then proceeds with the
legacy overrides and `make_groot_pre_post_processors_from_pretrained` retries
without them if they do not match the serialized pipeline.
"""
path = Path(pretrained_path).expanduser()
if not path.is_dir():
if path.is_dir():
config = _read_json(path / config_filename)
elif path.exists():
return False
config = _read_json(path / config_filename)
else:
try:
config_path = hf_hub_download(
repo_id=str(pretrained_path), filename=config_filename, repo_type="model"
)
except Exception:
return False
config = _read_json(Path(config_path))
steps = config.get("steps", [])
if not isinstance(steps, list):
return False
return any(isinstance(step, dict) and step.get("registry_name") == step_name for step in steps)
def _apply_groot_step_overrides(
pipeline: PolicyProcessorPipeline,
overrides: dict[str, Any] | None,
) -> None:
"""Apply ``from_pretrained``-style step overrides to a freshly built pipeline.
Raw N1.7 checkpoints build their processors from scratch instead of
deserializing them, so caller overrides must be applied to the constructed
steps. Override keys match a step's registry name or, as a convenience, its
class name (``PolicyProcessorPipeline.from_pretrained`` matches registered
steps by registry name only prefer registry names so overrides keep
working after the checkpoint is converted and reloaded from a serialized
pipeline). Keys or fields that match nothing raise instead of being dropped
silently.
"""
if not overrides:
return
def _step_keys(step: ProcessorStep) -> set[str]:
keys = {type(step).__name__}
registry_name = getattr(type(step), "_registry_name", None)
if registry_name:
keys.add(registry_name)
return keys
for override_key, step_overrides in overrides.items():
matched_steps = [step for step in pipeline.steps if override_key in _step_keys(step)]
if not matched_steps:
available = [
getattr(type(step), "_registry_name", None) or type(step).__name__ for step in pipeline.steps
]
raise KeyError(
f"Override key '{override_key}' does not match any step of the GR00T processor pipeline "
f"built for this raw N1.7 checkpoint. Available step keys: {available}."
)
for step in matched_steps:
if not is_dataclass(step):
raise TypeError(
f"Cannot apply overrides to step '{override_key}': it is not a dataclass step."
)
init_field_names = {f.name for f in fields(step) if f.init}
for field_name, value in dict(step_overrides).items():
if field_name not in init_field_names:
raise TypeError(
f"Override field '{field_name}' is not a config field of step '{override_key}'. "
f"Available fields: {sorted(init_field_names)}."
)
setattr(step, field_name, value)
# Re-derive attributes computed from the overridden config (e.g.
# DeviceProcessorStep resolves its torch.device in __post_init__).
post_init = getattr(step, "__post_init__", None)
if callable(post_init):
post_init()
def make_groot_pre_post_processors_from_pretrained(
config: GrootConfig,
pretrained_path: str,
@@ -508,12 +578,20 @@ def make_groot_pre_post_processors_from_pretrained(
if is_raw_groot_n1_7_checkpoint(pretrained_path):
processor_cfg = copy(config)
processor_cfg.base_model_path = str(pretrained_path)
return make_groot_pre_post_processors(
preprocessor, postprocessor = make_groot_pre_post_processors(
config=processor_cfg,
dataset_stats=dataset_stats,
)
# Raw checkpoints have no serialized pipelines to load overrides into,
# so apply the caller overrides (e.g. device and rename_map from
# lerobot-eval or the policy server) to the freshly built steps.
_apply_groot_step_overrides(preprocessor, preprocessor_overrides)
_apply_groot_step_overrides(postprocessor, postprocessor_overrides)
return preprocessor, postprocessor
if _local_processor_config_has_step(
caller_preprocessor_overrides = dict(preprocessor_overrides or {})
caller_postprocessor_overrides = dict(postprocessor_overrides or {})
if _pretrained_processor_config_has_step(
pretrained_path,
postprocessor_config_filename,
"groot_n1_7_action_decode_v1",
@@ -521,15 +599,55 @@ def make_groot_pre_post_processors_from_pretrained(
# Converted raw N1.7 checkpoints already carry the checkpoint-specific
# action decoder. Adding the legacy action-unpack override would target
# a step that is not present and break loading.
preprocessor_overrides = dict(preprocessor_overrides or {})
postprocessor_overrides = dict(postprocessor_overrides or {})
applied_legacy_overrides = False
preprocessor_overrides = caller_preprocessor_overrides
postprocessor_overrides = caller_postprocessor_overrides
else:
applied_legacy_overrides = True
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
config=config,
dataset_stats=dataset_stats,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
)
try:
preprocessor, postprocessor = _load_groot_processor_pipelines(
pretrained_path,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
preprocessor_config_filename=preprocessor_config_filename,
postprocessor_config_filename=postprocessor_config_filename,
)
except KeyError:
if not applied_legacy_overrides:
raise
# The legacy overrides target steps that are absent from the serialized
# pipelines (e.g. a converted raw N1.7 checkpoint whose postprocessor
# config could not be inspected before loading); retry with the caller
# overrides only.
preprocessor, postprocessor = _load_groot_processor_pipelines(
pretrained_path,
preprocessor_overrides=caller_preprocessor_overrides,
postprocessor_overrides=caller_postprocessor_overrides,
preprocessor_config_filename=preprocessor_config_filename,
postprocessor_config_filename=postprocessor_config_filename,
)
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
_reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor)
return preprocessor, postprocessor
def _load_groot_processor_pipelines(
pretrained_path: str,
*,
preprocessor_overrides: dict[str, Any],
postprocessor_overrides: dict[str, Any],
preprocessor_config_filename: str,
postprocessor_config_filename: str,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
preprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=preprocessor_config_filename,
@@ -544,7 +662,6 @@ def make_groot_pre_post_processors_from_pretrained(
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
)
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
return preprocessor, postprocessor
@@ -564,6 +681,28 @@ def _reconnect_groot_relative_absolute_steps(
step.relative_step = relative_step
def _reconnect_groot_n1_7_pack_decode_steps(
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
) -> None:
"""Re-link a deserialized N1.7 action decode step to its pack step.
The pack step holds the per-instance raw-state cache that relative-action
decoding reads its reference state from; the link itself is not serialized.
"""
pack_step = next(
(step for step in preprocessor.steps if isinstance(step, GrootN17PackInputsStep)),
None,
)
if pack_step is None:
return
for step in postprocessor.steps:
if isinstance(step, GrootN17ActionDecodeStep) and step.pack_step is None:
step.pack_step = pack_step
def make_groot_pre_post_processors(
config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[
@@ -607,9 +746,12 @@ def make_groot_pre_post_processors(
else action_horizon
)
checkpoint_stats = checkpoint_assets.stats if checkpoint_assets is not None else None
padded_stats = checkpoint_stats if _has_modality_stats(checkpoint_stats) else (dataset_stats or {})
checkpoint_has_stats = _has_modality_stats(checkpoint_stats)
padded_stats = checkpoint_stats if checkpoint_has_stats else (dataset_stats or {})
embodiment_mapping = (
checkpoint_assets.embodiment_mapping if checkpoint_assets is not None else dict(N1_7_EMBODIMENT_MAPPING)
checkpoint_assets.embodiment_mapping
if checkpoint_assets is not None
else dict(N1_7_EMBODIMENT_MAPPING)
)
formalize_language = checkpoint_assets.formalize_language if checkpoint_assets is not None else True
clip_outliers = checkpoint_assets.clip_outliers if checkpoint_assets is not None else True
@@ -618,7 +760,6 @@ def make_groot_pre_post_processors(
env_action_dim = int(config.output_features[ACTION].shape[0])
except Exception:
env_action_dim = 0
state_cache_key = f"groot_n1_7:{config.embodiment_tag}"
pack_step = GrootN17PackInputsStep(
state_horizon=1,
action_horizon=action_horizon,
@@ -636,7 +777,6 @@ def make_groot_pre_post_processors(
video_modality_keys=video_modality_keys,
raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None,
modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None,
state_cache_key=state_cache_key,
)
input_steps: list[ProcessorStep] = [
@@ -647,14 +787,31 @@ def make_groot_pre_post_processors(
model_name=config.n1_7_backbone_model,
image_crop_size=checkpoint_assets.image_crop_size if checkpoint_assets is not None else None,
image_target_size=checkpoint_assets.image_target_size if checkpoint_assets is not None else None,
shortest_image_edge=checkpoint_assets.shortest_image_edge if checkpoint_assets is not None else None,
shortest_image_edge=checkpoint_assets.shortest_image_edge
if checkpoint_assets is not None
else None,
crop_fraction=checkpoint_assets.crop_fraction if checkpoint_assets is not None else None,
use_albumentations=checkpoint_assets.use_albumentations if checkpoint_assets is not None else False,
use_albumentations=checkpoint_assets.use_albumentations
if checkpoint_assets is not None
else False,
),
DeviceProcessorStep(device=config.device),
]
if checkpoint_assets is None:
if checkpoint_assets is not None and not checkpoint_has_stats and not _has_modality_stats(padded_stats):
raise ValueError(
f"GR00T N1.7 checkpoint '{config.base_model_path}' has no statistics for embodiment tag "
f"'{config.embodiment_tag}', and no dataset stats were provided to fall back to, so "
"actions cannot be normalized or decoded. Pass dataset_stats, or set "
"config.embodiment_tag to an embodiment present in the checkpoint's statistics.json."
)
if checkpoint_assets is None or not checkpoint_has_stats:
# When the checkpoint sidecars have no stats for the configured
# embodiment tag (e.g. finetuning a raw base checkpoint with the
# default 'new_embodiment' tag), the pack step above normalized with
# the dataset stats; the decode step must invert with the same stats
# instead of using a checkpoint decoder whose empty stats would
# silently return normalized [-1, 1] actions.
action_decode_step: ProcessorStep = GrootActionUnpackUnnormalizeStep(
env_action_dim=env_action_dim,
stats=padded_stats,
@@ -669,7 +826,6 @@ def make_groot_pre_post_processors(
use_percentiles=checkpoint_assets.use_percentiles,
use_relative_action=checkpoint_assets.use_relative_action,
pack_step=pack_step,
state_cache_key=state_cache_key,
action_decode_transform=config.action_decode_transform,
)
@@ -770,7 +926,7 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
try:
from transformers import (
AutoTokenizer,
Qwen2VLImageProcessorFast,
Qwen2VLImageProcessor,
Qwen3VLProcessor,
Qwen3VLVideoProcessor,
)
@@ -781,7 +937,7 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
) from exc
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
image_processor = Qwen2VLImageProcessorFast.from_pretrained(model_name, trust_remote_code=True)
image_processor = Qwen2VLImageProcessor.from_pretrained(model_name, trust_remote_code=True)
video_processor = Qwen3VLVideoProcessor.from_pretrained(model_name, trust_remote_code=True)
proc = Qwen3VLProcessor(
image_processor=image_processor,
@@ -902,8 +1058,11 @@ class GrootN17PackInputsStep(ProcessorStep):
video_modality_keys: list[str] | None = None
raw_stats: dict[str, Any] | None = None
modality_config: dict[str, Any] | None = None
# Unused: kept so serialized configs that include it still load. The raw
# state cache is per instance (_last_raw_state), never process-global.
state_cache_key: str = ""
_last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False)
_warned_image_keys: bool = field(default=False, init=False, repr=False)
def _ordered_image_keys(self, obs: dict[str, Any]) -> list[str]:
available = {key for key in obs if key.startswith(OBS_IMAGES)}
@@ -913,19 +1072,56 @@ class GrootN17PackInputsStep(ProcessorStep):
return sorted(available)
ordered: list[str] = []
unmatched: list[str] = []
for modality_key in self.video_modality_keys:
candidates = [f"{OBS_IMAGES}.{modality_key}"]
# Alias for datasets converted with generic camera names (e.g. the
# LIBERO conversions expose the wrist camera as
# `observation.images.image2`), so raw N1.7 LIBERO checkpoints
# match those datasets out of the box.
if modality_key == "wrist_image":
candidates.append(f"{OBS_IMAGES}.image2")
elif modality_key == "image":
candidates.append(f"{OBS_IMAGES}.image")
match = next((candidate for candidate in candidates if candidate in available), None)
if match is not None:
if match is None:
unmatched.append(modality_key)
else:
ordered.append(match)
if not ordered:
if not self._warned_image_keys:
self._warned_image_keys = True
logging.warning(
"None of the GR00T N1.7 checkpoint video modality keys %s match a camera among %s; "
"falling back to feeding all cameras in alphabetical order, which is unlikely to be "
"the layout the checkpoint was trained with. Rename the dataset cameras (e.g. via "
"--rename_map) to match the checkpoint keys.",
self.video_modality_keys,
sorted(available),
)
return sorted(available)
unused = sorted(available - set(ordered))
if (unmatched or unused) and not self._warned_image_keys:
self._warned_image_keys = True
if unmatched:
logging.warning(
"GR00T N1.7 checkpoint video modality keys %s have no matching camera among %s; "
"the model will receive %d view(s) instead of the %d it was trained with. Rename "
"the dataset cameras (e.g. via --rename_map) to match the checkpoint keys %s.",
unmatched,
sorted(available),
len(ordered),
len(self.video_modality_keys),
self.video_modality_keys,
)
if unused:
logging.warning(
"Dropping camera(s) %s: the GR00T N1.7 checkpoint only consumes the video modality "
"keys %s, which matched %s.",
unused,
self.video_modality_keys,
ordered,
)
return ordered
def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -988,7 +1184,6 @@ class GrootN17PackInputsStep(ProcessorStep):
start_idx += dim
if grouped:
self._last_raw_state = grouped
_N1_7_RAW_STATE_CACHE[_n1_7_state_cache_key(self.state_cache_key)] = grouped
img_keys = self._ordered_image_keys(obs)
if img_keys:
@@ -1101,7 +1296,6 @@ class GrootN17PackInputsStep(ProcessorStep):
"video_modality_keys": self.video_modality_keys,
"raw_stats": self.raw_stats,
"modality_config": self.modality_config,
"state_cache_key": self.state_cache_key,
}
def get_cached_raw_state(self) -> dict[str, np.ndarray] | None:
@@ -1223,6 +1417,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
"use_albumentations": self.use_albumentations,
}
def _stat_dim_from_entry(entry: dict[str, Any]) -> int:
for stat_name in ("mean", "q01", "min", "max", "std"):
value = entry.get(stat_name)
@@ -1351,6 +1546,18 @@ class GrootN17ActionDecodeStep(ProcessorStep):
group with the checkpoint stats, converts relative groups to absolute values
using the raw state cached during packing, concatenates groups in checkpoint
order, and finally slices to the environment action dimension.
Relative-action decoding reads the reference state from the connected
``pack_step`` (re-linked after ``from_pretrained`` by
``_reconnect_groot_n1_7_pack_decode_steps``), i.e. the state seen by the
most recent preprocess call. Engines that decode the whole chunk right
after prediction (RTC, async policy server) therefore use the
prediction-time state, matching Isaac-GR00T. The sync per-step queue path
instead decodes each popped (B, D) action against the latest observation:
the reference can be newer than the observation the chunk was predicted
from, and per-timestep relative stats are applied as if the popped action
were chunk step 0. Fixing that would require carrying the reference state
and chunk index alongside each queued action through the postprocessor.
"""
env_action_dim: int = 0
@@ -1358,6 +1565,7 @@ class GrootN17ActionDecodeStep(ProcessorStep):
modality_config: dict[str, Any] | None = None
use_percentiles: bool = False
use_relative_action: bool = False
# Unused: kept so serialized configs that include it still load.
state_cache_key: str = ""
action_decode_transform: str | None = None
pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False)
@@ -1378,6 +1586,12 @@ class GrootN17ActionDecodeStep(ProcessorStep):
return transition
action_np = action.detach().cpu().float().numpy()
# The sync action queue postprocesses popped actions as (B, D); decode
# them as single-step (B, 1, D) chunks and squeeze the horizon back at
# the end so both ranks share the chunk decode logic below.
squeeze_horizon = action_np.ndim == 2
if squeeze_horizon:
action_np = action_np[:, None, :]
valid_horizon = _n1_7_decode_valid_horizon(action_config, action_np)
if valid_horizon is not None:
action_np = action_np[:, :valid_horizon]
@@ -1405,17 +1619,24 @@ class GrootN17ActionDecodeStep(ProcessorStep):
use_relative_action=self.use_relative_action,
use_percentiles=self.use_percentiles,
)
# Per-timestep stats carry one row per chunk step; align them with
# the decoded horizon (chunks always start at step 0, and a popped
# (B, D) action is decoded as step 0).
if min_v.ndim == 2 and normalized.shape[1] <= min_v.shape[0]:
min_v = min_v[: normalized.shape[1]]
max_v = max_v[: normalized.shape[1]]
decoded_groups[key] = _unnormalize_min_max(normalized, min_v, max_v)
start_idx += dim
if self.use_relative_action:
raw_state = self.pack_step.get_cached_raw_state() if self.pack_step is not None else None
if raw_state is None:
raw_state = _N1_7_RAW_STATE_CACHE.get(_n1_7_state_cache_key(self.state_cache_key))
if raw_state is None:
raise RuntimeError(
"GrootN17ActionDecodeStep requires cached raw state from GrootN17PackInputsStep "
"to convert relative N1.7 actions back to absolute actions."
"GrootN17ActionDecodeStep requires the raw state cached by its connected "
"GrootN17PackInputsStep to convert relative N1.7 actions back to absolute actions. "
"Build both pipelines through make_groot_pre_post_processors (or load them together "
"via make_groot_pre_post_processors_from_pretrained) and run the preprocessor on an "
"observation before decoding actions."
)
for idx, key in enumerate(action_keys):
if not isinstance(key, str) or key not in decoded_groups or idx >= len(action_configs):
@@ -1451,6 +1672,8 @@ class GrootN17ActionDecodeStep(ProcessorStep):
action_keys=action_keys,
decoded_groups=decoded_groups,
)
if squeeze_horizon:
decoded = decoded[:, 0]
new_transition = transition.copy()
new_transition[TransitionKey.ACTION] = torch.as_tensor(
decoded, dtype=action.dtype, device=action.device
@@ -1467,13 +1690,15 @@ class GrootN17ActionDecodeStep(ProcessorStep):
"modality_config": self.modality_config,
"use_percentiles": self.use_percentiles,
"use_relative_action": self.use_relative_action,
"state_cache_key": self.state_cache_key,
"action_decode_transform": self.action_decode_transform,
}
@dataclass
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v1")
# v2: unlike the N1.5-era v1 step, this step no longer collapses (B, T, D)
# action chunks to the last timestep, so old serialized v1 pipelines must not
# silently load into it (v1 is stubbed below with the removal guidance).
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v2")
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
env_action_dim: int = 0
# Apply inverse of min-max normalization if it was used in preprocessor
@@ -1585,3 +1810,37 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
if reconstructed:
self.stats = reconstructed
def _register_removed_n1_5_step_stub(registry_name: str) -> None:
"""Register a stub for a processor step that only GR00T N1.5 pipelines serialize.
Saved N1.5 checkpoints reference these registry names in their processor JSON
files. Deserializing them must fail with the canonical N1.5 removal guidance
instead of an opaque registry KeyError (or, for
``groot_action_unpack_unnormalize_v1``, silently loading the v2 step whose
action-chunk semantics changed).
"""
@ProcessorStepRegistry.register(name=registry_name)
class _RemovedGrootN15ProcessorStep(ProcessorStep):
def __init__(self, **_kwargs: Any) -> None:
raise ValueError(
f"Processor step '{registry_name}' belongs to a GR00T N1.5 processor pipeline. "
f"{GROOT_N1_5_REMOVAL_GUIDANCE}"
)
def __call__(self, transition: EnvTransition) -> EnvTransition:
raise NotImplementedError
def transform_features(self, features):
raise NotImplementedError
for _removed_n1_5_step_name in (
"groot_pack_inputs_v3",
"groot_eagle_encode_v3",
"groot_eagle_collate_v3",
"groot_action_unpack_unnormalize_v1",
):
_register_removed_n1_5_step_stub(_removed_n1_5_step_name)