mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
test(processor): fix expected raise when normalization types are missing (#2040)
This commit is contained in:
@@ -291,7 +291,7 @@ def test_quantile_division_by_zero():
|
||||
|
||||
|
||||
def test_quantile_partial_stats():
|
||||
"""Test that quantile normalization handles missing quantile stats gracefully."""
|
||||
"""Test that quantile normalization handles missing quantile stats by raising."""
|
||||
features = {
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
@@ -313,11 +313,8 @@ def test_quantile_partial_stats():
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Should pass through unchanged when stats are incomplete
|
||||
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
|
||||
with pytest.raises(ValueError, match="QUANTILES normalization mode requires q01 and q99 stats"):
|
||||
_ = normalizer(transition)
|
||||
|
||||
|
||||
def test_quantile_mixed_with_other_modes():
|
||||
@@ -771,7 +768,7 @@ def test_empty_stats():
|
||||
|
||||
|
||||
def test_partial_stats():
|
||||
"""If statistics are incomplete, the value should pass through unchanged."""
|
||||
"""If statistics are incomplete, we should raise."""
|
||||
stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max)
|
||||
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
@@ -779,8 +776,8 @@ def test_partial_stats():
|
||||
observation = {"observation.image": torch.tensor([0.7])}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
processed = normalizer(transition)[TransitionKey.OBSERVATION]
|
||||
assert torch.allclose(processed["observation.image"], observation["observation.image"])
|
||||
with pytest.raises(ValueError, match="MEAN_STD normalization mode requires mean and std stats"):
|
||||
_ = normalizer(transition)[TransitionKey.OBSERVATION]
|
||||
|
||||
|
||||
def test_missing_action_stats_no_error():
|
||||
|
||||
Reference in New Issue
Block a user