feat(normalization): Implement IDENTITY mode for normalization and unnormalization

- Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified.
- Updated processing logic to check normalization modes and handle missing statistics gracefully.
- Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios.
- Improved error handling for unsupported normalization modes.
This commit is contained in:
Adil Zouitine
2025-07-25 19:06:13 +02:00
committed by Steven Palma
parent c0013b130b
commit fbe9009db2
2 changed files with 430 additions and 30 deletions
+327
View File
@@ -626,3 +626,330 @@ def test_serialization_roundtrip(full_stats):
assert new_processor.features[key].shape == original_processor.features[key].shape
assert new_processor.norm_map == original_processor.norm_map
# Identity normalization tests
def test_identity_normalization_observations():
"""Test that IDENTITY mode skips normalization for observations."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode
FeatureType.STATE: NormalizationMode.MEAN_STD, # Normal mode for comparison
}
stats = {
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([1.0, -0.5]),
}
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Image should remain unchanged (IDENTITY)
assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"])
# State should be normalized (MEAN_STD)
expected_state = (torch.tensor([1.0, -0.5]) - torch.tensor([0.0, 0.0])) / torch.tensor([1.0, 1.0])
assert torch.allclose(normalized_obs["observation.state"], expected_state)
def test_identity_normalization_actions():
"""Test that IDENTITY mode skips normalization for actions."""
features = {"action": PolicyFeature(FeatureType.ACTION, (2,))}
norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY}
stats = {"action": {"mean": [0.0, 0.0], "std": [1.0, 2.0]}}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
action = torch.tensor([1.0, -0.5])
transition = create_transition(action=action)
normalized_transition = normalizer(transition)
# Action should remain unchanged
assert torch.allclose(normalized_transition[TransitionKey.ACTION], action)
def test_identity_unnormalization_observations():
"""Test that IDENTITY mode skips unnormalization for observations."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode
FeatureType.STATE: NormalizationMode.MIN_MAX, # Normal mode for comparison
}
stats = {
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
"observation.state": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
}
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.0, -1.0]), # Normalized values in [-1, 1]
}
transition = create_transition(observation=observation)
unnormalized_transition = unnormalizer(transition)
unnormalized_obs = unnormalized_transition[TransitionKey.OBSERVATION]
# Image should remain unchanged (IDENTITY)
assert torch.allclose(unnormalized_obs["observation.image"], observation["observation.image"])
# State should be unnormalized (MIN_MAX)
# (0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = 0.0
# (-1.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = -1.0
expected_state = torch.tensor([0.0, -1.0])
assert torch.allclose(unnormalized_obs["observation.state"], expected_state)
def test_identity_unnormalization_actions():
"""Test that IDENTITY mode skips unnormalization for actions."""
features = {"action": PolicyFeature(FeatureType.ACTION, (2,))}
norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY}
stats = {"action": {"min": [-1.0, -2.0], "max": [1.0, 2.0]}}
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
action = torch.tensor([0.5, -0.8]) # Normalized values
transition = create_transition(action=action)
unnormalized_transition = unnormalizer(transition)
# Action should remain unchanged
assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action)
def test_identity_with_missing_stats():
"""Test that IDENTITY mode works even when stats are missing."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.IDENTITY,
FeatureType.ACTION: NormalizationMode.IDENTITY,
}
stats = {} # No stats provided
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
action = torch.tensor([1.0, -0.5])
transition = create_transition(observation=observation, action=action)
# Both should work without errors and return unchanged data
normalized_transition = normalizer(transition)
unnormalized_transition = unnormalizer(transition)
assert torch.allclose(
normalized_transition[TransitionKey.OBSERVATION]["observation.image"],
observation["observation.image"],
)
assert torch.allclose(normalized_transition[TransitionKey.ACTION], action)
assert torch.allclose(
unnormalized_transition[TransitionKey.OBSERVATION]["observation.image"],
observation["observation.image"],
)
assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action)
def test_identity_mixed_with_other_modes():
"""Test IDENTITY mode mixed with other normalization modes."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.IDENTITY,
FeatureType.STATE: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MIN_MAX,
}
stats = {
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, # Will be ignored
"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
"action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([1.0, -0.5]),
}
action = torch.tensor([0.5, 0.0])
transition = create_transition(observation=observation, action=action)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
normalized_action = normalized_transition[TransitionKey.ACTION]
# Image should remain unchanged (IDENTITY)
assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"])
# State should be normalized (MEAN_STD)
expected_state = torch.tensor([1.0, -0.5]) # (x - 0) / 1 = x
assert torch.allclose(normalized_obs["observation.state"], expected_state)
# Action should be normalized (MIN_MAX) to [-1, 1]
# 2 * (0.5 - (-1)) / (1 - (-1)) - 1 = 2 * 1.5 / 2 - 1 = 0.5
# 2 * (0.0 - (-1)) / (1 - (-1)) - 1 = 2 * 1.0 / 2 - 1 = 0.0
expected_action = torch.tensor([0.5, 0.0])
assert torch.allclose(normalized_action, expected_action)
def test_identity_defaults_when_not_in_norm_map():
"""Test that IDENTITY is used as default when feature type not in norm_map."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.STATE: NormalizationMode.MEAN_STD,
# VISUAL not specified, should default to IDENTITY
}
stats = {
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([1.0, -0.5]),
}
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Image should remain unchanged (defaults to IDENTITY)
assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"])
# State should be normalized (explicitly MEAN_STD)
expected_state = torch.tensor([1.0, -0.5])
assert torch.allclose(normalized_obs["observation.state"], expected_state)
def test_identity_roundtrip():
"""Test that IDENTITY normalization and unnormalization are true inverses."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.IDENTITY,
FeatureType.ACTION: NormalizationMode.IDENTITY,
}
stats = {
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
"action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
original_observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
original_action = torch.tensor([0.5, -0.2])
original_transition = create_transition(observation=original_observation, action=original_action)
# Normalize then unnormalize
normalized = normalizer(original_transition)
roundtrip = unnormalizer(normalized)
# Should be identical to original
assert torch.allclose(
roundtrip[TransitionKey.OBSERVATION]["observation.image"], original_observation["observation.image"]
)
assert torch.allclose(roundtrip[TransitionKey.ACTION], original_action)
def test_identity_config_serialization():
"""Test that IDENTITY mode is properly saved and loaded in config."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.IDENTITY,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
stats = {
"observation.image": {"mean": [0.5], "std": [0.2]},
"action": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
# Get config
config = normalizer.get_config()
# Check that IDENTITY is properly serialized
assert config["norm_map"]["VISUAL"] == "IDENTITY"
assert config["norm_map"]["ACTION"] == "MEAN_STD"
# Create new processor from config (simulating load)
new_normalizer = NormalizerProcessor(
features=config["features"],
norm_map=config["norm_map"],
stats=stats,
eps=config["eps"],
)
# Test that both work the same way
observation = {"observation.image": torch.tensor([0.7])}
action = torch.tensor([1.0, -0.5])
transition = create_transition(observation=observation, action=action)
result1 = normalizer(transition)
result2 = new_normalizer(transition)
# Results should be identical
assert torch.allclose(
result1[TransitionKey.OBSERVATION]["observation.image"],
result2[TransitionKey.OBSERVATION]["observation.image"],
)
assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION])
def test_unsupported_normalization_mode_error():
"""Test that unsupported normalization modes raise appropriate errors."""
features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))}
# Create an invalid norm_map (this would never happen in practice, but tests error handling)
from enum import Enum
class InvalidMode(str, Enum):
INVALID = "INVALID"
# We can't actually pass an invalid enum to the processor due to type checking,
# but we can test the error by manipulating the norm_map after creation
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
# Manually inject an invalid mode to test error handling
normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE"
observation = {"observation.state": torch.tensor([1.0, -0.5])}
transition = create_transition(observation=observation)
with pytest.raises(ValueError, match="Unsupported normalization mode"):
normalizer(transition)