mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
Add normalization check for backward compatibility
This commit is contained in:
@@ -418,3 +418,37 @@ class UnnormalizeBuffer(nn.Module):
|
|||||||
raise ValueError(norm_mode)
|
raise ValueError(norm_mode)
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def convert_normalize_to_buffer_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Convert state dict from Normalize/Unnormalize classes to NormalizeBuffer/UnnormalizeBuffer format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict: State dict from a model using Normalize/Unnormalize classes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Converted state dict compatible with NormalizeBuffer/UnnormalizeBuffer classes
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Old format (Normalize): "buffer_observation_image.mean"
|
||||||
|
# New format (NormalizeBuffer): "observation_image_mean"
|
||||||
|
"""
|
||||||
|
converted_state_dict = {}
|
||||||
|
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
# Check if this is a normalization buffer parameter
|
||||||
|
if key.startswith("buffer_") and ("." in key):
|
||||||
|
# Extract the prefix and stat type
|
||||||
|
# e.g. "buffer_observation_image.mean" -> "observation_image", "mean"
|
||||||
|
buffer_part = key[7:] # Remove "buffer_" prefix
|
||||||
|
prefix, stat_type = buffer_part.rsplit(".", 1)
|
||||||
|
|
||||||
|
# Convert to new format: "observation_image_mean"
|
||||||
|
new_key = f"{prefix}_{stat_type}"
|
||||||
|
converted_state_dict[new_key] = value
|
||||||
|
else:
|
||||||
|
# Keep non-normalization keys unchanged
|
||||||
|
converted_state_dict[key] = value
|
||||||
|
|
||||||
|
return converted_state_dict
|
||||||
|
|||||||
Reference in New Issue
Block a user