fix(pi0,pi0_fast): scale text embeddings by sqrt(embed_dim) to match OpenPI

OpenPI (pi0 and pi0-FAST) multiplies language token embeddings by
sqrt(embed_dim) — the Gemma embedder normalizer — before the transformer.
LeRobot pi0/pi0_fast omitted it, leaving text tokens ~45x under-scaled
relative to the residual stream (same class of bug as the pi05 image
scaling). pi0: applied in embed_prefix's lang_embed_func. pi0_fast:
applied inside embed_language_tokens so prompt, FAST action tokens, and
autoregressive next-token embeds are all scaled consistently.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn223
2026-06-04 18:14:27 +02:00
parent a48d4e32a1
commit f0757fc707
2 changed files with 11 additions and 2 deletions
+5 -1
View File
@@ -678,7 +678,11 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Process language tokens
def lang_embed_func(lang_tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
return lang_emb
# Match OpenPI: scale text embeddings by sqrt(embed_dim) (the Gemma embedder
# normalizer). lerobot/pi0_base (OpenPI port) expects this; main omitted it,
# leaving language tokens ~45x under-scaled relative to the residual stream.
lang_emb_dim = lang_emb.shape[-1]
return lang_emb * math.sqrt(lang_emb_dim)
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
embs.append(lang_emb)
@@ -268,7 +268,12 @@ class PI0FastPaliGemma(nn.Module):
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
# Match OpenPI: scale token embeddings by sqrt(embed_dim) (the Gemma embedder
# normalizer). Applied here so every caller — prompt, FAST action tokens, and
# autoregressive next-token embeds — is scaled consistently, mirroring OpenPI's
# llm.embed which normalizes all embedded tokens. main omitted this.
lang_emb = self.paligemma.model.language_model.get_input_embeddings()(tokens)
return lang_emb * (lang_emb.shape[-1] ** 0.5)
def forward(
self,