diff --git a/src/lerobot/policies/smolvla2/modeling_smolvla2.py b/src/lerobot/policies/smolvla2/modeling_smolvla2.py index c164c7785..08660aa00 100644 --- a/src/lerobot/policies/smolvla2/modeling_smolvla2.py +++ b/src/lerobot/policies/smolvla2/modeling_smolvla2.py @@ -85,6 +85,40 @@ def _locate_lang_range(prefix_att_masks: Tensor, num_lang: int) -> tuple[int, in return lang_start, lang_end +def _mark_target_span_causal( + prefix_att_masks: Tensor, text_labels: Tensor, lang_start: int, lang_end: int +) -> Tensor: + """Make the supervised text-target span causally masked. + + ``embed_prefix`` flags every language token with ``att=0``, which + ``make_att_2d_masks`` turns into one fully *bidirectional* block — + so a target token's hidden state attends to the very tokens it is + supposed to predict. The text cross-entropy then degenerates into + a copy task (loss → ~0) and the model never learns causal + next-token prediction — at inference, where ``select_message`` + decodes autoregressively (causally), it collapses. + + Fix: set ``att=1`` on the language positions that are supervised + targets (``text_labels != -100``). With ``make_att_2d_masks``'s + cumulative-block rule each target token then attends to images + + the user prompt bidirectionally and to *earlier* target tokens + only — i.e. genuine causal next-token prediction, matching + inference. Non-target language (the user prompt, and the + ``low_level_execution`` subtask which is a user turn, not a + target) stays ``att=0`` / bidirectional. The action expert is + unaffected: the suffix has a strictly higher cumsum so it still + attends to every prefix token. + """ + att = prefix_att_masks.clone() + n = min(text_labels.shape[1], lang_end - lang_start) + if n <= 0: + return att + target = text_labels[:, :n] != -100 # (B, n) bool + seg = att[:, lang_start : lang_start + n].bool() + att[:, lang_start : lang_start + n] = seg | target + return att + + def _shifted_ce(logits: Tensor, text_labels: Tensor) -> Tensor: """Next-token CE: hidden at t predicts label at t+1, ignore_index=-100.""" num_lang = logits.shape[1] @@ -287,6 +321,12 @@ class SmolVLA2Policy(SmolVLAPolicy): prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix( images, img_masks, lang_tokens, lang_masks, state=state ) + # Causally mask the supervised target span so the text-CE is + # genuine next-token prediction (see ``_mark_target_span_causal``). + lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1]) + prefix_att_masks = _mark_target_span_causal( + prefix_att_masks, text_labels, lang_start, lang_end + ) prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 @@ -306,7 +346,10 @@ class SmolVLA2Policy(SmolVLAPolicy): "states — text-loss path needs them." ) - lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1]) + # ``lang_start`` / ``lang_end`` were located above on the + # *unmodified* att masks — don't recompute here, because + # ``_mark_target_span_causal`` set target lang tokens to 1 and + # ``_locate_lang_range`` keys on the first 1 (the state token). vlm = self.model.vlm_with_expert.vlm lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype) logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab) @@ -363,6 +406,17 @@ class SmolVLA2Policy(SmolVLAPolicy): prefix_embs, prefix_pad_masks, prefix_att_masks = inner.embed_prefix( images, img_masks, lang_tokens, lang_masks, state=state ) + # Causally mask the supervised text-target span (see + # ``_mark_target_span_causal``). Per-sample: high_level_subtask + # samples have a subtask target → causal; low_level_execution + # samples have all -100 labels → untouched / bidirectional, so + # the action expert still reads the subtask as bidirectional + # context. ``lang_start`` / ``lang_end`` located here on the + # unmodified mask and reused for the text-loss slice below. + lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1]) + prefix_att_masks = _mark_target_span_causal( + prefix_att_masks, text_labels, lang_start, lang_end + ) suffix_embs, suffix_pad_masks, suffix_att_masks = inner.embed_suffix(x_t, time) pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) @@ -412,7 +466,7 @@ class SmolVLA2Policy(SmolVLAPolicy): flow_loss = per_sample_flow.mean() # ---------------- text loss (lang slice of prefix) --------------- - lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1]) + # ``lang_start`` / ``lang_end`` from above (unmodified mask). vlm = inner.vlm_with_expert.vlm lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype) logits = vlm.lm_head(lang_hidden)