mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 14:09:47 +00:00
fix(pi05): preserve pretrained paligemma lm head
Keep the PaliGemma LM head in float32 and initialize it from pretrained weights or token embeddings when loading pi05 checkpoints. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -409,6 +409,7 @@ class PaliGemmaWithExpertModel(
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"lm_head",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
@@ -1029,6 +1030,16 @@ class PI05Policy(PreTrainedPolicy):
|
||||
if remap_count > 0:
|
||||
print(f"Remapped {remap_count} state dict keys")
|
||||
|
||||
lm_head_key = "model.paligemma_with_expert.paligemma.lm_head.weight"
|
||||
embed_tokens_key = (
|
||||
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
)
|
||||
if lm_head_key not in remapped_state_dict and embed_tokens_key in remapped_state_dict:
|
||||
remapped_state_dict[lm_head_key] = remapped_state_dict[embed_tokens_key].clone().float()
|
||||
print("Initialized PaliGemma lm_head from language token embeddings")
|
||||
elif lm_head_key in remapped_state_dict:
|
||||
remapped_state_dict[lm_head_key] = remapped_state_dict[lm_head_key].float()
|
||||
|
||||
# Load the remapped state dict into the model
|
||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user