Replace normalize with register buffer version

This commit is contained in:
AdilZouitine
2025-06-05 14:13:45 +02:00
parent cf25b77805
commit 6ff4afff8f
3 changed files with 482 additions and 502 deletions
+10 -234
View File
@@ -20,92 +20,6 @@ from torch import Tensor, nn
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
def create_stats_buffers(
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> dict[str, dict[str, nn.ParameterDict]]:
"""
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
statistics.
Args: (see Normalize and Unnormalize)
Returns:
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
"""
stats_buffers = {}
for key, ft in features.items():
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
assert isinstance(norm_mode, NormalizationMode)
shape = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
# sanity checks
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
c, h, w = shape
assert c < h and c < w, f"{key} is not channel first ({shape=})"
# override image shape to be invariant to height and width
shape = (c, 1, 1)
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
# we assert they are not infinity anymore.
buffer = {}
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"mean": nn.Parameter(mean, requires_grad=False),
"std": nn.Parameter(std, requires_grad=False),
}
)
elif norm_mode is NormalizationMode.MIN_MAX:
min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"min": nn.Parameter(min, requires_grad=False),
"max": nn.Parameter(max, requires_grad=False),
}
)
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
if stats:
if isinstance(stats[key]["mean"], np.ndarray):
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
stats_buffers[key] = buffer
return stats_buffers
def _no_stats_error_str(name: str) -> str:
return (
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
@@ -113,150 +27,6 @@ def _no_stats_error_str(name: str) -> str:
)
class Normalize(nn.Module):
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.features = features
self.norm_map = norm_map
self.stats = stats
stats_buffers = create_stats_buffers(features, norm_map, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# TODO: Remove this shallow copy
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
# FIXME(aliberts, rcadene): This might lead to silent fail!
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min + 1e-8)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(norm_mode)
return batch
class Unnormalize(nn.Module):
"""
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
original range used by the environment.
"""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.features = features
self.norm_map = norm_map
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(features, norm_map, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(norm_mode)
return batch
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
# and remove the `Normalize` and `Unnormalize` classes.
def _initialize_stats_buffers(
module: nn.Module,
features: dict[str, PolicyFeature],
@@ -295,6 +65,9 @@ def _initialize_stats_buffers(
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
mean = mean_data.clone().to(dtype=torch.float32)
std = std_data.clone().to(dtype=torch.float32)
elif isinstance(mean_data, np.ndarray):
mean = torch.from_numpy(mean_data).to(dtype=torch.float32)
std = torch.from_numpy(std_data).to(dtype=torch.float32)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
@@ -312,6 +85,9 @@ def _initialize_stats_buffers(
if isinstance(min_data, torch.Tensor):
min_val = min_data.clone().to(dtype=torch.float32)
max_val = max_data.clone().to(dtype=torch.float32)
elif isinstance(min_data, np.ndarray):
min_val = torch.from_numpy(min_data).to(dtype=torch.float32)
max_val = torch.from_numpy(max_data).to(dtype=torch.float32)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
@@ -322,8 +98,8 @@ def _initialize_stats_buffers(
raise ValueError(norm_mode)
class NormalizeBuffer(nn.Module):
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
class Normalize(nn.Module):
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
def __init__(
self,
@@ -371,8 +147,8 @@ class NormalizeBuffer(nn.Module):
return batch
class UnnormalizeBuffer(nn.Module):
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
class Unnormalize(nn.Module):
"""Inverse operation of `Normalize`. Uses registered buffers for statistics."""
def __init__(
self,
-268
View File
@@ -1,268 +0,0 @@
import pytest
import torch
from lerobot.common.policies.normalize import (
Normalize,
NormalizeBuffer,
Unnormalize,
UnnormalizeBuffer,
)
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
def _dummy_setup():
# feature definitions
features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(5,)),
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
}
# map feature types to a normalization strategy
norm_map = {
FeatureType.STATE: NormalizationMode.MEAN_STD,
FeatureType.VISUAL: NormalizationMode.MIN_MAX,
}
# build statistics (include all stats for each feature)
stats = {
"observation.state": {
"mean": torch.arange(5, dtype=torch.float32),
"std": torch.arange(1, 6, dtype=torch.float32),
"min": torch.zeros(5, dtype=torch.float32),
"max": torch.ones(5, dtype=torch.float32) * 10.0,
},
# image statistics use (c,1,1) so they broadcast on spatial dims
"observation.image": {
"mean": torch.ones(3, 1, 1, dtype=torch.float32) * 127.5,
"std": torch.ones(3, 1, 1, dtype=torch.float32) * 50.0,
"min": torch.zeros(3, 1, 1, dtype=torch.float32),
"max": torch.ones(3, 1, 1, dtype=torch.float32) * 255.0,
},
}
return features, norm_map, stats
def _random_batch(stats):
"""Generate a batch consistent with the provided statistics."""
torch.manual_seed(0)
batch_size = 2
state_mean = stats["observation.state"]["mean"]
state_std = stats["observation.state"]["std"]
state = torch.randn(batch_size, 5) * state_std + state_mean # shape (b,5)
image_min = stats["observation.image"]["min"]
image_max = stats["observation.image"]["max"]
image = torch.rand(batch_size, 3, 64, 64) * (image_max - image_min) + image_min # shape (b,3,64,64)
return {
"observation.state": state,
"observation.image": image,
}
@pytest.mark.parametrize(
"module_pair",
[
(Normalize, NormalizeBuffer),
(Unnormalize, UnnormalizeBuffer),
],
)
def test_equivalence(module_pair):
features, norm_map, stats = _dummy_setup()
ParamCls, BufferCls = module_pair # noqa: N806
param_module = ParamCls(features=features, norm_map=norm_map, stats=stats)
buffer_module = BufferCls(features=features, norm_map=norm_map, stats=stats)
batch = _random_batch(stats)
out_param = param_module(batch)
out_buffer = buffer_module(batch)
# every tensor in the output dictionaries should match closely
for key in out_param:
torch.testing.assert_close(out_param[key], out_buffer[key])
def test_round_trip():
"""Normalize then unnormalize should give the original input back for both impls."""
features, norm_map, stats = _dummy_setup()
norm_p = Normalize(features, norm_map, stats)
unnorm_p = Unnormalize(features, norm_map, stats)
norm_b = NormalizeBuffer(features, norm_map, stats)
unnorm_b = UnnormalizeBuffer(features, norm_map, stats)
batch = _random_batch(stats)
recovered_p = unnorm_p(norm_p(batch))
recovered_b = unnorm_b(norm_b(batch))
for key in batch:
torch.testing.assert_close(recovered_p[key], batch[key])
torch.testing.assert_close(recovered_b[key], batch[key])
@pytest.mark.parametrize(
"image_shape,use_numpy",
[
((3, 64, 64), True),
((3, 128, 128), False),
],
)
def test_various_shapes_and_numpy(image_shape, use_numpy):
"""Ensure equivalence and round-trip correctness for different image shapes and numpy stats."""
# feature definitions (state dim fixed at 5)
features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(5,)),
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=image_shape),
}
norm_map = {
FeatureType.STATE: NormalizationMode.MEAN_STD,
FeatureType.VISUAL: NormalizationMode.MIN_MAX,
}
# statistics (torch or numpy)
state_mean = torch.arange(5, dtype=torch.float32)
state_std = torch.arange(1, 6, dtype=torch.float32)
img_min = torch.zeros(image_shape[0], 1, 1, dtype=torch.float32)
img_max = torch.ones(image_shape[0], 1, 1, dtype=torch.float32) * 10.0 # simple range [0,10]
if use_numpy:
state_mean_stats = state_mean.numpy()
state_std_stats = state_std.numpy()
img_min_stats = img_min.numpy()
img_max_stats = img_max.numpy()
else:
state_mean_stats = state_mean
state_std_stats = state_std
img_min_stats = img_min
img_max_stats = img_max
stats = {
"observation.state": {"mean": state_mean_stats, "std": state_std_stats},
"observation.image": {"min": img_min_stats, "max": img_max_stats},
}
# instantiate modules
norm_p = Normalize(features, norm_map, stats)
unnorm_p = Unnormalize(features, norm_map, stats)
norm_b = NormalizeBuffer(features, norm_map, stats)
unnorm_b = UnnormalizeBuffer(features, norm_map, stats)
# build random batch following stats
batch_size = 3
torch.manual_seed(42)
state = torch.randn(batch_size, 5) * state_std + state_mean
image = torch.rand(batch_size, *image_shape) * (img_max - img_min) + img_min
batch = {"observation.state": state, "observation.image": image}
# equivalence between param and buffer implementations
torch.testing.assert_close(norm_p(batch)["observation.state"], norm_b(batch)["observation.state"])
torch.testing.assert_close(norm_p(batch)["observation.image"], norm_b(batch)["observation.image"])
# round-trip
recovered_p = unnorm_p(norm_p(batch))
recovered_b = unnorm_b(norm_b(batch))
for key in batch:
torch.testing.assert_close(recovered_p[key], batch[key])
torch.testing.assert_close(recovered_b[key], batch[key])
def test_state_dict_conversion():
"""Test that state dict can be converted from Normalize to NormalizeBuffer format."""
from lerobot.common.policies.normalize import convert_normalize_to_buffer_state_dict
features, norm_map, stats = _dummy_setup()
# Create Normalize module and get its state dict
normalize_module = Normalize(features=features, norm_map=norm_map, stats=stats)
old_state_dict = normalize_module.state_dict()
# Convert state dict
new_state_dict = convert_normalize_to_buffer_state_dict(old_state_dict)
# Create NormalizeBuffer module and load converted state dict
buffer_module = NormalizeBuffer(features=features, norm_map=norm_map, stats=None)
buffer_module.load_state_dict(new_state_dict)
# Test that both modules produce the same output
batch = _random_batch(stats)
old_output = normalize_module(batch)
new_output = buffer_module(batch)
for key in old_output:
torch.testing.assert_close(old_output[key], new_output[key])
def test_state_dict_conversion_unnormalize():
"""Test that state dict can be converted from Unnormalize to UnnormalizeBuffer format."""
from lerobot.common.policies.normalize import convert_normalize_to_buffer_state_dict
features, norm_map, stats = _dummy_setup()
# Create Unnormalize module and get its state dict
unnormalize_module = Unnormalize(features=features, norm_map=norm_map, stats=stats)
old_state_dict = unnormalize_module.state_dict()
# Convert state dict
new_state_dict = convert_normalize_to_buffer_state_dict(old_state_dict)
# Create UnnormalizeBuffer module and load converted state dict
buffer_module = UnnormalizeBuffer(features=features, norm_map=norm_map, stats=None)
buffer_module.load_state_dict(new_state_dict)
# Test that both modules produce the same output on normalized data
batch = _random_batch(stats)
# First normalize the batch
normalize_module = Normalize(features=features, norm_map=norm_map, stats=stats)
normalized_batch = normalize_module(batch)
old_output = unnormalize_module(normalized_batch)
new_output = buffer_module(normalized_batch)
for key in old_output:
torch.testing.assert_close(old_output[key], new_output[key])
def test_state_dict_conversion_key_format():
"""Test that conversion produces the expected key format."""
from lerobot.common.policies.normalize import convert_normalize_to_buffer_state_dict
# Mock state dict with the old format
old_state_dict = {
"buffer_observation_image.mean": torch.randn(3, 1, 1),
"buffer_observation_image.std": torch.randn(3, 1, 1),
"buffer_observation_state.min": torch.randn(5),
"buffer_observation_state.max": torch.randn(5),
"some_other_param": torch.randn(10), # Non-normalization parameter
}
new_state_dict = convert_normalize_to_buffer_state_dict(old_state_dict)
# Check expected key transformations
expected_keys = {
"observation_image_mean",
"observation_image_std",
"observation_state_min",
"observation_state_max",
"some_other_param", # Should be unchanged
}
assert set(new_state_dict.keys()) == expected_keys
# Check values are preserved
torch.testing.assert_close(
new_state_dict["observation_image_mean"], old_state_dict["buffer_observation_image.mean"]
)
torch.testing.assert_close(
new_state_dict["observation_image_std"], old_state_dict["buffer_observation_image.std"]
)
torch.testing.assert_close(new_state_dict["some_other_param"], old_state_dict["some_other_param"])
+472
View File
@@ -0,0 +1,472 @@
import numpy as np
import pytest
import torch
from torch import Tensor, nn
from lerobot.common.policies.normalize import (
Normalize,
Unnormalize,
)
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
# Legacy implementations for backward compatibility testing
def create_stats_buffers_legacy(
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> dict[str, dict[str, nn.ParameterDict]]:
"""Legacy version of create_stats_buffers for testing backward compatibility."""
stats_buffers = {}
for key, ft in features.items():
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
assert isinstance(norm_mode, NormalizationMode)
shape = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
# sanity checks
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
c, h, w = shape
assert c < h and c < w, f"{key} is not channel first ({shape=})"
# override image shape to be invariant to height and width
shape = (c, 1, 1)
buffer = {}
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"mean": nn.Parameter(mean, requires_grad=False),
"std": nn.Parameter(std, requires_grad=False),
}
)
elif norm_mode is NormalizationMode.MIN_MAX:
min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"min": nn.Parameter(min, requires_grad=False),
"max": nn.Parameter(max, requires_grad=False),
}
)
if stats:
if norm_mode is NormalizationMode.MEAN_STD:
if isinstance(stats[key]["mean"], np.ndarray):
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
elif isinstance(stats[key]["mean"], torch.Tensor):
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
elif norm_mode is NormalizationMode.MIN_MAX:
if isinstance(stats[key]["min"], np.ndarray):
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
elif isinstance(stats[key]["min"], torch.Tensor):
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["min"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
stats_buffers[key] = buffer
return stats_buffers
def _no_stats_error_str_legacy(name: str) -> str:
return (
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
"pretrained model."
)
class NormalizeLegacy(nn.Module):
"""Legacy Normalize class using nn.ParameterDict for backward compatibility testing."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
self.features = features
self.norm_map = norm_map
self.stats = stats
stats_buffers = create_stats_buffers_legacy(features, norm_map, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str_legacy("mean")
assert not torch.isinf(std).any(), _no_stats_error_str_legacy("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str_legacy("min")
assert not torch.isinf(max).any(), _no_stats_error_str_legacy("max")
batch[key] = (batch[key] - min) / (max - min + 1e-8)
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(norm_mode)
return batch
class UnnormalizeLegacy(nn.Module):
"""Legacy Unnormalize class using nn.ParameterDict for backward compatibility testing."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
self.features = features
self.norm_map = norm_map
self.stats = stats
stats_buffers = create_stats_buffers_legacy(features, norm_map, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str_legacy("mean")
assert not torch.isinf(std).any(), _no_stats_error_str_legacy("std")
batch[key] = batch[key] * std + mean
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str_legacy("min")
assert not torch.isinf(max).any(), _no_stats_error_str_legacy("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(norm_mode)
return batch
def _dummy_setup():
# feature definitions
features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(5,)),
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
}
# map feature types to a normalization strategy
norm_map = {
FeatureType.STATE: NormalizationMode.MEAN_STD,
FeatureType.VISUAL: NormalizationMode.MIN_MAX,
}
# build statistics (include all stats for each feature)
stats = {
"observation.state": {
"mean": torch.arange(5, dtype=torch.float32),
"std": torch.arange(1, 6, dtype=torch.float32),
"min": torch.zeros(5, dtype=torch.float32),
"max": torch.ones(5, dtype=torch.float32) * 10.0,
},
# image statistics use (c,1,1) so they broadcast on spatial dims
"observation.image": {
"mean": torch.ones(3, 1, 1, dtype=torch.float32) * 127.5,
"std": torch.ones(3, 1, 1, dtype=torch.float32) * 50.0,
"min": torch.zeros(3, 1, 1, dtype=torch.float32),
"max": torch.ones(3, 1, 1, dtype=torch.float32) * 255.0,
},
}
return features, norm_map, stats
def _random_batch(stats):
"""Generate a batch consistent with the provided statistics."""
torch.manual_seed(0)
batch_size = 2
state_mean = stats["observation.state"]["mean"]
state_std = stats["observation.state"]["std"]
state = torch.randn(batch_size, 5) * state_std + state_mean # shape (b,5)
image_min = stats["observation.image"]["min"]
image_max = stats["observation.image"]["max"]
image = torch.rand(batch_size, 3, 64, 64) * (image_max - image_min) + image_min # shape (b,3,64,64)
return {
"observation.state": state,
"observation.image": image,
}
@pytest.mark.parametrize(
"module_pair",
[
(NormalizeLegacy, Normalize),
(UnnormalizeLegacy, Unnormalize),
],
)
def test_equivalence(module_pair):
features, norm_map, stats = _dummy_setup()
ParamCls, BufferCls = module_pair # noqa: N806
param_module = ParamCls(features=features, norm_map=norm_map, stats=stats)
buffer_module = BufferCls(features=features, norm_map=norm_map, stats=stats)
batch = _random_batch(stats)
out_param = param_module(batch)
out_buffer = buffer_module(batch)
# every tensor in the output dictionaries should match closely
for key in out_param:
torch.testing.assert_close(out_param[key], out_buffer[key])
def test_round_trip():
"""Normalize then unnormalize should give the original input back for both impls."""
features, norm_map, stats = _dummy_setup()
norm_p = NormalizeLegacy(features, norm_map, stats)
unnorm_p = UnnormalizeLegacy(features, norm_map, stats)
norm_b = Normalize(features, norm_map, stats)
unnorm_b = Unnormalize(features, norm_map, stats)
batch = _random_batch(stats)
recovered_p = unnorm_p(norm_p(batch))
recovered_b = unnorm_b(norm_b(batch))
for key in batch:
torch.testing.assert_close(recovered_p[key], batch[key])
torch.testing.assert_close(recovered_b[key], batch[key])
@pytest.mark.parametrize(
"image_shape,use_numpy",
[
((3, 64, 64), True),
((3, 128, 128), False),
],
)
def test_various_shapes_and_numpy(image_shape, use_numpy):
"""Ensure equivalence and round-trip correctness for different image shapes and numpy stats."""
# feature definitions (state dim fixed at 5)
features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(5,)),
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=image_shape),
}
norm_map = {
FeatureType.STATE: NormalizationMode.MEAN_STD,
FeatureType.VISUAL: NormalizationMode.MIN_MAX,
}
# statistics (torch or numpy)
state_mean = torch.arange(5, dtype=torch.float32)
state_std = torch.arange(1, 6, dtype=torch.float32)
img_min = torch.zeros(image_shape[0], 1, 1, dtype=torch.float32)
img_max = torch.ones(image_shape[0], 1, 1, dtype=torch.float32) * 10.0 # simple range [0,10]
if use_numpy:
state_mean_stats = state_mean.numpy()
state_std_stats = state_std.numpy()
img_min_stats = img_min.numpy()
img_max_stats = img_max.numpy()
else:
state_mean_stats = state_mean
state_std_stats = state_std
img_min_stats = img_min
img_max_stats = img_max
stats = {
"observation.state": {"mean": state_mean_stats, "std": state_std_stats},
"observation.image": {"min": img_min_stats, "max": img_max_stats},
}
# instantiate modules
norm_p = NormalizeLegacy(features, norm_map, stats)
unnorm_p = UnnormalizeLegacy(features, norm_map, stats)
norm_b = Normalize(features, norm_map, stats)
unnorm_b = Unnormalize(features, norm_map, stats)
# build random batch following stats
batch_size = 3
torch.manual_seed(42)
state = torch.randn(batch_size, 5) * state_std + state_mean
image = torch.rand(batch_size, *image_shape) * (img_max - img_min) + img_min
batch = {"observation.state": state, "observation.image": image}
# equivalence between param and buffer implementations
torch.testing.assert_close(norm_p(batch)["observation.state"], norm_b(batch)["observation.state"])
torch.testing.assert_close(norm_p(batch)["observation.image"], norm_b(batch)["observation.image"])
# round-trip
recovered_p = unnorm_p(norm_p(batch))
recovered_b = unnorm_b(norm_b(batch))
for key in batch:
torch.testing.assert_close(recovered_p[key], batch[key])
torch.testing.assert_close(recovered_b[key], batch[key])
def test_state_dict_conversion():
"""Test that state dict can be converted from Normalize to NormalizeBuffer format."""
from lerobot.common.policies.normalize import convert_normalize_to_buffer_state_dict
features, norm_map, stats = _dummy_setup()
# Create Legacy Normalize module and get its state dict
legacy_normalize_module = NormalizeLegacy(features=features, norm_map=norm_map, stats=stats)
old_state_dict = legacy_normalize_module.state_dict()
# Convert state dict
new_state_dict = convert_normalize_to_buffer_state_dict(old_state_dict)
# Create new Normalize module and load converted state dict
buffer_module = Normalize(features=features, norm_map=norm_map, stats=None)
buffer_module.load_state_dict(new_state_dict)
# Test that both modules produce the same output
batch = _random_batch(stats)
old_output = legacy_normalize_module(batch)
new_output = buffer_module(batch)
for key in old_output:
torch.testing.assert_close(old_output[key], new_output[key])
def test_state_dict_conversion_unnormalize():
"""Test that state dict can be converted from Unnormalize to UnnormalizeBuffer format."""
from lerobot.common.policies.normalize import convert_normalize_to_buffer_state_dict
features, norm_map, stats = _dummy_setup()
# Create Legacy Unnormalize module and get its state dict
legacy_unnormalize_module = UnnormalizeLegacy(features=features, norm_map=norm_map, stats=stats)
old_state_dict = legacy_unnormalize_module.state_dict()
# Convert state dict
new_state_dict = convert_normalize_to_buffer_state_dict(old_state_dict)
# Create new Unnormalize module and load converted state dict
buffer_module = Unnormalize(features=features, norm_map=norm_map, stats=None)
buffer_module.load_state_dict(new_state_dict)
# Test that both modules produce the same output on normalized data
batch = _random_batch(stats)
# First normalize the batch
normalize_module = Normalize(features=features, norm_map=norm_map, stats=stats)
normalized_batch = normalize_module(batch)
old_output = legacy_unnormalize_module(normalized_batch)
new_output = buffer_module(normalized_batch)
for key in old_output:
torch.testing.assert_close(old_output[key], new_output[key])
def test_state_dict_conversion_key_format():
"""Test that conversion produces the expected key format."""
from lerobot.common.policies.normalize import convert_normalize_to_buffer_state_dict
# Mock state dict with the old format
old_state_dict = {
"buffer_observation_image.mean": torch.randn(3, 1, 1),
"buffer_observation_image.std": torch.randn(3, 1, 1),
"buffer_observation_state.min": torch.randn(5),
"buffer_observation_state.max": torch.randn(5),
"some_other_param": torch.randn(10), # Non-normalization parameter
}
new_state_dict = convert_normalize_to_buffer_state_dict(old_state_dict)
# Check expected key transformations
expected_keys = {
"observation_image_mean",
"observation_image_std",
"observation_state_min",
"observation_state_max",
"some_other_param", # Should be unchanged
}
assert set(new_state_dict.keys()) == expected_keys
# Check values are preserved
torch.testing.assert_close(
new_state_dict["observation_image_mean"], old_state_dict["buffer_observation_image.mean"]
)
torch.testing.assert_close(
new_state_dict["observation_image_std"], old_state_dict["buffer_observation_image.std"]
)
torch.testing.assert_close(new_state_dict["some_other_param"], old_state_dict["some_other_param"])
def test_legacy_vs_buffer_equivalence():
"""Test that legacy implementation produces same results as buffer implementation."""
features, norm_map, stats = _dummy_setup()
# Create both legacy and buffer implementations
legacy_normalize = NormalizeLegacy(features=features, norm_map=norm_map, stats=stats)
buffer_normalize = Normalize(features=features, norm_map=norm_map, stats=stats)
legacy_unnormalize = UnnormalizeLegacy(features=features, norm_map=norm_map, stats=stats)
buffer_unnormalize = Unnormalize(features=features, norm_map=norm_map, stats=stats)
# Test with random batch
batch = _random_batch(stats)
# Compare normalize outputs
legacy_norm_output = legacy_normalize(batch)
buffer_norm_output = buffer_normalize(batch)
for key in legacy_norm_output:
torch.testing.assert_close(legacy_norm_output[key], buffer_norm_output[key])
# Compare unnormalize outputs (using normalized batch)
legacy_unnorm_output = legacy_unnormalize(legacy_norm_output)
buffer_unnorm_output = buffer_unnormalize(buffer_norm_output)
for key in legacy_unnorm_output:
torch.testing.assert_close(legacy_unnorm_output[key], buffer_unnorm_output[key])