From 9d30d91021ac24c9755ab17be40314dd8ab68981 Mon Sep 17 00:00:00 2001 From: pepijn Date: Fri, 22 May 2026 09:50:14 +0000 Subject: [PATCH] fix(pi052,smolvla2): unblock text generation when LM head drifted to MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PaliGemma's pretraining puts heavy first-token mass on its .. ids at any "Assistant:" continuation. Our pi052 fine-tunes with knowledge_insulation=True and a small text-CE budget (~45% of samples) drift back toward that prior on long runs at low LR — teacher- forced argmax stays at 100% (CE only measures next-token given correct prefix) while autoregressive first-token selection collapses onto . On the running poulain11 checkpoint at step 8000 this manifests as a stream of tokens for every subtask call — confirmed locally against the saved checkpoint on a dataset frame. Add a `suppress_loc_tokens` knob to `PI052Policy.select_message` that masks ids [256000, 257024) to -inf before sampling, and pass it from the three text-only inference steps (HighLevelSubtaskFwd, MemoryUpdateFwd, UserInterjectionFwd). VQA steps keep the default False so spatial answers can still emit locs. Verified end-to-end: suppressed → "the robot arm moves the blue block to the green basket". Also fix `_msgs_for_memory`: it was emitting the older `User: ${task}\nPlan:..\nMemory:..` / `Assistant: ${subtask}` template, which no longer matches the `memory_update` recipe layout (`User: ${task}` / `Assistant: Previous memory: ..` / `User: Completed subtask: ..`). The new prompt mirrors the training recipe; `HighLevelSubtaskFwd` stashes the just-completed subtask in `state['prior_subtask']` so the memory prompt can render `Completed subtask: ..` for `MemoryUpdateFwd`. Co-authored-by: Cursor --- src/lerobot/policies/pi052/modeling_pi052.py | 13 ++++ .../policies/smolvla2/inference/steps.py | 68 ++++++++++++++----- 2 files changed, 65 insertions(+), 16 deletions(-) diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index bcf7f6a18..5a055a5d5 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -930,12 +930,23 @@ class PI052Policy(PI05Policy): temperature: float = 0.0, top_p: float = 1.0, tokenizer: Any = None, + suppress_loc_tokens: bool = False, ) -> str: """Generate text continuation from a multimodal prefix. Mirrors ``SmolVLA2Policy.select_message`` so the same :class:`lerobot.policies.smolvla2.inference.SmolVLA2Runtime` can drive π0.5 v2 unchanged. + + ``suppress_loc_tokens`` masks PaliGemma's reserved ```` + ids ([256000, 257024)) to ``-inf`` before sampling. PaliGemma's + pretraining puts heavy first-token mass on these ids for any + ``Assistant:`` continuation; with a small fine-tuning text-CE + budget (or aggressive LR decay) the LM head can drift back + toward that prior even when teacher-forced argmax stays at + 100%. Callsites that legitimately emit ```` (VQA spatial + answers) must keep this ``False``; subtask / memory / plan + generation should pass ``True``. """ self.eval() @@ -1015,6 +1026,8 @@ class PI052Policy(PI05Policy): if special_ids and len(generated) < min_new_tokens: for sid in special_ids: logits_step[..., sid] = float("-inf") + if suppress_loc_tokens: + logits_step[..., 256000:257024] = float("-inf") next_ids = self._sample_next_token(logits_step, temperature, top_p) tok_id = int(next_ids[0].item()) generated.append(tok_id) diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index 6366103e9..9e624c87c 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -484,6 +484,11 @@ class HighLevelSubtaskFwd(InferenceStep): min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0), temperature=float(state.get("text_gen_temperature") or 0.0), top_p=float(state.get("text_gen_top_p") or 1.0), + # Subtasks never legitimately contain PaliGemma ```` + # tokens — suppress them so a checkpoint whose LM head + # has drifted toward the pretrained loc-prior falls back + # to its (still-correct) text mass. + suppress_loc_tokens=True, ) # Diagnostics: surface what the model is *actually* producing # at chunk boundaries, even when the output gets rejected or @@ -526,8 +531,16 @@ class HighLevelSubtaskFwd(InferenceStep): ) return None if msg: + prev_subtask = state.get("current_subtask") changed = set_if_changed(state, "current_subtask", msg, label="subtask") if changed: + # Stash the just-completed subtask so ``MemoryUpdateFwd`` + # can drop it into its prompt as ``Completed subtask:`` + # — the recipe binds ``completed_subtask`` to + # ``nth_prev(style=subtask, offset=1)``, i.e. the subtask + # that was active *before* the change. + if prev_subtask: + state["prior_subtask"] = prev_subtask # Subtask change is a downstream trigger. state.setdefault("events_this_tick", []).append("subtask_change") state["subtask_repeat_count"] = 0 @@ -576,6 +589,7 @@ class MemoryUpdateFwd(InferenceStep): observation=observation, state=state, label="memory gen", + suppress_loc_tokens=True, ) state["last_memory_raw"] = new_memory or "" if new_memory and _looks_like_gibberish(new_memory): @@ -619,6 +633,7 @@ class UserInterjectionFwd(InferenceStep): observation=observation, state=state, label="plan/say gen", + suppress_loc_tokens=True, ) if not out: # Don't log every empty completion — happens repeatedly on @@ -851,21 +866,34 @@ def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]: def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]: - """Memory-update prompt — boundary-frame tail of ``high_level_subtask``. + """Memory-update prompt — mirrors ``memory_update`` recipe layout. - Recipe layout on a boundary frame: - user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}" - assistant: "${subtask}" + Recipe layout (``subtask_mem.yaml``): + + user: "${task}" + assistant: "Previous memory: ${prior_memory}" (if_present prior) + user: "Completed subtask: ${completed}" (if_present completed) assistant: → predicts new memory - Fired when the runtime detects a subtask transition; the - just-predicted subtask lives in ``state['current_subtask']``. + Fired by ``MemoryUpdateFwd`` on a ``subtask_change`` event: + ``state['current_memory']`` is the memory the policy last emitted + (= the ``prior_memory`` binding at training), and + ``state['prior_subtask']`` is the subtask that just got replaced + (= the ``completed_subtask`` binding at training). """ msgs: list[dict[str, Any]] = [ - {"role": "user", "content": _hirobot_user_head(state)}, + {"role": "user", "content": state.get("task") or ""}, ] - if state.get("current_subtask"): - msgs.append({"role": "assistant", "content": state["current_subtask"]}) + prior_memory = state.get("current_memory") + if prior_memory: + msgs.append( + {"role": "assistant", "content": f"Previous memory: {prior_memory}"} + ) + completed_subtask = state.get("prior_subtask") + if completed_subtask: + msgs.append( + {"role": "user", "content": f"Completed subtask: {completed_subtask}"} + ) return msgs @@ -925,6 +953,7 @@ def _generate_with_policy( min_new_tokens: int = 0, temperature: float = 0.0, top_p: float = 1.0, + suppress_loc_tokens: bool = False, ) -> str: """Drive ``policy.select_message`` with a chat batch (and optional obs). @@ -959,13 +988,20 @@ def _generate_with_policy( for k, v in observation.items(): if isinstance(k, str) and k.startswith("observation.") and k not in batch: batch[k] = v - return policy.select_message( - batch, - tokenizer=text_batch["tokenizer"], - min_new_tokens=min_new_tokens, - temperature=temperature, - top_p=top_p, - ) + kwargs: dict[str, Any] = { + "tokenizer": text_batch["tokenizer"], + "min_new_tokens": min_new_tokens, + "temperature": temperature, + "top_p": top_p, + } + # Only pass ``suppress_loc_tokens`` to backbones that accept it + # (pi052). SmolVLA2's ``select_message`` does not, so we omit + # the kwarg there to avoid breaking the shared runtime. + import inspect # noqa: PLC0415 + + if "suppress_loc_tokens" in inspect.signature(policy.select_message).parameters: + kwargs["suppress_loc_tokens"] = suppress_loc_tokens + return policy.select_message(batch, **kwargs) except Exception as exc: # noqa: BLE001 logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG)) if state is not None: