diff --git a/src/lerobot/policies/groot/action_head/cross_attention_dit.py b/src/lerobot/policies/groot/action_head/cross_attention_dit.py index 0991ef029..7b531d92e 100755 --- a/src/lerobot/policies/groot/action_head/cross_attention_dit.py +++ b/src/lerobot/policies/groot/action_head/cross_attention_dit.py @@ -14,6 +14,7 @@ # limitations under the License. +import logging from typing import TYPE_CHECKING import torch @@ -42,6 +43,9 @@ else: Timesteps = None +logger = logging.getLogger(__name__) + + class TimestepEncoder(nn.Module): def __init__(self, embedding_dim, compute_dtype=torch.float32): require_package("diffusers", extra="groot") @@ -265,8 +269,8 @@ class DiT(ModelMixin, ConfigMixin): self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim) - print( - "Total number of DiT parameters: ", + logger.debug( + "Total number of DiT parameters: %d", sum(p.numel() for p in self.parameters() if p.requires_grad), ) @@ -426,8 +430,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin): for _ in range(self.config.num_layers) ] ) - print( - "Total number of SelfAttentionTransformer parameters: ", + logger.debug( + "Total number of SelfAttentionTransformer parameters: %d", sum(p.numel() for p in self.parameters() if p.requires_grad), ) diff --git a/src/lerobot/policies/groot/configuration_groot.py b/src/lerobot/policies/groot/configuration_groot.py index f54683d6d..8a234c55a 100644 --- a/src/lerobot/policies/groot/configuration_groot.py +++ b/src/lerobot/policies/groot/configuration_groot.py @@ -352,12 +352,16 @@ class GrootConfig(PreTrainedConfig): # Maximum action dimension. Shorter actions will be zero-padded. max_action_dim: int = 132 - # Normalization (start with identity, adjust as needed) + # GR00T normalizes state/action internally in its processor steps (min/max with + # q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor + # handles image normalization. The policy therefore does NOT use LeRobot's + # NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally + # IDENTITY for every feature and is not consulted by make_groot_pre_post_processors. normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { "VISUAL": NormalizationMode.IDENTITY, - "STATE": NormalizationMode.MEAN_STD, - "ACTION": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, } ) @@ -578,11 +582,22 @@ class GrootConfig(PreTrainedConfig): @property def action_delta_indices(self) -> list[int]: - """Return indices for delta actions.""" + """Return indices for delta actions. + + The model action horizon is read from the checkpoint's processor_config.json + when available; the result is cached (keyed on the inputs that determine it) so + repeated access during dataset/training setup does not re-read from disk. + """ + cache_key = (self.base_model_path, self.embodiment_tag, self.chunk_size) + cached = getattr(self, "_action_delta_indices_cache", None) + if cached is not None and cached[0] == cache_key: + return cached[1] model_action_horizon = ( infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40 ) - return list(range(min(self.chunk_size, model_action_horizon))) + indices = list(range(min(self.chunk_size, model_action_horizon))) + object.__setattr__(self, "_action_delta_indices_cache", (cache_key, indices)) + return indices @property def reward_delta_indices(self) -> None: diff --git a/src/lerobot/policies/groot/groot_n1_7.py b/src/lerobot/policies/groot/groot_n1_7.py index f30b078a4..ef31b5b25 100644 --- a/src/lerobot/policies/groot/groot_n1_7.py +++ b/src/lerobot/policies/groot/groot_n1_7.py @@ -71,7 +71,7 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = { "backbone_embedding_dim": 2048, "tune_llm": False, "tune_visual": False, - "select_layer": 12, + "select_layer": 16, # N1.7-3B checkpoint value; real checkpoint loads override this from config.json "reproject_vision": False, "use_flash_attention": True, "load_bf16": False, diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index d2804d827..1cd3eb171 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -23,6 +23,7 @@ orchestration are handled by LeRobot's standard training stack. """ import builtins +import logging import os from collections import deque from pathlib import Path @@ -48,6 +49,8 @@ from .configuration_groot import ( normalize_groot_model_version, ) +logger = logging.getLogger(__name__) + T = TypeVar("T", bound="GrootPolicy") @@ -149,9 +152,10 @@ class GrootPolicy(PreTrainedPolicy): if config is not None else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7 ) - print( - f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n" - f"Loading pretrained model from: {pretrained_name_or_path}" + logger.info( + "The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s", + requested_version, + pretrained_name_or_path, ) model_id = str(pretrained_name_or_path) @@ -182,7 +186,7 @@ class GrootPolicy(PreTrainedPolicy): if is_finetuned_checkpoint: # This is a fine-tuned LeRobot checkpoint - use parent class loading - print("Detected fine-tuned LeRobot checkpoint, loading with state dict...") + logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...") return super().from_pretrained( pretrained_name_or_path=pretrained_name_or_path, config=config, @@ -198,7 +202,7 @@ class GrootPolicy(PreTrainedPolicy): ) # This is a base GR00T model - load it fresh - print("Detected base GR00T model, loading from HuggingFace...") + logger.info("Detected base GR00T model, loading from HuggingFace...") if config is None: model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7 @@ -409,6 +413,11 @@ class GrootPolicy(PreTrainedPolicy): # Isaac-GR00T returns a BatchFeature; loss key is typically 'loss' loss = outputs.get("loss") + if loss is None: + raise RuntimeError( + "GR00T model.forward did not return a 'loss'. Training batches must include " + "'action' and 'action_mask'; check the preprocessor output." + ) loss_dict = {"loss": loss.item()} @@ -471,33 +480,21 @@ class GrootPolicy(PreTrainedPolicy): # Internal helpers # ------------------------- def _handle_flash_attention_compatibility(self) -> None: - """Handle Flash Attention compatibility issues by setting environment variables. + """Log Flash Attention availability (diagnostic only). - This addresses the common 'undefined symbol' error that occurs when Flash Attention - is compiled against a different PyTorch version than what's currently installed. + The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is + unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not + change behaviour or mutate global state. """ - - # Set environment variables to handle Flash Attention compatibility - # These help with symbol resolution issues - os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0") - os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0") - - # Try to import flash_attn and handle failures gracefully try: import flash_attn - print(f"[GROOT] Flash Attention version: {flash_attn.__version__}") - except ImportError as e: - print(f"[GROOT] Flash Attention not available: {e}") - print("[GROOT] Will use fallback attention mechanism") - except Exception as e: - if "undefined symbol" in str(e): - print(f"[GROOT] Flash Attention compatibility issue detected: {e}") - print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch") - print("[GROOT] Consider reinstalling Flash Attention with compatible version:") - print(" pip uninstall flash-attn") - print(" pip install --no-build-isolation flash-attn==2.6.3") - print("[GROOT] Continuing with fallback attention mechanism") - else: - print(f"[GROOT] Flash Attention error: {e}") - print("[GROOT] Continuing with fallback attention mechanism") + logger.debug("Flash Attention %s is available.", flash_attn.__version__) + except ImportError: + logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.") + except Exception as e: # noqa: BLE001 + logger.warning( + "Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is " + "an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.", + e, + )