mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
feat(normalization): Implement IDENTITY mode for normalization and unnormalization
- Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes.
This commit is contained in:
committed by
Steven Palma
parent
c0013b130b
commit
fbe9009db2
@@ -626,3 +626,330 @@ def test_serialization_roundtrip(full_stats):
|
||||
assert new_processor.features[key].shape == original_processor.features[key].shape
|
||||
|
||||
assert new_processor.norm_map == original_processor.norm_map
|
||||
|
||||
|
||||
# Identity normalization tests
|
||||
def test_identity_normalization_observations():
|
||||
"""Test that IDENTITY mode skips normalization for observations."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD, # Normal mode for comparison
|
||||
}
|
||||
stats = {
|
||||
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
|
||||
"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([1.0, -0.5]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Image should remain unchanged (IDENTITY)
|
||||
assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"])
|
||||
|
||||
# State should be normalized (MEAN_STD)
|
||||
expected_state = (torch.tensor([1.0, -0.5]) - torch.tensor([0.0, 0.0])) / torch.tensor([1.0, 1.0])
|
||||
assert torch.allclose(normalized_obs["observation.state"], expected_state)
|
||||
|
||||
|
||||
def test_identity_normalization_actions():
|
||||
"""Test that IDENTITY mode skips normalization for actions."""
|
||||
features = {"action": PolicyFeature(FeatureType.ACTION, (2,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY}
|
||||
stats = {"action": {"mean": [0.0, 0.0], "std": [1.0, 2.0]}}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = create_transition(action=action)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
|
||||
# Action should remain unchanged
|
||||
assert torch.allclose(normalized_transition[TransitionKey.ACTION], action)
|
||||
|
||||
|
||||
def test_identity_unnormalization_observations():
|
||||
"""Test that IDENTITY mode skips unnormalization for observations."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode
|
||||
FeatureType.STATE: NormalizationMode.MIN_MAX, # Normal mode for comparison
|
||||
}
|
||||
stats = {
|
||||
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
|
||||
"observation.state": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
|
||||
}
|
||||
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.0, -1.0]), # Normalized values in [-1, 1]
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
unnormalized_transition = unnormalizer(transition)
|
||||
unnormalized_obs = unnormalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Image should remain unchanged (IDENTITY)
|
||||
assert torch.allclose(unnormalized_obs["observation.image"], observation["observation.image"])
|
||||
|
||||
# State should be unnormalized (MIN_MAX)
|
||||
# (0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = 0.0
|
||||
# (-1.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = -1.0
|
||||
expected_state = torch.tensor([0.0, -1.0])
|
||||
assert torch.allclose(unnormalized_obs["observation.state"], expected_state)
|
||||
|
||||
|
||||
def test_identity_unnormalization_actions():
|
||||
"""Test that IDENTITY mode skips unnormalization for actions."""
|
||||
features = {"action": PolicyFeature(FeatureType.ACTION, (2,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY}
|
||||
stats = {"action": {"min": [-1.0, -2.0], "max": [1.0, 2.0]}}
|
||||
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
action = torch.tensor([0.5, -0.8]) # Normalized values
|
||||
transition = create_transition(action=action)
|
||||
|
||||
unnormalized_transition = unnormalizer(transition)
|
||||
|
||||
# Action should remain unchanged
|
||||
assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action)
|
||||
|
||||
|
||||
def test_identity_with_missing_stats():
|
||||
"""Test that IDENTITY mode works even when stats are missing."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.ACTION: NormalizationMode.IDENTITY,
|
||||
}
|
||||
stats = {} # No stats provided
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
|
||||
# Both should work without errors and return unchanged data
|
||||
normalized_transition = normalizer(transition)
|
||||
unnormalized_transition = unnormalizer(transition)
|
||||
|
||||
assert torch.allclose(
|
||||
normalized_transition[TransitionKey.OBSERVATION]["observation.image"],
|
||||
observation["observation.image"],
|
||||
)
|
||||
assert torch.allclose(normalized_transition[TransitionKey.ACTION], action)
|
||||
assert torch.allclose(
|
||||
unnormalized_transition[TransitionKey.OBSERVATION]["observation.image"],
|
||||
observation["observation.image"],
|
||||
)
|
||||
assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action)
|
||||
|
||||
|
||||
def test_identity_mixed_with_other_modes():
|
||||
"""Test IDENTITY mode mixed with other normalization modes."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
stats = {
|
||||
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, # Will be ignored
|
||||
"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
|
||||
"action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([1.0, -0.5]),
|
||||
}
|
||||
action = torch.tensor([0.5, 0.0])
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
normalized_action = normalized_transition[TransitionKey.ACTION]
|
||||
|
||||
# Image should remain unchanged (IDENTITY)
|
||||
assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"])
|
||||
|
||||
# State should be normalized (MEAN_STD)
|
||||
expected_state = torch.tensor([1.0, -0.5]) # (x - 0) / 1 = x
|
||||
assert torch.allclose(normalized_obs["observation.state"], expected_state)
|
||||
|
||||
# Action should be normalized (MIN_MAX) to [-1, 1]
|
||||
# 2 * (0.5 - (-1)) / (1 - (-1)) - 1 = 2 * 1.5 / 2 - 1 = 0.5
|
||||
# 2 * (0.0 - (-1)) / (1 - (-1)) - 1 = 2 * 1.0 / 2 - 1 = 0.0
|
||||
expected_action = torch.tensor([0.5, 0.0])
|
||||
assert torch.allclose(normalized_action, expected_action)
|
||||
|
||||
|
||||
def test_identity_defaults_when_not_in_norm_map():
|
||||
"""Test that IDENTITY is used as default when feature type not in norm_map."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
# VISUAL not specified, should default to IDENTITY
|
||||
}
|
||||
stats = {
|
||||
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
|
||||
"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([1.0, -0.5]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Image should remain unchanged (defaults to IDENTITY)
|
||||
assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"])
|
||||
|
||||
# State should be normalized (explicitly MEAN_STD)
|
||||
expected_state = torch.tensor([1.0, -0.5])
|
||||
assert torch.allclose(normalized_obs["observation.state"], expected_state)
|
||||
|
||||
|
||||
def test_identity_roundtrip():
|
||||
"""Test that IDENTITY normalization and unnormalization are true inverses."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.ACTION: NormalizationMode.IDENTITY,
|
||||
}
|
||||
stats = {
|
||||
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
|
||||
"action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
original_observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
|
||||
original_action = torch.tensor([0.5, -0.2])
|
||||
original_transition = create_transition(observation=original_observation, action=original_action)
|
||||
|
||||
# Normalize then unnormalize
|
||||
normalized = normalizer(original_transition)
|
||||
roundtrip = unnormalizer(normalized)
|
||||
|
||||
# Should be identical to original
|
||||
assert torch.allclose(
|
||||
roundtrip[TransitionKey.OBSERVATION]["observation.image"], original_observation["observation.image"]
|
||||
)
|
||||
assert torch.allclose(roundtrip[TransitionKey.ACTION], original_action)
|
||||
|
||||
|
||||
def test_identity_config_serialization():
|
||||
"""Test that IDENTITY mode is properly saved and loaded in config."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
stats = {
|
||||
"observation.image": {"mean": [0.5], "std": [0.2]},
|
||||
"action": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
# Get config
|
||||
config = normalizer.get_config()
|
||||
|
||||
# Check that IDENTITY is properly serialized
|
||||
assert config["norm_map"]["VISUAL"] == "IDENTITY"
|
||||
assert config["norm_map"]["ACTION"] == "MEAN_STD"
|
||||
|
||||
# Create new processor from config (simulating load)
|
||||
new_normalizer = NormalizerProcessor(
|
||||
features=config["features"],
|
||||
norm_map=config["norm_map"],
|
||||
stats=stats,
|
||||
eps=config["eps"],
|
||||
)
|
||||
|
||||
# Test that both work the same way
|
||||
observation = {"observation.image": torch.tensor([0.7])}
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
|
||||
result1 = normalizer(transition)
|
||||
result2 = new_normalizer(transition)
|
||||
|
||||
# Results should be identical
|
||||
assert torch.allclose(
|
||||
result1[TransitionKey.OBSERVATION]["observation.image"],
|
||||
result2[TransitionKey.OBSERVATION]["observation.image"],
|
||||
)
|
||||
assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION])
|
||||
|
||||
|
||||
def test_unsupported_normalization_mode_error():
|
||||
"""Test that unsupported normalization modes raise appropriate errors."""
|
||||
features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))}
|
||||
|
||||
# Create an invalid norm_map (this would never happen in practice, but tests error handling)
|
||||
from enum import Enum
|
||||
|
||||
class InvalidMode(str, Enum):
|
||||
INVALID = "INVALID"
|
||||
|
||||
# We can't actually pass an invalid enum to the processor due to type checking,
|
||||
# but we can test the error by manipulating the norm_map after creation
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||
stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
# Manually inject an invalid mode to test error handling
|
||||
normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE"
|
||||
|
||||
observation = {"observation.state": torch.tensor([1.0, -0.5])}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported normalization mode"):
|
||||
normalizer(transition)
|
||||
|
||||
Reference in New Issue
Block a user