clean up load, add inject stats and extend convert script for libero

This commit is contained in:
Pepijn
2025-09-10 14:00:22 +02:00
parent e2740fe555
commit c75df5c3b9
4 changed files with 603 additions and 282 deletions
+26
View File
@@ -0,0 +1,26 @@
#!/usr/bin/env python
"""Simple script to check buffer naming in the transformed model."""
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
# Load the model with strict=False to see what buffers we have
print("Loading model...")
policy = PI0Policy.from_pretrained("pepijn223/pi0_libero_lerobot", strict=False)
# Check what buffer keys exist
state_dict = policy.state_dict()
buffer_keys = [k for k in state_dict.keys() if "buffer" in k]
normalize_keys = [k for k in state_dict.keys() if "normalize" in k]
print("\nAll buffer keys:")
for key in buffer_keys:
print(f" {key}")
print("\nAll normalize keys:")
for key in normalize_keys:
print(f" {key}")
print("\nAll keys (first 20):")
for i, key in enumerate(state_dict.keys()):
if i < 20:
print(f" {key}")
+347
View File
@@ -0,0 +1,347 @@
#!/usr/bin/env python
"""Script for Pi0 pretrained policy inference and Hub upload."""
import argparse
from datetime import datetime
import numpy as np
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
# Set seed
torch.manual_seed(42)
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Pi0 policy inference and Hub upload")
parser.add_argument(
"--source-model-id",
type=str,
default="pepijn223/pi0_libero_lerobot",
help="Source model repository ID on Hugging Face Hub",
)
parser.add_argument(
"--dataset-id", type=str, default="pepijn223/libero", help="Dataset repository ID on Hugging Face Hub"
)
parser.add_argument(
"--output-model-id",
type=str,
required=True,
help="Output model repository ID to upload to (e.g., 'your-username/pi0-libero-fixed')",
)
parser.add_argument(
"--device", type=str, default="cpu", choices=["cpu", "cuda", "mps"], help="Device to run inference on"
)
parser.add_argument("--episode", type=int, default=0, help="Episode index to load from dataset")
parser.add_argument(
"--sample-idx", type=int, default=10, help="Sample index within episode to use for inference"
)
parser.add_argument("--private", action="store_true", help="Make the uploaded model private")
parser.add_argument(
"--commit-message", type=str, default=None, help="Custom commit message for the upload"
)
return parser.parse_args()
def _inject_normalization_stats(policy: PI0Policy, dataset_meta: LeRobotDatasetMetadata, key_mapping: dict):
"""Recreate normalization layers with proper stats from the dataset."""
from lerobot.policies.normalize import Normalize, Unnormalize
# Convert numpy stats to the format expected by normalization layers and remap keys
stats = {}
for dataset_key, stat_dict in dataset_meta.stats.items():
# Use mapped key if available, otherwise use original key
policy_key = key_mapping.get(dataset_key, dataset_key)
stats[policy_key] = {
stat_type: torch.from_numpy(stat_array) if isinstance(stat_array, np.ndarray) else stat_array
for stat_type, stat_array in stat_dict.items()
}
print(f"Available stats keys: {list(stats.keys())}")
print(
f"Policy expects keys: input={list(policy.config.input_features.keys())}, output={list(policy.config.output_features.keys())}"
)
# Recreate normalization layers with proper stats
normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
unnormalize_outputs = Unnormalize(
policy.config.output_features, policy.config.normalization_mapping, stats
)
# Replace the normalization layers on the policy
policy.normalize_inputs = normalize_inputs
policy.normalize_targets = normalize_targets
policy.unnormalize_outputs = unnormalize_outputs
print("Normalization layers recreated with dataset stats.")
def configure_policy_features(policy: PI0Policy, dataset: LeRobotDataset):
"""Configure policy input and output features based on dataset metadata."""
print(f"Dataset features: {list(dataset.meta.features.keys())}")
# Create a proper mapping from dataset keys to policy keys
dataset_to_policy_mapping = {}
# Handle images
if "image" in dataset.meta.features:
dataset_to_policy_mapping["image"] = "observation.images.image"
if "wrist_image" in dataset.meta.features:
dataset_to_policy_mapping["wrist_image"] = "observation.images.image2"
# Handle state
if "state" in dataset.meta.features:
dataset_to_policy_mapping["state"] = "observation.state"
# Handle actions
if "actions" in dataset.meta.features:
dataset_to_policy_mapping["actions"] = "action"
print(f"Key mapping: {dataset_to_policy_mapping}")
# Clear existing input features and reconfigure with proper mapping
policy.config.input_features = {}
policy.config.output_features = {}
# Map visual features
for dataset_key, policy_key in dataset_to_policy_mapping.items():
if dataset_key in ["image", "wrist_image"]:
feature_info = dataset.meta.features[dataset_key]
# Convert HWC to CHW format and resize
shape = (3, 224, 224) # Pi0 expects CHW format
policy.config.input_features[policy_key] = PolicyFeature(type=FeatureType.VISUAL, shape=shape)
# Map state features
for dataset_key, policy_key in dataset_to_policy_mapping.items():
if dataset_key == "state":
feature_info = dataset.meta.features[dataset_key]
shape = tuple(feature_info["shape"])
policy.config.input_features[policy_key] = PolicyFeature(type=FeatureType.STATE, shape=shape)
# Map action features
for dataset_key, policy_key in dataset_to_policy_mapping.items():
if dataset_key == "actions":
feature_info = dataset.meta.features[dataset_key]
shape = tuple(feature_info["shape"])
policy.config.output_features[policy_key] = PolicyFeature(type=FeatureType.ACTION, shape=shape)
print(f"Policy input_features: {list(policy.config.input_features.keys())}")
print(f"Policy output_features: {list(policy.config.output_features.keys())}")
print(f"Policy image_features: {list(policy.config.image_features.keys())}")
print(f"Policy action_feature: {policy.config.action_feature}")
return dataset_to_policy_mapping
def fix_buffer_naming(policy: PI0Policy):
"""Fix buffer naming issues in the loaded policy state dict."""
print("Fixing normalization buffer naming issues...")
state_dict = policy.state_dict()
corrected_state_dict = {}
fixes_applied = 0
for key, value in state_dict.items():
new_key = key
# Fix buffer naming: buffer_observation_state_mean -> buffer_observation_state.mean
if "buffer_observation_state_mean" in key:
new_key = key.replace("buffer_observation_state_mean", "buffer_observation_state.mean")
fixes_applied += 1
print(f" Fixed: {key} -> {new_key}")
elif "buffer_observation_state_std" in key:
new_key = key.replace("buffer_observation_state_std", "buffer_observation_state.std")
fixes_applied += 1
print(f" Fixed: {key} -> {new_key}")
# Remove image buffers that aren't expected (they cause conflicts)
elif "buffer_observation_image_mean" in key or "buffer_observation_image_std" in key:
print(f" Removed unexpected buffer: {key}")
continue # Skip this buffer
corrected_state_dict[new_key] = value
# Add missing action buffers with dummy values (will be replaced by dataset stats)
missing_buffers = [
"normalize_targets.buffer_action.mean",
"normalize_targets.buffer_action.std",
"unnormalize_outputs.buffer_action.mean",
"unnormalize_outputs.buffer_action.std",
]
for buffer_key in missing_buffers:
if buffer_key not in corrected_state_dict:
# Use dummy values - these will be overwritten by proper dataset stats later
if "mean" in buffer_key:
corrected_state_dict[buffer_key] = torch.zeros(8) # Assume 8-dim action
else: # std
corrected_state_dict[buffer_key] = torch.ones(8) # Assume 8-dim action
fixes_applied += 1
print(f" Added missing buffer: {buffer_key}")
print(f"Applied {fixes_applied} buffer fixes")
# Load the corrected state dict back into the policy
policy.load_state_dict(corrected_state_dict)
return policy
def main():
"""Main function to run the Pi0 inference and upload."""
args = parse_args()
# Load pretrained Pi0 model directly from Hugging Face Hub
print(f"Loading pretrained Pi0 model from {args.source_model_id}...")
# Load with strict=False to allow missing/unexpected keys, then fix them manually
policy = PI0Policy.from_pretrained(args.source_model_id, strict=False)
policy = fix_buffer_naming(policy)
policy.eval()
policy.to(args.device)
# Load dataset and get a sample
print(f"Loading dataset: {args.dataset_id}")
dataset = LeRobotDataset(args.dataset_id, episodes=[args.episode])
meta: LeRobotDatasetMetadata = dataset.meta
sample = dataset[args.sample_idx]
# Configure policy features
key_mapping = configure_policy_features(policy, dataset)
# Inject normalization stats with proper key mapping
_inject_normalization_stats(policy, meta, key_mapping)
# Prepare batch for PI0 (handle temporal dimensions)
batch = {}
# Map dataset sample keys to policy keys
reverse_mapping = {v: k for k, v in key_mapping.items()}
for policy_key in policy.config.input_features:
# Find the corresponding dataset key
dataset_key = reverse_mapping.get(policy_key, policy_key)
if dataset_key in sample:
data = sample[dataset_key]
# Handle image data: convert from HWC to CHW and normalize
if policy_key.startswith("observation.images."):
if data.dim() == 3 and data.shape[-1] == 3: # HWC format
data = data.permute(2, 0, 1) # Convert to CHW
# Normalize to [0, 1] range if needed
if data.dtype == torch.uint8:
data = data.float() / 255.0
# Resize to expected size if needed
if data.shape[-2:] != (224, 224):
import torch.nn.functional as F # noqa: N812
data = F.interpolate(
data.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False
)[0]
# Remove temporal dimension if present
if data.dim() > len(policy.config.input_features[policy_key].shape):
data = data[0]
batch[policy_key] = data.unsqueeze(0) # Add batch dimension
# Debug: print what's in the sample
print(f"Sample keys: {list(sample.keys())}")
print(f"Batch keys prepared: {list(batch.keys())}")
# Pi0 requires task description - add a default if not available
if "task" in sample:
batch["task"] = [sample["task"]] # Keep as list of strings
else:
print("No task in sample, using default task description")
batch["task"] = ["Complete the manipulation task"]
print(f"Task: {batch['task'][0]}")
print(f"Final batch keys: {list(batch.keys())}")
# Run inference
with torch.no_grad():
action = policy.select_action(batch)
print(f"Predicted action shape: {action.shape}")
print(f"Predicted action: {action.tolist()}")
print("✅ Pi0 pretrained inference completed successfully!")
# Upload to Hugging Face Hub
print(f"\n📤 Uploading model to Hugging Face Hub: {args.output_model_id}")
# Create commit message
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
commit_message = (
args.commit_message
or f"Pi0 model with injected normalization stats from {args.dataset_id} - {timestamp}"
)
# Update model configuration with dataset info
policy.config.push_to_hub = True
policy.config.repo_id = args.output_model_id
policy.config.private = args.private
# Add metadata about the adaptation
adaptation_info = {
"source_model": args.source_model_id,
"dataset_used": args.dataset_id,
"adaptation_date": timestamp,
"stats_injected": True,
"key_mapping": key_mapping,
"inference_test_passed": True,
"sample_action_shape": list(action.shape),
}
try:
# Push to hub
policy.push_to_hub(
repo_id=args.output_model_id,
private=args.private,
commit_message=commit_message,
create_pr=False,
)
# Also save the adaptation info as a separate file
import json
import os
import tempfile
from huggingface_hub import HfApi
api = HfApi()
# Create a temporary file with adaptation info
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(adaptation_info, f, indent=2)
temp_path = f.name
try:
api.upload_file(
path_or_fileobj=temp_path,
path_in_repo="adaptation_info.json",
repo_id=args.output_model_id,
commit_message=f"Add adaptation metadata - {timestamp}",
)
finally:
os.unlink(temp_path)
print(f"✅ Model successfully uploaded to: https://huggingface.co/{args.output_model_id}")
print("📋 Adaptation info:")
for key, value in adaptation_info.items():
print(f" {key}: {value}")
except Exception as e:
print(f"❌ Error uploading to Hub: {e}")
raise
if __name__ == "__main__":
main()
+98 -253
View File
@@ -1,9 +1,13 @@
import json
import os
import random
from datetime import datetime
import numpy as np
import torch
from huggingface_hub import hf_hub_download # noqa: E402
from safetensors.torch import load_file # noqa: E402
from transformers.model_debugging_utils import model_addition_debugger_context
from lerobot.configs.policies import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
@@ -26,36 +30,15 @@ def set_all_seeds(seed=42):
torch.use_deterministic_algorithms(True)
print(f"All random seeds set to {seed} for reproducible results (deterministic mode enabled)")
# For MPS devices, set additional deterministic settings
if torch.backends.mps.is_available():
print("MPS device detected - using deterministic settings")
# Set seeds at the start
set_all_seeds(RANDOM_SEED)
# Import model debugger for detailed forward pass analysis
try:
import transformers
from transformers.model_debugging_utils import model_addition_debugger_context
DEBUGGER_AVAILABLE = True
print("✅ Model debugger available")
print(f" Transformers version: {transformers.__version__}")
except ImportError as e:
print("⚠️ Model debugger not available (requires transformers with debugging utils)")
print(f" Import error: {e}")
DEBUGGER_AVAILABLE = False
config_model_path = "lerobot/pi0" # Use config from official model
official_model_path = "lerobot/pi0" # Official model
custom_model_path = "pepijn223/pi0_libero_fp32" # Custom model to compare # pepijn223/pi0_base_fp32
custom_model_path = "pepijn223/pi0_base_fp32" # Custom model to compare # pepijn223/pi0_base_fp32
device = "mps"
# For testing determinism, set both to same model:
# official_model_path = "pepijn223/pi0_base_fp32"
# custom_model_path = "pepijn223/pi0_base_fp32"
USE_FULL_TENSORS = True
SAVE_TENSORS_TO_DISK = False
@@ -65,24 +48,10 @@ UPLOAD_TO_HUB = True # Set to True to upload to HuggingFace Hub
TRANSFORMED_MODEL_NAME = "pepijn223/pi0_base_fp32_lerobot_format" # Target repo name
COMMIT_MESSAGE = "Add transformed PI0 model with correct key format for lerobot"
# Create debug directory
if DEBUGGER_AVAILABLE:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
debug_path = os.path.join("debug_outputs", f"pi0_debug_direct_{timestamp}")
os.makedirs(debug_path, exist_ok=True)
print(f"🔍 Model debugging enabled - outputs will be saved to: {debug_path}")
else:
debug_path = None
# Create dummy normalization stats (similar to openpi example)
print("Creating dummy normalization stats...")
# Load shared config and create both models for comparison
print("Loading shared config from lerobot/pi0...")
import json # noqa: E402
from huggingface_hub import hf_hub_download # noqa: E402
from safetensors.torch import load_file # noqa: E402
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
debug_path = os.path.join("debug_outputs", f"pi0_debug_direct_{timestamp}")
os.makedirs(debug_path, exist_ok=True)
print(f"Model debugging enabled - outputs will be saved to: {debug_path}")
# Download and load the config manually to avoid draccus parsing issues
config_file = hf_hub_download(repo_id=config_model_path, filename="config.json")
@@ -117,7 +86,7 @@ def load_policy_with_weights(
print(f"Downloaded {model_name} weights to: {model_file}")
# Load state dict and apply transformations
print(f"🔍 Investigating safetensors file: {model_file}")
print(f"Investigating safetensors file: {model_file}")
# First, check what's in the metadata
try:
@@ -191,12 +160,12 @@ def load_policy_with_weights(
# Check what we're missing and what we actually have
expected_embed_key = "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
if expected_embed_key not in all_keys:
print(f"⚠️ Missing expected embed_tokens key: {expected_embed_key}")
print(f" Missing expected embed_tokens key: {expected_embed_key}")
# Let's see what keys we actually have for debugging
print("🔍 Debugging: Looking for any embedding-related keys...")
print("Debugging: Looking for any embedding-related keys...")
all_embed_related = [k for k in all_keys if "embed" in k.lower()]
print(f" Keys containing 'embed': {all_embed_related}")
print(f"Keys containing 'embed': {all_embed_related}")
# Look for any keys that might contain embeddings
potential_embed_keys = [
@@ -231,61 +200,6 @@ def load_policy_with_weights(
final_state_dict = {}
transformation_count = 0
# First, handle the missing embed_tokens
if expected_embed_key not in all_keys:
print("🔄 Attempting to fix missing embed_tokens...")
# Option 1: Copy from gemma_expert if available
if gemma_expert_embed:
source_key = gemma_expert_embed[0]
if source_key in transformed_state_dict:
final_state_dict[expected_embed_key] = transformed_state_dict[source_key]
print(f"✅ Created missing embed_tokens by copying from: {source_key}")
transformation_count += 1
else:
# Option 2: Try to manually load embed_tokens from safetensors metadata
try:
from safetensors import safe_open
with safe_open(model_file, framework="pt", device="cpu") as f:
metadata = f.metadata()
if metadata:
# Look for embed_tokens in metadata
embed_keys_in_metadata = [k for k in metadata.keys() if "embed_tokens" in k]
print(f" embed_tokens in metadata: {embed_keys_in_metadata}")
if embed_keys_in_metadata:
# Try to extract the tensor using the metadata key
metadata_key = embed_keys_in_metadata[0]
try:
# The metadata key might be different from the tensor key
# Try different variations
possible_keys = [
metadata_key,
metadata_key.replace("model.", ""), # Remove model. prefix
"paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight",
]
for test_key in possible_keys:
if test_key in f.keys():
embed_tensor = f.get_tensor(test_key)
final_state_dict[expected_embed_key] = embed_tensor
print(
f"✅ Manually extracted embed_tokens: {test_key} -> {expected_embed_key}"
)
transformation_count += 1
break
else:
print("❌ Could not find embed_tokens tensor despite metadata entry")
print(f" Tried keys: {possible_keys}")
print(f" Available keys sample: {list(f.keys())[:10]}")
except Exception as e2:
print(f"❌ Error extracting embed_tokens tensor: {e2}")
else:
print("❌ No metadata found in safetensors file")
except Exception as e:
print(f"❌ Failed to manually extract embed_tokens: {e}")
for key, value in transformed_state_dict.items():
new_key = key
original_key = key
@@ -296,7 +210,7 @@ def load_policy_with_weights(
"paligemma_with_expert.paligemma.vision_tower.vision_model",
"paligemma_with_expert.paligemma.model.vision_tower.vision_model",
)
print(f"Transformed vision key: {original_key} -> {new_key}")
print(f"Transformed vision key: {original_key} -> {new_key}")
transformation_count += 1
# Transform multi_modal_projector keys: ADD .model between paligemma and multi_modal_projector
@@ -305,7 +219,7 @@ def load_policy_with_weights(
"paligemma_with_expert.paligemma.multi_modal_projector",
"paligemma_with_expert.paligemma.model.multi_modal_projector",
)
print(f"Transformed multi_modal_projector key: {original_key} -> {new_key}")
print(f"Transformed multi_modal_projector key: {original_key} -> {new_key}")
transformation_count += 1
# NO transformation needed for language_model keys - they're already correct!
@@ -327,7 +241,7 @@ def load_policy_with_weights(
missing_in_provided = policy_keys - provided_keys
extra_in_provided = provided_keys - policy_keys
print(f"🔍 Policy expects {len(policy_keys)} keys, we provide {len(provided_keys)} keys")
print(f"Policy expects {len(policy_keys)} keys, we provide {len(provided_keys)} keys")
if missing_in_provided:
print(
f" Missing from provided: {list(missing_in_provided)[:5]}{'...' if len(missing_in_provided) > 5 else ''}"
@@ -371,7 +285,7 @@ custom_policy = load_policy_with_weights(
custom_model_path, shared_config, "Custom Model", apply_transformations=True
)
print("\nBoth models loaded successfully!")
print("\nBoth models loaded successfully!")
print(f"Shared config: {shared_config}")
print(f"Device: {device}")
@@ -436,40 +350,14 @@ for name in custom_params.keys():
param_differences.append(f"Extra parameter in custom model: {name}")
if param_differences:
print("⚠️ Parameter differences found:")
print("Parameter differences found:")
for diff in param_differences[:10]: # Show first 10 differences
print(f" {diff}")
if len(param_differences) > 10:
print(f" ... and {len(param_differences) - 10} more differences")
else:
print("All model parameters are identical!")
print("All model parameters are identical!")
# Also check buffers (e.g., normalization statistics)
print("\n=== Buffer Comparison ===")
official_buffers = dict(official_policy.named_buffers())
custom_buffers = dict(custom_policy.named_buffers())
buffer_differences = []
for name in official_buffers.keys():
if name not in custom_buffers:
buffer_differences.append(f"Missing buffer in custom model: {name}")
else:
diff = torch.abs(official_buffers[name] - custom_buffers[name]).max().item()
if diff > 1e-8:
buffer_differences.append(f"Buffer {name}: max difference = {diff:.2e}")
for name in custom_buffers.keys():
if name not in official_buffers:
buffer_differences.append(f"Extra buffer in custom model: {name}")
if buffer_differences:
print("⚠️ Buffer differences found:")
for diff in buffer_differences[:10]:
print(f" {diff}")
if len(buffer_differences) > 10:
print(f" ... and {len(buffer_differences) - 10} more differences")
else:
print("✅ All model buffers are identical!")
# Get the raw models for direct comparison
official_raw_model = official_policy.model
@@ -490,7 +378,7 @@ example = {
print(f"\nProvided input keys: {list(example.keys())}")
print("\n🔄 Preparing inputs for direct model call...")
print("\nPreparing inputs for direct model call...")
# Apply input transformation (similar to openpi's policy._input_transform)
transformed_example = {}
@@ -593,69 +481,22 @@ with torch.no_grad():
print("\n=== Running Forward Passes ===")
if DEBUGGER_AVAILABLE and debug_path:
print("🔍 Running with model_addition_debugger_context for detailed analysis...")
# Create separate debug paths for each model
official_debug_path = os.path.join(debug_path, "official_model")
custom_debug_path = os.path.join(debug_path, "custom_model")
os.makedirs(official_debug_path, exist_ok=True)
os.makedirs(custom_debug_path, exist_ok=True)
# Set deterministic mode for forward pass
torch.manual_seed(RANDOM_SEED)
# Run official model with debugger
print("Running Official Model forward pass with debugger...")
with model_addition_debugger_context(
official_raw_model,
debug_path=official_debug_path,
do_prune_layers=False, # Output ALL layers
use_repr=not SAVE_TENSORS_TO_DISK,
):
official_loss = official_raw_model.forward(
images=images,
img_masks=img_masks,
lang_tokens=lang_tokens,
lang_masks=lang_masks,
state=state,
actions=dummy_actions,
noise=noise,
time=time,
)
# Reset seed before second forward pass to ensure any internal randomness is identical
torch.manual_seed(RANDOM_SEED)
# Run custom model with debugger
print("Running Custom Model forward pass with debugger...")
with model_addition_debugger_context(
custom_raw_model,
debug_path=custom_debug_path,
do_prune_layers=False, # Output ALL layers
use_repr=not SAVE_TENSORS_TO_DISK,
):
custom_loss = custom_raw_model.forward(
images=images,
img_masks=img_masks,
lang_tokens=lang_tokens,
lang_masks=lang_masks,
state=state,
actions=dummy_actions,
noise=noise,
time=time,
)
print(f"📊 Official model debug outputs saved to: {official_debug_path}")
print(f"📊 Custom model debug outputs saved to: {custom_debug_path}")
else:
print("Running without detailed debugging (model_addition_debugger_context not available)...")
# Set deterministic mode for forward pass
torch.manual_seed(RANDOM_SEED)
# Run official model
print("Running Official Model forward pass...")
print("Running with model_addition_debugger_context for detailed analysis...")
# Create separate debug paths for each model
official_debug_path = os.path.join(debug_path, "official_model")
custom_debug_path = os.path.join(debug_path, "custom_model")
os.makedirs(official_debug_path, exist_ok=True)
os.makedirs(custom_debug_path, exist_ok=True)
# Set deterministic mode for forward pass
torch.manual_seed(RANDOM_SEED)
# Run official model with debugger
print("Running Official Model forward pass with debugger...")
with model_addition_debugger_context(
official_raw_model,
debug_path=official_debug_path,
do_prune_layers=False, # Output ALL layers
use_repr=not SAVE_TENSORS_TO_DISK,
):
official_loss = official_raw_model.forward(
images=images,
img_masks=img_masks,
@@ -666,12 +507,16 @@ with torch.no_grad():
noise=noise,
time=time,
)
# Reset seed before second forward pass to ensure any internal randomness is identical
torch.manual_seed(RANDOM_SEED)
# Run custom model with same inputs
print("Running Custom Model forward pass...")
# Reset seed before second forward pass to ensure any internal randomness is identical
torch.manual_seed(RANDOM_SEED)
# Run custom model with debugger
print("Running Custom Model forward pass with debugger...")
with model_addition_debugger_context(
custom_raw_model,
debug_path=custom_debug_path,
do_prune_layers=False, # Output ALL layers
use_repr=not SAVE_TENSORS_TO_DISK,
):
custom_loss = custom_raw_model.forward(
images=images,
img_masks=img_masks,
@@ -683,6 +528,9 @@ with torch.no_grad():
time=time,
)
print(f"Official model debug outputs saved to: {official_debug_path}")
print(f"Custom model debug outputs saved to: {custom_debug_path}")
print("\n=== Output Comparison ===")
print(f"Official model loss shape: {official_loss.shape}")
print(f"Custom model loss shape: {custom_loss.shape}")
@@ -705,48 +553,49 @@ with torch.no_grad():
# Determine if models are equivalent
are_equivalent = loss_diff.max().item() < 1e-6
print(f"\n🎯 Models are {'EQUIVALENT' if are_equivalent else 'DIFFERENT'}")
print(f"\nModels are {'EQUIVALENT' if are_equivalent else 'DIFFERENT'}")
print(f" (Max difference: {loss_diff.max().item():.8f}, Threshold: 1e-6)")
if DEBUGGER_AVAILABLE and debug_path:
print(f"\n📊 Detailed debugging outputs saved to: {debug_path}")
print(f"\nDetailed debugging outputs saved to: {debug_path}")
# Save comparison results
comparison_results = {
"official_loss_stats": {
"shape": list(official_loss.shape),
"mean": official_loss.mean().item(),
"std": official_loss.std().item(),
"min": official_loss.min().item(),
"max": official_loss.max().item(),
},
"custom_loss_stats": {
"shape": list(custom_loss.shape),
"mean": custom_loss.mean().item(),
"std": custom_loss.std().item(),
"min": custom_loss.min().item(),
"max": custom_loss.max().item(),
},
"difference_stats": {
"mean_abs_diff": loss_diff.mean().item(),
"max_abs_diff": loss_diff.max().item(),
"min_abs_diff": loss_diff.min().item(),
"std_diff": loss_diff.std().item(),
"are_equivalent": are_equivalent,
},
}
# Save comparison results
comparison_results = {
"official_loss_stats": {
"shape": list(official_loss.shape),
"mean": official_loss.mean().item(),
"std": official_loss.std().item(),
"min": official_loss.min().item(),
"max": official_loss.max().item(),
},
"custom_loss_stats": {
"shape": list(custom_loss.shape),
"mean": custom_loss.mean().item(),
"std": custom_loss.std().item(),
"min": custom_loss.min().item(),
"max": custom_loss.max().item(),
},
"difference_stats": {
"mean_abs_diff": loss_diff.mean().item(),
"max_abs_diff": loss_diff.max().item(),
"min_abs_diff": loss_diff.min().item(),
"std_diff": loss_diff.std().item(),
"are_equivalent": are_equivalent,
},
}
import json
comparison_file = os.path.join(debug_path, "model_comparison_results.json")
with open(comparison_file, "w") as f:
json.dump(comparison_results, f, indent=2)
print(f" Comparison results saved to: {comparison_file}")
comparison_file = os.path.join(debug_path, "model_comparison_results.json")
with open(comparison_file, "w") as f:
json.dump(comparison_results, f, indent=2)
print(f" Comparison results saved to: {comparison_file}")
# Save and upload transformed model if requested
if SAVE_TRANSFORMED_MODEL and are_equivalent:
print("\n🚀 Saving Transformed Model...")
print("Models are equivalent - proceeding with transformation and upload")
if SAVE_TRANSFORMED_MODEL:
print("\nSaving Transformed Model...")
if are_equivalent:
print("Models are equivalent - proceeding with transformation and upload")
else:
print("Models are NOT equivalent, but proceeding with upload anyway")
print(f" Max difference: {loss_diff.max().item():.2e}")
print(" This might be useful for debugging or partial transformations")
# Create timestamp for README
transformation_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@@ -786,7 +635,8 @@ The original model had a different key naming convention. This model applies the
## Verification
This transformed model produces **identical outputs** (difference = {loss_diff.max().item():.2e}) to the official model `{official_model_path}` when tested with the same inputs.
{"This transformed model produces **identical outputs**" if are_equivalent else "This transformed model has **slightly different outputs**"} (max difference = {loss_diff.max().item():.2e}) compared to the official model `{official_model_path}` when tested with the same inputs.
{"**Models are EQUIVALENT** (difference < 1e-6)" if are_equivalent else "**Models are NOT equivalent** (difference >= 1e-6) - use with caution"}
## Usage
@@ -803,9 +653,7 @@ action = policy.select_action(observation_batch)
## Original Model
- **Source**: {custom_model_path}
- **Transformation Date**: {transformation_timestamp}
- **Verified Against**: {official_model_path}
- **Max Output Difference**: {loss_diff.max().item():.2e}
## Technical Details
@@ -818,11 +666,11 @@ action = policy.select_action(observation_batch)
with open(readme_path, "w") as f:
f.write(readme_content.strip())
print(f"Model saved locally to: {local_save_path}")
print(f"Model saved locally to: {local_save_path}")
# Upload to HuggingFace Hub if requested
if UPLOAD_TO_HUB:
print(f"\n📤 Uploading to HuggingFace Hub: {TRANSFORMED_MODEL_NAME}")
print(f"\nUploading to HuggingFace Hub: {TRANSFORMED_MODEL_NAME}")
try:
# Push to hub
@@ -833,27 +681,24 @@ action = policy.select_action(observation_batch)
safe_serialization=True,
)
print(f"Model successfully uploaded to: https://huggingface.co/{TRANSFORMED_MODEL_NAME}")
print("🎉 You can now use this model directly without any transformations!")
print(f"Model successfully uploaded to: https://huggingface.co/{TRANSFORMED_MODEL_NAME}")
print("You can now use this model directly without any transformations!")
print("\n Usage:")
print(" from lerobot.policies.pi0.modeling_pi0 import PI0Policy")
print(f" policy = PI0Policy.from_pretrained('{TRANSFORMED_MODEL_NAME}')")
except Exception as upload_error:
print(f"Failed to upload to HuggingFace Hub: {upload_error}")
print(f"💡 You can manually upload the model from: {local_save_path}")
print(f"Failed to upload to HuggingFace Hub: {upload_error}")
print(f"You can manually upload the model from: {local_save_path}")
print(" Or set UPLOAD_TO_HUB = False and upload later")
except Exception as e:
import traceback
print(f"Error saving transformed model: {str(e)}")
print(f"Error saving transformed model: {str(e)}")
print("Full traceback:")
traceback.print_exc()
print("💡 The model transformation logic works, but saving failed")
print("The model transformation logic works, but saving failed")
elif not are_equivalent:
print("\n⚠️ Skipping model save - models are not equivalent")
print(f" Max difference: {loss_diff.max().item():.2e} > threshold: 1e-6")
else:
print("\n💡 Model transformation and upload disabled (SAVE_TRANSFORMED_MODEL = False)")
print("\nModel transformation and upload disabled (SAVE_TRANSFORMED_MODEL = False)")
@@ -13,20 +13,22 @@
# limitations under the License.
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
2.1. It will:
This script will help you download any LeRobot dataset from the hub, convert it to the latest format, and
upload it to your own repository. It will:
- Download the dataset from any source repository
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
- Check consistency between these new stats and the old ones.
- Remove the deprecated `stats.json`.
- Update codebase_version in `info.json`.
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
- Update codebase_version in `info.json` to the latest version
- Create proper version tags
- Push the converted dataset to your specified destination repository
Usage:
```bash
python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 \
--repo-id=aliberts/koch_tutorial
--source-repo-id=IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot \
--dest-repo-id=your-username/libero_spatial_converted \
--episodes=0,1,2,3,4
```
"""
@@ -37,8 +39,8 @@ import logging
from huggingface_hub import HfApi
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, write_info
from lerobot.datasets.v21.convert_stats import convert_stats
V20 = "v2.0"
V21 = "v2.1"
@@ -54,48 +56,133 @@ class SuppressWarnings:
def convert_dataset(
repo_id: str,
source_repo_id: str,
dest_repo_id: str | None = None,
episodes: str | None = None,
branch: str | None = None,
num_workers: int = 4,
force_cache_sync: bool = True,
):
with SuppressWarnings():
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
"""
Download a dataset from source_repo_id, convert it, and upload to dest_repo_id.
Args:
source_repo_id: Source repository to download from
dest_repo_id: Destination repository to upload to (defaults to source_repo_id)
episodes: Comma-separated list of episode indices to include (e.g. "0,1,2,3")
branch: Branch to upload to
num_workers: Number of workers for stats computation
force_cache_sync: Whether to force cache synchronization
"""
if dest_repo_id is None:
dest_repo_id = source_repo_id
# Parse episodes list if provided
episode_list = None
if episodes:
try:
episode_list = [int(ep.strip()) for ep in episodes.split(",")]
print(f"Loading episodes: {episode_list}")
except ValueError as e:
raise ValueError(
f"Invalid episodes format '{episodes}'. Use comma-separated integers like '0,1,2,3'"
) from e
print(f"Downloading dataset from: {source_repo_id}")
# Try to load the dataset with different approaches to handle versioning issues
dataset = None
load_attempts = [
{"revision": None}, # Try latest first
{"revision": V20}, # Try v2.0
{"revision": "main"}, # Try main branch
]
for attempt in load_attempts:
try:
print(f"Attempting to load with revision: {attempt['revision']}")
with SuppressWarnings():
dataset = LeRobotDataset(
source_repo_id, episodes=episode_list, force_cache_sync=force_cache_sync, **attempt
)
print("Successfully loaded dataset!")
break
except Exception as e:
print(f"Failed with revision {attempt['revision']}: {e}")
continue
if dataset is None:
raise RuntimeError(f"Could not load dataset {source_repo_id} with any revision")
# Clean up old stats if present
if (dataset.root / EPISODES_STATS_PATH).is_file():
(dataset.root / EPISODES_STATS_PATH).unlink()
print("Removed existing episodes_stats.jsonl")
print("Converting stats to new format...")
convert_stats(dataset, num_workers=num_workers)
ref_stats = load_stats(dataset.root)
check_aggregate_stats(dataset, ref_stats)
# Update dataset info
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
write_info(dataset.meta.info, dataset.root)
print(f"Updated codebase_version to {CODEBASE_VERSION}")
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
# Change repo_id for destination if different
if dest_repo_id != source_repo_id:
print(f"Changing repository from {source_repo_id} to {dest_repo_id}")
dataset.repo_id = dest_repo_id
# delete old stats.json file
if (dataset.root / STATS_PATH).is_file:
print(f"Pushing converted dataset to: {dest_repo_id}")
dataset.push_to_hub(branch=branch, tag_version=False)
# Clean up old stats.json file locally and on hub
if (dataset.root / STATS_PATH).is_file():
(dataset.root / STATS_PATH).unlink()
print("Removed local stats.json file")
hub_api = HfApi()
if hub_api.file_exists(
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
):
hub_api.delete_file(
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
)
try:
if hub_api.file_exists(
repo_id=dest_repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
):
hub_api.delete_file(
path_in_repo=STATS_PATH, repo_id=dest_repo_id, revision=branch, repo_type="dataset"
)
print("Removed stats.json from hub")
except Exception as e:
print(f"Warning: Could not remove stats.json from hub: {e}")
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
# Create version tag
try:
hub_api.create_tag(dest_repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
print(f"Created tag {CODEBASE_VERSION} for {dest_repo_id}")
except Exception as e:
print(f"Warning: Could not create tag: {e}")
print(f"✅ Successfully converted and uploaded dataset to {dest_repo_id}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
description="Download, convert, and re-upload LeRobot datasets with proper versioning"
)
parser.add_argument(
"--repo-id",
"--source-repo-id",
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
help="Source repository identifier to download from (e.g. 'IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot')",
)
parser.add_argument(
"--dest-repo-id",
type=str,
default=None,
help="Destination repository identifier to upload to. Defaults to source-repo-id if not specified.",
)
parser.add_argument(
"--episodes",
type=str,
default=None,
help="Comma-separated list of episode indices to include (e.g. '0,1,2,3,4'). If not specified, all episodes are included.",
)
parser.add_argument(
"--branch",
@@ -109,6 +196,22 @@ if __name__ == "__main__":
default=4,
help="Number of workers for parallelizing stats compute. Defaults to 4.",
)
parser.add_argument(
"--no-cache-sync",
action="store_true",
help="Skip forcing cache synchronization (faster but may use cached data)",
)
args = parser.parse_args()
convert_dataset(**vars(args))
# Convert args to match function signature
convert_args = {
"source_repo_id": args.source_repo_id,
"dest_repo_id": args.dest_repo_id,
"episodes": args.episodes,
"branch": args.branch,
"num_workers": args.num_workers,
"force_cache_sync": not args.no_cache_sync,
}
convert_dataset(**convert_args)