remove todo

This commit is contained in:
Pepijn
2025-09-16 15:15:24 +02:00
parent aaae109447
commit 5924d4d9eb
2 changed files with 4 additions and 8 deletions
@@ -863,9 +863,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
self.reset()
@classmethod
def from_pretrained(
cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs
): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict
def from_pretrained(cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs):
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
@@ -911,7 +909,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
print("Returning model without loading pretrained weights")
return model
# First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
@@ -882,9 +882,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
self.reset()
@classmethod
def from_pretrained(
cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs
): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict
def from_pretrained(cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs):
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
@@ -930,7 +928,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
print("Returning model without loading pretrained weights")
return model
# First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it