From 4dcfc4cda9168604fb749f5900eb041227439396 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 29 Sep 2025 15:26:02 +0200 Subject: [PATCH] cleanup from pretrained --- src/lerobot/policies/pi0/modeling_pi0.py | 14 +++++--------- src/lerobot/policies/pi05/modeling_pi05.py | 14 ++++++-------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 8ea741d08..833f96c51 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -894,9 +894,9 @@ class PI0Policy(PreTrainedPolicy): ) -> T: """Override the from_pretrained method to handle key remapping and display important disclaimer.""" print( - "⚠️ DISCLAIMER: The PI0 model is a direct PyTorch port of the OpenPI implementation. \n" - " This implementation follows the original OpenPI structure for compatibility. \n" - " Original implementation: https://github.com/Physical-Intelligence/openpi" + "The PI05 model is a direct port of the OpenPI implementation. \n" + "This implementation follows the original OpenPI structure for compatibility. \n" + "Original implementation: https://github.com/Physical-Intelligence/openpi" ) if pretrained_name_or_path is None: raise ValueError("pretrained_name_or_path is required") @@ -959,15 +959,11 @@ class PI0Policy(PreTrainedPolicy): new_key = f"model.{key}" remapped_state_dict[new_key] = value remap_count += 1 - if remap_count <= 10: # Only print first 10 to avoid spam - print(f"Remapped: {key} -> {new_key}") else: remapped_state_dict[key] = value - if remap_count > 10: - print(f"... and {remap_count - 10} more keys remapped") - - print(f"Total keys remapped: {remap_count}") + if remap_count > 0: + print(f"Remapped {remap_count} state dict keys") # Load the remapped state dict into the model missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 8db75913c..54bb576c0 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -867,9 +867,9 @@ class PI05Policy(PreTrainedPolicy): ) -> T: """Override the from_pretrained method to handle key remapping and display important disclaimer.""" print( - "⚠️ DISCLAIMER: The PI05 model is a direct PyTorch port of the OpenPI implementation. \n" - " This implementation follows the original OpenPI structure for compatibility. \n" - " Original implementation: https://github.com/Physical-Intelligence/openpi" + "The PI05 model is a direct port of the OpenPI implementation. \n" + "This implementation follows the original OpenPI structure for compatibility. \n" + "Original implementation: https://github.com/Physical-Intelligence/openpi" ) if pretrained_name_or_path is None: raise ValueError("pretrained_name_or_path is required") @@ -878,7 +878,7 @@ class PI05Policy(PreTrainedPolicy): if config is None: config = PreTrainedConfig.from_pretrained( pretrained_name_or_path=pretrained_name_or_path, - force_download=force_download, + force_download=force_download,s resume_download=resume_download, proxies=proxies, token=token, @@ -937,10 +937,8 @@ class PI05Policy(PreTrainedPolicy): else: remapped_state_dict[key] = value - if remap_count > 10: - print(f"... and {remap_count - 10} more keys remapped") - - print(f"Total keys remapped: {remap_count}") + if remap_count > 0: + print(f"Remapped {remap_count} state dict keys") # Load the remapped state dict into the model missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)