fix(pi052): pass 4d masks to prefix-only forwards

Convert PI052 prefix-only attention masks before calling PaliGemma so text-only batches and generation use the same mask shape as fused training.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-18 21:07:13 +00:00
parent 2629175d2d
commit bb31988915
+4 -2
View File
@@ -726,9 +726,10 @@ class PI052Policy(PI05Policy):
att_2d = make_att_2d_masks(full_pad, full_att)
position_ids = torch.cumsum(full_pad, dim=1) - 1
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d)
(vlm_out, _), _ = self.model.paligemma_with_expert.forward(
attention_mask=att_2d,
attention_mask=att_2d_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[full_embs, None],
@@ -832,8 +833,9 @@ class PI052Policy(PI05Policy):
for _ in range(max_new_tokens):
att_2d = make_att_2d_masks(current_pad, current_att)
position_ids = torch.cumsum(current_pad, dim=1) - 1
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d)
(vlm_out, _), _ = backbone.forward(
attention_mask=att_2d,
attention_mask=att_2d_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[current_embs, None],