diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index 35c6f7c9a..ce813fdb8 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -991,14 +991,22 @@ class PI0OpenPIPolicy(PreTrainedPolicy): r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight", key, ): - # This key structure suggests old model without adaRMS - keep as is or skip - logging.warning(f"Skipping old layer norm key (no adaRMS support): {key}") - continue + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}") + continue if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key): - # Skip old norm structure - logging.warning(f"Skipping old norm key (no adaRMS support): {key}") - continue + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping norm key (adaRMS mismatch): {key}") + continue # Handle MLP naming changes for pi0 # non-pi05 model expects action_time_mlp_*, but checkpoint might have time_mlp_* diff --git a/test_pi0_original_vs_lerobot.py b/test_pi0_original_vs_lerobot.py new file mode 100644 index 000000000..e5cdf3dd7 --- /dev/null +++ b/test_pi0_original_vs_lerobot.py @@ -0,0 +1,316 @@ +"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation.""" + +import os + +import torch + +# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions. +from openpi.models_pytorch.pi0_pytorch import PI0Pytorch + +from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy + +DUMMY_ACTION_DIM = 32 +DUMMY_STATE_DIM = 32 +DUMMY_ACTION_HORIZON = 50 +DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05) +DEVICE = "cpu" # Use CPU to avoid memory issues for testing + +DUMMY_DATASET_STATS = { + "observation.state": { + "mean": torch.zeros(DUMMY_STATE_DIM), + "std": torch.ones(DUMMY_STATE_DIM), + }, + "action": { + "mean": torch.zeros(DUMMY_ACTION_DIM), + "std": torch.ones(DUMMY_ACTION_DIM), + }, + "images": { + "base_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + }, + "left_wrist_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + }, + "right_wrist_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + }, + }, +} + + +class PI0BaseOriginalConfig: + action_dim: int = DUMMY_ACTION_DIM + action_horizon: int = DUMMY_ACTION_HORIZON + paligemma_variant: str = "gemma_2b" + action_expert_variant: str = "gemma_300m" + precision: str = "float32" + pi05: bool = False + dtype: str = "float32" + + +def instantiate_lerobot_pi0(from_pretrained: bool = False): + if from_pretrained: + # Load the policy first + policy = PI0OpenPIPolicy.from_pretrained("pepijn223/pi0_base_fp32") + # Then reinitialize the normalization with proper stats + from lerobot.policies.normalize import Normalize, Unnormalize + + policy.normalize_inputs = Normalize( + policy.config.input_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS + ) + policy.normalize_targets = Normalize( + policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS + ) + policy.unnormalize_outputs = Unnormalize( + policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS + ) + else: + config = PI0OpenPIConfig(action_dim=DUMMY_ACTION_DIM, state_dim=DUMMY_STATE_DIM, dtype="float32") + policy = PI0OpenPIPolicy(config, DUMMY_DATASET_STATS) + policy.to(DEVICE) + return policy + + +def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None): + config = PI0BaseOriginalConfig() + policy = PI0Pytorch(config) + + if from_pretrained: + try: + print("Loading converted PyTorch weights from HuggingFace Hub (pepijn223/pi0_base_fp32)...") + + # Download the model from HuggingFace Hub + import safetensors.torch + from huggingface_hub import snapshot_download + + # Download the entire repository + if model_path and os.path.exists(model_path): + cache_dir = model_path + print(f"Using cached model from: {cache_dir}") + else: + cache_dir = snapshot_download(repo_id="pepijn223/pi0_base_fp32", repo_type="model") + print(f"Downloaded model to: {cache_dir}") + + # Try to load safetensors format first + model_file = os.path.join(cache_dir, "model.safetensors") + if os.path.exists(model_file): + state_dict = safetensors.torch.load_file(model_file) + print(f"Loaded {len(state_dict)} parameters from safetensors") + else: + raise FileNotFoundError(f"No safetensors file found in {cache_dir}") + + # Load the state dict into the model + missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False) + + if missing_keys: + print(f"Missing keys: {len(missing_keys)}") + if len(missing_keys) <= 5: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:5]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 5} more") + + if unexpected_keys: + print(f"Unexpected keys: {len(unexpected_keys)}") + if len(unexpected_keys) <= 5: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:5]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 5} more") + + if not missing_keys and not unexpected_keys: + print("All pretrained weights loaded successfully!") + else: + print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)") + + except Exception as e: + print(f"Failed to load pretrained weights: {e}") + print(" Using randomly initialized weights...") + import traceback + + traceback.print_exc() + + policy.to(DEVICE) + return policy + + +def create_dummy_data(): + batch_size = 2 # Reduce batch size for testing + device = DEVICE + + # Use the exact same prompt for both implementations + prompt = "Pick up the red block and place it in the bin" + + batch = { + "observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device), + "action": torch.randn( + batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device + ), + # Create images in [-1, 1] range as expected by both implementations + "observation.images.base_0_rgb": torch.randn( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ).clamp(-1, 1), + "observation.images.left_wrist_0_rgb": torch.randn( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ).clamp(-1, 1), + "observation.images.right_wrist_0_rgb": torch.randn( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ).clamp(-1, 1), + # Add the task prompt for LeRobot - provide as list with single element to trigger expansion + "task": [prompt], + } + return batch + + +def extract_lerobot_processed_inputs(lerobot_pi0, batch): + """Extract the exact same processed inputs that LeRobot uses internally.""" + # Get the tokenized language from LeRobot's internal method + lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch) + + # Get the preprocessed images from LeRobot's internal method + images, img_masks = lerobot_pi0._preprocess_images(batch) + + # Create dummy token_ar_mask and token_loss_mask for original implementation + token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32) + token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool) + + return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask + + +class PI0Observation: + """Observation class that matches the original OpenPI format.""" + + def __init__( + self, + state, + images, + image_masks, + tokenized_prompt, + tokenized_prompt_mask, + token_ar_mask, + token_loss_mask, + ): + self.state = state + self.images = images + self.image_masks = image_masks + self.tokenized_prompt = tokenized_prompt + self.tokenized_prompt_mask = tokenized_prompt_mask + self.token_ar_mask = token_ar_mask + self.token_loss_mask = token_loss_mask + + +def create_original_observation_from_lerobot(lerobot_pi0, batch): + """Create observation object compatible with original OpenPI using the exact same inputs as LeRobot.""" + _batch_size = batch["observation.state"].shape[0] + _device = batch["observation.state"].device + + # Extract the exact same processed inputs that LeRobot uses + images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = ( + extract_lerobot_processed_inputs(lerobot_pi0, batch) + ) + + # Convert images list to dict with original OpenPI keys + image_dict = { + "base_0_rgb": images[0], + "left_wrist_0_rgb": images[1], + "right_wrist_0_rgb": images[2], + } + + # Convert image masks list to dict with original OpenPI keys + image_masks_dict = { + "base_0_rgb": img_masks[0], + "left_wrist_0_rgb": img_masks[1], + "right_wrist_0_rgb": img_masks[2], + } + + return PI0Observation( + state=batch["observation.state"], + images=image_dict, + image_masks=image_masks_dict, + tokenized_prompt=lang_tokens, + tokenized_prompt_mask=lang_masks, + token_ar_mask=token_ar_mask, + token_loss_mask=token_loss_mask, + ) + + +def main(): + print("Initializing models...") + lerobot_pi0 = instantiate_lerobot_pi0(from_pretrained=True) # Load pretrained LeRobot model + original_pi0 = instantiate_original_pi0( + from_pretrained=True + ) # Load pretrained OpenPI model from HuggingFace Hub + + print("Creating dummy data...") + batch = create_dummy_data() + + print("Creating observation for original PI0 using LeRobot's exact preprocessing...") + pi0_obs = create_original_observation_from_lerobot(lerobot_pi0, batch) + + # Verify both implementations get the same inputs + print(f"Task prompt: '{batch['task'][0]}'") + print(f"Tokenized prompt shape: {pi0_obs.tokenized_prompt.shape}") + print(f"Image shapes: {[img.shape for img in pi0_obs.images.values()]}") + print(f"State shape: {pi0_obs.state.shape}") + + print("Testing original PI0...") + + # Test training forward pass (returns loss) + print("1. Training forward pass (computing loss):") + original_pi0.train() + original_loss = original_pi0(observation=pi0_obs, actions=batch["action"]) + print(f" Loss shape: {original_loss.shape}, Mean loss: {original_loss.mean().item():.6f}") + + # Test inference (action sampling) with fixed noise for reproducibility + print("2. Inference (action sampling):") + original_pi0.eval() + + # Create the same noise for both implementations + torch.manual_seed(42) # Set seed for reproducibility + batch_size = batch["observation.state"].shape[0] + noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM) + fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE) + + with torch.no_grad(): + original_actions = original_pi0.sample_actions( + device=DEVICE, observation=pi0_obs, noise=fixed_noise, num_steps=10 + ) + print(f"Original PI0 Actions shape: {original_actions.shape}") + print(f"Original PI0 Actions mean: {original_actions.mean().item():.6f}") + print(f"Original PI0 Actions std: {original_actions.std().item():.6f}") + + # Test LeRobot implementation with the same noise + print("\nTesting LeRobot PI0...") + lerobot_pi0.eval() + + # For LeRobot, we need to modify the batch to force the same noise + # This is more complex since LeRobot generates noise internally + torch.manual_seed(42) # Set the same seed + with torch.no_grad(): + # lerobot_pi0_actions = lerobot_pi0.select_action(batch) + lerobot_pi0_actions = lerobot_pi0.predict_action_chunk(batch) + print(f"LeRobot actions shape: {lerobot_pi0_actions.shape}") + print(f"LeRobot actions mean: {lerobot_pi0_actions.mean().item():.6f}") + print(f"LeRobot actions std: {lerobot_pi0_actions.std().item():.6f}") + + print("\nComparing implementations:") + print(f"Original actions shape: {original_actions.shape}") + print(f"LeRobot actions shape: {lerobot_pi0_actions.shape}") + + # Compare the first action step (since LeRobot select_action returns a single step) + print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_pi0_actions, original_actions, atol=1e-4)}") + print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_pi0_actions, original_actions, atol=1e-2)}") + print(f"Max absolute difference: {torch.abs(lerobot_pi0_actions - original_actions).max().item():.6f}") + + print("\nOriginal PI0 test completed successfully!") + + +if __name__ == "__main__": + main()