From 5924d4d9eb6f5d2a39e9e657191093982f752f47 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 16 Sep 2025 15:15:24 +0200 Subject: [PATCH] remove todo --- src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py | 6 ++---- src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py index 4a3b05043..73e3a0f44 100644 --- a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py @@ -863,9 +863,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy): self.reset() @classmethod - def from_pretrained( - cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs - ): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict + def from_pretrained(cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs): """Override the from_pretrained method to handle key remapping and display important disclaimer.""" print( "⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n" @@ -911,7 +909,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy): print("Returning model without loading pretrained weights") return model - # First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) # Then add "model." prefix for all keys that don't already have it diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index 51d050253..db137f025 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -882,9 +882,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): self.reset() @classmethod - def from_pretrained( - cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs - ): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict + def from_pretrained(cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs): """Override the from_pretrained method to handle key remapping and display important disclaimer.""" print( "⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n" @@ -930,7 +928,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): print("Returning model without loading pretrained weights") return model - # First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) # Then add "model." prefix for all keys that don't already have it