Compare commits

..

2 Commits

Author SHA1 Message Date
Steven Palma 378897800a fix(groot): skip normalization overrides for training 2026-06-13 19:51:29 +02:00
Steven Palma fcb371eddd fix(groot): N1.7 config defaults, N1.5 rejection, and processor/model runtime fixes
Covers the GR00T N1.7 source trio (configuration, processor, model wrapper).

Config:
- GrootConfig defaults are the N1.7 values; explicitly passed legacy N1.5-era
  values (chunk_size=50, max_state_dim=64, ...) are remapped with a warning
  instead of silently.
- action_decode_transform gains an 'auto' sentinel so an explicit 'none'
  opt-out wins over the libero_sim default and survives save/load round-trips.
- action_delta_indices is cached on the inputs that determine it.
- Legacy N1.5 checkpoints/configs (tokenizer_assets_repo, model_type/
  architectures/eagle backbone markers) are rejected with a single clear
  error pointing to lerobot==0.5.1.

Processor:
- GrootN17ActionDecodeStep handles the 2-D (B, D) actions delivered by sync
  select_action (relative eef/non-eef decode in eval/record flows).
- Postprocessor falls back to dataset stats when a raw checkpoint lacks the
  configured embodiment tag; raw-state cache is per-instance, not
  process-global; caller overrides (device, rename_map) are honored on the
  raw-checkpoint branch.
- Camera/modality-key mismatches warn (including the zero-match fallback);
  deprecated Qwen2VLImageProcessorFast replaced with Qwen2VLImageProcessor;
  removed N1.5 processor steps are stubbed to raise the removal guidance and
  the action-unpack step is re-registered as _v2.

Model:
- Flash-attention probe is diagnostic-only; forward raises on a missing loss;
  print() replaced with logging; N1.5 base-path mismatch includes the
  removal guidance.
