mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-13 22:39:56 +00:00
fix(groot): N1.7 config defaults, N1.5 rejection, and processor/model runtime fixes
Covers the GR00T N1.7 source trio (configuration, processor, model wrapper). These three files are grouped together because processor_groot and modeling_groot import GROOT_N1_5_REMOVAL_GUIDANCE defined in configuration_groot. Config: - GrootConfig defaults are the N1.7 values; explicitly passed legacy N1.5-era values (chunk_size=50, max_state_dim=64, ...) are remapped with a warning instead of silently. - action_decode_transform gains an 'auto' sentinel so an explicit 'none' opt-out wins over the libero_sim default and survives save/load round-trips. - action_delta_indices is cached on the inputs that determine it. - Legacy N1.5 checkpoints/configs (tokenizer_assets_repo, model_type/ architectures/eagle backbone markers) are rejected with a single clear error pointing to lerobot==0.5.1. Processor: - GrootN17ActionDecodeStep handles the 2-D (B, D) actions delivered by sync select_action (relative eef/non-eef decode in eval/record flows). - Postprocessor falls back to dataset stats when a raw checkpoint lacks the configured embodiment tag; raw-state cache is per-instance, not process-global; caller overrides (device, rename_map) are honored on the raw-checkpoint branch. - Hub-hosted finetuned checkpoints resolve the processor config via hf_hub_download, with a tolerant retry when inspection fails. - Camera/modality-key mismatches warn (including the zero-match fallback); deprecated Qwen2VLImageProcessorFast replaced with Qwen2VLImageProcessor; removed N1.5 processor steps are stubbed to raise the removal guidance and the action-unpack step is re-registered as _v2. Model: - use_bf16=False no longer crashes (compute_dtype only set when used). - Flash-attention probe is diagnostic-only; forward raises on a missing loss; print() replaced with logging; N1.5 base-path mismatch includes the removal guidance. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
@@ -23,15 +24,29 @@ from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTr
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GROOT_N1_7 = "n1.7"
|
||||
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
|
||||
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
|
||||
# still rejects it). It is retained only so that infer_groot_model_version can recognise
|
||||
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
|
||||
GROOT_N1_5 = "n1.5"
|
||||
# Canonical guidance appended to every error raised when an N1.5 checkpoint, config,
|
||||
# or processor pipeline is detected. Keep this message in sync with docs/source/groot.mdx.
|
||||
GROOT_N1_5_REMOVAL_GUIDANCE = (
|
||||
"GR00T N1.5 support was removed from LeRobot. "
|
||||
"To keep using an N1.5 checkpoint, pin the last release that supports it: "
|
||||
"`pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 "
|
||||
"(model_version='n1.7', base model nvidia/GR00T-N1.7-3B)."
|
||||
)
|
||||
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
|
||||
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
|
||||
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
|
||||
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
|
||||
# explicit 'none' (resolved to None) so an opt-out survives a draccus save/load round-trip.
|
||||
GROOT_ACTION_DECODE_TRANSFORM_AUTO = "auto"
|
||||
|
||||
_GROOT_MODEL_VERSION_ALIASES = {
|
||||
"n1.7": GROOT_N1_7,
|
||||
@@ -41,7 +56,12 @@ _GROOT_MODEL_VERSION_ALIASES = {
|
||||
"1.7": GROOT_N1_7,
|
||||
}
|
||||
|
||||
# Legacy N1.5 spellings, kept ONLY so they can be detected and rejected with
|
||||
# GROOT_N1_5_REMOVAL_GUIDANCE (see GROOT_N1_5 above). Never map these to a supported version.
|
||||
_GROOT_N1_5_VERSION_ALIASES = {"n1.5", "n1_5", "n1d5", "n15", "1.5"}
|
||||
|
||||
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
|
||||
GROOT_ACTION_DECODE_TRANSFORM_AUTO: GROOT_ACTION_DECODE_TRANSFORM_AUTO,
|
||||
"none": None,
|
||||
"": None,
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||
@@ -52,9 +72,10 @@ def normalize_groot_model_version(model_version: str) -> str:
|
||||
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
|
||||
if normalized is None:
|
||||
supported = GROOT_N1_7
|
||||
raise ValueError(
|
||||
f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
|
||||
)
|
||||
message = f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
|
||||
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
return normalized
|
||||
|
||||
|
||||
@@ -286,6 +307,8 @@ def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
|
||||
def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
||||
model_version = config.get("model_version")
|
||||
if isinstance(model_version, str):
|
||||
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
|
||||
return GROOT_N1_5
|
||||
try:
|
||||
return normalize_groot_model_version(model_version)
|
||||
except ValueError:
|
||||
@@ -298,8 +321,17 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
||||
normalized = candidate.lower().replace("-", "_")
|
||||
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
|
||||
return GROOT_N1_7
|
||||
# nvidia/GR00T-N1.5-3B ships model_type 'gr00t_n1_5' and architectures ['GR00T_N1_5'].
|
||||
# Recognise them so N1.5 checkpoints at generic local paths are rejected loudly
|
||||
# instead of being silently treated as N1.7 (see infer_groot_model_version).
|
||||
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
|
||||
return GROOT_N1_5
|
||||
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
|
||||
return GROOT_N1_7
|
||||
# The Eagle VLM backbone is specific to pre-N1.7 GR00T checkpoints (N1.7 uses Cosmos/Qwen3-VL).
|
||||
backbone_cfg = config.get("backbone_cfg")
|
||||
if isinstance(backbone_cfg, dict) and "eagle_path" in backbone_cfg:
|
||||
return GROOT_N1_5
|
||||
return None
|
||||
|
||||
|
||||
@@ -310,27 +342,32 @@ class GrootConfig(PreTrainedConfig):
|
||||
|
||||
# Basic policy settings
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
chunk_size: int = 40
|
||||
n_action_steps: int = 40
|
||||
|
||||
# Dimension settings (must match pretrained GR00T model expectations)
|
||||
# Maximum state dimension. Shorter states will be zero-padded.
|
||||
max_state_dim: int = 64
|
||||
max_state_dim: int = 132
|
||||
|
||||
# Maximum action dimension. Shorter actions will be zero-padded.
|
||||
max_action_dim: int = 32
|
||||
max_action_dim: int = 132
|
||||
|
||||
# Normalization (start with identity, adjust as needed)
|
||||
# GR00T normalizes state/action internally in its processor steps (min/max with
|
||||
# q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor
|
||||
# handles image normalization. The policy therefore does NOT use LeRobot's
|
||||
# NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally
|
||||
# IDENTITY for every feature and is not consulted by make_groot_pre_post_processors.
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
# Image preprocessing (adjust to match Groot's expected input)
|
||||
image_size: tuple[int, int] = (224, 224)
|
||||
# Deprecated and unused: image sizing is handled by the backbone's image processor.
|
||||
# Kept only so config.json files saved with earlier versions still parse.
|
||||
image_size: tuple[int, int] = (256, 256)
|
||||
|
||||
# Groot-specific model parameters (from groot_finetune_script.py)
|
||||
|
||||
@@ -344,7 +381,14 @@ class GrootConfig(PreTrainedConfig):
|
||||
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
|
||||
|
||||
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
|
||||
action_decode_transform: str | None = None
|
||||
# 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no
|
||||
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
|
||||
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
|
||||
|
||||
# Deprecated, GR00T N1.5 only — do not set. Kept so config.json files saved by lerobot<=0.5.1
|
||||
# still parse (draccus rejects unknown fields) and can be rejected in __post_init__ with a
|
||||
# clear error pointing at GROOT_N1_5_REMOVAL_GUIDANCE instead of a cryptic DecodingError.
|
||||
tokenizer_assets_repo: str | None = None
|
||||
|
||||
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
||||
embodiment_tag: str = "new_embodiment"
|
||||
@@ -384,17 +428,13 @@ class GrootConfig(PreTrainedConfig):
|
||||
warmup_ratio: float = 0.05
|
||||
use_bf16: bool = True
|
||||
|
||||
# Dataset parameters
|
||||
# Video backend to use for training ('decord' or 'torchvision_av')
|
||||
# Deprecated Isaac-GR00T runner fields below — unused by the LeRobot N1.7 implementation
|
||||
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
|
||||
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
|
||||
# would break every previously saved groot checkpoint at config-load time.
|
||||
video_backend: str = "decord"
|
||||
|
||||
# Whether to balance dataset weights in mixture datasets
|
||||
balance_dataset_weights: bool = True
|
||||
|
||||
# Whether to sample trajectories weighted by their length
|
||||
balance_trajectory_weights: bool = True
|
||||
|
||||
# Optional dataset paths for delegating training to Isaac-GR00T runner
|
||||
dataset_paths: list[str] | None = None
|
||||
output_dir: str = "./tmp/gr00t"
|
||||
save_steps: int = 1000
|
||||
@@ -405,6 +445,15 @@ class GrootConfig(PreTrainedConfig):
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# 'tokenizer_assets_repo' only ever existed for GR00T N1.5 (lerobot<=0.5.1) and was
|
||||
# serialized into every groot checkpoint config.json, so a value here means a legacy
|
||||
# N1.5 checkpoint or config is being loaded.
|
||||
if self.tokenizer_assets_repo is not None:
|
||||
raise ValueError(
|
||||
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
|
||||
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
)
|
||||
|
||||
self.model_version = normalize_groot_model_version(self.model_version)
|
||||
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
|
||||
if self.base_model_path is None:
|
||||
@@ -416,26 +465,48 @@ class GrootConfig(PreTrainedConfig):
|
||||
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
|
||||
# This matches the embodiment-specific handling already done for the
|
||||
# action execution horizon (see infer_groot_n1_7_action_execution_horizon).
|
||||
if self.action_decode_transform is None and self.embodiment_tag == "libero_sim":
|
||||
self.action_decode_transform = GROOT_ACTION_DECODE_TRANSFORM_LIBERO
|
||||
# Only the 'auto' sentinel resolves to the embodiment default; an explicit
|
||||
# 'none' (normalized to None above) keeps the transform disabled.
|
||||
if self.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_AUTO:
|
||||
self.action_decode_transform = (
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO if self.embodiment_tag == "libero_sim" else None
|
||||
)
|
||||
|
||||
if self.max_state_dim == 64:
|
||||
self.max_state_dim = 132
|
||||
if self.max_action_dim == 32:
|
||||
self.max_action_dim = 132
|
||||
if self.chunk_size == 50:
|
||||
self.chunk_size = 40
|
||||
if self.n_action_steps == 50:
|
||||
self.n_action_steps = 40
|
||||
if tuple(self.image_size) == (224, 224):
|
||||
self.image_size = (256, 256)
|
||||
# GR00T N1.5-era default values (e.g. --policy.chunk_size=50 from old commands or
|
||||
# stale configs) are migrated to the values the N1.7 checkpoints expect, with a
|
||||
# warning. The dataclass defaults are already the N1.7 values, so a plain
|
||||
# GrootConfig() never triggers this.
|
||||
legacy_default_remaps = (
|
||||
("max_state_dim", 64, 132),
|
||||
("max_action_dim", 32, 132),
|
||||
("chunk_size", 50, 40),
|
||||
("n_action_steps", 50, 40),
|
||||
("image_size", (224, 224), (256, 256)),
|
||||
)
|
||||
for field_name, legacy_value, n1_7_value in legacy_default_remaps:
|
||||
current_value = getattr(self, field_name)
|
||||
if isinstance(legacy_value, tuple):
|
||||
current_value = tuple(current_value)
|
||||
if current_value == legacy_value:
|
||||
logger.warning(
|
||||
"GrootConfig.%s=%s matches a legacy GR00T N1.5-era default; remapping it to %s, "
|
||||
"the value expected by GR00T N1.7 checkpoints. Set a different value explicitly "
|
||||
"if this is not what you want.",
|
||||
field_name,
|
||||
legacy_value,
|
||||
n1_7_value,
|
||||
)
|
||||
setattr(self, field_name, n1_7_value)
|
||||
|
||||
inferred_version = infer_groot_model_version(self.base_model_path)
|
||||
if inferred_version is not None and inferred_version != self.model_version:
|
||||
raise ValueError(
|
||||
message = (
|
||||
f"GR00T model_version '{self.model_version}' does not match base_model_path "
|
||||
f"'{self.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
@@ -511,9 +582,22 @@ class GrootConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Return indices for delta actions."""
|
||||
model_action_horizon = infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
||||
return list(range(min(self.chunk_size, model_action_horizon)))
|
||||
"""Return indices for delta actions.
|
||||
|
||||
The model action horizon is read from the checkpoint's processor_config.json
|
||||
when available; the result is cached (keyed on the inputs that determine it) so
|
||||
repeated access during dataset/training setup does not re-read from disk.
|
||||
"""
|
||||
cache_key = (self.base_model_path, self.embodiment_tag, self.chunk_size)
|
||||
cached = getattr(self, "_action_delta_indices_cache", None)
|
||||
if cached is not None and cached[0] == cache_key:
|
||||
return cached[1]
|
||||
model_action_horizon = (
|
||||
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
||||
)
|
||||
indices = list(range(min(self.chunk_size, model_action_horizon)))
|
||||
object.__setattr__(self, "_action_delta_indices_cache", (cache_key, indices))
|
||||
return indices
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
||||
@@ -18,15 +18,12 @@
|
||||
Groot Policy Wrapper for LeRobot Integration
|
||||
|
||||
Minimal integration that delegates to Isaac-GR00T N1.7 components where
|
||||
possible without porting their code.
|
||||
|
||||
Notes:
|
||||
- Dataset loading and full training orchestration is handled by Isaac-GR00T
|
||||
TrainRunner in their codebase. If you want to invoke that flow end-to-end
|
||||
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
|
||||
possible without porting their code. Dataset loading and training
|
||||
orchestration are handled by LeRobot's standard training stack.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
@@ -42,6 +39,8 @@ from lerobot.utils.import_utils import require_package
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..utils import get_device_from_parameters
|
||||
from .configuration_groot import (
|
||||
GROOT_N1_5,
|
||||
GROOT_N1_5_REMOVAL_GUIDANCE,
|
||||
GROOT_N1_7,
|
||||
GrootConfig,
|
||||
infer_groot_model_version,
|
||||
@@ -50,6 +49,8 @@ from .configuration_groot import (
|
||||
normalize_groot_model_version,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", bound="GrootPolicy")
|
||||
|
||||
|
||||
@@ -92,8 +93,11 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
transformers_loading_kwargs={"trust_remote_code": True},
|
||||
)
|
||||
|
||||
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
|
||||
model.config.compute_dtype = model.compute_dtype
|
||||
# GR00TN17 defines no compute_dtype attribute, so only record the
|
||||
# bf16 preference when it is enabled instead of reading a default back.
|
||||
if self.config.use_bf16:
|
||||
model.compute_dtype = "bfloat16"
|
||||
model.config.compute_dtype = "bfloat16"
|
||||
|
||||
return model
|
||||
|
||||
@@ -148,9 +152,10 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
if config is not None
|
||||
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
)
|
||||
print(
|
||||
f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n"
|
||||
f"Loading pretrained model from: {pretrained_name_or_path}"
|
||||
logger.info(
|
||||
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
|
||||
requested_version,
|
||||
pretrained_name_or_path,
|
||||
)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
@@ -181,7 +186,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
if is_finetuned_checkpoint:
|
||||
# This is a fine-tuned LeRobot checkpoint - use parent class loading
|
||||
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||
logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||
return super().from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
config=config,
|
||||
@@ -197,7 +202,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
)
|
||||
|
||||
# This is a base GR00T model - load it fresh
|
||||
print("Detected base GR00T model, loading from HuggingFace...")
|
||||
logger.info("Detected base GR00T model, loading from HuggingFace...")
|
||||
|
||||
if config is None:
|
||||
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
@@ -229,10 +234,13 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
config.model_version = normalize_groot_model_version(config.model_version)
|
||||
inferred_version = infer_groot_model_version(config.base_model_path)
|
||||
if inferred_version is not None and inferred_version != config.model_version:
|
||||
raise ValueError(
|
||||
message = (
|
||||
f"GR00T model_version '{config.model_version}' does not match base_model_path "
|
||||
f"'{config.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
# Create a fresh policy instance - this will automatically load the GR00T model
|
||||
# in __init__ via _create_groot_model()
|
||||
policy = cls(config)
|
||||
@@ -297,9 +305,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
allowed_base.add("action_mask")
|
||||
|
||||
return {
|
||||
k: v
|
||||
for k, v in batch.items()
|
||||
if k in allowed_base and not (k.startswith("next.") or k == "info")
|
||||
k: v for k, v in batch.items() if k in allowed_base and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
|
||||
def _prepare_n1_7_rtc_inputs(
|
||||
@@ -320,9 +326,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
if prev_actions.ndim == 2:
|
||||
prev_actions = prev_actions.unsqueeze(0)
|
||||
elif prev_actions.ndim != 3:
|
||||
raise ValueError(
|
||||
"prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC."
|
||||
)
|
||||
raise ValueError("prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC.")
|
||||
|
||||
state = inputs.get("state")
|
||||
if state is None:
|
||||
@@ -331,9 +335,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
if prev_actions.shape[0] == 1 and batch_size > 1:
|
||||
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
|
||||
elif prev_actions.shape[0] != batch_size:
|
||||
raise ValueError(
|
||||
"prev_chunk_left_over batch size must match the current GR00T N1.7 batch size."
|
||||
)
|
||||
raise ValueError("prev_chunk_left_over batch size must match the current GR00T N1.7 batch size.")
|
||||
|
||||
# The generic LeRobot RTC engine pads short leftovers with exact zero
|
||||
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
|
||||
@@ -346,7 +348,9 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
else:
|
||||
return inputs, None
|
||||
|
||||
model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size))
|
||||
model_action_horizon = int(
|
||||
getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)
|
||||
)
|
||||
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
|
||||
if prev_actions.shape[1] > model_action_horizon:
|
||||
prev_actions = prev_actions[:, -model_action_horizon:, :]
|
||||
@@ -409,6 +413,11 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
|
||||
loss = outputs.get("loss")
|
||||
if loss is None:
|
||||
raise RuntimeError(
|
||||
"GR00T model.forward did not return a 'loss'. Training batches must include "
|
||||
"'action' and 'action_mask'; check the preprocessor output."
|
||||
)
|
||||
|
||||
loss_dict = {"loss": loss.item()}
|
||||
|
||||
@@ -471,33 +480,21 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
# Internal helpers
|
||||
# -------------------------
|
||||
def _handle_flash_attention_compatibility(self) -> None:
|
||||
"""Handle Flash Attention compatibility issues by setting environment variables.
|
||||
"""Log Flash Attention availability (diagnostic only).
|
||||
|
||||
This addresses the common 'undefined symbol' error that occurs when Flash Attention
|
||||
is compiled against a different PyTorch version than what's currently installed.
|
||||
The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is
|
||||
unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not
|
||||
change behaviour or mutate global state.
|
||||
"""
|
||||
|
||||
# Set environment variables to handle Flash Attention compatibility
|
||||
# These help with symbol resolution issues
|
||||
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
|
||||
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
|
||||
|
||||
# Try to import flash_attn and handle failures gracefully
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
|
||||
except ImportError as e:
|
||||
print(f"[GROOT] Flash Attention not available: {e}")
|
||||
print("[GROOT] Will use fallback attention mechanism")
|
||||
except Exception as e:
|
||||
if "undefined symbol" in str(e):
|
||||
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
|
||||
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
|
||||
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
|
||||
print(" pip uninstall flash-attn")
|
||||
print(" pip install --no-build-isolation flash-attn==2.6.3")
|
||||
print("[GROOT] Continuing with fallback attention mechanism")
|
||||
else:
|
||||
print(f"[GROOT] Flash Attention error: {e}")
|
||||
print("[GROOT] Continuing with fallback attention mechanism")
|
||||
logger.debug("Flash Attention %s is available.", flash_attn.__version__)
|
||||
except ImportError:
|
||||
logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is "
|
||||
"an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.",
|
||||
e,
|
||||
)
|
||||
|
||||
@@ -15,14 +15,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, fields, is_dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
@@ -59,6 +61,7 @@ from lerobot.utils.constants import (
|
||||
|
||||
from .configuration_groot import (
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||
GROOT_N1_5_REMOVAL_GUIDANCE,
|
||||
GROOT_N1_7_BACKBONE_MODEL,
|
||||
GrootConfig,
|
||||
is_raw_groot_n1_7_checkpoint,
|
||||
@@ -80,12 +83,6 @@ N1_7_EMBODIMENT_MAPPING = {
|
||||
"new_embodiment": 10,
|
||||
}
|
||||
|
||||
_N1_7_RAW_STATE_CACHE: dict[str, dict[str, np.ndarray]] = {}
|
||||
|
||||
|
||||
def _n1_7_state_cache_key(value: str | None) -> str:
|
||||
return value or "groot_n1_7_default"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GrootN17CheckpointProcessorAssets:
|
||||
@@ -471,25 +468,98 @@ def _legacy_groot_processor_overrides(
|
||||
env_action_dim = int(config.output_features[ACTION].shape[0])
|
||||
except Exception:
|
||||
env_action_dim = 0
|
||||
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v1", {}))
|
||||
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v2", {}))
|
||||
action_unpack_overrides["normalize_min_max"] = True
|
||||
action_unpack_overrides["env_action_dim"] = env_action_dim
|
||||
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = action_unpack_overrides
|
||||
postprocessor_overrides["groot_action_unpack_unnormalize_v2"] = action_unpack_overrides
|
||||
|
||||
return preprocessor_overrides, postprocessor_overrides
|
||||
|
||||
|
||||
def _local_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
|
||||
def _pretrained_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
|
||||
"""Check whether a serialized processor pipeline contains a registry step.
|
||||
|
||||
Resolves the processor config from a local directory or, for Hub repo ids,
|
||||
via ``hf_hub_download`` (which serves the cached copy when offline). Returns
|
||||
False when the config cannot be resolved; loading then proceeds with the
|
||||
legacy overrides and `make_groot_pre_post_processors_from_pretrained` retries
|
||||
without them if they do not match the serialized pipeline.
|
||||
"""
|
||||
path = Path(pretrained_path).expanduser()
|
||||
if not path.is_dir():
|
||||
if path.is_dir():
|
||||
config = _read_json(path / config_filename)
|
||||
elif path.exists():
|
||||
return False
|
||||
config = _read_json(path / config_filename)
|
||||
else:
|
||||
try:
|
||||
config_path = hf_hub_download(
|
||||
repo_id=str(pretrained_path), filename=config_filename, repo_type="model"
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
config = _read_json(Path(config_path))
|
||||
steps = config.get("steps", [])
|
||||
if not isinstance(steps, list):
|
||||
return False
|
||||
return any(isinstance(step, dict) and step.get("registry_name") == step_name for step in steps)
|
||||
|
||||
|
||||
def _apply_groot_step_overrides(
|
||||
pipeline: PolicyProcessorPipeline,
|
||||
overrides: dict[str, Any] | None,
|
||||
) -> None:
|
||||
"""Apply ``from_pretrained``-style step overrides to a freshly built pipeline.
|
||||
|
||||
Raw N1.7 checkpoints build their processors from scratch instead of
|
||||
deserializing them, so caller overrides must be applied to the constructed
|
||||
steps. Override keys match a step's registry name or, as a convenience, its
|
||||
class name (``PolicyProcessorPipeline.from_pretrained`` matches registered
|
||||
steps by registry name only — prefer registry names so overrides keep
|
||||
working after the checkpoint is converted and reloaded from a serialized
|
||||
pipeline). Keys or fields that match nothing raise instead of being dropped
|
||||
silently.
|
||||
"""
|
||||
|
||||
if not overrides:
|
||||
return
|
||||
|
||||
def _step_keys(step: ProcessorStep) -> set[str]:
|
||||
keys = {type(step).__name__}
|
||||
registry_name = getattr(type(step), "_registry_name", None)
|
||||
if registry_name:
|
||||
keys.add(registry_name)
|
||||
return keys
|
||||
|
||||
for override_key, step_overrides in overrides.items():
|
||||
matched_steps = [step for step in pipeline.steps if override_key in _step_keys(step)]
|
||||
if not matched_steps:
|
||||
available = [
|
||||
getattr(type(step), "_registry_name", None) or type(step).__name__ for step in pipeline.steps
|
||||
]
|
||||
raise KeyError(
|
||||
f"Override key '{override_key}' does not match any step of the GR00T processor pipeline "
|
||||
f"built for this raw N1.7 checkpoint. Available step keys: {available}."
|
||||
)
|
||||
for step in matched_steps:
|
||||
if not is_dataclass(step):
|
||||
raise TypeError(
|
||||
f"Cannot apply overrides to step '{override_key}': it is not a dataclass step."
|
||||
)
|
||||
init_field_names = {f.name for f in fields(step) if f.init}
|
||||
for field_name, value in dict(step_overrides).items():
|
||||
if field_name not in init_field_names:
|
||||
raise TypeError(
|
||||
f"Override field '{field_name}' is not a config field of step '{override_key}'. "
|
||||
f"Available fields: {sorted(init_field_names)}."
|
||||
)
|
||||
setattr(step, field_name, value)
|
||||
# Re-derive attributes computed from the overridden config (e.g.
|
||||
# DeviceProcessorStep resolves its torch.device in __post_init__).
|
||||
post_init = getattr(step, "__post_init__", None)
|
||||
if callable(post_init):
|
||||
post_init()
|
||||
|
||||
|
||||
def make_groot_pre_post_processors_from_pretrained(
|
||||
config: GrootConfig,
|
||||
pretrained_path: str,
|
||||
@@ -508,12 +578,20 @@ def make_groot_pre_post_processors_from_pretrained(
|
||||
if is_raw_groot_n1_7_checkpoint(pretrained_path):
|
||||
processor_cfg = copy(config)
|
||||
processor_cfg.base_model_path = str(pretrained_path)
|
||||
return make_groot_pre_post_processors(
|
||||
preprocessor, postprocessor = make_groot_pre_post_processors(
|
||||
config=processor_cfg,
|
||||
dataset_stats=dataset_stats,
|
||||
)
|
||||
# Raw checkpoints have no serialized pipelines to load overrides into,
|
||||
# so apply the caller overrides (e.g. device and rename_map from
|
||||
# lerobot-eval or the policy server) to the freshly built steps.
|
||||
_apply_groot_step_overrides(preprocessor, preprocessor_overrides)
|
||||
_apply_groot_step_overrides(postprocessor, postprocessor_overrides)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
if _local_processor_config_has_step(
|
||||
caller_preprocessor_overrides = dict(preprocessor_overrides or {})
|
||||
caller_postprocessor_overrides = dict(postprocessor_overrides or {})
|
||||
if _pretrained_processor_config_has_step(
|
||||
pretrained_path,
|
||||
postprocessor_config_filename,
|
||||
"groot_n1_7_action_decode_v1",
|
||||
@@ -521,15 +599,55 @@ def make_groot_pre_post_processors_from_pretrained(
|
||||
# Converted raw N1.7 checkpoints already carry the checkpoint-specific
|
||||
# action decoder. Adding the legacy action-unpack override would target
|
||||
# a step that is not present and break loading.
|
||||
preprocessor_overrides = dict(preprocessor_overrides or {})
|
||||
postprocessor_overrides = dict(postprocessor_overrides or {})
|
||||
applied_legacy_overrides = False
|
||||
preprocessor_overrides = caller_preprocessor_overrides
|
||||
postprocessor_overrides = caller_postprocessor_overrides
|
||||
else:
|
||||
applied_legacy_overrides = True
|
||||
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
|
||||
config=config,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
postprocessor_overrides=postprocessor_overrides,
|
||||
)
|
||||
try:
|
||||
preprocessor, postprocessor = _load_groot_processor_pipelines(
|
||||
pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
postprocessor_overrides=postprocessor_overrides,
|
||||
preprocessor_config_filename=preprocessor_config_filename,
|
||||
postprocessor_config_filename=postprocessor_config_filename,
|
||||
)
|
||||
except KeyError:
|
||||
if not applied_legacy_overrides:
|
||||
raise
|
||||
# The legacy overrides target steps that are absent from the serialized
|
||||
# pipelines (e.g. a converted raw N1.7 checkpoint whose postprocessor
|
||||
# config could not be inspected before loading); retry with the caller
|
||||
# overrides only.
|
||||
preprocessor, postprocessor = _load_groot_processor_pipelines(
|
||||
pretrained_path,
|
||||
preprocessor_overrides=caller_preprocessor_overrides,
|
||||
postprocessor_overrides=caller_postprocessor_overrides,
|
||||
preprocessor_config_filename=preprocessor_config_filename,
|
||||
postprocessor_config_filename=postprocessor_config_filename,
|
||||
)
|
||||
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
|
||||
_reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
|
||||
def _load_groot_processor_pipelines(
|
||||
pretrained_path: str,
|
||||
*,
|
||||
preprocessor_overrides: dict[str, Any],
|
||||
postprocessor_overrides: dict[str, Any],
|
||||
preprocessor_config_filename: str,
|
||||
postprocessor_config_filename: str,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
preprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=preprocessor_config_filename,
|
||||
@@ -544,7 +662,6 @@ def make_groot_pre_post_processors_from_pretrained(
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
)
|
||||
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
|
||||
@@ -564,6 +681,28 @@ def _reconnect_groot_relative_absolute_steps(
|
||||
step.relative_step = relative_step
|
||||
|
||||
|
||||
def _reconnect_groot_n1_7_pack_decode_steps(
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
) -> None:
|
||||
"""Re-link a deserialized N1.7 action decode step to its pack step.
|
||||
|
||||
The pack step holds the per-instance raw-state cache that relative-action
|
||||
decoding reads its reference state from; the link itself is not serialized.
|
||||
"""
|
||||
|
||||
pack_step = next(
|
||||
(step for step in preprocessor.steps if isinstance(step, GrootN17PackInputsStep)),
|
||||
None,
|
||||
)
|
||||
if pack_step is None:
|
||||
return
|
||||
|
||||
for step in postprocessor.steps:
|
||||
if isinstance(step, GrootN17ActionDecodeStep) and step.pack_step is None:
|
||||
step.pack_step = pack_step
|
||||
|
||||
|
||||
def make_groot_pre_post_processors(
|
||||
config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[
|
||||
@@ -607,9 +746,12 @@ def make_groot_pre_post_processors(
|
||||
else action_horizon
|
||||
)
|
||||
checkpoint_stats = checkpoint_assets.stats if checkpoint_assets is not None else None
|
||||
padded_stats = checkpoint_stats if _has_modality_stats(checkpoint_stats) else (dataset_stats or {})
|
||||
checkpoint_has_stats = _has_modality_stats(checkpoint_stats)
|
||||
padded_stats = checkpoint_stats if checkpoint_has_stats else (dataset_stats or {})
|
||||
embodiment_mapping = (
|
||||
checkpoint_assets.embodiment_mapping if checkpoint_assets is not None else dict(N1_7_EMBODIMENT_MAPPING)
|
||||
checkpoint_assets.embodiment_mapping
|
||||
if checkpoint_assets is not None
|
||||
else dict(N1_7_EMBODIMENT_MAPPING)
|
||||
)
|
||||
formalize_language = checkpoint_assets.formalize_language if checkpoint_assets is not None else True
|
||||
clip_outliers = checkpoint_assets.clip_outliers if checkpoint_assets is not None else True
|
||||
@@ -618,7 +760,6 @@ def make_groot_pre_post_processors(
|
||||
env_action_dim = int(config.output_features[ACTION].shape[0])
|
||||
except Exception:
|
||||
env_action_dim = 0
|
||||
state_cache_key = f"groot_n1_7:{config.embodiment_tag}"
|
||||
pack_step = GrootN17PackInputsStep(
|
||||
state_horizon=1,
|
||||
action_horizon=action_horizon,
|
||||
@@ -636,7 +777,6 @@ def make_groot_pre_post_processors(
|
||||
video_modality_keys=video_modality_keys,
|
||||
raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None,
|
||||
modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None,
|
||||
state_cache_key=state_cache_key,
|
||||
)
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
@@ -647,14 +787,31 @@ def make_groot_pre_post_processors(
|
||||
model_name=config.n1_7_backbone_model,
|
||||
image_crop_size=checkpoint_assets.image_crop_size if checkpoint_assets is not None else None,
|
||||
image_target_size=checkpoint_assets.image_target_size if checkpoint_assets is not None else None,
|
||||
shortest_image_edge=checkpoint_assets.shortest_image_edge if checkpoint_assets is not None else None,
|
||||
shortest_image_edge=checkpoint_assets.shortest_image_edge
|
||||
if checkpoint_assets is not None
|
||||
else None,
|
||||
crop_fraction=checkpoint_assets.crop_fraction if checkpoint_assets is not None else None,
|
||||
use_albumentations=checkpoint_assets.use_albumentations if checkpoint_assets is not None else False,
|
||||
use_albumentations=checkpoint_assets.use_albumentations
|
||||
if checkpoint_assets is not None
|
||||
else False,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
if checkpoint_assets is None:
|
||||
if checkpoint_assets is not None and not checkpoint_has_stats and not _has_modality_stats(padded_stats):
|
||||
raise ValueError(
|
||||
f"GR00T N1.7 checkpoint '{config.base_model_path}' has no statistics for embodiment tag "
|
||||
f"'{config.embodiment_tag}', and no dataset stats were provided to fall back to, so "
|
||||
"actions cannot be normalized or decoded. Pass dataset_stats, or set "
|
||||
"config.embodiment_tag to an embodiment present in the checkpoint's statistics.json."
|
||||
)
|
||||
if checkpoint_assets is None or not checkpoint_has_stats:
|
||||
# When the checkpoint sidecars have no stats for the configured
|
||||
# embodiment tag (e.g. finetuning a raw base checkpoint with the
|
||||
# default 'new_embodiment' tag), the pack step above normalized with
|
||||
# the dataset stats; the decode step must invert with the same stats
|
||||
# instead of using a checkpoint decoder whose empty stats would
|
||||
# silently return normalized [-1, 1] actions.
|
||||
action_decode_step: ProcessorStep = GrootActionUnpackUnnormalizeStep(
|
||||
env_action_dim=env_action_dim,
|
||||
stats=padded_stats,
|
||||
@@ -669,7 +826,6 @@ def make_groot_pre_post_processors(
|
||||
use_percentiles=checkpoint_assets.use_percentiles,
|
||||
use_relative_action=checkpoint_assets.use_relative_action,
|
||||
pack_step=pack_step,
|
||||
state_cache_key=state_cache_key,
|
||||
action_decode_transform=config.action_decode_transform,
|
||||
)
|
||||
|
||||
@@ -770,7 +926,7 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
|
||||
try:
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
Qwen2VLImageProcessorFast,
|
||||
Qwen2VLImageProcessor,
|
||||
Qwen3VLProcessor,
|
||||
Qwen3VLVideoProcessor,
|
||||
)
|
||||
@@ -781,7 +937,7 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
|
||||
) from exc
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
image_processor = Qwen2VLImageProcessorFast.from_pretrained(model_name, trust_remote_code=True)
|
||||
image_processor = Qwen2VLImageProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
video_processor = Qwen3VLVideoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
proc = Qwen3VLProcessor(
|
||||
image_processor=image_processor,
|
||||
@@ -902,8 +1058,11 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
video_modality_keys: list[str] | None = None
|
||||
raw_stats: dict[str, Any] | None = None
|
||||
modality_config: dict[str, Any] | None = None
|
||||
# Unused: kept so serialized configs that include it still load. The raw
|
||||
# state cache is per instance (_last_raw_state), never process-global.
|
||||
state_cache_key: str = ""
|
||||
_last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False)
|
||||
_warned_image_keys: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def _ordered_image_keys(self, obs: dict[str, Any]) -> list[str]:
|
||||
available = {key for key in obs if key.startswith(OBS_IMAGES)}
|
||||
@@ -913,19 +1072,56 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
return sorted(available)
|
||||
|
||||
ordered: list[str] = []
|
||||
unmatched: list[str] = []
|
||||
for modality_key in self.video_modality_keys:
|
||||
candidates = [f"{OBS_IMAGES}.{modality_key}"]
|
||||
# Alias for datasets converted with generic camera names (e.g. the
|
||||
# LIBERO conversions expose the wrist camera as
|
||||
# `observation.images.image2`), so raw N1.7 LIBERO checkpoints
|
||||
# match those datasets out of the box.
|
||||
if modality_key == "wrist_image":
|
||||
candidates.append(f"{OBS_IMAGES}.image2")
|
||||
elif modality_key == "image":
|
||||
candidates.append(f"{OBS_IMAGES}.image")
|
||||
|
||||
match = next((candidate for candidate in candidates if candidate in available), None)
|
||||
if match is not None:
|
||||
if match is None:
|
||||
unmatched.append(modality_key)
|
||||
else:
|
||||
ordered.append(match)
|
||||
|
||||
if not ordered:
|
||||
if not self._warned_image_keys:
|
||||
self._warned_image_keys = True
|
||||
logging.warning(
|
||||
"None of the GR00T N1.7 checkpoint video modality keys %s match a camera among %s; "
|
||||
"falling back to feeding all cameras in alphabetical order, which is unlikely to be "
|
||||
"the layout the checkpoint was trained with. Rename the dataset cameras (e.g. via "
|
||||
"--rename_map) to match the checkpoint keys.",
|
||||
self.video_modality_keys,
|
||||
sorted(available),
|
||||
)
|
||||
return sorted(available)
|
||||
unused = sorted(available - set(ordered))
|
||||
if (unmatched or unused) and not self._warned_image_keys:
|
||||
self._warned_image_keys = True
|
||||
if unmatched:
|
||||
logging.warning(
|
||||
"GR00T N1.7 checkpoint video modality keys %s have no matching camera among %s; "
|
||||
"the model will receive %d view(s) instead of the %d it was trained with. Rename "
|
||||
"the dataset cameras (e.g. via --rename_map) to match the checkpoint keys %s.",
|
||||
unmatched,
|
||||
sorted(available),
|
||||
len(ordered),
|
||||
len(self.video_modality_keys),
|
||||
self.video_modality_keys,
|
||||
)
|
||||
if unused:
|
||||
logging.warning(
|
||||
"Dropping camera(s) %s: the GR00T N1.7 checkpoint only consumes the video modality "
|
||||
"keys %s, which matched %s.",
|
||||
unused,
|
||||
self.video_modality_keys,
|
||||
ordered,
|
||||
)
|
||||
return ordered
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -988,7 +1184,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
start_idx += dim
|
||||
if grouped:
|
||||
self._last_raw_state = grouped
|
||||
_N1_7_RAW_STATE_CACHE[_n1_7_state_cache_key(self.state_cache_key)] = grouped
|
||||
|
||||
img_keys = self._ordered_image_keys(obs)
|
||||
if img_keys:
|
||||
@@ -1101,7 +1296,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
"video_modality_keys": self.video_modality_keys,
|
||||
"raw_stats": self.raw_stats,
|
||||
"modality_config": self.modality_config,
|
||||
"state_cache_key": self.state_cache_key,
|
||||
}
|
||||
|
||||
def get_cached_raw_state(self) -> dict[str, np.ndarray] | None:
|
||||
@@ -1223,6 +1417,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
"use_albumentations": self.use_albumentations,
|
||||
}
|
||||
|
||||
|
||||
def _stat_dim_from_entry(entry: dict[str, Any]) -> int:
|
||||
for stat_name in ("mean", "q01", "min", "max", "std"):
|
||||
value = entry.get(stat_name)
|
||||
@@ -1351,6 +1546,18 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
group with the checkpoint stats, converts relative groups to absolute values
|
||||
using the raw state cached during packing, concatenates groups in checkpoint
|
||||
order, and finally slices to the environment action dimension.
|
||||
|
||||
Relative-action decoding reads the reference state from the connected
|
||||
``pack_step`` (re-linked after ``from_pretrained`` by
|
||||
``_reconnect_groot_n1_7_pack_decode_steps``), i.e. the state seen by the
|
||||
most recent preprocess call. Engines that decode the whole chunk right
|
||||
after prediction (RTC, async policy server) therefore use the
|
||||
prediction-time state, matching Isaac-GR00T. The sync per-step queue path
|
||||
instead decodes each popped (B, D) action against the latest observation:
|
||||
the reference can be newer than the observation the chunk was predicted
|
||||
from, and per-timestep relative stats are applied as if the popped action
|
||||
were chunk step 0. Fixing that would require carrying the reference state
|
||||
and chunk index alongside each queued action through the postprocessor.
|
||||
"""
|
||||
|
||||
env_action_dim: int = 0
|
||||
@@ -1358,6 +1565,7 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
modality_config: dict[str, Any] | None = None
|
||||
use_percentiles: bool = False
|
||||
use_relative_action: bool = False
|
||||
# Unused: kept so serialized configs that include it still load.
|
||||
state_cache_key: str = ""
|
||||
action_decode_transform: str | None = None
|
||||
pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False)
|
||||
@@ -1378,6 +1586,12 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
return transition
|
||||
|
||||
action_np = action.detach().cpu().float().numpy()
|
||||
# The sync action queue postprocesses popped actions as (B, D); decode
|
||||
# them as single-step (B, 1, D) chunks and squeeze the horizon back at
|
||||
# the end so both ranks share the chunk decode logic below.
|
||||
squeeze_horizon = action_np.ndim == 2
|
||||
if squeeze_horizon:
|
||||
action_np = action_np[:, None, :]
|
||||
valid_horizon = _n1_7_decode_valid_horizon(action_config, action_np)
|
||||
if valid_horizon is not None:
|
||||
action_np = action_np[:, :valid_horizon]
|
||||
@@ -1405,17 +1619,24 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
use_relative_action=self.use_relative_action,
|
||||
use_percentiles=self.use_percentiles,
|
||||
)
|
||||
# Per-timestep stats carry one row per chunk step; align them with
|
||||
# the decoded horizon (chunks always start at step 0, and a popped
|
||||
# (B, D) action is decoded as step 0).
|
||||
if min_v.ndim == 2 and normalized.shape[1] <= min_v.shape[0]:
|
||||
min_v = min_v[: normalized.shape[1]]
|
||||
max_v = max_v[: normalized.shape[1]]
|
||||
decoded_groups[key] = _unnormalize_min_max(normalized, min_v, max_v)
|
||||
start_idx += dim
|
||||
|
||||
if self.use_relative_action:
|
||||
raw_state = self.pack_step.get_cached_raw_state() if self.pack_step is not None else None
|
||||
if raw_state is None:
|
||||
raw_state = _N1_7_RAW_STATE_CACHE.get(_n1_7_state_cache_key(self.state_cache_key))
|
||||
if raw_state is None:
|
||||
raise RuntimeError(
|
||||
"GrootN17ActionDecodeStep requires cached raw state from GrootN17PackInputsStep "
|
||||
"to convert relative N1.7 actions back to absolute actions."
|
||||
"GrootN17ActionDecodeStep requires the raw state cached by its connected "
|
||||
"GrootN17PackInputsStep to convert relative N1.7 actions back to absolute actions. "
|
||||
"Build both pipelines through make_groot_pre_post_processors (or load them together "
|
||||
"via make_groot_pre_post_processors_from_pretrained) and run the preprocessor on an "
|
||||
"observation before decoding actions."
|
||||
)
|
||||
for idx, key in enumerate(action_keys):
|
||||
if not isinstance(key, str) or key not in decoded_groups or idx >= len(action_configs):
|
||||
@@ -1451,6 +1672,8 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
action_keys=action_keys,
|
||||
decoded_groups=decoded_groups,
|
||||
)
|
||||
if squeeze_horizon:
|
||||
decoded = decoded[:, 0]
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.ACTION] = torch.as_tensor(
|
||||
decoded, dtype=action.dtype, device=action.device
|
||||
@@ -1467,13 +1690,15 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
"modality_config": self.modality_config,
|
||||
"use_percentiles": self.use_percentiles,
|
||||
"use_relative_action": self.use_relative_action,
|
||||
"state_cache_key": self.state_cache_key,
|
||||
"action_decode_transform": self.action_decode_transform,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v1")
|
||||
# v2: unlike the N1.5-era v1 step, this step no longer collapses (B, T, D)
|
||||
# action chunks to the last timestep, so old serialized v1 pipelines must not
|
||||
# silently load into it (v1 is stubbed below with the removal guidance).
|
||||
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v2")
|
||||
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
||||
env_action_dim: int = 0
|
||||
# Apply inverse of min-max normalization if it was used in preprocessor
|
||||
@@ -1585,3 +1810,37 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
||||
|
||||
if reconstructed:
|
||||
self.stats = reconstructed
|
||||
|
||||
|
||||
def _register_removed_n1_5_step_stub(registry_name: str) -> None:
|
||||
"""Register a stub for a processor step that only GR00T N1.5 pipelines serialize.
|
||||
|
||||
Saved N1.5 checkpoints reference these registry names in their processor JSON
|
||||
files. Deserializing them must fail with the canonical N1.5 removal guidance
|
||||
instead of an opaque registry KeyError (or, for
|
||||
``groot_action_unpack_unnormalize_v1``, silently loading the v2 step whose
|
||||
action-chunk semantics changed).
|
||||
"""
|
||||
|
||||
@ProcessorStepRegistry.register(name=registry_name)
|
||||
class _RemovedGrootN15ProcessorStep(ProcessorStep):
|
||||
def __init__(self, **_kwargs: Any) -> None:
|
||||
raise ValueError(
|
||||
f"Processor step '{registry_name}' belongs to a GR00T N1.5 processor pipeline. "
|
||||
f"{GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
raise NotImplementedError
|
||||
|
||||
def transform_features(self, features):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
for _removed_n1_5_step_name in (
|
||||
"groot_pack_inputs_v3",
|
||||
"groot_eagle_encode_v3",
|
||||
"groot_eagle_collate_v3",
|
||||
"groot_action_unpack_unnormalize_v1",
|
||||
):
|
||||
_register_removed_n1_5_step_stub(_removed_n1_5_step_name)
|
||||
|
||||
Reference in New Issue
Block a user