mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
use dummy stats
This commit is contained in:
+68
-17
@@ -7,6 +7,29 @@ import torch
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
|
||||
|
||||
|
||||
def create_dummy_stats(config):
|
||||
"""Create dummy dataset statistics for testing."""
|
||||
dummy_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(config.state_dim),
|
||||
"std": torch.ones(config.state_dim),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(config.action_dim),
|
||||
"std": torch.ones(config.action_dim),
|
||||
},
|
||||
}
|
||||
|
||||
# Add stats for image keys if they exist
|
||||
for key in config.image_keys:
|
||||
dummy_stats[key] = {
|
||||
"mean": torch.zeros(3, config.image_resolution[0], config.image_resolution[1]),
|
||||
"std": torch.ones(3, config.image_resolution[0], config.image_resolution[1]),
|
||||
}
|
||||
|
||||
return dummy_stats
|
||||
|
||||
|
||||
def test_hub_loading():
|
||||
"""Test loading model from HuggingFace hub."""
|
||||
print("=" * 60)
|
||||
@@ -27,6 +50,21 @@ def test_hub_loading():
|
||||
)
|
||||
print("✓ Model loaded successfully from HuggingFace hub")
|
||||
|
||||
# Inject dummy stats since they aren't loaded from the hub
|
||||
print("Creating dummy dataset stats for testing...")
|
||||
device = next(policy.parameters()).device
|
||||
dummy_stats = create_dummy_stats(policy.config)
|
||||
|
||||
# Move dummy stats to device
|
||||
for key, stats in dummy_stats.items():
|
||||
dummy_stats[key] = {
|
||||
"mean": stats["mean"].to(device),
|
||||
"std": stats["std"].to(device),
|
||||
}
|
||||
|
||||
# Initialize normalization layers with dummy stats if they have NaN/inf values
|
||||
print("✓ Dummy stats created and moved to device")
|
||||
|
||||
# Get model info
|
||||
print("\nModel configuration:")
|
||||
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
|
||||
@@ -34,7 +72,7 @@ def test_hub_loading():
|
||||
print(f" - Action dimension: {policy.config.action_dim}")
|
||||
print(f" - State dimension: {policy.config.state_dim}")
|
||||
print(f" - Action horizon: {policy.config.action_horizon}")
|
||||
print(f" - Device: {next(policy.parameters()).device}")
|
||||
print(f" - Device: {device}")
|
||||
print(f" - Dtype: {next(policy.parameters()).dtype}")
|
||||
|
||||
except Exception as e:
|
||||
@@ -46,30 +84,43 @@ def test_hub_loading():
|
||||
|
||||
# Create dummy batch for testing
|
||||
batch_size = 1
|
||||
device = next(policy.parameters()).device
|
||||
|
||||
# Create dummy dataset stats if not loaded with the model
|
||||
if not hasattr(policy, "normalize_inputs") or policy.normalize_inputs is None:
|
||||
# Check if normalization layers have invalid stats and replace with dummy stats if needed
|
||||
try:
|
||||
# Check if the normalize_inputs has valid stats
|
||||
if hasattr(policy.normalize_inputs, "stats"):
|
||||
obs_state_mean = policy.normalize_inputs.stats.get("observation.state", {}).get("mean")
|
||||
if obs_state_mean is not None and (
|
||||
torch.isinf(obs_state_mean).any() or torch.isnan(obs_state_mean).any()
|
||||
):
|
||||
print("⚠️ Found invalid normalization stats, replacing with dummy stats...")
|
||||
|
||||
# Replace with dummy stats
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
policy.normalize_inputs = Normalize(
|
||||
policy.config.input_features, policy.config.normalization_mapping, dummy_stats
|
||||
)
|
||||
policy.normalize_targets = Normalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
||||
)
|
||||
policy.unnormalize_outputs = Unnormalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
||||
)
|
||||
print("✓ Normalization layers updated with dummy stats")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error checking normalization stats, creating new ones: {e}")
|
||||
# Fallback: create new normalization layers
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(policy.config.state_dim, device=device),
|
||||
"std": torch.ones(policy.config.state_dim, device=device),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(policy.config.action_dim, device=device),
|
||||
"std": torch.ones(policy.config.action_dim, device=device),
|
||||
},
|
||||
}
|
||||
policy.normalize_inputs = Normalize(
|
||||
policy.config.input_features, policy.config.normalization_mapping, dataset_stats
|
||||
policy.config.input_features, policy.config.normalization_mapping, dummy_stats
|
||||
)
|
||||
policy.normalize_targets = Normalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, dataset_stats
|
||||
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
||||
)
|
||||
policy.unnormalize_outputs = Unnormalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, dataset_stats
|
||||
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
||||
)
|
||||
|
||||
# Create test batch
|
||||
|
||||
Reference in New Issue
Block a user