diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index bcf7f6a18..5a055a5d5 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -930,12 +930,23 @@ class PI052Policy(PI05Policy): temperature: float = 0.0, top_p: float = 1.0, tokenizer: Any = None, + suppress_loc_tokens: bool = False, ) -> str: """Generate text continuation from a multimodal prefix. Mirrors ``SmolVLA2Policy.select_message`` so the same :class:`lerobot.policies.smolvla2.inference.SmolVLA2Runtime` can drive π0.5 v2 unchanged. + + ``suppress_loc_tokens`` masks PaliGemma's reserved ```` + ids ([256000, 257024)) to ``-inf`` before sampling. PaliGemma's + pretraining puts heavy first-token mass on these ids for any + ``Assistant:`` continuation; with a small fine-tuning text-CE + budget (or aggressive LR decay) the LM head can drift back + toward that prior even when teacher-forced argmax stays at + 100%. Callsites that legitimately emit ```` (VQA spatial + answers) must keep this ``False``; subtask / memory / plan + generation should pass ``True``. """ self.eval() @@ -1015,6 +1026,8 @@ class PI052Policy(PI05Policy): if special_ids and len(generated) < min_new_tokens: for sid in special_ids: logits_step[..., sid] = float("-inf") + if suppress_loc_tokens: + logits_step[..., 256000:257024] = float("-inf") next_ids = self._sample_next_token(logits_step, temperature, top_p) tok_id = int(next_ids[0].item()) generated.append(tok_id) diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index 6366103e9..9e624c87c 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -484,6 +484,11 @@ class HighLevelSubtaskFwd(InferenceStep): min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0), temperature=float(state.get("text_gen_temperature") or 0.0), top_p=float(state.get("text_gen_top_p") or 1.0), + # Subtasks never legitimately contain PaliGemma ```` + # tokens — suppress them so a checkpoint whose LM head + # has drifted toward the pretrained loc-prior falls back + # to its (still-correct) text mass. + suppress_loc_tokens=True, ) # Diagnostics: surface what the model is *actually* producing # at chunk boundaries, even when the output gets rejected or @@ -526,8 +531,16 @@ class HighLevelSubtaskFwd(InferenceStep): ) return None if msg: + prev_subtask = state.get("current_subtask") changed = set_if_changed(state, "current_subtask", msg, label="subtask") if changed: + # Stash the just-completed subtask so ``MemoryUpdateFwd`` + # can drop it into its prompt as ``Completed subtask:`` + # — the recipe binds ``completed_subtask`` to + # ``nth_prev(style=subtask, offset=1)``, i.e. the subtask + # that was active *before* the change. + if prev_subtask: + state["prior_subtask"] = prev_subtask # Subtask change is a downstream trigger. state.setdefault("events_this_tick", []).append("subtask_change") state["subtask_repeat_count"] = 0 @@ -576,6 +589,7 @@ class MemoryUpdateFwd(InferenceStep): observation=observation, state=state, label="memory gen", + suppress_loc_tokens=True, ) state["last_memory_raw"] = new_memory or "" if new_memory and _looks_like_gibberish(new_memory): @@ -619,6 +633,7 @@ class UserInterjectionFwd(InferenceStep): observation=observation, state=state, label="plan/say gen", + suppress_loc_tokens=True, ) if not out: # Don't log every empty completion — happens repeatedly on @@ -851,21 +866,34 @@ def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]: def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]: - """Memory-update prompt — boundary-frame tail of ``high_level_subtask``. + """Memory-update prompt — mirrors ``memory_update`` recipe layout. - Recipe layout on a boundary frame: - user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}" - assistant: "${subtask}" + Recipe layout (``subtask_mem.yaml``): + + user: "${task}" + assistant: "Previous memory: ${prior_memory}" (if_present prior) + user: "Completed subtask: ${completed}" (if_present completed) assistant: → predicts new memory - Fired when the runtime detects a subtask transition; the - just-predicted subtask lives in ``state['current_subtask']``. + Fired by ``MemoryUpdateFwd`` on a ``subtask_change`` event: + ``state['current_memory']`` is the memory the policy last emitted + (= the ``prior_memory`` binding at training), and + ``state['prior_subtask']`` is the subtask that just got replaced + (= the ``completed_subtask`` binding at training). """ msgs: list[dict[str, Any]] = [ - {"role": "user", "content": _hirobot_user_head(state)}, + {"role": "user", "content": state.get("task") or ""}, ] - if state.get("current_subtask"): - msgs.append({"role": "assistant", "content": state["current_subtask"]}) + prior_memory = state.get("current_memory") + if prior_memory: + msgs.append( + {"role": "assistant", "content": f"Previous memory: {prior_memory}"} + ) + completed_subtask = state.get("prior_subtask") + if completed_subtask: + msgs.append( + {"role": "user", "content": f"Completed subtask: {completed_subtask}"} + ) return msgs @@ -925,6 +953,7 @@ def _generate_with_policy( min_new_tokens: int = 0, temperature: float = 0.0, top_p: float = 1.0, + suppress_loc_tokens: bool = False, ) -> str: """Drive ``policy.select_message`` with a chat batch (and optional obs). @@ -959,13 +988,20 @@ def _generate_with_policy( for k, v in observation.items(): if isinstance(k, str) and k.startswith("observation.") and k not in batch: batch[k] = v - return policy.select_message( - batch, - tokenizer=text_batch["tokenizer"], - min_new_tokens=min_new_tokens, - temperature=temperature, - top_p=top_p, - ) + kwargs: dict[str, Any] = { + "tokenizer": text_batch["tokenizer"], + "min_new_tokens": min_new_tokens, + "temperature": temperature, + "top_p": top_p, + } + # Only pass ``suppress_loc_tokens`` to backbones that accept it + # (pi052). SmolVLA2's ``select_message`` does not, so we omit + # the kwarg there to avoid breaking the shared runtime. + import inspect # noqa: PLC0415 + + if "suppress_loc_tokens" in inspect.signature(policy.select_message).parameters: + kwargs["suppress_loc_tokens"] = suppress_loc_tokens + return policy.select_message(batch, **kwargs) except Exception as exc: # noqa: BLE001 logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG)) if state is not None: