mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 09:39:47 +00:00
fix(smolvla2): select_message must decode from the language position
``embed_prefix`` lays the prefix out as ``[images, lang, state]`` with the state token LAST. Training supervises the text head on the *language* positions (``_compute_text_loss`` / ``_compute_fused_loss`` slice ``prefix_out[lang_start:lang_end]`` and run lm_head there). But ``select_message`` started AR generation from the full prefix and read ``prefix_out[:, -1:]`` — the **state token** — to decode the first subtask token. The state token's hidden state exists for the action expert to read; the lm_head was never trained to produce subtask text from it. So inference decoded the high-level head from a position entirely outside the training distribution: the text head collapses (``the arm the arm``, ``grasp the surface population``, ``_333 absburg…``) no matter how cleanly ``text_loss`` converged. Fix: truncate the state token off the prefix before the AR loop, so ``prefix_out[:, -1:]`` is the last language token (right after the ``Assistant:`` generation prompt) — exactly where training supervised. Inference-only change — no retraining needed; existing checkpoints benefit immediately. The action path (``predict_action_chunk``) is untouched: state belongs in the action expert's prefix. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -520,6 +520,22 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
|||||||
images, img_masks, lang_tokens, lang_masks, state=state
|
images, img_masks, lang_tokens, lang_masks, state=state
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ``embed_prefix`` lays the prefix out as ``[images, lang, state]``
|
||||||
|
# — the state token is LAST. Training supervises the text head on
|
||||||
|
# the *language* positions (see ``_compute_text_loss`` /
|
||||||
|
# ``_compute_fused_loss``: lm_head over ``prefix_out[lang_start:
|
||||||
|
# lang_end]``). So AR text generation must continue from the last
|
||||||
|
# language token (right after the ``Assistant:`` generation
|
||||||
|
# prompt) — NOT from the state token, whose hidden state exists
|
||||||
|
# for the action expert to read and which the lm_head was never
|
||||||
|
# trained to decode subtask text from. Truncating the state token
|
||||||
|
# here makes ``prefix_out[:, -1:]`` in the loop below the last
|
||||||
|
# language position, matching the training distribution.
|
||||||
|
_, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
|
||||||
|
prefix_embs = prefix_embs[:, :lang_end]
|
||||||
|
prefix_pad_masks = prefix_pad_masks[:, :lang_end]
|
||||||
|
prefix_att_masks = prefix_att_masks[:, :lang_end]
|
||||||
|
|
||||||
device = prefix_embs.device
|
device = prefix_embs.device
|
||||||
bsize = prefix_embs.shape[0]
|
bsize = prefix_embs.shape[0]
|
||||||
vlm = self.model.vlm_with_expert.vlm
|
vlm = self.model.vlm_with_expert.vlm
|
||||||
|
|||||||
Reference in New Issue
Block a user