fix from pretrained

This commit is contained in:
Pepijn
2025-09-12 21:12:48 +02:00
parent d1eefd4e97
commit 376cc772ff
2 changed files with 42 additions and 52 deletions
@@ -864,41 +864,35 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
@classmethod
def from_pretrained(
cls, *args, **kwargs
cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs
): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI05OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
"⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
" This implementation follows the original OpenPI structure for compatibility. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
)
if pretrained_name_or_path is None:
raise ValueError("pretrained_name_or_path is required")
# Store original strict mode
original_strict = kwargs.get("strict", True)
# Temporarily set strict=False to avoid loading issues, we'll handle it manually
kwargs["strict"] = False
# Create default config
config = cls.config_class()
# Call parent from_pretrained with strict=False
model = super().from_pretrained(*args, **kwargs)
# Extract the pretrained_model_name_or_path from args or kwargs for remapping
if len(args) > 0:
pretrained_model_name_or_path = args[0]
elif "pretrained_model_name_or_path" in kwargs:
pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"]
else:
return model
# Initialize model without loading weights
# Check if dataset_stats were provided in kwargs
dataset_stats = kwargs.get("dataset_stats")
model = cls(config=config, dataset_stats=dataset_stats)
# Now manually load and remap the state dict
try:
from transformers.utils import cached_file
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_model_name_or_path}")
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_model_name_or_path,
pretrained_name_or_path,
"model.safetensors",
cache_dir=kwargs.get("cache_dir"),
force_download=kwargs.get("force_download", False),
@@ -914,9 +908,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
print("✓ Loaded state dict from model.safetensors")
except Exception as e:
print(f"Could not load state dict from remote files: {e}")
print("Returning model without loading pretrained weights")
return model
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
# First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
@@ -939,10 +934,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
print(f"Total keys remapped: {remap_count}")
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=original_strict)
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
if missing_keys:
print(f"⚠️ Missing keys when loading state dict: {len(missing_keys)} keys")
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
@@ -952,7 +947,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"⚠️ Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
@@ -962,11 +957,11 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All keys loaded successfully!")
print("All keys loaded successfully!")
except Exception as e:
print(f"⚠️ Warning: Could not remap state dict keys: {e}")
print("Using default loading behavior")
print(f"Warning: Could not remap state dict keys: {e}")
print("Returning model without loading pretrained weights")
return model
@@ -881,7 +881,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
@classmethod
def from_pretrained(
cls, *args, **kwargs
cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs
): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
print(
@@ -889,33 +889,27 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
" This implementation follows the original OpenPI structure for compatibility. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
)
if pretrained_name_or_path is None:
raise ValueError("pretrained_name_or_path is required")
# Store original strict mode
original_strict = kwargs.get("strict", True)
# Temporarily set strict=False to avoid loading issues, we'll handle it manually
kwargs["strict"] = False
# Create default config
config = cls.config_class()
# Call parent from_pretrained with strict=False
model = super().from_pretrained(*args, **kwargs)
# Extract the pretrained_model_name_or_path from args or kwargs for remapping
if len(args) > 0:
pretrained_model_name_or_path = args[0]
elif "pretrained_model_name_or_path" in kwargs:
pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"]
else:
return model
# Initialize model without loading weights
# Check if dataset_stats were provided in kwargs
dataset_stats = kwargs.get("dataset_stats")
model = cls(config=config, dataset_stats=dataset_stats)
# Now manually load and remap the state dict
try:
from transformers.utils import cached_file
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_model_name_or_path}")
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_model_name_or_path,
pretrained_name_or_path,
"model.safetensors",
cache_dir=kwargs.get("cache_dir"),
force_download=kwargs.get("force_download", False),
@@ -931,6 +925,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
print("✓ Loaded state dict from model.safetensors")
except Exception as e:
print(f"Could not load state dict from remote files: {e}")
print("Returning model without loading pretrained weights")
return model
# First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
@@ -956,10 +951,10 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
print(f"Total keys remapped: {remap_count}")
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=original_strict)
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
if missing_keys:
print(f"⚠️ Missing keys when loading state dict: {len(missing_keys)} keys")
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
@@ -969,7 +964,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"⚠️ Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
@@ -979,11 +974,11 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All keys loaded successfully!")
print("All keys loaded successfully!")
except Exception as e:
print(f"⚠️ Warning: Could not remap state dict keys: {e}")
print("Using default loading behavior")
print(f"Warning: Could not remap state dict keys: {e}")
print("Returning model without loading pretrained weights")
return model