test(processor): fix expected raise when normalization types are missing (#2040)

This commit is contained in:
Steven Palma
2025-09-25 15:13:12 +02:00
committed by GitHub
parent 77fc8cc4d9
commit a196639c09
+6 -9
View File
@@ -291,7 +291,7 @@ def test_quantile_division_by_zero():
def test_quantile_partial_stats():
"""Test that quantile normalization handles missing quantile stats gracefully."""
"""Test that quantile normalization handles missing quantile stats by raising."""
features = {
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
@@ -313,11 +313,8 @@ def test_quantile_partial_stats():
}
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Should pass through unchanged when stats are incomplete
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
with pytest.raises(ValueError, match="QUANTILES normalization mode requires q01 and q99 stats"):
_ = normalizer(transition)
def test_quantile_mixed_with_other_modes():
@@ -771,7 +768,7 @@ def test_empty_stats():
def test_partial_stats():
"""If statistics are incomplete, the value should pass through unchanged."""
"""If statistics are incomplete, we should raise."""
stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max)
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
@@ -779,8 +776,8 @@ def test_partial_stats():
observation = {"observation.image": torch.tensor([0.7])}
transition = create_transition(observation=observation)
processed = normalizer(transition)[TransitionKey.OBSERVATION]
assert torch.allclose(processed["observation.image"], observation["observation.image"])
with pytest.raises(ValueError, match="MEAN_STD normalization mode requires mean and std stats"):
_ = normalizer(transition)[TransitionKey.OBSERVATION]
def test_missing_action_stats_no_error():