mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 895eaf0d7c |
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -42,6 +43,9 @@ else:
|
|||||||
Timesteps = None
|
Timesteps = None
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TimestepEncoder(nn.Module):
|
class TimestepEncoder(nn.Module):
|
||||||
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
||||||
require_package("diffusers", extra="groot")
|
require_package("diffusers", extra="groot")
|
||||||
@@ -265,8 +269,8 @@ class DiT(ModelMixin, ConfigMixin):
|
|||||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||||
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
||||||
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
||||||
print(
|
logger.debug(
|
||||||
"Total number of DiT parameters: ",
|
"Total number of DiT parameters: %d",
|
||||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -426,8 +430,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin):
|
|||||||
for _ in range(self.config.num_layers)
|
for _ in range(self.config.num_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
print(
|
logger.debug(
|
||||||
"Total number of SelfAttentionTransformer parameters: ",
|
"Total number of SelfAttentionTransformer parameters: %d",
|
||||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
|||||||
"backbone_embedding_dim": 2048,
|
"backbone_embedding_dim": 2048,
|
||||||
"tune_llm": False,
|
"tune_llm": False,
|
||||||
"tune_visual": False,
|
"tune_visual": False,
|
||||||
"select_layer": 12,
|
"select_layer": 16,
|
||||||
"reproject_vision": False,
|
"reproject_vision": False,
|
||||||
"use_flash_attention": True,
|
"use_flash_attention": True,
|
||||||
"load_bf16": False,
|
"load_bf16": False,
|
||||||
@@ -819,11 +819,14 @@ def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
|
|||||||
|
|
||||||
|
|
||||||
def get_backbone_cls(config: GR00TN17Config):
|
def get_backbone_cls(config: GR00TN17Config):
|
||||||
if (
|
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
|
||||||
config.backbone_model_type == "qwen"
|
return Qwen3Backbone
|
||||||
or "nvidia/Cosmos-Reason2" in config.model_name
|
if config.backbone_model_type == "qwen":
|
||||||
or "Qwen/Qwen3-VL" in config.model_name
|
logger.warning(
|
||||||
):
|
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
|
||||||
|
"backbone because backbone_model_type='qwen'.",
|
||||||
|
config.model_name,
|
||||||
|
)
|
||||||
return Qwen3Backbone
|
return Qwen3Backbone
|
||||||
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
|
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
|
||||||
|
|
||||||
@@ -909,7 +912,7 @@ class GR00TN17(PreTrainedModel):
|
|||||||
"trust_remote_code": True
|
"trust_remote_code": True
|
||||||
}
|
}
|
||||||
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
|
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
|
||||||
for key in ("revision", "cache_dir", "local_files_only", "token"):
|
for key in ("cache_dir", "local_files_only", "token"):
|
||||||
if key in kwargs:
|
if key in kwargs:
|
||||||
transformers_loading_kwargs.setdefault(key, kwargs[key])
|
transformers_loading_kwargs.setdefault(key, kwargs[key])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user