diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index 513ce714c..4306f7d2d 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -365,8 +365,23 @@ class HighLevelSubtaskFwd(InferenceStep): return None ctx = _msgs_for_subtask(state) observation = _maybe_observation(self.observation_provider) + # Force the head to commit to ≥ 5 real tokens before it can + # close the turn, and sample at moderate temperature with + # nucleus filtering. On a memorised head whose argmax at + # position 0 is EOS, greedy decoding silently produced empty + # completions every chunk boundary (visible as the + # ``empty:N`` counter climbing). Temp 0.4 + top_p 0.9 is well + # below where SmolVLM goes incoherent and above where greedy + # collapse re-emerges. msg = _generate_with_policy( - self.policy, ctx, observation=observation, state=state, label="subtask gen" + self.policy, + ctx, + observation=observation, + state=state, + label="subtask gen", + min_new_tokens=5, + temperature=0.4, + top_p=0.9, ) # Diagnostics: surface what the model is *actually* producing # at chunk boundaries, even when the output gets rejected or @@ -447,7 +462,14 @@ class MemoryUpdateFwd(InferenceStep): ctx = _msgs_for_memory(state) observation = _maybe_observation(self.observation_provider) new_memory = _generate_with_policy( - self.policy, ctx, observation=observation, state=state, label="memory gen" + self.policy, + ctx, + observation=observation, + state=state, + label="memory gen", + min_new_tokens=5, + temperature=0.4, + top_p=0.9, ) state["last_memory_raw"] = new_memory or "" if new_memory and _looks_like_gibberish(new_memory): @@ -486,7 +508,14 @@ class UserInterjectionFwd(InferenceStep): ctx = _msgs_for_interjection(state) observation = _maybe_observation(self.observation_provider) out = _generate_with_policy( - self.policy, ctx, observation=observation, state=state, label="plan/say gen" + self.policy, + ctx, + observation=observation, + state=state, + label="plan/say gen", + min_new_tokens=10, + temperature=0.5, + top_p=0.9, ) if not out: # Don't log every empty completion — happens repeatedly on @@ -551,7 +580,14 @@ class AskVQAFwd(InferenceStep): ctx = _msgs_for_vqa(question) observation = _maybe_observation(self.observation_provider) answer = _generate_with_policy( - self.policy, ctx, observation=observation, state=state, label="vqa gen" + self.policy, + ctx, + observation=observation, + state=state, + label="vqa gen", + min_new_tokens=3, + temperature=0.4, + top_p=0.9, ) # VQA answers are intentionally JSON-like during training, so # ``_looks_like_gibberish`` would false-positive on them. Keep @@ -763,6 +799,9 @@ def _generate_with_policy( observation: dict | None = None, state: dict[str, Any] | None = None, label: str = "select_message", + min_new_tokens: int = 0, + temperature: float = 0.0, + top_p: float = 1.0, ) -> str: """Drive ``policy.select_message`` with a chat batch (and optional obs). @@ -797,7 +836,13 @@ def _generate_with_policy( for k, v in observation.items(): if isinstance(k, str) and k.startswith("observation.") and k not in batch: batch[k] = v - return policy.select_message(batch, tokenizer=text_batch["tokenizer"]) + return policy.select_message( + batch, + tokenizer=text_batch["tokenizer"], + min_new_tokens=min_new_tokens, + temperature=temperature, + top_p=top_p, + ) except Exception as exc: # noqa: BLE001 logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG)) if state is not None: diff --git a/src/lerobot/policies/smolvla2/modeling_smolvla2.py b/src/lerobot/policies/smolvla2/modeling_smolvla2.py index 79978cf95..a808363ef 100644 --- a/src/lerobot/policies/smolvla2/modeling_smolvla2.py +++ b/src/lerobot/policies/smolvla2/modeling_smolvla2.py @@ -263,6 +263,7 @@ class SmolVLA2Policy(SmolVLAPolicy): batch: dict[str, Tensor], *, max_new_tokens: int = 256, + min_new_tokens: int = 0, eos_token_id: int | None = None, temperature: float = 0.0, top_p: float = 1.0, @@ -365,6 +366,19 @@ class SmolVLA2Policy(SmolVLAPolicy): last_hidden = prefix_out[:, -1:].to(vlm.lm_head.weight.dtype) logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V) + # Suppress EOS until we've decoded ``min_new_tokens`` real + # tokens. Without this, a memorised LM head whose argmax + # at position 0 is EOS produces an empty completion every + # time — confirmed in the real-robot run (the runtime's + # ``subtask_empty_count`` climbed every chunk boundary + # with no exception). Masking EOS for the first N steps + # forces the head to commit to a real token before it can + # close the turn. + if ( + eos_token_id is not None + and len(generated) < min_new_tokens + ): + logits_step[..., eos_token_id] = float("-inf") next_ids = self._sample_next_token(logits_step, temperature, top_p) tok_id = int(next_ids[0].item()) generated.append(tok_id)