mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
fix(pi0,pi0_fast): scale text embeddings by sqrt(embed_dim) to match OpenPI
OpenPI (pi0 and pi0-FAST) multiplies language token embeddings by sqrt(embed_dim) — the Gemma embedder normalizer — before the transformer. LeRobot pi0/pi0_fast omitted it, leaving text tokens ~45x under-scaled relative to the residual stream (same class of bug as the pi05 image scaling). pi0: applied in embed_prefix's lang_embed_func. pi0_fast: applied inside embed_language_tokens so prompt, FAST action tokens, and autoregressive next-token embeds are all scaled consistently. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -678,7 +678,11 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Process language tokens
|
||||
def lang_embed_func(lang_tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||
return lang_emb
|
||||
# Match OpenPI: scale text embeddings by sqrt(embed_dim) (the Gemma embedder
|
||||
# normalizer). lerobot/pi0_base (OpenPI port) expects this; main omitted it,
|
||||
# leaving language tokens ~45x under-scaled relative to the residual stream.
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
||||
embs.append(lang_emb)
|
||||
|
||||
@@ -268,7 +268,12 @@ class PI0FastPaliGemma(nn.Module):
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
# Match OpenPI: scale token embeddings by sqrt(embed_dim) (the Gemma embedder
|
||||
# normalizer). Applied here so every caller — prompt, FAST action tokens, and
|
||||
# autoregressive next-token embeds — is scaled consistently, mirroring OpenPI's
|
||||
# llm.embed which normalizes all embedded tokens. main omitted this.
|
||||
lang_emb = self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
return lang_emb * (lang_emb.shape[-1] ** 0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user