mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-05 00:57:06 +00:00
Merge pull request #39 from johnnynunez/split/groot-n17-training-optim-contract
fix(groot): align N1.7 fine-tuning optimizer/scheduler/precision with Isaac-GR00T
This commit is contained in:
@@ -15,11 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.optim import AdamWConfig, DiffuserSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
from .utils import read_json
|
||||
@@ -336,11 +337,14 @@ class GrootConfig(PreTrainedConfig):
|
||||
|
||||
# Training parameters
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.95, 0.999)
|
||||
# Isaac-GR00T N1.7 fine-tunes with AdamW betas (0.9, 0.999).
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
warmup_ratio: float = 0.05
|
||||
use_bf16: bool = True
|
||||
# The native N1.7 fine-tuning recipe keeps model parameters in FP32 and computes under BF16 autocast.
|
||||
model_params_fp32: bool = True
|
||||
|
||||
# TODO(Steven): Remove these deprecated fields in a future release.
|
||||
# Deprecated Isaac-GR00T runner / GR00T N1.5 fields, plus the (never-wired) LoRA fields — all
|
||||
@@ -480,15 +484,20 @@ class GrootConfig(PreTrainedConfig):
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=1.0,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
"""Return scheduler configuration."""
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
num_warmup_steps=int(10000 * self.warmup_ratio), # 5% warmup by default
|
||||
num_decay_steps=10000, # Adjust based on training steps
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.optimizer_lr * 0.1,
|
||||
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
||||
"""Return scheduler configuration.
|
||||
|
||||
Isaac-GR00T uses the HF Trainer cosine schedule with ~5% warmup over the
|
||||
actual training update count; DiffuserSchedulerConfig wraps the same
|
||||
diffusers/transformers `get_scheduler("cosine")` implementation and
|
||||
derives num_training_steps from the outer --steps value at runtime.
|
||||
"""
|
||||
return DiffuserSchedulerConfig(
|
||||
name="cosine",
|
||||
num_warmup_steps=math.ceil(self.max_steps * self.warmup_ratio),
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -504,6 +513,11 @@ class GrootConfig(PreTrainedConfig):
|
||||
)
|
||||
return list(range(min(self.chunk_size, model_action_horizon)))
|
||||
|
||||
@property
|
||||
def drop_n_last_frames(self) -> int:
|
||||
"""Exclude episode tails that cannot supply a complete N1.7 action chunk."""
|
||||
return max(0, len(self.action_delta_indices) - 1)
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
"""Return indices for delta rewards (None for Groot)."""
|
||||
|
||||
@@ -60,6 +60,19 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _tie_unused_qwen_lm_head(model: nn.Module) -> None:
|
||||
"""Restore the TF4 weight tie so the unused LM head stays frozen and is omitted on save."""
|
||||
lm_head = getattr(model, "lm_head", None)
|
||||
get_input_embeddings = getattr(model, "get_input_embeddings", None)
|
||||
if lm_head is None or not callable(get_input_embeddings):
|
||||
return
|
||||
input_embeddings = get_input_embeddings()
|
||||
embedding_weight = getattr(input_embeddings, "weight", None)
|
||||
if embedding_weight is None:
|
||||
return
|
||||
lm_head.weight = embedding_weight
|
||||
|
||||
|
||||
GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
||||
"model_dtype": "bfloat16",
|
||||
"dtype": "bfloat16",
|
||||
@@ -288,6 +301,7 @@ class Qwen3Backbone(nn.Module):
|
||||
config_kwargs=transformers_loading_kwargs,
|
||||
).eval()
|
||||
|
||||
_tie_unused_qwen_lm_head(self.model)
|
||||
while len(self.language_model.layers) > select_layer:
|
||||
self.language_model.layers.pop(-1)
|
||||
|
||||
@@ -603,7 +617,7 @@ class GR00TN17ActionHead(nn.Module):
|
||||
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
action_mask = action_input.action_mask.to(dtype=pred_actions.dtype)
|
||||
action_mask = action_input.action_mask
|
||||
action_loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
|
||||
loss = action_loss.sum() / (action_mask.sum() + 1e-6)
|
||||
return BatchFeature(
|
||||
|
||||
@@ -34,6 +34,7 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from torch import Tensor
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
@@ -50,7 +51,7 @@ from .configuration_groot import (
|
||||
infer_groot_n1_7_action_execution_horizon,
|
||||
infer_groot_n1_7_action_horizon,
|
||||
)
|
||||
from .groot_n1_7 import GR00TN17
|
||||
from .groot_n1_7 import GR00TN17, _tie_unused_qwen_lm_head
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -96,11 +97,49 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
if self.config.rtc_ramp_rate is not None:
|
||||
model_kwargs["rtc_ramp_rate"] = self.config.rtc_ramp_rate
|
||||
|
||||
return GR00TN17.from_pretrained(
|
||||
model = GR00TN17.from_pretrained(
|
||||
**model_kwargs,
|
||||
tune_vlln=self.config.tune_vlln,
|
||||
transformers_loading_kwargs={"trust_remote_code": True},
|
||||
)
|
||||
backbone = getattr(model, "backbone", None)
|
||||
qwen_model = getattr(backbone, "model", None)
|
||||
if qwen_model is not None:
|
||||
_tie_unused_qwen_lm_head(qwen_model)
|
||||
if self.config.model_params_fp32:
|
||||
self._cast_model_parameters_to_fp32(model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _cast_model_parameters_to_fp32(model: torch.nn.Module) -> None:
|
||||
for parameter in model.parameters():
|
||||
if parameter.is_floating_point():
|
||||
parameter.data = parameter.data.to(torch.float32)
|
||||
|
||||
@staticmethod
|
||||
def _build_weight_decay_parameter_groups(model: torch.nn.Module) -> list[dict[str, object]]:
|
||||
forbidden_name_patterns = [
|
||||
r"bias",
|
||||
r"layernorm",
|
||||
r"rmsnorm",
|
||||
r"(?:^|\.)norm(?:$|\.)",
|
||||
r"_norm(?:$|\.)",
|
||||
]
|
||||
decay_names = set(get_parameter_names(model, [torch.nn.LayerNorm], forbidden_name_patterns))
|
||||
decay_params = [
|
||||
parameter
|
||||
for name, parameter in model.named_parameters()
|
||||
if parameter.requires_grad and name in decay_names
|
||||
]
|
||||
no_decay_params = [
|
||||
parameter
|
||||
for name, parameter in model.named_parameters()
|
||||
if parameter.requires_grad and name not in decay_names
|
||||
]
|
||||
return [
|
||||
{"params": decay_params},
|
||||
{"params": no_decay_params, "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
"""Reset policy state when environment resets."""
|
||||
@@ -238,8 +277,9 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
policy.eval()
|
||||
return policy
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
def get_optim_params(self): # type: ignore[override]
|
||||
"""Isaac-GR00T excludes biases and normalization parameters from weight decay."""
|
||||
return self._build_weight_decay_parameter_groups(self)
|
||||
|
||||
def _resolve_action_queue_steps(self) -> int:
|
||||
n_action_steps = int(self.config.n_action_steps)
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Isaac-GR00T N1.7 optimizer/scheduler/precision training contract.
|
||||
|
||||
Pins the LeRobot GR00T fine-tuning recipe to the native Isaac-GR00T contract:
|
||||
AdamW(lr=1e-4, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5, grad clip 1.0),
|
||||
HF cosine schedule with ~5% warmup over the actual update count, FP32 master
|
||||
parameters under BF16 autocast, transformers-style weight-decay grouping, the
|
||||
frozen LM-head weight tie, and episode-tail exclusion for incomplete chunks.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.optim.schedulers import DiffuserSchedulerConfig
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.groot.groot_n1_7 import _tie_unused_qwen_lm_head
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
|
||||
|
||||
def test_groot_n1_7_optimizer_matches_isaac_training_contract():
|
||||
optimizer = GrootConfig().get_optimizer_preset()
|
||||
|
||||
assert optimizer.lr == pytest.approx(1e-4)
|
||||
assert optimizer.betas == pytest.approx((0.9, 0.999))
|
||||
assert optimizer.eps == pytest.approx(1e-8)
|
||||
assert optimizer.weight_decay == pytest.approx(1e-5)
|
||||
assert optimizer.grad_clip_norm == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_groot_n1_7_sampler_excludes_incomplete_action_tails():
|
||||
config = GrootConfig(chunk_size=16, n_action_steps=16)
|
||||
|
||||
assert len(config.action_delta_indices) == 16
|
||||
assert config.drop_n_last_frames == 15
|
||||
|
||||
|
||||
def test_groot_n1_7_scheduler_matches_isaac_hf_cosine_contract():
|
||||
config = GrootConfig(max_steps=20_000)
|
||||
scheduler_config = config.get_scheduler_preset()
|
||||
|
||||
assert isinstance(scheduler_config, DiffuserSchedulerConfig)
|
||||
assert scheduler_config.name == "cosine"
|
||||
assert scheduler_config.num_warmup_steps == 1_000
|
||||
|
||||
parameter = torch.nn.Parameter(torch.ones(()))
|
||||
optimizer = torch.optim.AdamW([parameter], lr=config.optimizer_lr)
|
||||
scheduler = scheduler_config.build(optimizer, num_training_steps=20_000)
|
||||
lr_factor = scheduler.lr_lambdas[0]
|
||||
|
||||
assert lr_factor(0) == pytest.approx(0.0)
|
||||
assert lr_factor(1_000) == pytest.approx(1.0)
|
||||
assert lr_factor(10_500) == pytest.approx(0.5)
|
||||
assert lr_factor(20_000) == pytest.approx(0.0, abs=1e-12)
|
||||
|
||||
|
||||
def test_groot_n1_7_scheduler_rounds_fractional_warmup_up_like_transformers():
|
||||
scheduler_config = GrootConfig(max_steps=777).get_scheduler_preset()
|
||||
|
||||
assert scheduler_config.num_warmup_steps == 39
|
||||
|
||||
|
||||
def test_groot_n1_7_model_parameters_use_fp32_checkpoint_and_optimizer_precision():
|
||||
module = torch.nn.Module()
|
||||
module.trainable = torch.nn.Parameter(torch.ones(3, dtype=torch.bfloat16))
|
||||
module.frozen = torch.nn.Parameter(torch.ones(3, dtype=torch.bfloat16), requires_grad=False)
|
||||
|
||||
GrootPolicy._cast_model_parameters_to_fp32(module)
|
||||
|
||||
assert module.trainable.dtype == torch.float32
|
||||
assert module.frozen.dtype == torch.float32
|
||||
|
||||
|
||||
def test_groot_n1_7_ties_unused_qwen_lm_head_to_frozen_input_embeddings():
|
||||
class DummyQwen(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embed_tokens = torch.nn.Embedding(7, 3)
|
||||
self.lm_head = torch.nn.Linear(3, 7, bias=False)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
model = DummyQwen()
|
||||
_tie_unused_qwen_lm_head(model)
|
||||
|
||||
assert model.lm_head.weight is model.embed_tokens.weight
|
||||
assert len(list(model.parameters())) == 1
|
||||
|
||||
|
||||
def test_groot_n1_7_optimizer_groups_match_transformers_weight_decay_rules():
|
||||
module = torch.nn.Module()
|
||||
module.linear = torch.nn.Linear(3, 2)
|
||||
module.norm = torch.nn.LayerNorm(2)
|
||||
module.frozen = torch.nn.Parameter(torch.ones(1), requires_grad=False)
|
||||
|
||||
groups = GrootPolicy._build_weight_decay_parameter_groups(module)
|
||||
|
||||
assert len(groups) == 2
|
||||
assert "weight_decay" not in groups[0]
|
||||
assert groups[1]["weight_decay"] == 0.0
|
||||
assert groups[0]["params"] == [module.linear.weight]
|
||||
assert {id(parameter) for parameter in groups[1]["params"]} == {
|
||||
id(module.linear.bias),
|
||||
id(module.norm.weight),
|
||||
id(module.norm.bias),
|
||||
}
|
||||
Reference in New Issue
Block a user