From ca1b951e7b76cb03b0832cce916bacc9f7b21e64 Mon Sep 17 00:00:00 2001 From: pepijn Date: Fri, 22 May 2026 09:56:46 +0000 Subject: [PATCH] feat(pi05): expose lm_head_lr_scale for stronger text-CE gradient MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With knowledge_insulation=True the LM head only receives gradients on text-CE samples (e.g. ~45% of the mix for subtask_mem.yaml). Under aggressive cosine LR decay this is enough for the head's first-token distribution to drift back toward PaliGemma's pretrained detection prior — teacher-forced argmax stays high while autoregressive generation collapses to tokens. Add `lm_head_lr_scale` (default 1.0, no behavior change) on PI05Config. When != 1.0, PI05Policy.get_optim_params splits the policy into two param groups: the PaliGemma lm_head projection plus its tied embed_tokens at lr * lm_head_lr_scale, and the rest at lr. The cosine scheduler multiplies both groups by the same lambda each step, so the ratio is preserved across decay. Recommended starting point for pi052 + subtask_mem.yaml runs: 5.0, combined with a higher scheduler_decay_lr floor (e.g. 5e-6 instead of 1e-6) so the head doesn't get starved in the second half of training. Co-authored-by: Cursor --- .../policies/pi05/configuration_pi05.py | 13 +++++ src/lerobot/policies/pi05/modeling_pi05.py | 58 ++++++++++++++++++- 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py index 192fce448..2d4c0e4d8 100644 --- a/src/lerobot/policies/pi05/configuration_pi05.py +++ b/src/lerobot/policies/pi05/configuration_pi05.py @@ -96,6 +96,19 @@ class PI05Config(PreTrainedConfig): optimizer_foreach: bool | None = False optimizer_fused: bool | None = True + # LM-head LR multiplier. The PaliGemma `lm_head` projection (and its + # tied `embed_tokens`) is the surface the LM head's first-token + # distribution depends on. With ``knowledge_insulation`` blocking + # action→VLM gradients, the LM head only sees gradients on text-CE + # samples — which can be a small fraction of the mix (e.g. ~45% in + # ``subtask_mem.yaml``). Under aggressive cosine LR decay the head's + # first-token distribution can drift back toward PaliGemma's + # pretrained ```` detection prior, despite teacher-forced CE + # staying near zero. Boosting just the LM-head LR (e.g. 5x) keeps + # the head pinned to fine-tuning targets without perturbing the + # backbone / vision tower / action expert. Default 1.0 = no change. + lm_head_lr_scale: float = 1.0 + # Scheduler settings: see openpi `CosineDecaySchedule` # Note: These will auto-scale if --steps < scheduler_decay_steps # For example, --steps=3000 will scale warmup to 100 and decay to 3000 diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 2c0301e73..2f6e97321 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -1133,8 +1133,62 @@ class PI05Policy(PreTrainedPolicy): return fixed_state_dict - def get_optim_params(self) -> dict: - return self.parameters() + def get_optim_params(self): + """Return policy parameters, optionally split into LR-scaled groups. + + When ``config.lm_head_lr_scale != 1.0``, the PaliGemma ``lm_head`` + and its tied ``embed_tokens`` are placed in their own param + group with ``lr = base_lr * lm_head_lr_scale``. The cosine + scheduler multiplies both groups by the same lambda each step, + so the ratio is preserved across decay. Default ``1.0`` = + return ``self.parameters()`` (back-compat with existing checkpoints + and configs). + """ + scale = float(getattr(self.config, "lm_head_lr_scale", 1.0)) + if scale == 1.0: + return self.parameters() + head_params: list[torch.nn.Parameter] = [] + other_params: list[torch.nn.Parameter] = [] + # Both ``lm_head.weight`` and the tied ``embed_tokens.weight`` — + # boosting only the projection without the embedding pulls them + # apart and breaks the tie that PaliGemma was pre-trained with. + head_substrings = ( + "paligemma_with_expert.paligemma.lm_head.", + "paligemma_with_expert.paligemma.model.language_model.embed_tokens.", + ) + for name, p in self.named_parameters(): + if not p.requires_grad: + continue + if any(s in name for s in head_substrings): + head_params.append(p) + else: + other_params.append(p) + base_lr = float(self.config.optimizer_lr) + groups: list[dict[str, object]] = [] + if other_params: + groups.append({"params": other_params, "lr": base_lr, "name": "policy"}) + if head_params: + groups.append( + {"params": head_params, "lr": base_lr * scale, "name": "lm_head"} + ) + # Sanity: head_substrings must match at least one parameter, otherwise + # the scale silently does nothing — surface that fast. + if not head_params: + raise RuntimeError( + "lm_head_lr_scale != 1.0 but no parameters matched the LM-head " + "name patterns: " + f"{head_substrings!r}. Did the underlying PaliGemma module rename?" + ) + logging.info( + "PI05Policy: LM-head LR scale = %.3g (base=%.3g, head=%.3g) over " + "%d head params + %d other params", + scale, + base_lr, + base_lr * scale, + len(head_params), + len(other_params), + ) + return groups def reset(self): """Reset internal state - called when environment resets."""