From 7960cc14ec448a91f0b3327c5acd7d1622e54e91 Mon Sep 17 00:00:00 2001 From: pepijn Date: Mon, 18 May 2026 17:52:47 +0000 Subject: [PATCH] fix(pi052): call policy preprocessing helpers Use PI05Policy helpers for action padding and image preprocessing in PI052 fused losses instead of looking them up on the inner PI05Pytorch module. Co-authored-by: Cursor --- src/lerobot/policies/pi052/modeling_pi052.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index c429830b7..7fbf7921c 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -546,7 +546,7 @@ class PI052Policy(PI05Policy): from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415 # ---- preamble (mirrors PI05Pytorch.forward) ------------------ - actions = self.model.prepare_action(batch) + actions = self.prepare_action(batch) noise = self.model.sample_noise(actions.shape, actions.device) time = self.model.sample_time(actions.shape[0], actions.device) time_expanded = time[:, None, None] @@ -554,7 +554,7 @@ class PI052Policy(PI05Policy): u_t = noise - actions # ---- prefix: images + language + (optional FAST) ------------- - images, img_masks = self.model._preprocess_images(batch) + images, img_masks = self._preprocess_images(batch) lang_tokens = batch[OBS_LANGUAGE_TOKENS] lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK] prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix( @@ -686,7 +686,7 @@ class PI052Policy(PI05Policy): """ from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415 - images, img_masks = self.model._preprocess_images(batch) + images, img_masks = self._preprocess_images(batch) lang_tokens = batch[OBS_LANGUAGE_TOKENS] lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK] @@ -805,7 +805,7 @@ class PI052Policy(PI05Policy): if eos_token_id is not None: special_ids.add(int(eos_token_id)) - images, img_masks = self.model._preprocess_images(batch) + images, img_masks = self._preprocess_images(batch) tokens = batch[OBS_LANGUAGE_TOKENS] masks = batch[OBS_LANGUAGE_ATTENTION_MASK]