mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
Replace normalize with register buffer version
This commit is contained in:
@@ -20,92 +20,6 @@ from torch import Tensor, nn
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
def create_stats_buffers(
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> dict[str, dict[str, nn.ParameterDict]]:
|
||||
"""
|
||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
||||
statistics.
|
||||
|
||||
Args: (see Normalize and Unnormalize)
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
|
||||
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
|
||||
"""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
assert isinstance(norm_mode, NormalizationMode)
|
||||
|
||||
shape = tuple(ft.shape)
|
||||
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# sanity checks
|
||||
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
||||
# override image shape to be invariant to height and width
|
||||
shape = (c, 1, 1)
|
||||
|
||||
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
|
||||
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
|
||||
# we assert they are not infinity anymore.
|
||||
|
||||
buffer = {}
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"mean": nn.Parameter(mean, requires_grad=False),
|
||||
"std": nn.Parameter(std, requires_grad=False),
|
||||
}
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"min": nn.Parameter(min, requires_grad=False),
|
||||
"max": nn.Parameter(max, requires_grad=False),
|
||||
}
|
||||
)
|
||||
|
||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||
if stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
|
||||
def _no_stats_error_str(name: str) -> str:
|
||||
return (
|
||||
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
|
||||
@@ -113,150 +27,6 @@ def _no_stats_error_str(name: str) -> str:
|
||||
)
|
||||
|
||||
|
||||
class Normalize(nn.Module):
|
||||
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# TODO: Remove this shallow copy
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
class Unnormalize(nn.Module):
|
||||
"""
|
||||
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
|
||||
original range used by the environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
|
||||
# and remove the `Normalize` and `Unnormalize` classes.
|
||||
def _initialize_stats_buffers(
|
||||
module: nn.Module,
|
||||
features: dict[str, PolicyFeature],
|
||||
@@ -295,6 +65,9 @@ def _initialize_stats_buffers(
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
mean = mean_data.clone().to(dtype=torch.float32)
|
||||
std = std_data.clone().to(dtype=torch.float32)
|
||||
elif isinstance(mean_data, np.ndarray):
|
||||
mean = torch.from_numpy(mean_data).to(dtype=torch.float32)
|
||||
std = torch.from_numpy(std_data).to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
@@ -312,6 +85,9 @@ def _initialize_stats_buffers(
|
||||
if isinstance(min_data, torch.Tensor):
|
||||
min_val = min_data.clone().to(dtype=torch.float32)
|
||||
max_val = max_data.clone().to(dtype=torch.float32)
|
||||
elif isinstance(min_data, np.ndarray):
|
||||
min_val = torch.from_numpy(min_data).to(dtype=torch.float32)
|
||||
max_val = torch.from_numpy(max_data).to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
@@ -322,8 +98,8 @@ def _initialize_stats_buffers(
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
|
||||
class NormalizeBuffer(nn.Module):
|
||||
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
|
||||
class Normalize(nn.Module):
|
||||
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -371,8 +147,8 @@ class NormalizeBuffer(nn.Module):
|
||||
return batch
|
||||
|
||||
|
||||
class UnnormalizeBuffer(nn.Module):
|
||||
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
|
||||
class Unnormalize(nn.Module):
|
||||
"""Inverse operation of `Normalize`. Uses registered buffers for statistics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,268 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.policies.normalize import (
|
||||
Normalize,
|
||||
NormalizeBuffer,
|
||||
Unnormalize,
|
||||
UnnormalizeBuffer,
|
||||
)
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
def _dummy_setup():
|
||||
# feature definitions
|
||||
features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(5,)),
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
|
||||
}
|
||||
|
||||
# map feature types to a normalization strategy
|
||||
norm_map = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
# build statistics (include all stats for each feature)
|
||||
stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.arange(5, dtype=torch.float32),
|
||||
"std": torch.arange(1, 6, dtype=torch.float32),
|
||||
"min": torch.zeros(5, dtype=torch.float32),
|
||||
"max": torch.ones(5, dtype=torch.float32) * 10.0,
|
||||
},
|
||||
# image statistics use (c,1,1) so they broadcast on spatial dims
|
||||
"observation.image": {
|
||||
"mean": torch.ones(3, 1, 1, dtype=torch.float32) * 127.5,
|
||||
"std": torch.ones(3, 1, 1, dtype=torch.float32) * 50.0,
|
||||
"min": torch.zeros(3, 1, 1, dtype=torch.float32),
|
||||
"max": torch.ones(3, 1, 1, dtype=torch.float32) * 255.0,
|
||||
},
|
||||
}
|
||||
|
||||
return features, norm_map, stats
|
||||
|
||||
|
||||
def _random_batch(stats):
|
||||
"""Generate a batch consistent with the provided statistics."""
|
||||
torch.manual_seed(0)
|
||||
batch_size = 2
|
||||
|
||||
state_mean = stats["observation.state"]["mean"]
|
||||
state_std = stats["observation.state"]["std"]
|
||||
state = torch.randn(batch_size, 5) * state_std + state_mean # shape (b,5)
|
||||
|
||||
image_min = stats["observation.image"]["min"]
|
||||
image_max = stats["observation.image"]["max"]
|
||||
image = torch.rand(batch_size, 3, 64, 64) * (image_max - image_min) + image_min # shape (b,3,64,64)
|
||||
|
||||
return {
|
||||
"observation.state": state,
|
||||
"observation.image": image,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"module_pair",
|
||||
[
|
||||
(Normalize, NormalizeBuffer),
|
||||
(Unnormalize, UnnormalizeBuffer),
|
||||
],
|
||||
)
|
||||
def test_equivalence(module_pair):
|
||||
features, norm_map, stats = _dummy_setup()
|
||||
ParamCls, BufferCls = module_pair # noqa: N806
|
||||
|
||||
param_module = ParamCls(features=features, norm_map=norm_map, stats=stats)
|
||||
buffer_module = BufferCls(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
batch = _random_batch(stats)
|
||||
|
||||
out_param = param_module(batch)
|
||||
out_buffer = buffer_module(batch)
|
||||
|
||||
# every tensor in the output dictionaries should match closely
|
||||
for key in out_param:
|
||||
torch.testing.assert_close(out_param[key], out_buffer[key])
|
||||
|
||||
|
||||
def test_round_trip():
|
||||
"""Normalize then unnormalize should give the original input back for both impls."""
|
||||
features, norm_map, stats = _dummy_setup()
|
||||
|
||||
norm_p = Normalize(features, norm_map, stats)
|
||||
unnorm_p = Unnormalize(features, norm_map, stats)
|
||||
|
||||
norm_b = NormalizeBuffer(features, norm_map, stats)
|
||||
unnorm_b = UnnormalizeBuffer(features, norm_map, stats)
|
||||
|
||||
batch = _random_batch(stats)
|
||||
recovered_p = unnorm_p(norm_p(batch))
|
||||
recovered_b = unnorm_b(norm_b(batch))
|
||||
|
||||
for key in batch:
|
||||
torch.testing.assert_close(recovered_p[key], batch[key])
|
||||
torch.testing.assert_close(recovered_b[key], batch[key])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"image_shape,use_numpy",
|
||||
[
|
||||
((3, 64, 64), True),
|
||||
((3, 128, 128), False),
|
||||
],
|
||||
)
|
||||
def test_various_shapes_and_numpy(image_shape, use_numpy):
|
||||
"""Ensure equivalence and round-trip correctness for different image shapes and numpy stats."""
|
||||
# feature definitions (state dim fixed at 5)
|
||||
features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(5,)),
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=image_shape),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
# statistics (torch or numpy)
|
||||
state_mean = torch.arange(5, dtype=torch.float32)
|
||||
state_std = torch.arange(1, 6, dtype=torch.float32)
|
||||
img_min = torch.zeros(image_shape[0], 1, 1, dtype=torch.float32)
|
||||
img_max = torch.ones(image_shape[0], 1, 1, dtype=torch.float32) * 10.0 # simple range [0,10]
|
||||
|
||||
if use_numpy:
|
||||
state_mean_stats = state_mean.numpy()
|
||||
state_std_stats = state_std.numpy()
|
||||
img_min_stats = img_min.numpy()
|
||||
img_max_stats = img_max.numpy()
|
||||
else:
|
||||
state_mean_stats = state_mean
|
||||
state_std_stats = state_std
|
||||
img_min_stats = img_min
|
||||
img_max_stats = img_max
|
||||
|
||||
stats = {
|
||||
"observation.state": {"mean": state_mean_stats, "std": state_std_stats},
|
||||
"observation.image": {"min": img_min_stats, "max": img_max_stats},
|
||||
}
|
||||
|
||||
# instantiate modules
|
||||
norm_p = Normalize(features, norm_map, stats)
|
||||
unnorm_p = Unnormalize(features, norm_map, stats)
|
||||
norm_b = NormalizeBuffer(features, norm_map, stats)
|
||||
unnorm_b = UnnormalizeBuffer(features, norm_map, stats)
|
||||
|
||||
# build random batch following stats
|
||||
batch_size = 3
|
||||
torch.manual_seed(42)
|
||||
state = torch.randn(batch_size, 5) * state_std + state_mean
|
||||
image = torch.rand(batch_size, *image_shape) * (img_max - img_min) + img_min
|
||||
|
||||
batch = {"observation.state": state, "observation.image": image}
|
||||
|
||||
# equivalence between param and buffer implementations
|
||||
torch.testing.assert_close(norm_p(batch)["observation.state"], norm_b(batch)["observation.state"])
|
||||
torch.testing.assert_close(norm_p(batch)["observation.image"], norm_b(batch)["observation.image"])
|
||||
|
||||
# round-trip
|
||||
recovered_p = unnorm_p(norm_p(batch))
|
||||
recovered_b = unnorm_b(norm_b(batch))
|
||||
|
||||
for key in batch:
|
||||
torch.testing.assert_close(recovered_p[key], batch[key])
|
||||
torch.testing.assert_close(recovered_b[key], batch[key])
|
||||
|
||||
|
||||
def test_state_dict_conversion():
|
||||
"""Test that state dict can be converted from Normalize to NormalizeBuffer format."""
|
||||
from lerobot.common.policies.normalize import convert_normalize_to_buffer_state_dict
|
||||
|
||||
features, norm_map, stats = _dummy_setup()
|
||||
|
||||
# Create Normalize module and get its state dict
|
||||
normalize_module = Normalize(features=features, norm_map=norm_map, stats=stats)
|
||||
old_state_dict = normalize_module.state_dict()
|
||||
|
||||
# Convert state dict
|
||||
new_state_dict = convert_normalize_to_buffer_state_dict(old_state_dict)
|
||||
|
||||
# Create NormalizeBuffer module and load converted state dict
|
||||
buffer_module = NormalizeBuffer(features=features, norm_map=norm_map, stats=None)
|
||||
buffer_module.load_state_dict(new_state_dict)
|
||||
|
||||
# Test that both modules produce the same output
|
||||
batch = _random_batch(stats)
|
||||
|
||||
old_output = normalize_module(batch)
|
||||
new_output = buffer_module(batch)
|
||||
|
||||
for key in old_output:
|
||||
torch.testing.assert_close(old_output[key], new_output[key])
|
||||
|
||||
|
||||
def test_state_dict_conversion_unnormalize():
|
||||
"""Test that state dict can be converted from Unnormalize to UnnormalizeBuffer format."""
|
||||
from lerobot.common.policies.normalize import convert_normalize_to_buffer_state_dict
|
||||
|
||||
features, norm_map, stats = _dummy_setup()
|
||||
|
||||
# Create Unnormalize module and get its state dict
|
||||
unnormalize_module = Unnormalize(features=features, norm_map=norm_map, stats=stats)
|
||||
old_state_dict = unnormalize_module.state_dict()
|
||||
|
||||
# Convert state dict
|
||||
new_state_dict = convert_normalize_to_buffer_state_dict(old_state_dict)
|
||||
|
||||
# Create UnnormalizeBuffer module and load converted state dict
|
||||
buffer_module = UnnormalizeBuffer(features=features, norm_map=norm_map, stats=None)
|
||||
buffer_module.load_state_dict(new_state_dict)
|
||||
|
||||
# Test that both modules produce the same output on normalized data
|
||||
batch = _random_batch(stats)
|
||||
|
||||
# First normalize the batch
|
||||
normalize_module = Normalize(features=features, norm_map=norm_map, stats=stats)
|
||||
normalized_batch = normalize_module(batch)
|
||||
|
||||
old_output = unnormalize_module(normalized_batch)
|
||||
new_output = buffer_module(normalized_batch)
|
||||
|
||||
for key in old_output:
|
||||
torch.testing.assert_close(old_output[key], new_output[key])
|
||||
|
||||
|
||||
def test_state_dict_conversion_key_format():
|
||||
"""Test that conversion produces the expected key format."""
|
||||
from lerobot.common.policies.normalize import convert_normalize_to_buffer_state_dict
|
||||
|
||||
# Mock state dict with the old format
|
||||
old_state_dict = {
|
||||
"buffer_observation_image.mean": torch.randn(3, 1, 1),
|
||||
"buffer_observation_image.std": torch.randn(3, 1, 1),
|
||||
"buffer_observation_state.min": torch.randn(5),
|
||||
"buffer_observation_state.max": torch.randn(5),
|
||||
"some_other_param": torch.randn(10), # Non-normalization parameter
|
||||
}
|
||||
|
||||
new_state_dict = convert_normalize_to_buffer_state_dict(old_state_dict)
|
||||
|
||||
# Check expected key transformations
|
||||
expected_keys = {
|
||||
"observation_image_mean",
|
||||
"observation_image_std",
|
||||
"observation_state_min",
|
||||
"observation_state_max",
|
||||
"some_other_param", # Should be unchanged
|
||||
}
|
||||
|
||||
assert set(new_state_dict.keys()) == expected_keys
|
||||
|
||||
# Check values are preserved
|
||||
torch.testing.assert_close(
|
||||
new_state_dict["observation_image_mean"], old_state_dict["buffer_observation_image.mean"]
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
new_state_dict["observation_image_std"], old_state_dict["buffer_observation_image.std"]
|
||||
)
|
||||
torch.testing.assert_close(new_state_dict["some_other_param"], old_state_dict["some_other_param"])
|
||||
@@ -0,0 +1,472 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.policies.normalize import (
|
||||
Normalize,
|
||||
Unnormalize,
|
||||
)
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
# Legacy implementations for backward compatibility testing
|
||||
def create_stats_buffers_legacy(
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> dict[str, dict[str, nn.ParameterDict]]:
|
||||
"""Legacy version of create_stats_buffers for testing backward compatibility."""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
assert isinstance(norm_mode, NormalizationMode)
|
||||
|
||||
shape = tuple(ft.shape)
|
||||
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# sanity checks
|
||||
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
||||
# override image shape to be invariant to height and width
|
||||
shape = (c, 1, 1)
|
||||
|
||||
buffer = {}
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"mean": nn.Parameter(mean, requires_grad=False),
|
||||
"std": nn.Parameter(std, requires_grad=False),
|
||||
}
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"min": nn.Parameter(min, requires_grad=False),
|
||||
"max": nn.Parameter(max, requires_grad=False),
|
||||
}
|
||||
)
|
||||
|
||||
if stats:
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
if isinstance(stats[key]["min"], np.ndarray):
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["min"], torch.Tensor):
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["min"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
|
||||
def _no_stats_error_str_legacy(name: str) -> str:
|
||||
return (
|
||||
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
|
||||
"pretrained model."
|
||||
)
|
||||
|
||||
|
||||
class NormalizeLegacy(nn.Module):
|
||||
"""Legacy Normalize class using nn.ParameterDict for backward compatibility testing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers_legacy(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str_legacy("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str_legacy("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str_legacy("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str_legacy("max")
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
class UnnormalizeLegacy(nn.Module):
|
||||
"""Legacy Unnormalize class using nn.ParameterDict for backward compatibility testing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers_legacy(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str_legacy("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str_legacy("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str_legacy("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str_legacy("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
def _dummy_setup():
|
||||
# feature definitions
|
||||
features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(5,)),
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
|
||||
}
|
||||
|
||||
# map feature types to a normalization strategy
|
||||
norm_map = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
# build statistics (include all stats for each feature)
|
||||
stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.arange(5, dtype=torch.float32),
|
||||
"std": torch.arange(1, 6, dtype=torch.float32),
|
||||
"min": torch.zeros(5, dtype=torch.float32),
|
||||
"max": torch.ones(5, dtype=torch.float32) * 10.0,
|
||||
},
|
||||
# image statistics use (c,1,1) so they broadcast on spatial dims
|
||||
"observation.image": {
|
||||
"mean": torch.ones(3, 1, 1, dtype=torch.float32) * 127.5,
|
||||
"std": torch.ones(3, 1, 1, dtype=torch.float32) * 50.0,
|
||||
"min": torch.zeros(3, 1, 1, dtype=torch.float32),
|
||||
"max": torch.ones(3, 1, 1, dtype=torch.float32) * 255.0,
|
||||
},
|
||||
}
|
||||
|
||||
return features, norm_map, stats
|
||||
|
||||
|
||||
def _random_batch(stats):
|
||||
"""Generate a batch consistent with the provided statistics."""
|
||||
torch.manual_seed(0)
|
||||
batch_size = 2
|
||||
|
||||
state_mean = stats["observation.state"]["mean"]
|
||||
state_std = stats["observation.state"]["std"]
|
||||
state = torch.randn(batch_size, 5) * state_std + state_mean # shape (b,5)
|
||||
|
||||
image_min = stats["observation.image"]["min"]
|
||||
image_max = stats["observation.image"]["max"]
|
||||
image = torch.rand(batch_size, 3, 64, 64) * (image_max - image_min) + image_min # shape (b,3,64,64)
|
||||
|
||||
return {
|
||||
"observation.state": state,
|
||||
"observation.image": image,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"module_pair",
|
||||
[
|
||||
(NormalizeLegacy, Normalize),
|
||||
(UnnormalizeLegacy, Unnormalize),
|
||||
],
|
||||
)
|
||||
def test_equivalence(module_pair):
|
||||
features, norm_map, stats = _dummy_setup()
|
||||
ParamCls, BufferCls = module_pair # noqa: N806
|
||||
|
||||
param_module = ParamCls(features=features, norm_map=norm_map, stats=stats)
|
||||
buffer_module = BufferCls(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
batch = _random_batch(stats)
|
||||
|
||||
out_param = param_module(batch)
|
||||
out_buffer = buffer_module(batch)
|
||||
|
||||
# every tensor in the output dictionaries should match closely
|
||||
for key in out_param:
|
||||
torch.testing.assert_close(out_param[key], out_buffer[key])
|
||||
|
||||
|
||||
def test_round_trip():
|
||||
"""Normalize then unnormalize should give the original input back for both impls."""
|
||||
features, norm_map, stats = _dummy_setup()
|
||||
|
||||
norm_p = NormalizeLegacy(features, norm_map, stats)
|
||||
unnorm_p = UnnormalizeLegacy(features, norm_map, stats)
|
||||
|
||||
norm_b = Normalize(features, norm_map, stats)
|
||||
unnorm_b = Unnormalize(features, norm_map, stats)
|
||||
|
||||
batch = _random_batch(stats)
|
||||
recovered_p = unnorm_p(norm_p(batch))
|
||||
recovered_b = unnorm_b(norm_b(batch))
|
||||
|
||||
for key in batch:
|
||||
torch.testing.assert_close(recovered_p[key], batch[key])
|
||||
torch.testing.assert_close(recovered_b[key], batch[key])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"image_shape,use_numpy",
|
||||
[
|
||||
((3, 64, 64), True),
|
||||
((3, 128, 128), False),
|
||||
],
|
||||
)
|
||||
def test_various_shapes_and_numpy(image_shape, use_numpy):
|
||||
"""Ensure equivalence and round-trip correctness for different image shapes and numpy stats."""
|
||||
# feature definitions (state dim fixed at 5)
|
||||
features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(5,)),
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=image_shape),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
# statistics (torch or numpy)
|
||||
state_mean = torch.arange(5, dtype=torch.float32)
|
||||
state_std = torch.arange(1, 6, dtype=torch.float32)
|
||||
img_min = torch.zeros(image_shape[0], 1, 1, dtype=torch.float32)
|
||||
img_max = torch.ones(image_shape[0], 1, 1, dtype=torch.float32) * 10.0 # simple range [0,10]
|
||||
|
||||
if use_numpy:
|
||||
state_mean_stats = state_mean.numpy()
|
||||
state_std_stats = state_std.numpy()
|
||||
img_min_stats = img_min.numpy()
|
||||
img_max_stats = img_max.numpy()
|
||||
else:
|
||||
state_mean_stats = state_mean
|
||||
state_std_stats = state_std
|
||||
img_min_stats = img_min
|
||||
img_max_stats = img_max
|
||||
|
||||
stats = {
|
||||
"observation.state": {"mean": state_mean_stats, "std": state_std_stats},
|
||||
"observation.image": {"min": img_min_stats, "max": img_max_stats},
|
||||
}
|
||||
|
||||
# instantiate modules
|
||||
norm_p = NormalizeLegacy(features, norm_map, stats)
|
||||
unnorm_p = UnnormalizeLegacy(features, norm_map, stats)
|
||||
norm_b = Normalize(features, norm_map, stats)
|
||||
unnorm_b = Unnormalize(features, norm_map, stats)
|
||||
|
||||
# build random batch following stats
|
||||
batch_size = 3
|
||||
torch.manual_seed(42)
|
||||
state = torch.randn(batch_size, 5) * state_std + state_mean
|
||||
image = torch.rand(batch_size, *image_shape) * (img_max - img_min) + img_min
|
||||
|
||||
batch = {"observation.state": state, "observation.image": image}
|
||||
|
||||
# equivalence between param and buffer implementations
|
||||
torch.testing.assert_close(norm_p(batch)["observation.state"], norm_b(batch)["observation.state"])
|
||||
torch.testing.assert_close(norm_p(batch)["observation.image"], norm_b(batch)["observation.image"])
|
||||
|
||||
# round-trip
|
||||
recovered_p = unnorm_p(norm_p(batch))
|
||||
recovered_b = unnorm_b(norm_b(batch))
|
||||
|
||||
for key in batch:
|
||||
torch.testing.assert_close(recovered_p[key], batch[key])
|
||||
torch.testing.assert_close(recovered_b[key], batch[key])
|
||||
|
||||
|
||||
def test_state_dict_conversion():
|
||||
"""Test that state dict can be converted from Normalize to NormalizeBuffer format."""
|
||||
from lerobot.common.policies.normalize import convert_normalize_to_buffer_state_dict
|
||||
|
||||
features, norm_map, stats = _dummy_setup()
|
||||
|
||||
# Create Legacy Normalize module and get its state dict
|
||||
legacy_normalize_module = NormalizeLegacy(features=features, norm_map=norm_map, stats=stats)
|
||||
old_state_dict = legacy_normalize_module.state_dict()
|
||||
|
||||
# Convert state dict
|
||||
new_state_dict = convert_normalize_to_buffer_state_dict(old_state_dict)
|
||||
|
||||
# Create new Normalize module and load converted state dict
|
||||
buffer_module = Normalize(features=features, norm_map=norm_map, stats=None)
|
||||
buffer_module.load_state_dict(new_state_dict)
|
||||
|
||||
# Test that both modules produce the same output
|
||||
batch = _random_batch(stats)
|
||||
|
||||
old_output = legacy_normalize_module(batch)
|
||||
new_output = buffer_module(batch)
|
||||
|
||||
for key in old_output:
|
||||
torch.testing.assert_close(old_output[key], new_output[key])
|
||||
|
||||
|
||||
def test_state_dict_conversion_unnormalize():
|
||||
"""Test that state dict can be converted from Unnormalize to UnnormalizeBuffer format."""
|
||||
from lerobot.common.policies.normalize import convert_normalize_to_buffer_state_dict
|
||||
|
||||
features, norm_map, stats = _dummy_setup()
|
||||
|
||||
# Create Legacy Unnormalize module and get its state dict
|
||||
legacy_unnormalize_module = UnnormalizeLegacy(features=features, norm_map=norm_map, stats=stats)
|
||||
old_state_dict = legacy_unnormalize_module.state_dict()
|
||||
|
||||
# Convert state dict
|
||||
new_state_dict = convert_normalize_to_buffer_state_dict(old_state_dict)
|
||||
|
||||
# Create new Unnormalize module and load converted state dict
|
||||
buffer_module = Unnormalize(features=features, norm_map=norm_map, stats=None)
|
||||
buffer_module.load_state_dict(new_state_dict)
|
||||
|
||||
# Test that both modules produce the same output on normalized data
|
||||
batch = _random_batch(stats)
|
||||
|
||||
# First normalize the batch
|
||||
normalize_module = Normalize(features=features, norm_map=norm_map, stats=stats)
|
||||
normalized_batch = normalize_module(batch)
|
||||
|
||||
old_output = legacy_unnormalize_module(normalized_batch)
|
||||
new_output = buffer_module(normalized_batch)
|
||||
|
||||
for key in old_output:
|
||||
torch.testing.assert_close(old_output[key], new_output[key])
|
||||
|
||||
|
||||
def test_state_dict_conversion_key_format():
|
||||
"""Test that conversion produces the expected key format."""
|
||||
from lerobot.common.policies.normalize import convert_normalize_to_buffer_state_dict
|
||||
|
||||
# Mock state dict with the old format
|
||||
old_state_dict = {
|
||||
"buffer_observation_image.mean": torch.randn(3, 1, 1),
|
||||
"buffer_observation_image.std": torch.randn(3, 1, 1),
|
||||
"buffer_observation_state.min": torch.randn(5),
|
||||
"buffer_observation_state.max": torch.randn(5),
|
||||
"some_other_param": torch.randn(10), # Non-normalization parameter
|
||||
}
|
||||
|
||||
new_state_dict = convert_normalize_to_buffer_state_dict(old_state_dict)
|
||||
|
||||
# Check expected key transformations
|
||||
expected_keys = {
|
||||
"observation_image_mean",
|
||||
"observation_image_std",
|
||||
"observation_state_min",
|
||||
"observation_state_max",
|
||||
"some_other_param", # Should be unchanged
|
||||
}
|
||||
|
||||
assert set(new_state_dict.keys()) == expected_keys
|
||||
|
||||
# Check values are preserved
|
||||
torch.testing.assert_close(
|
||||
new_state_dict["observation_image_mean"], old_state_dict["buffer_observation_image.mean"]
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
new_state_dict["observation_image_std"], old_state_dict["buffer_observation_image.std"]
|
||||
)
|
||||
torch.testing.assert_close(new_state_dict["some_other_param"], old_state_dict["some_other_param"])
|
||||
|
||||
|
||||
def test_legacy_vs_buffer_equivalence():
|
||||
"""Test that legacy implementation produces same results as buffer implementation."""
|
||||
features, norm_map, stats = _dummy_setup()
|
||||
|
||||
# Create both legacy and buffer implementations
|
||||
legacy_normalize = NormalizeLegacy(features=features, norm_map=norm_map, stats=stats)
|
||||
buffer_normalize = Normalize(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
legacy_unnormalize = UnnormalizeLegacy(features=features, norm_map=norm_map, stats=stats)
|
||||
buffer_unnormalize = Unnormalize(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
# Test with random batch
|
||||
batch = _random_batch(stats)
|
||||
|
||||
# Compare normalize outputs
|
||||
legacy_norm_output = legacy_normalize(batch)
|
||||
buffer_norm_output = buffer_normalize(batch)
|
||||
|
||||
for key in legacy_norm_output:
|
||||
torch.testing.assert_close(legacy_norm_output[key], buffer_norm_output[key])
|
||||
|
||||
# Compare unnormalize outputs (using normalized batch)
|
||||
legacy_unnorm_output = legacy_unnormalize(legacy_norm_output)
|
||||
buffer_unnorm_output = buffer_unnormalize(buffer_norm_output)
|
||||
|
||||
for key in legacy_unnorm_output:
|
||||
torch.testing.assert_close(legacy_unnorm_output[key], buffer_unnorm_output[key])
|
||||
Reference in New Issue
Block a user