From bb31988915694f8285b3e04eaa5a98cce8494e00 Mon Sep 17 00:00:00 2001 From: pepijn Date: Mon, 18 May 2026 21:07:13 +0000 Subject: [PATCH] fix(pi052): pass 4d masks to prefix-only forwards Convert PI052 prefix-only attention masks before calling PaliGemma so text-only batches and generation use the same mask shape as fused training. Co-authored-by: Cursor --- src/lerobot/policies/pi052/modeling_pi052.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index b528eaca0..1819abd22 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -726,9 +726,10 @@ class PI052Policy(PI05Policy): att_2d = make_att_2d_masks(full_pad, full_att) position_ids = torch.cumsum(full_pad, dim=1) - 1 + att_2d_4d = self.model._prepare_attention_masks_4d(att_2d) (vlm_out, _), _ = self.model.paligemma_with_expert.forward( - attention_mask=att_2d, + attention_mask=att_2d_4d, position_ids=position_ids, past_key_values=None, inputs_embeds=[full_embs, None], @@ -832,8 +833,9 @@ class PI052Policy(PI05Policy): for _ in range(max_new_tokens): att_2d = make_att_2d_masks(current_pad, current_att) position_ids = torch.cumsum(current_pad, dim=1) - 1 + att_2d_4d = self.model._prepare_attention_masks_4d(att_2d) (vlm_out, _), _ = backbone.forward( - attention_mask=att_2d, + attention_mask=att_2d_4d, position_ids=position_ids, past_key_values=None, inputs_embeds=[current_embs, None],