diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index b528eaca0..1819abd22 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -726,9 +726,10 @@ class PI052Policy(PI05Policy): att_2d = make_att_2d_masks(full_pad, full_att) position_ids = torch.cumsum(full_pad, dim=1) - 1 + att_2d_4d = self.model._prepare_attention_masks_4d(att_2d) (vlm_out, _), _ = self.model.paligemma_with_expert.forward( - attention_mask=att_2d, + attention_mask=att_2d_4d, position_ids=position_ids, past_key_values=None, inputs_embeds=[full_embs, None], @@ -832,8 +833,9 @@ class PI052Policy(PI05Policy): for _ in range(max_new_tokens): att_2d = make_att_2d_masks(current_pad, current_att) position_ids = torch.cumsum(current_pad, dim=1) - 1 + att_2d_4d = self.model._prepare_attention_masks_4d(att_2d) (vlm_out, _), _ = backbone.forward( - attention_mask=att_2d, + attention_mask=att_2d_4d, position_ids=position_ids, past_key_values=None, inputs_embeds=[current_embs, None],