mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
feat(normalization): Implement IDENTITY mode for normalization and unnormalization
- Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes.
This commit is contained in:
committed by
Steven Palma
parent
c0013b130b
commit
fbe9009db2
@@ -128,7 +128,19 @@ class NormalizerProcessor:
|
|||||||
|
|
||||||
processed = dict(observation)
|
processed = dict(observation)
|
||||||
for key in keys_to_norm:
|
for key in keys_to_norm:
|
||||||
if key not in processed or key not in self._tensor_stats:
|
if key not in processed or key not in self.features:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check the normalization mode for this feature type
|
||||||
|
feature = self.features[key]
|
||||||
|
norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY)
|
||||||
|
|
||||||
|
# Skip normalization if mode is IDENTITY
|
||||||
|
if norm_mode is NormalizationMode.IDENTITY:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip if no stats available for this key
|
||||||
|
if key not in self._tensor_stats:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
orig_val = processed[key]
|
orig_val = processed[key]
|
||||||
@@ -139,16 +151,32 @@ class NormalizerProcessor:
|
|||||||
)
|
)
|
||||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||||
|
|
||||||
if "mean" in stats and "std" in stats:
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
mean, std = stats["mean"], stats["std"]
|
if "mean" in stats and "std" in stats:
|
||||||
processed[key] = (tensor - mean) / (std + self.eps)
|
mean, std = stats["mean"], stats["std"]
|
||||||
elif "min" in stats and "max" in stats:
|
processed[key] = (tensor - mean) / (std + self.eps)
|
||||||
min_val, max_val = stats["min"], stats["max"]
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
if "min" in stats and "max" in stats:
|
||||||
|
min_val, max_val = stats["min"], stats["max"]
|
||||||
|
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
||||||
def _normalize_action(self, action):
|
def _normalize_action(self, action):
|
||||||
if action is None or "action" not in self._tensor_stats:
|
if action is None:
|
||||||
|
return action
|
||||||
|
|
||||||
|
# Check the normalization mode for actions
|
||||||
|
norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY)
|
||||||
|
|
||||||
|
# Skip normalization if mode is IDENTITY
|
||||||
|
if norm_mode is NormalizationMode.IDENTITY:
|
||||||
|
return action
|
||||||
|
|
||||||
|
# Skip if no stats available for actions
|
||||||
|
if "action" not in self._tensor_stats:
|
||||||
return action
|
return action
|
||||||
|
|
||||||
tensor = (
|
tensor = (
|
||||||
@@ -157,13 +185,20 @@ class NormalizerProcessor:
|
|||||||
else torch.as_tensor(action, dtype=torch.float32)
|
else torch.as_tensor(action, dtype=torch.float32)
|
||||||
)
|
)
|
||||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||||
if "mean" in stats and "std" in stats:
|
|
||||||
mean, std = stats["mean"], stats["std"]
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
return (tensor - mean) / (std + self.eps)
|
if "mean" in stats and "std" in stats:
|
||||||
if "min" in stats and "max" in stats:
|
mean, std = stats["mean"], stats["std"]
|
||||||
min_val, max_val = stats["min"], stats["max"]
|
return (tensor - mean) / (std + self.eps)
|
||||||
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
if "min" in stats and "max" in stats:
|
||||||
|
min_val, max_val = stats["min"], stats["max"]
|
||||||
|
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||||
|
|
||||||
|
# If we reach here, the required stats for the normalization mode are not available
|
||||||
|
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
|
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||||
@@ -259,8 +294,21 @@ class UnnormalizerProcessor:
|
|||||||
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
|
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
|
||||||
processed = dict(observation)
|
processed = dict(observation)
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key not in processed or key not in self._tensor_stats:
|
if key not in processed or key not in self.features:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Check the normalization mode for this feature type
|
||||||
|
feature = self.features[key]
|
||||||
|
norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY)
|
||||||
|
|
||||||
|
# Skip unnormalization if mode is IDENTITY
|
||||||
|
if norm_mode is NormalizationMode.IDENTITY:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip if no stats available for this key
|
||||||
|
if key not in self._tensor_stats:
|
||||||
|
continue
|
||||||
|
|
||||||
orig_val = processed[key]
|
orig_val = processed[key]
|
||||||
tensor = (
|
tensor = (
|
||||||
orig_val.to(dtype=torch.float32)
|
orig_val.to(dtype=torch.float32)
|
||||||
@@ -268,30 +316,55 @@ class UnnormalizerProcessor:
|
|||||||
else torch.as_tensor(orig_val, dtype=torch.float32)
|
else torch.as_tensor(orig_val, dtype=torch.float32)
|
||||||
)
|
)
|
||||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||||
if "mean" in stats and "std" in stats:
|
|
||||||
mean, std = stats["mean"], stats["std"]
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
processed[key] = tensor * std + mean
|
if "mean" in stats and "std" in stats:
|
||||||
elif "min" in stats and "max" in stats:
|
mean, std = stats["mean"], stats["std"]
|
||||||
min_val, max_val = stats["min"], stats["max"]
|
processed[key] = tensor * std + mean
|
||||||
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
|
if "min" in stats and "max" in stats:
|
||||||
|
min_val, max_val = stats["min"], stats["max"]
|
||||||
|
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
||||||
def _unnormalize_action(self, action):
|
def _unnormalize_action(self, action):
|
||||||
if action is None or "action" not in self._tensor_stats:
|
if action is None:
|
||||||
return action
|
return action
|
||||||
|
|
||||||
|
# Check the normalization mode for actions
|
||||||
|
norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY)
|
||||||
|
|
||||||
|
# Skip unnormalization if mode is IDENTITY
|
||||||
|
if norm_mode is NormalizationMode.IDENTITY:
|
||||||
|
return action
|
||||||
|
|
||||||
|
# Skip if no stats available for actions
|
||||||
|
if "action" not in self._tensor_stats:
|
||||||
|
return action
|
||||||
|
|
||||||
tensor = (
|
tensor = (
|
||||||
action.to(dtype=torch.float32)
|
action.to(dtype=torch.float32)
|
||||||
if isinstance(action, torch.Tensor)
|
if isinstance(action, torch.Tensor)
|
||||||
else torch.as_tensor(action, dtype=torch.float32)
|
else torch.as_tensor(action, dtype=torch.float32)
|
||||||
)
|
)
|
||||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||||
if "mean" in stats and "std" in stats:
|
|
||||||
mean, std = stats["mean"], stats["std"]
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
return tensor * std + mean
|
if "mean" in stats and "std" in stats:
|
||||||
if "min" in stats and "max" in stats:
|
mean, std = stats["mean"], stats["std"]
|
||||||
min_val, max_val = stats["min"], stats["max"]
|
return tensor * std + mean
|
||||||
return (tensor + 1) / 2 * (max_val - min_val) + min_val
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
if "min" in stats and "max" in stats:
|
||||||
|
min_val, max_val = stats["min"], stats["max"]
|
||||||
|
return (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||||
|
|
||||||
|
# If we reach here, the required stats for the normalization mode are not available
|
||||||
|
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
|
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||||
|
|||||||
@@ -626,3 +626,330 @@ def test_serialization_roundtrip(full_stats):
|
|||||||
assert new_processor.features[key].shape == original_processor.features[key].shape
|
assert new_processor.features[key].shape == original_processor.features[key].shape
|
||||||
|
|
||||||
assert new_processor.norm_map == original_processor.norm_map
|
assert new_processor.norm_map == original_processor.norm_map
|
||||||
|
|
||||||
|
|
||||||
|
# Identity normalization tests
|
||||||
|
def test_identity_normalization_observations():
|
||||||
|
"""Test that IDENTITY mode skips normalization for observations."""
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||||
|
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||||
|
}
|
||||||
|
norm_map = {
|
||||||
|
FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode
|
||||||
|
FeatureType.STATE: NormalizationMode.MEAN_STD, # Normal mode for comparison
|
||||||
|
}
|
||||||
|
stats = {
|
||||||
|
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
|
||||||
|
"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||||
|
"observation.state": torch.tensor([1.0, -0.5]),
|
||||||
|
}
|
||||||
|
transition = create_transition(observation=observation)
|
||||||
|
|
||||||
|
normalized_transition = normalizer(transition)
|
||||||
|
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||||
|
|
||||||
|
# Image should remain unchanged (IDENTITY)
|
||||||
|
assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"])
|
||||||
|
|
||||||
|
# State should be normalized (MEAN_STD)
|
||||||
|
expected_state = (torch.tensor([1.0, -0.5]) - torch.tensor([0.0, 0.0])) / torch.tensor([1.0, 1.0])
|
||||||
|
assert torch.allclose(normalized_obs["observation.state"], expected_state)
|
||||||
|
|
||||||
|
|
||||||
|
def test_identity_normalization_actions():
|
||||||
|
"""Test that IDENTITY mode skips normalization for actions."""
|
||||||
|
features = {"action": PolicyFeature(FeatureType.ACTION, (2,))}
|
||||||
|
norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY}
|
||||||
|
stats = {"action": {"mean": [0.0, 0.0], "std": [1.0, 2.0]}}
|
||||||
|
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
|
||||||
|
action = torch.tensor([1.0, -0.5])
|
||||||
|
transition = create_transition(action=action)
|
||||||
|
|
||||||
|
normalized_transition = normalizer(transition)
|
||||||
|
|
||||||
|
# Action should remain unchanged
|
||||||
|
assert torch.allclose(normalized_transition[TransitionKey.ACTION], action)
|
||||||
|
|
||||||
|
|
||||||
|
def test_identity_unnormalization_observations():
|
||||||
|
"""Test that IDENTITY mode skips unnormalization for observations."""
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||||
|
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||||
|
}
|
||||||
|
norm_map = {
|
||||||
|
FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode
|
||||||
|
FeatureType.STATE: NormalizationMode.MIN_MAX, # Normal mode for comparison
|
||||||
|
}
|
||||||
|
stats = {
|
||||||
|
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
|
||||||
|
"observation.state": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
|
||||||
|
}
|
||||||
|
|
||||||
|
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||||
|
"observation.state": torch.tensor([0.0, -1.0]), # Normalized values in [-1, 1]
|
||||||
|
}
|
||||||
|
transition = create_transition(observation=observation)
|
||||||
|
|
||||||
|
unnormalized_transition = unnormalizer(transition)
|
||||||
|
unnormalized_obs = unnormalized_transition[TransitionKey.OBSERVATION]
|
||||||
|
|
||||||
|
# Image should remain unchanged (IDENTITY)
|
||||||
|
assert torch.allclose(unnormalized_obs["observation.image"], observation["observation.image"])
|
||||||
|
|
||||||
|
# State should be unnormalized (MIN_MAX)
|
||||||
|
# (0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = 0.0
|
||||||
|
# (-1.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = -1.0
|
||||||
|
expected_state = torch.tensor([0.0, -1.0])
|
||||||
|
assert torch.allclose(unnormalized_obs["observation.state"], expected_state)
|
||||||
|
|
||||||
|
|
||||||
|
def test_identity_unnormalization_actions():
|
||||||
|
"""Test that IDENTITY mode skips unnormalization for actions."""
|
||||||
|
features = {"action": PolicyFeature(FeatureType.ACTION, (2,))}
|
||||||
|
norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY}
|
||||||
|
stats = {"action": {"min": [-1.0, -2.0], "max": [1.0, 2.0]}}
|
||||||
|
|
||||||
|
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
|
||||||
|
action = torch.tensor([0.5, -0.8]) # Normalized values
|
||||||
|
transition = create_transition(action=action)
|
||||||
|
|
||||||
|
unnormalized_transition = unnormalizer(transition)
|
||||||
|
|
||||||
|
# Action should remain unchanged
|
||||||
|
assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action)
|
||||||
|
|
||||||
|
|
||||||
|
def test_identity_with_missing_stats():
|
||||||
|
"""Test that IDENTITY mode works even when stats are missing."""
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||||
|
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||||
|
}
|
||||||
|
norm_map = {
|
||||||
|
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||||
|
FeatureType.ACTION: NormalizationMode.IDENTITY,
|
||||||
|
}
|
||||||
|
stats = {} # No stats provided
|
||||||
|
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
|
||||||
|
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
|
||||||
|
action = torch.tensor([1.0, -0.5])
|
||||||
|
transition = create_transition(observation=observation, action=action)
|
||||||
|
|
||||||
|
# Both should work without errors and return unchanged data
|
||||||
|
normalized_transition = normalizer(transition)
|
||||||
|
unnormalized_transition = unnormalizer(transition)
|
||||||
|
|
||||||
|
assert torch.allclose(
|
||||||
|
normalized_transition[TransitionKey.OBSERVATION]["observation.image"],
|
||||||
|
observation["observation.image"],
|
||||||
|
)
|
||||||
|
assert torch.allclose(normalized_transition[TransitionKey.ACTION], action)
|
||||||
|
assert torch.allclose(
|
||||||
|
unnormalized_transition[TransitionKey.OBSERVATION]["observation.image"],
|
||||||
|
observation["observation.image"],
|
||||||
|
)
|
||||||
|
assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action)
|
||||||
|
|
||||||
|
|
||||||
|
def test_identity_mixed_with_other_modes():
|
||||||
|
"""Test IDENTITY mode mixed with other normalization modes."""
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||||
|
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||||
|
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||||
|
}
|
||||||
|
norm_map = {
|
||||||
|
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||||
|
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||||
|
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||||
|
}
|
||||||
|
stats = {
|
||||||
|
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, # Will be ignored
|
||||||
|
"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
|
||||||
|
"action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||||
|
"observation.state": torch.tensor([1.0, -0.5]),
|
||||||
|
}
|
||||||
|
action = torch.tensor([0.5, 0.0])
|
||||||
|
transition = create_transition(observation=observation, action=action)
|
||||||
|
|
||||||
|
normalized_transition = normalizer(transition)
|
||||||
|
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||||
|
normalized_action = normalized_transition[TransitionKey.ACTION]
|
||||||
|
|
||||||
|
# Image should remain unchanged (IDENTITY)
|
||||||
|
assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"])
|
||||||
|
|
||||||
|
# State should be normalized (MEAN_STD)
|
||||||
|
expected_state = torch.tensor([1.0, -0.5]) # (x - 0) / 1 = x
|
||||||
|
assert torch.allclose(normalized_obs["observation.state"], expected_state)
|
||||||
|
|
||||||
|
# Action should be normalized (MIN_MAX) to [-1, 1]
|
||||||
|
# 2 * (0.5 - (-1)) / (1 - (-1)) - 1 = 2 * 1.5 / 2 - 1 = 0.5
|
||||||
|
# 2 * (0.0 - (-1)) / (1 - (-1)) - 1 = 2 * 1.0 / 2 - 1 = 0.0
|
||||||
|
expected_action = torch.tensor([0.5, 0.0])
|
||||||
|
assert torch.allclose(normalized_action, expected_action)
|
||||||
|
|
||||||
|
|
||||||
|
def test_identity_defaults_when_not_in_norm_map():
|
||||||
|
"""Test that IDENTITY is used as default when feature type not in norm_map."""
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||||
|
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||||
|
}
|
||||||
|
norm_map = {
|
||||||
|
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||||
|
# VISUAL not specified, should default to IDENTITY
|
||||||
|
}
|
||||||
|
stats = {
|
||||||
|
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
|
||||||
|
"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||||
|
"observation.state": torch.tensor([1.0, -0.5]),
|
||||||
|
}
|
||||||
|
transition = create_transition(observation=observation)
|
||||||
|
|
||||||
|
normalized_transition = normalizer(transition)
|
||||||
|
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||||
|
|
||||||
|
# Image should remain unchanged (defaults to IDENTITY)
|
||||||
|
assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"])
|
||||||
|
|
||||||
|
# State should be normalized (explicitly MEAN_STD)
|
||||||
|
expected_state = torch.tensor([1.0, -0.5])
|
||||||
|
assert torch.allclose(normalized_obs["observation.state"], expected_state)
|
||||||
|
|
||||||
|
|
||||||
|
def test_identity_roundtrip():
|
||||||
|
"""Test that IDENTITY normalization and unnormalization are true inverses."""
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||||
|
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||||
|
}
|
||||||
|
norm_map = {
|
||||||
|
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||||
|
FeatureType.ACTION: NormalizationMode.IDENTITY,
|
||||||
|
}
|
||||||
|
stats = {
|
||||||
|
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
|
||||||
|
"action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
|
||||||
|
original_observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
|
||||||
|
original_action = torch.tensor([0.5, -0.2])
|
||||||
|
original_transition = create_transition(observation=original_observation, action=original_action)
|
||||||
|
|
||||||
|
# Normalize then unnormalize
|
||||||
|
normalized = normalizer(original_transition)
|
||||||
|
roundtrip = unnormalizer(normalized)
|
||||||
|
|
||||||
|
# Should be identical to original
|
||||||
|
assert torch.allclose(
|
||||||
|
roundtrip[TransitionKey.OBSERVATION]["observation.image"], original_observation["observation.image"]
|
||||||
|
)
|
||||||
|
assert torch.allclose(roundtrip[TransitionKey.ACTION], original_action)
|
||||||
|
|
||||||
|
|
||||||
|
def test_identity_config_serialization():
|
||||||
|
"""Test that IDENTITY mode is properly saved and loaded in config."""
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||||
|
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||||
|
}
|
||||||
|
norm_map = {
|
||||||
|
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||||
|
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||||
|
}
|
||||||
|
stats = {
|
||||||
|
"observation.image": {"mean": [0.5], "std": [0.2]},
|
||||||
|
"action": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
|
||||||
|
# Get config
|
||||||
|
config = normalizer.get_config()
|
||||||
|
|
||||||
|
# Check that IDENTITY is properly serialized
|
||||||
|
assert config["norm_map"]["VISUAL"] == "IDENTITY"
|
||||||
|
assert config["norm_map"]["ACTION"] == "MEAN_STD"
|
||||||
|
|
||||||
|
# Create new processor from config (simulating load)
|
||||||
|
new_normalizer = NormalizerProcessor(
|
||||||
|
features=config["features"],
|
||||||
|
norm_map=config["norm_map"],
|
||||||
|
stats=stats,
|
||||||
|
eps=config["eps"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that both work the same way
|
||||||
|
observation = {"observation.image": torch.tensor([0.7])}
|
||||||
|
action = torch.tensor([1.0, -0.5])
|
||||||
|
transition = create_transition(observation=observation, action=action)
|
||||||
|
|
||||||
|
result1 = normalizer(transition)
|
||||||
|
result2 = new_normalizer(transition)
|
||||||
|
|
||||||
|
# Results should be identical
|
||||||
|
assert torch.allclose(
|
||||||
|
result1[TransitionKey.OBSERVATION]["observation.image"],
|
||||||
|
result2[TransitionKey.OBSERVATION]["observation.image"],
|
||||||
|
)
|
||||||
|
assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION])
|
||||||
|
|
||||||
|
|
||||||
|
def test_unsupported_normalization_mode_error():
|
||||||
|
"""Test that unsupported normalization modes raise appropriate errors."""
|
||||||
|
features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))}
|
||||||
|
|
||||||
|
# Create an invalid norm_map (this would never happen in practice, but tests error handling)
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class InvalidMode(str, Enum):
|
||||||
|
INVALID = "INVALID"
|
||||||
|
|
||||||
|
# We can't actually pass an invalid enum to the processor due to type checking,
|
||||||
|
# but we can test the error by manipulating the norm_map after creation
|
||||||
|
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||||
|
stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}}
|
||||||
|
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
|
||||||
|
# Manually inject an invalid mode to test error handling
|
||||||
|
normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE"
|
||||||
|
|
||||||
|
observation = {"observation.state": torch.tensor([1.0, -0.5])}
|
||||||
|
transition = create_transition(observation=observation)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unsupported normalization mode"):
|
||||||
|
normalizer(transition)
|
||||||
|
|||||||
Reference in New Issue
Block a user