feat(normalization): Implement IDENTITY mode for normalization and unnormalization

- Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified.
- Updated processing logic to check normalization modes and handle missing statistics gracefully.
- Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios.
- Improved error handling for unsupported normalization modes.
This commit is contained in:
Adil Zouitine
2025-07-25 19:06:13 +02:00
committed by Steven Palma
parent c0013b130b
commit fbe9009db2
2 changed files with 430 additions and 30 deletions
+103 -30
View File
@@ -128,7 +128,19 @@ class NormalizerProcessor:
processed = dict(observation)
for key in keys_to_norm:
if key not in processed or key not in self._tensor_stats:
if key not in processed or key not in self.features:
continue
# Check the normalization mode for this feature type
feature = self.features[key]
norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY)
# Skip normalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
continue
# Skip if no stats available for this key
if key not in self._tensor_stats:
continue
orig_val = processed[key]
@@ -139,16 +151,32 @@ class NormalizerProcessor:
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = (tensor - mean) / (std + self.eps)
elif "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = (tensor - mean) / (std + self.eps)
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
return processed
def _normalize_action(self, action):
if action is None or "action" not in self._tensor_stats:
if action is None:
return action
# Check the normalization mode for actions
norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY)
# Skip normalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
return action
# Skip if no stats available for actions
if "action" not in self._tensor_stats:
return action
tensor = (
@@ -157,13 +185,20 @@ class NormalizerProcessor:
else torch.as_tensor(action, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
return (tensor - mean) / (std + self.eps)
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
return (tensor - mean) / (std + self.eps)
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
# If we reach here, the required stats for the normalization mode are not available
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
@@ -259,8 +294,21 @@ class UnnormalizerProcessor:
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
processed = dict(observation)
for key in keys:
if key not in processed or key not in self._tensor_stats:
if key not in processed or key not in self.features:
continue
# Check the normalization mode for this feature type
feature = self.features[key]
norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY)
# Skip unnormalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
continue
# Skip if no stats available for this key
if key not in self._tensor_stats:
continue
orig_val = processed[key]
tensor = (
orig_val.to(dtype=torch.float32)
@@ -268,30 +316,55 @@ class UnnormalizerProcessor:
else torch.as_tensor(orig_val, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = tensor * std + mean
elif "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = tensor * std + mean
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
return processed
def _unnormalize_action(self, action):
if action is None or "action" not in self._tensor_stats:
if action is None:
return action
# Check the normalization mode for actions
norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY)
# Skip unnormalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
return action
# Skip if no stats available for actions
if "action" not in self._tensor_stats:
return action
tensor = (
action.to(dtype=torch.float32)
if isinstance(action, torch.Tensor)
else torch.as_tensor(action, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
return tensor * std + mean
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
return (tensor + 1) / 2 * (max_val - min_val) + min_val
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
return tensor * std + mean
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
return (tensor + 1) / 2 * (max_val - min_val) + min_val
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
# If we reach here, the required stats for the normalization mode are not available
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
+327
View File
@@ -626,3 +626,330 @@ def test_serialization_roundtrip(full_stats):
assert new_processor.features[key].shape == original_processor.features[key].shape
assert new_processor.norm_map == original_processor.norm_map
# Identity normalization tests
def test_identity_normalization_observations():
"""Test that IDENTITY mode skips normalization for observations."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode
FeatureType.STATE: NormalizationMode.MEAN_STD, # Normal mode for comparison
}
stats = {
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([1.0, -0.5]),
}
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Image should remain unchanged (IDENTITY)
assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"])
# State should be normalized (MEAN_STD)
expected_state = (torch.tensor([1.0, -0.5]) - torch.tensor([0.0, 0.0])) / torch.tensor([1.0, 1.0])
assert torch.allclose(normalized_obs["observation.state"], expected_state)
def test_identity_normalization_actions():
"""Test that IDENTITY mode skips normalization for actions."""
features = {"action": PolicyFeature(FeatureType.ACTION, (2,))}
norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY}
stats = {"action": {"mean": [0.0, 0.0], "std": [1.0, 2.0]}}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
action = torch.tensor([1.0, -0.5])
transition = create_transition(action=action)
normalized_transition = normalizer(transition)
# Action should remain unchanged
assert torch.allclose(normalized_transition[TransitionKey.ACTION], action)
def test_identity_unnormalization_observations():
"""Test that IDENTITY mode skips unnormalization for observations."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode
FeatureType.STATE: NormalizationMode.MIN_MAX, # Normal mode for comparison
}
stats = {
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
"observation.state": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
}
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.0, -1.0]), # Normalized values in [-1, 1]
}
transition = create_transition(observation=observation)
unnormalized_transition = unnormalizer(transition)
unnormalized_obs = unnormalized_transition[TransitionKey.OBSERVATION]
# Image should remain unchanged (IDENTITY)
assert torch.allclose(unnormalized_obs["observation.image"], observation["observation.image"])
# State should be unnormalized (MIN_MAX)
# (0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = 0.0
# (-1.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = -1.0
expected_state = torch.tensor([0.0, -1.0])
assert torch.allclose(unnormalized_obs["observation.state"], expected_state)
def test_identity_unnormalization_actions():
"""Test that IDENTITY mode skips unnormalization for actions."""
features = {"action": PolicyFeature(FeatureType.ACTION, (2,))}
norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY}
stats = {"action": {"min": [-1.0, -2.0], "max": [1.0, 2.0]}}
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
action = torch.tensor([0.5, -0.8]) # Normalized values
transition = create_transition(action=action)
unnormalized_transition = unnormalizer(transition)
# Action should remain unchanged
assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action)
def test_identity_with_missing_stats():
"""Test that IDENTITY mode works even when stats are missing."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.IDENTITY,
FeatureType.ACTION: NormalizationMode.IDENTITY,
}
stats = {} # No stats provided
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
action = torch.tensor([1.0, -0.5])
transition = create_transition(observation=observation, action=action)
# Both should work without errors and return unchanged data
normalized_transition = normalizer(transition)
unnormalized_transition = unnormalizer(transition)
assert torch.allclose(
normalized_transition[TransitionKey.OBSERVATION]["observation.image"],
observation["observation.image"],
)
assert torch.allclose(normalized_transition[TransitionKey.ACTION], action)
assert torch.allclose(
unnormalized_transition[TransitionKey.OBSERVATION]["observation.image"],
observation["observation.image"],
)
assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action)
def test_identity_mixed_with_other_modes():
"""Test IDENTITY mode mixed with other normalization modes."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.IDENTITY,
FeatureType.STATE: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MIN_MAX,
}
stats = {
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, # Will be ignored
"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
"action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([1.0, -0.5]),
}
action = torch.tensor([0.5, 0.0])
transition = create_transition(observation=observation, action=action)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
normalized_action = normalized_transition[TransitionKey.ACTION]
# Image should remain unchanged (IDENTITY)
assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"])
# State should be normalized (MEAN_STD)
expected_state = torch.tensor([1.0, -0.5]) # (x - 0) / 1 = x
assert torch.allclose(normalized_obs["observation.state"], expected_state)
# Action should be normalized (MIN_MAX) to [-1, 1]
# 2 * (0.5 - (-1)) / (1 - (-1)) - 1 = 2 * 1.5 / 2 - 1 = 0.5
# 2 * (0.0 - (-1)) / (1 - (-1)) - 1 = 2 * 1.0 / 2 - 1 = 0.0
expected_action = torch.tensor([0.5, 0.0])
assert torch.allclose(normalized_action, expected_action)
def test_identity_defaults_when_not_in_norm_map():
"""Test that IDENTITY is used as default when feature type not in norm_map."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.STATE: NormalizationMode.MEAN_STD,
# VISUAL not specified, should default to IDENTITY
}
stats = {
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([1.0, -0.5]),
}
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Image should remain unchanged (defaults to IDENTITY)
assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"])
# State should be normalized (explicitly MEAN_STD)
expected_state = torch.tensor([1.0, -0.5])
assert torch.allclose(normalized_obs["observation.state"], expected_state)
def test_identity_roundtrip():
"""Test that IDENTITY normalization and unnormalization are true inverses."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.IDENTITY,
FeatureType.ACTION: NormalizationMode.IDENTITY,
}
stats = {
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
"action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
original_observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
original_action = torch.tensor([0.5, -0.2])
original_transition = create_transition(observation=original_observation, action=original_action)
# Normalize then unnormalize
normalized = normalizer(original_transition)
roundtrip = unnormalizer(normalized)
# Should be identical to original
assert torch.allclose(
roundtrip[TransitionKey.OBSERVATION]["observation.image"], original_observation["observation.image"]
)
assert torch.allclose(roundtrip[TransitionKey.ACTION], original_action)
def test_identity_config_serialization():
"""Test that IDENTITY mode is properly saved and loaded in config."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.IDENTITY,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
stats = {
"observation.image": {"mean": [0.5], "std": [0.2]},
"action": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
# Get config
config = normalizer.get_config()
# Check that IDENTITY is properly serialized
assert config["norm_map"]["VISUAL"] == "IDENTITY"
assert config["norm_map"]["ACTION"] == "MEAN_STD"
# Create new processor from config (simulating load)
new_normalizer = NormalizerProcessor(
features=config["features"],
norm_map=config["norm_map"],
stats=stats,
eps=config["eps"],
)
# Test that both work the same way
observation = {"observation.image": torch.tensor([0.7])}
action = torch.tensor([1.0, -0.5])
transition = create_transition(observation=observation, action=action)
result1 = normalizer(transition)
result2 = new_normalizer(transition)
# Results should be identical
assert torch.allclose(
result1[TransitionKey.OBSERVATION]["observation.image"],
result2[TransitionKey.OBSERVATION]["observation.image"],
)
assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION])
def test_unsupported_normalization_mode_error():
"""Test that unsupported normalization modes raise appropriate errors."""
features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))}
# Create an invalid norm_map (this would never happen in practice, but tests error handling)
from enum import Enum
class InvalidMode(str, Enum):
INVALID = "INVALID"
# We can't actually pass an invalid enum to the processor due to type checking,
# but we can test the error by manipulating the norm_map after creation
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
# Manually inject an invalid mode to test error handling
normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE"
observation = {"observation.state": torch.tensor([1.0, -0.5])}
transition = create_transition(observation=observation)
with pytest.raises(ValueError, match="Unsupported normalization mode"):
normalizer(transition)