fix(pi052): stop double-scaling FAST/text token embeddings

embed_language_tokens already applies Gemma's sqrt(hidden) normalizer
(GemmaTextScaledWordEmbedding, transformers >=5.4.0). pi052 multiplied FAST
action-token and autoregressive subtask-text embeddings by sqrt(emb_dim) on
top of that, double-scaling them (~2048x). Remove the manual scaling so FAST
and text tokens are single-scaled, consistent with the pi05 fix and OpenPI.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn223
2026-06-04 18:31:41 +02:00
parent 77cc35b932
commit 8292548f0d
+5 -8
View File
@@ -38,7 +38,6 @@ for the LM head.
from __future__ import annotations
import logging
import math
import types
from typing import Any
@@ -719,9 +718,9 @@ class PI052Policy(PI05Policy):
fast_len = 0
if action_tokens is not None and action_mask is not None:
emb_dim = prefix_embs.shape[-1]
# embed_language_tokens already applies the Gemma sqrt(hidden) scale (tf>=5.4.0);
# do not scale FAST action tokens again (would double-scale).
fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens)
fast_emb = fast_emb * math.sqrt(emb_dim)
fast_len = action_tokens.shape[1]
ones_att = torch.ones(
(action_tokens.shape[0], fast_len),
@@ -865,9 +864,9 @@ class PI052Policy(PI05Policy):
fast_len = 0
if action_tokens is not None and action_mask is not None:
emb_dim = prefix_embs.shape[-1]
# embed_language_tokens already applies the Gemma sqrt(hidden) scale (tf>=5.4.0);
# do not scale FAST action tokens again (would double-scale).
fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens)
fast_emb = fast_emb * math.sqrt(emb_dim)
fast_len = action_tokens.shape[1]
ones_att = torch.ones(
@@ -1083,8 +1082,6 @@ class PI052Policy(PI05Policy):
device = prefix_embs.device
bsize = prefix_embs.shape[0]
emb_dim = prefix_embs.shape[-1]
text_emb_scale = math.sqrt(emb_dim)
ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device)
current_embs = prefix_embs
@@ -1145,8 +1142,8 @@ class PI052Policy(PI05Policy):
if eos_token_id is not None and tok_id == eos_token_id:
break
# embed_language_tokens already applies the Gemma sqrt(hidden) scale (tf>=5.4.0).
new_emb = backbone.embed_language_tokens(next_ids.unsqueeze(0))
new_emb = new_emb * text_emb_scale
current_embs = torch.cat([current_embs, new_emb], dim=1)
current_pad = torch.cat([current_pad, ones_step], dim=1)
current_att = torch.cat([current_att, ones_step], dim=1)