#!/usr/bin/env python """Script for Pi0 pretrained policy inference and Hub upload.""" import argparse from datetime import datetime import numpy as np import torch from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.policies.pi0.modeling_pi0 import PI0Policy # Set seed torch.manual_seed(42) def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser(description="Pi0 policy inference and Hub upload") parser.add_argument( "--source-model-id", type=str, default="pepijn223/pi0_libero_lerobot", help="Source model repository ID on Hugging Face Hub", ) parser.add_argument( "--dataset-id", type=str, default="pepijn223/libero", help="Dataset repository ID on Hugging Face Hub" ) parser.add_argument( "--output-model-id", type=str, required=True, help="Output model repository ID to upload to (e.g., 'your-username/pi0-libero-fixed')", ) parser.add_argument( "--device", type=str, default="cpu", choices=["cpu", "cuda", "mps"], help="Device to run inference on" ) parser.add_argument("--episode", type=int, default=0, help="Episode index to load from dataset") parser.add_argument( "--sample-idx", type=int, default=10, help="Sample index within episode to use for inference" ) parser.add_argument("--private", action="store_true", help="Make the uploaded model private") parser.add_argument( "--commit-message", type=str, default=None, help="Custom commit message for the upload" ) return parser.parse_args() def _inject_normalization_stats(policy: PI0Policy, dataset_meta: LeRobotDatasetMetadata, key_mapping: dict): """Recreate normalization layers with proper stats from the dataset.""" from lerobot.policies.normalize import Normalize, Unnormalize # Convert numpy stats to the format expected by normalization layers and remap keys stats = {} for dataset_key, stat_dict in dataset_meta.stats.items(): # Use mapped key if available, otherwise use original key policy_key = key_mapping.get(dataset_key, dataset_key) stats[policy_key] = { stat_type: torch.from_numpy(stat_array) if isinstance(stat_array, np.ndarray) else stat_array for stat_type, stat_array in stat_dict.items() } print(f"Available stats keys: {list(stats.keys())}") print( f"Policy expects keys: input={list(policy.config.input_features.keys())}, output={list(policy.config.output_features.keys())}" ) # Recreate normalization layers with proper stats normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats) normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats) unnormalize_outputs = Unnormalize( policy.config.output_features, policy.config.normalization_mapping, stats ) # Replace the normalization layers on the policy policy.normalize_inputs = normalize_inputs policy.normalize_targets = normalize_targets policy.unnormalize_outputs = unnormalize_outputs print("Normalization layers recreated with dataset stats.") def configure_policy_features(policy: PI0Policy, dataset: LeRobotDataset): """Configure policy input and output features based on dataset metadata.""" print(f"Dataset features: {list(dataset.meta.features.keys())}") # Create a proper mapping from dataset keys to policy keys dataset_to_policy_mapping = {} # Handle images if "image" in dataset.meta.features: dataset_to_policy_mapping["image"] = "observation.images.image" if "wrist_image" in dataset.meta.features: dataset_to_policy_mapping["wrist_image"] = "observation.images.image2" # Handle state if "state" in dataset.meta.features: dataset_to_policy_mapping["state"] = "observation.state" # Handle actions if "actions" in dataset.meta.features: dataset_to_policy_mapping["actions"] = "action" print(f"Key mapping: {dataset_to_policy_mapping}") # Clear existing input features and reconfigure with proper mapping policy.config.input_features = {} policy.config.output_features = {} # Map visual features for dataset_key, policy_key in dataset_to_policy_mapping.items(): if dataset_key in ["image", "wrist_image"]: feature_info = dataset.meta.features[dataset_key] # Convert HWC to CHW format and resize shape = (3, 224, 224) # Pi0 expects CHW format policy.config.input_features[policy_key] = PolicyFeature(type=FeatureType.VISUAL, shape=shape) # Map state features for dataset_key, policy_key in dataset_to_policy_mapping.items(): if dataset_key == "state": feature_info = dataset.meta.features[dataset_key] shape = tuple(feature_info["shape"]) policy.config.input_features[policy_key] = PolicyFeature(type=FeatureType.STATE, shape=shape) # Map action features for dataset_key, policy_key in dataset_to_policy_mapping.items(): if dataset_key == "actions": feature_info = dataset.meta.features[dataset_key] shape = tuple(feature_info["shape"]) policy.config.output_features[policy_key] = PolicyFeature(type=FeatureType.ACTION, shape=shape) print(f"Policy input_features: {list(policy.config.input_features.keys())}") print(f"Policy output_features: {list(policy.config.output_features.keys())}") print(f"Policy image_features: {list(policy.config.image_features.keys())}") print(f"Policy action_feature: {policy.config.action_feature}") return dataset_to_policy_mapping def fix_buffer_naming(policy: PI0Policy): """Fix buffer naming issues in the loaded policy state dict.""" print("Fixing normalization buffer naming issues...") state_dict = policy.state_dict() corrected_state_dict = {} fixes_applied = 0 for key, value in state_dict.items(): new_key = key # Fix buffer naming: buffer_observation_state_mean -> buffer_observation_state.mean if "buffer_observation_state_mean" in key: new_key = key.replace("buffer_observation_state_mean", "buffer_observation_state.mean") fixes_applied += 1 print(f" Fixed: {key} -> {new_key}") elif "buffer_observation_state_std" in key: new_key = key.replace("buffer_observation_state_std", "buffer_observation_state.std") fixes_applied += 1 print(f" Fixed: {key} -> {new_key}") # Remove image buffers that aren't expected (they cause conflicts) elif "buffer_observation_image_mean" in key or "buffer_observation_image_std" in key: print(f" Removed unexpected buffer: {key}") continue # Skip this buffer corrected_state_dict[new_key] = value # Add missing action buffers with dummy values (will be replaced by dataset stats) missing_buffers = [ "normalize_targets.buffer_action.mean", "normalize_targets.buffer_action.std", "unnormalize_outputs.buffer_action.mean", "unnormalize_outputs.buffer_action.std", ] for buffer_key in missing_buffers: if buffer_key not in corrected_state_dict: # Use dummy values - these will be overwritten by proper dataset stats later if "mean" in buffer_key: corrected_state_dict[buffer_key] = torch.zeros(8) # Assume 8-dim action else: # std corrected_state_dict[buffer_key] = torch.ones(8) # Assume 8-dim action fixes_applied += 1 print(f" Added missing buffer: {buffer_key}") print(f"Applied {fixes_applied} buffer fixes") # Load the corrected state dict back into the policy policy.load_state_dict(corrected_state_dict) return policy def main(): """Main function to run the Pi0 inference and upload.""" args = parse_args() # Load pretrained Pi0 model directly from Hugging Face Hub print(f"Loading pretrained Pi0 model from {args.source_model_id}...") # Load with strict=False to allow missing/unexpected keys, then fix them manually policy = PI0Policy.from_pretrained(args.source_model_id, strict=False) policy = fix_buffer_naming(policy) policy.eval() policy.to(args.device) # Load dataset and get a sample print(f"Loading dataset: {args.dataset_id}") dataset = LeRobotDataset(args.dataset_id, episodes=[args.episode]) meta: LeRobotDatasetMetadata = dataset.meta sample = dataset[args.sample_idx] # Configure policy features key_mapping = configure_policy_features(policy, dataset) # Inject normalization stats with proper key mapping _inject_normalization_stats(policy, meta, key_mapping) # Prepare batch for PI0 (handle temporal dimensions) batch = {} # Map dataset sample keys to policy keys reverse_mapping = {v: k for k, v in key_mapping.items()} for policy_key in policy.config.input_features: # Find the corresponding dataset key dataset_key = reverse_mapping.get(policy_key, policy_key) if dataset_key in sample: data = sample[dataset_key] # Handle image data: convert from HWC to CHW and normalize if policy_key.startswith("observation.images."): if data.dim() == 3 and data.shape[-1] == 3: # HWC format data = data.permute(2, 0, 1) # Convert to CHW # Normalize to [0, 1] range if needed if data.dtype == torch.uint8: data = data.float() / 255.0 # Resize to expected size if needed if data.shape[-2:] != (224, 224): import torch.nn.functional as F # noqa: N812 data = F.interpolate( data.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False )[0] # Remove temporal dimension if present if data.dim() > len(policy.config.input_features[policy_key].shape): data = data[0] batch[policy_key] = data.unsqueeze(0) # Add batch dimension # Debug: print what's in the sample print(f"Sample keys: {list(sample.keys())}") print(f"Batch keys prepared: {list(batch.keys())}") # Pi0 requires task description - add a default if not available if "task" in sample: batch["task"] = [sample["task"]] # Keep as list of strings else: print("No task in sample, using default task description") batch["task"] = ["Complete the manipulation task"] print(f"Task: {batch['task'][0]}") print(f"Final batch keys: {list(batch.keys())}") # Run inference with torch.no_grad(): action = policy.select_action(batch) print(f"Predicted action shape: {action.shape}") print(f"Predicted action: {action.tolist()}") print("āœ… Pi0 pretrained inference completed successfully!") # Upload to Hugging Face Hub print(f"\nšŸ“¤ Uploading model to Hugging Face Hub: {args.output_model_id}") # Create commit message timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") commit_message = ( args.commit_message or f"Pi0 model with injected normalization stats from {args.dataset_id} - {timestamp}" ) # Update model configuration with dataset info policy.config.push_to_hub = True policy.config.repo_id = args.output_model_id policy.config.private = args.private # Add metadata about the adaptation adaptation_info = { "source_model": args.source_model_id, "dataset_used": args.dataset_id, "adaptation_date": timestamp, "stats_injected": True, "key_mapping": key_mapping, "inference_test_passed": True, "sample_action_shape": list(action.shape), } try: # Push to hub policy.push_to_hub( repo_id=args.output_model_id, private=args.private, commit_message=commit_message, create_pr=False, ) # Also save the adaptation info as a separate file import json import os import tempfile from huggingface_hub import HfApi api = HfApi() # Create a temporary file with adaptation info with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: json.dump(adaptation_info, f, indent=2) temp_path = f.name try: api.upload_file( path_or_fileobj=temp_path, path_in_repo="adaptation_info.json", repo_id=args.output_model_id, commit_message=f"Add adaptation metadata - {timestamp}", ) finally: os.unlink(temp_path) print(f"āœ… Model successfully uploaded to: https://huggingface.co/{args.output_model_id}") print("šŸ“‹ Adaptation info:") for key, value in adaptation_info.items(): print(f" {key}: {value}") except Exception as e: print(f"āŒ Error uploading to Hub: {e}") raise if __name__ == "__main__": main()