fix(pi05): preserve pretrained paligemma lm head

Keep the PaliGemma LM head in float32 and initialize it from pretrained weights or token embeddings when loading pi05 checkpoints.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-21 13:25:03 +00:00
parent d41d874581
commit 36f828221c
@@ -409,6 +409,7 @@ class PaliGemmaWithExpertModel(
params_to_keep_float32 = [
"vision_tower",
"multi_modal_projector",
"lm_head",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
@@ -1029,6 +1030,16 @@ class PI05Policy(PreTrainedPolicy):
if remap_count > 0:
print(f"Remapped {remap_count} state dict keys")
lm_head_key = "model.paligemma_with_expert.paligemma.lm_head.weight"
embed_tokens_key = (
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
)
if lm_head_key not in remapped_state_dict and embed_tokens_key in remapped_state_dict:
remapped_state_dict[lm_head_key] = remapped_state_dict[embed_tokens_key].clone().float()
print("Initialized PaliGemma lm_head from language token embeddings")
elif lm_head_key in remapped_state_dict:
remapped_state_dict[lm_head_key] = remapped_state_dict[lm_head_key].float()
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)