mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
remove todo
This commit is contained in:
@@ -863,9 +863,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
self.reset()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs
|
||||
): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict
|
||||
def from_pretrained(cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs):
|
||||
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
|
||||
print(
|
||||
"⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
|
||||
@@ -911,7 +909,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
|
||||
@@ -882,9 +882,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
self.reset()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs
|
||||
): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict
|
||||
def from_pretrained(cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs):
|
||||
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
|
||||
print(
|
||||
"⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
|
||||
@@ -930,7 +928,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
|
||||
Reference in New Issue
Block a user