mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
use dummy stats
This commit is contained in:
+68
-17
@@ -7,6 +7,29 @@ import torch
|
|||||||
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
|
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_stats(config):
|
||||||
|
"""Create dummy dataset statistics for testing."""
|
||||||
|
dummy_stats = {
|
||||||
|
"observation.state": {
|
||||||
|
"mean": torch.zeros(config.state_dim),
|
||||||
|
"std": torch.ones(config.state_dim),
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"mean": torch.zeros(config.action_dim),
|
||||||
|
"std": torch.ones(config.action_dim),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add stats for image keys if they exist
|
||||||
|
for key in config.image_keys:
|
||||||
|
dummy_stats[key] = {
|
||||||
|
"mean": torch.zeros(3, config.image_resolution[0], config.image_resolution[1]),
|
||||||
|
"std": torch.ones(3, config.image_resolution[0], config.image_resolution[1]),
|
||||||
|
}
|
||||||
|
|
||||||
|
return dummy_stats
|
||||||
|
|
||||||
|
|
||||||
def test_hub_loading():
|
def test_hub_loading():
|
||||||
"""Test loading model from HuggingFace hub."""
|
"""Test loading model from HuggingFace hub."""
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
@@ -27,6 +50,21 @@ def test_hub_loading():
|
|||||||
)
|
)
|
||||||
print("✓ Model loaded successfully from HuggingFace hub")
|
print("✓ Model loaded successfully from HuggingFace hub")
|
||||||
|
|
||||||
|
# Inject dummy stats since they aren't loaded from the hub
|
||||||
|
print("Creating dummy dataset stats for testing...")
|
||||||
|
device = next(policy.parameters()).device
|
||||||
|
dummy_stats = create_dummy_stats(policy.config)
|
||||||
|
|
||||||
|
# Move dummy stats to device
|
||||||
|
for key, stats in dummy_stats.items():
|
||||||
|
dummy_stats[key] = {
|
||||||
|
"mean": stats["mean"].to(device),
|
||||||
|
"std": stats["std"].to(device),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize normalization layers with dummy stats if they have NaN/inf values
|
||||||
|
print("✓ Dummy stats created and moved to device")
|
||||||
|
|
||||||
# Get model info
|
# Get model info
|
||||||
print("\nModel configuration:")
|
print("\nModel configuration:")
|
||||||
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
|
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
|
||||||
@@ -34,7 +72,7 @@ def test_hub_loading():
|
|||||||
print(f" - Action dimension: {policy.config.action_dim}")
|
print(f" - Action dimension: {policy.config.action_dim}")
|
||||||
print(f" - State dimension: {policy.config.state_dim}")
|
print(f" - State dimension: {policy.config.state_dim}")
|
||||||
print(f" - Action horizon: {policy.config.action_horizon}")
|
print(f" - Action horizon: {policy.config.action_horizon}")
|
||||||
print(f" - Device: {next(policy.parameters()).device}")
|
print(f" - Device: {device}")
|
||||||
print(f" - Dtype: {next(policy.parameters()).dtype}")
|
print(f" - Dtype: {next(policy.parameters()).dtype}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -46,30 +84,43 @@ def test_hub_loading():
|
|||||||
|
|
||||||
# Create dummy batch for testing
|
# Create dummy batch for testing
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
device = next(policy.parameters()).device
|
|
||||||
|
|
||||||
# Create dummy dataset stats if not loaded with the model
|
# Check if normalization layers have invalid stats and replace with dummy stats if needed
|
||||||
if not hasattr(policy, "normalize_inputs") or policy.normalize_inputs is None:
|
try:
|
||||||
|
# Check if the normalize_inputs has valid stats
|
||||||
|
if hasattr(policy.normalize_inputs, "stats"):
|
||||||
|
obs_state_mean = policy.normalize_inputs.stats.get("observation.state", {}).get("mean")
|
||||||
|
if obs_state_mean is not None and (
|
||||||
|
torch.isinf(obs_state_mean).any() or torch.isnan(obs_state_mean).any()
|
||||||
|
):
|
||||||
|
print("⚠️ Found invalid normalization stats, replacing with dummy stats...")
|
||||||
|
|
||||||
|
# Replace with dummy stats
|
||||||
|
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||||
|
|
||||||
|
policy.normalize_inputs = Normalize(
|
||||||
|
policy.config.input_features, policy.config.normalization_mapping, dummy_stats
|
||||||
|
)
|
||||||
|
policy.normalize_targets = Normalize(
|
||||||
|
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
||||||
|
)
|
||||||
|
policy.unnormalize_outputs = Unnormalize(
|
||||||
|
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
||||||
|
)
|
||||||
|
print("✓ Normalization layers updated with dummy stats")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Error checking normalization stats, creating new ones: {e}")
|
||||||
|
# Fallback: create new normalization layers
|
||||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||||
|
|
||||||
dataset_stats = {
|
|
||||||
"observation.state": {
|
|
||||||
"mean": torch.zeros(policy.config.state_dim, device=device),
|
|
||||||
"std": torch.ones(policy.config.state_dim, device=device),
|
|
||||||
},
|
|
||||||
"action": {
|
|
||||||
"mean": torch.zeros(policy.config.action_dim, device=device),
|
|
||||||
"std": torch.ones(policy.config.action_dim, device=device),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
policy.normalize_inputs = Normalize(
|
policy.normalize_inputs = Normalize(
|
||||||
policy.config.input_features, policy.config.normalization_mapping, dataset_stats
|
policy.config.input_features, policy.config.normalization_mapping, dummy_stats
|
||||||
)
|
)
|
||||||
policy.normalize_targets = Normalize(
|
policy.normalize_targets = Normalize(
|
||||||
policy.config.output_features, policy.config.normalization_mapping, dataset_stats
|
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
||||||
)
|
)
|
||||||
policy.unnormalize_outputs = Unnormalize(
|
policy.unnormalize_outputs = Unnormalize(
|
||||||
policy.config.output_features, policy.config.normalization_mapping, dataset_stats
|
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create test batch
|
# Create test batch
|
||||||
|
|||||||
Reference in New Issue
Block a user