diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index f38536994..5cc7db457 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -38,7 +38,6 @@ for the LM head. from __future__ import annotations import logging -import math import types from typing import Any @@ -719,9 +718,9 @@ class PI052Policy(PI05Policy): fast_len = 0 if action_tokens is not None and action_mask is not None: - emb_dim = prefix_embs.shape[-1] + # embed_language_tokens already applies the Gemma sqrt(hidden) scale (tf>=5.4.0); + # do not scale FAST action tokens again (would double-scale). fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens) - fast_emb = fast_emb * math.sqrt(emb_dim) fast_len = action_tokens.shape[1] ones_att = torch.ones( (action_tokens.shape[0], fast_len), @@ -865,9 +864,9 @@ class PI052Policy(PI05Policy): fast_len = 0 if action_tokens is not None and action_mask is not None: - emb_dim = prefix_embs.shape[-1] + # embed_language_tokens already applies the Gemma sqrt(hidden) scale (tf>=5.4.0); + # do not scale FAST action tokens again (would double-scale). fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens) - fast_emb = fast_emb * math.sqrt(emb_dim) fast_len = action_tokens.shape[1] ones_att = torch.ones( @@ -1083,8 +1082,6 @@ class PI052Policy(PI05Policy): device = prefix_embs.device bsize = prefix_embs.shape[0] - emb_dim = prefix_embs.shape[-1] - text_emb_scale = math.sqrt(emb_dim) ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device) current_embs = prefix_embs @@ -1145,8 +1142,8 @@ class PI052Policy(PI05Policy): if eos_token_id is not None and tok_id == eos_token_id: break + # embed_language_tokens already applies the Gemma sqrt(hidden) scale (tf>=5.4.0). new_emb = backbone.embed_language_tokens(next_ids.unsqueeze(0)) - new_emb = new_emb * text_emb_scale current_embs = torch.cat([current_embs, new_emb], dim=1) current_pad = torch.cat([current_pad, ones_step], dim=1) current_att = torch.cat([current_att, ones_step], dim=1)