test(processor): fix expected raise when normalization types are missing (#2040)

This commit is contained in:
Steven Palma
2025-09-25 15:13:12 +02:00
committed by GitHub
parent 77fc8cc4d9
commit a196639c09
+6 -9
View File
@@ -291,7 +291,7 @@ def test_quantile_division_by_zero():
def test_quantile_partial_stats(): def test_quantile_partial_stats():
"""Test that quantile normalization handles missing quantile stats gracefully.""" """Test that quantile normalization handles missing quantile stats by raising."""
features = { features = {
"observation.state": PolicyFeature(FeatureType.STATE, (2,)), "observation.state": PolicyFeature(FeatureType.STATE, (2,)),
} }
@@ -313,11 +313,8 @@ def test_quantile_partial_stats():
} }
transition = create_transition(observation=observation) transition = create_transition(observation=observation)
normalized_transition = normalizer(transition) with pytest.raises(ValueError, match="QUANTILES normalization mode requires q01 and q99 stats"):
normalized_obs = normalized_transition[TransitionKey.OBSERVATION] _ = normalizer(transition)
# Should pass through unchanged when stats are incomplete
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
def test_quantile_mixed_with_other_modes(): def test_quantile_mixed_with_other_modes():
@@ -771,7 +768,7 @@ def test_empty_stats():
def test_partial_stats(): def test_partial_stats():
"""If statistics are incomplete, the value should pass through unchanged.""" """If statistics are incomplete, we should raise."""
stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max) stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max)
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
@@ -779,8 +776,8 @@ def test_partial_stats():
observation = {"observation.image": torch.tensor([0.7])} observation = {"observation.image": torch.tensor([0.7])}
transition = create_transition(observation=observation) transition = create_transition(observation=observation)
processed = normalizer(transition)[TransitionKey.OBSERVATION] with pytest.raises(ValueError, match="MEAN_STD normalization mode requires mean and std stats"):
assert torch.allclose(processed["observation.image"], observation["observation.image"]) _ = normalizer(transition)[TransitionKey.OBSERVATION]
def test_missing_action_stats_no_error(): def test_missing_action_stats_no_error():