mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
fix(pi052): stop double-scaling FAST/text token embeddings
embed_language_tokens already applies Gemma's sqrt(hidden) normalizer (GemmaTextScaledWordEmbedding, transformers >=5.4.0). pi052 multiplied FAST action-token and autoregressive subtask-text embeddings by sqrt(emb_dim) on top of that, double-scaling them (~2048x). Remove the manual scaling so FAST and text tokens are single-scaled, consistent with the pi05 fix and OpenPI. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -38,7 +38,6 @@ for the LM head.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import types
|
||||
from typing import Any
|
||||
|
||||
@@ -719,9 +718,9 @@ class PI052Policy(PI05Policy):
|
||||
|
||||
fast_len = 0
|
||||
if action_tokens is not None and action_mask is not None:
|
||||
emb_dim = prefix_embs.shape[-1]
|
||||
# embed_language_tokens already applies the Gemma sqrt(hidden) scale (tf>=5.4.0);
|
||||
# do not scale FAST action tokens again (would double-scale).
|
||||
fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens)
|
||||
fast_emb = fast_emb * math.sqrt(emb_dim)
|
||||
fast_len = action_tokens.shape[1]
|
||||
ones_att = torch.ones(
|
||||
(action_tokens.shape[0], fast_len),
|
||||
@@ -865,9 +864,9 @@ class PI052Policy(PI05Policy):
|
||||
|
||||
fast_len = 0
|
||||
if action_tokens is not None and action_mask is not None:
|
||||
emb_dim = prefix_embs.shape[-1]
|
||||
# embed_language_tokens already applies the Gemma sqrt(hidden) scale (tf>=5.4.0);
|
||||
# do not scale FAST action tokens again (would double-scale).
|
||||
fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens)
|
||||
fast_emb = fast_emb * math.sqrt(emb_dim)
|
||||
|
||||
fast_len = action_tokens.shape[1]
|
||||
ones_att = torch.ones(
|
||||
@@ -1083,8 +1082,6 @@ class PI052Policy(PI05Policy):
|
||||
|
||||
device = prefix_embs.device
|
||||
bsize = prefix_embs.shape[0]
|
||||
emb_dim = prefix_embs.shape[-1]
|
||||
text_emb_scale = math.sqrt(emb_dim)
|
||||
ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
|
||||
current_embs = prefix_embs
|
||||
@@ -1145,8 +1142,8 @@ class PI052Policy(PI05Policy):
|
||||
if eos_token_id is not None and tok_id == eos_token_id:
|
||||
break
|
||||
|
||||
# embed_language_tokens already applies the Gemma sqrt(hidden) scale (tf>=5.4.0).
|
||||
new_emb = backbone.embed_language_tokens(next_ids.unsqueeze(0))
|
||||
new_emb = new_emb * text_emb_scale
|
||||
current_embs = torch.cat([current_embs, new_emb], dim=1)
|
||||
current_pad = torch.cat([current_pad, ones_step], dim=1)
|
||||
current_att = torch.cat([current_att, ones_step], dim=1)
|
||||
|
||||
Reference in New Issue
Block a user