From 3cd348ffe261998ae64482b937020c6f509d6a0b Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sat, 16 May 2026 18:24:44 +0200 Subject: [PATCH] fix(smolvla2): causal mask on the text-CE target span (THE collapse bug) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Root cause of every collapsed inference run. ``embed_prefix`` flags all language tokens ``att=0``; ``make_att_2d_masks`` turns that into a single fully BIDIRECTIONAL block. So during the text-loss forward, a supervised subtask token's hidden state attends to the very tokens it is trained to predict. The cross-entropy degenerates into a copy task — ``text_loss → ~3e-5`` not because the model learned to predict subtasks but because it can see the answer. At inference ``select_message`` decodes autoregressively (causally): each token must be predicted WITHOUT seeing it — a task the model was never actually trained on. Hence the universal collapse: a coherent first token or two ("grasp the yellow cube"), then a loop ("cover cover cover", "icatorsicators", "the the the"). Fix: ``_mark_target_span_causal`` sets ``att=1`` on the language positions that are supervised targets (``text_labels != -100``). With make_att_2d_masks's cumulative-block rule each target token then attends to images + the user prompt bidirectionally and to EARLIER target tokens only — genuine causal next-token prediction, matching select_message. Applied in both ``_compute_text_loss`` and ``_compute_fused_loss``. Per-sample correct: high_level_subtask targets become causal; low_level_execution subtasks (a user turn, labels all -100) stay bidirectional so the action expert reads them as bidirectional context. The action expert is otherwise unaffected — the suffix has a strictly higher cumsum and still attends to the whole prefix. Requires retraining: this changes the training objective. Existing checkpoints were all trained on the degenerate copy task and cannot generate text. Expect ``text_loss`` to settle MUCH higher than 3e-5 after this — that is correct; it is now a real prediction task. NOTE: pi052's text path (PaliGemma prefix-LM) has the same bidirectional-block structure and needs the analogous fix. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../policies/smolvla2/modeling_smolvla2.py | 58 ++++++++++++++++++- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/smolvla2/modeling_smolvla2.py b/src/lerobot/policies/smolvla2/modeling_smolvla2.py index c164c7785..08660aa00 100644 --- a/src/lerobot/policies/smolvla2/modeling_smolvla2.py +++ b/src/lerobot/policies/smolvla2/modeling_smolvla2.py @@ -85,6 +85,40 @@ def _locate_lang_range(prefix_att_masks: Tensor, num_lang: int) -> tuple[int, in return lang_start, lang_end +def _mark_target_span_causal( + prefix_att_masks: Tensor, text_labels: Tensor, lang_start: int, lang_end: int +) -> Tensor: + """Make the supervised text-target span causally masked. + + ``embed_prefix`` flags every language token with ``att=0``, which + ``make_att_2d_masks`` turns into one fully *bidirectional* block — + so a target token's hidden state attends to the very tokens it is + supposed to predict. The text cross-entropy then degenerates into + a copy task (loss → ~0) and the model never learns causal + next-token prediction — at inference, where ``select_message`` + decodes autoregressively (causally), it collapses. + + Fix: set ``att=1`` on the language positions that are supervised + targets (``text_labels != -100``). With ``make_att_2d_masks``'s + cumulative-block rule each target token then attends to images + + the user prompt bidirectionally and to *earlier* target tokens + only — i.e. genuine causal next-token prediction, matching + inference. Non-target language (the user prompt, and the + ``low_level_execution`` subtask which is a user turn, not a + target) stays ``att=0`` / bidirectional. The action expert is + unaffected: the suffix has a strictly higher cumsum so it still + attends to every prefix token. + """ + att = prefix_att_masks.clone() + n = min(text_labels.shape[1], lang_end - lang_start) + if n <= 0: + return att + target = text_labels[:, :n] != -100 # (B, n) bool + seg = att[:, lang_start : lang_start + n].bool() + att[:, lang_start : lang_start + n] = seg | target + return att + + def _shifted_ce(logits: Tensor, text_labels: Tensor) -> Tensor: """Next-token CE: hidden at t predicts label at t+1, ignore_index=-100.""" num_lang = logits.shape[1] @@ -287,6 +321,12 @@ class SmolVLA2Policy(SmolVLAPolicy): prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix( images, img_masks, lang_tokens, lang_masks, state=state ) + # Causally mask the supervised target span so the text-CE is + # genuine next-token prediction (see ``_mark_target_span_causal``). + lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1]) + prefix_att_masks = _mark_target_span_causal( + prefix_att_masks, text_labels, lang_start, lang_end + ) prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 @@ -306,7 +346,10 @@ class SmolVLA2Policy(SmolVLAPolicy): "states — text-loss path needs them." ) - lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1]) + # ``lang_start`` / ``lang_end`` were located above on the + # *unmodified* att masks — don't recompute here, because + # ``_mark_target_span_causal`` set target lang tokens to 1 and + # ``_locate_lang_range`` keys on the first 1 (the state token). vlm = self.model.vlm_with_expert.vlm lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype) logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab) @@ -363,6 +406,17 @@ class SmolVLA2Policy(SmolVLAPolicy): prefix_embs, prefix_pad_masks, prefix_att_masks = inner.embed_prefix( images, img_masks, lang_tokens, lang_masks, state=state ) + # Causally mask the supervised text-target span (see + # ``_mark_target_span_causal``). Per-sample: high_level_subtask + # samples have a subtask target → causal; low_level_execution + # samples have all -100 labels → untouched / bidirectional, so + # the action expert still reads the subtask as bidirectional + # context. ``lang_start`` / ``lang_end`` located here on the + # unmodified mask and reused for the text-loss slice below. + lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1]) + prefix_att_masks = _mark_target_span_causal( + prefix_att_masks, text_labels, lang_start, lang_end + ) suffix_embs, suffix_pad_masks, suffix_att_masks = inner.embed_suffix(x_t, time) pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) @@ -412,7 +466,7 @@ class SmolVLA2Policy(SmolVLAPolicy): flow_loss = per_sample_flow.mean() # ---------------- text loss (lang slice of prefix) --------------- - lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1]) + # ``lang_start`` / ``lang_end`` from above (unmodified mask). vlm = inner.vlm_with_expert.vlm lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype) logits = vlm.lm_head(lang_hidden)