From 1292304c425db617253cc4556bff8b8f5f3ec496 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 12 May 2026 17:49:53 +0200 Subject: [PATCH] fix(smolvla2): suppress all special tokens during min_new_tokens window MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previous attempt only masked the tokenizer's eos_token_id during the min_new_tokens prefix. The empty-completion symptom persisted because a memorised SmolVLM head doesn't just want EOS — its top-1 at position 0 is *some* special token, and when EOS is masked the argmax shifts to a sibling (``<|im_end|>``, ````, ````, ````, …). Those tokens survive generation but then get stripped by ``decode(skip_special_tokens=True)``, so the runtime still saw ``last_raw='(empty)'`` every chunk boundary. Mask the full ``tokenizer.all_special_ids`` set instead. Forces the head to commit to a normal vocabulary token before it can close or quietly poison the turn. Also: when decode returns empty but tokens *were* generated, expose the raw token ids and the special-tokens-included decoded string via ``policy._last_select_message_debug``. The runtime surfaces this in the scrollback so the operator can see what the head is actually emitting — distinguishing "head EOS-ing" from "head emitting image placeholders" from "head emitting chat-template fragments". Co-Authored-By: Claude Opus 4.7 (1M context) --- .../policies/smolvla2/inference/steps.py | 21 ++++-- .../policies/smolvla2/modeling_smolvla2.py | 66 +++++++++++++++---- 2 files changed, 66 insertions(+), 21 deletions(-) diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index 4306f7d2d..03e902dfe 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -397,13 +397,20 @@ class HighLevelSubtaskFwd(InferenceStep): if not msg: empties = state.get("subtask_empty_count", 0) + 1 state["subtask_empty_count"] = empties - if empties == 1 or empties % 10 == 0: - push_log( - state, - f" [info] subtask gen returned empty (×{empties}) — " - "model EOS-ing immediately or generation raised " - "(check stderr / -v for traceback).", - ) + if empties == 1 or empties % 5 == 0: + debug = getattr(self.policy, "_last_select_message_debug", "") or "" + if debug: + push_log( + state, + f" [info] subtask gen empty (×{empties}); {debug}", + ) + else: + push_log( + state, + f" [info] subtask gen returned empty (×{empties}) — " + "no tokens generated (head EOS-ing before any " + "non-special token).", + ) if msg and _looks_like_gibberish(msg): # Bump a counter so the operator can see the model is # struggling without spamming the log every tick. A first diff --git a/src/lerobot/policies/smolvla2/modeling_smolvla2.py b/src/lerobot/policies/smolvla2/modeling_smolvla2.py index a808363ef..e0235e53f 100644 --- a/src/lerobot/policies/smolvla2/modeling_smolvla2.py +++ b/src/lerobot/policies/smolvla2/modeling_smolvla2.py @@ -304,6 +304,25 @@ class SmolVLA2Policy(SmolVLAPolicy): if eos_token_id is None: eos_token_id = tokenizer.eos_token_id + # Build the full set of special-token ids to suppress during + # the ``min_new_tokens`` window. EOS alone is not enough on a + # memorised SmolVLM head — when EOS is masked, the argmax + # falls onto a sibling special token (``<|im_end|>``, + # ````, ````, ````, + # …) which then survives generation but gets stripped by + # ``skip_special_tokens=True`` so ``decode`` returns an empty + # string and the runtime sees ``last_raw='(empty)'`` every + # chunk boundary. + special_ids_set: set[int] = set() + try: + for sid in (tokenizer.all_special_ids or []): + if sid is not None: + special_ids_set.add(int(sid)) + except Exception: # noqa: BLE001 + pass + if eos_token_id is not None: + special_ids_set.add(int(eos_token_id)) + # Match training's text-loss forward path (see # ``_compute_text_loss`` above): build the full prefix via # ``embed_prefix`` so images + state conditioning is intact, @@ -366,19 +385,20 @@ class SmolVLA2Policy(SmolVLAPolicy): last_hidden = prefix_out[:, -1:].to(vlm.lm_head.weight.dtype) logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V) - # Suppress EOS until we've decoded ``min_new_tokens`` real - # tokens. Without this, a memorised LM head whose argmax - # at position 0 is EOS produces an empty completion every - # time — confirmed in the real-robot run (the runtime's - # ``subtask_empty_count`` climbed every chunk boundary - # with no exception). Masking EOS for the first N steps - # forces the head to commit to a real token before it can - # close the turn. - if ( - eos_token_id is not None - and len(generated) < min_new_tokens - ): - logits_step[..., eos_token_id] = float("-inf") + # Suppress *all* special tokens until we've decoded + # ``min_new_tokens`` real (renderable) tokens. Without + # this, a memorised SmolVLM head whose argmax at position + # 0 is a special token produces an empty completion every + # time — either EOS directly, or (after we mask EOS) the + # argmax shifts to a sibling special id (``<|im_end|>``, + # ````, ````, …) which decode strips + # via ``skip_special_tokens=True``. Masking the full + # ``all_special_ids`` set for the first N steps forces + # the head to commit to a normal vocabulary token before + # it can close (or quietly poison) the turn. + if special_ids_set and len(generated) < min_new_tokens: + for sid in special_ids_set: + logits_step[..., sid] = float("-inf") next_ids = self._sample_next_token(logits_step, temperature, top_p) tok_id = int(next_ids[0].item()) generated.append(tok_id) @@ -393,7 +413,25 @@ class SmolVLA2Policy(SmolVLAPolicy): current_pad = torch.cat([current_pad, ones_step], dim=1) current_att = torch.cat([current_att, ones_step], dim=1) - return tokenizer.decode(generated, skip_special_tokens=True).strip() + decoded = tokenizer.decode(generated, skip_special_tokens=True).strip() + # When the visible decoded string is empty but tokens *were* + # generated, expose what those raw tokens decoded to without + # the special-token filter. This is what the runtime turns + # into a scrollback line when ``last_raw='(empty)'`` so the + # operator can tell whether the head is emitting EOS, image + # placeholder tokens, the chat-template ``<|im_end|>`` shard, + # or something else. + if not decoded and generated: + try: + self._last_select_message_debug = ( + f"raw_ids={generated[:16]} " + f"decoded_w_special={tokenizer.decode(generated, skip_special_tokens=False)!r}" + ) + except Exception: # noqa: BLE001 + self._last_select_message_debug = f"raw_ids={generated[:16]}" + else: + self._last_select_message_debug = "" + return decoded @staticmethod def _sample_next_token(