diff --git a/tests/policies/pi0_pi05/test_pi05_rtc.py b/tests/policies/pi0_pi05/test_pi05_rtc.py index c5a54b25a..b58ff49dc 100644 --- a/tests/policies/pi0_pi05/test_pi05_rtc.py +++ b/tests/policies/pi0_pi05/test_pi05_rtc.py @@ -183,10 +183,20 @@ def test_pi05_rtc_inference_with_prev_chunk(): "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), } - # Create dataset stats + # Create dataset stats (PI0.5 uses QUANTILES normalization) dataset_stats = { - "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, - "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.state": { + "mean": torch.zeros(14), + "std": torch.ones(14), + "q01": -torch.ones(14), + "q99": torch.ones(14), + }, + "action": { + "mean": torch.zeros(7), + "std": torch.ones(7), + "q01": -torch.ones(7), + "q99": torch.ones(7), + }, "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, } @@ -260,10 +270,20 @@ def test_pi05_rtc_inference_without_prev_chunk(): "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), } - # Create dataset stats + # Create dataset stats (PI0.5 uses QUANTILES normalization) dataset_stats = { - "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, - "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.state": { + "mean": torch.zeros(14), + "std": torch.ones(14), + "q01": -torch.ones(14), + "q99": torch.ones(14), + }, + "action": { + "mean": torch.zeros(7), + "std": torch.ones(7), + "q01": -torch.ones(7), + "q99": torch.ones(7), + }, "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, } @@ -328,10 +348,20 @@ def test_pi05_rtc_validation_rules(): "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), } - # Create dataset stats + # Create dataset stats (PI0.5 uses QUANTILES normalization) dataset_stats = { - "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, - "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.state": { + "mean": torch.zeros(14), + "std": torch.ones(14), + "q01": -torch.ones(14), + "q99": torch.ones(14), + }, + "action": { + "mean": torch.zeros(7), + "std": torch.ones(7), + "q01": -torch.ones(7), + "q99": torch.ones(7), + }, "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, }