mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 02:29:47 +00:00
fix(pi052): call policy preprocessing helpers
Use PI05Policy helpers for action padding and image preprocessing in PI052 fused losses instead of looking them up on the inner PI05Pytorch module. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -546,7 +546,7 @@ class PI052Policy(PI05Policy):
|
||||
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
|
||||
|
||||
# ---- preamble (mirrors PI05Pytorch.forward) ------------------
|
||||
actions = self.model.prepare_action(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||
time_expanded = time[:, None, None]
|
||||
@@ -554,7 +554,7 @@ class PI052Policy(PI05Policy):
|
||||
u_t = noise - actions
|
||||
|
||||
# ---- prefix: images + language + (optional FAST) -------------
|
||||
images, img_masks = self.model._preprocess_images(batch)
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix(
|
||||
@@ -686,7 +686,7 @@ class PI052Policy(PI05Policy):
|
||||
"""
|
||||
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
|
||||
|
||||
images, img_masks = self.model._preprocess_images(batch)
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
@@ -805,7 +805,7 @@ class PI052Policy(PI05Policy):
|
||||
if eos_token_id is not None:
|
||||
special_ids.add(int(eos_token_id))
|
||||
|
||||
images, img_masks = self.model._preprocess_images(batch)
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user