use dummy stats

This commit is contained in:
Pepijn
2025-09-10 20:42:48 +02:00
parent 2eafcc7ca1
commit b028907d21
+68 -17
View File
@@ -7,6 +7,29 @@ import torch
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
def create_dummy_stats(config):
"""Create dummy dataset statistics for testing."""
dummy_stats = {
"observation.state": {
"mean": torch.zeros(config.state_dim),
"std": torch.ones(config.state_dim),
},
"action": {
"mean": torch.zeros(config.action_dim),
"std": torch.ones(config.action_dim),
},
}
# Add stats for image keys if they exist
for key in config.image_keys:
dummy_stats[key] = {
"mean": torch.zeros(3, config.image_resolution[0], config.image_resolution[1]),
"std": torch.ones(3, config.image_resolution[0], config.image_resolution[1]),
}
return dummy_stats
def test_hub_loading():
"""Test loading model from HuggingFace hub."""
print("=" * 60)
@@ -27,6 +50,21 @@ def test_hub_loading():
)
print("✓ Model loaded successfully from HuggingFace hub")
# Inject dummy stats since they aren't loaded from the hub
print("Creating dummy dataset stats for testing...")
device = next(policy.parameters()).device
dummy_stats = create_dummy_stats(policy.config)
# Move dummy stats to device
for key, stats in dummy_stats.items():
dummy_stats[key] = {
"mean": stats["mean"].to(device),
"std": stats["std"].to(device),
}
# Initialize normalization layers with dummy stats if they have NaN/inf values
print("✓ Dummy stats created and moved to device")
# Get model info
print("\nModel configuration:")
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
@@ -34,7 +72,7 @@ def test_hub_loading():
print(f" - Action dimension: {policy.config.action_dim}")
print(f" - State dimension: {policy.config.state_dim}")
print(f" - Action horizon: {policy.config.action_horizon}")
print(f" - Device: {next(policy.parameters()).device}")
print(f" - Device: {device}")
print(f" - Dtype: {next(policy.parameters()).dtype}")
except Exception as e:
@@ -46,30 +84,43 @@ def test_hub_loading():
# Create dummy batch for testing
batch_size = 1
device = next(policy.parameters()).device
# Create dummy dataset stats if not loaded with the model
if not hasattr(policy, "normalize_inputs") or policy.normalize_inputs is None:
# Check if normalization layers have invalid stats and replace with dummy stats if needed
try:
# Check if the normalize_inputs has valid stats
if hasattr(policy.normalize_inputs, "stats"):
obs_state_mean = policy.normalize_inputs.stats.get("observation.state", {}).get("mean")
if obs_state_mean is not None and (
torch.isinf(obs_state_mean).any() or torch.isnan(obs_state_mean).any()
):
print("⚠️ Found invalid normalization stats, replacing with dummy stats...")
# Replace with dummy stats
from lerobot.policies.normalize import Normalize, Unnormalize
policy.normalize_inputs = Normalize(
policy.config.input_features, policy.config.normalization_mapping, dummy_stats
)
policy.normalize_targets = Normalize(
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
)
policy.unnormalize_outputs = Unnormalize(
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
)
print("✓ Normalization layers updated with dummy stats")
except Exception as e:
print(f"⚠️ Error checking normalization stats, creating new ones: {e}")
# Fallback: create new normalization layers
from lerobot.policies.normalize import Normalize, Unnormalize
dataset_stats = {
"observation.state": {
"mean": torch.zeros(policy.config.state_dim, device=device),
"std": torch.ones(policy.config.state_dim, device=device),
},
"action": {
"mean": torch.zeros(policy.config.action_dim, device=device),
"std": torch.ones(policy.config.action_dim, device=device),
},
}
policy.normalize_inputs = Normalize(
policy.config.input_features, policy.config.normalization_mapping, dataset_stats
policy.config.input_features, policy.config.normalization_mapping, dummy_stats
)
policy.normalize_targets = Normalize(
policy.config.output_features, policy.config.normalization_mapping, dataset_stats
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
)
policy.unnormalize_outputs = Unnormalize(
policy.config.output_features, policy.config.normalization_mapping, dataset_stats
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
)
# Create test batch