fix(smolvla2): causal mask on the text-CE target span (THE collapse bug)

Root cause of every collapsed inference run. ``embed_prefix`` flags
all language tokens ``att=0``; ``make_att_2d_masks`` turns that into
a single fully BIDIRECTIONAL block. So during the text-loss forward,
a supervised subtask token's hidden state attends to the very tokens
it is trained to predict. The cross-entropy degenerates into a copy
task — ``text_loss → ~3e-5`` not because the model learned to
predict subtasks but because it can see the answer.

At inference ``select_message`` decodes autoregressively (causally):
each token must be predicted WITHOUT seeing it — a task the model
was never actually trained on. Hence the universal collapse: a
coherent first token or two ("grasp the yellow cube"), then a loop
("cover cover cover", "icatorsicators", "the the the").

Fix: ``_mark_target_span_causal`` sets ``att=1`` on the language
positions that are supervised targets (``text_labels != -100``).
With make_att_2d_masks's cumulative-block rule each target token
then attends to images + the user prompt bidirectionally and to
EARLIER target tokens only — genuine causal next-token prediction,
matching select_message. Applied in both ``_compute_text_loss`` and
``_compute_fused_loss``. Per-sample correct: high_level_subtask
targets become causal; low_level_execution subtasks (a user turn,
labels all -100) stay bidirectional so the action expert reads them
as bidirectional context. The action expert is otherwise unaffected
— the suffix has a strictly higher cumsum and still attends to the
whole prefix.

Requires retraining: this changes the training objective. Existing
checkpoints were all trained on the degenerate copy task and cannot
generate text. Expect ``text_loss`` to settle MUCH higher than 3e-5
after this — that is correct; it is now a real prediction task.

NOTE: pi052's text path (PaliGemma prefix-LM) has the same
bidirectional-block structure and needs the analogous fix.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-16 18:24:44 +02:00
parent db03fc6dc4
commit 3cd348ffe2
@@ -85,6 +85,40 @@ def _locate_lang_range(prefix_att_masks: Tensor, num_lang: int) -> tuple[int, in
return lang_start, lang_end
def _mark_target_span_causal(
prefix_att_masks: Tensor, text_labels: Tensor, lang_start: int, lang_end: int
) -> Tensor:
"""Make the supervised text-target span causally masked.
``embed_prefix`` flags every language token with ``att=0``, which
``make_att_2d_masks`` turns into one fully *bidirectional* block —
so a target token's hidden state attends to the very tokens it is
supposed to predict. The text cross-entropy then degenerates into
a copy task (loss → ~0) and the model never learns causal
next-token prediction — at inference, where ``select_message``
decodes autoregressively (causally), it collapses.
Fix: set ``att=1`` on the language positions that are supervised
targets (``text_labels != -100``). With ``make_att_2d_masks``'s
cumulative-block rule each target token then attends to images +
the user prompt bidirectionally and to *earlier* target tokens
only — i.e. genuine causal next-token prediction, matching
inference. Non-target language (the user prompt, and the
``low_level_execution`` subtask which is a user turn, not a
target) stays ``att=0`` / bidirectional. The action expert is
unaffected: the suffix has a strictly higher cumsum so it still
attends to every prefix token.
"""
att = prefix_att_masks.clone()
n = min(text_labels.shape[1], lang_end - lang_start)
if n <= 0:
return att
target = text_labels[:, :n] != -100 # (B, n) bool
seg = att[:, lang_start : lang_start + n].bool()
att[:, lang_start : lang_start + n] = seg | target
return att
def _shifted_ce(logits: Tensor, text_labels: Tensor) -> Tensor:
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100."""
num_lang = logits.shape[1]
@@ -287,6 +321,12 @@ class SmolVLA2Policy(SmolVLAPolicy):
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
)
# Causally mask the supervised target span so the text-CE is
# genuine next-token prediction (see ``_mark_target_span_causal``).
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
prefix_att_masks = _mark_target_span_causal(
prefix_att_masks, text_labels, lang_start, lang_end
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
@@ -306,7 +346,10 @@ class SmolVLA2Policy(SmolVLAPolicy):
"states — text-loss path needs them."
)
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
# ``lang_start`` / ``lang_end`` were located above on the
# *unmodified* att masks — don't recompute here, because
# ``_mark_target_span_causal`` set target lang tokens to 1 and
# ``_locate_lang_range`` keys on the first 1 (the state token).
vlm = self.model.vlm_with_expert.vlm
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab)
@@ -363,6 +406,17 @@ class SmolVLA2Policy(SmolVLAPolicy):
prefix_embs, prefix_pad_masks, prefix_att_masks = inner.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
)
# Causally mask the supervised text-target span (see
# ``_mark_target_span_causal``). Per-sample: high_level_subtask
# samples have a subtask target → causal; low_level_execution
# samples have all -100 labels → untouched / bidirectional, so
# the action expert still reads the subtask as bidirectional
# context. ``lang_start`` / ``lang_end`` located here on the
# unmodified mask and reused for the text-loss slice below.
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
prefix_att_masks = _mark_target_span_causal(
prefix_att_masks, text_labels, lang_start, lang_end
)
suffix_embs, suffix_pad_masks, suffix_att_masks = inner.embed_suffix(x_t, time)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
@@ -412,7 +466,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
flow_loss = per_sample_flow.mean()
# ---------------- text loss (lang slice of prefix) ---------------
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
# ``lang_start`` / ``lang_end`` from above (unmodified mask).
vlm = inner.vlm_with_expert.vlm
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
logits = vlm.lm_head(lang_hidden)