2026-06-13 18:30:21 +02:00
3 changed files with 52 additions and 136 deletions
@@ -321,9 +321,6 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
normalized = candidate.lower().replace("-", "_")
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
return GROOT_N1_7
# nvidia/GR00T-N1.5-3B ships model_type 'gr00t_n1_5' and architectures ['GR00T_N1_5'].
# Recognise them so N1.5 checkpoints at generic local paths are rejected loudly
# instead of being silently treated as N1.7 (see infer_groot_model_version).
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
return GROOT_N1_5
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
@@ -365,11 +362,7 @@ class GrootConfig(PreTrainedConfig):
}
)
# Deprecated and unused: image sizing is handled by the backbone's image processor.
# Kept only so config.json files saved with earlier versions still parse.
image_size: tuple[int, int] = (256, 256)
# Groot-specific model parameters (from groot_finetune_script.py)
# Groot-specific model parameters
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
model_version: str = GROOT_N1_7
@@ -385,11 +378,6 @@ class GrootConfig(PreTrainedConfig):
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
# Deprecated, GR00T N1.5 only — do not set. Kept so config.json files saved by lerobot<=0.5.1
# still parse (draccus rejects unknown fields) and can be rejected in __post_init__ with a
# clear error pointing at GROOT_N1_5_REMOVAL_GUIDANCE instead of a cryptic DecodingError.
tokenizer_assets_repo: str | None = None
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment"
@@ -428,10 +416,13 @@ class GrootConfig(PreTrainedConfig):
warmup_ratio: float = 0.05
use_bf16: bool = True
# Deprecated Isaac-GR00T runner fields below — unused by the LeRobot N1.7 implementation
# TODO(Steven): Remove these deprecated fields in a future release.
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
# would break every previously saved groot checkpoint at config-load time.
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
tokenizer_assets_repo: str | None = None
video_backend: str = "decord"
balance_dataset_weights: bool = True
balance_trajectory_weights: bool = True
@@ -445,9 +436,6 @@ class GrootConfig(PreTrainedConfig):
resume: bool = False
def __post_init__(self):
# 'tokenizer_assets_repo' only ever existed for GR00T N1.5 (lerobot<=0.5.1) and was
# serialized into every groot checkpoint config.json, so a value here means a legacy
# N1.5 checkpoint or config is being loaded.
if self.tokenizer_assets_repo is not None:
raise ValueError(
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
@@ -582,22 +570,11 @@ class GrootConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions.
The model action horizon is read from the checkpoint's processor_config.json
when available; the result is cached (keyed on the inputs that determine it) so
repeated access during dataset/training setup does not re-read from disk.
"""
cache_key = (self.base_model_path, self.embodiment_tag, self.chunk_size)
cached = getattr(self, "_action_delta_indices_cache", None)
if cached is not None and cached[0] == cache_key:
return cached[1]
"""Return indices for delta actions."""
model_action_horizon = (
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
)
indices = list(range(min(self.chunk_size, model_action_horizon)))
object.__setattr__(self, "_action_delta_indices_cache", (cache_key, indices))
return indices
return list(range(min(self.chunk_size, model_action_horizon)))
@property
def reward_delta_indices(self) -> None:
@@ -93,12 +93,6 @@ class GrootPolicy(PreTrainedPolicy):
transformers_loading_kwargs={"trust_remote_code": True},
)
# GR00TN17 defines no compute_dtype attribute, so only record the
# bf16 preference when it is enabled instead of reading a default back.
if self.config.use_bf16:
model.compute_dtype = "bfloat16"
model.config.compute_dtype = "bfloat16"
return model
def reset(self):
+45 -100
View File
@@ -24,7 +24,6 @@ from typing import TYPE_CHECKING, Any
import numpy as np
import torch
from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image
from lerobot.utils.import_utils import _transformers_available
@@ -448,60 +447,40 @@ def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
return any(bool(modality_stats) for modality_stats in stats.values())
def _legacy_groot_processor_overrides(
config: GrootConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None,
preprocessor_overrides: dict[str, Any] | None = None,
postprocessor_overrides: dict[str, Any] | None = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Patch older serialized Groot processors with fields current processors expect."""
preprocessor_overrides = dict(preprocessor_overrides or {})
postprocessor_overrides = dict(postprocessor_overrides or {})
pack_inputs_key = "groot_n1_7_pack_inputs_v1"
pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {}))
pack_input_overrides["normalize_min_max"] = True
preprocessor_overrides[pack_inputs_key] = pack_input_overrides
try:
env_action_dim = int(config.output_features[ACTION].shape[0])
except Exception:
env_action_dim = 0
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v2", {}))
action_unpack_overrides["normalize_min_max"] = True
action_unpack_overrides["env_action_dim"] = env_action_dim
postprocessor_overrides["groot_action_unpack_unnormalize_v2"] = action_unpack_overrides
return preprocessor_overrides, postprocessor_overrides
# GR00T normalizes state/action inside its own processor steps and so deliberately has no
# NormalizerProcessorStep/UnnormalizerProcessorStep (see GrootConfig.normalization_mapping, which is
# IDENTITY for every feature). lerobot-train nonetheless emits these standard override keys
# unconditionally, so for a GR00T pipeline they legitimately match no step. They are dropped up front
# by _drop_groot_absent_standard_overrides so they neither break loading nor mask genuine typos.
_GROOT_ABSENT_STANDARD_OVERRIDE_KEYS = frozenset({"normalizer_processor", "unnormalizer_processor"})
def _pretrained_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
"""Check whether a serialized processor pipeline contains a registry step.
def _drop_groot_absent_standard_overrides(overrides: dict[str, Any] | None) -> dict[str, Any] | None:
"""Strip standard normalization override keys that a GR00T pipeline has no step for.
Resolves the processor config from a local directory or, for Hub repo ids,
via ``hf_hub_download`` (which serves the cached copy when offline). Returns
False when the config cannot be resolved; loading then proceeds with the
legacy overrides and `make_groot_pre_post_processors_from_pretrained` retries
without them if they do not match the serialized pipeline.
``lerobot-train`` emits ``normalizer_processor``/``unnormalizer_processor`` overrides
unconditionally, but GR00T normalizes inside its own steps and has no such step (see
``GrootConfig.normalization_mapping``). Both override-application paths reject keys that match no
step — ``_apply_groot_step_overrides`` raises for the freshly built raw-checkpoint pipeline, and
``PolicyProcessorPipeline.from_pretrained`` raises via its used-override validation for the
serialized pipeline — so these keys are removed before either path runs. Any other unknown key
(e.g. a typo) is left in place and still raises.
"""
path = Path(pretrained_path).expanduser()
if path.is_dir():
config = _read_json(path / config_filename)
elif path.exists():
return False
else:
try:
config_path = hf_hub_download(
repo_id=str(pretrained_path), filename=config_filename, repo_type="model"
if not overrides:
return overrides
filtered: dict[str, Any] = {}
for key, value in overrides.items():
if key in _GROOT_ABSENT_STANDARD_OVERRIDE_KEYS:
logging.debug(
"Ignoring override key '%s': GR00T normalizes inside its own processor steps and has "
"no matching step (see GrootConfig.normalization_mapping).",
key,
)
except Exception:
return False
config = _read_json(Path(config_path))
steps = config.get("steps", [])
if not isinstance(steps, list):
return False
return any(isinstance(step, dict) and step.get("registry_name") == step_name for step in steps)
continue
filtered[key] = value
return filtered
def _apply_groot_step_overrides(
@@ -517,7 +496,8 @@ def _apply_groot_step_overrides(
steps by registry name only — prefer registry names so overrides keep
working after the checkpoint is converted and reloaded from a serialized
pipeline). Keys or fields that match nothing raise instead of being dropped
silently.
silently (standard normalization keys GR00T has no step for are removed
beforehand by ``_drop_groot_absent_standard_overrides``).
"""
if not overrides:
@@ -573,7 +553,13 @@ def make_groot_pre_post_processors_from_pretrained(
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Load Groot processors while preserving compatibility with older serialized configs."""
"""Load Groot processors for a raw N1.7 checkpoint or a serialized LeRobot pipeline."""
# Drop the standard normalizer/unnormalizer override keys lerobot-train emits unconditionally:
# GR00T has no such steps, so they would make both the raw-checkpoint and serialized override
# paths raise. This must happen before either branch below.
preprocessor_overrides = _drop_groot_absent_standard_overrides(preprocessor_overrides)
postprocessor_overrides = _drop_groot_absent_standard_overrides(postprocessor_overrides)
if is_raw_groot_n1_7_checkpoint(pretrained_path):
processor_cfg = copy(config)
@@ -589,49 +575,13 @@ def make_groot_pre_post_processors_from_pretrained(
_apply_groot_step_overrides(postprocessor, postprocessor_overrides)
return preprocessor, postprocessor
caller_preprocessor_overrides = dict(preprocessor_overrides or {})
caller_postprocessor_overrides = dict(postprocessor_overrides or {})
if _pretrained_processor_config_has_step(
preprocessor, postprocessor = _load_groot_processor_pipelines(
pretrained_path,
postprocessor_config_filename,
"groot_n1_7_action_decode_v1",
):
# Converted raw N1.7 checkpoints already carry the checkpoint-specific
# action decoder. Adding the legacy action-unpack override would target
# a step that is not present and break loading.
applied_legacy_overrides = False
preprocessor_overrides = caller_preprocessor_overrides
postprocessor_overrides = caller_postprocessor_overrides
else:
applied_legacy_overrides = True
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
config=config,
dataset_stats=dataset_stats,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
)
try:
preprocessor, postprocessor = _load_groot_processor_pipelines(
pretrained_path,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
preprocessor_config_filename=preprocessor_config_filename,
postprocessor_config_filename=postprocessor_config_filename,
)
except KeyError:
if not applied_legacy_overrides:
raise
# The legacy overrides target steps that are absent from the serialized
# pipelines (e.g. a converted raw N1.7 checkpoint whose postprocessor
# config could not be inspected before loading); retry with the caller
# overrides only.
preprocessor, postprocessor = _load_groot_processor_pipelines(
pretrained_path,
preprocessor_overrides=caller_preprocessor_overrides,
postprocessor_overrides=caller_postprocessor_overrides,
preprocessor_config_filename=preprocessor_config_filename,
postprocessor_config_filename=postprocessor_config_filename,
)
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
preprocessor_config_filename=preprocessor_config_filename,
postprocessor_config_filename=postprocessor_config_filename,
)
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
_reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor)
return preprocessor, postprocessor
@@ -1058,9 +1008,6 @@ class GrootN17PackInputsStep(ProcessorStep):
video_modality_keys: list[str] | None = None
raw_stats: dict[str, Any] | None = None
modality_config: dict[str, Any] | None = None
# Unused: kept so serialized configs that include it still load. The raw
# state cache is per instance (_last_raw_state), never process-global.
state_cache_key: str = ""
_last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False)
_warned_image_keys: bool = field(default=False, init=False, repr=False)
@@ -1565,8 +1512,6 @@ class GrootN17ActionDecodeStep(ProcessorStep):
modality_config: dict[str, Any] | None = None
use_percentiles: bool = False
use_relative_action: bool = False
# Unused: kept so serialized configs that include it still load.
state_cache_key: str = ""
action_decode_transform: str | None = None
pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False)
@@ -1694,10 +1639,10 @@ class GrootN17ActionDecodeStep(ProcessorStep):
}
@dataclass
# v2: unlike the N1.5-era v1 step, this step no longer collapses (B, T, D)
# action chunks to the last timestep, so old serialized v1 pipelines must not
# silently load into it (v1 is stubbed below with the removal guidance).
@dataclass
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v2")
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
env_action_dim: int = 0