mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 21:19:53 +00:00
fix(pi052): debug parity harness truncates prompt instead of masking
The parity check in debug_text_predictions was producing false ✗
DIVERGED reports. Root cause: I built the "inference" batch by
zero-masking the attention past the supervised span, but kept the
full 512-token padded sequence. select_message reads the prompt-end
hidden state via ``vlm_out[:, -1:]`` — the LAST position of the
prefix — which in a padded batch is a padding-token hidden state,
not the last prompt token. PaliGemma's prior on those padded
positions reliably argmaxes to <loc0879>, falsely flagging a
training/inference mismatch.
Fix: truncate both tokens AND mask to length == first_sup before
calling select_message, mirroring what the real runtime does
(``tokenizer(prompt)`` returns un-padded ids). Now the parity check
compares like-with-like.
The actual training argmax in the dump was sensible English
("' move the blue cube into the green bin'" at acc=6/9) — the head
is learning correctly. The "<loc>" salad was purely the harness
reading from the wrong position.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -874,14 +874,21 @@ class PI052Policy(PI05Policy):
|
||||
inference_outputs.append({"first_token": None, "decoded": ""})
|
||||
continue
|
||||
first_sup = int(sup_pos[0].item())
|
||||
# Build a single-sample batch with attention zeroed past
|
||||
# the supervised span — that gives ``embed_prefix`` only
|
||||
# the user-prompt portion to attend over.
|
||||
prompt_mask = sub[OBS_LANGUAGE_ATTENTION_MASK][s : s + 1].clone()
|
||||
prompt_mask[:, first_sup:] = 0
|
||||
# Build a single-sample batch by *truncating* the token
|
||||
# sequence to the prompt-only portion (length == first_sup),
|
||||
# not by zero-masking. ``select_message`` reads the
|
||||
# prompt-end hidden state via ``vlm_out[:, -1:]`` — the
|
||||
# *last position* of the prefix — so a padded sequence
|
||||
# would make it read a padding-token hidden state
|
||||
# (PaliGemma's prior on those happens to be ``<loc>``,
|
||||
# which would falsely flag a parity diverge). The real
|
||||
# runtime feeds ``tokenizer(prompt)`` without padding,
|
||||
# so we mirror that here.
|
||||
prompt_tokens = sub[OBS_LANGUAGE_TOKENS][s : s + 1, :first_sup]
|
||||
prompt_mask_orig = sub[OBS_LANGUAGE_ATTENTION_MASK][s : s + 1, :first_sup]
|
||||
inf_batch: dict[str, Any] = {
|
||||
OBS_LANGUAGE_TOKENS: sub[OBS_LANGUAGE_TOKENS][s : s + 1],
|
||||
OBS_LANGUAGE_ATTENTION_MASK: prompt_mask,
|
||||
OBS_LANGUAGE_TOKENS: prompt_tokens,
|
||||
OBS_LANGUAGE_ATTENTION_MASK: prompt_mask_orig,
|
||||
}
|
||||
for k, v in sub.items():
|
||||
if isinstance(k, str) and k.startswith("observation.images."):
|
||||
|
||||
Reference in New Issue
Block a user