diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index c0b30f707..bcf7f6a18 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -874,14 +874,21 @@ class PI052Policy(PI05Policy): inference_outputs.append({"first_token": None, "decoded": ""}) continue first_sup = int(sup_pos[0].item()) - # Build a single-sample batch with attention zeroed past - # the supervised span — that gives ``embed_prefix`` only - # the user-prompt portion to attend over. - prompt_mask = sub[OBS_LANGUAGE_ATTENTION_MASK][s : s + 1].clone() - prompt_mask[:, first_sup:] = 0 + # Build a single-sample batch by *truncating* the token + # sequence to the prompt-only portion (length == first_sup), + # not by zero-masking. ``select_message`` reads the + # prompt-end hidden state via ``vlm_out[:, -1:]`` — the + # *last position* of the prefix — so a padded sequence + # would make it read a padding-token hidden state + # (PaliGemma's prior on those happens to be ````, + # which would falsely flag a parity diverge). The real + # runtime feeds ``tokenizer(prompt)`` without padding, + # so we mirror that here. + prompt_tokens = sub[OBS_LANGUAGE_TOKENS][s : s + 1, :first_sup] + prompt_mask_orig = sub[OBS_LANGUAGE_ATTENTION_MASK][s : s + 1, :first_sup] inf_batch: dict[str, Any] = { - OBS_LANGUAGE_TOKENS: sub[OBS_LANGUAGE_TOKENS][s : s + 1], - OBS_LANGUAGE_ATTENTION_MASK: prompt_mask, + OBS_LANGUAGE_TOKENS: prompt_tokens, + OBS_LANGUAGE_ATTENTION_MASK: prompt_mask_orig, } for k, v in sub.items(): if isinstance(k, str) and k.startswith("observation.images."):