mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 02:29:47 +00:00
fix(pi052): pass 4d masks to prefix-only forwards
Convert PI052 prefix-only attention masks before calling PaliGemma so text-only batches and generation use the same mask shape as fused training. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -726,9 +726,10 @@ class PI052Policy(PI05Policy):
|
||||
|
||||
att_2d = make_att_2d_masks(full_pad, full_att)
|
||||
position_ids = torch.cumsum(full_pad, dim=1) - 1
|
||||
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d)
|
||||
|
||||
(vlm_out, _), _ = self.model.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d,
|
||||
attention_mask=att_2d_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[full_embs, None],
|
||||
@@ -832,8 +833,9 @@ class PI052Policy(PI05Policy):
|
||||
for _ in range(max_new_tokens):
|
||||
att_2d = make_att_2d_masks(current_pad, current_att)
|
||||
position_ids = torch.cumsum(current_pad, dim=1) - 1
|
||||
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d)
|
||||
(vlm_out, _), _ = backbone.forward(
|
||||
attention_mask=att_2d,
|
||||
attention_mask=att_2d_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[current_embs, None],
|
||||
|
||||
Reference in New Issue
Block a user