mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
feat(pi05): expose lm_head_lr_scale for stronger text-CE gradient
With knowledge_insulation=True the LM head only receives gradients on text-CE samples (e.g. ~45% of the mix for subtask_mem.yaml). Under aggressive cosine LR decay this is enough for the head's first-token distribution to drift back toward PaliGemma's pretrained <loc> detection prior — teacher-forced argmax stays high while autoregressive generation collapses to <locDDDD> tokens. Add `lm_head_lr_scale` (default 1.0, no behavior change) on PI05Config. When != 1.0, PI05Policy.get_optim_params splits the policy into two param groups: the PaliGemma lm_head projection plus its tied embed_tokens at lr * lm_head_lr_scale, and the rest at lr. The cosine scheduler multiplies both groups by the same lambda each step, so the ratio is preserved across decay. Recommended starting point for pi052 + subtask_mem.yaml runs: 5.0, combined with a higher scheduler_decay_lr floor (e.g. 5e-6 instead of 1e-6) so the head doesn't get starved in the second half of training. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -96,6 +96,19 @@ class PI05Config(PreTrainedConfig):
|
||||
optimizer_foreach: bool | None = False
|
||||
optimizer_fused: bool | None = True
|
||||
|
||||
# LM-head LR multiplier. The PaliGemma `lm_head` projection (and its
|
||||
# tied `embed_tokens`) is the surface the LM head's first-token
|
||||
# distribution depends on. With ``knowledge_insulation`` blocking
|
||||
# action→VLM gradients, the LM head only sees gradients on text-CE
|
||||
# samples — which can be a small fraction of the mix (e.g. ~45% in
|
||||
# ``subtask_mem.yaml``). Under aggressive cosine LR decay the head's
|
||||
# first-token distribution can drift back toward PaliGemma's
|
||||
# pretrained ``<loc>`` detection prior, despite teacher-forced CE
|
||||
# staying near zero. Boosting just the LM-head LR (e.g. 5x) keeps
|
||||
# the head pinned to fine-tuning targets without perturbing the
|
||||
# backbone / vision tower / action expert. Default 1.0 = no change.
|
||||
lm_head_lr_scale: float = 1.0
|
||||
|
||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||
|
||||
@@ -1133,8 +1133,62 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
return fixed_state_dict
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
def get_optim_params(self):
|
||||
"""Return policy parameters, optionally split into LR-scaled groups.
|
||||
|
||||
When ``config.lm_head_lr_scale != 1.0``, the PaliGemma ``lm_head``
|
||||
and its tied ``embed_tokens`` are placed in their own param
|
||||
group with ``lr = base_lr * lm_head_lr_scale``. The cosine
|
||||
scheduler multiplies both groups by the same lambda each step,
|
||||
so the ratio is preserved across decay. Default ``1.0`` =
|
||||
return ``self.parameters()`` (back-compat with existing checkpoints
|
||||
and configs).
|
||||
"""
|
||||
scale = float(getattr(self.config, "lm_head_lr_scale", 1.0))
|
||||
if scale == 1.0:
|
||||
return self.parameters()
|
||||
head_params: list[torch.nn.Parameter] = []
|
||||
other_params: list[torch.nn.Parameter] = []
|
||||
# Both ``lm_head.weight`` and the tied ``embed_tokens.weight`` —
|
||||
# boosting only the projection without the embedding pulls them
|
||||
# apart and breaks the tie that PaliGemma was pre-trained with.
|
||||
head_substrings = (
|
||||
"paligemma_with_expert.paligemma.lm_head.",
|
||||
"paligemma_with_expert.paligemma.model.language_model.embed_tokens.",
|
||||
)
|
||||
for name, p in self.named_parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if any(s in name for s in head_substrings):
|
||||
head_params.append(p)
|
||||
else:
|
||||
other_params.append(p)
|
||||
base_lr = float(self.config.optimizer_lr)
|
||||
groups: list[dict[str, object]] = []
|
||||
if other_params:
|
||||
groups.append({"params": other_params, "lr": base_lr, "name": "policy"})
|
||||
if head_params:
|
||||
groups.append(
|
||||
{"params": head_params, "lr": base_lr * scale, "name": "lm_head"}
|
||||
)
|
||||
# Sanity: head_substrings must match at least one parameter, otherwise
|
||||
# the scale silently does nothing — surface that fast.
|
||||
if not head_params:
|
||||
raise RuntimeError(
|
||||
"lm_head_lr_scale != 1.0 but no parameters matched the LM-head "
|
||||
"name patterns: "
|
||||
f"{head_substrings!r}. Did the underlying PaliGemma module rename?"
|
||||
)
|
||||
logging.info(
|
||||
"PI05Policy: LM-head LR scale = %.3g (base=%.3g, head=%.3g) over "
|
||||
"%d head params + %d other params",
|
||||
scale,
|
||||
base_lr,
|
||||
base_lr * scale,
|
||||
len(head_params),
|
||||
len(other_params),
|
||||
)
|
||||
return groups
|
||||
|
||||
def reset(self):
|
||||
"""Reset internal state - called when environment resets."""
|
||||
|
||||
Reference in New Issue
Block a user