mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
fix(pi052): cast attention bias to model dtype for bf16 inference
`_prepare_attention_masks_4d` always returns fp32 (the 0.0 / -inf literals); with bf16 weights, HF PaliGemma's SDPA path raises "invalid dtype for bias - should match query's dtype" and select_message returns empty every step. Cast in both attention sites: `_compute_layer_ki` (training, when both experts run) and `select_message` (inference, VLM-only branch). Bf16 training + bf16 inference now run end to end with no dtype mismatch. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -230,6 +230,15 @@ def _compute_layer_ki(
|
||||
|
||||
mask_for_vlm = attention_mask[:, :, :vlm_len, :]
|
||||
mask_for_action = attention_mask[:, :, vlm_len:, :]
|
||||
# ``_prepare_attention_masks_4d`` always returns fp32 (0.0 / -inf
|
||||
# literals), but PaliGemma weights are bf16 when ``dtype=bfloat16``,
|
||||
# making q bf16. SDPA's ``scaled_dot_product_attention`` then raises
|
||||
# "invalid dtype for bias - should match query's dtype". Cast each
|
||||
# mask slice to the corresponding query dtype right before use.
|
||||
if mask_for_vlm.dtype != Q_vlm.dtype:
|
||||
mask_for_vlm = mask_for_vlm.to(dtype=Q_vlm.dtype)
|
||||
if mask_for_action.dtype != Q_action.dtype:
|
||||
mask_for_action = mask_for_action.to(dtype=Q_action.dtype)
|
||||
|
||||
att_vlm, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
@@ -839,10 +848,18 @@ class PI052Policy(PI05Policy):
|
||||
backbone = self.model.paligemma_with_expert
|
||||
lm_head = backbone.paligemma.lm_head
|
||||
|
||||
# ``_prepare_attention_masks_4d`` always returns fp32 (0.0 / -inf
|
||||
# literals). When weights are bf16, HF's PaliGemma SDPA raises
|
||||
# "invalid dtype for bias - should match query's dtype". Pull the
|
||||
# backbone's dtype once and cast each step's mask to it.
|
||||
backbone_dtype = next(backbone.paligemma.parameters()).dtype
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
att_2d = make_att_2d_masks(current_pad, current_att)
|
||||
position_ids = torch.cumsum(current_pad, dim=1) - 1
|
||||
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d)
|
||||
if att_2d_4d.dtype != backbone_dtype:
|
||||
att_2d_4d = att_2d_4d.to(dtype=backbone_dtype)
|
||||
(vlm_out, _), _ = backbone.forward(
|
||||
attention_mask=att_2d_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
Reference in New Issue
Block a user