fix(pi05): cast attention masks to model dtype

Ensure attention masks follow the backbone dtype during bf16 inference to avoid mixed dtype failures.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-21 09:51:12 +00:00
parent 5bb2da4da6
commit 28b86449a2
3 changed files with 19 additions and 10 deletions
+12 -5
View File
@@ -617,10 +617,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
)
return func(*args, **kwargs)
def _prepare_attention_masks_4d(self, att_2d_masks):
def _prepare_attention_masks_4d(self, att_2d_masks, dtype=None):
"""Helper method to prepare 4D attention masks for transformer."""
att_2d_masks_4d = att_2d_masks[:, None, :, :]
return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
result = torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
if dtype is not None:
result = result.to(dtype=dtype)
return result
def sample_noise(self, shape, device):
return torch.normal(
@@ -756,7 +759,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
(_, suffix_out), _ = self.paligemma_with_expert.forward(
@@ -814,7 +817,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(
prefix_att_2d_masks, dtype=prefix_embs.dtype
)
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
@@ -884,7 +889,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
full_att_2d_masks_4d = self._prepare_attention_masks_4d(
full_att_2d_masks, dtype=suffix_embs.dtype
)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
+5 -5
View File
@@ -630,7 +630,9 @@ class PI052Policy(PI05Policy):
att_2d_masks[:, fast_end:, fast_start:fast_end] = False
position_ids = torch.cumsum(pad_masks, dim=1) - 1
att_2d_masks_4d = self.model._prepare_attention_masks_4d(att_2d_masks)
att_2d_masks_4d = self.model._prepare_attention_masks_4d(
att_2d_masks, dtype=prefix_embs.dtype
)
# ---- forward (capture BOTH expert outputs) ------------------
(prefix_out, suffix_out), _ = self.model.paligemma_with_expert.forward(
@@ -740,7 +742,7 @@ class PI052Policy(PI05Policy):
att_2d = make_att_2d_masks(full_pad, full_att)
position_ids = torch.cumsum(full_pad, dim=1) - 1
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d)
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d, dtype=full_embs.dtype)
(vlm_out, _), _ = self.model.paligemma_with_expert.forward(
attention_mask=att_2d_4d,
@@ -864,9 +866,7 @@ class PI052Policy(PI05Policy):
for _ in range(max_new_tokens):
att_2d = make_att_2d_masks(current_pad, current_att)
position_ids = torch.cumsum(current_pad, dim=1) - 1
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d)
if att_2d_4d.dtype != backbone_dtype:
att_2d_4d = att_2d_4d.to(dtype=backbone_dtype)
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d, dtype=backbone_dtype)
(vlm_out, _), _ = backbone.forward(
attention_mask=att_2d_4d,
position_ids=position_ids,
+2
View File
@@ -272,6 +272,8 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc]
# Convert to bfloat16 if the first layer uses bfloat16
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.bfloat16)
if causal_mask is not None and torch.is_floating_point(causal_mask):
causal_mask = causal_mask.to(dtype=hidden_states.dtype)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)