diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 3534c7ae8..f655e7601 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -444,13 +444,13 @@ class PaliGemmaWithExpertModel( if image.dtype != torch.float32: image = image.to(torch.float32) image_outputs = self.paligemma.model.get_image_features(image) - features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5 + features = image_outputs.pooler_output if features.dtype != out_dtype: features = features.to(out_dtype) return features def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.model.language_model.embed_tokens(tokens) + return self.paligemma.model.language_model.get_inputs_embeddings()(tokens) def forward( self, @@ -666,8 +666,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` # Process language tokens def lang_embed_func(lang_tokens): lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) - lang_emb_dim = lang_emb.shape[-1] - return lang_emb * math.sqrt(lang_emb_dim) + return lang_emb lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) embs.append(lang_emb) diff --git a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py index a49828ad1..dde26169f 100644 --- a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py @@ -260,13 +260,13 @@ class PI0FastPaliGemma(nn.Module): if image.dtype != torch.float32: image = image.to(torch.float32) image_outputs = self.paligemma.model.get_image_features(image) - features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5 + features = image_outputs.pooler_output if features.dtype != out_dtype: features = features.to(out_dtype) return features def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.model.language_model.embed_tokens(tokens) + return self.paligemma.model.language_model.get_inputs_embeddings()(tokens) def forward( self, @@ -416,8 +416,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` # Process language instruction tokens def lang_embed_func(tokens): lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) - lang_emb_dim = lang_emb.shape[-1] - return lang_emb * math.sqrt(lang_emb_dim) + return lang_emb lang_emb = self._apply_checkpoint(lang_embed_func, tokens) embs.append(lang_emb) @@ -431,8 +430,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` def fast_action_embed_func(fast_action_tokens): fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens) - fast_emb_dim = fast_emb.shape[-1] - return fast_emb * math.sqrt(fast_emb_dim) + return fast_emb fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens) embs.append(fast_action_emb) @@ -665,7 +663,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` if t < max_decoding_steps - 1: # embed the newly generated token next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token) - next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1]) if prefix_embs.dtype == torch.bfloat16: next_token_emb = next_token_emb.to(dtype=torch.bfloat16) @@ -770,7 +767,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` # Embed the single previous token # We use embed_language_tokens directly to avoid overhead of full prefix embedding next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token) - next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1]) if prefix_embs.dtype == torch.bfloat16: next_token_emb = next_token_emb.to(dtype=torch.bfloat16)