mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
18 lines
688 B
Python
18 lines
688 B
Python
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
|
|
|
model_id = "google/paligemma-3b-pt-224"
|
|
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
|
|
processor = AutoProcessor.from_pretrained(model_id)
|
|
|
|
breakpoint()
|
|
prefix_output = model.language_model.forward(
|
|
inputs_embeds=inputs_embeds[0],
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
|
)
|
|
prefix_past_key_values = prefix_output.past_key_values
|
|
# prefix_output to be used for the language head
|
|
# shape: [batch_size, seq_len, hidden_size] with hidden_size = 2048
|
|
prefix_output = prefix_output.last_hidden_state
|