mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
fix from pretrained
This commit is contained in:
@@ -864,41 +864,35 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, *args, **kwargs
|
||||
cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs
|
||||
): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict
|
||||
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
|
||||
print(
|
||||
"⚠️ DISCLAIMER: The PI05OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
|
||||
"⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
|
||||
" This implementation follows the original OpenPI structure for compatibility. \n"
|
||||
" Original implementation: https://github.com/Physical-Intelligence/openpi"
|
||||
)
|
||||
if pretrained_name_or_path is None:
|
||||
raise ValueError("pretrained_name_or_path is required")
|
||||
|
||||
# Store original strict mode
|
||||
original_strict = kwargs.get("strict", True)
|
||||
# Temporarily set strict=False to avoid loading issues, we'll handle it manually
|
||||
kwargs["strict"] = False
|
||||
# Create default config
|
||||
config = cls.config_class()
|
||||
|
||||
# Call parent from_pretrained with strict=False
|
||||
model = super().from_pretrained(*args, **kwargs)
|
||||
|
||||
# Extract the pretrained_model_name_or_path from args or kwargs for remapping
|
||||
if len(args) > 0:
|
||||
pretrained_model_name_or_path = args[0]
|
||||
elif "pretrained_model_name_or_path" in kwargs:
|
||||
pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"]
|
||||
else:
|
||||
return model
|
||||
# Initialize model without loading weights
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
dataset_stats = kwargs.get("dataset_stats")
|
||||
model = cls(config=config, dataset_stats=dataset_stats)
|
||||
|
||||
# Now manually load and remap the state dict
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try to load the pytorch_model.bin or model.safetensors file
|
||||
print(f"Loading model from: {pretrained_model_name_or_path}")
|
||||
print(f"Loading model from: {pretrained_name_or_path}")
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
pretrained_name_or_path,
|
||||
"model.safetensors",
|
||||
cache_dir=kwargs.get("cache_dir"),
|
||||
force_download=kwargs.get("force_download", False),
|
||||
@@ -914,9 +908,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
print("✓ Loaded state dict from model.safetensors")
|
||||
except Exception as e:
|
||||
print(f"Could not load state dict from remote files: {e}")
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
# First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
@@ -939,10 +934,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
print(f"Total keys remapped: {remap_count}")
|
||||
|
||||
# Load the remapped state dict into the model
|
||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=original_strict)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
|
||||
|
||||
if missing_keys:
|
||||
print(f"⚠️ Missing keys when loading state dict: {len(missing_keys)} keys")
|
||||
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
|
||||
if len(missing_keys) <= 5:
|
||||
for key in missing_keys:
|
||||
print(f" - {key}")
|
||||
@@ -952,7 +947,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
print(f" ... and {len(missing_keys) - 5} more")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"⚠️ Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
|
||||
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
|
||||
if len(unexpected_keys) <= 5:
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
@@ -962,11 +957,11 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("✅ All keys loaded successfully!")
|
||||
print("All keys loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Warning: Could not remap state dict keys: {e}")
|
||||
print("Using default loading behavior")
|
||||
print(f"Warning: Could not remap state dict keys: {e}")
|
||||
print("Returning model without loading pretrained weights")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -881,7 +881,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, *args, **kwargs
|
||||
cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs
|
||||
): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict
|
||||
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
|
||||
print(
|
||||
@@ -889,33 +889,27 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
" This implementation follows the original OpenPI structure for compatibility. \n"
|
||||
" Original implementation: https://github.com/Physical-Intelligence/openpi"
|
||||
)
|
||||
if pretrained_name_or_path is None:
|
||||
raise ValueError("pretrained_name_or_path is required")
|
||||
|
||||
# Store original strict mode
|
||||
original_strict = kwargs.get("strict", True)
|
||||
# Temporarily set strict=False to avoid loading issues, we'll handle it manually
|
||||
kwargs["strict"] = False
|
||||
# Create default config
|
||||
config = cls.config_class()
|
||||
|
||||
# Call parent from_pretrained with strict=False
|
||||
model = super().from_pretrained(*args, **kwargs)
|
||||
|
||||
# Extract the pretrained_model_name_or_path from args or kwargs for remapping
|
||||
if len(args) > 0:
|
||||
pretrained_model_name_or_path = args[0]
|
||||
elif "pretrained_model_name_or_path" in kwargs:
|
||||
pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"]
|
||||
else:
|
||||
return model
|
||||
# Initialize model without loading weights
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
dataset_stats = kwargs.get("dataset_stats")
|
||||
model = cls(config=config, dataset_stats=dataset_stats)
|
||||
|
||||
# Now manually load and remap the state dict
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try to load the pytorch_model.bin or model.safetensors file
|
||||
print(f"Loading model from: {pretrained_model_name_or_path}")
|
||||
print(f"Loading model from: {pretrained_name_or_path}")
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
pretrained_name_or_path,
|
||||
"model.safetensors",
|
||||
cache_dir=kwargs.get("cache_dir"),
|
||||
force_download=kwargs.get("force_download", False),
|
||||
@@ -931,6 +925,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
print("✓ Loaded state dict from model.safetensors")
|
||||
except Exception as e:
|
||||
print(f"Could not load state dict from remote files: {e}")
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
@@ -956,10 +951,10 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
print(f"Total keys remapped: {remap_count}")
|
||||
|
||||
# Load the remapped state dict into the model
|
||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=original_strict)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
|
||||
|
||||
if missing_keys:
|
||||
print(f"⚠️ Missing keys when loading state dict: {len(missing_keys)} keys")
|
||||
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
|
||||
if len(missing_keys) <= 5:
|
||||
for key in missing_keys:
|
||||
print(f" - {key}")
|
||||
@@ -969,7 +964,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
print(f" ... and {len(missing_keys) - 5} more")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"⚠️ Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
|
||||
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
|
||||
if len(unexpected_keys) <= 5:
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
@@ -979,11 +974,11 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("✅ All keys loaded successfully!")
|
||||
print("All keys loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Warning: Could not remap state dict keys: {e}")
|
||||
print("Using default loading behavior")
|
||||
print(f"Warning: Could not remap state dict keys: {e}")
|
||||
print("Returning model without loading pretrained weights")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user