mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
fix(smolvla2): suppress all special tokens during min_new_tokens window
Previous attempt only masked the tokenizer's eos_token_id during the min_new_tokens prefix. The empty-completion symptom persisted because a memorised SmolVLM head doesn't just want EOS — its top-1 at position 0 is *some* special token, and when EOS is masked the argmax shifts to a sibling (``<|im_end|>``, ``<image>``, ``<fake_token_around_image>``, ``<row_X_col_Y>``, …). Those tokens survive generation but then get stripped by ``decode(skip_special_tokens=True)``, so the runtime still saw ``last_raw='(empty)'`` every chunk boundary. Mask the full ``tokenizer.all_special_ids`` set instead. Forces the head to commit to a normal vocabulary token before it can close or quietly poison the turn. Also: when decode returns empty but tokens *were* generated, expose the raw token ids and the special-tokens-included decoded string via ``policy._last_select_message_debug``. The runtime surfaces this in the scrollback so the operator can see what the head is actually emitting — distinguishing "head EOS-ing" from "head emitting image placeholders" from "head emitting chat-template fragments". Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -397,12 +397,19 @@ class HighLevelSubtaskFwd(InferenceStep):
|
|||||||
if not msg:
|
if not msg:
|
||||||
empties = state.get("subtask_empty_count", 0) + 1
|
empties = state.get("subtask_empty_count", 0) + 1
|
||||||
state["subtask_empty_count"] = empties
|
state["subtask_empty_count"] = empties
|
||||||
if empties == 1 or empties % 10 == 0:
|
if empties == 1 or empties % 5 == 0:
|
||||||
|
debug = getattr(self.policy, "_last_select_message_debug", "") or ""
|
||||||
|
if debug:
|
||||||
|
push_log(
|
||||||
|
state,
|
||||||
|
f" [info] subtask gen empty (×{empties}); {debug}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
push_log(
|
push_log(
|
||||||
state,
|
state,
|
||||||
f" [info] subtask gen returned empty (×{empties}) — "
|
f" [info] subtask gen returned empty (×{empties}) — "
|
||||||
"model EOS-ing immediately or generation raised "
|
"no tokens generated (head EOS-ing before any "
|
||||||
"(check stderr / -v for traceback).",
|
"non-special token).",
|
||||||
)
|
)
|
||||||
if msg and _looks_like_gibberish(msg):
|
if msg and _looks_like_gibberish(msg):
|
||||||
# Bump a counter so the operator can see the model is
|
# Bump a counter so the operator can see the model is
|
||||||
|
|||||||
@@ -304,6 +304,25 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
|||||||
if eos_token_id is None:
|
if eos_token_id is None:
|
||||||
eos_token_id = tokenizer.eos_token_id
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
|
# Build the full set of special-token ids to suppress during
|
||||||
|
# the ``min_new_tokens`` window. EOS alone is not enough on a
|
||||||
|
# memorised SmolVLM head — when EOS is masked, the argmax
|
||||||
|
# falls onto a sibling special token (``<|im_end|>``,
|
||||||
|
# ``<image>``, ``<fake_token_around_image>``, ``<row_X_col_Y>``,
|
||||||
|
# …) which then survives generation but gets stripped by
|
||||||
|
# ``skip_special_tokens=True`` so ``decode`` returns an empty
|
||||||
|
# string and the runtime sees ``last_raw='(empty)'`` every
|
||||||
|
# chunk boundary.
|
||||||
|
special_ids_set: set[int] = set()
|
||||||
|
try:
|
||||||
|
for sid in (tokenizer.all_special_ids or []):
|
||||||
|
if sid is not None:
|
||||||
|
special_ids_set.add(int(sid))
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
pass
|
||||||
|
if eos_token_id is not None:
|
||||||
|
special_ids_set.add(int(eos_token_id))
|
||||||
|
|
||||||
# Match training's text-loss forward path (see
|
# Match training's text-loss forward path (see
|
||||||
# ``_compute_text_loss`` above): build the full prefix via
|
# ``_compute_text_loss`` above): build the full prefix via
|
||||||
# ``embed_prefix`` so images + state conditioning is intact,
|
# ``embed_prefix`` so images + state conditioning is intact,
|
||||||
@@ -366,19 +385,20 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
|||||||
|
|
||||||
last_hidden = prefix_out[:, -1:].to(vlm.lm_head.weight.dtype)
|
last_hidden = prefix_out[:, -1:].to(vlm.lm_head.weight.dtype)
|
||||||
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
|
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
|
||||||
# Suppress EOS until we've decoded ``min_new_tokens`` real
|
# Suppress *all* special tokens until we've decoded
|
||||||
# tokens. Without this, a memorised LM head whose argmax
|
# ``min_new_tokens`` real (renderable) tokens. Without
|
||||||
# at position 0 is EOS produces an empty completion every
|
# this, a memorised SmolVLM head whose argmax at position
|
||||||
# time — confirmed in the real-robot run (the runtime's
|
# 0 is a special token produces an empty completion every
|
||||||
# ``subtask_empty_count`` climbed every chunk boundary
|
# time — either EOS directly, or (after we mask EOS) the
|
||||||
# with no exception). Masking EOS for the first N steps
|
# argmax shifts to a sibling special id (``<|im_end|>``,
|
||||||
# forces the head to commit to a real token before it can
|
# ``<image>``, ``<row_X_col_Y>``, …) which decode strips
|
||||||
# close the turn.
|
# via ``skip_special_tokens=True``. Masking the full
|
||||||
if (
|
# ``all_special_ids`` set for the first N steps forces
|
||||||
eos_token_id is not None
|
# the head to commit to a normal vocabulary token before
|
||||||
and len(generated) < min_new_tokens
|
# it can close (or quietly poison) the turn.
|
||||||
):
|
if special_ids_set and len(generated) < min_new_tokens:
|
||||||
logits_step[..., eos_token_id] = float("-inf")
|
for sid in special_ids_set:
|
||||||
|
logits_step[..., sid] = float("-inf")
|
||||||
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
||||||
tok_id = int(next_ids[0].item())
|
tok_id = int(next_ids[0].item())
|
||||||
generated.append(tok_id)
|
generated.append(tok_id)
|
||||||
@@ -393,7 +413,25 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
|||||||
current_pad = torch.cat([current_pad, ones_step], dim=1)
|
current_pad = torch.cat([current_pad, ones_step], dim=1)
|
||||||
current_att = torch.cat([current_att, ones_step], dim=1)
|
current_att = torch.cat([current_att, ones_step], dim=1)
|
||||||
|
|
||||||
return tokenizer.decode(generated, skip_special_tokens=True).strip()
|
decoded = tokenizer.decode(generated, skip_special_tokens=True).strip()
|
||||||
|
# When the visible decoded string is empty but tokens *were*
|
||||||
|
# generated, expose what those raw tokens decoded to without
|
||||||
|
# the special-token filter. This is what the runtime turns
|
||||||
|
# into a scrollback line when ``last_raw='(empty)'`` so the
|
||||||
|
# operator can tell whether the head is emitting EOS, image
|
||||||
|
# placeholder tokens, the chat-template ``<|im_end|>`` shard,
|
||||||
|
# or something else.
|
||||||
|
if not decoded and generated:
|
||||||
|
try:
|
||||||
|
self._last_select_message_debug = (
|
||||||
|
f"raw_ids={generated[:16]} "
|
||||||
|
f"decoded_w_special={tokenizer.decode(generated, skip_special_tokens=False)!r}"
|
||||||
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
self._last_select_message_debug = f"raw_ids={generated[:16]}"
|
||||||
|
else:
|
||||||
|
self._last_select_message_debug = ""
|
||||||
|
return decoded
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sample_next_token(
|
def _sample_next_token(
|
||||||
|
|||||||
Reference in New Issue
Block a user