From cf25b778053234b3dcc8e3144fbd4cdd206521ec Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Thu, 5 Jun 2025 13:19:37 +0200 Subject: [PATCH] Add test for the backward compatibility --- test_normalize_buffer.py | 268 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 268 insertions(+) create mode 100644 test_normalize_buffer.py diff --git a/test_normalize_buffer.py b/test_normalize_buffer.py new file mode 100644 index 000000000..143d0b197 --- /dev/null +++ b/test_normalize_buffer.py @@ -0,0 +1,268 @@ +import pytest +import torch + +from lerobot.common.policies.normalize import ( + Normalize, + NormalizeBuffer, + Unnormalize, + UnnormalizeBuffer, +) +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + + +def _dummy_setup(): + # feature definitions + features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(5,)), + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)), + } + + # map feature types to a normalization strategy + norm_map = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.VISUAL: NormalizationMode.MIN_MAX, + } + + # build statistics (include all stats for each feature) + stats = { + "observation.state": { + "mean": torch.arange(5, dtype=torch.float32), + "std": torch.arange(1, 6, dtype=torch.float32), + "min": torch.zeros(5, dtype=torch.float32), + "max": torch.ones(5, dtype=torch.float32) * 10.0, + }, + # image statistics use (c,1,1) so they broadcast on spatial dims + "observation.image": { + "mean": torch.ones(3, 1, 1, dtype=torch.float32) * 127.5, + "std": torch.ones(3, 1, 1, dtype=torch.float32) * 50.0, + "min": torch.zeros(3, 1, 1, dtype=torch.float32), + "max": torch.ones(3, 1, 1, dtype=torch.float32) * 255.0, + }, + } + + return features, norm_map, stats + + +def _random_batch(stats): + """Generate a batch consistent with the provided statistics.""" + torch.manual_seed(0) + batch_size = 2 + + state_mean = stats["observation.state"]["mean"] + state_std = stats["observation.state"]["std"] + state = torch.randn(batch_size, 5) * state_std + state_mean # shape (b,5) + + image_min = stats["observation.image"]["min"] + image_max = stats["observation.image"]["max"] + image = torch.rand(batch_size, 3, 64, 64) * (image_max - image_min) + image_min # shape (b,3,64,64) + + return { + "observation.state": state, + "observation.image": image, + } + + +@pytest.mark.parametrize( + "module_pair", + [ + (Normalize, NormalizeBuffer), + (Unnormalize, UnnormalizeBuffer), + ], +) +def test_equivalence(module_pair): + features, norm_map, stats = _dummy_setup() + ParamCls, BufferCls = module_pair # noqa: N806 + + param_module = ParamCls(features=features, norm_map=norm_map, stats=stats) + buffer_module = BufferCls(features=features, norm_map=norm_map, stats=stats) + + batch = _random_batch(stats) + + out_param = param_module(batch) + out_buffer = buffer_module(batch) + + # every tensor in the output dictionaries should match closely + for key in out_param: + torch.testing.assert_close(out_param[key], out_buffer[key]) + + +def test_round_trip(): + """Normalize then unnormalize should give the original input back for both impls.""" + features, norm_map, stats = _dummy_setup() + + norm_p = Normalize(features, norm_map, stats) + unnorm_p = Unnormalize(features, norm_map, stats) + + norm_b = NormalizeBuffer(features, norm_map, stats) + unnorm_b = UnnormalizeBuffer(features, norm_map, stats) + + batch = _random_batch(stats) + recovered_p = unnorm_p(norm_p(batch)) + recovered_b = unnorm_b(norm_b(batch)) + + for key in batch: + torch.testing.assert_close(recovered_p[key], batch[key]) + torch.testing.assert_close(recovered_b[key], batch[key]) + + +@pytest.mark.parametrize( + "image_shape,use_numpy", + [ + ((3, 64, 64), True), + ((3, 128, 128), False), + ], +) +def test_various_shapes_and_numpy(image_shape, use_numpy): + """Ensure equivalence and round-trip correctness for different image shapes and numpy stats.""" + # feature definitions (state dim fixed at 5) + features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(5,)), + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=image_shape), + } + + norm_map = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.VISUAL: NormalizationMode.MIN_MAX, + } + + # statistics (torch or numpy) + state_mean = torch.arange(5, dtype=torch.float32) + state_std = torch.arange(1, 6, dtype=torch.float32) + img_min = torch.zeros(image_shape[0], 1, 1, dtype=torch.float32) + img_max = torch.ones(image_shape[0], 1, 1, dtype=torch.float32) * 10.0 # simple range [0,10] + + if use_numpy: + state_mean_stats = state_mean.numpy() + state_std_stats = state_std.numpy() + img_min_stats = img_min.numpy() + img_max_stats = img_max.numpy() + else: + state_mean_stats = state_mean + state_std_stats = state_std + img_min_stats = img_min + img_max_stats = img_max + + stats = { + "observation.state": {"mean": state_mean_stats, "std": state_std_stats}, + "observation.image": {"min": img_min_stats, "max": img_max_stats}, + } + + # instantiate modules + norm_p = Normalize(features, norm_map, stats) + unnorm_p = Unnormalize(features, norm_map, stats) + norm_b = NormalizeBuffer(features, norm_map, stats) + unnorm_b = UnnormalizeBuffer(features, norm_map, stats) + + # build random batch following stats + batch_size = 3 + torch.manual_seed(42) + state = torch.randn(batch_size, 5) * state_std + state_mean + image = torch.rand(batch_size, *image_shape) * (img_max - img_min) + img_min + + batch = {"observation.state": state, "observation.image": image} + + # equivalence between param and buffer implementations + torch.testing.assert_close(norm_p(batch)["observation.state"], norm_b(batch)["observation.state"]) + torch.testing.assert_close(norm_p(batch)["observation.image"], norm_b(batch)["observation.image"]) + + # round-trip + recovered_p = unnorm_p(norm_p(batch)) + recovered_b = unnorm_b(norm_b(batch)) + + for key in batch: + torch.testing.assert_close(recovered_p[key], batch[key]) + torch.testing.assert_close(recovered_b[key], batch[key]) + + +def test_state_dict_conversion(): + """Test that state dict can be converted from Normalize to NormalizeBuffer format.""" + from lerobot.common.policies.normalize import convert_normalize_to_buffer_state_dict + + features, norm_map, stats = _dummy_setup() + + # Create Normalize module and get its state dict + normalize_module = Normalize(features=features, norm_map=norm_map, stats=stats) + old_state_dict = normalize_module.state_dict() + + # Convert state dict + new_state_dict = convert_normalize_to_buffer_state_dict(old_state_dict) + + # Create NormalizeBuffer module and load converted state dict + buffer_module = NormalizeBuffer(features=features, norm_map=norm_map, stats=None) + buffer_module.load_state_dict(new_state_dict) + + # Test that both modules produce the same output + batch = _random_batch(stats) + + old_output = normalize_module(batch) + new_output = buffer_module(batch) + + for key in old_output: + torch.testing.assert_close(old_output[key], new_output[key]) + + +def test_state_dict_conversion_unnormalize(): + """Test that state dict can be converted from Unnormalize to UnnormalizeBuffer format.""" + from lerobot.common.policies.normalize import convert_normalize_to_buffer_state_dict + + features, norm_map, stats = _dummy_setup() + + # Create Unnormalize module and get its state dict + unnormalize_module = Unnormalize(features=features, norm_map=norm_map, stats=stats) + old_state_dict = unnormalize_module.state_dict() + + # Convert state dict + new_state_dict = convert_normalize_to_buffer_state_dict(old_state_dict) + + # Create UnnormalizeBuffer module and load converted state dict + buffer_module = UnnormalizeBuffer(features=features, norm_map=norm_map, stats=None) + buffer_module.load_state_dict(new_state_dict) + + # Test that both modules produce the same output on normalized data + batch = _random_batch(stats) + + # First normalize the batch + normalize_module = Normalize(features=features, norm_map=norm_map, stats=stats) + normalized_batch = normalize_module(batch) + + old_output = unnormalize_module(normalized_batch) + new_output = buffer_module(normalized_batch) + + for key in old_output: + torch.testing.assert_close(old_output[key], new_output[key]) + + +def test_state_dict_conversion_key_format(): + """Test that conversion produces the expected key format.""" + from lerobot.common.policies.normalize import convert_normalize_to_buffer_state_dict + + # Mock state dict with the old format + old_state_dict = { + "buffer_observation_image.mean": torch.randn(3, 1, 1), + "buffer_observation_image.std": torch.randn(3, 1, 1), + "buffer_observation_state.min": torch.randn(5), + "buffer_observation_state.max": torch.randn(5), + "some_other_param": torch.randn(10), # Non-normalization parameter + } + + new_state_dict = convert_normalize_to_buffer_state_dict(old_state_dict) + + # Check expected key transformations + expected_keys = { + "observation_image_mean", + "observation_image_std", + "observation_state_min", + "observation_state_max", + "some_other_param", # Should be unchanged + } + + assert set(new_state_dict.keys()) == expected_keys + + # Check values are preserved + torch.testing.assert_close( + new_state_dict["observation_image_mean"], old_state_dict["buffer_observation_image.mean"] + ) + torch.testing.assert_close( + new_state_dict["observation_image_std"], old_state_dict["buffer_observation_image.std"] + ) + torch.testing.assert_close(new_state_dict["some_other_param"], old_state_dict["some_other_param"])