fix(smolvla2): select_message must decode from the language position

``embed_prefix`` lays the prefix out as ``[images, lang, state]`` with
the state token LAST. Training supervises the text head on the
*language* positions (``_compute_text_loss`` / ``_compute_fused_loss``
slice ``prefix_out[lang_start:lang_end]`` and run lm_head there).

But ``select_message`` started AR generation from the full prefix and
read ``prefix_out[:, -1:]`` — the **state token** — to decode the
first subtask token. The state token's hidden state exists for the
action expert to read; the lm_head was never trained to produce
subtask text from it. So inference decoded the high-level head from a
position entirely outside the training distribution: the text head
collapses (``the arm the arm``, ``grasp the surface population``,
``_333 absburg…``) no matter how cleanly ``text_loss`` converged.

Fix: truncate the state token off the prefix before the AR loop, so
``prefix_out[:, -1:]`` is the last language token (right after the
``Assistant:`` generation prompt) — exactly where training supervised.

Inference-only change — no retraining needed; existing checkpoints
benefit immediately. The action path (``predict_action_chunk``) is
untouched: state belongs in the action expert's prefix.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-16 15:05:16 +02:00
parent 56068d37ea
commit db03fc6dc4
@@ -520,6 +520,22 @@ class SmolVLA2Policy(SmolVLAPolicy):
images, img_masks, lang_tokens, lang_masks, state=state images, img_masks, lang_tokens, lang_masks, state=state
) )
# ``embed_prefix`` lays the prefix out as ``[images, lang, state]``
# — the state token is LAST. Training supervises the text head on
# the *language* positions (see ``_compute_text_loss`` /
# ``_compute_fused_loss``: lm_head over ``prefix_out[lang_start:
# lang_end]``). So AR text generation must continue from the last
# language token (right after the ``Assistant:`` generation
# prompt) — NOT from the state token, whose hidden state exists
# for the action expert to read and which the lm_head was never
# trained to decode subtask text from. Truncating the state token
# here makes ``prefix_out[:, -1:]`` in the loop below the last
# language position, matching the training distribution.
_, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
prefix_embs = prefix_embs[:, :lang_end]
prefix_pad_masks = prefix_pad_masks[:, :lang_end]
prefix_att_masks = prefix_att_masks[:, :lang_end]
device = prefix_embs.device device = prefix_embs.device
bsize = prefix_embs.shape[0] bsize = prefix_embs.shape[0]
vlm = self.model.vlm_with_expert.vlm vlm = self.model.vlm_with_expert.vlm