diff --git a/src/lerobot/policies/pi052/configuration_pi052.py b/src/lerobot/policies/pi052/configuration_pi052.py index 84a570c67..c6563a04b 100644 --- a/src/lerobot/policies/pi052/configuration_pi052.py +++ b/src/lerobot/policies/pi052/configuration_pi052.py @@ -180,6 +180,16 @@ class PI052Config(PI05Config): # the same cosine lambda, so the 5x ratio is preserved across decay. lm_head_lr_scale: float = 5.0 + # PaLM-style z-loss on text CE. Penalises the log-partition function + # ``z = log Σ exp(logits)`` drifting away from zero — without it, large- + # vocab models (PaliGemma is 257k) can let ``logsumexp`` grow unbounded + # while CE stays low, because a uniform additive logit bias cancels in + # softmax. PaLM appendix B / Chinchilla report z-loss is essential for + # stable large-vocab CE; it especially helps under ``lm_head_lr_scale= + # 5.0`` which amplifies drift risk on the LM head. ``1e-4`` is the + # commonly cited weight; set 0 to disable entirely. + text_ce_z_loss_weight: float = 1e-4 + def __post_init__(self) -> None: super().__post_init__() # Backbone needs gradients flowing through the text head when diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 1f431b26c..af30d5f8a 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -68,22 +68,35 @@ def _mask_per_sample(per_sample: Tensor, predict_actions_t: Tensor | None) -> Te return (per_sample * mask).sum() / mask.sum().clamp(min=1.0) -def _shifted_ce(logits: Tensor, labels: Tensor) -> Tensor: +def _shifted_ce(logits: Tensor, labels: Tensor, z_loss_weight: float = 0.0) -> Tensor: """Next-token CE: hidden at t predicts label at t+1, ignore_index=-100. Mean over non-ignored positions across the batch. Returns 0 cleanly when no positions are supervised (clamp(min=1) on the denominator). + + When ``z_loss_weight > 0``, also adds PaLM-style z-loss + (``z² · w``, where ``z = log Σ exp(logits)``) on every supervised + position. Penalises the log-partition function drifting away from + zero — without it, large-vocab models (PaliGemma is 257k) can let + ``logsumexp`` grow unboundedly while CE stays low, because uniform + additive logit bias cancels in softmax. PaLM appendix B / Chinchilla + report this is essential for stable large-vocab CE; cheap insurance + here especially with ``lm_head_lr_scale=5.0`` amplifying drift risk. """ shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous().long() valid = shift_labels != -100 if not bool(valid.any().item()): return shift_logits.sum() * 0.0 - return F.cross_entropy( - shift_logits[valid], - shift_labels[valid], - reduction="mean", - ) + valid_logits = shift_logits[valid] + valid_labels = shift_labels[valid] + ce = F.cross_entropy(valid_logits, valid_labels, reduction="mean") + if z_loss_weight <= 0.0: + return ce + # PaLM z-loss: penalise (log Σ exp(logits))² per supervised position. + # ``logsumexp`` is numerically stable and shares the softmax kernel. + z = torch.logsumexp(valid_logits, dim=-1) + return ce + z_loss_weight * (z**2).mean() def _mark_target_span_causal( @@ -668,7 +681,11 @@ class PI052Policy(PI05Policy): else: text_hidden = prefix_out[:, -lang_len:, :] text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) - text_loss = _shifted_ce(text_logits, text_labels) + text_loss = _shifted_ce( + text_logits, + text_labels, + z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0), + ) fast_loss: Tensor | None = None if fast_len > 0 and prefix_out is not None and action_code_mask is not None: @@ -768,7 +785,11 @@ class PI052Policy(PI05Policy): else: text_hidden = vlm_out[:, -lang_len:, :] text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) - text_loss = _shifted_ce(text_logits, text_labels) + text_loss = _shifted_ce( + text_logits, + text_labels, + z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0), + ) fast_loss: Tensor | None = None if (