diff --git a/src/lerobot/policies/groot/action_head/cross_attention_dit.py b/src/lerobot/policies/groot/action_head/cross_attention_dit.py index 0991ef029..7b531d92e 100755 --- a/src/lerobot/policies/groot/action_head/cross_attention_dit.py +++ b/src/lerobot/policies/groot/action_head/cross_attention_dit.py @@ -14,6 +14,7 @@ # limitations under the License. +import logging from typing import TYPE_CHECKING import torch @@ -42,6 +43,9 @@ else: Timesteps = None +logger = logging.getLogger(__name__) + + class TimestepEncoder(nn.Module): def __init__(self, embedding_dim, compute_dtype=torch.float32): require_package("diffusers", extra="groot") @@ -265,8 +269,8 @@ class DiT(ModelMixin, ConfigMixin): self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim) - print( - "Total number of DiT parameters: ", + logger.debug( + "Total number of DiT parameters: %d", sum(p.numel() for p in self.parameters() if p.requires_grad), ) @@ -426,8 +430,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin): for _ in range(self.config.num_layers) ] ) - print( - "Total number of SelfAttentionTransformer parameters: ", + logger.debug( + "Total number of SelfAttentionTransformer parameters: %d", sum(p.numel() for p in self.parameters() if p.requires_grad), ) diff --git a/src/lerobot/policies/groot/groot_n1_7.py b/src/lerobot/policies/groot/groot_n1_7.py index 103c03a58..e062e2c5c 100644 --- a/src/lerobot/policies/groot/groot_n1_7.py +++ b/src/lerobot/policies/groot/groot_n1_7.py @@ -71,7 +71,7 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = { "backbone_embedding_dim": 2048, "tune_llm": False, "tune_visual": False, - "select_layer": 12, + "select_layer": 16, "reproject_vision": False, "use_flash_attention": True, "load_bf16": False, @@ -819,11 +819,14 @@ def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig: def get_backbone_cls(config: GR00TN17Config): - if ( - config.backbone_model_type == "qwen" - or "nvidia/Cosmos-Reason2" in config.model_name - or "Qwen/Qwen3-VL" in config.model_name - ): + if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name: + return Qwen3Backbone + if config.backbone_model_type == "qwen": + logger.warning( + "Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible " + "backbone because backbone_model_type='qwen'.", + config.model_name, + ) return Qwen3Backbone raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}") @@ -909,7 +912,7 @@ class GR00TN17(PreTrainedModel): "trust_remote_code": True } load_backbone_weights = kwargs.pop("load_backbone_weights", False) - for key in ("revision", "cache_dir", "local_files_only", "token"): + for key in ("cache_dir", "local_files_only", "token"): if key in kwargs: transformers_loading_kwargs.setdefault(key, kwargs[key])