fix(pi052): call policy preprocessing helpers

Use PI05Policy helpers for action padding and image preprocessing in PI052 fused losses instead of looking them up on the inner PI05Pytorch module.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-18 17:52:47 +00:00
parent 1750a87104
commit 7960cc14ec
+4 -4
View File
@@ -546,7 +546,7 @@ class PI052Policy(PI05Policy):
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
# ---- preamble (mirrors PI05Pytorch.forward) ------------------
actions = self.model.prepare_action(batch)
actions = self.prepare_action(batch)
noise = self.model.sample_noise(actions.shape, actions.device)
time = self.model.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
@@ -554,7 +554,7 @@ class PI052Policy(PI05Policy):
u_t = noise - actions
# ---- prefix: images + language + (optional FAST) -------------
images, img_masks = self.model._preprocess_images(batch)
images, img_masks = self._preprocess_images(batch)
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix(
@@ -686,7 +686,7 @@ class PI052Policy(PI05Policy):
"""
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
images, img_masks = self.model._preprocess_images(batch)
images, img_masks = self._preprocess_images(batch)
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
@@ -805,7 +805,7 @@ class PI052Policy(PI05Policy):
if eos_token_id is not None:
special_ids.add(int(eos_token_id))
images, img_masks = self.model._preprocess_images(batch)
images, img_masks = self._preprocess_images(batch)
tokens = batch[OBS_LANGUAGE_TOKENS]
masks = batch[OBS_LANGUAGE_ATTENTION_MASK]