mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 21:19:53 +00:00
fix(pi052,smolvla2): unblock text generation when LM head drifted to <loc>
PaliGemma's pretraining puts heavy first-token mass on its <loc0000>..
<loc1023> ids at any "Assistant:" continuation. Our pi052 fine-tunes
with knowledge_insulation=True and a small text-CE budget (~45% of
samples) drift back toward that prior on long runs at low LR — teacher-
forced argmax stays at 100% (CE only measures next-token given correct
prefix) while autoregressive first-token selection collapses onto <loc>.
On the running poulain11 checkpoint at step 8000 this manifests as a
stream of <locDDDD> tokens for every subtask call — confirmed locally
against the saved checkpoint on a dataset frame.
Add a `suppress_loc_tokens` knob to `PI052Policy.select_message` that
masks ids [256000, 257024) to -inf before sampling, and pass it from
the three text-only inference steps (HighLevelSubtaskFwd,
MemoryUpdateFwd, UserInterjectionFwd). VQA steps keep the default
False so spatial answers can still emit locs. Verified end-to-end:
suppressed → "the robot arm moves the blue block to the green basket".
Also fix `_msgs_for_memory`: it was emitting the older
`User: ${task}\nPlan:..\nMemory:..` / `Assistant: ${subtask}` template,
which no longer matches the `memory_update` recipe layout
(`User: ${task}` / `Assistant: Previous memory: ..` /
`User: Completed subtask: ..`). The new prompt mirrors the training
recipe; `HighLevelSubtaskFwd` stashes the just-completed subtask in
`state['prior_subtask']` so the memory prompt can render
`Completed subtask: ..` for `MemoryUpdateFwd`.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -930,12 +930,23 @@ class PI052Policy(PI05Policy):
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
tokenizer: Any = None,
|
||||
suppress_loc_tokens: bool = False,
|
||||
) -> str:
|
||||
"""Generate text continuation from a multimodal prefix.
|
||||
|
||||
Mirrors ``SmolVLA2Policy.select_message`` so the same
|
||||
:class:`lerobot.policies.smolvla2.inference.SmolVLA2Runtime`
|
||||
can drive π0.5 v2 unchanged.
|
||||
|
||||
``suppress_loc_tokens`` masks PaliGemma's reserved ``<locDDDD>``
|
||||
ids ([256000, 257024)) to ``-inf`` before sampling. PaliGemma's
|
||||
pretraining puts heavy first-token mass on these ids for any
|
||||
``Assistant:`` continuation; with a small fine-tuning text-CE
|
||||
budget (or aggressive LR decay) the LM head can drift back
|
||||
toward that prior even when teacher-forced argmax stays at
|
||||
100%. Callsites that legitimately emit ``<loc>`` (VQA spatial
|
||||
answers) must keep this ``False``; subtask / memory / plan
|
||||
generation should pass ``True``.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
@@ -1015,6 +1026,8 @@ class PI052Policy(PI05Policy):
|
||||
if special_ids and len(generated) < min_new_tokens:
|
||||
for sid in special_ids:
|
||||
logits_step[..., sid] = float("-inf")
|
||||
if suppress_loc_tokens:
|
||||
logits_step[..., 256000:257024] = float("-inf")
|
||||
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
||||
tok_id = int(next_ids[0].item())
|
||||
generated.append(tok_id)
|
||||
|
||||
@@ -484,6 +484,11 @@ class HighLevelSubtaskFwd(InferenceStep):
|
||||
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
|
||||
temperature=float(state.get("text_gen_temperature") or 0.0),
|
||||
top_p=float(state.get("text_gen_top_p") or 1.0),
|
||||
# Subtasks never legitimately contain PaliGemma ``<loc>``
|
||||
# tokens — suppress them so a checkpoint whose LM head
|
||||
# has drifted toward the pretrained loc-prior falls back
|
||||
# to its (still-correct) text mass.
|
||||
suppress_loc_tokens=True,
|
||||
)
|
||||
# Diagnostics: surface what the model is *actually* producing
|
||||
# at chunk boundaries, even when the output gets rejected or
|
||||
@@ -526,8 +531,16 @@ class HighLevelSubtaskFwd(InferenceStep):
|
||||
)
|
||||
return None
|
||||
if msg:
|
||||
prev_subtask = state.get("current_subtask")
|
||||
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
|
||||
if changed:
|
||||
# Stash the just-completed subtask so ``MemoryUpdateFwd``
|
||||
# can drop it into its prompt as ``Completed subtask:``
|
||||
# — the recipe binds ``completed_subtask`` to
|
||||
# ``nth_prev(style=subtask, offset=1)``, i.e. the subtask
|
||||
# that was active *before* the change.
|
||||
if prev_subtask:
|
||||
state["prior_subtask"] = prev_subtask
|
||||
# Subtask change is a downstream trigger.
|
||||
state.setdefault("events_this_tick", []).append("subtask_change")
|
||||
state["subtask_repeat_count"] = 0
|
||||
@@ -576,6 +589,7 @@ class MemoryUpdateFwd(InferenceStep):
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="memory gen",
|
||||
suppress_loc_tokens=True,
|
||||
)
|
||||
state["last_memory_raw"] = new_memory or ""
|
||||
if new_memory and _looks_like_gibberish(new_memory):
|
||||
@@ -619,6 +633,7 @@ class UserInterjectionFwd(InferenceStep):
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="plan/say gen",
|
||||
suppress_loc_tokens=True,
|
||||
)
|
||||
if not out:
|
||||
# Don't log every empty completion — happens repeatedly on
|
||||
@@ -851,21 +866,34 @@ def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
|
||||
|
||||
def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Memory-update prompt — boundary-frame tail of ``high_level_subtask``.
|
||||
"""Memory-update prompt — mirrors ``memory_update`` recipe layout.
|
||||
|
||||
Recipe layout on a boundary frame:
|
||||
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}"
|
||||
assistant: "${subtask}"
|
||||
Recipe layout (``subtask_mem.yaml``):
|
||||
|
||||
user: "${task}"
|
||||
assistant: "Previous memory: ${prior_memory}" (if_present prior)
|
||||
user: "Completed subtask: ${completed}" (if_present completed)
|
||||
assistant: → predicts new memory
|
||||
|
||||
Fired when the runtime detects a subtask transition; the
|
||||
just-predicted subtask lives in ``state['current_subtask']``.
|
||||
Fired by ``MemoryUpdateFwd`` on a ``subtask_change`` event:
|
||||
``state['current_memory']`` is the memory the policy last emitted
|
||||
(= the ``prior_memory`` binding at training), and
|
||||
``state['prior_subtask']`` is the subtask that just got replaced
|
||||
(= the ``completed_subtask`` binding at training).
|
||||
"""
|
||||
msgs: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": _hirobot_user_head(state)},
|
||||
{"role": "user", "content": state.get("task") or ""},
|
||||
]
|
||||
if state.get("current_subtask"):
|
||||
msgs.append({"role": "assistant", "content": state["current_subtask"]})
|
||||
prior_memory = state.get("current_memory")
|
||||
if prior_memory:
|
||||
msgs.append(
|
||||
{"role": "assistant", "content": f"Previous memory: {prior_memory}"}
|
||||
)
|
||||
completed_subtask = state.get("prior_subtask")
|
||||
if completed_subtask:
|
||||
msgs.append(
|
||||
{"role": "user", "content": f"Completed subtask: {completed_subtask}"}
|
||||
)
|
||||
return msgs
|
||||
|
||||
|
||||
@@ -925,6 +953,7 @@ def _generate_with_policy(
|
||||
min_new_tokens: int = 0,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
suppress_loc_tokens: bool = False,
|
||||
) -> str:
|
||||
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
|
||||
|
||||
@@ -959,13 +988,20 @@ def _generate_with_policy(
|
||||
for k, v in observation.items():
|
||||
if isinstance(k, str) and k.startswith("observation.") and k not in batch:
|
||||
batch[k] = v
|
||||
return policy.select_message(
|
||||
batch,
|
||||
tokenizer=text_batch["tokenizer"],
|
||||
min_new_tokens=min_new_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
kwargs: dict[str, Any] = {
|
||||
"tokenizer": text_batch["tokenizer"],
|
||||
"min_new_tokens": min_new_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
# Only pass ``suppress_loc_tokens`` to backbones that accept it
|
||||
# (pi052). SmolVLA2's ``select_message`` does not, so we omit
|
||||
# the kwarg there to avoid breaking the shared runtime.
|
||||
import inspect # noqa: PLC0415
|
||||
|
||||
if "suppress_loc_tokens" in inspect.signature(policy.select_message).parameters:
|
||||
kwargs["suppress_loc_tokens"] = suppress_loc_tokens
|
||||
return policy.select_message(batch, **kwargs)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG))
|
||||
if state is not None:
|
||||
|
||||
Reference in New Issue
Block a user