fix(pi052): read backbone dtype from q_proj, not first parameter

select_message's bf16 cast used next(paligemma.parameters()).dtype,
which lands on a fp32-kept param (norm / embedding) under
to_bfloat16_for_selected_params. Mask stayed fp32 while q/k/v were
bf16 → SDPA still raised "invalid dtype for bias". Read the dtype
from layers[0].self_attn.q_proj.weight instead — q_proj is always
cast with the rest, so its dtype matches what SDPA sees.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-20 18:46:08 +02:00
parent 3b4376aa33
commit f7b989ad97
+9 -2
View File
@@ -851,8 +851,15 @@ class PI052Policy(PI05Policy):
# ``_prepare_attention_masks_4d`` always returns fp32 (0.0 / -inf
# literals). When weights are bf16, HF's PaliGemma SDPA raises
# "invalid dtype for bias - should match query's dtype". Pull the
# backbone's dtype once and cast each step's mask to it.
backbone_dtype = next(backbone.paligemma.parameters()).dtype
# dtype from an attention *projection* weight specifically:
# ``to_bfloat16_for_selected_params`` keeps norms / embeddings in
# fp32 even when the rest is bf16, so ``next(parameters())``
# would land on one of those and we'd skip the cast. q_proj is
# always cast with the rest, so its dtype is the one SDPA sees.
backbone_dtype = (
backbone.paligemma.model.language_model.layers[0]
.self_attn.q_proj.weight.dtype
)
for _ in range(max_new_tokens):
att_2d = make_att_2d_masks(current_pad, current_att)