mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
fix(smolvla2): causal mask on the text-CE target span (THE collapse bug)
Root cause of every collapsed inference run. ``embed_prefix`` flags
all language tokens ``att=0``; ``make_att_2d_masks`` turns that into
a single fully BIDIRECTIONAL block. So during the text-loss forward,
a supervised subtask token's hidden state attends to the very tokens
it is trained to predict. The cross-entropy degenerates into a copy
task — ``text_loss → ~3e-5`` not because the model learned to
predict subtasks but because it can see the answer.
At inference ``select_message`` decodes autoregressively (causally):
each token must be predicted WITHOUT seeing it — a task the model
was never actually trained on. Hence the universal collapse: a
coherent first token or two ("grasp the yellow cube"), then a loop
("cover cover cover", "icatorsicators", "the the the").
Fix: ``_mark_target_span_causal`` sets ``att=1`` on the language
positions that are supervised targets (``text_labels != -100``).
With make_att_2d_masks's cumulative-block rule each target token
then attends to images + the user prompt bidirectionally and to
EARLIER target tokens only — genuine causal next-token prediction,
matching select_message. Applied in both ``_compute_text_loss`` and
``_compute_fused_loss``. Per-sample correct: high_level_subtask
targets become causal; low_level_execution subtasks (a user turn,
labels all -100) stay bidirectional so the action expert reads them
as bidirectional context. The action expert is otherwise unaffected
— the suffix has a strictly higher cumsum and still attends to the
whole prefix.
Requires retraining: this changes the training objective. Existing
checkpoints were all trained on the degenerate copy task and cannot
generate text. Expect ``text_loss`` to settle MUCH higher than 3e-5
after this — that is correct; it is now a real prediction task.
NOTE: pi052's text path (PaliGemma prefix-LM) has the same
bidirectional-block structure and needs the analogous fix.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -85,6 +85,40 @@ def _locate_lang_range(prefix_att_masks: Tensor, num_lang: int) -> tuple[int, in
|
|||||||
return lang_start, lang_end
|
return lang_start, lang_end
|
||||||
|
|
||||||
|
|
||||||
|
def _mark_target_span_causal(
|
||||||
|
prefix_att_masks: Tensor, text_labels: Tensor, lang_start: int, lang_end: int
|
||||||
|
) -> Tensor:
|
||||||
|
"""Make the supervised text-target span causally masked.
|
||||||
|
|
||||||
|
``embed_prefix`` flags every language token with ``att=0``, which
|
||||||
|
``make_att_2d_masks`` turns into one fully *bidirectional* block —
|
||||||
|
so a target token's hidden state attends to the very tokens it is
|
||||||
|
supposed to predict. The text cross-entropy then degenerates into
|
||||||
|
a copy task (loss → ~0) and the model never learns causal
|
||||||
|
next-token prediction — at inference, where ``select_message``
|
||||||
|
decodes autoregressively (causally), it collapses.
|
||||||
|
|
||||||
|
Fix: set ``att=1`` on the language positions that are supervised
|
||||||
|
targets (``text_labels != -100``). With ``make_att_2d_masks``'s
|
||||||
|
cumulative-block rule each target token then attends to images +
|
||||||
|
the user prompt bidirectionally and to *earlier* target tokens
|
||||||
|
only — i.e. genuine causal next-token prediction, matching
|
||||||
|
inference. Non-target language (the user prompt, and the
|
||||||
|
``low_level_execution`` subtask which is a user turn, not a
|
||||||
|
target) stays ``att=0`` / bidirectional. The action expert is
|
||||||
|
unaffected: the suffix has a strictly higher cumsum so it still
|
||||||
|
attends to every prefix token.
|
||||||
|
"""
|
||||||
|
att = prefix_att_masks.clone()
|
||||||
|
n = min(text_labels.shape[1], lang_end - lang_start)
|
||||||
|
if n <= 0:
|
||||||
|
return att
|
||||||
|
target = text_labels[:, :n] != -100 # (B, n) bool
|
||||||
|
seg = att[:, lang_start : lang_start + n].bool()
|
||||||
|
att[:, lang_start : lang_start + n] = seg | target
|
||||||
|
return att
|
||||||
|
|
||||||
|
|
||||||
def _shifted_ce(logits: Tensor, text_labels: Tensor) -> Tensor:
|
def _shifted_ce(logits: Tensor, text_labels: Tensor) -> Tensor:
|
||||||
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100."""
|
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100."""
|
||||||
num_lang = logits.shape[1]
|
num_lang = logits.shape[1]
|
||||||
@@ -287,6 +321,12 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
|||||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||||
images, img_masks, lang_tokens, lang_masks, state=state
|
images, img_masks, lang_tokens, lang_masks, state=state
|
||||||
)
|
)
|
||||||
|
# Causally mask the supervised target span so the text-CE is
|
||||||
|
# genuine next-token prediction (see ``_mark_target_span_causal``).
|
||||||
|
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
|
||||||
|
prefix_att_masks = _mark_target_span_causal(
|
||||||
|
prefix_att_masks, text_labels, lang_start, lang_end
|
||||||
|
)
|
||||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
|
|
||||||
@@ -306,7 +346,10 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
|||||||
"states — text-loss path needs them."
|
"states — text-loss path needs them."
|
||||||
)
|
)
|
||||||
|
|
||||||
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
|
# ``lang_start`` / ``lang_end`` were located above on the
|
||||||
|
# *unmodified* att masks — don't recompute here, because
|
||||||
|
# ``_mark_target_span_causal`` set target lang tokens to 1 and
|
||||||
|
# ``_locate_lang_range`` keys on the first 1 (the state token).
|
||||||
vlm = self.model.vlm_with_expert.vlm
|
vlm = self.model.vlm_with_expert.vlm
|
||||||
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
|
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
|
||||||
logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab)
|
logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab)
|
||||||
@@ -363,6 +406,17 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
|||||||
prefix_embs, prefix_pad_masks, prefix_att_masks = inner.embed_prefix(
|
prefix_embs, prefix_pad_masks, prefix_att_masks = inner.embed_prefix(
|
||||||
images, img_masks, lang_tokens, lang_masks, state=state
|
images, img_masks, lang_tokens, lang_masks, state=state
|
||||||
)
|
)
|
||||||
|
# Causally mask the supervised text-target span (see
|
||||||
|
# ``_mark_target_span_causal``). Per-sample: high_level_subtask
|
||||||
|
# samples have a subtask target → causal; low_level_execution
|
||||||
|
# samples have all -100 labels → untouched / bidirectional, so
|
||||||
|
# the action expert still reads the subtask as bidirectional
|
||||||
|
# context. ``lang_start`` / ``lang_end`` located here on the
|
||||||
|
# unmodified mask and reused for the text-loss slice below.
|
||||||
|
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
|
||||||
|
prefix_att_masks = _mark_target_span_causal(
|
||||||
|
prefix_att_masks, text_labels, lang_start, lang_end
|
||||||
|
)
|
||||||
suffix_embs, suffix_pad_masks, suffix_att_masks = inner.embed_suffix(x_t, time)
|
suffix_embs, suffix_pad_masks, suffix_att_masks = inner.embed_suffix(x_t, time)
|
||||||
|
|
||||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||||
@@ -412,7 +466,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
|||||||
flow_loss = per_sample_flow.mean()
|
flow_loss = per_sample_flow.mean()
|
||||||
|
|
||||||
# ---------------- text loss (lang slice of prefix) ---------------
|
# ---------------- text loss (lang slice of prefix) ---------------
|
||||||
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
|
# ``lang_start`` / ``lang_end`` from above (unmodified mask).
|
||||||
vlm = inner.vlm_with_expert.vlm
|
vlm = inner.vlm_with_expert.vlm
|
||||||
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
|
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
|
||||||
logits = vlm.lm_head(lang_hidden)
|
logits = vlm.lm_head(lang_hidden)
|
||||||
|
|||||||
Reference in New Issue
Block a user