From 77cc35b9326ee28509be431f4c184212049ec155 Mon Sep 17 00:00:00 2001 From: pepijn223 Date: Thu, 4 Jun 2026 18:22:34 +0200 Subject: [PATCH] fix(pi0,pi05,pi0_fast): stop double-scaling text embeddings transformers >=5.4.0 (PR #44432) makes Gemma's embed_tokens a GemmaTextScaledWordEmbedding that already multiplies token embeddings by sqrt(hidden_size). The manual `* sqrt(embed_dim)` applied on top therefore double-scaled text (~2048x instead of ~45x), breaking VLM alignment for models trained/run on stock transformers. Remove the manual scaling and rely on embed_tokens' internal normalizer (matches main #3603). Image features stay raw (un-normalized), as before. Co-authored-by: Cursor --- src/lerobot/policies/pi0/modeling_pi0.py | 10 ++++------ src/lerobot/policies/pi05/modeling_pi05.py | 7 ++++--- src/lerobot/policies/pi0_fast/modeling_pi0_fast.py | 9 +++------ 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index f25a686b4..753b8d45f 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -677,12 +677,10 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` # Process language tokens def lang_embed_func(lang_tokens): - lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) - # Match OpenPI: scale text embeddings by sqrt(embed_dim) (the Gemma embedder - # normalizer). lerobot/pi0_base (OpenPI port) expects this; main omitted it, - # leaving language tokens ~45x under-scaled relative to the residual stream. - lang_emb_dim = lang_emb.shape[-1] - return lang_emb * math.sqrt(lang_emb_dim) + # embed_language_tokens -> Gemma get_input_embeddings(), which is + # GemmaTextScaledWordEmbedding (transformers >=5.4.0): it already multiplies by + # sqrt(hidden_size) internally. Do NOT scale again here (would double-scale text). + return self.paligemma_with_expert.embed_language_tokens(lang_tokens) lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) embs.append(lang_emb) diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 12d662c17..3c30cdeb8 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -704,9 +704,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` # Process language tokens def lang_embed_func(tokens): - lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) - lang_emb_dim = lang_emb.shape[-1] - return lang_emb * math.sqrt(lang_emb_dim) + # embed_language_tokens -> Gemma embed_tokens, which is GemmaTextScaledWordEmbedding + # (transformers >=5.4.0): it already multiplies by sqrt(hidden_size) internally. Do NOT + # scale again here or text tokens get double-scaled (~45x) and break alignment. + return self.paligemma_with_expert.embed_language_tokens(tokens) lang_emb = self._apply_checkpoint(lang_embed_func, tokens) embs.append(lang_emb) diff --git a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py index d52b0033d..d342cffaf 100644 --- a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py @@ -268,12 +268,9 @@ class PI0FastPaliGemma(nn.Module): return features def embed_language_tokens(self, tokens: torch.Tensor): - # Match OpenPI: scale token embeddings by sqrt(embed_dim) (the Gemma embedder - # normalizer). Applied here so every caller — prompt, FAST action tokens, and - # autoregressive next-token embeds — is scaled consistently, mirroring OpenPI's - # llm.embed which normalizes all embedded tokens. main omitted this. - lang_emb = self.paligemma.model.language_model.get_input_embeddings()(tokens) - return lang_emb * (lang_emb.shape[-1] ** 0.5) + # get_input_embeddings() is GemmaTextScaledWordEmbedding (transformers >=5.4.0): it already + # multiplies by sqrt(hidden_size) internally, so no manual scaling is needed here. + return self.paligemma.model.language_model.get_input_embeddings()(tokens) def forward( self,