From a196639c099b1ee79e0eafe34a99e0ca4b7be6fe Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 25 Sep 2025 15:13:12 +0200 Subject: [PATCH] test(processor): fix expected raise when normalization types are missing (#2040) --- tests/processor/test_normalize_processor.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 9669b4ea9..f07b7af45 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -291,7 +291,7 @@ def test_quantile_division_by_zero(): def test_quantile_partial_stats(): - """Test that quantile normalization handles missing quantile stats gracefully.""" + """Test that quantile normalization handles missing quantile stats by raising.""" features = { "observation.state": PolicyFeature(FeatureType.STATE, (2,)), } @@ -313,11 +313,8 @@ def test_quantile_partial_stats(): } transition = create_transition(observation=observation) - normalized_transition = normalizer(transition) - normalized_obs = normalized_transition[TransitionKey.OBSERVATION] - - # Should pass through unchanged when stats are incomplete - assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"]) + with pytest.raises(ValueError, match="QUANTILES normalization mode requires q01 and q99 stats"): + _ = normalizer(transition) def test_quantile_mixed_with_other_modes(): @@ -771,7 +768,7 @@ def test_empty_stats(): def test_partial_stats(): - """If statistics are incomplete, the value should pass through unchanged.""" + """If statistics are incomplete, we should raise.""" stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max) features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} @@ -779,8 +776,8 @@ def test_partial_stats(): observation = {"observation.image": torch.tensor([0.7])} transition = create_transition(observation=observation) - processed = normalizer(transition)[TransitionKey.OBSERVATION] - assert torch.allclose(processed["observation.image"], observation["observation.image"]) + with pytest.raises(ValueError, match="MEAN_STD normalization mode requires mean and std stats"): + _ = normalizer(transition)[TransitionKey.OBSERVATION] def test_missing_action_stats_no_error():