mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-19 01:07:18 +00:00
fix(smolvla2): force min_new_tokens + sampling so memorised LM emits something
Real-robot run confirmed the LM head is producing 0 tokens at every
chunk boundary (empty:N counter climbing, no exception in scrollback):
the model EOS-es at decode step 0. That's the memorisation collapse —
training reached text_loss=6e-6 by overfitting one trajectory whose
supervised subtask turn ended in EOS, and at inference the head's
argmax for token 0 is EOS regardless of the actual frame.
Two changes in select_message:
* ``min_new_tokens`` parameter masks the EOS logit to -inf until at
least N real tokens have been decoded. Without this the head's
"EOS first" prior produces an empty completion every single time.
* The runtime callers now pass ``min_new_tokens=5..10`` plus
``temperature=0.4..0.5`` + ``top_p=0.9``. Sampling at moderate
temperature with nucleus filtering also helps break the greedy
argmax collapse — when the model has memorised one continuation,
greedy keeps replaying it; nucleus sampling forces it to commit
to *some* coherent continuation that's well-supported by the
prefix even when greedy's top-1 is degenerate.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -365,8 +365,23 @@ class HighLevelSubtaskFwd(InferenceStep):
|
||||
return None
|
||||
ctx = _msgs_for_subtask(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
# Force the head to commit to ≥ 5 real tokens before it can
|
||||
# close the turn, and sample at moderate temperature with
|
||||
# nucleus filtering. On a memorised head whose argmax at
|
||||
# position 0 is EOS, greedy decoding silently produced empty
|
||||
# completions every chunk boundary (visible as the
|
||||
# ``empty:N`` counter climbing). Temp 0.4 + top_p 0.9 is well
|
||||
# below where SmolVLM goes incoherent and above where greedy
|
||||
# collapse re-emerges.
|
||||
msg = _generate_with_policy(
|
||||
self.policy, ctx, observation=observation, state=state, label="subtask gen"
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="subtask gen",
|
||||
min_new_tokens=5,
|
||||
temperature=0.4,
|
||||
top_p=0.9,
|
||||
)
|
||||
# Diagnostics: surface what the model is *actually* producing
|
||||
# at chunk boundaries, even when the output gets rejected or
|
||||
@@ -447,7 +462,14 @@ class MemoryUpdateFwd(InferenceStep):
|
||||
ctx = _msgs_for_memory(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
new_memory = _generate_with_policy(
|
||||
self.policy, ctx, observation=observation, state=state, label="memory gen"
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="memory gen",
|
||||
min_new_tokens=5,
|
||||
temperature=0.4,
|
||||
top_p=0.9,
|
||||
)
|
||||
state["last_memory_raw"] = new_memory or ""
|
||||
if new_memory and _looks_like_gibberish(new_memory):
|
||||
@@ -486,7 +508,14 @@ class UserInterjectionFwd(InferenceStep):
|
||||
ctx = _msgs_for_interjection(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
out = _generate_with_policy(
|
||||
self.policy, ctx, observation=observation, state=state, label="plan/say gen"
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="plan/say gen",
|
||||
min_new_tokens=10,
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
)
|
||||
if not out:
|
||||
# Don't log every empty completion — happens repeatedly on
|
||||
@@ -551,7 +580,14 @@ class AskVQAFwd(InferenceStep):
|
||||
ctx = _msgs_for_vqa(question)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
answer = _generate_with_policy(
|
||||
self.policy, ctx, observation=observation, state=state, label="vqa gen"
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="vqa gen",
|
||||
min_new_tokens=3,
|
||||
temperature=0.4,
|
||||
top_p=0.9,
|
||||
)
|
||||
# VQA answers are intentionally JSON-like during training, so
|
||||
# ``_looks_like_gibberish`` would false-positive on them. Keep
|
||||
@@ -763,6 +799,9 @@ def _generate_with_policy(
|
||||
observation: dict | None = None,
|
||||
state: dict[str, Any] | None = None,
|
||||
label: str = "select_message",
|
||||
min_new_tokens: int = 0,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
) -> str:
|
||||
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
|
||||
|
||||
@@ -797,7 +836,13 @@ def _generate_with_policy(
|
||||
for k, v in observation.items():
|
||||
if isinstance(k, str) and k.startswith("observation.") and k not in batch:
|
||||
batch[k] = v
|
||||
return policy.select_message(batch, tokenizer=text_batch["tokenizer"])
|
||||
return policy.select_message(
|
||||
batch,
|
||||
tokenizer=text_batch["tokenizer"],
|
||||
min_new_tokens=min_new_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG))
|
||||
if state is not None:
|
||||
|
||||
@@ -263,6 +263,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
batch: dict[str, Tensor],
|
||||
*,
|
||||
max_new_tokens: int = 256,
|
||||
min_new_tokens: int = 0,
|
||||
eos_token_id: int | None = None,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
@@ -365,6 +366,19 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
|
||||
last_hidden = prefix_out[:, -1:].to(vlm.lm_head.weight.dtype)
|
||||
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
|
||||
# Suppress EOS until we've decoded ``min_new_tokens`` real
|
||||
# tokens. Without this, a memorised LM head whose argmax
|
||||
# at position 0 is EOS produces an empty completion every
|
||||
# time — confirmed in the real-robot run (the runtime's
|
||||
# ``subtask_empty_count`` climbed every chunk boundary
|
||||
# with no exception). Masking EOS for the first N steps
|
||||
# forces the head to commit to a real token before it can
|
||||
# close the turn.
|
||||
if (
|
||||
eos_token_id is not None
|
||||
and len(generated) < min_new_tokens
|
||||
):
|
||||
logits_step[..., eos_token_id] = float("-inf")
|
||||
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
||||
tok_id = int(next_ids[0].item())
|
||||
generated.append(tok_id)
|
||||
|
||||
Reference in New Issue
Block a user