diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index bb206d608..bdaf01f2c 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -441,13 +441,13 @@ class PaliGemmaWithExpertModel( if image.dtype != torch.float32: image = image.to(torch.float32) image_outputs = self.paligemma.model.get_image_features(image) - features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5 + features = image_outputs.pooler_output if features.dtype != out_dtype: features = features.to(out_dtype) return features def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.model.language_model.embed_tokens(tokens) + return self.paligemma.model.language_model.get_input_embeddings()(tokens) def forward( self, @@ -662,8 +662,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` # Process language tokens def lang_embed_func(tokens): lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) - lang_emb_dim = lang_emb.shape[-1] - return lang_emb * math.sqrt(lang_emb_dim) + return lang_emb lang_emb = self._apply_checkpoint(lang_embed_func, tokens) embs.append(lang_emb)