fix(pi052): debug parity harness truncates prompt instead of masking

The parity check in debug_text_predictions was producing false ✗
DIVERGED reports. Root cause: I built the "inference" batch by
zero-masking the attention past the supervised span, but kept the
full 512-token padded sequence. select_message reads the prompt-end
hidden state via ``vlm_out[:, -1:]`` — the LAST position of the
prefix — which in a padded batch is a padding-token hidden state,
not the last prompt token. PaliGemma's prior on those padded
positions reliably argmaxes to <loc0879>, falsely flagging a
training/inference mismatch.

Fix: truncate both tokens AND mask to length == first_sup before
calling select_message, mirroring what the real runtime does
(``tokenizer(prompt)`` returns un-padded ids). Now the parity check
compares like-with-like.

The actual training argmax in the dump was sensible English
("' move the blue cube into the green bin'" at acc=6/9) — the head
is learning correctly. The "<loc>" salad was purely the harness
reading from the wrong position.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-21 15:09:36 +02:00
parent efa05f0ada
commit d41d874581
+14 -7
View File
@@ -874,14 +874,21 @@ class PI052Policy(PI05Policy):
inference_outputs.append({"first_token": None, "decoded": ""})
continue
first_sup = int(sup_pos[0].item())
# Build a single-sample batch with attention zeroed past
# the supervised span — that gives ``embed_prefix`` only
# the user-prompt portion to attend over.
prompt_mask = sub[OBS_LANGUAGE_ATTENTION_MASK][s : s + 1].clone()
prompt_mask[:, first_sup:] = 0
# Build a single-sample batch by *truncating* the token
# sequence to the prompt-only portion (length == first_sup),
# not by zero-masking. ``select_message`` reads the
# prompt-end hidden state via ``vlm_out[:, -1:]`` — the
# *last position* of the prefix — so a padded sequence
# would make it read a padding-token hidden state
# (PaliGemma's prior on those happens to be ``<loc>``,
# which would falsely flag a parity diverge). The real
# runtime feeds ``tokenizer(prompt)`` without padding,
# so we mirror that here.
prompt_tokens = sub[OBS_LANGUAGE_TOKENS][s : s + 1, :first_sup]
prompt_mask_orig = sub[OBS_LANGUAGE_ATTENTION_MASK][s : s + 1, :first_sup]
inf_batch: dict[str, Any] = {
OBS_LANGUAGE_TOKENS: sub[OBS_LANGUAGE_TOKENS][s : s + 1],
OBS_LANGUAGE_ATTENTION_MASK: prompt_mask,
OBS_LANGUAGE_TOKENS: prompt_tokens,
OBS_LANGUAGE_ATTENTION_MASK: prompt_mask_orig,
}
for k, v in sub.items():
if isinstance(k, str) and k.startswith("observation.images."):