From d41d87458119c2a5e562aed719cfb9972bf8e962 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 21 May 2026 15:09:36 +0200 Subject: [PATCH] fix(pi052): debug parity harness truncates prompt instead of masking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The parity check in debug_text_predictions was producing false ✗ DIVERGED reports. Root cause: I built the "inference" batch by zero-masking the attention past the supervised span, but kept the full 512-token padded sequence. select_message reads the prompt-end hidden state via ``vlm_out[:, -1:]`` — the LAST position of the prefix — which in a padded batch is a padding-token hidden state, not the last prompt token. PaliGemma's prior on those padded positions reliably argmaxes to , falsely flagging a training/inference mismatch. Fix: truncate both tokens AND mask to length == first_sup before calling select_message, mirroring what the real runtime does (``tokenizer(prompt)`` returns un-padded ids). Now the parity check compares like-with-like. The actual training argmax in the dump was sensible English ("' move the blue cube into the green bin'" at acc=6/9) — the head is learning correctly. The "" salad was purely the harness reading from the wrong position. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lerobot/policies/pi052/modeling_pi052.py | 21 +++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index c0b30f707..bcf7f6a18 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -874,14 +874,21 @@ class PI052Policy(PI05Policy): inference_outputs.append({"first_token": None, "decoded": ""}) continue first_sup = int(sup_pos[0].item()) - # Build a single-sample batch with attention zeroed past - # the supervised span — that gives ``embed_prefix`` only - # the user-prompt portion to attend over. - prompt_mask = sub[OBS_LANGUAGE_ATTENTION_MASK][s : s + 1].clone() - prompt_mask[:, first_sup:] = 0 + # Build a single-sample batch by *truncating* the token + # sequence to the prompt-only portion (length == first_sup), + # not by zero-masking. ``select_message`` reads the + # prompt-end hidden state via ``vlm_out[:, -1:]`` — the + # *last position* of the prefix — so a padded sequence + # would make it read a padding-token hidden state + # (PaliGemma's prior on those happens to be ````, + # which would falsely flag a parity diverge). The real + # runtime feeds ``tokenizer(prompt)`` without padding, + # so we mirror that here. + prompt_tokens = sub[OBS_LANGUAGE_TOKENS][s : s + 1, :first_sup] + prompt_mask_orig = sub[OBS_LANGUAGE_ATTENTION_MASK][s : s + 1, :first_sup] inf_batch: dict[str, Any] = { - OBS_LANGUAGE_TOKENS: sub[OBS_LANGUAGE_TOKENS][s : s + 1], - OBS_LANGUAGE_ATTENTION_MASK: prompt_mask, + OBS_LANGUAGE_TOKENS: prompt_tokens, + OBS_LANGUAGE_ATTENTION_MASK: prompt_mask_orig, } for k, v in sub.items(): if isinstance(k, str) and k.startswith("observation.images."):