mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
fix(smolvla2): causal mask on the text-CE target span (THE collapse bug)
Root cause of every collapsed inference run. ``embed_prefix`` flags
all language tokens ``att=0``; ``make_att_2d_masks`` turns that into
a single fully BIDIRECTIONAL block. So during the text-loss forward,
a supervised subtask token's hidden state attends to the very tokens
it is trained to predict. The cross-entropy degenerates into a copy
task — ``text_loss → ~3e-5`` not because the model learned to
predict subtasks but because it can see the answer.
At inference ``select_message`` decodes autoregressively (causally):
each token must be predicted WITHOUT seeing it — a task the model
was never actually trained on. Hence the universal collapse: a
coherent first token or two ("grasp the yellow cube"), then a loop
("cover cover cover", "icatorsicators", "the the the").
Fix: ``_mark_target_span_causal`` sets ``att=1`` on the language
positions that are supervised targets (``text_labels != -100``).
With make_att_2d_masks's cumulative-block rule each target token
then attends to images + the user prompt bidirectionally and to
EARLIER target tokens only — genuine causal next-token prediction,
matching select_message. Applied in both ``_compute_text_loss`` and
``_compute_fused_loss``. Per-sample correct: high_level_subtask
targets become causal; low_level_execution subtasks (a user turn,
labels all -100) stay bidirectional so the action expert reads them
as bidirectional context. The action expert is otherwise unaffected
— the suffix has a strictly higher cumsum and still attends to the
whole prefix.
Requires retraining: this changes the training objective. Existing
checkpoints were all trained on the degenerate copy task and cannot
generate text. Expect ``text_loss`` to settle MUCH higher than 3e-5
after this — that is correct; it is now a real prediction task.
NOTE: pi052's text path (PaliGemma prefix-LM) has the same
bidirectional-block structure and needs the analogous fix.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -85,6 +85,40 @@ def _locate_lang_range(prefix_att_masks: Tensor, num_lang: int) -> tuple[int, in
|
||||
return lang_start, lang_end
|
||||
|
||||
|
||||
def _mark_target_span_causal(
|
||||
prefix_att_masks: Tensor, text_labels: Tensor, lang_start: int, lang_end: int
|
||||
) -> Tensor:
|
||||
"""Make the supervised text-target span causally masked.
|
||||
|
||||
``embed_prefix`` flags every language token with ``att=0``, which
|
||||
``make_att_2d_masks`` turns into one fully *bidirectional* block —
|
||||
so a target token's hidden state attends to the very tokens it is
|
||||
supposed to predict. The text cross-entropy then degenerates into
|
||||
a copy task (loss → ~0) and the model never learns causal
|
||||
next-token prediction — at inference, where ``select_message``
|
||||
decodes autoregressively (causally), it collapses.
|
||||
|
||||
Fix: set ``att=1`` on the language positions that are supervised
|
||||
targets (``text_labels != -100``). With ``make_att_2d_masks``'s
|
||||
cumulative-block rule each target token then attends to images +
|
||||
the user prompt bidirectionally and to *earlier* target tokens
|
||||
only — i.e. genuine causal next-token prediction, matching
|
||||
inference. Non-target language (the user prompt, and the
|
||||
``low_level_execution`` subtask which is a user turn, not a
|
||||
target) stays ``att=0`` / bidirectional. The action expert is
|
||||
unaffected: the suffix has a strictly higher cumsum so it still
|
||||
attends to every prefix token.
|
||||
"""
|
||||
att = prefix_att_masks.clone()
|
||||
n = min(text_labels.shape[1], lang_end - lang_start)
|
||||
if n <= 0:
|
||||
return att
|
||||
target = text_labels[:, :n] != -100 # (B, n) bool
|
||||
seg = att[:, lang_start : lang_start + n].bool()
|
||||
att[:, lang_start : lang_start + n] = seg | target
|
||||
return att
|
||||
|
||||
|
||||
def _shifted_ce(logits: Tensor, text_labels: Tensor) -> Tensor:
|
||||
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100."""
|
||||
num_lang = logits.shape[1]
|
||||
@@ -287,6 +321,12 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
# Causally mask the supervised target span so the text-CE is
|
||||
# genuine next-token prediction (see ``_mark_target_span_causal``).
|
||||
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
|
||||
prefix_att_masks = _mark_target_span_causal(
|
||||
prefix_att_masks, text_labels, lang_start, lang_end
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
@@ -306,7 +346,10 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
"states — text-loss path needs them."
|
||||
)
|
||||
|
||||
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
|
||||
# ``lang_start`` / ``lang_end`` were located above on the
|
||||
# *unmodified* att masks — don't recompute here, because
|
||||
# ``_mark_target_span_causal`` set target lang tokens to 1 and
|
||||
# ``_locate_lang_range`` keys on the first 1 (the state token).
|
||||
vlm = self.model.vlm_with_expert.vlm
|
||||
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
|
||||
logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab)
|
||||
@@ -363,6 +406,17 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = inner.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
# Causally mask the supervised text-target span (see
|
||||
# ``_mark_target_span_causal``). Per-sample: high_level_subtask
|
||||
# samples have a subtask target → causal; low_level_execution
|
||||
# samples have all -100 labels → untouched / bidirectional, so
|
||||
# the action expert still reads the subtask as bidirectional
|
||||
# context. ``lang_start`` / ``lang_end`` located here on the
|
||||
# unmodified mask and reused for the text-loss slice below.
|
||||
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
|
||||
prefix_att_masks = _mark_target_span_causal(
|
||||
prefix_att_masks, text_labels, lang_start, lang_end
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = inner.embed_suffix(x_t, time)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
@@ -412,7 +466,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
flow_loss = per_sample_flow.mean()
|
||||
|
||||
# ---------------- text loss (lang slice of prefix) ---------------
|
||||
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
|
||||
# ``lang_start`` / ``lang_end`` from above (unmodified mask).
|
||||
vlm = inner.vlm_with_expert.vlm
|
||||
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
|
||||
logits = vlm.lm_head(lang_hidden)
|
||||
|
||||
Reference in New Issue
Block a user