TODO: Make test works

This commit is contained in:
AdilZouitine
2025-09-19 18:10:47 +02:00
parent 10f5ea854f
commit f077bbae5d
3 changed files with 57 additions and 26 deletions
@@ -866,6 +866,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
self.model.gradient_checkpointing_enable() self.model.gradient_checkpointing_enable()
self.reset() self.reset()
self.model.to(config.device)
@classmethod @classmethod
def from_pretrained( def from_pretrained(
@@ -908,8 +909,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
# Initialize model without loading weights # Initialize model without loading weights
# Check if dataset_stats were provided in kwargs # Check if dataset_stats were provided in kwargs
dataset_stats = kwargs.get("dataset_stats") # TODO(Adil, Pepijn): Remove this with pipeline model = cls(config, **kwargs)
model = cls(config, dataset_stats=dataset_stats, **kwargs)
# Now manually load and remap the state dict # Now manually load and remap the state dict
try: try:
+18 -5
View File
@@ -14,13 +14,22 @@ pytestmark = pytest.mark.skipif(
) )
from lerobot.policies.factory import make_policy_config # noqa: E402 from lerobot.policies.factory import make_policy_config # noqa: E402
from lerobot.policies.pi0 import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402 from lerobot.policies.pi0_openpi import ( # noqa: E402
PI0OpenPIConfig,
PI0OpenPIPolicy,
make_pi0_openpi_pre_post_processors, # noqa: E402
)
from lerobot.utils.random_utils import set_seed # noqa: E402
from tests.utils import require_cuda # noqa: E402 from tests.utils import require_cuda # noqa: E402
# Set seed
@require_cuda @require_cuda
def test_policy_instantiation(): def test_policy_instantiation():
# Create config # Create config
set_seed(42)
config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32") config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
# Set up input_features and output_features in the config # Set up input_features and output_features in the config
@@ -61,11 +70,13 @@ def test_policy_instantiation():
} }
# Instantiate policy # Instantiate policy
policy = PI0OpenPIPolicy(config, dataset_stats) policy = PI0OpenPIPolicy(config)
preprocessor, postprocessor = make_pi0_openpi_pre_post_processors(
config=config, dataset_stats=dataset_stats
)
# Test forward pass with dummy data # Test forward pass with dummy data
batch_size = 1 batch_size = 1
device = policy.device if hasattr(policy, "device") else "cpu" device = config.device
batch = { batch = {
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device), "observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device), "action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
@@ -74,7 +85,7 @@ def test_policy_instantiation():
), # Use rand for [0,1] range ), # Use rand for [0,1] range
"task": ["Pick up the object"] * batch_size, "task": ["Pick up the object"] * batch_size,
} }
batch = preprocessor(batch)
try: try:
loss, loss_dict = policy.forward(batch) loss, loss_dict = policy.forward(batch)
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}") print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
@@ -85,6 +96,8 @@ def test_policy_instantiation():
try: try:
with torch.no_grad(): with torch.no_grad():
action = policy.select_action(batch) action = policy.select_action(batch)
action = postprocessor(action)
print(f"Action: {action}")
print(f"Action prediction successful. Action shape: {action.shape}") print(f"Action prediction successful. Action shape: {action.shape}")
except Exception as e: except Exception as e:
print(f"Action prediction failed: {e}") print(f"Action prediction failed: {e}")
@@ -1,6 +1,8 @@
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation, only meant to be run locally!""" """Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
import os import os
from copy import deepcopy
from typing import Any
import pytest import pytest
import torch import torch
@@ -21,8 +23,11 @@ from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402 from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from transformers import AutoTokenizer # noqa: E402 from transformers import AutoTokenizer # noqa: E402
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402 from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402
from lerobot.policies.pi0_openpi.processor_pi0_openpi import make_pi0_openpi_pre_post_processors # noqa: E402
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
DUMMY_ACTION_DIM = 32 DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32 DUMMY_STATE_DIM = 32
DUMMY_ACTION_HORIZON = 50 DUMMY_ACTION_HORIZON = 50
@@ -65,27 +70,29 @@ class PI0BaseOriginalConfig:
dtype: str = "float32" dtype: str = "float32"
def instantiate_lerobot_pi0(from_pretrained: bool = False): def instantiate_lerobot_pi0(
from_pretrained: bool = False,
) -> tuple[
PI0OpenPIPolicy,
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
if from_pretrained: if from_pretrained:
# Load the policy first # Load the policy first
policy = PI0Policy.from_pretrained(pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True) policy = PI0OpenPIPolicy.from_pretrained(
# Then reinitialize the normalization with proper stats pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True
from lerobot.policies.normalize import Normalize, Unnormalize
policy.normalize_inputs = Normalize(
policy.config.input_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
)
policy.normalize_targets = Normalize(
policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
)
policy.unnormalize_outputs = Unnormalize(
policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
) )
else: else:
config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32") config = PI0OpenPIConfig(
policy = PI0Policy(config, DUMMY_DATASET_STATS) max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32"
)
policy = PI0OpenPIPolicy(config)
policy.to(DEVICE) policy.to(DEVICE)
return policy preprocessor, postprocessor = make_pi0_openpi_pre_post_processors(
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
)
return (policy, preprocessor, postprocessor)
def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None): def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None):
@@ -324,13 +331,16 @@ def create_original_observation_from_lerobot(lerobot_pi0, batch):
def test_pi0_original_vs_lerobot(): def test_pi0_original_vs_lerobot():
"""Test PI0 original implementation vs LeRobot implementation.""" """Test PI0 original implementation vs LeRobot implementation."""
print("Initializing models...") print("Initializing models...")
lerobot_pi0 = instantiate_lerobot_pi0(from_pretrained=True) # Load pretrained LeRobot model lerobot_pi0, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi0(
from_pretrained=True
) # Load pretrained LeRobot model
original_pi0 = instantiate_original_pi0( original_pi0 = instantiate_original_pi0(
from_pretrained=True from_pretrained=True
) # Load pretrained OpenPI model from HuggingFace Hub ) # Load pretrained OpenPI model from HuggingFace Hub
print("Creating dummy data...") print("Creating dummy data...")
batch = create_dummy_data() batch = create_dummy_data()
batch_lerobot = deepcopy(batch)
# Test 1: Each model with its own preprocessing (more realistic end-to-end test) # Test 1: Each model with its own preprocessing (more realistic end-to-end test)
print("\nTEST 1: Each model with its own preprocessing") print("\nTEST 1: Each model with its own preprocessing")
@@ -353,16 +363,24 @@ def test_pi0_original_vs_lerobot():
openpi_actions = original_pi0.sample_actions( openpi_actions = original_pi0.sample_actions(
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10 device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
) )
openpi_actions_unit = openpi_actions[:, 0, :]
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}") print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}") print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}") print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
print("Testing LeRobot with own preprocessing...") print("Testing LeRobot with own preprocessing...")
lerobot_pi0.eval() lerobot_pi0.eval()
torch.manual_seed(42) # Set the same seed torch.manual_seed(42) # Set the same seed
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
with torch.no_grad(): with torch.no_grad():
lerobot_actions_own = lerobot_pi0.predict_action_chunk(batch) lerobot_actions_own = lerobot_pi0.predict_action_chunk(
batch_lerobot_processed
) # batch_size, n_action_steps, action_dim
lerobot_ations_unit = lerobot_actions_own[:, 0, :]
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}") print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_ations_unit.shape}")
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}") print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}") print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")