mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
fix(pi052): call policy preprocessing helpers
Use PI05Policy helpers for action padding and image preprocessing in PI052 fused losses instead of looking them up on the inner PI05Pytorch module. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -546,7 +546,7 @@ class PI052Policy(PI05Policy):
|
|||||||
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
|
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
|
||||||
|
|
||||||
# ---- preamble (mirrors PI05Pytorch.forward) ------------------
|
# ---- preamble (mirrors PI05Pytorch.forward) ------------------
|
||||||
actions = self.model.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||||
time = self.model.sample_time(actions.shape[0], actions.device)
|
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||||
time_expanded = time[:, None, None]
|
time_expanded = time[:, None, None]
|
||||||
@@ -554,7 +554,7 @@ class PI052Policy(PI05Policy):
|
|||||||
u_t = noise - actions
|
u_t = noise - actions
|
||||||
|
|
||||||
# ---- prefix: images + language + (optional FAST) -------------
|
# ---- prefix: images + language + (optional FAST) -------------
|
||||||
images, img_masks = self.model._preprocess_images(batch)
|
images, img_masks = self._preprocess_images(batch)
|
||||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||||
prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix(
|
prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix(
|
||||||
@@ -686,7 +686,7 @@ class PI052Policy(PI05Policy):
|
|||||||
"""
|
"""
|
||||||
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
|
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
|
||||||
|
|
||||||
images, img_masks = self.model._preprocess_images(batch)
|
images, img_masks = self._preprocess_images(batch)
|
||||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||||
|
|
||||||
@@ -805,7 +805,7 @@ class PI052Policy(PI05Policy):
|
|||||||
if eos_token_id is not None:
|
if eos_token_id is not None:
|
||||||
special_ids.add(int(eos_token_id))
|
special_ids.add(int(eos_token_id))
|
||||||
|
|
||||||
images, img_masks = self.model._preprocess_images(batch)
|
images, img_masks = self._preprocess_images(batch)
|
||||||
tokens = batch[OBS_LANGUAGE_TOKENS]
|
tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||||
masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user