From 8292548f0dbf7b7da196ac31cdadeb589d597ef1 Mon Sep 17 00:00:00 2001 From: pepijn223 Date: Thu, 4 Jun 2026 18:31:41 +0200 Subject: [PATCH] fix(pi052): stop double-scaling FAST/text token embeddings embed_language_tokens already applies Gemma's sqrt(hidden) normalizer (GemmaTextScaledWordEmbedding, transformers >=5.4.0). pi052 multiplied FAST action-token and autoregressive subtask-text embeddings by sqrt(emb_dim) on top of that, double-scaling them (~2048x). Remove the manual scaling so FAST and text tokens are single-scaled, consistent with the pi05 fix and OpenPI. Co-authored-by: Cursor --- src/lerobot/policies/pi052/modeling_pi052.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index f38536994..5cc7db457 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -38,7 +38,6 @@ for the LM head. from __future__ import annotations import logging -import math import types from typing import Any @@ -719,9 +718,9 @@ class PI052Policy(PI05Policy): fast_len = 0 if action_tokens is not None and action_mask is not None: - emb_dim = prefix_embs.shape[-1] + # embed_language_tokens already applies the Gemma sqrt(hidden) scale (tf>=5.4.0); + # do not scale FAST action tokens again (would double-scale). fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens) - fast_emb = fast_emb * math.sqrt(emb_dim) fast_len = action_tokens.shape[1] ones_att = torch.ones( (action_tokens.shape[0], fast_len), @@ -865,9 +864,9 @@ class PI052Policy(PI05Policy): fast_len = 0 if action_tokens is not None and action_mask is not None: - emb_dim = prefix_embs.shape[-1] + # embed_language_tokens already applies the Gemma sqrt(hidden) scale (tf>=5.4.0); + # do not scale FAST action tokens again (would double-scale). fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens) - fast_emb = fast_emb * math.sqrt(emb_dim) fast_len = action_tokens.shape[1] ones_att = torch.ones( @@ -1083,8 +1082,6 @@ class PI052Policy(PI05Policy): device = prefix_embs.device bsize = prefix_embs.shape[0] - emb_dim = prefix_embs.shape[-1] - text_emb_scale = math.sqrt(emb_dim) ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device) current_embs = prefix_embs @@ -1145,8 +1142,8 @@ class PI052Policy(PI05Policy): if eos_token_id is not None and tok_id == eos_token_id: break + # embed_language_tokens already applies the Gemma sqrt(hidden) scale (tf>=5.4.0). new_emb = backbone.embed_language_tokens(next_ids.unsqueeze(0)) - new_emb = new_emb * text_emb_scale current_embs = torch.cat([current_embs, new_emb], dim=1) current_pad = torch.cat([current_pad, ones_step], dim=1) current_att = torch.cat([current_att, ones_step], dim=1)