also for pi05

This commit is contained in:
Pepijn
2025-09-12 19:02:13 +02:00
parent bf90efa7e1
commit 6ce2a00135
@@ -974,14 +974,22 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight",
key,
):
# This key structure suggests old model without adaRMS - keep as is or skip
logging.warning(f"Skipping old layer norm key (no adaRMS support): {key}")
continue
# Check if the model actually has adaRMS enabled for the expert
expert_uses_adarms = getattr(
self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
)
if expert_uses_adarms:
logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}")
continue
if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key):
# Skip old norm structure
logging.warning(f"Skipping old norm key (no adaRMS support): {key}")
continue
# Check if the model actually has adaRMS enabled for the expert
expert_uses_adarms = getattr(
self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
)
if expert_uses_adarms:
logging.warning(f"Skipping norm key (adaRMS mismatch): {key}")
continue
# Handle MLP naming changes for pi05
# pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_*