From ac323b011333303e233ad87d3e687a416c5b2aa9 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 10 Sep 2025 21:33:55 +0200 Subject: [PATCH] add pi05 --- push_pi0_to_hub.py | 35 --- .../pi0_openpi/configuration_pi0openpi.py | 24 ++ .../policies/pi0_openpi/modeling_pi0openpi.py | 60 ++++- test_pi05_hub_only.py | 27 +++ test_pi05_openpi.py | 227 ++++++++++++++++++ test_pi0_hub.py | 95 +++++++- 6 files changed, 422 insertions(+), 46 deletions(-) create mode 100644 test_pi05_hub_only.py create mode 100644 test_pi05_openpi.py diff --git a/push_pi0_to_hub.py b/push_pi0_to_hub.py index f6b4b1b09..bb1db156b 100644 --- a/push_pi0_to_hub.py +++ b/push_pi0_to_hub.py @@ -154,41 +154,6 @@ def create_and_push_model( print(f"\n✓ Model successfully uploaded to: https://huggingface.co/{repo_id}") - # Test loading the model back - print("\n" + "-" * 60) - print("Testing model loading from hub...") - - try: - loaded_policy = PI0OpenPIPolicy.from_pretrained( - repo_id, - token=token, - ) - print("✓ Model loaded successfully from hub") - - # Quick validation - batch_size = 1 - device = next(loaded_policy.parameters()).device - test_batch = { - "observation.state": torch.randn(batch_size, config.state_dim, device=device), - "action": torch.randn(batch_size, config.action_horizon, config.action_dim, device=device), - "task": ["Test task"], - } - - # Add images - for key in config.image_keys: - test_batch[key] = torch.rand(batch_size, 3, 224, 224, device=device) - - # Test forward pass - loaded_policy.train() - loss, loss_dict = loaded_policy.forward(test_batch) - print(f"✓ Forward pass successful - Loss: {loss_dict['loss']:.4f}") - - except Exception as e: - print(f"✗ Failed to load model: {e}") - import traceback - - traceback.print_exc() - print("\n" + "=" * 60) print("✓ Process complete!") print("=" * 60) diff --git a/src/lerobot/policies/pi0_openpi/configuration_pi0openpi.py b/src/lerobot/policies/pi0_openpi/configuration_pi0openpi.py index 4d4d70071..70a9130df 100644 --- a/src/lerobot/policies/pi0_openpi/configuration_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/configuration_pi0openpi.py @@ -45,6 +45,19 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig # pip install transformers==4.53.2 +# Comparison of PI0 vs PI0.5 +# +# Feature | PI0 | PI0.5 +# ---------------------|---------------------------------------------|----------------------------------------- +# State Embedding | Uses state_proj layer | No state embedding +# Time Conditioning | Concatenates time with actions via | Uses time_mlp_* for AdaRMS conditioning +# | action_time_mlp_* | +# AdaRMS | Not used | Used in action expert +# Tokenizer Length | 200 tokens | 48 tokens +# discrete_state_input | False | True +# Parameter Count | Higher (includes state_proj) | Lower (no state embedding) + + @PreTrainedConfig.register_subclass("pi0_openpi") @dataclass class PI0OpenPIConfig(PreTrainedConfig): @@ -52,6 +65,7 @@ class PI0OpenPIConfig(PreTrainedConfig): paligemma_variant: str = "gemma_2b" action_expert_variant: str = "gemma_300m" pi05: bool = False # Whether to use PI0.5 variant with AdaRMS + discrete_state_input: bool | None = None # Whether to use discrete state input (defaults to pi05 value) dtype: str = "float32" # Options: "bfloat16", "float32" # Input / output structure @@ -108,6 +122,16 @@ class PI0OpenPIConfig(PreTrainedConfig): def __post_init__(self): super().__post_init__() + # Set discrete_state_input to pi05 value if not explicitly set + if self.discrete_state_input is None: # see openpi `Pi0Config, __post_init__` + object.__setattr__(self, "discrete_state_input", self.pi05) + + # Set tokenizer max length based on pi05 mode, see openpi `Pi0Config, __post_init__` + if self.pi05: + self.tokenizer_max_length = 48 + else: + self.tokenizer_max_length = 200 + # Validate configuration if self.n_action_steps > self.action_horizon: raise ValueError( diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index baff3b15f..bef211674 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -924,11 +924,14 @@ class PI0OpenPIPolicy(PreTrainedPolicy): print(f"Could not load state dict from remote files: {e}") return model - # Create a new state dict with "model." prefix for all keys that don't already have it + # First, fix any pi05-specific key differences + fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) + + # Then add "model." prefix for all keys that don't already have it remapped_state_dict = {} remap_count = 0 - for key, value in original_state_dict.items(): + for key, value in fixed_state_dict.items(): if not key.startswith("model."): new_key = f"model.{key}" remapped_state_dict[new_key] = value @@ -975,6 +978,59 @@ class PI0OpenPIPolicy(PreTrainedPolicy): return model + def _fix_pytorch_state_dict_keys( + self, state_dict, model_config + ): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys` + """Fix state dict keys to match current model architecture.""" + import re + + fixed_state_dict = {} + + for key, value in state_dict.items(): + new_key = key + + # Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias + # For gemma expert layers + if re.match( + r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight", + key, + ): + # This key structure suggests old model without adaRMS - keep as is or skip + logging.warning(f"Skipping old layer norm key (no adaRMS support): {key}") + continue + + if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key): + # Skip old norm structure + logging.warning(f"Skipping old norm key (no adaRMS support): {key}") + continue + + # Handle MLP naming changes for pi05 vs non-pi05 + if model_config.pi05: + # pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_* + if key.startswith("action_time_mlp_in."): + new_key = key.replace("action_time_mlp_in.", "time_mlp_in.") + elif key.startswith("action_time_mlp_out."): + new_key = key.replace("action_time_mlp_out.", "time_mlp_out.") + # Also handle state_proj which shouldn't exist in pi05 + if key.startswith("state_proj."): + logging.warning(f"Skipping state_proj key in pi05 mode: {key}") + continue + else: + # non-pi05 model expects action_time_mlp_*, but checkpoint might have time_mlp_* + if key.startswith("time_mlp_in."): + new_key = key.replace("time_mlp_in.", "action_time_mlp_in.") + elif key.startswith("time_mlp_out."): + new_key = key.replace("time_mlp_out.", "action_time_mlp_out.") + + # Handle vision tower embedding layer potential differences + if "patch_embedding" in key: + # Some checkpoints might have this, but current model expects different structure + logging.warning(f"Vision embedding key might need handling: {key}") + + fixed_state_dict[new_key] = value + + return fixed_state_dict + def get_optim_params(self) -> dict: # see lerobot pi0 `get_optim_params` return self.parameters() diff --git a/test_pi05_hub_only.py b/test_pi05_hub_only.py new file mode 100644 index 000000000..649445f8d --- /dev/null +++ b/test_pi05_hub_only.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# TODO(pepijn): delete this file +"""Quick test script to load and test only the PI0.5 model from HuggingFace hub.""" + +from test_pi0_hub import test_hub_loading + + +def main(): + """Test only the PI0.5 model.""" + print("\n") + print("=" * 60) + print("PI0.5 Model Quick Test") + print("=" * 60) + + success = test_hub_loading(model_id="pepijn223/pi05_base_fp32", model_name="PI0.5") + + if success: + print("\n✅ PI0.5 model loaded and tested successfully!") + else: + print("\n❌ PI0.5 test failed!") + + return success + + +if __name__ == "__main__": + success = main() + exit(0 if success else 1) diff --git a/test_pi05_openpi.py b/test_pi05_openpi.py new file mode 100644 index 000000000..60f51c92f --- /dev/null +++ b/test_pi05_openpi.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python + +"""Test script to verify PI0.5 (pi05) support in PI0OpenPI policy.""" + +import torch + +from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy + + +def test_pi05_model_architecture(): + """Test that pi05=True creates the correct model architecture.""" + print("Testing PI0.5 model architecture...") + + # Create config with pi05=True + config = PI0OpenPIConfig( + action_dim=7, + state_dim=14, + dtype="float32", + pi05=True, # Enable PI0.5 mode + ) + + # Verify tokenizer max length is set correctly + assert config.tokenizer_max_length == 48, ( + f"Expected tokenizer_max_length=48 for pi05, got {config.tokenizer_max_length}" + ) + print(f"✓ Tokenizer max length correctly set to {config.tokenizer_max_length}") + + # Verify discrete_state_input defaults to pi05 value + assert config.discrete_state_input == True, ( # noqa: E712 + f"Expected discrete_state_input=True for pi05, got {config.discrete_state_input}" + ) + print(f"✓ discrete_state_input correctly defaults to pi05 value: {config.discrete_state_input}") + + # Create dummy dataset stats + dataset_stats = { + "observation.state": { + "mean": torch.zeros(14), + "std": torch.ones(14), + }, + "action": { + "mean": torch.zeros(7), + "std": torch.ones(7), + }, + } + + # Instantiate policy + policy = PI0OpenPIPolicy(config, dataset_stats) + + # Verify pi05 model components exist + assert policy.model.pi05 == True, "Model pi05 flag not set" # noqa: E712 + print("✓ PI0.5 mode enabled in model") + + # Check that time_mlp layers exist (for AdaRMS conditioning) + assert hasattr(policy.model, "time_mlp_in"), "Missing time_mlp_in layer for pi05" + assert hasattr(policy.model, "time_mlp_out"), "Missing time_mlp_out layer for pi05" + print("✓ Time MLP layers present for AdaRMS conditioning") + + # Check that action_time_mlp layers don't exist (pi0 only) + assert not hasattr(policy.model, "action_time_mlp_in"), "action_time_mlp_in should not exist in pi05 mode" + assert not hasattr(policy.model, "action_time_mlp_out"), ( + "action_time_mlp_out should not exist in pi05 mode" + ) + print("✓ Action-time MLP layers correctly absent") + + # Check that state_proj doesn't exist in pi05 mode + assert not hasattr(policy.model, "state_proj"), "state_proj should not exist in pi05 mode" + print("✓ State projection layer correctly absent") + + # Check AdaRMS configuration in the underlying model + adarms_config = policy.model.paligemma_with_expert.paligemma.config.text_config.use_adarms + assert adarms_config == False, f"PaliGemma should not use AdaRMS, got {adarms_config}" # noqa: E712 + + adarms_expert_config = policy.model.paligemma_with_expert.gemma_expert.config.use_adarms + assert adarms_expert_config == True, ( # noqa: E712 + f"Action expert should use AdaRMS in pi05, got {adarms_expert_config}" + ) + print("✓ AdaRMS correctly configured: PaliGemma=False, Expert=True") + + return True + + +def test_pi05_forward_pass(): + """Test forward pass with pi05=True.""" + print("\nTesting PI0.5 forward pass...") + + # Create config with pi05=True + config = PI0OpenPIConfig( + action_dim=7, + state_dim=14, + dtype="float32", + pi05=True, + action_horizon=16, # Shorter horizon for testing + ) + + # Create dummy dataset stats + dataset_stats = { + "observation.state": { + "mean": torch.zeros(14), + "std": torch.ones(14), + }, + "action": { + "mean": torch.zeros(7), + "std": torch.ones(7), + }, + } + + # Instantiate policy + policy = PI0OpenPIPolicy(config, dataset_stats) + + # Create test batch + batch_size = 2 + device = next(policy.parameters()).device + batch = { + "observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device), + "action": torch.randn(batch_size, config.action_horizon, 7, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), + "task": ["Pick up the object"] * batch_size, + } + + # Test forward pass + try: + loss, loss_dict = policy.forward(batch) + print(f"✓ Forward pass successful. Loss: {loss_dict['loss']:.4f}") + assert not torch.isnan(loss), "Loss is NaN" + assert loss.item() >= 0, "Loss should be non-negative" + except Exception as e: + print(f"✗ Forward pass failed: {e}") + return False + + # Test action prediction + try: + with torch.no_grad(): + action = policy.select_action(batch) + print(f"✓ Action prediction successful. Action shape: {action.shape}") + assert action.shape == (7,), f"Expected action shape (7,), got {action.shape}" + assert not torch.isnan(action).any(), "Action contains NaN values" + except Exception as e: + print(f"✗ Action prediction failed: {e}") + return False + + return True + + +def test_pi0_vs_pi05_differences(): + """Test key differences between pi0 and pi05 modes.""" + print("\nComparing PI0 vs PI0.5 architectures...") + + # Create both configurations + config_pi0 = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32", pi05=False) + config_pi05 = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32", pi05=True) + + dataset_stats = { + "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, + "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + } + + # Create both models + policy_pi0 = PI0OpenPIPolicy(config_pi0, dataset_stats) + policy_pi05 = PI0OpenPIPolicy(config_pi05, dataset_stats) + + print("\nPI0 Model:") + print(f" - Tokenizer max length: {config_pi0.tokenizer_max_length}") + print(f" - discrete_state_input: {config_pi0.discrete_state_input}") + print(f" - Has state_proj: {hasattr(policy_pi0.model, 'state_proj')}") + print(f" - Has action_time_mlp: {hasattr(policy_pi0.model, 'action_time_mlp_in')}") + print(f" - Has time_mlp: {hasattr(policy_pi0.model, 'time_mlp_in')}") + print(f" - Uses AdaRMS: {policy_pi0.model.paligemma_with_expert.gemma_expert.config.use_adarms}") + + print("\nPI0.5 Model:") + print(f" - Tokenizer max length: {config_pi05.tokenizer_max_length}") + print(f" - discrete_state_input: {config_pi05.discrete_state_input}") + print(f" - Has state_proj: {hasattr(policy_pi05.model, 'state_proj')}") + print(f" - Has action_time_mlp: {hasattr(policy_pi05.model, 'action_time_mlp_in')}") + print(f" - Has time_mlp: {hasattr(policy_pi05.model, 'time_mlp_in')}") + print(f" - Uses AdaRMS: {policy_pi05.model.paligemma_with_expert.gemma_expert.config.use_adarms}") + + # Count parameters + pi0_params = sum(p.numel() for p in policy_pi0.parameters()) + pi05_params = sum(p.numel() for p in policy_pi05.parameters()) + + print("\nParameter counts:") + print(f" - PI0: {pi0_params:,}") + print(f" - PI0.5: {pi05_params:,}") + print(f" - Difference: {pi0_params - pi05_params:,} (PI0.5 has fewer params due to no state embedding)") + + return True + + +def main(): + """Run all PI0.5 tests.""" + print("=" * 60) + print("PI0.5 Support Test Suite") + print("=" * 60) + + tests = [ + ("Model Architecture", test_pi05_model_architecture), + ("Forward Pass", test_pi05_forward_pass), + ("PI0 vs PI0.5 Comparison", test_pi0_vs_pi05_differences), + ] + + all_passed = True + for test_name, test_func in tests: + print(f"\n[{test_name}]") + print("-" * 40) + try: + if not test_func(): + all_passed = False + print(f"✗ {test_name} failed") + except Exception as e: + all_passed = False + print(f"✗ {test_name} failed with exception: {e}") + import traceback + + traceback.print_exc() + + print("\n" + "=" * 60) + if all_passed: + print("✅ All PI0.5 tests passed!") + else: + print("❌ Some tests failed.") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/test_pi0_hub.py b/test_pi0_hub.py index dec52b3c8..96c15cfe7 100644 --- a/test_pi0_hub.py +++ b/test_pi0_hub.py @@ -30,14 +30,16 @@ def create_dummy_stats(config): return dummy_stats -def test_hub_loading(): - """Test loading model from HuggingFace hub.""" - print("=" * 60) - print("PI0OpenPI HuggingFace Hub Loading Test") - print("=" * 60) +def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"): + """Test loading model from HuggingFace hub. - # Model ID on HuggingFace hub - model_id = "pepijn223/pi0_base_fp32" # We made sure this config matches our code and `PI0OpenPIConfig` by uploading a model with push_pi0_to_hub.py and copying that config. + Args: + model_id: HuggingFace model ID to load + model_name: Display name for the model (e.g., "PI0", "PI0.5") + """ + print("=" * 60) + print(f"{model_name} OpenPI HuggingFace Hub Loading Test") + print("=" * 60) print(f"\nLoading model from: {model_id}") print("-" * 60) @@ -67,14 +69,45 @@ def test_hub_loading(): # Get model info print("\nModel configuration:") + print(f" - Model type: {'PI0.5' if policy.config.pi05 else 'PI0'}") print(f" - PaliGemma variant: {policy.config.paligemma_variant}") print(f" - Action expert variant: {policy.config.action_expert_variant}") print(f" - Action dimension: {policy.config.action_dim}") print(f" - State dimension: {policy.config.state_dim}") print(f" - Action horizon: {policy.config.action_horizon}") + print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}") + print(f" - discrete_state_input: {policy.config.discrete_state_input}") print(f" - Device: {device}") print(f" - Dtype: {next(policy.parameters()).dtype}") + # Check model-specific features + if policy.config.pi05: + print("\nPI0.5 specific features:") + print(f" - Has time_mlp layers: {hasattr(policy.model, 'time_mlp_in')}") + print(f" - Has state_proj: {hasattr(policy.model, 'state_proj')} (should be False)") + print(f" - Uses AdaRMS: {policy.model.paligemma_with_expert.gemma_expert.config.use_adarms}") + + # Verify PI0.5 architecture + assert hasattr(policy.model, "time_mlp_in"), "PI0.5 should have time_mlp_in" + assert hasattr(policy.model, "time_mlp_out"), "PI0.5 should have time_mlp_out" + assert not hasattr(policy.model, "state_proj"), "PI0.5 should not have state_proj" + assert not hasattr(policy.model, "action_time_mlp_in"), "PI0.5 should not have action_time_mlp_in" + print(" ✓ PI0.5 architecture verified") + else: + print("\nPI0 specific features:") + print(f" - Has action_time_mlp layers: {hasattr(policy.model, 'action_time_mlp_in')}") + print(f" - Has state_proj: {hasattr(policy.model, 'state_proj')} (should be True)") + print( + f" - Uses AdaRMS: {policy.model.paligemma_with_expert.gemma_expert.config.use_adarms} (should be False)" + ) + + # Verify PI0 architecture + assert hasattr(policy.model, "action_time_mlp_in"), "PI0 should have action_time_mlp_in" + assert hasattr(policy.model, "action_time_mlp_out"), "PI0 should have action_time_mlp_out" + assert hasattr(policy.model, "state_proj"), "PI0 should have state_proj" + assert not hasattr(policy.model, "time_mlp_in"), "PI0 should not have time_mlp_in" + print(" ✓ PI0 architecture verified") + except Exception as e: print(f"✗ Failed to load model: {e}") return False @@ -177,11 +210,55 @@ def test_hub_loading(): return False print("\n" + "=" * 60) - print("✓ All tests passed!") + print(f"✓ All tests passed for {model_name}!") print("=" * 60) return True +def main(): + """Run tests for both PI0 and PI0.5 models.""" + print("\n") + print("╔" + "═" * 58 + "╗") + print("║" + " PI0 & PI0.5 HuggingFace Hub Loading Test Suite ".center(58) + "║") + print("╚" + "═" * 58 + "╝") + print() + + results = [] + + # Test PI0 model + print("\n[Test 1/2] Testing PI0 model...") + print("─" * 60) + pi0_success = test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0") + results.append(("PI0", pi0_success)) + + # Test PI0.5 model + print("\n\n[Test 2/2] Testing PI0.5 model...") + print("─" * 60) + pi05_success = test_hub_loading(model_id="pepijn223/pi05_base_fp32", model_name="PI0.5") + results.append(("PI0.5", pi05_success)) + + # Summary + print("\n\n") + print("╔" + "═" * 58 + "╗") + print("║" + " TEST SUMMARY ".center(58) + "║") + print("╚" + "═" * 58 + "╝") + + all_passed = True + for model_name, success in results: + status = "✅ PASSED" if success else "❌ FAILED" + print(f" {model_name:10} : {status}") + if not success: + all_passed = False + + print() + if all_passed: + print("🎉 All models loaded and tested successfully!") + else: + print("⚠️ Some tests failed. Check the output above for details.") + + return all_passed + + if __name__ == "__main__": - success = test_hub_loading() + success = main() exit(0 if success else 1)