fix(smolvla2): suppress all special tokens during min_new_tokens window

Previous attempt only masked the tokenizer's eos_token_id during the
min_new_tokens prefix. The empty-completion symptom persisted because a
memorised SmolVLM head doesn't just want EOS — its top-1 at position 0
is *some* special token, and when EOS is masked the argmax shifts to a
sibling (``<|im_end|>``, ``<image>``, ``<fake_token_around_image>``,
``<row_X_col_Y>``, …). Those tokens survive generation but then get
stripped by ``decode(skip_special_tokens=True)``, so the runtime still
saw ``last_raw='(empty)'`` every chunk boundary.

Mask the full ``tokenizer.all_special_ids`` set instead. Forces the
head to commit to a normal vocabulary token before it can close or
quietly poison the turn.

Also: when decode returns empty but tokens *were* generated, expose
the raw token ids and the special-tokens-included decoded string via
``policy._last_select_message_debug``. The runtime surfaces this in
the scrollback so the operator can see what the head is actually
emitting — distinguishing "head EOS-ing" from "head emitting image
placeholders" from "head emitting chat-template fragments".

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-12 17:49:53 +02:00
parent b95eebff77
commit 1292304c42
2 changed files with 66 additions and 21 deletions
@@ -397,13 +397,20 @@ class HighLevelSubtaskFwd(InferenceStep):
if not msg:
empties = state.get("subtask_empty_count", 0) + 1
state["subtask_empty_count"] = empties
if empties == 1 or empties % 10 == 0:
push_log(
state,
f" [info] subtask gen returned empty (×{empties}) — "
"model EOS-ing immediately or generation raised "
"(check stderr / -v for traceback).",
)
if empties == 1 or empties % 5 == 0:
debug = getattr(self.policy, "_last_select_message_debug", "") or ""
if debug:
push_log(
state,
f" [info] subtask gen empty (×{empties}); {debug}",
)
else:
push_log(
state,
f" [info] subtask gen returned empty (×{empties}) — "
"no tokens generated (head EOS-ing before any "
"non-special token).",
)
if msg and _looks_like_gibberish(msg):
# Bump a counter so the operator can see the model is
# struggling without spamming the log every tick. A first
@@ -304,6 +304,25 @@ class SmolVLA2Policy(SmolVLAPolicy):
if eos_token_id is None:
eos_token_id = tokenizer.eos_token_id
# Build the full set of special-token ids to suppress during
# the ``min_new_tokens`` window. EOS alone is not enough on a
# memorised SmolVLM head — when EOS is masked, the argmax
# falls onto a sibling special token (``<|im_end|>``,
# ``<image>``, ``<fake_token_around_image>``, ``<row_X_col_Y>``,
# …) which then survives generation but gets stripped by
# ``skip_special_tokens=True`` so ``decode`` returns an empty
# string and the runtime sees ``last_raw='(empty)'`` every
# chunk boundary.
special_ids_set: set[int] = set()
try:
for sid in (tokenizer.all_special_ids or []):
if sid is not None:
special_ids_set.add(int(sid))
except Exception: # noqa: BLE001
pass
if eos_token_id is not None:
special_ids_set.add(int(eos_token_id))
# Match training's text-loss forward path (see
# ``_compute_text_loss`` above): build the full prefix via
# ``embed_prefix`` so images + state conditioning is intact,
@@ -366,19 +385,20 @@ class SmolVLA2Policy(SmolVLAPolicy):
last_hidden = prefix_out[:, -1:].to(vlm.lm_head.weight.dtype)
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
# Suppress EOS until we've decoded ``min_new_tokens`` real
# tokens. Without this, a memorised LM head whose argmax
# at position 0 is EOS produces an empty completion every
# time — confirmed in the real-robot run (the runtime's
# ``subtask_empty_count`` climbed every chunk boundary
# with no exception). Masking EOS for the first N steps
# forces the head to commit to a real token before it can
# close the turn.
if (
eos_token_id is not None
and len(generated) < min_new_tokens
):
logits_step[..., eos_token_id] = float("-inf")
# Suppress *all* special tokens until we've decoded
# ``min_new_tokens`` real (renderable) tokens. Without
# this, a memorised SmolVLM head whose argmax at position
# 0 is a special token produces an empty completion every
# time — either EOS directly, or (after we mask EOS) the
# argmax shifts to a sibling special id (``<|im_end|>``,
# ``<image>``, ``<row_X_col_Y>``, …) which decode strips
# via ``skip_special_tokens=True``. Masking the full
# ``all_special_ids`` set for the first N steps forces
# the head to commit to a normal vocabulary token before
# it can close (or quietly poison) the turn.
if special_ids_set and len(generated) < min_new_tokens:
for sid in special_ids_set:
logits_step[..., sid] = float("-inf")
next_ids = self._sample_next_token(logits_step, temperature, top_p)
tok_id = int(next_ids[0].item())
generated.append(tok_id)
@@ -393,7 +413,25 @@ class SmolVLA2Policy(SmolVLAPolicy):
current_pad = torch.cat([current_pad, ones_step], dim=1)
current_att = torch.cat([current_att, ones_step], dim=1)
return tokenizer.decode(generated, skip_special_tokens=True).strip()
decoded = tokenizer.decode(generated, skip_special_tokens=True).strip()
# When the visible decoded string is empty but tokens *were*
# generated, expose what those raw tokens decoded to without
# the special-token filter. This is what the runtime turns
# into a scrollback line when ``last_raw='(empty)'`` so the
# operator can tell whether the head is emitting EOS, image
# placeholder tokens, the chat-template ``<|im_end|>`` shard,
# or something else.
if not decoded and generated:
try:
self._last_select_message_debug = (
f"raw_ids={generated[:16]} "
f"decoded_w_special={tokenizer.decode(generated, skip_special_tokens=False)!r}"
)
except Exception: # noqa: BLE001
self._last_select_message_debug = f"raw_ids={generated[:16]}"
else:
self._last_select_message_debug = ""
return decoded
@staticmethod
def _sample_next_token(