mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
cleanup from pretrained
This commit is contained in:
@@ -894,9 +894,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
) -> T:
|
||||
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
|
||||
print(
|
||||
"⚠️ DISCLAIMER: The PI0 model is a direct PyTorch port of the OpenPI implementation. \n"
|
||||
" This implementation follows the original OpenPI structure for compatibility. \n"
|
||||
" Original implementation: https://github.com/Physical-Intelligence/openpi"
|
||||
"The PI05 model is a direct port of the OpenPI implementation. \n"
|
||||
"This implementation follows the original OpenPI structure for compatibility. \n"
|
||||
"Original implementation: https://github.com/Physical-Intelligence/openpi"
|
||||
)
|
||||
if pretrained_name_or_path is None:
|
||||
raise ValueError("pretrained_name_or_path is required")
|
||||
@@ -959,15 +959,11 @@ class PI0Policy(PreTrainedPolicy):
|
||||
new_key = f"model.{key}"
|
||||
remapped_state_dict[new_key] = value
|
||||
remap_count += 1
|
||||
if remap_count <= 10: # Only print first 10 to avoid spam
|
||||
print(f"Remapped: {key} -> {new_key}")
|
||||
else:
|
||||
remapped_state_dict[key] = value
|
||||
|
||||
if remap_count > 10:
|
||||
print(f"... and {remap_count - 10} more keys remapped")
|
||||
|
||||
print(f"Total keys remapped: {remap_count}")
|
||||
if remap_count > 0:
|
||||
print(f"Remapped {remap_count} state dict keys")
|
||||
|
||||
# Load the remapped state dict into the model
|
||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
|
||||
|
||||
@@ -867,9 +867,9 @@ class PI05Policy(PreTrainedPolicy):
|
||||
) -> T:
|
||||
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
|
||||
print(
|
||||
"⚠️ DISCLAIMER: The PI05 model is a direct PyTorch port of the OpenPI implementation. \n"
|
||||
" This implementation follows the original OpenPI structure for compatibility. \n"
|
||||
" Original implementation: https://github.com/Physical-Intelligence/openpi"
|
||||
"The PI05 model is a direct port of the OpenPI implementation. \n"
|
||||
"This implementation follows the original OpenPI structure for compatibility. \n"
|
||||
"Original implementation: https://github.com/Physical-Intelligence/openpi"
|
||||
)
|
||||
if pretrained_name_or_path is None:
|
||||
raise ValueError("pretrained_name_or_path is required")
|
||||
@@ -878,7 +878,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
if config is None:
|
||||
config = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
force_download=force_download,s
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
@@ -937,10 +937,8 @@ class PI05Policy(PreTrainedPolicy):
|
||||
else:
|
||||
remapped_state_dict[key] = value
|
||||
|
||||
if remap_count > 10:
|
||||
print(f"... and {remap_count - 10} more keys remapped")
|
||||
|
||||
print(f"Total keys remapped: {remap_count}")
|
||||
if remap_count > 0:
|
||||
print(f"Remapped {remap_count} state dict keys")
|
||||
|
||||
# Load the remapped state dict into the model
|
||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
|
||||
|
||||
Reference in New Issue
Block a user