diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 5cd56f1f3..b1146ea1a 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -851,8 +851,15 @@ class PI052Policy(PI05Policy): # ``_prepare_attention_masks_4d`` always returns fp32 (0.0 / -inf # literals). When weights are bf16, HF's PaliGemma SDPA raises # "invalid dtype for bias - should match query's dtype". Pull the - # backbone's dtype once and cast each step's mask to it. - backbone_dtype = next(backbone.paligemma.parameters()).dtype + # dtype from an attention *projection* weight specifically: + # ``to_bfloat16_for_selected_params`` keeps norms / embeddings in + # fp32 even when the rest is bf16, so ``next(parameters())`` + # would land on one of those and we'd skip the cast. q_proj is + # always cast with the rest, so its dtype is the one SDPA sees. + backbone_dtype = ( + backbone.paligemma.model.language_model.layers[0] + .self_attn.q_proj.weight.dtype + ) for _ in range(max_new_tokens): att_2d = make_att_2d_masks(current_pad, current_att)