fix(smolvla2): force min_new_tokens + sampling so memorised LM emits something

Real-robot run confirmed the LM head is producing 0 tokens at every
chunk boundary (empty:N counter climbing, no exception in scrollback):
the model EOS-es at decode step 0. That's the memorisation collapse —
training reached text_loss=6e-6 by overfitting one trajectory whose
supervised subtask turn ended in EOS, and at inference the head's
argmax for token 0 is EOS regardless of the actual frame.

Two changes in select_message:

  * ``min_new_tokens`` parameter masks the EOS logit to -inf until at
    least N real tokens have been decoded. Without this the head's
    "EOS first" prior produces an empty completion every single time.

  * The runtime callers now pass ``min_new_tokens=5..10`` plus
    ``temperature=0.4..0.5`` + ``top_p=0.9``. Sampling at moderate
    temperature with nucleus filtering also helps break the greedy
    argmax collapse — when the model has memorised one continuation,
    greedy keeps replaying it; nucleus sampling forces it to commit
    to *some* coherent continuation that's well-supported by the
    prefix even when greedy's top-1 is degenerate.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-12 17:48:08 +02:00
parent fbcac95662
commit b95eebff77
2 changed files with 64 additions and 5 deletions
@@ -365,8 +365,23 @@ class HighLevelSubtaskFwd(InferenceStep):
return None
ctx = _msgs_for_subtask(state)
observation = _maybe_observation(self.observation_provider)
# Force the head to commit to ≥ 5 real tokens before it can
# close the turn, and sample at moderate temperature with
# nucleus filtering. On a memorised head whose argmax at
# position 0 is EOS, greedy decoding silently produced empty
# completions every chunk boundary (visible as the
# ``empty:N`` counter climbing). Temp 0.4 + top_p 0.9 is well
# below where SmolVLM goes incoherent and above where greedy
# collapse re-emerges.
msg = _generate_with_policy(
self.policy, ctx, observation=observation, state=state, label="subtask gen"
self.policy,
ctx,
observation=observation,
state=state,
label="subtask gen",
min_new_tokens=5,
temperature=0.4,
top_p=0.9,
)
# Diagnostics: surface what the model is *actually* producing
# at chunk boundaries, even when the output gets rejected or
@@ -447,7 +462,14 @@ class MemoryUpdateFwd(InferenceStep):
ctx = _msgs_for_memory(state)
observation = _maybe_observation(self.observation_provider)
new_memory = _generate_with_policy(
self.policy, ctx, observation=observation, state=state, label="memory gen"
self.policy,
ctx,
observation=observation,
state=state,
label="memory gen",
min_new_tokens=5,
temperature=0.4,
top_p=0.9,
)
state["last_memory_raw"] = new_memory or ""
if new_memory and _looks_like_gibberish(new_memory):
@@ -486,7 +508,14 @@ class UserInterjectionFwd(InferenceStep):
ctx = _msgs_for_interjection(state)
observation = _maybe_observation(self.observation_provider)
out = _generate_with_policy(
self.policy, ctx, observation=observation, state=state, label="plan/say gen"
self.policy,
ctx,
observation=observation,
state=state,
label="plan/say gen",
min_new_tokens=10,
temperature=0.5,
top_p=0.9,
)
if not out:
# Don't log every empty completion — happens repeatedly on
@@ -551,7 +580,14 @@ class AskVQAFwd(InferenceStep):
ctx = _msgs_for_vqa(question)
observation = _maybe_observation(self.observation_provider)
answer = _generate_with_policy(
self.policy, ctx, observation=observation, state=state, label="vqa gen"
self.policy,
ctx,
observation=observation,
state=state,
label="vqa gen",
min_new_tokens=3,
temperature=0.4,
top_p=0.9,
)
# VQA answers are intentionally JSON-like during training, so
# ``_looks_like_gibberish`` would false-positive on them. Keep
@@ -763,6 +799,9 @@ def _generate_with_policy(
observation: dict | None = None,
state: dict[str, Any] | None = None,
label: str = "select_message",
min_new_tokens: int = 0,
temperature: float = 0.0,
top_p: float = 1.0,
) -> str:
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
@@ -797,7 +836,13 @@ def _generate_with_policy(
for k, v in observation.items():
if isinstance(k, str) and k.startswith("observation.") and k not in batch:
batch[k] = v
return policy.select_message(batch, tokenizer=text_batch["tokenizer"])
return policy.select_message(
batch,
tokenizer=text_batch["tokenizer"],
min_new_tokens=min_new_tokens,
temperature=temperature,
top_p=top_p,
)
except Exception as exc: # noqa: BLE001
logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG))
if state is not None:
@@ -263,6 +263,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
batch: dict[str, Tensor],
*,
max_new_tokens: int = 256,
min_new_tokens: int = 0,
eos_token_id: int | None = None,
temperature: float = 0.0,
top_p: float = 1.0,
@@ -365,6 +366,19 @@ class SmolVLA2Policy(SmolVLAPolicy):
last_hidden = prefix_out[:, -1:].to(vlm.lm_head.weight.dtype)
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
# Suppress EOS until we've decoded ``min_new_tokens`` real
# tokens. Without this, a memorised LM head whose argmax
# at position 0 is EOS produces an empty completion every
# time — confirmed in the real-robot run (the runtime's
# ``subtask_empty_count`` climbed every chunk boundary
# with no exception). Masking EOS for the first N steps
# forces the head to commit to a real token before it can
# close the turn.
if (
eos_token_id is not None
and len(generated) < min_new_tokens
):
logits_step[..., eos_token_id] = float("-inf")
next_ids = self._sample_next_token(logits_step, temperature, top_p)
tok_id = int(next_ids[0].item())
generated.append(tok_id)