diff --git a/test_pi0_hub.py b/test_pi0_hub.py index f4729fbb0..dec52b3c8 100644 --- a/test_pi0_hub.py +++ b/test_pi0_hub.py @@ -7,6 +7,29 @@ import torch from lerobot.policies.pi0_openpi import PI0OpenPIPolicy +def create_dummy_stats(config): + """Create dummy dataset statistics for testing.""" + dummy_stats = { + "observation.state": { + "mean": torch.zeros(config.state_dim), + "std": torch.ones(config.state_dim), + }, + "action": { + "mean": torch.zeros(config.action_dim), + "std": torch.ones(config.action_dim), + }, + } + + # Add stats for image keys if they exist + for key in config.image_keys: + dummy_stats[key] = { + "mean": torch.zeros(3, config.image_resolution[0], config.image_resolution[1]), + "std": torch.ones(3, config.image_resolution[0], config.image_resolution[1]), + } + + return dummy_stats + + def test_hub_loading(): """Test loading model from HuggingFace hub.""" print("=" * 60) @@ -27,6 +50,21 @@ def test_hub_loading(): ) print("✓ Model loaded successfully from HuggingFace hub") + # Inject dummy stats since they aren't loaded from the hub + print("Creating dummy dataset stats for testing...") + device = next(policy.parameters()).device + dummy_stats = create_dummy_stats(policy.config) + + # Move dummy stats to device + for key, stats in dummy_stats.items(): + dummy_stats[key] = { + "mean": stats["mean"].to(device), + "std": stats["std"].to(device), + } + + # Initialize normalization layers with dummy stats if they have NaN/inf values + print("✓ Dummy stats created and moved to device") + # Get model info print("\nModel configuration:") print(f" - PaliGemma variant: {policy.config.paligemma_variant}") @@ -34,7 +72,7 @@ def test_hub_loading(): print(f" - Action dimension: {policy.config.action_dim}") print(f" - State dimension: {policy.config.state_dim}") print(f" - Action horizon: {policy.config.action_horizon}") - print(f" - Device: {next(policy.parameters()).device}") + print(f" - Device: {device}") print(f" - Dtype: {next(policy.parameters()).dtype}") except Exception as e: @@ -46,30 +84,43 @@ def test_hub_loading(): # Create dummy batch for testing batch_size = 1 - device = next(policy.parameters()).device - # Create dummy dataset stats if not loaded with the model - if not hasattr(policy, "normalize_inputs") or policy.normalize_inputs is None: + # Check if normalization layers have invalid stats and replace with dummy stats if needed + try: + # Check if the normalize_inputs has valid stats + if hasattr(policy.normalize_inputs, "stats"): + obs_state_mean = policy.normalize_inputs.stats.get("observation.state", {}).get("mean") + if obs_state_mean is not None and ( + torch.isinf(obs_state_mean).any() or torch.isnan(obs_state_mean).any() + ): + print("⚠️ Found invalid normalization stats, replacing with dummy stats...") + + # Replace with dummy stats + from lerobot.policies.normalize import Normalize, Unnormalize + + policy.normalize_inputs = Normalize( + policy.config.input_features, policy.config.normalization_mapping, dummy_stats + ) + policy.normalize_targets = Normalize( + policy.config.output_features, policy.config.normalization_mapping, dummy_stats + ) + policy.unnormalize_outputs = Unnormalize( + policy.config.output_features, policy.config.normalization_mapping, dummy_stats + ) + print("✓ Normalization layers updated with dummy stats") + except Exception as e: + print(f"⚠️ Error checking normalization stats, creating new ones: {e}") + # Fallback: create new normalization layers from lerobot.policies.normalize import Normalize, Unnormalize - dataset_stats = { - "observation.state": { - "mean": torch.zeros(policy.config.state_dim, device=device), - "std": torch.ones(policy.config.state_dim, device=device), - }, - "action": { - "mean": torch.zeros(policy.config.action_dim, device=device), - "std": torch.ones(policy.config.action_dim, device=device), - }, - } policy.normalize_inputs = Normalize( - policy.config.input_features, policy.config.normalization_mapping, dataset_stats + policy.config.input_features, policy.config.normalization_mapping, dummy_stats ) policy.normalize_targets = Normalize( - policy.config.output_features, policy.config.normalization_mapping, dataset_stats + policy.config.output_features, policy.config.normalization_mapping, dummy_stats ) policy.unnormalize_outputs = Unnormalize( - policy.config.output_features, policy.config.normalization_mapping, dataset_stats + policy.config.output_features, policy.config.normalization_mapping, dummy_stats ) # Create test batch