pi052: PaLM-style z-loss on text CE (default weight 1e-4)

Penalise the log-partition function z = log Σ exp(logits) drifting away
from zero on text-CE supervised positions. Without it, large-vocab
models (PaliGemma's 257k vocab) can let logsumexp grow unboundedly
while CE stays low — a uniform additive logit bias cancels in softmax
but pushes the partition function out of bounds, causing numerical
instability and generation drift.

PaLM appendix B / Chinchilla report z-loss is essential for stable
large-vocab CE. It is especially valuable for pi052 because the recent
default lm_head_lr_scale=5.0 amplifies head-drift risk: the 5x boost
keeps the head pinned to fine-tuning targets, and z-loss caps the
partition function so the head can't just bias all logits high uniformly.

Implementation:
  * _shifted_ce(logits, labels, z_loss_weight=0.0) gains the new arg
    with default 0.0 (back-compat for any other caller).
  * Both call sites in PI052Policy.forward read self.config.text_ce_
    z_loss_weight and pass it through.
  * PI052Config.text_ce_z_loss_weight defaults to 1e-4 (commonly cited
    PaLM value); set to 0 to disable.

Cheap to compute: one extra logsumexp shares the softmax kernel that
F.cross_entropy already runs. No memory overhead beyond a (B*T,) tensor.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-25 21:08:56 +02:00
parent 8ba3b187a1
commit 738e317caa
2 changed files with 39 additions and 8 deletions
@@ -180,6 +180,16 @@ class PI052Config(PI05Config):
# the same cosine lambda, so the 5x ratio is preserved across decay. # the same cosine lambda, so the 5x ratio is preserved across decay.
lm_head_lr_scale: float = 5.0 lm_head_lr_scale: float = 5.0
# PaLM-style z-loss on text CE. Penalises the log-partition function
# ``z = log Σ exp(logits)`` drifting away from zero — without it, large-
# vocab models (PaliGemma is 257k) can let ``logsumexp`` grow unbounded
# while CE stays low, because a uniform additive logit bias cancels in
# softmax. PaLM appendix B / Chinchilla report z-loss is essential for
# stable large-vocab CE; it especially helps under ``lm_head_lr_scale=
# 5.0`` which amplifies drift risk on the LM head. ``1e-4`` is the
# commonly cited weight; set 0 to disable entirely.
text_ce_z_loss_weight: float = 1e-4
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__post_init__() super().__post_init__()
# Backbone needs gradients flowing through the text head when # Backbone needs gradients flowing through the text head when
+29 -8
View File
@@ -68,22 +68,35 @@ def _mask_per_sample(per_sample: Tensor, predict_actions_t: Tensor | None) -> Te
return (per_sample * mask).sum() / mask.sum().clamp(min=1.0) return (per_sample * mask).sum() / mask.sum().clamp(min=1.0)
def _shifted_ce(logits: Tensor, labels: Tensor) -> Tensor: def _shifted_ce(logits: Tensor, labels: Tensor, z_loss_weight: float = 0.0) -> Tensor:
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100. """Next-token CE: hidden at t predicts label at t+1, ignore_index=-100.
Mean over non-ignored positions across the batch. Returns 0 cleanly Mean over non-ignored positions across the batch. Returns 0 cleanly
when no positions are supervised (clamp(min=1) on the denominator). when no positions are supervised (clamp(min=1) on the denominator).
When ``z_loss_weight > 0``, also adds PaLM-style z-loss
(``z² · w``, where ``z = log Σ exp(logits)``) on every supervised
position. Penalises the log-partition function drifting away from
zero — without it, large-vocab models (PaliGemma is 257k) can let
``logsumexp`` grow unboundedly while CE stays low, because uniform
additive logit bias cancels in softmax. PaLM appendix B / Chinchilla
report this is essential for stable large-vocab CE; cheap insurance
here especially with ``lm_head_lr_scale=5.0`` amplifying drift risk.
""" """
shift_logits = logits[:, :-1, :].contiguous() shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous().long() shift_labels = labels[:, 1:].contiguous().long()
valid = shift_labels != -100 valid = shift_labels != -100
if not bool(valid.any().item()): if not bool(valid.any().item()):
return shift_logits.sum() * 0.0 return shift_logits.sum() * 0.0
return F.cross_entropy( valid_logits = shift_logits[valid]
shift_logits[valid], valid_labels = shift_labels[valid]
shift_labels[valid], ce = F.cross_entropy(valid_logits, valid_labels, reduction="mean")
reduction="mean", if z_loss_weight <= 0.0:
) return ce
# PaLM z-loss: penalise (log Σ exp(logits))² per supervised position.
# ``logsumexp`` is numerically stable and shares the softmax kernel.
z = torch.logsumexp(valid_logits, dim=-1)
return ce + z_loss_weight * (z**2).mean()
def _mark_target_span_causal( def _mark_target_span_causal(
@@ -668,7 +681,11 @@ class PI052Policy(PI05Policy):
else: else:
text_hidden = prefix_out[:, -lang_len:, :] text_hidden = prefix_out[:, -lang_len:, :]
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
text_loss = _shifted_ce(text_logits, text_labels) text_loss = _shifted_ce(
text_logits,
text_labels,
z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0),
)
fast_loss: Tensor | None = None fast_loss: Tensor | None = None
if fast_len > 0 and prefix_out is not None and action_code_mask is not None: if fast_len > 0 and prefix_out is not None and action_code_mask is not None:
@@ -768,7 +785,11 @@ class PI052Policy(PI05Policy):
else: else:
text_hidden = vlm_out[:, -lang_len:, :] text_hidden = vlm_out[:, -lang_len:, :]
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
text_loss = _shifted_ce(text_logits, text_labels) text_loss = _shifted_ce(
text_logits,
text_labels,
z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0),
)
fast_loss: Tensor | None = None fast_loss: Tensor | None = None
if ( if (