Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Eugene Mironov
2025-11-12 00:22:49 +07:00
parent fd88a3acda
commit 5659c77988
+39 -9
View File
@@ -183,10 +183,20 @@ def test_pi05_rtc_inference_with_prev_chunk():
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
# Create dataset stats (PI0.5 uses QUANTILES normalization)
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.state": {
"mean": torch.zeros(14),
"std": torch.ones(14),
"q01": -torch.ones(14),
"q99": torch.ones(14),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
"q01": -torch.ones(7),
"q99": torch.ones(7),
},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
@@ -260,10 +270,20 @@ def test_pi05_rtc_inference_without_prev_chunk():
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
# Create dataset stats (PI0.5 uses QUANTILES normalization)
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.state": {
"mean": torch.zeros(14),
"std": torch.ones(14),
"q01": -torch.ones(14),
"q99": torch.ones(14),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
"q01": -torch.ones(7),
"q99": torch.ones(7),
},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
@@ -328,10 +348,20 @@ def test_pi05_rtc_validation_rules():
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
# Create dataset stats (PI0.5 uses QUANTILES normalization)
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.state": {
"mean": torch.zeros(14),
"std": torch.ones(14),
"q01": -torch.ones(14),
"q99": torch.ones(14),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
"q01": -torch.ones(7),
"q99": torch.ones(7),
},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}