mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
clean up load, add inject stats and extend convert script for libero
This commit is contained in:
+98
-253
@@ -1,9 +1,13 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download # noqa: E402
|
||||
from safetensors.torch import load_file # noqa: E402
|
||||
from transformers.model_debugging_utils import model_addition_debugger_context
|
||||
|
||||
from lerobot.configs.policies import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
@@ -26,36 +30,15 @@ def set_all_seeds(seed=42):
|
||||
torch.use_deterministic_algorithms(True)
|
||||
print(f"All random seeds set to {seed} for reproducible results (deterministic mode enabled)")
|
||||
|
||||
# For MPS devices, set additional deterministic settings
|
||||
if torch.backends.mps.is_available():
|
||||
print("MPS device detected - using deterministic settings")
|
||||
|
||||
|
||||
# Set seeds at the start
|
||||
set_all_seeds(RANDOM_SEED)
|
||||
|
||||
# Import model debugger for detailed forward pass analysis
|
||||
try:
|
||||
import transformers
|
||||
from transformers.model_debugging_utils import model_addition_debugger_context
|
||||
|
||||
DEBUGGER_AVAILABLE = True
|
||||
print("✅ Model debugger available")
|
||||
print(f" Transformers version: {transformers.__version__}")
|
||||
except ImportError as e:
|
||||
print("⚠️ Model debugger not available (requires transformers with debugging utils)")
|
||||
print(f" Import error: {e}")
|
||||
DEBUGGER_AVAILABLE = False
|
||||
|
||||
config_model_path = "lerobot/pi0" # Use config from official model
|
||||
official_model_path = "lerobot/pi0" # Official model
|
||||
custom_model_path = "pepijn223/pi0_libero_fp32" # Custom model to compare # pepijn223/pi0_base_fp32
|
||||
custom_model_path = "pepijn223/pi0_base_fp32" # Custom model to compare # pepijn223/pi0_base_fp32
|
||||
device = "mps"
|
||||
|
||||
# For testing determinism, set both to same model:
|
||||
# official_model_path = "pepijn223/pi0_base_fp32"
|
||||
# custom_model_path = "pepijn223/pi0_base_fp32"
|
||||
|
||||
USE_FULL_TENSORS = True
|
||||
SAVE_TENSORS_TO_DISK = False
|
||||
|
||||
@@ -65,24 +48,10 @@ UPLOAD_TO_HUB = True # Set to True to upload to HuggingFace Hub
|
||||
TRANSFORMED_MODEL_NAME = "pepijn223/pi0_base_fp32_lerobot_format" # Target repo name
|
||||
COMMIT_MESSAGE = "Add transformed PI0 model with correct key format for lerobot"
|
||||
|
||||
# Create debug directory
|
||||
if DEBUGGER_AVAILABLE:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
debug_path = os.path.join("debug_outputs", f"pi0_debug_direct_{timestamp}")
|
||||
os.makedirs(debug_path, exist_ok=True)
|
||||
print(f"🔍 Model debugging enabled - outputs will be saved to: {debug_path}")
|
||||
else:
|
||||
debug_path = None
|
||||
|
||||
# Create dummy normalization stats (similar to openpi example)
|
||||
print("Creating dummy normalization stats...")
|
||||
|
||||
# Load shared config and create both models for comparison
|
||||
print("Loading shared config from lerobot/pi0...")
|
||||
import json # noqa: E402
|
||||
|
||||
from huggingface_hub import hf_hub_download # noqa: E402
|
||||
from safetensors.torch import load_file # noqa: E402
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
debug_path = os.path.join("debug_outputs", f"pi0_debug_direct_{timestamp}")
|
||||
os.makedirs(debug_path, exist_ok=True)
|
||||
print(f"Model debugging enabled - outputs will be saved to: {debug_path}")
|
||||
|
||||
# Download and load the config manually to avoid draccus parsing issues
|
||||
config_file = hf_hub_download(repo_id=config_model_path, filename="config.json")
|
||||
@@ -117,7 +86,7 @@ def load_policy_with_weights(
|
||||
print(f"Downloaded {model_name} weights to: {model_file}")
|
||||
|
||||
# Load state dict and apply transformations
|
||||
print(f"🔍 Investigating safetensors file: {model_file}")
|
||||
print(f"Investigating safetensors file: {model_file}")
|
||||
|
||||
# First, check what's in the metadata
|
||||
try:
|
||||
@@ -191,12 +160,12 @@ def load_policy_with_weights(
|
||||
# Check what we're missing and what we actually have
|
||||
expected_embed_key = "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
if expected_embed_key not in all_keys:
|
||||
print(f"⚠️ Missing expected embed_tokens key: {expected_embed_key}")
|
||||
print(f" Missing expected embed_tokens key: {expected_embed_key}")
|
||||
|
||||
# Let's see what keys we actually have for debugging
|
||||
print("🔍 Debugging: Looking for any embedding-related keys...")
|
||||
print("Debugging: Looking for any embedding-related keys...")
|
||||
all_embed_related = [k for k in all_keys if "embed" in k.lower()]
|
||||
print(f" Keys containing 'embed': {all_embed_related}")
|
||||
print(f"Keys containing 'embed': {all_embed_related}")
|
||||
|
||||
# Look for any keys that might contain embeddings
|
||||
potential_embed_keys = [
|
||||
@@ -231,61 +200,6 @@ def load_policy_with_weights(
|
||||
final_state_dict = {}
|
||||
transformation_count = 0
|
||||
|
||||
# First, handle the missing embed_tokens
|
||||
if expected_embed_key not in all_keys:
|
||||
print("🔄 Attempting to fix missing embed_tokens...")
|
||||
|
||||
# Option 1: Copy from gemma_expert if available
|
||||
if gemma_expert_embed:
|
||||
source_key = gemma_expert_embed[0]
|
||||
if source_key in transformed_state_dict:
|
||||
final_state_dict[expected_embed_key] = transformed_state_dict[source_key]
|
||||
print(f"✅ Created missing embed_tokens by copying from: {source_key}")
|
||||
transformation_count += 1
|
||||
else:
|
||||
# Option 2: Try to manually load embed_tokens from safetensors metadata
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata:
|
||||
# Look for embed_tokens in metadata
|
||||
embed_keys_in_metadata = [k for k in metadata.keys() if "embed_tokens" in k]
|
||||
print(f" embed_tokens in metadata: {embed_keys_in_metadata}")
|
||||
|
||||
if embed_keys_in_metadata:
|
||||
# Try to extract the tensor using the metadata key
|
||||
metadata_key = embed_keys_in_metadata[0]
|
||||
try:
|
||||
# The metadata key might be different from the tensor key
|
||||
# Try different variations
|
||||
possible_keys = [
|
||||
metadata_key,
|
||||
metadata_key.replace("model.", ""), # Remove model. prefix
|
||||
"paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight",
|
||||
]
|
||||
|
||||
for test_key in possible_keys:
|
||||
if test_key in f.keys():
|
||||
embed_tensor = f.get_tensor(test_key)
|
||||
final_state_dict[expected_embed_key] = embed_tensor
|
||||
print(
|
||||
f"✅ Manually extracted embed_tokens: {test_key} -> {expected_embed_key}"
|
||||
)
|
||||
transformation_count += 1
|
||||
break
|
||||
else:
|
||||
print("❌ Could not find embed_tokens tensor despite metadata entry")
|
||||
print(f" Tried keys: {possible_keys}")
|
||||
print(f" Available keys sample: {list(f.keys())[:10]}")
|
||||
except Exception as e2:
|
||||
print(f"❌ Error extracting embed_tokens tensor: {e2}")
|
||||
else:
|
||||
print("❌ No metadata found in safetensors file")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to manually extract embed_tokens: {e}")
|
||||
|
||||
for key, value in transformed_state_dict.items():
|
||||
new_key = key
|
||||
original_key = key
|
||||
@@ -296,7 +210,7 @@ def load_policy_with_weights(
|
||||
"paligemma_with_expert.paligemma.vision_tower.vision_model",
|
||||
"paligemma_with_expert.paligemma.model.vision_tower.vision_model",
|
||||
)
|
||||
print(f"✅ Transformed vision key: {original_key} -> {new_key}")
|
||||
print(f"Transformed vision key: {original_key} -> {new_key}")
|
||||
transformation_count += 1
|
||||
|
||||
# Transform multi_modal_projector keys: ADD .model between paligemma and multi_modal_projector
|
||||
@@ -305,7 +219,7 @@ def load_policy_with_weights(
|
||||
"paligemma_with_expert.paligemma.multi_modal_projector",
|
||||
"paligemma_with_expert.paligemma.model.multi_modal_projector",
|
||||
)
|
||||
print(f"✅ Transformed multi_modal_projector key: {original_key} -> {new_key}")
|
||||
print(f"Transformed multi_modal_projector key: {original_key} -> {new_key}")
|
||||
transformation_count += 1
|
||||
|
||||
# NO transformation needed for language_model keys - they're already correct!
|
||||
@@ -327,7 +241,7 @@ def load_policy_with_weights(
|
||||
missing_in_provided = policy_keys - provided_keys
|
||||
extra_in_provided = provided_keys - policy_keys
|
||||
|
||||
print(f"🔍 Policy expects {len(policy_keys)} keys, we provide {len(provided_keys)} keys")
|
||||
print(f"Policy expects {len(policy_keys)} keys, we provide {len(provided_keys)} keys")
|
||||
if missing_in_provided:
|
||||
print(
|
||||
f" Missing from provided: {list(missing_in_provided)[:5]}{'...' if len(missing_in_provided) > 5 else ''}"
|
||||
@@ -371,7 +285,7 @@ custom_policy = load_policy_with_weights(
|
||||
custom_model_path, shared_config, "Custom Model", apply_transformations=True
|
||||
)
|
||||
|
||||
print("\n✅ Both models loaded successfully!")
|
||||
print("\nBoth models loaded successfully!")
|
||||
print(f"Shared config: {shared_config}")
|
||||
print(f"Device: {device}")
|
||||
|
||||
@@ -436,40 +350,14 @@ for name in custom_params.keys():
|
||||
param_differences.append(f"Extra parameter in custom model: {name}")
|
||||
|
||||
if param_differences:
|
||||
print("⚠️ Parameter differences found:")
|
||||
print("Parameter differences found:")
|
||||
for diff in param_differences[:10]: # Show first 10 differences
|
||||
print(f" {diff}")
|
||||
if len(param_differences) > 10:
|
||||
print(f" ... and {len(param_differences) - 10} more differences")
|
||||
else:
|
||||
print("✅ All model parameters are identical!")
|
||||
print("All model parameters are identical!")
|
||||
|
||||
# Also check buffers (e.g., normalization statistics)
|
||||
print("\n=== Buffer Comparison ===")
|
||||
official_buffers = dict(official_policy.named_buffers())
|
||||
custom_buffers = dict(custom_policy.named_buffers())
|
||||
|
||||
buffer_differences = []
|
||||
for name in official_buffers.keys():
|
||||
if name not in custom_buffers:
|
||||
buffer_differences.append(f"Missing buffer in custom model: {name}")
|
||||
else:
|
||||
diff = torch.abs(official_buffers[name] - custom_buffers[name]).max().item()
|
||||
if diff > 1e-8:
|
||||
buffer_differences.append(f"Buffer {name}: max difference = {diff:.2e}")
|
||||
|
||||
for name in custom_buffers.keys():
|
||||
if name not in official_buffers:
|
||||
buffer_differences.append(f"Extra buffer in custom model: {name}")
|
||||
|
||||
if buffer_differences:
|
||||
print("⚠️ Buffer differences found:")
|
||||
for diff in buffer_differences[:10]:
|
||||
print(f" {diff}")
|
||||
if len(buffer_differences) > 10:
|
||||
print(f" ... and {len(buffer_differences) - 10} more differences")
|
||||
else:
|
||||
print("✅ All model buffers are identical!")
|
||||
|
||||
# Get the raw models for direct comparison
|
||||
official_raw_model = official_policy.model
|
||||
@@ -490,7 +378,7 @@ example = {
|
||||
|
||||
print(f"\nProvided input keys: {list(example.keys())}")
|
||||
|
||||
print("\n🔄 Preparing inputs for direct model call...")
|
||||
print("\nPreparing inputs for direct model call...")
|
||||
|
||||
# Apply input transformation (similar to openpi's policy._input_transform)
|
||||
transformed_example = {}
|
||||
@@ -593,69 +481,22 @@ with torch.no_grad():
|
||||
|
||||
print("\n=== Running Forward Passes ===")
|
||||
|
||||
if DEBUGGER_AVAILABLE and debug_path:
|
||||
print("🔍 Running with model_addition_debugger_context for detailed analysis...")
|
||||
|
||||
# Create separate debug paths for each model
|
||||
official_debug_path = os.path.join(debug_path, "official_model")
|
||||
custom_debug_path = os.path.join(debug_path, "custom_model")
|
||||
os.makedirs(official_debug_path, exist_ok=True)
|
||||
os.makedirs(custom_debug_path, exist_ok=True)
|
||||
|
||||
# Set deterministic mode for forward pass
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
|
||||
# Run official model with debugger
|
||||
print("Running Official Model forward pass with debugger...")
|
||||
with model_addition_debugger_context(
|
||||
official_raw_model,
|
||||
debug_path=official_debug_path,
|
||||
do_prune_layers=False, # Output ALL layers
|
||||
use_repr=not SAVE_TENSORS_TO_DISK,
|
||||
):
|
||||
official_loss = official_raw_model.forward(
|
||||
images=images,
|
||||
img_masks=img_masks,
|
||||
lang_tokens=lang_tokens,
|
||||
lang_masks=lang_masks,
|
||||
state=state,
|
||||
actions=dummy_actions,
|
||||
noise=noise,
|
||||
time=time,
|
||||
)
|
||||
|
||||
# Reset seed before second forward pass to ensure any internal randomness is identical
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
|
||||
# Run custom model with debugger
|
||||
print("Running Custom Model forward pass with debugger...")
|
||||
with model_addition_debugger_context(
|
||||
custom_raw_model,
|
||||
debug_path=custom_debug_path,
|
||||
do_prune_layers=False, # Output ALL layers
|
||||
use_repr=not SAVE_TENSORS_TO_DISK,
|
||||
):
|
||||
custom_loss = custom_raw_model.forward(
|
||||
images=images,
|
||||
img_masks=img_masks,
|
||||
lang_tokens=lang_tokens,
|
||||
lang_masks=lang_masks,
|
||||
state=state,
|
||||
actions=dummy_actions,
|
||||
noise=noise,
|
||||
time=time,
|
||||
)
|
||||
|
||||
print(f"📊 Official model debug outputs saved to: {official_debug_path}")
|
||||
print(f"📊 Custom model debug outputs saved to: {custom_debug_path}")
|
||||
else:
|
||||
print("Running without detailed debugging (model_addition_debugger_context not available)...")
|
||||
|
||||
# Set deterministic mode for forward pass
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
|
||||
# Run official model
|
||||
print("Running Official Model forward pass...")
|
||||
print("Running with model_addition_debugger_context for detailed analysis...")
|
||||
# Create separate debug paths for each model
|
||||
official_debug_path = os.path.join(debug_path, "official_model")
|
||||
custom_debug_path = os.path.join(debug_path, "custom_model")
|
||||
os.makedirs(official_debug_path, exist_ok=True)
|
||||
os.makedirs(custom_debug_path, exist_ok=True)
|
||||
# Set deterministic mode for forward pass
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
# Run official model with debugger
|
||||
print("Running Official Model forward pass with debugger...")
|
||||
with model_addition_debugger_context(
|
||||
official_raw_model,
|
||||
debug_path=official_debug_path,
|
||||
do_prune_layers=False, # Output ALL layers
|
||||
use_repr=not SAVE_TENSORS_TO_DISK,
|
||||
):
|
||||
official_loss = official_raw_model.forward(
|
||||
images=images,
|
||||
img_masks=img_masks,
|
||||
@@ -666,12 +507,16 @@ with torch.no_grad():
|
||||
noise=noise,
|
||||
time=time,
|
||||
)
|
||||
|
||||
# Reset seed before second forward pass to ensure any internal randomness is identical
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
|
||||
# Run custom model with same inputs
|
||||
print("Running Custom Model forward pass...")
|
||||
# Reset seed before second forward pass to ensure any internal randomness is identical
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
# Run custom model with debugger
|
||||
print("Running Custom Model forward pass with debugger...")
|
||||
with model_addition_debugger_context(
|
||||
custom_raw_model,
|
||||
debug_path=custom_debug_path,
|
||||
do_prune_layers=False, # Output ALL layers
|
||||
use_repr=not SAVE_TENSORS_TO_DISK,
|
||||
):
|
||||
custom_loss = custom_raw_model.forward(
|
||||
images=images,
|
||||
img_masks=img_masks,
|
||||
@@ -683,6 +528,9 @@ with torch.no_grad():
|
||||
time=time,
|
||||
)
|
||||
|
||||
print(f"Official model debug outputs saved to: {official_debug_path}")
|
||||
print(f"Custom model debug outputs saved to: {custom_debug_path}")
|
||||
|
||||
print("\n=== Output Comparison ===")
|
||||
print(f"Official model loss shape: {official_loss.shape}")
|
||||
print(f"Custom model loss shape: {custom_loss.shape}")
|
||||
@@ -705,48 +553,49 @@ with torch.no_grad():
|
||||
|
||||
# Determine if models are equivalent
|
||||
are_equivalent = loss_diff.max().item() < 1e-6
|
||||
print(f"\n🎯 Models are {'EQUIVALENT' if are_equivalent else 'DIFFERENT'}")
|
||||
print(f"\nModels are {'EQUIVALENT' if are_equivalent else 'DIFFERENT'}")
|
||||
print(f" (Max difference: {loss_diff.max().item():.8f}, Threshold: 1e-6)")
|
||||
|
||||
if DEBUGGER_AVAILABLE and debug_path:
|
||||
print(f"\n📊 Detailed debugging outputs saved to: {debug_path}")
|
||||
print(f"\nDetailed debugging outputs saved to: {debug_path}")
|
||||
# Save comparison results
|
||||
comparison_results = {
|
||||
"official_loss_stats": {
|
||||
"shape": list(official_loss.shape),
|
||||
"mean": official_loss.mean().item(),
|
||||
"std": official_loss.std().item(),
|
||||
"min": official_loss.min().item(),
|
||||
"max": official_loss.max().item(),
|
||||
},
|
||||
"custom_loss_stats": {
|
||||
"shape": list(custom_loss.shape),
|
||||
"mean": custom_loss.mean().item(),
|
||||
"std": custom_loss.std().item(),
|
||||
"min": custom_loss.min().item(),
|
||||
"max": custom_loss.max().item(),
|
||||
},
|
||||
"difference_stats": {
|
||||
"mean_abs_diff": loss_diff.mean().item(),
|
||||
"max_abs_diff": loss_diff.max().item(),
|
||||
"min_abs_diff": loss_diff.min().item(),
|
||||
"std_diff": loss_diff.std().item(),
|
||||
"are_equivalent": are_equivalent,
|
||||
},
|
||||
}
|
||||
|
||||
# Save comparison results
|
||||
comparison_results = {
|
||||
"official_loss_stats": {
|
||||
"shape": list(official_loss.shape),
|
||||
"mean": official_loss.mean().item(),
|
||||
"std": official_loss.std().item(),
|
||||
"min": official_loss.min().item(),
|
||||
"max": official_loss.max().item(),
|
||||
},
|
||||
"custom_loss_stats": {
|
||||
"shape": list(custom_loss.shape),
|
||||
"mean": custom_loss.mean().item(),
|
||||
"std": custom_loss.std().item(),
|
||||
"min": custom_loss.min().item(),
|
||||
"max": custom_loss.max().item(),
|
||||
},
|
||||
"difference_stats": {
|
||||
"mean_abs_diff": loss_diff.mean().item(),
|
||||
"max_abs_diff": loss_diff.max().item(),
|
||||
"min_abs_diff": loss_diff.min().item(),
|
||||
"std_diff": loss_diff.std().item(),
|
||||
"are_equivalent": are_equivalent,
|
||||
},
|
||||
}
|
||||
|
||||
import json
|
||||
|
||||
comparison_file = os.path.join(debug_path, "model_comparison_results.json")
|
||||
with open(comparison_file, "w") as f:
|
||||
json.dump(comparison_results, f, indent=2)
|
||||
print(f" Comparison results saved to: {comparison_file}")
|
||||
comparison_file = os.path.join(debug_path, "model_comparison_results.json")
|
||||
with open(comparison_file, "w") as f:
|
||||
json.dump(comparison_results, f, indent=2)
|
||||
print(f" Comparison results saved to: {comparison_file}")
|
||||
|
||||
# Save and upload transformed model if requested
|
||||
if SAVE_TRANSFORMED_MODEL and are_equivalent:
|
||||
print("\n🚀 Saving Transformed Model...")
|
||||
print("Models are equivalent - proceeding with transformation and upload")
|
||||
if SAVE_TRANSFORMED_MODEL:
|
||||
print("\nSaving Transformed Model...")
|
||||
if are_equivalent:
|
||||
print("Models are equivalent - proceeding with transformation and upload")
|
||||
else:
|
||||
print("Models are NOT equivalent, but proceeding with upload anyway")
|
||||
print(f" Max difference: {loss_diff.max().item():.2e}")
|
||||
print(" This might be useful for debugging or partial transformations")
|
||||
|
||||
# Create timestamp for README
|
||||
transformation_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -786,7 +635,8 @@ The original model had a different key naming convention. This model applies the
|
||||
|
||||
## Verification
|
||||
|
||||
This transformed model produces **identical outputs** (difference = {loss_diff.max().item():.2e}) to the official model `{official_model_path}` when tested with the same inputs.
|
||||
{"This transformed model produces **identical outputs**" if are_equivalent else "This transformed model has **slightly different outputs**"} (max difference = {loss_diff.max().item():.2e}) compared to the official model `{official_model_path}` when tested with the same inputs.
|
||||
{"**Models are EQUIVALENT** (difference < 1e-6)" if are_equivalent else "**Models are NOT equivalent** (difference >= 1e-6) - use with caution"}
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -803,9 +653,7 @@ action = policy.select_action(observation_batch)
|
||||
## Original Model
|
||||
|
||||
- **Source**: {custom_model_path}
|
||||
- **Transformation Date**: {transformation_timestamp}
|
||||
- **Verified Against**: {official_model_path}
|
||||
- **Max Output Difference**: {loss_diff.max().item():.2e}
|
||||
|
||||
## Technical Details
|
||||
|
||||
@@ -818,11 +666,11 @@ action = policy.select_action(observation_batch)
|
||||
with open(readme_path, "w") as f:
|
||||
f.write(readme_content.strip())
|
||||
|
||||
print(f"✅ Model saved locally to: {local_save_path}")
|
||||
print(f"Model saved locally to: {local_save_path}")
|
||||
|
||||
# Upload to HuggingFace Hub if requested
|
||||
if UPLOAD_TO_HUB:
|
||||
print(f"\n📤 Uploading to HuggingFace Hub: {TRANSFORMED_MODEL_NAME}")
|
||||
print(f"\nUploading to HuggingFace Hub: {TRANSFORMED_MODEL_NAME}")
|
||||
|
||||
try:
|
||||
# Push to hub
|
||||
@@ -833,27 +681,24 @@ action = policy.select_action(observation_batch)
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
print(f"✅ Model successfully uploaded to: https://huggingface.co/{TRANSFORMED_MODEL_NAME}")
|
||||
print("🎉 You can now use this model directly without any transformations!")
|
||||
print(f"Model successfully uploaded to: https://huggingface.co/{TRANSFORMED_MODEL_NAME}")
|
||||
print("You can now use this model directly without any transformations!")
|
||||
print("\n Usage:")
|
||||
print(" from lerobot.policies.pi0.modeling_pi0 import PI0Policy")
|
||||
print(f" policy = PI0Policy.from_pretrained('{TRANSFORMED_MODEL_NAME}')")
|
||||
|
||||
except Exception as upload_error:
|
||||
print(f"❌ Failed to upload to HuggingFace Hub: {upload_error}")
|
||||
print(f"💡 You can manually upload the model from: {local_save_path}")
|
||||
print(f"Failed to upload to HuggingFace Hub: {upload_error}")
|
||||
print(f"You can manually upload the model from: {local_save_path}")
|
||||
print(" Or set UPLOAD_TO_HUB = False and upload later")
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
print(f"❌ Error saving transformed model: {str(e)}")
|
||||
print(f"Error saving transformed model: {str(e)}")
|
||||
print("Full traceback:")
|
||||
traceback.print_exc()
|
||||
print("💡 The model transformation logic works, but saving failed")
|
||||
print("The model transformation logic works, but saving failed")
|
||||
|
||||
elif not are_equivalent:
|
||||
print("\n⚠️ Skipping model save - models are not equivalent")
|
||||
print(f" Max difference: {loss_diff.max().item():.2e} > threshold: 1e-6")
|
||||
else:
|
||||
print("\n💡 Model transformation and upload disabled (SAVE_TRANSFORMED_MODEL = False)")
|
||||
print("\nModel transformation and upload disabled (SAVE_TRANSFORMED_MODEL = False)")
|
||||
|
||||
Reference in New Issue
Block a user