mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 06:29:47 +00:00
pi052: PaLM-style z-loss on text CE (default weight 1e-4)
Penalise the log-partition function z = log Σ exp(logits) drifting away
from zero on text-CE supervised positions. Without it, large-vocab
models (PaliGemma's 257k vocab) can let logsumexp grow unboundedly
while CE stays low — a uniform additive logit bias cancels in softmax
but pushes the partition function out of bounds, causing numerical
instability and generation drift.
PaLM appendix B / Chinchilla report z-loss is essential for stable
large-vocab CE. It is especially valuable for pi052 because the recent
default lm_head_lr_scale=5.0 amplifies head-drift risk: the 5x boost
keeps the head pinned to fine-tuning targets, and z-loss caps the
partition function so the head can't just bias all logits high uniformly.
Implementation:
* _shifted_ce(logits, labels, z_loss_weight=0.0) gains the new arg
with default 0.0 (back-compat for any other caller).
* Both call sites in PI052Policy.forward read self.config.text_ce_
z_loss_weight and pass it through.
* PI052Config.text_ce_z_loss_weight defaults to 1e-4 (commonly cited
PaLM value); set to 0 to disable.
Cheap to compute: one extra logsumexp shares the softmax kernel that
F.cross_entropy already runs. No memory overhead beyond a (B*T,) tensor.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -180,6 +180,16 @@ class PI052Config(PI05Config):
|
||||
# the same cosine lambda, so the 5x ratio is preserved across decay.
|
||||
lm_head_lr_scale: float = 5.0
|
||||
|
||||
# PaLM-style z-loss on text CE. Penalises the log-partition function
|
||||
# ``z = log Σ exp(logits)`` drifting away from zero — without it, large-
|
||||
# vocab models (PaliGemma is 257k) can let ``logsumexp`` grow unbounded
|
||||
# while CE stays low, because a uniform additive logit bias cancels in
|
||||
# softmax. PaLM appendix B / Chinchilla report z-loss is essential for
|
||||
# stable large-vocab CE; it especially helps under ``lm_head_lr_scale=
|
||||
# 5.0`` which amplifies drift risk on the LM head. ``1e-4`` is the
|
||||
# commonly cited weight; set 0 to disable entirely.
|
||||
text_ce_z_loss_weight: float = 1e-4
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
# Backbone needs gradients flowing through the text head when
|
||||
|
||||
@@ -68,22 +68,35 @@ def _mask_per_sample(per_sample: Tensor, predict_actions_t: Tensor | None) -> Te
|
||||
return (per_sample * mask).sum() / mask.sum().clamp(min=1.0)
|
||||
|
||||
|
||||
def _shifted_ce(logits: Tensor, labels: Tensor) -> Tensor:
|
||||
def _shifted_ce(logits: Tensor, labels: Tensor, z_loss_weight: float = 0.0) -> Tensor:
|
||||
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100.
|
||||
|
||||
Mean over non-ignored positions across the batch. Returns 0 cleanly
|
||||
when no positions are supervised (clamp(min=1) on the denominator).
|
||||
|
||||
When ``z_loss_weight > 0``, also adds PaLM-style z-loss
|
||||
(``z² · w``, where ``z = log Σ exp(logits)``) on every supervised
|
||||
position. Penalises the log-partition function drifting away from
|
||||
zero — without it, large-vocab models (PaliGemma is 257k) can let
|
||||
``logsumexp`` grow unboundedly while CE stays low, because uniform
|
||||
additive logit bias cancels in softmax. PaLM appendix B / Chinchilla
|
||||
report this is essential for stable large-vocab CE; cheap insurance
|
||||
here especially with ``lm_head_lr_scale=5.0`` amplifying drift risk.
|
||||
"""
|
||||
shift_logits = logits[:, :-1, :].contiguous()
|
||||
shift_labels = labels[:, 1:].contiguous().long()
|
||||
valid = shift_labels != -100
|
||||
if not bool(valid.any().item()):
|
||||
return shift_logits.sum() * 0.0
|
||||
return F.cross_entropy(
|
||||
shift_logits[valid],
|
||||
shift_labels[valid],
|
||||
reduction="mean",
|
||||
)
|
||||
valid_logits = shift_logits[valid]
|
||||
valid_labels = shift_labels[valid]
|
||||
ce = F.cross_entropy(valid_logits, valid_labels, reduction="mean")
|
||||
if z_loss_weight <= 0.0:
|
||||
return ce
|
||||
# PaLM z-loss: penalise (log Σ exp(logits))² per supervised position.
|
||||
# ``logsumexp`` is numerically stable and shares the softmax kernel.
|
||||
z = torch.logsumexp(valid_logits, dim=-1)
|
||||
return ce + z_loss_weight * (z**2).mean()
|
||||
|
||||
|
||||
def _mark_target_span_causal(
|
||||
@@ -668,7 +681,11 @@ class PI052Policy(PI05Policy):
|
||||
else:
|
||||
text_hidden = prefix_out[:, -lang_len:, :]
|
||||
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
|
||||
text_loss = _shifted_ce(text_logits, text_labels)
|
||||
text_loss = _shifted_ce(
|
||||
text_logits,
|
||||
text_labels,
|
||||
z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0),
|
||||
)
|
||||
|
||||
fast_loss: Tensor | None = None
|
||||
if fast_len > 0 and prefix_out is not None and action_code_mask is not None:
|
||||
@@ -768,7 +785,11 @@ class PI052Policy(PI05Policy):
|
||||
else:
|
||||
text_hidden = vlm_out[:, -lang_len:, :]
|
||||
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
|
||||
text_loss = _shifted_ce(text_logits, text_labels)
|
||||
text_loss = _shifted_ce(
|
||||
text_logits,
|
||||
text_labels,
|
||||
z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0),
|
||||
)
|
||||
|
||||
fast_loss: Tensor | None = None
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user