mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
chore(normalization): addressing comments from copilot
This commit is contained in:
@@ -87,6 +87,23 @@ class NormalizerProcessor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
# Handle deserialization from JSON config
|
||||||
|
if self.features and isinstance(list(self.features.values())[0], dict):
|
||||||
|
# Features came from JSON - need to reconstruct PolicyFeature objects
|
||||||
|
reconstructed_features = {}
|
||||||
|
for key, ft_dict in self.features.items():
|
||||||
|
reconstructed_features[key] = PolicyFeature(
|
||||||
|
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||||
|
)
|
||||||
|
self.features = reconstructed_features
|
||||||
|
|
||||||
|
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
|
||||||
|
# norm_map came from JSON - need to reconstruct enum keys and values
|
||||||
|
reconstructed_norm_map = {}
|
||||||
|
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||||
|
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||||
|
self.norm_map = reconstructed_norm_map
|
||||||
|
|
||||||
# Convert statistics once so we avoid repeated numpy→Tensor conversions
|
# Convert statistics once so we avoid repeated numpy→Tensor conversions
|
||||||
# during runtime.
|
# during runtime.
|
||||||
self.stats = self.stats or {}
|
self.stats = self.stats or {}
|
||||||
@@ -161,7 +178,13 @@ class NormalizerProcessor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
def get_config(self) -> dict[str, Any]:
|
||||||
config = {"eps": self.eps}
|
config = {
|
||||||
|
"eps": self.eps,
|
||||||
|
"features": {
|
||||||
|
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
|
||||||
|
},
|
||||||
|
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
|
||||||
|
}
|
||||||
if self.normalize_keys is not None:
|
if self.normalize_keys is not None:
|
||||||
# Serialise as a list for YAML / JSON friendliness
|
# Serialise as a list for YAML / JSON friendliness
|
||||||
config["normalize_keys"] = sorted(self.normalize_keys)
|
config["normalize_keys"] = sorted(self.normalize_keys)
|
||||||
@@ -212,6 +235,23 @@ class UnnormalizerProcessor:
|
|||||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps)
|
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
# Handle deserialization from JSON config
|
||||||
|
if self.features and isinstance(list(self.features.values())[0], dict):
|
||||||
|
# Features came from JSON - need to reconstruct PolicyFeature objects
|
||||||
|
reconstructed_features = {}
|
||||||
|
for key, ft_dict in self.features.items():
|
||||||
|
reconstructed_features[key] = PolicyFeature(
|
||||||
|
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||||
|
)
|
||||||
|
self.features = reconstructed_features
|
||||||
|
|
||||||
|
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
|
||||||
|
# norm_map came from JSON - need to reconstruct enum keys and values
|
||||||
|
reconstructed_norm_map = {}
|
||||||
|
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||||
|
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||||
|
self.norm_map = reconstructed_norm_map
|
||||||
|
|
||||||
self.stats = self.stats or {}
|
self.stats = self.stats or {}
|
||||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||||
|
|
||||||
@@ -269,7 +309,13 @@ class UnnormalizerProcessor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
def get_config(self) -> dict[str, Any]:
|
||||||
return {"eps": self.eps}
|
return {
|
||||||
|
"eps": self.eps,
|
||||||
|
"features": {
|
||||||
|
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
|
||||||
|
},
|
||||||
|
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
|
||||||
|
}
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, Tensor]:
|
def state_dict(self) -> dict[str, Tensor]:
|
||||||
flat = {}
|
flat = {}
|
||||||
|
|||||||
@@ -440,7 +440,21 @@ def test_get_config(full_stats):
|
|||||||
)
|
)
|
||||||
|
|
||||||
config = processor.get_config()
|
config = processor.get_config()
|
||||||
assert config == {"normalize_keys": ["observation.image"], "eps": 1e-6}
|
expected_config = {
|
||||||
|
"normalize_keys": ["observation.image"],
|
||||||
|
"eps": 1e-6,
|
||||||
|
"features": {
|
||||||
|
"observation.image": {"type": "VISUAL", "shape": (3, 96, 96)},
|
||||||
|
"observation.state": {"type": "STATE", "shape": (2,)},
|
||||||
|
"action": {"type": "ACTION", "shape": (2,)},
|
||||||
|
},
|
||||||
|
"norm_map": {
|
||||||
|
"VISUAL": "MEAN_STD",
|
||||||
|
"STATE": "MIN_MAX",
|
||||||
|
"ACTION": "MEAN_STD",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert config == expected_config
|
||||||
|
|
||||||
|
|
||||||
def test_integration_with_robot_processor(normalizer_processor):
|
def test_integration_with_robot_processor(normalizer_processor):
|
||||||
@@ -509,3 +523,47 @@ def test_missing_action_stats_no_error():
|
|||||||
processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
|
processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
|
||||||
# The tensor stats should not contain the 'action' key
|
# The tensor stats should not contain the 'action' key
|
||||||
assert "action" not in processor._tensor_stats
|
assert "action" not in processor._tensor_stats
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialization_roundtrip(full_stats):
|
||||||
|
"""Test that features and norm_map can be serialized and deserialized correctly."""
|
||||||
|
features = _create_full_features()
|
||||||
|
norm_map = _create_full_norm_map()
|
||||||
|
original_processor = NormalizerProcessor(
|
||||||
|
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get config (serialization)
|
||||||
|
config = original_processor.get_config()
|
||||||
|
|
||||||
|
# Create a new processor from the config (deserialization)
|
||||||
|
new_processor = NormalizerProcessor(
|
||||||
|
features=config["features"],
|
||||||
|
norm_map=config["norm_map"],
|
||||||
|
stats=full_stats,
|
||||||
|
normalize_keys=set(config["normalize_keys"]),
|
||||||
|
eps=config["eps"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that both processors work the same way
|
||||||
|
observation = {
|
||||||
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||||
|
"observation.state": torch.tensor([0.5, 0.0]),
|
||||||
|
}
|
||||||
|
action = torch.tensor([1.0, -0.5])
|
||||||
|
transition = (observation, action, 1.0, False, False, {}, {})
|
||||||
|
|
||||||
|
result1 = original_processor(transition)
|
||||||
|
result2 = new_processor(transition)
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
assert torch.allclose(result1[0]["observation.image"], result2[0]["observation.image"])
|
||||||
|
assert torch.allclose(result1[1], result2[1])
|
||||||
|
|
||||||
|
# Verify features and norm_map are correctly reconstructed
|
||||||
|
assert new_processor.features.keys() == original_processor.features.keys()
|
||||||
|
for key in new_processor.features:
|
||||||
|
assert new_processor.features[key].type == original_processor.features[key].type
|
||||||
|
assert new_processor.features[key].shape == original_processor.features[key].shape
|
||||||
|
|
||||||
|
assert new_processor.norm_map == original_processor.norm_map
|
||||||
|
|||||||
Reference in New Issue
Block a user