mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
Add normalization check for backward compatibility
This commit is contained in:
@@ -418,3 +418,37 @@ class UnnormalizeBuffer(nn.Module):
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def convert_normalize_to_buffer_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Convert state dict from Normalize/Unnormalize classes to NormalizeBuffer/UnnormalizeBuffer format.
|
||||
|
||||
Args:
|
||||
state_dict: State dict from a model using Normalize/Unnormalize classes
|
||||
|
||||
Returns:
|
||||
Converted state dict compatible with NormalizeBuffer/UnnormalizeBuffer classes
|
||||
|
||||
Example:
|
||||
# Old format (Normalize): "buffer_observation_image.mean"
|
||||
# New format (NormalizeBuffer): "observation_image_mean"
|
||||
"""
|
||||
converted_state_dict = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
# Check if this is a normalization buffer parameter
|
||||
if key.startswith("buffer_") and ("." in key):
|
||||
# Extract the prefix and stat type
|
||||
# e.g. "buffer_observation_image.mean" -> "observation_image", "mean"
|
||||
buffer_part = key[7:] # Remove "buffer_" prefix
|
||||
prefix, stat_type = buffer_part.rsplit(".", 1)
|
||||
|
||||
# Convert to new format: "observation_image_mean"
|
||||
new_key = f"{prefix}_{stat_type}"
|
||||
converted_state_dict[new_key] = value
|
||||
else:
|
||||
# Keep non-normalization keys unchanged
|
||||
converted_state_dict[key] = value
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
Reference in New Issue
Block a user