add preprocess tests

This commit is contained in:
Pepijn
2025-09-12 21:41:25 +02:00
parent 376cc772ff
commit c8163662ad
3 changed files with 130 additions and 52 deletions
+130 -50
View File
@@ -3,9 +3,11 @@
import os
import torch
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch
from transformers import AutoTokenizer
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
@@ -54,7 +56,9 @@ class PI0BaseOriginalConfig:
def instantiate_lerobot_pi0(from_pretrained: bool = False):
if from_pretrained:
# Load the policy first
policy = PI0OpenPIPolicy.from_pretrained("pepijn223/pi0_base_fp32")
policy = PI0OpenPIPolicy.from_pretrained(
pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True
)
# Then reinitialize the normalization with proper stats
from lerobot.policies.normalize import Normalize, Unnormalize
@@ -153,16 +157,16 @@ def create_dummy_data():
"action": torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
),
# Create images in [-1, 1] range as expected by both implementations
"observation.images.base_0_rgb": torch.randn(
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
).clamp(-1, 1),
"observation.images.left_wrist_0_rgb": torch.randn(
),
"observation.images.left_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
).clamp(-1, 1),
"observation.images.right_wrist_0_rgb": torch.randn(
),
"observation.images.right_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
).clamp(-1, 1),
),
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
"task": [prompt],
}
@@ -175,7 +179,7 @@ def extract_lerobot_processed_inputs(lerobot_pi0, batch):
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
# Get the preprocessed images from LeRobot's internal method
images, img_masks = lerobot_pi0._preprocess_images(batch)
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
# Create dummy token_ar_mask and token_loss_mask for original implementation
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
@@ -206,6 +210,72 @@ class PI0Observation:
self.token_loss_mask = token_loss_mask
def create_original_observation_with_openpi_preprocessing(batch):
"""Create observation object for OpenPI using OpenPI's own preprocessing."""
batch_size = batch["observation.state"].shape[0]
device = batch["observation.state"].device
# Create tokenizer for OpenPI (same as LeRobot uses)
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
# Get task description
if "task" in batch:
tasks = batch["task"]
if isinstance(tasks, str):
tasks = [tasks]
elif isinstance(tasks, list) and len(tasks) == 1:
# Expand to batch size
tasks = tasks * batch_size
else:
# Default task if not provided
tasks = ["Pick up the object"] * batch_size
# Tokenize with max_length padding to match OpenPI's expected format
tokenized = tokenizer(
tasks,
padding="max_length",
padding_side="right",
truncation=True,
max_length=DUMMY_MAX_TOKEN_LEN,
return_tensors="pt",
)
lang_tokens = tokenized["input_ids"].to(device)
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
# Create dummy token_ar_mask and token_loss_mask for OpenPI
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
image_dict = {
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
}
# Create image masks (all ones for real images)
image_masks_dict = {}
for key in image_dict:
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
# Create raw observation object (before preprocessing)
raw_observation = PI0Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
# Now use OpenPI's preprocessing
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
return processed_obs
def create_original_observation_from_lerobot(lerobot_pi0, batch):
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
_batch_size = batch["observation.state"].shape[0]
@@ -251,65 +321,75 @@ def main():
print("Creating dummy data...")
batch = create_dummy_data()
print("Creating observation for original PI0 using LeRobot's exact preprocessing...")
pi0_obs = create_original_observation_from_lerobot(lerobot_pi0, batch)
# Test 1: Each model with its own preprocessing (more realistic end-to-end test)
print("\n=== TEST 1: Each model with its own preprocessing ===")
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
# Verify both implementations get the same inputs
print(f"Task prompt: '{batch['task'][0]}'")
print(f"Tokenized prompt shape: {pi0_obs.tokenized_prompt.shape}")
print(f"Image shapes: {[img.shape for img in pi0_obs.images.values()]}")
print(f"State shape: {pi0_obs.state.shape}")
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
print("Testing original PI0...")
# Test training forward pass (returns loss)
print("1. Training forward pass (computing loss):")
original_pi0.train()
original_loss = original_pi0(observation=pi0_obs, actions=batch["action"])
print(f" Loss shape: {original_loss.shape}, Mean loss: {original_loss.mean().item():.6f}")
# Test inference (action sampling) with fixed noise for reproducibility
print("2. Inference (action sampling):")
print("Testing OpenPI with own preprocessing...")
original_pi0.eval()
# Create the same noise for both implementations
torch.manual_seed(42) # Set seed for reproducibility
batch_size = batch["observation.state"].shape[0]
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
with torch.no_grad():
original_actions = original_pi0.sample_actions(
device=DEVICE, observation=pi0_obs, noise=fixed_noise, num_steps=10
openpi_actions = original_pi0.sample_actions(
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
)
print(f"Original PI0 Actions shape: {original_actions.shape}")
print(f"Original PI0 Actions mean: {original_actions.mean().item():.6f}")
print(f"Original PI0 Actions std: {original_actions.std().item():.6f}")
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
# Test LeRobot implementation with the same noise
print("\nTesting LeRobot PI0...")
print("Testing LeRobot with own preprocessing...")
lerobot_pi0.eval()
# For LeRobot, we need to modify the batch to force the same noise
# This is more complex since LeRobot generates noise internally
torch.manual_seed(42) # Set the same seed
with torch.no_grad():
# lerobot_pi0_actions = lerobot_pi0.select_action(batch)
lerobot_pi0_actions = lerobot_pi0.predict_action_chunk(batch)
print(f"LeRobot actions shape: {lerobot_pi0_actions.shape}")
print(f"LeRobot actions mean: {lerobot_pi0_actions.mean().item():.6f}")
print(f"LeRobot actions std: {lerobot_pi0_actions.std().item():.6f}")
lerobot_actions_own = lerobot_pi0.predict_action_chunk(batch)
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
print("\nComparing implementations:")
print(f"Original actions shape: {original_actions.shape}")
print(f"LeRobot actions shape: {lerobot_pi0_actions.shape}")
print("\nComparing end-to-end implementations:")
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
# Compare the first action step (since LeRobot select_action returns a single step)
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_pi0_actions, original_actions, atol=1e-4)}")
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_pi0_actions, original_actions, atol=1e-2)}")
print(f"Max absolute difference: {torch.abs(lerobot_pi0_actions - original_actions).max().item():.6f}")
# Test 2: Both models with LeRobot preprocessing (isolates model differences)
print("\n=== TEST 2: Both models with LeRobot preprocessing (model comparison) ===")
print("Creating observation for OpenPI using LeRobot's preprocessing...")
pi0_obs_lerobot = create_original_observation_from_lerobot(lerobot_pi0, batch)
print("\nOriginal PI0 test completed successfully!")
print("Testing OpenPI with LeRobot preprocessing...")
torch.manual_seed(42) # Set seed for reproducibility
with torch.no_grad():
openpi_actions_lerobot_preproc = original_pi0.sample_actions(
device=DEVICE, observation=pi0_obs_lerobot, noise=fixed_noise, num_steps=10
)
print(f"OpenPI (LeRobot preprocessing) Actions shape: {openpi_actions_lerobot_preproc.shape}")
print(f"OpenPI (LeRobot preprocessing) Actions mean: {openpi_actions_lerobot_preproc.mean().item():.6f}")
print(f"OpenPI (LeRobot preprocessing) Actions std: {openpi_actions_lerobot_preproc.std().item():.6f}")
print("\nComparing models with same preprocessing:")
print(
f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-4)}"
)
print(
f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-2)}"
)
print(
f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions_lerobot_preproc).max().item():.6f}"
)
print("\n=== SUMMARY ===")
print("Test 1 compares end-to-end pipelines (each model with its own preprocessing)")
print("Test 2 isolates model differences (both models with LeRobot preprocessing)")
print("Both tests completed successfully!")
if __name__ == "__main__":