mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
fix(smolvla2): suppress all special tokens during min_new_tokens window
Previous attempt only masked the tokenizer's eos_token_id during the min_new_tokens prefix. The empty-completion symptom persisted because a memorised SmolVLM head doesn't just want EOS — its top-1 at position 0 is *some* special token, and when EOS is masked the argmax shifts to a sibling (``<|im_end|>``, ``<image>``, ``<fake_token_around_image>``, ``<row_X_col_Y>``, …). Those tokens survive generation but then get stripped by ``decode(skip_special_tokens=True)``, so the runtime still saw ``last_raw='(empty)'`` every chunk boundary. Mask the full ``tokenizer.all_special_ids`` set instead. Forces the head to commit to a normal vocabulary token before it can close or quietly poison the turn. Also: when decode returns empty but tokens *were* generated, expose the raw token ids and the special-tokens-included decoded string via ``policy._last_select_message_debug``. The runtime surfaces this in the scrollback so the operator can see what the head is actually emitting — distinguishing "head EOS-ing" from "head emitting image placeholders" from "head emitting chat-template fragments". Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -397,13 +397,20 @@ class HighLevelSubtaskFwd(InferenceStep):
|
||||
if not msg:
|
||||
empties = state.get("subtask_empty_count", 0) + 1
|
||||
state["subtask_empty_count"] = empties
|
||||
if empties == 1 or empties % 10 == 0:
|
||||
push_log(
|
||||
state,
|
||||
f" [info] subtask gen returned empty (×{empties}) — "
|
||||
"model EOS-ing immediately or generation raised "
|
||||
"(check stderr / -v for traceback).",
|
||||
)
|
||||
if empties == 1 or empties % 5 == 0:
|
||||
debug = getattr(self.policy, "_last_select_message_debug", "") or ""
|
||||
if debug:
|
||||
push_log(
|
||||
state,
|
||||
f" [info] subtask gen empty (×{empties}); {debug}",
|
||||
)
|
||||
else:
|
||||
push_log(
|
||||
state,
|
||||
f" [info] subtask gen returned empty (×{empties}) — "
|
||||
"no tokens generated (head EOS-ing before any "
|
||||
"non-special token).",
|
||||
)
|
||||
if msg and _looks_like_gibberish(msg):
|
||||
# Bump a counter so the operator can see the model is
|
||||
# struggling without spamming the log every tick. A first
|
||||
|
||||
@@ -304,6 +304,25 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
if eos_token_id is None:
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
# Build the full set of special-token ids to suppress during
|
||||
# the ``min_new_tokens`` window. EOS alone is not enough on a
|
||||
# memorised SmolVLM head — when EOS is masked, the argmax
|
||||
# falls onto a sibling special token (``<|im_end|>``,
|
||||
# ``<image>``, ``<fake_token_around_image>``, ``<row_X_col_Y>``,
|
||||
# …) which then survives generation but gets stripped by
|
||||
# ``skip_special_tokens=True`` so ``decode`` returns an empty
|
||||
# string and the runtime sees ``last_raw='(empty)'`` every
|
||||
# chunk boundary.
|
||||
special_ids_set: set[int] = set()
|
||||
try:
|
||||
for sid in (tokenizer.all_special_ids or []):
|
||||
if sid is not None:
|
||||
special_ids_set.add(int(sid))
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
if eos_token_id is not None:
|
||||
special_ids_set.add(int(eos_token_id))
|
||||
|
||||
# Match training's text-loss forward path (see
|
||||
# ``_compute_text_loss`` above): build the full prefix via
|
||||
# ``embed_prefix`` so images + state conditioning is intact,
|
||||
@@ -366,19 +385,20 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
|
||||
last_hidden = prefix_out[:, -1:].to(vlm.lm_head.weight.dtype)
|
||||
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
|
||||
# Suppress EOS until we've decoded ``min_new_tokens`` real
|
||||
# tokens. Without this, a memorised LM head whose argmax
|
||||
# at position 0 is EOS produces an empty completion every
|
||||
# time — confirmed in the real-robot run (the runtime's
|
||||
# ``subtask_empty_count`` climbed every chunk boundary
|
||||
# with no exception). Masking EOS for the first N steps
|
||||
# forces the head to commit to a real token before it can
|
||||
# close the turn.
|
||||
if (
|
||||
eos_token_id is not None
|
||||
and len(generated) < min_new_tokens
|
||||
):
|
||||
logits_step[..., eos_token_id] = float("-inf")
|
||||
# Suppress *all* special tokens until we've decoded
|
||||
# ``min_new_tokens`` real (renderable) tokens. Without
|
||||
# this, a memorised SmolVLM head whose argmax at position
|
||||
# 0 is a special token produces an empty completion every
|
||||
# time — either EOS directly, or (after we mask EOS) the
|
||||
# argmax shifts to a sibling special id (``<|im_end|>``,
|
||||
# ``<image>``, ``<row_X_col_Y>``, …) which decode strips
|
||||
# via ``skip_special_tokens=True``. Masking the full
|
||||
# ``all_special_ids`` set for the first N steps forces
|
||||
# the head to commit to a normal vocabulary token before
|
||||
# it can close (or quietly poison) the turn.
|
||||
if special_ids_set and len(generated) < min_new_tokens:
|
||||
for sid in special_ids_set:
|
||||
logits_step[..., sid] = float("-inf")
|
||||
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
||||
tok_id = int(next_ids[0].item())
|
||||
generated.append(tok_id)
|
||||
@@ -393,7 +413,25 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
current_pad = torch.cat([current_pad, ones_step], dim=1)
|
||||
current_att = torch.cat([current_att, ones_step], dim=1)
|
||||
|
||||
return tokenizer.decode(generated, skip_special_tokens=True).strip()
|
||||
decoded = tokenizer.decode(generated, skip_special_tokens=True).strip()
|
||||
# When the visible decoded string is empty but tokens *were*
|
||||
# generated, expose what those raw tokens decoded to without
|
||||
# the special-token filter. This is what the runtime turns
|
||||
# into a scrollback line when ``last_raw='(empty)'`` so the
|
||||
# operator can tell whether the head is emitting EOS, image
|
||||
# placeholder tokens, the chat-template ``<|im_end|>`` shard,
|
||||
# or something else.
|
||||
if not decoded and generated:
|
||||
try:
|
||||
self._last_select_message_debug = (
|
||||
f"raw_ids={generated[:16]} "
|
||||
f"decoded_w_special={tokenizer.decode(generated, skip_special_tokens=False)!r}"
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
self._last_select_message_debug = f"raw_ids={generated[:16]}"
|
||||
else:
|
||||
self._last_select_message_debug = ""
|
||||
return decoded
|
||||
|
||||
@staticmethod
|
||||
def _sample_next_token(
|
||||
|
||||
Reference in New Issue
Block a user