fix(pi052,smolvla2): unblock text generation when LM head drifted to <loc>

PaliGemma's pretraining puts heavy first-token mass on its <loc0000>..
<loc1023> ids at any "Assistant:" continuation. Our pi052 fine-tunes
with knowledge_insulation=True and a small text-CE budget (~45% of
samples) drift back toward that prior on long runs at low LR — teacher-
forced argmax stays at 100% (CE only measures next-token given correct
prefix) while autoregressive first-token selection collapses onto <loc>.
On the running poulain11 checkpoint at step 8000 this manifests as a
stream of <locDDDD> tokens for every subtask call — confirmed locally
against the saved checkpoint on a dataset frame.

Add a `suppress_loc_tokens` knob to `PI052Policy.select_message` that
masks ids [256000, 257024) to -inf before sampling, and pass it from
the three text-only inference steps (HighLevelSubtaskFwd,
MemoryUpdateFwd, UserInterjectionFwd). VQA steps keep the default
False so spatial answers can still emit locs. Verified end-to-end:
suppressed → "the robot arm moves the blue block to the green basket".

Also fix `_msgs_for_memory`: it was emitting the older
`User: ${task}\nPlan:..\nMemory:..` / `Assistant: ${subtask}` template,
which no longer matches the `memory_update` recipe layout
(`User: ${task}` / `Assistant: Previous memory: ..` /
`User: Completed subtask: ..`). The new prompt mirrors the training
recipe; `HighLevelSubtaskFwd` stashes the just-completed subtask in
`state['prior_subtask']` so the memory prompt can render
`Completed subtask: ..` for `MemoryUpdateFwd`.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-22 09:50:14 +00:00
parent e050d0fe0a
commit 9d30d91021
2 changed files with 65 additions and 16 deletions
@@ -930,12 +930,23 @@ class PI052Policy(PI05Policy):
temperature: float = 0.0,
top_p: float = 1.0,
tokenizer: Any = None,
suppress_loc_tokens: bool = False,
) -> str:
"""Generate text continuation from a multimodal prefix.
Mirrors ``SmolVLA2Policy.select_message`` so the same
:class:`lerobot.policies.smolvla2.inference.SmolVLA2Runtime`
can drive π0.5 v2 unchanged.
``suppress_loc_tokens`` masks PaliGemma's reserved ``<locDDDD>``
ids ([256000, 257024)) to ``-inf`` before sampling. PaliGemma's
pretraining puts heavy first-token mass on these ids for any
``Assistant:`` continuation; with a small fine-tuning text-CE
budget (or aggressive LR decay) the LM head can drift back
toward that prior even when teacher-forced argmax stays at
100%. Callsites that legitimately emit ``<loc>`` (VQA spatial
answers) must keep this ``False``; subtask / memory / plan
generation should pass ``True``.
"""
self.eval()
@@ -1015,6 +1026,8 @@ class PI052Policy(PI05Policy):
if special_ids and len(generated) < min_new_tokens:
for sid in special_ids:
logits_step[..., sid] = float("-inf")
if suppress_loc_tokens:
logits_step[..., 256000:257024] = float("-inf")
next_ids = self._sample_next_token(logits_step, temperature, top_p)
tok_id = int(next_ids[0].item())
generated.append(tok_id)
@@ -484,6 +484,11 @@ class HighLevelSubtaskFwd(InferenceStep):
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
temperature=float(state.get("text_gen_temperature") or 0.0),
top_p=float(state.get("text_gen_top_p") or 1.0),
# Subtasks never legitimately contain PaliGemma ``<loc>``
# tokens — suppress them so a checkpoint whose LM head
# has drifted toward the pretrained loc-prior falls back
# to its (still-correct) text mass.
suppress_loc_tokens=True,
)
# Diagnostics: surface what the model is *actually* producing
# at chunk boundaries, even when the output gets rejected or
@@ -526,8 +531,16 @@ class HighLevelSubtaskFwd(InferenceStep):
)
return None
if msg:
prev_subtask = state.get("current_subtask")
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
if changed:
# Stash the just-completed subtask so ``MemoryUpdateFwd``
# can drop it into its prompt as ``Completed subtask:``
# — the recipe binds ``completed_subtask`` to
# ``nth_prev(style=subtask, offset=1)``, i.e. the subtask
# that was active *before* the change.
if prev_subtask:
state["prior_subtask"] = prev_subtask
# Subtask change is a downstream trigger.
state.setdefault("events_this_tick", []).append("subtask_change")
state["subtask_repeat_count"] = 0
@@ -576,6 +589,7 @@ class MemoryUpdateFwd(InferenceStep):
observation=observation,
state=state,
label="memory gen",
suppress_loc_tokens=True,
)
state["last_memory_raw"] = new_memory or ""
if new_memory and _looks_like_gibberish(new_memory):
@@ -619,6 +633,7 @@ class UserInterjectionFwd(InferenceStep):
observation=observation,
state=state,
label="plan/say gen",
suppress_loc_tokens=True,
)
if not out:
# Don't log every empty completion — happens repeatedly on
@@ -851,21 +866,34 @@ def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]:
def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]:
"""Memory-update prompt — boundary-frame tail of ``high_level_subtask``.
"""Memory-update prompt — mirrors ``memory_update`` recipe layout.
Recipe layout on a boundary frame:
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}"
assistant: "${subtask}"
Recipe layout (``subtask_mem.yaml``):
user: "${task}"
assistant: "Previous memory: ${prior_memory}" (if_present prior)
user: "Completed subtask: ${completed}" (if_present completed)
assistant: predicts new memory
Fired when the runtime detects a subtask transition; the
just-predicted subtask lives in ``state['current_subtask']``.
Fired by ``MemoryUpdateFwd`` on a ``subtask_change`` event:
``state['current_memory']`` is the memory the policy last emitted
(= the ``prior_memory`` binding at training), and
``state['prior_subtask']`` is the subtask that just got replaced
(= the ``completed_subtask`` binding at training).
"""
msgs: list[dict[str, Any]] = [
{"role": "user", "content": _hirobot_user_head(state)},
{"role": "user", "content": state.get("task") or ""},
]
if state.get("current_subtask"):
msgs.append({"role": "assistant", "content": state["current_subtask"]})
prior_memory = state.get("current_memory")
if prior_memory:
msgs.append(
{"role": "assistant", "content": f"Previous memory: {prior_memory}"}
)
completed_subtask = state.get("prior_subtask")
if completed_subtask:
msgs.append(
{"role": "user", "content": f"Completed subtask: {completed_subtask}"}
)
return msgs
@@ -925,6 +953,7 @@ def _generate_with_policy(
min_new_tokens: int = 0,
temperature: float = 0.0,
top_p: float = 1.0,
suppress_loc_tokens: bool = False,
) -> str:
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
@@ -959,13 +988,20 @@ def _generate_with_policy(
for k, v in observation.items():
if isinstance(k, str) and k.startswith("observation.") and k not in batch:
batch[k] = v
return policy.select_message(
batch,
tokenizer=text_batch["tokenizer"],
min_new_tokens=min_new_tokens,
temperature=temperature,
top_p=top_p,
)
kwargs: dict[str, Any] = {
"tokenizer": text_batch["tokenizer"],
"min_new_tokens": min_new_tokens,
"temperature": temperature,
"top_p": top_p,
}
# Only pass ``suppress_loc_tokens`` to backbones that accept it
# (pi052). SmolVLA2's ``select_message`` does not, so we omit
# the kwarg there to avoid breaking the shared runtime.
import inspect # noqa: PLC0415
if "suppress_loc_tokens" in inspect.signature(policy.select_message).parameters:
kwargs["suppress_loc_tokens"] = suppress_loc_tokens
return policy.select_message(batch, **kwargs)
except Exception as exc: # noqa: BLE001
logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG))
if state is not None: