minor fixes

This commit is contained in:
Steven Palma
2026-06-12 18:45:24 +02:00
parent 9ce6633518
commit f81fc79564
4 changed files with 56 additions and 40 deletions
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
import torch
@@ -42,6 +43,9 @@ else:
Timesteps = None
logger = logging.getLogger(__name__)
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim, compute_dtype=torch.float32):
require_package("diffusers", extra="groot")
@@ -265,8 +269,8 @@ class DiT(ModelMixin, ConfigMixin):
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
print(
"Total number of DiT parameters: ",
logger.debug(
"Total number of DiT parameters: %d",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
@@ -426,8 +430,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin):
for _ in range(self.config.num_layers)
]
)
print(
"Total number of SelfAttentionTransformer parameters: ",
logger.debug(
"Total number of SelfAttentionTransformer parameters: %d",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
@@ -352,12 +352,16 @@ class GrootConfig(PreTrainedConfig):
# Maximum action dimension. Shorter actions will be zero-padded.
max_action_dim: int = 132
# Normalization (start with identity, adjust as needed)
# GR00T normalizes state/action internally in its processor steps (min/max with
# q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor
# handles image normalization. The policy therefore does NOT use LeRobot's
# NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally
# IDENTITY for every feature and is not consulted by make_groot_pre_post_processors.
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
)
@@ -578,11 +582,22 @@ class GrootConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions."""
"""Return indices for delta actions.
The model action horizon is read from the checkpoint's processor_config.json
when available; the result is cached (keyed on the inputs that determine it) so
repeated access during dataset/training setup does not re-read from disk.
"""
cache_key = (self.base_model_path, self.embodiment_tag, self.chunk_size)
cached = getattr(self, "_action_delta_indices_cache", None)
if cached is not None and cached[0] == cache_key:
return cached[1]
model_action_horizon = (
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
)
return list(range(min(self.chunk_size, model_action_horizon)))
indices = list(range(min(self.chunk_size, model_action_horizon)))
object.__setattr__(self, "_action_delta_indices_cache", (cache_key, indices))
return indices
@property
def reward_delta_indices(self) -> None:
+1 -1
View File
@@ -71,7 +71,7 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"backbone_embedding_dim": 2048,
"tune_llm": False,
"tune_visual": False,
"select_layer": 12,
"select_layer": 16, # N1.7-3B checkpoint value; real checkpoint loads override this from config.json
"reproject_vision": False,
"use_flash_attention": True,
"load_bf16": False,
+27 -30
View File
@@ -23,6 +23,7 @@ orchestration are handled by LeRobot's standard training stack.
"""
import builtins
import logging
import os
from collections import deque
from pathlib import Path
@@ -48,6 +49,8 @@ from .configuration_groot import (
normalize_groot_model_version,
)
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="GrootPolicy")
@@ -149,9 +152,10 @@ class GrootPolicy(PreTrainedPolicy):
if config is not None
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
)
print(
f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n"
f"Loading pretrained model from: {pretrained_name_or_path}"
logger.info(
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
requested_version,
pretrained_name_or_path,
)
model_id = str(pretrained_name_or_path)
@@ -182,7 +186,7 @@ class GrootPolicy(PreTrainedPolicy):
if is_finetuned_checkpoint:
# This is a fine-tuned LeRobot checkpoint - use parent class loading
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
return super().from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
config=config,
@@ -198,7 +202,7 @@ class GrootPolicy(PreTrainedPolicy):
)
# This is a base GR00T model - load it fresh
print("Detected base GR00T model, loading from HuggingFace...")
logger.info("Detected base GR00T model, loading from HuggingFace...")
if config is None:
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
@@ -409,6 +413,11 @@ class GrootPolicy(PreTrainedPolicy):
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
loss = outputs.get("loss")
if loss is None:
raise RuntimeError(
"GR00T model.forward did not return a 'loss'. Training batches must include "
"'action' and 'action_mask'; check the preprocessor output."
)
loss_dict = {"loss": loss.item()}
@@ -471,33 +480,21 @@ class GrootPolicy(PreTrainedPolicy):
# Internal helpers
# -------------------------
def _handle_flash_attention_compatibility(self) -> None:
"""Handle Flash Attention compatibility issues by setting environment variables.
"""Log Flash Attention availability (diagnostic only).
This addresses the common 'undefined symbol' error that occurs when Flash Attention
is compiled against a different PyTorch version than what's currently installed.
The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is
unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not
change behaviour or mutate global state.
"""
# Set environment variables to handle Flash Attention compatibility
# These help with symbol resolution issues
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
# Try to import flash_attn and handle failures gracefully
try:
import flash_attn
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
except ImportError as e:
print(f"[GROOT] Flash Attention not available: {e}")
print("[GROOT] Will use fallback attention mechanism")
except Exception as e:
if "undefined symbol" in str(e):
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
print(" pip uninstall flash-attn")
print(" pip install --no-build-isolation flash-attn==2.6.3")
print("[GROOT] Continuing with fallback attention mechanism")
else:
print(f"[GROOT] Flash Attention error: {e}")
print("[GROOT] Continuing with fallback attention mechanism")
logger.debug("Flash Attention %s is available.", flash_attn.__version__)
except ImportError:
logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.")
except Exception as e: # noqa: BLE001
logger.warning(
"Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is "
"an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.",
e,
)