diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 4f3dce1c9..2c0301e73 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -409,6 +409,7 @@ class PaliGemmaWithExpertModel( params_to_keep_float32 = [ "vision_tower", "multi_modal_projector", + "lm_head", "input_layernorm", "post_attention_layernorm", "model.norm", @@ -1029,6 +1030,16 @@ class PI05Policy(PreTrainedPolicy): if remap_count > 0: print(f"Remapped {remap_count} state dict keys") + lm_head_key = "model.paligemma_with_expert.paligemma.lm_head.weight" + embed_tokens_key = ( + "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" + ) + if lm_head_key not in remapped_state_dict and embed_tokens_key in remapped_state_dict: + remapped_state_dict[lm_head_key] = remapped_state_dict[embed_tokens_key].clone().float() + print("Initialized PaliGemma lm_head from language token embeddings") + elif lm_head_key in remapped_state_dict: + remapped_state_dict[lm_head_key] = remapped_state_dict[lm_head_key].float() + # Load the remapped state dict into the model missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)