mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(sarm): handle BaseModelOutputWithPooling from transformers 5.x in… (#3419)
* fix(sarm): handle BaseModelOutputWithPooling from transformers 5.x in CLIP encoding In transformers 5.x, CLIPModel.get_image_features() and get_text_features() return BaseModelOutputWithPooling instead of a plain torch.FloatTensor. Added isinstance check to extract pooler_output when the return value is not a tensor, maintaining backward compatibility with transformers 4.x. Fixes AttributeError: 'BaseModelOutputWithPooling' object has no attribute 'detach' * Adding assertion check for pooler_output of CLIP. This change is response to below comment. https://github.com/huggingface/lerobot/pull/3419#discussion_r3112594387 * Adding assertion check for pooler_output of CLIP. This change is response to below comment. Change to simple check and rise https://github.com/huggingface/lerobot/pull/3419#discussion_r3126953776 --------- Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
This commit is contained in:
@@ -455,7 +455,13 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Get image embeddings
|
||||
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
|
||||
# transformers 5.x returns BaseModelOutputWithPooling instead of a plain tensor
|
||||
output = self.clip_model.get_image_features(**inputs)
|
||||
if not isinstance(output, torch.Tensor):
|
||||
output = output.pooler_output
|
||||
if output is None:
|
||||
raise ValueError("pooler_output should not be None for CLIP models.")
|
||||
embeddings = output.detach().cpu()
|
||||
|
||||
# Handle single frame case
|
||||
if embeddings.dim() == 1:
|
||||
@@ -482,7 +488,13 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
inputs = self.clip_processor.tokenizer([text], return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
|
||||
# transformers 5.x returns BaseModelOutputWithPooling instead of a plain tensor
|
||||
output = self.clip_model.get_text_features(**inputs)
|
||||
if not isinstance(output, torch.Tensor):
|
||||
output = output.pooler_output
|
||||
if output is None:
|
||||
raise ValueError("pooler_output should not be None for CLIP models.")
|
||||
text_embedding = output.detach().cpu()
|
||||
text_embedding = text_embedding.expand(batch_size, -1)
|
||||
|
||||
return text_embedding
|
||||
|
||||
Reference in New Issue
Block a user