mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
fix(pi05): update pi05 with transformers v5.4.0 interface (#3603)
This commit is contained in:
@@ -441,13 +441,13 @@ class PaliGemmaWithExpertModel(
|
|||||||
if image.dtype != torch.float32:
|
if image.dtype != torch.float32:
|
||||||
image = image.to(torch.float32)
|
image = image.to(torch.float32)
|
||||||
image_outputs = self.paligemma.model.get_image_features(image)
|
image_outputs = self.paligemma.model.get_image_features(image)
|
||||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
features = image_outputs.pooler_output
|
||||||
if features.dtype != out_dtype:
|
if features.dtype != out_dtype:
|
||||||
features = features.to(out_dtype)
|
features = features.to(out_dtype)
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -662,8 +662,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
# Process language tokens
|
# Process language tokens
|
||||||
def lang_embed_func(tokens):
|
def lang_embed_func(tokens):
|
||||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||||
lang_emb_dim = lang_emb.shape[-1]
|
return lang_emb
|
||||||
return lang_emb * math.sqrt(lang_emb_dim)
|
|
||||||
|
|
||||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||||
embs.append(lang_emb)
|
embs.append(lang_emb)
|
||||||
|
|||||||
Reference in New Issue
Block a user