cleanup from pretrained

This commit is contained in:
Pepijn
2025-09-29 15:26:02 +02:00
parent 8de5280fd3
commit 4dcfc4cda9
2 changed files with 11 additions and 17 deletions
+5 -9
View File
@@ -894,9 +894,9 @@ class PI0Policy(PreTrainedPolicy):
) -> T:
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI0 model is a direct PyTorch port of the OpenPI implementation. \n"
" This implementation follows the original OpenPI structure for compatibility. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
"The PI05 model is a direct port of the OpenPI implementation. \n"
"This implementation follows the original OpenPI structure for compatibility. \n"
"Original implementation: https://github.com/Physical-Intelligence/openpi"
)
if pretrained_name_or_path is None:
raise ValueError("pretrained_name_or_path is required")
@@ -959,15 +959,11 @@ class PI0Policy(PreTrainedPolicy):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1
if remap_count <= 10: # Only print first 10 to avoid spam
print(f"Remapped: {key} -> {new_key}")
else:
remapped_state_dict[key] = value
if remap_count > 10:
print(f"... and {remap_count - 10} more keys remapped")
print(f"Total keys remapped: {remap_count}")
if remap_count > 0:
print(f"Remapped {remap_count} state dict keys")
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
+6 -8
View File
@@ -867,9 +867,9 @@ class PI05Policy(PreTrainedPolicy):
) -> T:
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI05 model is a direct PyTorch port of the OpenPI implementation. \n"
" This implementation follows the original OpenPI structure for compatibility. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
"The PI05 model is a direct port of the OpenPI implementation. \n"
"This implementation follows the original OpenPI structure for compatibility. \n"
"Original implementation: https://github.com/Physical-Intelligence/openpi"
)
if pretrained_name_or_path is None:
raise ValueError("pretrained_name_or_path is required")
@@ -878,7 +878,7 @@ class PI05Policy(PreTrainedPolicy):
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
force_download=force_download,s
resume_download=resume_download,
proxies=proxies,
token=token,
@@ -937,10 +937,8 @@ class PI05Policy(PreTrainedPolicy):
else:
remapped_state_dict[key] = value
if remap_count > 10:
print(f"... and {remap_count - 10} more keys remapped")
print(f"Total keys remapped: {remap_count}")
if remap_count > 0:
print(f"Remapped {remap_count} state dict keys")
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)