From f0757fc7074d47048eb5647995f30fdbc218611f Mon Sep 17 00:00:00 2001 From: pepijn223 Date: Thu, 4 Jun 2026 18:14:27 +0200 Subject: [PATCH] fix(pi0,pi0_fast): scale text embeddings by sqrt(embed_dim) to match OpenPI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit OpenPI (pi0 and pi0-FAST) multiplies language token embeddings by sqrt(embed_dim) — the Gemma embedder normalizer — before the transformer. LeRobot pi0/pi0_fast omitted it, leaving text tokens ~45x under-scaled relative to the residual stream (same class of bug as the pi05 image scaling). pi0: applied in embed_prefix's lang_embed_func. pi0_fast: applied inside embed_language_tokens so prompt, FAST action tokens, and autoregressive next-token embeds are all scaled consistently. Co-authored-by: Cursor --- src/lerobot/policies/pi0/modeling_pi0.py | 6 +++++- src/lerobot/policies/pi0_fast/modeling_pi0_fast.py | 7 ++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index f6f4212fb..f25a686b4 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -678,7 +678,11 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` # Process language tokens def lang_embed_func(lang_tokens): lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) - return lang_emb + # Match OpenPI: scale text embeddings by sqrt(embed_dim) (the Gemma embedder + # normalizer). lerobot/pi0_base (OpenPI port) expects this; main omitted it, + # leaving language tokens ~45x under-scaled relative to the residual stream. + lang_emb_dim = lang_emb.shape[-1] + return lang_emb * math.sqrt(lang_emb_dim) lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) embs.append(lang_emb) diff --git a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py index d9342eb24..d52b0033d 100644 --- a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py @@ -268,7 +268,12 @@ class PI0FastPaliGemma(nn.Module): return features def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.model.language_model.get_input_embeddings()(tokens) + # Match OpenPI: scale token embeddings by sqrt(embed_dim) (the Gemma embedder + # normalizer). Applied here so every caller — prompt, FAST action tokens, and + # autoregressive next-token embeds — is scaled consistently, mirroring OpenPI's + # llm.embed which normalizes all embedded tokens. main omitted this. + lang_emb = self.paligemma.model.language_model.get_input_embeddings()(tokens) + return lang_emb * (lang_emb.shape[-1] ** 0.5) def forward( self,