From 5659c77988cc8ddf9dff9b12b622f35dd6cfba1d Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Wed, 12 Nov 2025 00:22:49 +0700 Subject: [PATCH] Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- tests/policies/pi0_pi05/test_pi05_rtc.py | 48 +++++++++++++++++++----- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/tests/policies/pi0_pi05/test_pi05_rtc.py b/tests/policies/pi0_pi05/test_pi05_rtc.py index c5a54b25a..b58ff49dc 100644 --- a/tests/policies/pi0_pi05/test_pi05_rtc.py +++ b/tests/policies/pi0_pi05/test_pi05_rtc.py @@ -183,10 +183,20 @@ def test_pi05_rtc_inference_with_prev_chunk(): "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), } - # Create dataset stats + # Create dataset stats (PI0.5 uses QUANTILES normalization) dataset_stats = { - "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, - "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.state": { + "mean": torch.zeros(14), + "std": torch.ones(14), + "q01": -torch.ones(14), + "q99": torch.ones(14), + }, + "action": { + "mean": torch.zeros(7), + "std": torch.ones(7), + "q01": -torch.ones(7), + "q99": torch.ones(7), + }, "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, } @@ -260,10 +270,20 @@ def test_pi05_rtc_inference_without_prev_chunk(): "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), } - # Create dataset stats + # Create dataset stats (PI0.5 uses QUANTILES normalization) dataset_stats = { - "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, - "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.state": { + "mean": torch.zeros(14), + "std": torch.ones(14), + "q01": -torch.ones(14), + "q99": torch.ones(14), + }, + "action": { + "mean": torch.zeros(7), + "std": torch.ones(7), + "q01": -torch.ones(7), + "q99": torch.ones(7), + }, "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, } @@ -328,10 +348,20 @@ def test_pi05_rtc_validation_rules(): "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), } - # Create dataset stats + # Create dataset stats (PI0.5 uses QUANTILES normalization) dataset_stats = { - "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, - "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.state": { + "mean": torch.zeros(14), + "std": torch.ones(14), + "q01": -torch.ones(14), + "q99": torch.ones(14), + }, + "action": { + "mean": torch.zeros(7), + "std": torch.ones(7), + "q01": -torch.ones(7), + "q99": torch.ones(7), + }, "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, }