From 097842c70f9784451d9844c57681cf6f5b59a942 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 7 Jul 2025 13:37:25 +0200 Subject: [PATCH] chore(normalization): addressing comments from copilot --- src/lerobot/processor/normalize_processor.py | 50 +++++++++++++++- tests/processor/test_normalize_processor.py | 60 +++++++++++++++++++- 2 files changed, 107 insertions(+), 3 deletions(-) diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index cf653be1f..3a225c36f 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -87,6 +87,23 @@ class NormalizerProcessor: ) def __post_init__(self): + # Handle deserialization from JSON config + if self.features and isinstance(list(self.features.values())[0], dict): + # Features came from JSON - need to reconstruct PolicyFeature objects + reconstructed_features = {} + for key, ft_dict in self.features.items(): + reconstructed_features[key] = PolicyFeature( + type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) + ) + self.features = reconstructed_features + + if self.norm_map and isinstance(list(self.norm_map.keys())[0], str): + # norm_map came from JSON - need to reconstruct enum keys and values + reconstructed_norm_map = {} + for ft_type_str, norm_mode_str in self.norm_map.items(): + reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) + self.norm_map = reconstructed_norm_map + # Convert statistics once so we avoid repeated numpy→Tensor conversions # during runtime. self.stats = self.stats or {} @@ -161,7 +178,13 @@ class NormalizerProcessor: ) def get_config(self) -> dict[str, Any]: - config = {"eps": self.eps} + config = { + "eps": self.eps, + "features": { + key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() + }, + "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, + } if self.normalize_keys is not None: # Serialise as a list for YAML / JSON friendliness config["normalize_keys"] = sorted(self.normalize_keys) @@ -212,6 +235,23 @@ class UnnormalizerProcessor: return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps) def __post_init__(self): + # Handle deserialization from JSON config + if self.features and isinstance(list(self.features.values())[0], dict): + # Features came from JSON - need to reconstruct PolicyFeature objects + reconstructed_features = {} + for key, ft_dict in self.features.items(): + reconstructed_features[key] = PolicyFeature( + type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) + ) + self.features = reconstructed_features + + if self.norm_map and isinstance(list(self.norm_map.keys())[0], str): + # norm_map came from JSON - need to reconstruct enum keys and values + reconstructed_norm_map = {} + for ft_type_str, norm_mode_str in self.norm_map.items(): + reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) + self.norm_map = reconstructed_norm_map + self.stats = self.stats or {} self._tensor_stats = _convert_stats_to_tensors(self.stats) @@ -269,7 +309,13 @@ class UnnormalizerProcessor: ) def get_config(self) -> dict[str, Any]: - return {"eps": self.eps} + return { + "eps": self.eps, + "features": { + key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() + }, + "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, + } def state_dict(self) -> dict[str, Tensor]: flat = {} diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 0c48433e8..3aabbe532 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -440,7 +440,21 @@ def test_get_config(full_stats): ) config = processor.get_config() - assert config == {"normalize_keys": ["observation.image"], "eps": 1e-6} + expected_config = { + "normalize_keys": ["observation.image"], + "eps": 1e-6, + "features": { + "observation.image": {"type": "VISUAL", "shape": (3, 96, 96)}, + "observation.state": {"type": "STATE", "shape": (2,)}, + "action": {"type": "ACTION", "shape": (2,)}, + }, + "norm_map": { + "VISUAL": "MEAN_STD", + "STATE": "MIN_MAX", + "ACTION": "MEAN_STD", + }, + } + assert config == expected_config def test_integration_with_robot_processor(normalizer_processor): @@ -509,3 +523,47 @@ def test_missing_action_stats_no_error(): processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) # The tensor stats should not contain the 'action' key assert "action" not in processor._tensor_stats + + +def test_serialization_roundtrip(full_stats): + """Test that features and norm_map can be serialized and deserialized correctly.""" + features = _create_full_features() + norm_map = _create_full_norm_map() + original_processor = NormalizerProcessor( + features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 + ) + + # Get config (serialization) + config = original_processor.get_config() + + # Create a new processor from the config (deserialization) + new_processor = NormalizerProcessor( + features=config["features"], + norm_map=config["norm_map"], + stats=full_stats, + normalize_keys=set(config["normalize_keys"]), + eps=config["eps"], + ) + + # Test that both processors work the same way + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + action = torch.tensor([1.0, -0.5]) + transition = (observation, action, 1.0, False, False, {}, {}) + + result1 = original_processor(transition) + result2 = new_processor(transition) + + # Compare results + assert torch.allclose(result1[0]["observation.image"], result2[0]["observation.image"]) + assert torch.allclose(result1[1], result2[1]) + + # Verify features and norm_map are correctly reconstructed + assert new_processor.features.keys() == original_processor.features.keys() + for key in new_processor.features: + assert new_processor.features[key].type == original_processor.features[key].type + assert new_processor.features[key].shape == original_processor.features[key].shape + + assert new_processor.norm_map == original_processor.norm_map