diff --git a/src/lerobot/policies/groot/configuration_groot.py b/src/lerobot/policies/groot/configuration_groot.py index 4ab899468..8a234c55a 100644 --- a/src/lerobot/policies/groot/configuration_groot.py +++ b/src/lerobot/policies/groot/configuration_groot.py @@ -15,6 +15,7 @@ # limitations under the License. import json +import logging import os from dataclasses import dataclass, field from pathlib import Path @@ -23,15 +24,29 @@ from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTr from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig from lerobot.utils.constants import ACTION, OBS_STATE +logger = logging.getLogger(__name__) + GROOT_N1_7 = "n1.7" # Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is # intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version # still rejects it). It is retained only so that infer_groot_model_version can recognise # an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch. GROOT_N1_5 = "n1.5" +# Canonical guidance appended to every error raised when an N1.5 checkpoint, config, +# or processor pipeline is detected. Keep this message in sync with docs/source/groot.mdx. +GROOT_N1_5_REMOVAL_GUIDANCE = ( + "GR00T N1.5 support was removed from LeRobot. " + "To keep using an N1.5 checkpoint, pin the last release that supports it: " + "`pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 " + "(model_version='n1.7', base model nvidia/GR00T-N1.7-3B)." +) GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B" GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B" GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero" +# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it +# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an +# explicit 'none' (resolved to None) so an opt-out survives a draccus save/load round-trip. +GROOT_ACTION_DECODE_TRANSFORM_AUTO = "auto" _GROOT_MODEL_VERSION_ALIASES = { "n1.7": GROOT_N1_7, @@ -41,7 +56,12 @@ _GROOT_MODEL_VERSION_ALIASES = { "1.7": GROOT_N1_7, } +# Legacy N1.5 spellings, kept ONLY so they can be detected and rejected with +# GROOT_N1_5_REMOVAL_GUIDANCE (see GROOT_N1_5 above). Never map these to a supported version. +_GROOT_N1_5_VERSION_ALIASES = {"n1.5", "n1_5", "n1d5", "n15", "1.5"} + _GROOT_ACTION_DECODE_TRANSFORM_ALIASES = { + GROOT_ACTION_DECODE_TRANSFORM_AUTO: GROOT_ACTION_DECODE_TRANSFORM_AUTO, "none": None, "": None, GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO, @@ -52,9 +72,10 @@ def normalize_groot_model_version(model_version: str) -> str: normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower()) if normalized is None: supported = GROOT_N1_7 - raise ValueError( - f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}." - ) + message = f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}." + if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES: + message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}" + raise ValueError(message) return normalized @@ -286,6 +307,8 @@ def _infer_groot_model_version_from_local_config(model_path: str) -> str | None: def _infer_groot_model_version_from_config(config: dict) -> str | None: model_version = config.get("model_version") if isinstance(model_version, str): + if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES: + return GROOT_N1_5 try: return normalize_groot_model_version(model_version) except ValueError: @@ -298,8 +321,17 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None: normalized = candidate.lower().replace("-", "_") if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}: return GROOT_N1_7 + # nvidia/GR00T-N1.5-3B ships model_type 'gr00t_n1_5' and architectures ['GR00T_N1_5']. + # Recognise them so N1.5 checkpoints at generic local paths are rejected loudly + # instead of being silently treated as N1.7 (see infer_groot_model_version). + if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}: + return GROOT_N1_5 if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL: return GROOT_N1_7 + # The Eagle VLM backbone is specific to pre-N1.7 GR00T checkpoints (N1.7 uses Cosmos/Qwen3-VL). + backbone_cfg = config.get("backbone_cfg") + if isinstance(backbone_cfg, dict) and "eagle_path" in backbone_cfg: + return GROOT_N1_5 return None @@ -310,27 +342,32 @@ class GrootConfig(PreTrainedConfig): # Basic policy settings n_obs_steps: int = 1 - chunk_size: int = 50 - n_action_steps: int = 50 + chunk_size: int = 40 + n_action_steps: int = 40 # Dimension settings (must match pretrained GR00T model expectations) # Maximum state dimension. Shorter states will be zero-padded. - max_state_dim: int = 64 + max_state_dim: int = 132 # Maximum action dimension. Shorter actions will be zero-padded. - max_action_dim: int = 32 + max_action_dim: int = 132 - # Normalization (start with identity, adjust as needed) + # GR00T normalizes state/action internally in its processor steps (min/max with + # q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor + # handles image normalization. The policy therefore does NOT use LeRobot's + # NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally + # IDENTITY for every feature and is not consulted by make_groot_pre_post_processors. normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { "VISUAL": NormalizationMode.IDENTITY, - "STATE": NormalizationMode.MEAN_STD, - "ACTION": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, } ) - # Image preprocessing (adjust to match Groot's expected input) - image_size: tuple[int, int] = (224, 224) + # Deprecated and unused: image sizing is handled by the backbone's image processor. + # Kept only so config.json files saved with earlier versions still parse. + image_size: tuple[int, int] = (256, 256) # Groot-specific model parameters (from groot_finetune_script.py) @@ -344,7 +381,14 @@ class GrootConfig(PreTrainedConfig): n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL # Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step(). - action_decode_transform: str | None = None + # 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no + # transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'. + action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO + + # Deprecated, GR00T N1.5 only — do not set. Kept so config.json files saved by lerobot<=0.5.1 + # still parse (draccus rejects unknown fields) and can be rejected in __post_init__ with a + # clear error pointing at GROOT_N1_5_REMOVAL_GUIDANCE instead of a cryptic DecodingError. + tokenizer_assets_repo: str | None = None # Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1') embodiment_tag: str = "new_embodiment" @@ -384,17 +428,13 @@ class GrootConfig(PreTrainedConfig): warmup_ratio: float = 0.05 use_bf16: bool = True - # Dataset parameters - # Video backend to use for training ('decord' or 'torchvision_av') + # Deprecated Isaac-GR00T runner fields below — unused by the LeRobot N1.7 implementation + # (nothing in src/lerobot reads them). They are kept only so config.json files saved by + # earlier lerobot releases still parse: draccus rejects unknown fields, so removing them + # would break every previously saved groot checkpoint at config-load time. video_backend: str = "decord" - - # Whether to balance dataset weights in mixture datasets balance_dataset_weights: bool = True - - # Whether to sample trajectories weighted by their length balance_trajectory_weights: bool = True - - # Optional dataset paths for delegating training to Isaac-GR00T runner dataset_paths: list[str] | None = None output_dir: str = "./tmp/gr00t" save_steps: int = 1000 @@ -405,6 +445,15 @@ class GrootConfig(PreTrainedConfig): resume: bool = False def __post_init__(self): + # 'tokenizer_assets_repo' only ever existed for GR00T N1.5 (lerobot<=0.5.1) and was + # serialized into every groot checkpoint config.json, so a value here means a legacy + # N1.5 checkpoint or config is being loaded. + if self.tokenizer_assets_repo is not None: + raise ValueError( + "Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks " + f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}" + ) + self.model_version = normalize_groot_model_version(self.model_version) self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform) if self.base_model_path is None: @@ -416,26 +465,48 @@ class GrootConfig(PreTrainedConfig): # 'libero_sim' embodiment grasps correctly instead of scoring 0% success. # This matches the embodiment-specific handling already done for the # action execution horizon (see infer_groot_n1_7_action_execution_horizon). - if self.action_decode_transform is None and self.embodiment_tag == "libero_sim": - self.action_decode_transform = GROOT_ACTION_DECODE_TRANSFORM_LIBERO + # Only the 'auto' sentinel resolves to the embodiment default; an explicit + # 'none' (normalized to None above) keeps the transform disabled. + if self.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_AUTO: + self.action_decode_transform = ( + GROOT_ACTION_DECODE_TRANSFORM_LIBERO if self.embodiment_tag == "libero_sim" else None + ) - if self.max_state_dim == 64: - self.max_state_dim = 132 - if self.max_action_dim == 32: - self.max_action_dim = 132 - if self.chunk_size == 50: - self.chunk_size = 40 - if self.n_action_steps == 50: - self.n_action_steps = 40 - if tuple(self.image_size) == (224, 224): - self.image_size = (256, 256) + # GR00T N1.5-era default values (e.g. --policy.chunk_size=50 from old commands or + # stale configs) are migrated to the values the N1.7 checkpoints expect, with a + # warning. The dataclass defaults are already the N1.7 values, so a plain + # GrootConfig() never triggers this. + legacy_default_remaps = ( + ("max_state_dim", 64, 132), + ("max_action_dim", 32, 132), + ("chunk_size", 50, 40), + ("n_action_steps", 50, 40), + ("image_size", (224, 224), (256, 256)), + ) + for field_name, legacy_value, n1_7_value in legacy_default_remaps: + current_value = getattr(self, field_name) + if isinstance(legacy_value, tuple): + current_value = tuple(current_value) + if current_value == legacy_value: + logger.warning( + "GrootConfig.%s=%s matches a legacy GR00T N1.5-era default; remapping it to %s, " + "the value expected by GR00T N1.7 checkpoints. Set a different value explicitly " + "if this is not what you want.", + field_name, + legacy_value, + n1_7_value, + ) + setattr(self, field_name, n1_7_value) inferred_version = infer_groot_model_version(self.base_model_path) if inferred_version is not None and inferred_version != self.model_version: - raise ValueError( + message = ( f"GR00T model_version '{self.model_version}' does not match base_model_path " f"'{self.base_model_path}', which looks like '{inferred_version}'." ) + if inferred_version == GROOT_N1_5: + message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}" + raise ValueError(message) super().__post_init__() @@ -511,9 +582,22 @@ class GrootConfig(PreTrainedConfig): @property def action_delta_indices(self) -> list[int]: - """Return indices for delta actions.""" - model_action_horizon = infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40 - return list(range(min(self.chunk_size, model_action_horizon))) + """Return indices for delta actions. + + The model action horizon is read from the checkpoint's processor_config.json + when available; the result is cached (keyed on the inputs that determine it) so + repeated access during dataset/training setup does not re-read from disk. + """ + cache_key = (self.base_model_path, self.embodiment_tag, self.chunk_size) + cached = getattr(self, "_action_delta_indices_cache", None) + if cached is not None and cached[0] == cache_key: + return cached[1] + model_action_horizon = ( + infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40 + ) + indices = list(range(min(self.chunk_size, model_action_horizon))) + object.__setattr__(self, "_action_delta_indices_cache", (cache_key, indices)) + return indices @property def reward_delta_indices(self) -> None: diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index d0587a425..1cd3eb171 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -18,15 +18,12 @@ Groot Policy Wrapper for LeRobot Integration Minimal integration that delegates to Isaac-GR00T N1.7 components where -possible without porting their code. - -Notes: -- Dataset loading and full training orchestration is handled by Isaac-GR00T - TrainRunner in their codebase. If you want to invoke that flow end-to-end - from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below. +possible without porting their code. Dataset loading and training +orchestration are handled by LeRobot's standard training stack. """ import builtins +import logging import os from collections import deque from pathlib import Path @@ -42,6 +39,8 @@ from lerobot.utils.import_utils import require_package from ..pretrained import PreTrainedPolicy from ..utils import get_device_from_parameters from .configuration_groot import ( + GROOT_N1_5, + GROOT_N1_5_REMOVAL_GUIDANCE, GROOT_N1_7, GrootConfig, infer_groot_model_version, @@ -50,6 +49,8 @@ from .configuration_groot import ( normalize_groot_model_version, ) +logger = logging.getLogger(__name__) + T = TypeVar("T", bound="GrootPolicy") @@ -92,8 +93,11 @@ class GrootPolicy(PreTrainedPolicy): transformers_loading_kwargs={"trust_remote_code": True}, ) - model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype - model.config.compute_dtype = model.compute_dtype + # GR00TN17 defines no compute_dtype attribute, so only record the + # bf16 preference when it is enabled instead of reading a default back. + if self.config.use_bf16: + model.compute_dtype = "bfloat16" + model.config.compute_dtype = "bfloat16" return model @@ -148,9 +152,10 @@ class GrootPolicy(PreTrainedPolicy): if config is not None else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7 ) - print( - f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n" - f"Loading pretrained model from: {pretrained_name_or_path}" + logger.info( + "The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s", + requested_version, + pretrained_name_or_path, ) model_id = str(pretrained_name_or_path) @@ -181,7 +186,7 @@ class GrootPolicy(PreTrainedPolicy): if is_finetuned_checkpoint: # This is a fine-tuned LeRobot checkpoint - use parent class loading - print("Detected fine-tuned LeRobot checkpoint, loading with state dict...") + logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...") return super().from_pretrained( pretrained_name_or_path=pretrained_name_or_path, config=config, @@ -197,7 +202,7 @@ class GrootPolicy(PreTrainedPolicy): ) # This is a base GR00T model - load it fresh - print("Detected base GR00T model, loading from HuggingFace...") + logger.info("Detected base GR00T model, loading from HuggingFace...") if config is None: model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7 @@ -229,10 +234,13 @@ class GrootPolicy(PreTrainedPolicy): config.model_version = normalize_groot_model_version(config.model_version) inferred_version = infer_groot_model_version(config.base_model_path) if inferred_version is not None and inferred_version != config.model_version: - raise ValueError( + message = ( f"GR00T model_version '{config.model_version}' does not match base_model_path " f"'{config.base_model_path}', which looks like '{inferred_version}'." ) + if inferred_version == GROOT_N1_5: + message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}" + raise ValueError(message) # Create a fresh policy instance - this will automatically load the GR00T model # in __init__ via _create_groot_model() policy = cls(config) @@ -297,9 +305,7 @@ class GrootPolicy(PreTrainedPolicy): allowed_base.add("action_mask") return { - k: v - for k, v in batch.items() - if k in allowed_base and not (k.startswith("next.") or k == "info") + k: v for k, v in batch.items() if k in allowed_base and not (k.startswith("next.") or k == "info") } def _prepare_n1_7_rtc_inputs( @@ -320,9 +326,7 @@ class GrootPolicy(PreTrainedPolicy): if prev_actions.ndim == 2: prev_actions = prev_actions.unsqueeze(0) elif prev_actions.ndim != 3: - raise ValueError( - "prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC." - ) + raise ValueError("prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC.") state = inputs.get("state") if state is None: @@ -331,9 +335,7 @@ class GrootPolicy(PreTrainedPolicy): if prev_actions.shape[0] == 1 and batch_size > 1: prev_actions = prev_actions.expand(batch_size, -1, -1).clone() elif prev_actions.shape[0] != batch_size: - raise ValueError( - "prev_chunk_left_over batch size must match the current GR00T N1.7 batch size." - ) + raise ValueError("prev_chunk_left_over batch size must match the current GR00T N1.7 batch size.") # The generic LeRobot RTC engine pads short leftovers with exact zero # rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every @@ -346,7 +348,9 @@ class GrootPolicy(PreTrainedPolicy): else: return inputs, None - model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)) + model_action_horizon = int( + getattr(self._groot_model.config, "action_horizon", self.config.chunk_size) + ) max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim)) if prev_actions.shape[1] > model_action_horizon: prev_actions = prev_actions[:, -model_action_horizon:, :] @@ -409,6 +413,11 @@ class GrootPolicy(PreTrainedPolicy): # Isaac-GR00T returns a BatchFeature; loss key is typically 'loss' loss = outputs.get("loss") + if loss is None: + raise RuntimeError( + "GR00T model.forward did not return a 'loss'. Training batches must include " + "'action' and 'action_mask'; check the preprocessor output." + ) loss_dict = {"loss": loss.item()} @@ -471,33 +480,21 @@ class GrootPolicy(PreTrainedPolicy): # Internal helpers # ------------------------- def _handle_flash_attention_compatibility(self) -> None: - """Handle Flash Attention compatibility issues by setting environment variables. + """Log Flash Attention availability (diagnostic only). - This addresses the common 'undefined symbol' error that occurs when Flash Attention - is compiled against a different PyTorch version than what's currently installed. + The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is + unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not + change behaviour or mutate global state. """ - - # Set environment variables to handle Flash Attention compatibility - # These help with symbol resolution issues - os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0") - os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0") - - # Try to import flash_attn and handle failures gracefully try: import flash_attn - print(f"[GROOT] Flash Attention version: {flash_attn.__version__}") - except ImportError as e: - print(f"[GROOT] Flash Attention not available: {e}") - print("[GROOT] Will use fallback attention mechanism") - except Exception as e: - if "undefined symbol" in str(e): - print(f"[GROOT] Flash Attention compatibility issue detected: {e}") - print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch") - print("[GROOT] Consider reinstalling Flash Attention with compatible version:") - print(" pip uninstall flash-attn") - print(" pip install --no-build-isolation flash-attn==2.6.3") - print("[GROOT] Continuing with fallback attention mechanism") - else: - print(f"[GROOT] Flash Attention error: {e}") - print("[GROOT] Continuing with fallback attention mechanism") + logger.debug("Flash Attention %s is available.", flash_attn.__version__) + except ImportError: + logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.") + except Exception as e: # noqa: BLE001 + logger.warning( + "Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is " + "an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.", + e, + ) diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index c3e51b791..95db63a05 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -15,14 +15,16 @@ # limitations under the License. import json +import logging from copy import copy -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields, is_dataclass from pathlib import Path from typing import TYPE_CHECKING, Any import numpy as np import torch from einops import rearrange +from huggingface_hub import hf_hub_download from PIL import Image from lerobot.utils.import_utils import _transformers_available @@ -59,6 +61,7 @@ from lerobot.utils.constants import ( from .configuration_groot import ( GROOT_ACTION_DECODE_TRANSFORM_LIBERO, + GROOT_N1_5_REMOVAL_GUIDANCE, GROOT_N1_7_BACKBONE_MODEL, GrootConfig, is_raw_groot_n1_7_checkpoint, @@ -80,12 +83,6 @@ N1_7_EMBODIMENT_MAPPING = { "new_embodiment": 10, } -_N1_7_RAW_STATE_CACHE: dict[str, dict[str, np.ndarray]] = {} - - -def _n1_7_state_cache_key(value: str | None) -> str: - return value or "groot_n1_7_default" - @dataclass class _GrootN17CheckpointProcessorAssets: @@ -471,25 +468,98 @@ def _legacy_groot_processor_overrides( env_action_dim = int(config.output_features[ACTION].shape[0]) except Exception: env_action_dim = 0 - action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v1", {})) + action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v2", {})) action_unpack_overrides["normalize_min_max"] = True action_unpack_overrides["env_action_dim"] = env_action_dim - postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = action_unpack_overrides + postprocessor_overrides["groot_action_unpack_unnormalize_v2"] = action_unpack_overrides return preprocessor_overrides, postprocessor_overrides -def _local_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool: +def _pretrained_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool: + """Check whether a serialized processor pipeline contains a registry step. + + Resolves the processor config from a local directory or, for Hub repo ids, + via ``hf_hub_download`` (which serves the cached copy when offline). Returns + False when the config cannot be resolved; loading then proceeds with the + legacy overrides and `make_groot_pre_post_processors_from_pretrained` retries + without them if they do not match the serialized pipeline. + """ path = Path(pretrained_path).expanduser() - if not path.is_dir(): + if path.is_dir(): + config = _read_json(path / config_filename) + elif path.exists(): return False - config = _read_json(path / config_filename) + else: + try: + config_path = hf_hub_download( + repo_id=str(pretrained_path), filename=config_filename, repo_type="model" + ) + except Exception: + return False + config = _read_json(Path(config_path)) steps = config.get("steps", []) if not isinstance(steps, list): return False return any(isinstance(step, dict) and step.get("registry_name") == step_name for step in steps) +def _apply_groot_step_overrides( + pipeline: PolicyProcessorPipeline, + overrides: dict[str, Any] | None, +) -> None: + """Apply ``from_pretrained``-style step overrides to a freshly built pipeline. + + Raw N1.7 checkpoints build their processors from scratch instead of + deserializing them, so caller overrides must be applied to the constructed + steps. Override keys match a step's registry name or, as a convenience, its + class name (``PolicyProcessorPipeline.from_pretrained`` matches registered + steps by registry name only — prefer registry names so overrides keep + working after the checkpoint is converted and reloaded from a serialized + pipeline). Keys or fields that match nothing raise instead of being dropped + silently. + """ + + if not overrides: + return + + def _step_keys(step: ProcessorStep) -> set[str]: + keys = {type(step).__name__} + registry_name = getattr(type(step), "_registry_name", None) + if registry_name: + keys.add(registry_name) + return keys + + for override_key, step_overrides in overrides.items(): + matched_steps = [step for step in pipeline.steps if override_key in _step_keys(step)] + if not matched_steps: + available = [ + getattr(type(step), "_registry_name", None) or type(step).__name__ for step in pipeline.steps + ] + raise KeyError( + f"Override key '{override_key}' does not match any step of the GR00T processor pipeline " + f"built for this raw N1.7 checkpoint. Available step keys: {available}." + ) + for step in matched_steps: + if not is_dataclass(step): + raise TypeError( + f"Cannot apply overrides to step '{override_key}': it is not a dataclass step." + ) + init_field_names = {f.name for f in fields(step) if f.init} + for field_name, value in dict(step_overrides).items(): + if field_name not in init_field_names: + raise TypeError( + f"Override field '{field_name}' is not a config field of step '{override_key}'. " + f"Available fields: {sorted(init_field_names)}." + ) + setattr(step, field_name, value) + # Re-derive attributes computed from the overridden config (e.g. + # DeviceProcessorStep resolves its torch.device in __post_init__). + post_init = getattr(step, "__post_init__", None) + if callable(post_init): + post_init() + + def make_groot_pre_post_processors_from_pretrained( config: GrootConfig, pretrained_path: str, @@ -508,12 +578,20 @@ def make_groot_pre_post_processors_from_pretrained( if is_raw_groot_n1_7_checkpoint(pretrained_path): processor_cfg = copy(config) processor_cfg.base_model_path = str(pretrained_path) - return make_groot_pre_post_processors( + preprocessor, postprocessor = make_groot_pre_post_processors( config=processor_cfg, dataset_stats=dataset_stats, ) + # Raw checkpoints have no serialized pipelines to load overrides into, + # so apply the caller overrides (e.g. device and rename_map from + # lerobot-eval or the policy server) to the freshly built steps. + _apply_groot_step_overrides(preprocessor, preprocessor_overrides) + _apply_groot_step_overrides(postprocessor, postprocessor_overrides) + return preprocessor, postprocessor - if _local_processor_config_has_step( + caller_preprocessor_overrides = dict(preprocessor_overrides or {}) + caller_postprocessor_overrides = dict(postprocessor_overrides or {}) + if _pretrained_processor_config_has_step( pretrained_path, postprocessor_config_filename, "groot_n1_7_action_decode_v1", @@ -521,15 +599,55 @@ def make_groot_pre_post_processors_from_pretrained( # Converted raw N1.7 checkpoints already carry the checkpoint-specific # action decoder. Adding the legacy action-unpack override would target # a step that is not present and break loading. - preprocessor_overrides = dict(preprocessor_overrides or {}) - postprocessor_overrides = dict(postprocessor_overrides or {}) + applied_legacy_overrides = False + preprocessor_overrides = caller_preprocessor_overrides + postprocessor_overrides = caller_postprocessor_overrides else: + applied_legacy_overrides = True preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides( config=config, dataset_stats=dataset_stats, preprocessor_overrides=preprocessor_overrides, postprocessor_overrides=postprocessor_overrides, ) + try: + preprocessor, postprocessor = _load_groot_processor_pipelines( + pretrained_path, + preprocessor_overrides=preprocessor_overrides, + postprocessor_overrides=postprocessor_overrides, + preprocessor_config_filename=preprocessor_config_filename, + postprocessor_config_filename=postprocessor_config_filename, + ) + except KeyError: + if not applied_legacy_overrides: + raise + # The legacy overrides target steps that are absent from the serialized + # pipelines (e.g. a converted raw N1.7 checkpoint whose postprocessor + # config could not be inspected before loading); retry with the caller + # overrides only. + preprocessor, postprocessor = _load_groot_processor_pipelines( + pretrained_path, + preprocessor_overrides=caller_preprocessor_overrides, + postprocessor_overrides=caller_postprocessor_overrides, + preprocessor_config_filename=preprocessor_config_filename, + postprocessor_config_filename=postprocessor_config_filename, + ) + _reconnect_groot_relative_absolute_steps(preprocessor, postprocessor) + _reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor) + return preprocessor, postprocessor + + +def _load_groot_processor_pipelines( + pretrained_path: str, + *, + preprocessor_overrides: dict[str, Any], + postprocessor_overrides: dict[str, Any], + preprocessor_config_filename: str, + postprocessor_config_filename: str, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: preprocessor = PolicyProcessorPipeline.from_pretrained( pretrained_model_name_or_path=pretrained_path, config_filename=preprocessor_config_filename, @@ -544,7 +662,6 @@ def make_groot_pre_post_processors_from_pretrained( to_transition=policy_action_to_transition, to_output=transition_to_policy_action, ) - _reconnect_groot_relative_absolute_steps(preprocessor, postprocessor) return preprocessor, postprocessor @@ -564,6 +681,28 @@ def _reconnect_groot_relative_absolute_steps( step.relative_step = relative_step +def _reconnect_groot_n1_7_pack_decode_steps( + preprocessor: PolicyProcessorPipeline, + postprocessor: PolicyProcessorPipeline, +) -> None: + """Re-link a deserialized N1.7 action decode step to its pack step. + + The pack step holds the per-instance raw-state cache that relative-action + decoding reads its reference state from; the link itself is not serialized. + """ + + pack_step = next( + (step for step in preprocessor.steps if isinstance(step, GrootN17PackInputsStep)), + None, + ) + if pack_step is None: + return + + for step in postprocessor.steps: + if isinstance(step, GrootN17ActionDecodeStep) and step.pack_step is None: + step.pack_step = pack_step + + def make_groot_pre_post_processors( config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[ @@ -607,9 +746,12 @@ def make_groot_pre_post_processors( else action_horizon ) checkpoint_stats = checkpoint_assets.stats if checkpoint_assets is not None else None - padded_stats = checkpoint_stats if _has_modality_stats(checkpoint_stats) else (dataset_stats or {}) + checkpoint_has_stats = _has_modality_stats(checkpoint_stats) + padded_stats = checkpoint_stats if checkpoint_has_stats else (dataset_stats or {}) embodiment_mapping = ( - checkpoint_assets.embodiment_mapping if checkpoint_assets is not None else dict(N1_7_EMBODIMENT_MAPPING) + checkpoint_assets.embodiment_mapping + if checkpoint_assets is not None + else dict(N1_7_EMBODIMENT_MAPPING) ) formalize_language = checkpoint_assets.formalize_language if checkpoint_assets is not None else True clip_outliers = checkpoint_assets.clip_outliers if checkpoint_assets is not None else True @@ -618,7 +760,6 @@ def make_groot_pre_post_processors( env_action_dim = int(config.output_features[ACTION].shape[0]) except Exception: env_action_dim = 0 - state_cache_key = f"groot_n1_7:{config.embodiment_tag}" pack_step = GrootN17PackInputsStep( state_horizon=1, action_horizon=action_horizon, @@ -636,7 +777,6 @@ def make_groot_pre_post_processors( video_modality_keys=video_modality_keys, raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None, modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None, - state_cache_key=state_cache_key, ) input_steps: list[ProcessorStep] = [ @@ -647,14 +787,31 @@ def make_groot_pre_post_processors( model_name=config.n1_7_backbone_model, image_crop_size=checkpoint_assets.image_crop_size if checkpoint_assets is not None else None, image_target_size=checkpoint_assets.image_target_size if checkpoint_assets is not None else None, - shortest_image_edge=checkpoint_assets.shortest_image_edge if checkpoint_assets is not None else None, + shortest_image_edge=checkpoint_assets.shortest_image_edge + if checkpoint_assets is not None + else None, crop_fraction=checkpoint_assets.crop_fraction if checkpoint_assets is not None else None, - use_albumentations=checkpoint_assets.use_albumentations if checkpoint_assets is not None else False, + use_albumentations=checkpoint_assets.use_albumentations + if checkpoint_assets is not None + else False, ), DeviceProcessorStep(device=config.device), ] - if checkpoint_assets is None: + if checkpoint_assets is not None and not checkpoint_has_stats and not _has_modality_stats(padded_stats): + raise ValueError( + f"GR00T N1.7 checkpoint '{config.base_model_path}' has no statistics for embodiment tag " + f"'{config.embodiment_tag}', and no dataset stats were provided to fall back to, so " + "actions cannot be normalized or decoded. Pass dataset_stats, or set " + "config.embodiment_tag to an embodiment present in the checkpoint's statistics.json." + ) + if checkpoint_assets is None or not checkpoint_has_stats: + # When the checkpoint sidecars have no stats for the configured + # embodiment tag (e.g. finetuning a raw base checkpoint with the + # default 'new_embodiment' tag), the pack step above normalized with + # the dataset stats; the decode step must invert with the same stats + # instead of using a checkpoint decoder whose empty stats would + # silently return normalized [-1, 1] actions. action_decode_step: ProcessorStep = GrootActionUnpackUnnormalizeStep( env_action_dim=env_action_dim, stats=padded_stats, @@ -669,7 +826,6 @@ def make_groot_pre_post_processors( use_percentiles=checkpoint_assets.use_percentiles, use_relative_action=checkpoint_assets.use_relative_action, pack_step=pack_step, - state_cache_key=state_cache_key, action_decode_transform=config.action_decode_transform, ) @@ -770,7 +926,7 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces try: from transformers import ( AutoTokenizer, - Qwen2VLImageProcessorFast, + Qwen2VLImageProcessor, Qwen3VLProcessor, Qwen3VLVideoProcessor, ) @@ -781,7 +937,7 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces ) from exc tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - image_processor = Qwen2VLImageProcessorFast.from_pretrained(model_name, trust_remote_code=True) + image_processor = Qwen2VLImageProcessor.from_pretrained(model_name, trust_remote_code=True) video_processor = Qwen3VLVideoProcessor.from_pretrained(model_name, trust_remote_code=True) proc = Qwen3VLProcessor( image_processor=image_processor, @@ -902,8 +1058,11 @@ class GrootN17PackInputsStep(ProcessorStep): video_modality_keys: list[str] | None = None raw_stats: dict[str, Any] | None = None modality_config: dict[str, Any] | None = None + # Unused: kept so serialized configs that include it still load. The raw + # state cache is per instance (_last_raw_state), never process-global. state_cache_key: str = "" _last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False) + _warned_image_keys: bool = field(default=False, init=False, repr=False) def _ordered_image_keys(self, obs: dict[str, Any]) -> list[str]: available = {key for key in obs if key.startswith(OBS_IMAGES)} @@ -913,19 +1072,56 @@ class GrootN17PackInputsStep(ProcessorStep): return sorted(available) ordered: list[str] = [] + unmatched: list[str] = [] for modality_key in self.video_modality_keys: candidates = [f"{OBS_IMAGES}.{modality_key}"] + # Alias for datasets converted with generic camera names (e.g. the + # LIBERO conversions expose the wrist camera as + # `observation.images.image2`), so raw N1.7 LIBERO checkpoints + # match those datasets out of the box. if modality_key == "wrist_image": candidates.append(f"{OBS_IMAGES}.image2") - elif modality_key == "image": - candidates.append(f"{OBS_IMAGES}.image") match = next((candidate for candidate in candidates if candidate in available), None) - if match is not None: + if match is None: + unmatched.append(modality_key) + else: ordered.append(match) if not ordered: + if not self._warned_image_keys: + self._warned_image_keys = True + logging.warning( + "None of the GR00T N1.7 checkpoint video modality keys %s match a camera among %s; " + "falling back to feeding all cameras in alphabetical order, which is unlikely to be " + "the layout the checkpoint was trained with. Rename the dataset cameras (e.g. via " + "--rename_map) to match the checkpoint keys.", + self.video_modality_keys, + sorted(available), + ) return sorted(available) + unused = sorted(available - set(ordered)) + if (unmatched or unused) and not self._warned_image_keys: + self._warned_image_keys = True + if unmatched: + logging.warning( + "GR00T N1.7 checkpoint video modality keys %s have no matching camera among %s; " + "the model will receive %d view(s) instead of the %d it was trained with. Rename " + "the dataset cameras (e.g. via --rename_map) to match the checkpoint keys %s.", + unmatched, + sorted(available), + len(ordered), + len(self.video_modality_keys), + self.video_modality_keys, + ) + if unused: + logging.warning( + "Dropping camera(s) %s: the GR00T N1.7 checkpoint only consumes the video modality " + "keys %s, which matched %s.", + unused, + self.video_modality_keys, + ordered, + ) return ordered def __call__(self, transition: EnvTransition) -> EnvTransition: @@ -988,7 +1184,6 @@ class GrootN17PackInputsStep(ProcessorStep): start_idx += dim if grouped: self._last_raw_state = grouped - _N1_7_RAW_STATE_CACHE[_n1_7_state_cache_key(self.state_cache_key)] = grouped img_keys = self._ordered_image_keys(obs) if img_keys: @@ -1101,7 +1296,6 @@ class GrootN17PackInputsStep(ProcessorStep): "video_modality_keys": self.video_modality_keys, "raw_stats": self.raw_stats, "modality_config": self.modality_config, - "state_cache_key": self.state_cache_key, } def get_cached_raw_state(self) -> dict[str, np.ndarray] | None: @@ -1223,6 +1417,7 @@ class GrootN17VLMEncodeStep(ProcessorStep): "use_albumentations": self.use_albumentations, } + def _stat_dim_from_entry(entry: dict[str, Any]) -> int: for stat_name in ("mean", "q01", "min", "max", "std"): value = entry.get(stat_name) @@ -1351,6 +1546,18 @@ class GrootN17ActionDecodeStep(ProcessorStep): group with the checkpoint stats, converts relative groups to absolute values using the raw state cached during packing, concatenates groups in checkpoint order, and finally slices to the environment action dimension. + + Relative-action decoding reads the reference state from the connected + ``pack_step`` (re-linked after ``from_pretrained`` by + ``_reconnect_groot_n1_7_pack_decode_steps``), i.e. the state seen by the + most recent preprocess call. Engines that decode the whole chunk right + after prediction (RTC, async policy server) therefore use the + prediction-time state, matching Isaac-GR00T. The sync per-step queue path + instead decodes each popped (B, D) action against the latest observation: + the reference can be newer than the observation the chunk was predicted + from, and per-timestep relative stats are applied as if the popped action + were chunk step 0. Fixing that would require carrying the reference state + and chunk index alongside each queued action through the postprocessor. """ env_action_dim: int = 0 @@ -1358,6 +1565,7 @@ class GrootN17ActionDecodeStep(ProcessorStep): modality_config: dict[str, Any] | None = None use_percentiles: bool = False use_relative_action: bool = False + # Unused: kept so serialized configs that include it still load. state_cache_key: str = "" action_decode_transform: str | None = None pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False) @@ -1378,6 +1586,12 @@ class GrootN17ActionDecodeStep(ProcessorStep): return transition action_np = action.detach().cpu().float().numpy() + # The sync action queue postprocesses popped actions as (B, D); decode + # them as single-step (B, 1, D) chunks and squeeze the horizon back at + # the end so both ranks share the chunk decode logic below. + squeeze_horizon = action_np.ndim == 2 + if squeeze_horizon: + action_np = action_np[:, None, :] valid_horizon = _n1_7_decode_valid_horizon(action_config, action_np) if valid_horizon is not None: action_np = action_np[:, :valid_horizon] @@ -1405,17 +1619,24 @@ class GrootN17ActionDecodeStep(ProcessorStep): use_relative_action=self.use_relative_action, use_percentiles=self.use_percentiles, ) + # Per-timestep stats carry one row per chunk step; align them with + # the decoded horizon (chunks always start at step 0, and a popped + # (B, D) action is decoded as step 0). + if min_v.ndim == 2 and normalized.shape[1] <= min_v.shape[0]: + min_v = min_v[: normalized.shape[1]] + max_v = max_v[: normalized.shape[1]] decoded_groups[key] = _unnormalize_min_max(normalized, min_v, max_v) start_idx += dim if self.use_relative_action: raw_state = self.pack_step.get_cached_raw_state() if self.pack_step is not None else None - if raw_state is None: - raw_state = _N1_7_RAW_STATE_CACHE.get(_n1_7_state_cache_key(self.state_cache_key)) if raw_state is None: raise RuntimeError( - "GrootN17ActionDecodeStep requires cached raw state from GrootN17PackInputsStep " - "to convert relative N1.7 actions back to absolute actions." + "GrootN17ActionDecodeStep requires the raw state cached by its connected " + "GrootN17PackInputsStep to convert relative N1.7 actions back to absolute actions. " + "Build both pipelines through make_groot_pre_post_processors (or load them together " + "via make_groot_pre_post_processors_from_pretrained) and run the preprocessor on an " + "observation before decoding actions." ) for idx, key in enumerate(action_keys): if not isinstance(key, str) or key not in decoded_groups or idx >= len(action_configs): @@ -1451,6 +1672,8 @@ class GrootN17ActionDecodeStep(ProcessorStep): action_keys=action_keys, decoded_groups=decoded_groups, ) + if squeeze_horizon: + decoded = decoded[:, 0] new_transition = transition.copy() new_transition[TransitionKey.ACTION] = torch.as_tensor( decoded, dtype=action.dtype, device=action.device @@ -1467,13 +1690,15 @@ class GrootN17ActionDecodeStep(ProcessorStep): "modality_config": self.modality_config, "use_percentiles": self.use_percentiles, "use_relative_action": self.use_relative_action, - "state_cache_key": self.state_cache_key, "action_decode_transform": self.action_decode_transform, } @dataclass -@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v1") +# v2: unlike the N1.5-era v1 step, this step no longer collapses (B, T, D) +# action chunks to the last timestep, so old serialized v1 pipelines must not +# silently load into it (v1 is stubbed below with the removal guidance). +@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v2") class GrootActionUnpackUnnormalizeStep(ProcessorStep): env_action_dim: int = 0 # Apply inverse of min-max normalization if it was used in preprocessor @@ -1585,3 +1810,37 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep): if reconstructed: self.stats = reconstructed + + +def _register_removed_n1_5_step_stub(registry_name: str) -> None: + """Register a stub for a processor step that only GR00T N1.5 pipelines serialize. + + Saved N1.5 checkpoints reference these registry names in their processor JSON + files. Deserializing them must fail with the canonical N1.5 removal guidance + instead of an opaque registry KeyError (or, for + ``groot_action_unpack_unnormalize_v1``, silently loading the v2 step whose + action-chunk semantics changed). + """ + + @ProcessorStepRegistry.register(name=registry_name) + class _RemovedGrootN15ProcessorStep(ProcessorStep): + def __init__(self, **_kwargs: Any) -> None: + raise ValueError( + f"Processor step '{registry_name}' belongs to a GR00T N1.5 processor pipeline. " + f"{GROOT_N1_5_REMOVAL_GUIDANCE}" + ) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + raise NotImplementedError + + def transform_features(self, features): + raise NotImplementedError + + +for _removed_n1_5_step_name in ( + "groot_pack_inputs_v3", + "groot_eagle_encode_v3", + "groot_eagle_collate_v3", + "groot_action_unpack_unnormalize_v1", +): + _register_removed_n1_5_step_stub(_removed_n1_5_step_name)