From 36f828221c2a6e8cfced40ad74cba5e3afbc9ac4 Mon Sep 17 00:00:00 2001 From: pepijn Date: Thu, 21 May 2026 13:25:03 +0000 Subject: [PATCH] fix(pi05): preserve pretrained paligemma lm head Keep the PaliGemma LM head in float32 and initialize it from pretrained weights or token embeddings when loading pi05 checkpoints. Co-authored-by: Cursor --- src/lerobot/policies/pi05/modeling_pi05.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 4f3dce1c9..2c0301e73 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -409,6 +409,7 @@ class PaliGemmaWithExpertModel( params_to_keep_float32 = [ "vision_tower", "multi_modal_projector", + "lm_head", "input_layernorm", "post_attention_layernorm", "model.norm", @@ -1029,6 +1030,16 @@ class PI05Policy(PreTrainedPolicy): if remap_count > 0: print(f"Remapped {remap_count} state dict keys") + lm_head_key = "model.paligemma_with_expert.paligemma.lm_head.weight" + embed_tokens_key = ( + "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" + ) + if lm_head_key not in remapped_state_dict and embed_tokens_key in remapped_state_dict: + remapped_state_dict[lm_head_key] = remapped_state_dict[embed_tokens_key].clone().float() + print("Initialized PaliGemma lm_head from language token embeddings") + elif lm_head_key in remapped_state_dict: + remapped_state_dict[lm_head_key] = remapped_state_dict[lm_head_key].float() + # Load the remapped state dict into the model missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)