Add normalization check for backward compatibility

This commit is contained in:
AdilZouitine
2025-06-05 13:18:23 +02:00
parent 79ec487af7
commit 113b3ba343
+34
View File
@@ -418,3 +418,37 @@ class UnnormalizeBuffer(nn.Module):
raise ValueError(norm_mode)
return batch
def convert_normalize_to_buffer_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Convert state dict from Normalize/Unnormalize classes to NormalizeBuffer/UnnormalizeBuffer format.
Args:
state_dict: State dict from a model using Normalize/Unnormalize classes
Returns:
Converted state dict compatible with NormalizeBuffer/UnnormalizeBuffer classes
Example:
# Old format (Normalize): "buffer_observation_image.mean"
# New format (NormalizeBuffer): "observation_image_mean"
"""
converted_state_dict = {}
for key, value in state_dict.items():
# Check if this is a normalization buffer parameter
if key.startswith("buffer_") and ("." in key):
# Extract the prefix and stat type
# e.g. "buffer_observation_image.mean" -> "observation_image", "mean"
buffer_part = key[7:] # Remove "buffer_" prefix
prefix, stat_type = buffer_part.rsplit(".", 1)
# Convert to new format: "observation_image_mean"
new_key = f"{prefix}_{stat_type}"
converted_state_dict[new_key] = value
else:
# Keep non-normalization keys unchanged
converted_state_dict[key] = value
return converted_state_dict