diff --git a/check_buffers.py b/check_buffers.py new file mode 100644 index 000000000..e57fb3c79 --- /dev/null +++ b/check_buffers.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +"""Simple script to check buffer naming in the transformed model.""" + +from lerobot.policies.pi0.modeling_pi0 import PI0Policy + +# Load the model with strict=False to see what buffers we have +print("Loading model...") +policy = PI0Policy.from_pretrained("pepijn223/pi0_libero_lerobot", strict=False) + +# Check what buffer keys exist +state_dict = policy.state_dict() +buffer_keys = [k for k in state_dict.keys() if "buffer" in k] +normalize_keys = [k for k in state_dict.keys() if "normalize" in k] + +print("\nAll buffer keys:") +for key in buffer_keys: + print(f" {key}") + +print("\nAll normalize keys:") +for key in normalize_keys: + print(f" {key}") + +print("\nAll keys (first 20):") +for i, key in enumerate(state_dict.keys()): + if i < 20: + print(f" {key}") diff --git a/inject_stats.py b/inject_stats.py new file mode 100644 index 000000000..3b6c15836 --- /dev/null +++ b/inject_stats.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python +"""Script for Pi0 pretrained policy inference and Hub upload.""" + +import argparse +from datetime import datetime + +import numpy as np +import torch + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.policies.pi0.modeling_pi0 import PI0Policy + +# Set seed +torch.manual_seed(42) + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Pi0 policy inference and Hub upload") + parser.add_argument( + "--source-model-id", + type=str, + default="pepijn223/pi0_libero_lerobot", + help="Source model repository ID on Hugging Face Hub", + ) + parser.add_argument( + "--dataset-id", type=str, default="pepijn223/libero", help="Dataset repository ID on Hugging Face Hub" + ) + parser.add_argument( + "--output-model-id", + type=str, + required=True, + help="Output model repository ID to upload to (e.g., 'your-username/pi0-libero-fixed')", + ) + parser.add_argument( + "--device", type=str, default="cpu", choices=["cpu", "cuda", "mps"], help="Device to run inference on" + ) + parser.add_argument("--episode", type=int, default=0, help="Episode index to load from dataset") + parser.add_argument( + "--sample-idx", type=int, default=10, help="Sample index within episode to use for inference" + ) + parser.add_argument("--private", action="store_true", help="Make the uploaded model private") + parser.add_argument( + "--commit-message", type=str, default=None, help="Custom commit message for the upload" + ) + return parser.parse_args() + + +def _inject_normalization_stats(policy: PI0Policy, dataset_meta: LeRobotDatasetMetadata, key_mapping: dict): + """Recreate normalization layers with proper stats from the dataset.""" + from lerobot.policies.normalize import Normalize, Unnormalize + + # Convert numpy stats to the format expected by normalization layers and remap keys + stats = {} + for dataset_key, stat_dict in dataset_meta.stats.items(): + # Use mapped key if available, otherwise use original key + policy_key = key_mapping.get(dataset_key, dataset_key) + + stats[policy_key] = { + stat_type: torch.from_numpy(stat_array) if isinstance(stat_array, np.ndarray) else stat_array + for stat_type, stat_array in stat_dict.items() + } + + print(f"Available stats keys: {list(stats.keys())}") + print( + f"Policy expects keys: input={list(policy.config.input_features.keys())}, output={list(policy.config.output_features.keys())}" + ) + + # Recreate normalization layers with proper stats + normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats) + + normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats) + + unnormalize_outputs = Unnormalize( + policy.config.output_features, policy.config.normalization_mapping, stats + ) + + # Replace the normalization layers on the policy + policy.normalize_inputs = normalize_inputs + policy.normalize_targets = normalize_targets + policy.unnormalize_outputs = unnormalize_outputs + + print("Normalization layers recreated with dataset stats.") + + +def configure_policy_features(policy: PI0Policy, dataset: LeRobotDataset): + """Configure policy input and output features based on dataset metadata.""" + print(f"Dataset features: {list(dataset.meta.features.keys())}") + + # Create a proper mapping from dataset keys to policy keys + dataset_to_policy_mapping = {} + + # Handle images + if "image" in dataset.meta.features: + dataset_to_policy_mapping["image"] = "observation.images.image" + if "wrist_image" in dataset.meta.features: + dataset_to_policy_mapping["wrist_image"] = "observation.images.image2" + + # Handle state + if "state" in dataset.meta.features: + dataset_to_policy_mapping["state"] = "observation.state" + + # Handle actions + if "actions" in dataset.meta.features: + dataset_to_policy_mapping["actions"] = "action" + + print(f"Key mapping: {dataset_to_policy_mapping}") + + # Clear existing input features and reconfigure with proper mapping + policy.config.input_features = {} + policy.config.output_features = {} + + # Map visual features + for dataset_key, policy_key in dataset_to_policy_mapping.items(): + if dataset_key in ["image", "wrist_image"]: + feature_info = dataset.meta.features[dataset_key] + # Convert HWC to CHW format and resize + shape = (3, 224, 224) # Pi0 expects CHW format + policy.config.input_features[policy_key] = PolicyFeature(type=FeatureType.VISUAL, shape=shape) + + # Map state features + for dataset_key, policy_key in dataset_to_policy_mapping.items(): + if dataset_key == "state": + feature_info = dataset.meta.features[dataset_key] + shape = tuple(feature_info["shape"]) + policy.config.input_features[policy_key] = PolicyFeature(type=FeatureType.STATE, shape=shape) + + # Map action features + for dataset_key, policy_key in dataset_to_policy_mapping.items(): + if dataset_key == "actions": + feature_info = dataset.meta.features[dataset_key] + shape = tuple(feature_info["shape"]) + policy.config.output_features[policy_key] = PolicyFeature(type=FeatureType.ACTION, shape=shape) + + print(f"Policy input_features: {list(policy.config.input_features.keys())}") + print(f"Policy output_features: {list(policy.config.output_features.keys())}") + print(f"Policy image_features: {list(policy.config.image_features.keys())}") + print(f"Policy action_feature: {policy.config.action_feature}") + + return dataset_to_policy_mapping + + +def fix_buffer_naming(policy: PI0Policy): + """Fix buffer naming issues in the loaded policy state dict.""" + print("Fixing normalization buffer naming issues...") + + state_dict = policy.state_dict() + corrected_state_dict = {} + fixes_applied = 0 + + for key, value in state_dict.items(): + new_key = key + + # Fix buffer naming: buffer_observation_state_mean -> buffer_observation_state.mean + if "buffer_observation_state_mean" in key: + new_key = key.replace("buffer_observation_state_mean", "buffer_observation_state.mean") + fixes_applied += 1 + print(f" Fixed: {key} -> {new_key}") + elif "buffer_observation_state_std" in key: + new_key = key.replace("buffer_observation_state_std", "buffer_observation_state.std") + fixes_applied += 1 + print(f" Fixed: {key} -> {new_key}") + # Remove image buffers that aren't expected (they cause conflicts) + elif "buffer_observation_image_mean" in key or "buffer_observation_image_std" in key: + print(f" Removed unexpected buffer: {key}") + continue # Skip this buffer + + corrected_state_dict[new_key] = value + + # Add missing action buffers with dummy values (will be replaced by dataset stats) + missing_buffers = [ + "normalize_targets.buffer_action.mean", + "normalize_targets.buffer_action.std", + "unnormalize_outputs.buffer_action.mean", + "unnormalize_outputs.buffer_action.std", + ] + + for buffer_key in missing_buffers: + if buffer_key not in corrected_state_dict: + # Use dummy values - these will be overwritten by proper dataset stats later + if "mean" in buffer_key: + corrected_state_dict[buffer_key] = torch.zeros(8) # Assume 8-dim action + else: # std + corrected_state_dict[buffer_key] = torch.ones(8) # Assume 8-dim action + fixes_applied += 1 + print(f" Added missing buffer: {buffer_key}") + + print(f"Applied {fixes_applied} buffer fixes") + + # Load the corrected state dict back into the policy + policy.load_state_dict(corrected_state_dict) + return policy + + +def main(): + """Main function to run the Pi0 inference and upload.""" + args = parse_args() + + # Load pretrained Pi0 model directly from Hugging Face Hub + print(f"Loading pretrained Pi0 model from {args.source_model_id}...") + + # Load with strict=False to allow missing/unexpected keys, then fix them manually + policy = PI0Policy.from_pretrained(args.source_model_id, strict=False) + policy = fix_buffer_naming(policy) + policy.eval() + policy.to(args.device) + + # Load dataset and get a sample + print(f"Loading dataset: {args.dataset_id}") + dataset = LeRobotDataset(args.dataset_id, episodes=[args.episode]) + meta: LeRobotDatasetMetadata = dataset.meta + sample = dataset[args.sample_idx] + + # Configure policy features + key_mapping = configure_policy_features(policy, dataset) + + # Inject normalization stats with proper key mapping + _inject_normalization_stats(policy, meta, key_mapping) + + # Prepare batch for PI0 (handle temporal dimensions) + batch = {} + + # Map dataset sample keys to policy keys + reverse_mapping = {v: k for k, v in key_mapping.items()} + + for policy_key in policy.config.input_features: + # Find the corresponding dataset key + dataset_key = reverse_mapping.get(policy_key, policy_key) + + if dataset_key in sample: + data = sample[dataset_key] + + # Handle image data: convert from HWC to CHW and normalize + if policy_key.startswith("observation.images."): + if data.dim() == 3 and data.shape[-1] == 3: # HWC format + data = data.permute(2, 0, 1) # Convert to CHW + # Normalize to [0, 1] range if needed + if data.dtype == torch.uint8: + data = data.float() / 255.0 + # Resize to expected size if needed + if data.shape[-2:] != (224, 224): + import torch.nn.functional as F # noqa: N812 + + data = F.interpolate( + data.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False + )[0] + + # Remove temporal dimension if present + if data.dim() > len(policy.config.input_features[policy_key].shape): + data = data[0] + + batch[policy_key] = data.unsqueeze(0) # Add batch dimension + + # Debug: print what's in the sample + print(f"Sample keys: {list(sample.keys())}") + print(f"Batch keys prepared: {list(batch.keys())}") + + # Pi0 requires task description - add a default if not available + if "task" in sample: + batch["task"] = [sample["task"]] # Keep as list of strings + else: + print("No task in sample, using default task description") + batch["task"] = ["Complete the manipulation task"] + + print(f"Task: {batch['task'][0]}") + print(f"Final batch keys: {list(batch.keys())}") + + # Run inference + with torch.no_grad(): + action = policy.select_action(batch) + print(f"Predicted action shape: {action.shape}") + print(f"Predicted action: {action.tolist()}") + + print("āœ… Pi0 pretrained inference completed successfully!") + + # Upload to Hugging Face Hub + print(f"\nšŸ“¤ Uploading model to Hugging Face Hub: {args.output_model_id}") + + # Create commit message + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + commit_message = ( + args.commit_message + or f"Pi0 model with injected normalization stats from {args.dataset_id} - {timestamp}" + ) + + # Update model configuration with dataset info + policy.config.push_to_hub = True + policy.config.repo_id = args.output_model_id + policy.config.private = args.private + + # Add metadata about the adaptation + adaptation_info = { + "source_model": args.source_model_id, + "dataset_used": args.dataset_id, + "adaptation_date": timestamp, + "stats_injected": True, + "key_mapping": key_mapping, + "inference_test_passed": True, + "sample_action_shape": list(action.shape), + } + + try: + # Push to hub + policy.push_to_hub( + repo_id=args.output_model_id, + private=args.private, + commit_message=commit_message, + create_pr=False, + ) + + # Also save the adaptation info as a separate file + import json + import os + import tempfile + + from huggingface_hub import HfApi + + api = HfApi() + + # Create a temporary file with adaptation info + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(adaptation_info, f, indent=2) + temp_path = f.name + + try: + api.upload_file( + path_or_fileobj=temp_path, + path_in_repo="adaptation_info.json", + repo_id=args.output_model_id, + commit_message=f"Add adaptation metadata - {timestamp}", + ) + finally: + os.unlink(temp_path) + + print(f"āœ… Model successfully uploaded to: https://huggingface.co/{args.output_model_id}") + print("šŸ“‹ Adaptation info:") + for key, value in adaptation_info.items(): + print(f" {key}: {value}") + + except Exception as e: + print(f"āŒ Error uploading to Hub: {e}") + raise + + +if __name__ == "__main__": + main() diff --git a/load_pi0.py b/load_pi0.py index 6eb2e4304..0300bc731 100644 --- a/load_pi0.py +++ b/load_pi0.py @@ -1,9 +1,13 @@ +import json import os import random from datetime import datetime import numpy as np import torch +from huggingface_hub import hf_hub_download # noqa: E402 +from safetensors.torch import load_file # noqa: E402 +from transformers.model_debugging_utils import model_addition_debugger_context from lerobot.configs.policies import FeatureType, PolicyFeature from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE @@ -26,36 +30,15 @@ def set_all_seeds(seed=42): torch.use_deterministic_algorithms(True) print(f"All random seeds set to {seed} for reproducible results (deterministic mode enabled)") - # For MPS devices, set additional deterministic settings - if torch.backends.mps.is_available(): - print("MPS device detected - using deterministic settings") - # Set seeds at the start set_all_seeds(RANDOM_SEED) -# Import model debugger for detailed forward pass analysis -try: - import transformers - from transformers.model_debugging_utils import model_addition_debugger_context - - DEBUGGER_AVAILABLE = True - print("āœ… Model debugger available") - print(f" Transformers version: {transformers.__version__}") -except ImportError as e: - print("āš ļø Model debugger not available (requires transformers with debugging utils)") - print(f" Import error: {e}") - DEBUGGER_AVAILABLE = False - config_model_path = "lerobot/pi0" # Use config from official model official_model_path = "lerobot/pi0" # Official model -custom_model_path = "pepijn223/pi0_libero_fp32" # Custom model to compare # pepijn223/pi0_base_fp32 +custom_model_path = "pepijn223/pi0_base_fp32" # Custom model to compare # pepijn223/pi0_base_fp32 device = "mps" -# For testing determinism, set both to same model: -# official_model_path = "pepijn223/pi0_base_fp32" -# custom_model_path = "pepijn223/pi0_base_fp32" - USE_FULL_TENSORS = True SAVE_TENSORS_TO_DISK = False @@ -65,24 +48,10 @@ UPLOAD_TO_HUB = True # Set to True to upload to HuggingFace Hub TRANSFORMED_MODEL_NAME = "pepijn223/pi0_base_fp32_lerobot_format" # Target repo name COMMIT_MESSAGE = "Add transformed PI0 model with correct key format for lerobot" -# Create debug directory -if DEBUGGER_AVAILABLE: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - debug_path = os.path.join("debug_outputs", f"pi0_debug_direct_{timestamp}") - os.makedirs(debug_path, exist_ok=True) - print(f"šŸ” Model debugging enabled - outputs will be saved to: {debug_path}") -else: - debug_path = None - -# Create dummy normalization stats (similar to openpi example) -print("Creating dummy normalization stats...") - -# Load shared config and create both models for comparison -print("Loading shared config from lerobot/pi0...") -import json # noqa: E402 - -from huggingface_hub import hf_hub_download # noqa: E402 -from safetensors.torch import load_file # noqa: E402 +timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") +debug_path = os.path.join("debug_outputs", f"pi0_debug_direct_{timestamp}") +os.makedirs(debug_path, exist_ok=True) +print(f"Model debugging enabled - outputs will be saved to: {debug_path}") # Download and load the config manually to avoid draccus parsing issues config_file = hf_hub_download(repo_id=config_model_path, filename="config.json") @@ -117,7 +86,7 @@ def load_policy_with_weights( print(f"Downloaded {model_name} weights to: {model_file}") # Load state dict and apply transformations - print(f"šŸ” Investigating safetensors file: {model_file}") + print(f"Investigating safetensors file: {model_file}") # First, check what's in the metadata try: @@ -191,12 +160,12 @@ def load_policy_with_weights( # Check what we're missing and what we actually have expected_embed_key = "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" if expected_embed_key not in all_keys: - print(f"āš ļø Missing expected embed_tokens key: {expected_embed_key}") + print(f" Missing expected embed_tokens key: {expected_embed_key}") # Let's see what keys we actually have for debugging - print("šŸ” Debugging: Looking for any embedding-related keys...") + print("Debugging: Looking for any embedding-related keys...") all_embed_related = [k for k in all_keys if "embed" in k.lower()] - print(f" Keys containing 'embed': {all_embed_related}") + print(f"Keys containing 'embed': {all_embed_related}") # Look for any keys that might contain embeddings potential_embed_keys = [ @@ -231,61 +200,6 @@ def load_policy_with_weights( final_state_dict = {} transformation_count = 0 - # First, handle the missing embed_tokens - if expected_embed_key not in all_keys: - print("šŸ”„ Attempting to fix missing embed_tokens...") - - # Option 1: Copy from gemma_expert if available - if gemma_expert_embed: - source_key = gemma_expert_embed[0] - if source_key in transformed_state_dict: - final_state_dict[expected_embed_key] = transformed_state_dict[source_key] - print(f"āœ… Created missing embed_tokens by copying from: {source_key}") - transformation_count += 1 - else: - # Option 2: Try to manually load embed_tokens from safetensors metadata - try: - from safetensors import safe_open - - with safe_open(model_file, framework="pt", device="cpu") as f: - metadata = f.metadata() - if metadata: - # Look for embed_tokens in metadata - embed_keys_in_metadata = [k for k in metadata.keys() if "embed_tokens" in k] - print(f" embed_tokens in metadata: {embed_keys_in_metadata}") - - if embed_keys_in_metadata: - # Try to extract the tensor using the metadata key - metadata_key = embed_keys_in_metadata[0] - try: - # The metadata key might be different from the tensor key - # Try different variations - possible_keys = [ - metadata_key, - metadata_key.replace("model.", ""), # Remove model. prefix - "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight", - ] - - for test_key in possible_keys: - if test_key in f.keys(): - embed_tensor = f.get_tensor(test_key) - final_state_dict[expected_embed_key] = embed_tensor - print( - f"āœ… Manually extracted embed_tokens: {test_key} -> {expected_embed_key}" - ) - transformation_count += 1 - break - else: - print("āŒ Could not find embed_tokens tensor despite metadata entry") - print(f" Tried keys: {possible_keys}") - print(f" Available keys sample: {list(f.keys())[:10]}") - except Exception as e2: - print(f"āŒ Error extracting embed_tokens tensor: {e2}") - else: - print("āŒ No metadata found in safetensors file") - except Exception as e: - print(f"āŒ Failed to manually extract embed_tokens: {e}") - for key, value in transformed_state_dict.items(): new_key = key original_key = key @@ -296,7 +210,7 @@ def load_policy_with_weights( "paligemma_with_expert.paligemma.vision_tower.vision_model", "paligemma_with_expert.paligemma.model.vision_tower.vision_model", ) - print(f"āœ… Transformed vision key: {original_key} -> {new_key}") + print(f"Transformed vision key: {original_key} -> {new_key}") transformation_count += 1 # Transform multi_modal_projector keys: ADD .model between paligemma and multi_modal_projector @@ -305,7 +219,7 @@ def load_policy_with_weights( "paligemma_with_expert.paligemma.multi_modal_projector", "paligemma_with_expert.paligemma.model.multi_modal_projector", ) - print(f"āœ… Transformed multi_modal_projector key: {original_key} -> {new_key}") + print(f"Transformed multi_modal_projector key: {original_key} -> {new_key}") transformation_count += 1 # NO transformation needed for language_model keys - they're already correct! @@ -327,7 +241,7 @@ def load_policy_with_weights( missing_in_provided = policy_keys - provided_keys extra_in_provided = provided_keys - policy_keys - print(f"šŸ” Policy expects {len(policy_keys)} keys, we provide {len(provided_keys)} keys") + print(f"Policy expects {len(policy_keys)} keys, we provide {len(provided_keys)} keys") if missing_in_provided: print( f" Missing from provided: {list(missing_in_provided)[:5]}{'...' if len(missing_in_provided) > 5 else ''}" @@ -371,7 +285,7 @@ custom_policy = load_policy_with_weights( custom_model_path, shared_config, "Custom Model", apply_transformations=True ) -print("\nāœ… Both models loaded successfully!") +print("\nBoth models loaded successfully!") print(f"Shared config: {shared_config}") print(f"Device: {device}") @@ -436,40 +350,14 @@ for name in custom_params.keys(): param_differences.append(f"Extra parameter in custom model: {name}") if param_differences: - print("āš ļø Parameter differences found:") + print("Parameter differences found:") for diff in param_differences[:10]: # Show first 10 differences print(f" {diff}") if len(param_differences) > 10: print(f" ... and {len(param_differences) - 10} more differences") else: - print("āœ… All model parameters are identical!") + print("All model parameters are identical!") -# Also check buffers (e.g., normalization statistics) -print("\n=== Buffer Comparison ===") -official_buffers = dict(official_policy.named_buffers()) -custom_buffers = dict(custom_policy.named_buffers()) - -buffer_differences = [] -for name in official_buffers.keys(): - if name not in custom_buffers: - buffer_differences.append(f"Missing buffer in custom model: {name}") - else: - diff = torch.abs(official_buffers[name] - custom_buffers[name]).max().item() - if diff > 1e-8: - buffer_differences.append(f"Buffer {name}: max difference = {diff:.2e}") - -for name in custom_buffers.keys(): - if name not in official_buffers: - buffer_differences.append(f"Extra buffer in custom model: {name}") - -if buffer_differences: - print("āš ļø Buffer differences found:") - for diff in buffer_differences[:10]: - print(f" {diff}") - if len(buffer_differences) > 10: - print(f" ... and {len(buffer_differences) - 10} more differences") -else: - print("āœ… All model buffers are identical!") # Get the raw models for direct comparison official_raw_model = official_policy.model @@ -490,7 +378,7 @@ example = { print(f"\nProvided input keys: {list(example.keys())}") -print("\nšŸ”„ Preparing inputs for direct model call...") +print("\nPreparing inputs for direct model call...") # Apply input transformation (similar to openpi's policy._input_transform) transformed_example = {} @@ -593,69 +481,22 @@ with torch.no_grad(): print("\n=== Running Forward Passes ===") - if DEBUGGER_AVAILABLE and debug_path: - print("šŸ” Running with model_addition_debugger_context for detailed analysis...") - - # Create separate debug paths for each model - official_debug_path = os.path.join(debug_path, "official_model") - custom_debug_path = os.path.join(debug_path, "custom_model") - os.makedirs(official_debug_path, exist_ok=True) - os.makedirs(custom_debug_path, exist_ok=True) - - # Set deterministic mode for forward pass - torch.manual_seed(RANDOM_SEED) - - # Run official model with debugger - print("Running Official Model forward pass with debugger...") - with model_addition_debugger_context( - official_raw_model, - debug_path=official_debug_path, - do_prune_layers=False, # Output ALL layers - use_repr=not SAVE_TENSORS_TO_DISK, - ): - official_loss = official_raw_model.forward( - images=images, - img_masks=img_masks, - lang_tokens=lang_tokens, - lang_masks=lang_masks, - state=state, - actions=dummy_actions, - noise=noise, - time=time, - ) - - # Reset seed before second forward pass to ensure any internal randomness is identical - torch.manual_seed(RANDOM_SEED) - - # Run custom model with debugger - print("Running Custom Model forward pass with debugger...") - with model_addition_debugger_context( - custom_raw_model, - debug_path=custom_debug_path, - do_prune_layers=False, # Output ALL layers - use_repr=not SAVE_TENSORS_TO_DISK, - ): - custom_loss = custom_raw_model.forward( - images=images, - img_masks=img_masks, - lang_tokens=lang_tokens, - lang_masks=lang_masks, - state=state, - actions=dummy_actions, - noise=noise, - time=time, - ) - - print(f"šŸ“Š Official model debug outputs saved to: {official_debug_path}") - print(f"šŸ“Š Custom model debug outputs saved to: {custom_debug_path}") - else: - print("Running without detailed debugging (model_addition_debugger_context not available)...") - - # Set deterministic mode for forward pass - torch.manual_seed(RANDOM_SEED) - - # Run official model - print("Running Official Model forward pass...") + print("Running with model_addition_debugger_context for detailed analysis...") + # Create separate debug paths for each model + official_debug_path = os.path.join(debug_path, "official_model") + custom_debug_path = os.path.join(debug_path, "custom_model") + os.makedirs(official_debug_path, exist_ok=True) + os.makedirs(custom_debug_path, exist_ok=True) + # Set deterministic mode for forward pass + torch.manual_seed(RANDOM_SEED) + # Run official model with debugger + print("Running Official Model forward pass with debugger...") + with model_addition_debugger_context( + official_raw_model, + debug_path=official_debug_path, + do_prune_layers=False, # Output ALL layers + use_repr=not SAVE_TENSORS_TO_DISK, + ): official_loss = official_raw_model.forward( images=images, img_masks=img_masks, @@ -666,12 +507,16 @@ with torch.no_grad(): noise=noise, time=time, ) - - # Reset seed before second forward pass to ensure any internal randomness is identical - torch.manual_seed(RANDOM_SEED) - - # Run custom model with same inputs - print("Running Custom Model forward pass...") + # Reset seed before second forward pass to ensure any internal randomness is identical + torch.manual_seed(RANDOM_SEED) + # Run custom model with debugger + print("Running Custom Model forward pass with debugger...") + with model_addition_debugger_context( + custom_raw_model, + debug_path=custom_debug_path, + do_prune_layers=False, # Output ALL layers + use_repr=not SAVE_TENSORS_TO_DISK, + ): custom_loss = custom_raw_model.forward( images=images, img_masks=img_masks, @@ -683,6 +528,9 @@ with torch.no_grad(): time=time, ) + print(f"Official model debug outputs saved to: {official_debug_path}") + print(f"Custom model debug outputs saved to: {custom_debug_path}") + print("\n=== Output Comparison ===") print(f"Official model loss shape: {official_loss.shape}") print(f"Custom model loss shape: {custom_loss.shape}") @@ -705,48 +553,49 @@ with torch.no_grad(): # Determine if models are equivalent are_equivalent = loss_diff.max().item() < 1e-6 - print(f"\nšŸŽÆ Models are {'EQUIVALENT' if are_equivalent else 'DIFFERENT'}") + print(f"\nModels are {'EQUIVALENT' if are_equivalent else 'DIFFERENT'}") print(f" (Max difference: {loss_diff.max().item():.8f}, Threshold: 1e-6)") - if DEBUGGER_AVAILABLE and debug_path: - print(f"\nšŸ“Š Detailed debugging outputs saved to: {debug_path}") + print(f"\nDetailed debugging outputs saved to: {debug_path}") + # Save comparison results + comparison_results = { + "official_loss_stats": { + "shape": list(official_loss.shape), + "mean": official_loss.mean().item(), + "std": official_loss.std().item(), + "min": official_loss.min().item(), + "max": official_loss.max().item(), + }, + "custom_loss_stats": { + "shape": list(custom_loss.shape), + "mean": custom_loss.mean().item(), + "std": custom_loss.std().item(), + "min": custom_loss.min().item(), + "max": custom_loss.max().item(), + }, + "difference_stats": { + "mean_abs_diff": loss_diff.mean().item(), + "max_abs_diff": loss_diff.max().item(), + "min_abs_diff": loss_diff.min().item(), + "std_diff": loss_diff.std().item(), + "are_equivalent": are_equivalent, + }, + } - # Save comparison results - comparison_results = { - "official_loss_stats": { - "shape": list(official_loss.shape), - "mean": official_loss.mean().item(), - "std": official_loss.std().item(), - "min": official_loss.min().item(), - "max": official_loss.max().item(), - }, - "custom_loss_stats": { - "shape": list(custom_loss.shape), - "mean": custom_loss.mean().item(), - "std": custom_loss.std().item(), - "min": custom_loss.min().item(), - "max": custom_loss.max().item(), - }, - "difference_stats": { - "mean_abs_diff": loss_diff.mean().item(), - "max_abs_diff": loss_diff.max().item(), - "min_abs_diff": loss_diff.min().item(), - "std_diff": loss_diff.std().item(), - "are_equivalent": are_equivalent, - }, - } - - import json - - comparison_file = os.path.join(debug_path, "model_comparison_results.json") - with open(comparison_file, "w") as f: - json.dump(comparison_results, f, indent=2) - print(f" Comparison results saved to: {comparison_file}") + comparison_file = os.path.join(debug_path, "model_comparison_results.json") + with open(comparison_file, "w") as f: + json.dump(comparison_results, f, indent=2) + print(f" Comparison results saved to: {comparison_file}") # Save and upload transformed model if requested -if SAVE_TRANSFORMED_MODEL and are_equivalent: - print("\nšŸš€ Saving Transformed Model...") - print("Models are equivalent - proceeding with transformation and upload") +if SAVE_TRANSFORMED_MODEL: + print("\nSaving Transformed Model...") + if are_equivalent: + print("Models are equivalent - proceeding with transformation and upload") + else: + print("Models are NOT equivalent, but proceeding with upload anyway") + print(f" Max difference: {loss_diff.max().item():.2e}") + print(" This might be useful for debugging or partial transformations") # Create timestamp for README transformation_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -786,7 +635,8 @@ The original model had a different key naming convention. This model applies the ## Verification -This transformed model produces **identical outputs** (difference = {loss_diff.max().item():.2e}) to the official model `{official_model_path}` when tested with the same inputs. +{"This transformed model produces **identical outputs**" if are_equivalent else "This transformed model has **slightly different outputs**"} (max difference = {loss_diff.max().item():.2e}) compared to the official model `{official_model_path}` when tested with the same inputs. +{"**Models are EQUIVALENT** (difference < 1e-6)" if are_equivalent else "**Models are NOT equivalent** (difference >= 1e-6) - use with caution"} ## Usage @@ -803,9 +653,7 @@ action = policy.select_action(observation_batch) ## Original Model - **Source**: {custom_model_path} -- **Transformation Date**: {transformation_timestamp} - **Verified Against**: {official_model_path} -- **Max Output Difference**: {loss_diff.max().item():.2e} ## Technical Details @@ -818,11 +666,11 @@ action = policy.select_action(observation_batch) with open(readme_path, "w") as f: f.write(readme_content.strip()) - print(f"āœ… Model saved locally to: {local_save_path}") + print(f"Model saved locally to: {local_save_path}") # Upload to HuggingFace Hub if requested if UPLOAD_TO_HUB: - print(f"\nšŸ“¤ Uploading to HuggingFace Hub: {TRANSFORMED_MODEL_NAME}") + print(f"\nUploading to HuggingFace Hub: {TRANSFORMED_MODEL_NAME}") try: # Push to hub @@ -833,27 +681,24 @@ action = policy.select_action(observation_batch) safe_serialization=True, ) - print(f"āœ… Model successfully uploaded to: https://huggingface.co/{TRANSFORMED_MODEL_NAME}") - print("šŸŽ‰ You can now use this model directly without any transformations!") + print(f"Model successfully uploaded to: https://huggingface.co/{TRANSFORMED_MODEL_NAME}") + print("You can now use this model directly without any transformations!") print("\n Usage:") print(" from lerobot.policies.pi0.modeling_pi0 import PI0Policy") print(f" policy = PI0Policy.from_pretrained('{TRANSFORMED_MODEL_NAME}')") except Exception as upload_error: - print(f"āŒ Failed to upload to HuggingFace Hub: {upload_error}") - print(f"šŸ’” You can manually upload the model from: {local_save_path}") + print(f"Failed to upload to HuggingFace Hub: {upload_error}") + print(f"You can manually upload the model from: {local_save_path}") print(" Or set UPLOAD_TO_HUB = False and upload later") except Exception as e: import traceback - print(f"āŒ Error saving transformed model: {str(e)}") + print(f"Error saving transformed model: {str(e)}") print("Full traceback:") traceback.print_exc() - print("šŸ’” The model transformation logic works, but saving failed") + print("The model transformation logic works, but saving failed") -elif not are_equivalent: - print("\nāš ļø Skipping model save - models are not equivalent") - print(f" Max difference: {loss_diff.max().item():.2e} > threshold: 1e-6") else: - print("\nšŸ’” Model transformation and upload disabled (SAVE_TRANSFORMED_MODEL = False)") + print("\nModel transformation and upload disabled (SAVE_TRANSFORMED_MODEL = False)") diff --git a/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py b/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py index 4ebc1086a..34c93a7cd 100644 --- a/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py +++ b/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py @@ -13,20 +13,22 @@ # limitations under the License. """ -This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to -2.1. It will: +This script will help you download any LeRobot dataset from the hub, convert it to the latest format, and +upload it to your own repository. It will: +- Download the dataset from any source repository - Generate per-episodes stats and writes them in `episodes_stats.jsonl` -- Check consistency between these new stats and the old ones. -- Remove the deprecated `stats.json`. -- Update codebase_version in `info.json`. -- Push this new version to the hub on the 'main' branch and tags it with "v2.1". +- Update codebase_version in `info.json` to the latest version +- Create proper version tags +- Push the converted dataset to your specified destination repository Usage: ```bash python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 \ - --repo-id=aliberts/koch_tutorial + --source-repo-id=IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot \ + --dest-repo-id=your-username/libero_spatial_converted \ + --episodes=0,1,2,3,4 ``` """ @@ -37,8 +39,8 @@ import logging from huggingface_hub import HfApi from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset -from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info -from lerobot.datasets.v21.convert_stats import check_aggregate_stats, convert_stats +from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, write_info +from lerobot.datasets.v21.convert_stats import convert_stats V20 = "v2.0" V21 = "v2.1" @@ -54,48 +56,133 @@ class SuppressWarnings: def convert_dataset( - repo_id: str, + source_repo_id: str, + dest_repo_id: str | None = None, + episodes: str | None = None, branch: str | None = None, num_workers: int = 4, + force_cache_sync: bool = True, ): - with SuppressWarnings(): - dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True) + """ + Download a dataset from source_repo_id, convert it, and upload to dest_repo_id. + Args: + source_repo_id: Source repository to download from + dest_repo_id: Destination repository to upload to (defaults to source_repo_id) + episodes: Comma-separated list of episode indices to include (e.g. "0,1,2,3") + branch: Branch to upload to + num_workers: Number of workers for stats computation + force_cache_sync: Whether to force cache synchronization + """ + if dest_repo_id is None: + dest_repo_id = source_repo_id + + # Parse episodes list if provided + episode_list = None + if episodes: + try: + episode_list = [int(ep.strip()) for ep in episodes.split(",")] + print(f"Loading episodes: {episode_list}") + except ValueError as e: + raise ValueError( + f"Invalid episodes format '{episodes}'. Use comma-separated integers like '0,1,2,3'" + ) from e + + print(f"Downloading dataset from: {source_repo_id}") + + # Try to load the dataset with different approaches to handle versioning issues + dataset = None + load_attempts = [ + {"revision": None}, # Try latest first + {"revision": V20}, # Try v2.0 + {"revision": "main"}, # Try main branch + ] + + for attempt in load_attempts: + try: + print(f"Attempting to load with revision: {attempt['revision']}") + with SuppressWarnings(): + dataset = LeRobotDataset( + source_repo_id, episodes=episode_list, force_cache_sync=force_cache_sync, **attempt + ) + print("Successfully loaded dataset!") + break + except Exception as e: + print(f"Failed with revision {attempt['revision']}: {e}") + continue + + if dataset is None: + raise RuntimeError(f"Could not load dataset {source_repo_id} with any revision") + + # Clean up old stats if present if (dataset.root / EPISODES_STATS_PATH).is_file(): (dataset.root / EPISODES_STATS_PATH).unlink() + print("Removed existing episodes_stats.jsonl") + print("Converting stats to new format...") convert_stats(dataset, num_workers=num_workers) - ref_stats = load_stats(dataset.root) - check_aggregate_stats(dataset, ref_stats) + # Update dataset info dataset.meta.info["codebase_version"] = CODEBASE_VERSION write_info(dataset.meta.info, dataset.root) + print(f"Updated codebase_version to {CODEBASE_VERSION}") - dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/") + # Change repo_id for destination if different + if dest_repo_id != source_repo_id: + print(f"Changing repository from {source_repo_id} to {dest_repo_id}") + dataset.repo_id = dest_repo_id - # delete old stats.json file - if (dataset.root / STATS_PATH).is_file: + print(f"Pushing converted dataset to: {dest_repo_id}") + dataset.push_to_hub(branch=branch, tag_version=False) + + # Clean up old stats.json file locally and on hub + if (dataset.root / STATS_PATH).is_file(): (dataset.root / STATS_PATH).unlink() + print("Removed local stats.json file") hub_api = HfApi() - if hub_api.file_exists( - repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset" - ): - hub_api.delete_file( - path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset" - ) + try: + if hub_api.file_exists( + repo_id=dest_repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset" + ): + hub_api.delete_file( + path_in_repo=STATS_PATH, repo_id=dest_repo_id, revision=branch, repo_type="dataset" + ) + print("Removed stats.json from hub") + except Exception as e: + print(f"Warning: Could not remove stats.json from hub: {e}") - hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") + # Create version tag + try: + hub_api.create_tag(dest_repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") + print(f"Created tag {CODEBASE_VERSION} for {dest_repo_id}") + except Exception as e: + print(f"Warning: Could not create tag: {e}") + + print(f"āœ… Successfully converted and uploaded dataset to {dest_repo_id}") if __name__ == "__main__": - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + description="Download, convert, and re-upload LeRobot datasets with proper versioning" + ) parser.add_argument( - "--repo-id", + "--source-repo-id", type=str, required=True, - help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset " - "(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", + help="Source repository identifier to download from (e.g. 'IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot')", + ) + parser.add_argument( + "--dest-repo-id", + type=str, + default=None, + help="Destination repository identifier to upload to. Defaults to source-repo-id if not specified.", + ) + parser.add_argument( + "--episodes", + type=str, + default=None, + help="Comma-separated list of episode indices to include (e.g. '0,1,2,3,4'). If not specified, all episodes are included.", ) parser.add_argument( "--branch", @@ -109,6 +196,22 @@ if __name__ == "__main__": default=4, help="Number of workers for parallelizing stats compute. Defaults to 4.", ) + parser.add_argument( + "--no-cache-sync", + action="store_true", + help="Skip forcing cache synchronization (faster but may use cached data)", + ) args = parser.parse_args() - convert_dataset(**vars(args)) + + # Convert args to match function signature + convert_args = { + "source_repo_id": args.source_repo_id, + "dest_repo_id": args.dest_repo_id, + "episodes": args.episodes, + "branch": args.branch, + "num_workers": args.num_workers, + "force_cache_sync": not args.no_cache_sync, + } + + convert_dataset(**convert_args)