chore(normalization): addressing comments from copilot

This commit is contained in:
Adil Zouitine
2025-07-07 13:37:25 +02:00
parent 3b8a3a32a0
commit 097842c70f
2 changed files with 107 additions and 3 deletions
+48 -2
View File
@@ -87,6 +87,23 @@ class NormalizerProcessor:
)
def __post_init__(self):
# Handle deserialization from JSON config
if self.features and isinstance(list(self.features.values())[0], dict):
# Features came from JSON - need to reconstruct PolicyFeature objects
reconstructed_features = {}
for key, ft_dict in self.features.items():
reconstructed_features[key] = PolicyFeature(
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
)
self.features = reconstructed_features
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
# norm_map came from JSON - need to reconstruct enum keys and values
reconstructed_norm_map = {}
for ft_type_str, norm_mode_str in self.norm_map.items():
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
self.norm_map = reconstructed_norm_map
# Convert statistics once so we avoid repeated numpy→Tensor conversions
# during runtime.
self.stats = self.stats or {}
@@ -161,7 +178,13 @@ class NormalizerProcessor:
)
def get_config(self) -> dict[str, Any]:
config = {"eps": self.eps}
config = {
"eps": self.eps,
"features": {
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
},
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
}
if self.normalize_keys is not None:
# Serialise as a list for YAML / JSON friendliness
config["normalize_keys"] = sorted(self.normalize_keys)
@@ -212,6 +235,23 @@ class UnnormalizerProcessor:
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps)
def __post_init__(self):
# Handle deserialization from JSON config
if self.features and isinstance(list(self.features.values())[0], dict):
# Features came from JSON - need to reconstruct PolicyFeature objects
reconstructed_features = {}
for key, ft_dict in self.features.items():
reconstructed_features[key] = PolicyFeature(
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
)
self.features = reconstructed_features
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
# norm_map came from JSON - need to reconstruct enum keys and values
reconstructed_norm_map = {}
for ft_type_str, norm_mode_str in self.norm_map.items():
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
self.norm_map = reconstructed_norm_map
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats)
@@ -269,7 +309,13 @@ class UnnormalizerProcessor:
)
def get_config(self) -> dict[str, Any]:
return {"eps": self.eps}
return {
"eps": self.eps,
"features": {
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
},
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
}
def state_dict(self) -> dict[str, Tensor]:
flat = {}
+59 -1
View File
@@ -440,7 +440,21 @@ def test_get_config(full_stats):
)
config = processor.get_config()
assert config == {"normalize_keys": ["observation.image"], "eps": 1e-6}
expected_config = {
"normalize_keys": ["observation.image"],
"eps": 1e-6,
"features": {
"observation.image": {"type": "VISUAL", "shape": (3, 96, 96)},
"observation.state": {"type": "STATE", "shape": (2,)},
"action": {"type": "ACTION", "shape": (2,)},
},
"norm_map": {
"VISUAL": "MEAN_STD",
"STATE": "MIN_MAX",
"ACTION": "MEAN_STD",
},
}
assert config == expected_config
def test_integration_with_robot_processor(normalizer_processor):
@@ -509,3 +523,47 @@ def test_missing_action_stats_no_error():
processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
# The tensor stats should not contain the 'action' key
assert "action" not in processor._tensor_stats
def test_serialization_roundtrip(full_stats):
"""Test that features and norm_map can be serialized and deserialized correctly."""
features = _create_full_features()
norm_map = _create_full_norm_map()
original_processor = NormalizerProcessor(
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
)
# Get config (serialization)
config = original_processor.get_config()
# Create a new processor from the config (deserialization)
new_processor = NormalizerProcessor(
features=config["features"],
norm_map=config["norm_map"],
stats=full_stats,
normalize_keys=set(config["normalize_keys"]),
eps=config["eps"],
)
# Test that both processors work the same way
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = (observation, action, 1.0, False, False, {}, {})
result1 = original_processor(transition)
result2 = new_processor(transition)
# Compare results
assert torch.allclose(result1[0]["observation.image"], result2[0]["observation.image"])
assert torch.allclose(result1[1], result2[1])
# Verify features and norm_map are correctly reconstructed
assert new_processor.features.keys() == original_processor.features.keys()
for key in new_processor.features:
assert new_processor.features[key].type == original_processor.features[key].type
assert new_processor.features[key].shape == original_processor.features[key].shape
assert new_processor.norm_map == original_processor.norm_map