From 113b3ba343ed138c6068759d5a42d131d549d682 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Thu, 5 Jun 2025 13:18:23 +0200 Subject: [PATCH] Add normalization check for backward compatibility --- lerobot/common/policies/normalize.py | 34 ++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 9cc94b929..feed874e2 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -418,3 +418,37 @@ class UnnormalizeBuffer(nn.Module): raise ValueError(norm_mode) return batch + + +def convert_normalize_to_buffer_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Convert state dict from Normalize/Unnormalize classes to NormalizeBuffer/UnnormalizeBuffer format. + + Args: + state_dict: State dict from a model using Normalize/Unnormalize classes + + Returns: + Converted state dict compatible with NormalizeBuffer/UnnormalizeBuffer classes + + Example: + # Old format (Normalize): "buffer_observation_image.mean" + # New format (NormalizeBuffer): "observation_image_mean" + """ + converted_state_dict = {} + + for key, value in state_dict.items(): + # Check if this is a normalization buffer parameter + if key.startswith("buffer_") and ("." in key): + # Extract the prefix and stat type + # e.g. "buffer_observation_image.mean" -> "observation_image", "mean" + buffer_part = key[7:] # Remove "buffer_" prefix + prefix, stat_type = buffer_part.rsplit(".", 1) + + # Convert to new format: "observation_image_mean" + new_key = f"{prefix}_{stat_type}" + converted_state_dict[new_key] = value + else: + # Keep non-normalization keys unchanged + converted_state_dict[key] = value + + return converted_state_dict