fix(pi0,pi05,pi0_fast): stop double-scaling text embeddings

transformers >=5.4.0 (PR #44432) makes Gemma's embed_tokens a
GemmaTextScaledWordEmbedding that already multiplies token embeddings by
sqrt(hidden_size). The manual `* sqrt(embed_dim)` applied on top therefore
double-scaled text (~2048x instead of ~45x), breaking VLM alignment for
models trained/run on stock transformers. Remove the manual scaling and rely
on embed_tokens' internal normalizer (matches main #3603). Image features
stay raw (un-normalized), as before.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn223
2026-06-04 18:22:34 +02:00
parent f0757fc707
commit 77cc35b932
3 changed files with 11 additions and 15 deletions
+4 -6
View File
@@ -677,12 +677,10 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Process language tokens
def lang_embed_func(lang_tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
# Match OpenPI: scale text embeddings by sqrt(embed_dim) (the Gemma embedder
# normalizer). lerobot/pi0_base (OpenPI port) expects this; main omitted it,
# leaving language tokens ~45x under-scaled relative to the residual stream.
lang_emb_dim = lang_emb.shape[-1]
return lang_emb * math.sqrt(lang_emb_dim)
# embed_language_tokens -> Gemma get_input_embeddings(), which is
# GemmaTextScaledWordEmbedding (transformers >=5.4.0): it already multiplies by
# sqrt(hidden_size) internally. Do NOT scale again here (would double-scale text).
return self.paligemma_with_expert.embed_language_tokens(lang_tokens)
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
embs.append(lang_emb)
+4 -3
View File
@@ -704,9 +704,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Process language tokens
def lang_embed_func(tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
lang_emb_dim = lang_emb.shape[-1]
return lang_emb * math.sqrt(lang_emb_dim)
# embed_language_tokens -> Gemma embed_tokens, which is GemmaTextScaledWordEmbedding
# (transformers >=5.4.0): it already multiplies by sqrt(hidden_size) internally. Do NOT
# scale again here or text tokens get double-scaled (~45x) and break alignment.
return self.paligemma_with_expert.embed_language_tokens(tokens)
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
embs.append(lang_emb)
@@ -268,12 +268,9 @@ class PI0FastPaliGemma(nn.Module):
return features
def embed_language_tokens(self, tokens: torch.Tensor):
# Match OpenPI: scale token embeddings by sqrt(embed_dim) (the Gemma embedder
# normalizer). Applied here so every caller — prompt, FAST action tokens, and
# autoregressive next-token embeds — is scaled consistently, mirroring OpenPI's
# llm.embed which normalizes all embedded tokens. main omitted this.
lang_emb = self.paligemma.model.language_model.get_input_embeddings()(tokens)
return lang_emb * (lang_emb.shape[-1] ** 0.5)
# get_input_embeddings() is GemmaTextScaledWordEmbedding (transformers >=5.4.0): it already
# multiplies by sqrt(hidden_size) internally, so no manual scaling is needed here.
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
def forward(
self,