diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 440c5afa9..5cd56f1f3 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -230,6 +230,15 @@ def _compute_layer_ki( mask_for_vlm = attention_mask[:, :, :vlm_len, :] mask_for_action = attention_mask[:, :, vlm_len:, :] + # ``_prepare_attention_masks_4d`` always returns fp32 (0.0 / -inf + # literals), but PaliGemma weights are bf16 when ``dtype=bfloat16``, + # making q bf16. SDPA's ``scaled_dot_product_attention`` then raises + # "invalid dtype for bias - should match query's dtype". Cast each + # mask slice to the corresponding query dtype right before use. + if mask_for_vlm.dtype != Q_vlm.dtype: + mask_for_vlm = mask_for_vlm.to(dtype=Q_vlm.dtype) + if mask_for_action.dtype != Q_action.dtype: + mask_for_action = mask_for_action.to(dtype=Q_action.dtype) att_vlm, _ = modeling_gemma.eager_attention_forward( paligemma.model.language_model.layers[layer_idx].self_attn, @@ -839,10 +848,18 @@ class PI052Policy(PI05Policy): backbone = self.model.paligemma_with_expert lm_head = backbone.paligemma.lm_head + # ``_prepare_attention_masks_4d`` always returns fp32 (0.0 / -inf + # literals). When weights are bf16, HF's PaliGemma SDPA raises + # "invalid dtype for bias - should match query's dtype". Pull the + # backbone's dtype once and cast each step's mask to it. + backbone_dtype = next(backbone.paligemma.parameters()).dtype + for _ in range(max_new_tokens): att_2d = make_att_2d_masks(current_pad, current_att) position_ids = torch.cumsum(current_pad, dim=1) - 1 att_2d_4d = self.model._prepare_attention_masks_4d(att_2d) + if att_2d_4d.dtype != backbone_dtype: + att_2d_4d = att_2d_4d.to(dtype=backbone_dtype) (vlm_out, _), _ = backbone.forward( attention_mask=att_2d_4d, position_ids=position_ids,