diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 14628727f..e4ce45d49 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -128,7 +128,19 @@ class NormalizerProcessor: processed = dict(observation) for key in keys_to_norm: - if key not in processed or key not in self._tensor_stats: + if key not in processed or key not in self.features: + continue + + # Check the normalization mode for this feature type + feature = self.features[key] + norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY) + + # Skip normalization if mode is IDENTITY + if norm_mode is NormalizationMode.IDENTITY: + continue + + # Skip if no stats available for this key + if key not in self._tensor_stats: continue orig_val = processed[key] @@ -139,16 +151,32 @@ class NormalizerProcessor: ) stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - processed[key] = (tensor - mean) / (std + self.eps) - elif "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 + if norm_mode is NormalizationMode.MEAN_STD: + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + processed[key] = (tensor - mean) / (std + self.eps) + elif norm_mode is NormalizationMode.MIN_MAX: + if "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 + else: + raise ValueError(f"Unsupported normalization mode: {norm_mode}") + return processed def _normalize_action(self, action): - if action is None or "action" not in self._tensor_stats: + if action is None: + return action + + # Check the normalization mode for actions + norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY) + + # Skip normalization if mode is IDENTITY + if norm_mode is NormalizationMode.IDENTITY: + return action + + # Skip if no stats available for actions + if "action" not in self._tensor_stats: return action tensor = ( @@ -157,13 +185,20 @@ class NormalizerProcessor: else torch.as_tensor(action, dtype=torch.float32) ) stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - return (tensor - mean) / (std + self.eps) - if "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 - raise ValueError("Action stats must contain either ('mean','std') or ('min','max')") + + if norm_mode is NormalizationMode.MEAN_STD: + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + return (tensor - mean) / (std + self.eps) + elif norm_mode is NormalizationMode.MIN_MAX: + if "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 + else: + raise ValueError(f"Unsupported normalization mode: {norm_mode}") + + # If we reach here, the required stats for the normalization mode are not available + raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization") def __call__(self, transition: EnvTransition) -> EnvTransition: observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION)) @@ -259,8 +294,21 @@ class UnnormalizerProcessor: keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION] processed = dict(observation) for key in keys: - if key not in processed or key not in self._tensor_stats: + if key not in processed or key not in self.features: continue + + # Check the normalization mode for this feature type + feature = self.features[key] + norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY) + + # Skip unnormalization if mode is IDENTITY + if norm_mode is NormalizationMode.IDENTITY: + continue + + # Skip if no stats available for this key + if key not in self._tensor_stats: + continue + orig_val = processed[key] tensor = ( orig_val.to(dtype=torch.float32) @@ -268,30 +316,55 @@ class UnnormalizerProcessor: else torch.as_tensor(orig_val, dtype=torch.float32) ) stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - processed[key] = tensor * std + mean - elif "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val + + if norm_mode is NormalizationMode.MEAN_STD: + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + processed[key] = tensor * std + mean + elif norm_mode is NormalizationMode.MIN_MAX: + if "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val + else: + raise ValueError(f"Unsupported normalization mode: {norm_mode}") + return processed def _unnormalize_action(self, action): - if action is None or "action" not in self._tensor_stats: + if action is None: return action + + # Check the normalization mode for actions + norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY) + + # Skip unnormalization if mode is IDENTITY + if norm_mode is NormalizationMode.IDENTITY: + return action + + # Skip if no stats available for actions + if "action" not in self._tensor_stats: + return action + tensor = ( action.to(dtype=torch.float32) if isinstance(action, torch.Tensor) else torch.as_tensor(action, dtype=torch.float32) ) stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - return tensor * std + mean - if "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - return (tensor + 1) / 2 * (max_val - min_val) + min_val - raise ValueError("Action stats must contain either ('mean','std') or ('min','max')") + + if norm_mode is NormalizationMode.MEAN_STD: + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + return tensor * std + mean + elif norm_mode is NormalizationMode.MIN_MAX: + if "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + return (tensor + 1) / 2 * (max_val - min_val) + min_val + else: + raise ValueError(f"Unsupported normalization mode: {norm_mode}") + + # If we reach here, the required stats for the normalization mode are not available + raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization") def __call__(self, transition: EnvTransition) -> EnvTransition: observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION)) diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 26aea56c7..5611ace04 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -626,3 +626,330 @@ def test_serialization_roundtrip(full_stats): assert new_processor.features[key].shape == original_processor.features[key].shape assert new_processor.norm_map == original_processor.norm_map + + +# Identity normalization tests +def test_identity_normalization_observations(): + """Test that IDENTITY mode skips normalization for observations.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode + FeatureType.STATE: NormalizationMode.MEAN_STD, # Normal mode for comparison + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + } + + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([1.0, -0.5]), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Image should remain unchanged (IDENTITY) + assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + + # State should be normalized (MEAN_STD) + expected_state = (torch.tensor([1.0, -0.5]) - torch.tensor([0.0, 0.0])) / torch.tensor([1.0, 1.0]) + assert torch.allclose(normalized_obs["observation.state"], expected_state) + + +def test_identity_normalization_actions(): + """Test that IDENTITY mode skips normalization for actions.""" + features = {"action": PolicyFeature(FeatureType.ACTION, (2,))} + norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY} + stats = {"action": {"mean": [0.0, 0.0], "std": [1.0, 2.0]}} + + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + action = torch.tensor([1.0, -0.5]) + transition = create_transition(action=action) + + normalized_transition = normalizer(transition) + + # Action should remain unchanged + assert torch.allclose(normalized_transition[TransitionKey.ACTION], action) + + +def test_identity_unnormalization_observations(): + """Test that IDENTITY mode skips unnormalization for observations.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode + FeatureType.STATE: NormalizationMode.MIN_MAX, # Normal mode for comparison + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + "observation.state": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, + } + + unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.0, -1.0]), # Normalized values in [-1, 1] + } + transition = create_transition(observation=observation) + + unnormalized_transition = unnormalizer(transition) + unnormalized_obs = unnormalized_transition[TransitionKey.OBSERVATION] + + # Image should remain unchanged (IDENTITY) + assert torch.allclose(unnormalized_obs["observation.image"], observation["observation.image"]) + + # State should be unnormalized (MIN_MAX) + # (0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = 0.0 + # (-1.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = -1.0 + expected_state = torch.tensor([0.0, -1.0]) + assert torch.allclose(unnormalized_obs["observation.state"], expected_state) + + +def test_identity_unnormalization_actions(): + """Test that IDENTITY mode skips unnormalization for actions.""" + features = {"action": PolicyFeature(FeatureType.ACTION, (2,))} + norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY} + stats = {"action": {"min": [-1.0, -2.0], "max": [1.0, 2.0]}} + + unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + action = torch.tensor([0.5, -0.8]) # Normalized values + transition = create_transition(action=action) + + unnormalized_transition = unnormalizer(transition) + + # Action should remain unchanged + assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action) + + +def test_identity_with_missing_stats(): + """Test that IDENTITY mode works even when stats are missing.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.IDENTITY, + } + stats = {} # No stats provided + + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + action = torch.tensor([1.0, -0.5]) + transition = create_transition(observation=observation, action=action) + + # Both should work without errors and return unchanged data + normalized_transition = normalizer(transition) + unnormalized_transition = unnormalizer(transition) + + assert torch.allclose( + normalized_transition[TransitionKey.OBSERVATION]["observation.image"], + observation["observation.image"], + ) + assert torch.allclose(normalized_transition[TransitionKey.ACTION], action) + assert torch.allclose( + unnormalized_transition[TransitionKey.OBSERVATION]["observation.image"], + observation["observation.image"], + ) + assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action) + + +def test_identity_mixed_with_other_modes(): + """Test IDENTITY mode mixed with other normalization modes.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, # Will be ignored + "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + "action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, + } + + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([1.0, -0.5]), + } + action = torch.tensor([0.5, 0.0]) + transition = create_transition(observation=observation, action=action) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + normalized_action = normalized_transition[TransitionKey.ACTION] + + # Image should remain unchanged (IDENTITY) + assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + + # State should be normalized (MEAN_STD) + expected_state = torch.tensor([1.0, -0.5]) # (x - 0) / 1 = x + assert torch.allclose(normalized_obs["observation.state"], expected_state) + + # Action should be normalized (MIN_MAX) to [-1, 1] + # 2 * (0.5 - (-1)) / (1 - (-1)) - 1 = 2 * 1.5 / 2 - 1 = 0.5 + # 2 * (0.0 - (-1)) / (1 - (-1)) - 1 = 2 * 1.0 / 2 - 1 = 0.0 + expected_action = torch.tensor([0.5, 0.0]) + assert torch.allclose(normalized_action, expected_action) + + +def test_identity_defaults_when_not_in_norm_map(): + """Test that IDENTITY is used as default when feature type not in norm_map.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + # VISUAL not specified, should default to IDENTITY + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + } + + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([1.0, -0.5]), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Image should remain unchanged (defaults to IDENTITY) + assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + + # State should be normalized (explicitly MEAN_STD) + expected_state = torch.tensor([1.0, -0.5]) + assert torch.allclose(normalized_obs["observation.state"], expected_state) + + +def test_identity_roundtrip(): + """Test that IDENTITY normalization and unnormalization are true inverses.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.IDENTITY, + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + "action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, + } + + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + original_observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + original_action = torch.tensor([0.5, -0.2]) + original_transition = create_transition(observation=original_observation, action=original_action) + + # Normalize then unnormalize + normalized = normalizer(original_transition) + roundtrip = unnormalizer(normalized) + + # Should be identical to original + assert torch.allclose( + roundtrip[TransitionKey.OBSERVATION]["observation.image"], original_observation["observation.image"] + ) + assert torch.allclose(roundtrip[TransitionKey.ACTION], original_action) + + +def test_identity_config_serialization(): + """Test that IDENTITY mode is properly saved and loaded in config.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + stats = { + "observation.image": {"mean": [0.5], "std": [0.2]}, + "action": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + } + + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + # Get config + config = normalizer.get_config() + + # Check that IDENTITY is properly serialized + assert config["norm_map"]["VISUAL"] == "IDENTITY" + assert config["norm_map"]["ACTION"] == "MEAN_STD" + + # Create new processor from config (simulating load) + new_normalizer = NormalizerProcessor( + features=config["features"], + norm_map=config["norm_map"], + stats=stats, + eps=config["eps"], + ) + + # Test that both work the same way + observation = {"observation.image": torch.tensor([0.7])} + action = torch.tensor([1.0, -0.5]) + transition = create_transition(observation=observation, action=action) + + result1 = normalizer(transition) + result2 = new_normalizer(transition) + + # Results should be identical + assert torch.allclose( + result1[TransitionKey.OBSERVATION]["observation.image"], + result2[TransitionKey.OBSERVATION]["observation.image"], + ) + assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) + + +def test_unsupported_normalization_mode_error(): + """Test that unsupported normalization modes raise appropriate errors.""" + features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))} + + # Create an invalid norm_map (this would never happen in practice, but tests error handling) + from enum import Enum + + class InvalidMode(str, Enum): + INVALID = "INVALID" + + # We can't actually pass an invalid enum to the processor due to type checking, + # but we can test the error by manipulating the norm_map after creation + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} + stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}} + + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + # Manually inject an invalid mode to test error handling + normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE" + + observation = {"observation.state": torch.tensor([1.0, -0.5])} + transition = create_transition(observation=observation) + + with pytest.raises(ValueError, match="Unsupported normalization mode"): + normalizer(transition)