Merge pull request #39 from johnnynunez/split/groot-n17-training-optim-contract

fix(groot): align N1.7 fine-tuning optimizer/scheduler/precision with Isaac-GR00T
This commit is contained in:
acwrenn53
2026-07-01 16:16:34 -07:00
committed by GitHub
4 changed files with 203 additions and 14 deletions
@@ -15,11 +15,12 @@
# limitations under the License.
import logging
import math
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.optim import AdamWConfig, DiffuserSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
from .utils import read_json
@@ -336,11 +337,14 @@ class GrootConfig(PreTrainedConfig):
# Training parameters
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.95, 0.999)
# Isaac-GR00T N1.7 fine-tunes with AdamW betas (0.9, 0.999).
optimizer_betas: tuple[float, float] = (0.9, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-5
warmup_ratio: float = 0.05
use_bf16: bool = True
# The native N1.7 fine-tuning recipe keeps model parameters in FP32 and computes under BF16 autocast.
model_params_fp32: bool = True
# TODO(Steven): Remove these deprecated fields in a future release.
# Deprecated Isaac-GR00T runner / GR00T N1.5 fields, plus the (never-wired) LoRA fields — all
@@ -480,15 +484,20 @@ class GrootConfig(PreTrainedConfig):
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=1.0,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
"""Return scheduler configuration."""
return CosineDecayWithWarmupSchedulerConfig(
num_warmup_steps=int(10000 * self.warmup_ratio), # 5% warmup by default
num_decay_steps=10000, # Adjust based on training steps
peak_lr=self.optimizer_lr,
decay_lr=self.optimizer_lr * 0.1,
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
"""Return scheduler configuration.
Isaac-GR00T uses the HF Trainer cosine schedule with ~5% warmup over the
actual training update count; DiffuserSchedulerConfig wraps the same
diffusers/transformers `get_scheduler("cosine")` implementation and
derives num_training_steps from the outer --steps value at runtime.
"""
return DiffuserSchedulerConfig(
name="cosine",
num_warmup_steps=math.ceil(self.max_steps * self.warmup_ratio),
)
@property
@@ -504,6 +513,11 @@ class GrootConfig(PreTrainedConfig):
)
return list(range(min(self.chunk_size, model_action_horizon)))
@property
def drop_n_last_frames(self) -> int:
"""Exclude episode tails that cannot supply a complete N1.7 action chunk."""
return max(0, len(self.action_delta_indices) - 1)
@property
def reward_delta_indices(self) -> None:
"""Return indices for delta rewards (None for Groot)."""
+15 -1
View File
@@ -60,6 +60,19 @@ except ImportError:
logger = logging.getLogger(__name__)
def _tie_unused_qwen_lm_head(model: nn.Module) -> None:
"""Restore the TF4 weight tie so the unused LM head stays frozen and is omitted on save."""
lm_head = getattr(model, "lm_head", None)
get_input_embeddings = getattr(model, "get_input_embeddings", None)
if lm_head is None or not callable(get_input_embeddings):
return
input_embeddings = get_input_embeddings()
embedding_weight = getattr(input_embeddings, "weight", None)
if embedding_weight is None:
return
lm_head.weight = embedding_weight
GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"model_dtype": "bfloat16",
"dtype": "bfloat16",
@@ -288,6 +301,7 @@ class Qwen3Backbone(nn.Module):
config_kwargs=transformers_loading_kwargs,
).eval()
_tie_unused_qwen_lm_head(self.model)
while len(self.language_model.layers) > select_layer:
self.language_model.layers.pop(-1)
@@ -603,7 +617,7 @@ class GR00TN17ActionHead(nn.Module):
pred = self.action_decoder(model_output, embodiment_id)
pred_actions = pred[:, -actions.shape[1] :]
action_mask = action_input.action_mask.to(dtype=pred_actions.dtype)
action_mask = action_input.action_mask
action_loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
loss = action_loss.sum() / (action_mask.sum() + 1e-6)
return BatchFeature(
+44 -4
View File
@@ -34,6 +34,7 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from torch import Tensor
from transformers.trainer_pt_utils import get_parameter_names
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.utils.constants import ACTION, OBS_IMAGES
@@ -50,7 +51,7 @@ from .configuration_groot import (
infer_groot_n1_7_action_execution_horizon,
infer_groot_n1_7_action_horizon,
)
from .groot_n1_7 import GR00TN17
from .groot_n1_7 import GR00TN17, _tie_unused_qwen_lm_head
logger = logging.getLogger(__name__)
@@ -96,11 +97,49 @@ class GrootPolicy(PreTrainedPolicy):
if self.config.rtc_ramp_rate is not None:
model_kwargs["rtc_ramp_rate"] = self.config.rtc_ramp_rate
return GR00TN17.from_pretrained(
model = GR00TN17.from_pretrained(
**model_kwargs,
tune_vlln=self.config.tune_vlln,
transformers_loading_kwargs={"trust_remote_code": True},
)
backbone = getattr(model, "backbone", None)
qwen_model = getattr(backbone, "model", None)
if qwen_model is not None:
_tie_unused_qwen_lm_head(qwen_model)
if self.config.model_params_fp32:
self._cast_model_parameters_to_fp32(model)
return model
@staticmethod
def _cast_model_parameters_to_fp32(model: torch.nn.Module) -> None:
for parameter in model.parameters():
if parameter.is_floating_point():
parameter.data = parameter.data.to(torch.float32)
@staticmethod
def _build_weight_decay_parameter_groups(model: torch.nn.Module) -> list[dict[str, object]]:
forbidden_name_patterns = [
r"bias",
r"layernorm",
r"rmsnorm",
r"(?:^|\.)norm(?:$|\.)",
r"_norm(?:$|\.)",
]
decay_names = set(get_parameter_names(model, [torch.nn.LayerNorm], forbidden_name_patterns))
decay_params = [
parameter
for name, parameter in model.named_parameters()
if parameter.requires_grad and name in decay_names
]
no_decay_params = [
parameter
for name, parameter in model.named_parameters()
if parameter.requires_grad and name not in decay_names
]
return [
{"params": decay_params},
{"params": no_decay_params, "weight_decay": 0.0},
]
def reset(self):
"""Reset policy state when environment resets."""
@@ -238,8 +277,9 @@ class GrootPolicy(PreTrainedPolicy):
policy.eval()
return policy
def get_optim_params(self) -> dict:
return self.parameters()
def get_optim_params(self): # type: ignore[override]
"""Isaac-GR00T excludes biases and normalization parameters from weight decay."""
return self._build_weight_decay_parameter_groups(self)
def _resolve_action_queue_steps(self) -> int:
n_action_steps = int(self.config.n_action_steps)
@@ -0,0 +1,121 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Isaac-GR00T N1.7 optimizer/scheduler/precision training contract.
Pins the LeRobot GR00T fine-tuning recipe to the native Isaac-GR00T contract:
AdamW(lr=1e-4, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5, grad clip 1.0),
HF cosine schedule with ~5% warmup over the actual update count, FP32 master
parameters under BF16 autocast, transformers-style weight-decay grouping, the
frozen LM-head weight tie, and episode-tail exclusion for incomplete chunks.
"""
import pytest
import torch
from lerobot.optim.schedulers import DiffuserSchedulerConfig
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.groot.groot_n1_7 import _tie_unused_qwen_lm_head
from lerobot.policies.groot.modeling_groot import GrootPolicy
def test_groot_n1_7_optimizer_matches_isaac_training_contract():
optimizer = GrootConfig().get_optimizer_preset()
assert optimizer.lr == pytest.approx(1e-4)
assert optimizer.betas == pytest.approx((0.9, 0.999))
assert optimizer.eps == pytest.approx(1e-8)
assert optimizer.weight_decay == pytest.approx(1e-5)
assert optimizer.grad_clip_norm == pytest.approx(1.0)
def test_groot_n1_7_sampler_excludes_incomplete_action_tails():
config = GrootConfig(chunk_size=16, n_action_steps=16)
assert len(config.action_delta_indices) == 16
assert config.drop_n_last_frames == 15
def test_groot_n1_7_scheduler_matches_isaac_hf_cosine_contract():
config = GrootConfig(max_steps=20_000)
scheduler_config = config.get_scheduler_preset()
assert isinstance(scheduler_config, DiffuserSchedulerConfig)
assert scheduler_config.name == "cosine"
assert scheduler_config.num_warmup_steps == 1_000
parameter = torch.nn.Parameter(torch.ones(()))
optimizer = torch.optim.AdamW([parameter], lr=config.optimizer_lr)
scheduler = scheduler_config.build(optimizer, num_training_steps=20_000)
lr_factor = scheduler.lr_lambdas[0]
assert lr_factor(0) == pytest.approx(0.0)
assert lr_factor(1_000) == pytest.approx(1.0)
assert lr_factor(10_500) == pytest.approx(0.5)
assert lr_factor(20_000) == pytest.approx(0.0, abs=1e-12)
def test_groot_n1_7_scheduler_rounds_fractional_warmup_up_like_transformers():
scheduler_config = GrootConfig(max_steps=777).get_scheduler_preset()
assert scheduler_config.num_warmup_steps == 39
def test_groot_n1_7_model_parameters_use_fp32_checkpoint_and_optimizer_precision():
module = torch.nn.Module()
module.trainable = torch.nn.Parameter(torch.ones(3, dtype=torch.bfloat16))
module.frozen = torch.nn.Parameter(torch.ones(3, dtype=torch.bfloat16), requires_grad=False)
GrootPolicy._cast_model_parameters_to_fp32(module)
assert module.trainable.dtype == torch.float32
assert module.frozen.dtype == torch.float32
def test_groot_n1_7_ties_unused_qwen_lm_head_to_frozen_input_embeddings():
class DummyQwen(torch.nn.Module):
def __init__(self):
super().__init__()
self.embed_tokens = torch.nn.Embedding(7, 3)
self.lm_head = torch.nn.Linear(3, 7, bias=False)
def get_input_embeddings(self):
return self.embed_tokens
model = DummyQwen()
_tie_unused_qwen_lm_head(model)
assert model.lm_head.weight is model.embed_tokens.weight
assert len(list(model.parameters())) == 1
def test_groot_n1_7_optimizer_groups_match_transformers_weight_decay_rules():
module = torch.nn.Module()
module.linear = torch.nn.Linear(3, 2)
module.norm = torch.nn.LayerNorm(2)
module.frozen = torch.nn.Parameter(torch.ones(1), requires_grad=False)
groups = GrootPolicy._build_weight_decay_parameter_groups(module)
assert len(groups) == 2
assert "weight_decay" not in groups[0]
assert groups[1]["weight_decay"] == 0.0
assert groups[0]["params"] == [module.linear.weight]
assert {id(parameter) for parameter in groups[1]["params"]} == {
id(module.linear.bias),
id(module.norm.weight),
id(module.norm.bias),
}