mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
fix(pi052): read backbone dtype from q_proj, not first parameter
select_message's bf16 cast used next(paligemma.parameters()).dtype, which lands on a fp32-kept param (norm / embedding) under to_bfloat16_for_selected_params. Mask stayed fp32 while q/k/v were bf16 → SDPA still raised "invalid dtype for bias". Read the dtype from layers[0].self_attn.q_proj.weight instead — q_proj is always cast with the rest, so its dtype matches what SDPA sees. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -851,8 +851,15 @@ class PI052Policy(PI05Policy):
|
|||||||
# ``_prepare_attention_masks_4d`` always returns fp32 (0.0 / -inf
|
# ``_prepare_attention_masks_4d`` always returns fp32 (0.0 / -inf
|
||||||
# literals). When weights are bf16, HF's PaliGemma SDPA raises
|
# literals). When weights are bf16, HF's PaliGemma SDPA raises
|
||||||
# "invalid dtype for bias - should match query's dtype". Pull the
|
# "invalid dtype for bias - should match query's dtype". Pull the
|
||||||
# backbone's dtype once and cast each step's mask to it.
|
# dtype from an attention *projection* weight specifically:
|
||||||
backbone_dtype = next(backbone.paligemma.parameters()).dtype
|
# ``to_bfloat16_for_selected_params`` keeps norms / embeddings in
|
||||||
|
# fp32 even when the rest is bf16, so ``next(parameters())``
|
||||||
|
# would land on one of those and we'd skip the cast. q_proj is
|
||||||
|
# always cast with the rest, so its dtype is the one SDPA sees.
|
||||||
|
backbone_dtype = (
|
||||||
|
backbone.paligemma.model.language_model.layers[0]
|
||||||
|
.self_attn.q_proj.weight.dtype
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
att_2d = make_att_2d_masks(current_pad, current_att)
|
att_2d = make_att_2d_masks(current_pad, current_att)
|
||||||
|
|||||||
Reference in New Issue
Block a user