feat(pi05): expose lm_head_lr_scale for stronger text-CE gradient

With knowledge_insulation=True the LM head only receives gradients on
text-CE samples (e.g. ~45% of the mix for subtask_mem.yaml). Under
aggressive cosine LR decay this is enough for the head's first-token
distribution to drift back toward PaliGemma's pretrained <loc>
detection prior — teacher-forced argmax stays high while autoregressive
generation collapses to <locDDDD> tokens.

Add `lm_head_lr_scale` (default 1.0, no behavior change) on PI05Config.
When != 1.0, PI05Policy.get_optim_params splits the policy into two
param groups: the PaliGemma lm_head projection plus its tied
embed_tokens at lr * lm_head_lr_scale, and the rest at lr. The cosine
scheduler multiplies both groups by the same lambda each step, so the
ratio is preserved across decay.

Recommended starting point for pi052 + subtask_mem.yaml runs: 5.0,
combined with a higher scheduler_decay_lr floor (e.g. 5e-6 instead of
1e-6) so the head doesn't get starved in the second half of training.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-22 09:56:46 +00:00
parent 9d30d91021
commit ca1b951e7b
2 changed files with 69 additions and 2 deletions
@@ -96,6 +96,19 @@ class PI05Config(PreTrainedConfig):
optimizer_foreach: bool | None = False
optimizer_fused: bool | None = True
# LM-head LR multiplier. The PaliGemma `lm_head` projection (and its
# tied `embed_tokens`) is the surface the LM head's first-token
# distribution depends on. With ``knowledge_insulation`` blocking
# action→VLM gradients, the LM head only sees gradients on text-CE
# samples — which can be a small fraction of the mix (e.g. ~45% in
# ``subtask_mem.yaml``). Under aggressive cosine LR decay the head's
# first-token distribution can drift back toward PaliGemma's
# pretrained ``<loc>`` detection prior, despite teacher-forced CE
# staying near zero. Boosting just the LM-head LR (e.g. 5x) keeps
# the head pinned to fine-tuning targets without perturbing the
# backbone / vision tower / action expert. Default 1.0 = no change.
lm_head_lr_scale: float = 1.0
# Scheduler settings: see openpi `CosineDecaySchedule`
# Note: These will auto-scale if --steps < scheduler_decay_steps
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
+56 -2
View File
@@ -1133,8 +1133,62 @@ class PI05Policy(PreTrainedPolicy):
return fixed_state_dict
def get_optim_params(self) -> dict:
return self.parameters()
def get_optim_params(self):
"""Return policy parameters, optionally split into LR-scaled groups.
When ``config.lm_head_lr_scale != 1.0``, the PaliGemma ``lm_head``
and its tied ``embed_tokens`` are placed in their own param
group with ``lr = base_lr * lm_head_lr_scale``. The cosine
scheduler multiplies both groups by the same lambda each step,
so the ratio is preserved across decay. Default ``1.0`` =
return ``self.parameters()`` (back-compat with existing checkpoints
and configs).
"""
scale = float(getattr(self.config, "lm_head_lr_scale", 1.0))
if scale == 1.0:
return self.parameters()
head_params: list[torch.nn.Parameter] = []
other_params: list[torch.nn.Parameter] = []
# Both ``lm_head.weight`` and the tied ``embed_tokens.weight`` —
# boosting only the projection without the embedding pulls them
# apart and breaks the tie that PaliGemma was pre-trained with.
head_substrings = (
"paligemma_with_expert.paligemma.lm_head.",
"paligemma_with_expert.paligemma.model.language_model.embed_tokens.",
)
for name, p in self.named_parameters():
if not p.requires_grad:
continue
if any(s in name for s in head_substrings):
head_params.append(p)
else:
other_params.append(p)
base_lr = float(self.config.optimizer_lr)
groups: list[dict[str, object]] = []
if other_params:
groups.append({"params": other_params, "lr": base_lr, "name": "policy"})
if head_params:
groups.append(
{"params": head_params, "lr": base_lr * scale, "name": "lm_head"}
)
# Sanity: head_substrings must match at least one parameter, otherwise
# the scale silently does nothing — surface that fast.
if not head_params:
raise RuntimeError(
"lm_head_lr_scale != 1.0 but no parameters matched the LM-head "
"name patterns: "
f"{head_substrings!r}. Did the underlying PaliGemma module rename?"
)
logging.info(
"PI05Policy: LM-head LR scale = %.3g (base=%.3g, head=%.3g) over "
"%d head params + %d other params",
scale,
base_lr,
base_lr * scale,
len(head_params),
len(other_params),
)
return groups
def reset(self):
"""Reset internal state - called when environment resets."""