fix(pi052): cast attention bias to model dtype for bf16 inference

`_prepare_attention_masks_4d` always returns fp32 (the 0.0 / -inf
literals); with bf16 weights, HF PaliGemma's SDPA path raises
"invalid dtype for bias - should match query's dtype" and
select_message returns empty every step. Cast in both attention
sites: `_compute_layer_ki` (training, when both experts run) and
`select_message` (inference, VLM-only branch). Bf16 training +
bf16 inference now run end to end with no dtype mismatch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-20 18:42:26 +02:00
parent 34269a5d78
commit 3b4376aa33
@@ -230,6 +230,15 @@ def _compute_layer_ki(
mask_for_vlm = attention_mask[:, :, :vlm_len, :]
mask_for_action = attention_mask[:, :, vlm_len:, :]
# ``_prepare_attention_masks_4d`` always returns fp32 (0.0 / -inf
# literals), but PaliGemma weights are bf16 when ``dtype=bfloat16``,
# making q bf16. SDPA's ``scaled_dot_product_attention`` then raises
# "invalid dtype for bias - should match query's dtype". Cast each
# mask slice to the corresponding query dtype right before use.
if mask_for_vlm.dtype != Q_vlm.dtype:
mask_for_vlm = mask_for_vlm.to(dtype=Q_vlm.dtype)
if mask_for_action.dtype != Q_action.dtype:
mask_for_action = mask_for_action.to(dtype=Q_action.dtype)
att_vlm, _ = modeling_gemma.eager_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
@@ -839,10 +848,18 @@ class PI052Policy(PI05Policy):
backbone = self.model.paligemma_with_expert
lm_head = backbone.paligemma.lm_head
# ``_prepare_attention_masks_4d`` always returns fp32 (0.0 / -inf
# literals). When weights are bf16, HF's PaliGemma SDPA raises
# "invalid dtype for bias - should match query's dtype". Pull the
# backbone's dtype once and cast each step's mask to it.
backbone_dtype = next(backbone.paligemma.parameters()).dtype
for _ in range(max_new_tokens):
att_2d = make_att_2d_masks(current_pad, current_att)
position_ids = torch.cumsum(current_pad, dim=1) - 1
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d)
if att_2d_4d.dtype != backbone_dtype:
att_2d_4d = att_2d_4d.to(dtype=backbone_dtype)
(vlm_out, _), _ = backbone.forward(
attention_mask=att_2d_4d,
position_ids=position_ids,