From abcbc16126568a01fe4b9d9407aa8c198059e4ae Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Fri, 8 Aug 2025 13:23:10 +0200 Subject: [PATCH] refactor(normalization): remove Normalize and Unnormalize classes - Deleted the Normalize and Unnormalize classes from the normalization module to streamline the codebase. - Updated tests to ensure compatibility with the removal of these classes, focusing on the new NormalizerProcessor and UnnormalizerProcessor implementations. - Enhanced the handling of normalization statistics and improved overall code clarity. --- src/lerobot/policies/normalize.py | 420 -------------------- tests/processor/test_normalize_processor.py | 154 +++++++ 2 files changed, 154 insertions(+), 420 deletions(-) delete mode 100644 src/lerobot/policies/normalize.py diff --git a/src/lerobot/policies/normalize.py b/src/lerobot/policies/normalize.py deleted file mode 100644 index 119055873..000000000 --- a/src/lerobot/policies/normalize.py +++ /dev/null @@ -1,420 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np -import torch -from torch import Tensor, nn - -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature - - -def create_stats_buffers( - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, -) -> dict[str, dict[str, nn.ParameterDict]]: - """ - Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max - statistics. - - Args: (see Normalize and Unnormalize) - - Returns: - dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing - `nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation. - """ - stats_buffers = {} - - for key, ft in features.items(): - norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - assert isinstance(norm_mode, NormalizationMode) - - shape = tuple(ft.shape) - - if ft.type is FeatureType.VISUAL: - # sanity checks - assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}" - c, h, w = shape - assert c < h and c < w, f"{key} is not channel first ({shape=})" - # override image shape to be invariant to height and width - shape = (c, 1, 1) - - # Note: we initialize mean, std, min, max to infinity. They should be overwritten - # downstream by `stats` or `policy.load_state_dict`, as expected. During forward, - # we assert they are not infinity anymore. - - buffer = {} - if norm_mode is NormalizationMode.MEAN_STD: - mean = torch.ones(shape, dtype=torch.float32) * torch.inf - std = torch.ones(shape, dtype=torch.float32) * torch.inf - buffer = nn.ParameterDict( - { - "mean": nn.Parameter(mean, requires_grad=False), - "std": nn.Parameter(std, requires_grad=False), - } - ) - elif norm_mode is NormalizationMode.MIN_MAX: - min = torch.ones(shape, dtype=torch.float32) * torch.inf - max = torch.ones(shape, dtype=torch.float32) * torch.inf - buffer = nn.ParameterDict( - { - "min": nn.Parameter(min, requires_grad=False), - "max": nn.Parameter(max, requires_grad=False), - } - ) - - # TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch) - if stats: - if isinstance(stats[key]["mean"], np.ndarray): - if norm_mode is NormalizationMode.MEAN_STD: - buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) - buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) - elif norm_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) - buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) - elif isinstance(stats[key]["mean"], torch.Tensor): - # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated - # tensors anywhere (for example, when we use the same stats for normalization and - # unnormalization). See the logic here - # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. - if norm_mode is NormalizationMode.MEAN_STD: - buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32) - buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) - elif norm_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32) - buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) - else: - type_ = type(stats[key]["mean"]) - raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.") - - stats_buffers[key] = buffer - return stats_buffers - - -def _no_stats_error_str(name: str) -> str: - return ( - f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a " - "pretrained model." - ) - - -class Normalize(nn.Module): - """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training.""" - - def __init__( - self, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, - ): - """ - Args: - shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values - are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing - mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape - is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format. - modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values - are their normalization modes among: - - "mean_std": subtract the mean and divide by standard deviation. - - "min_max": map to [-1, 1] range. - stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") - and values are dictionaries of statistic types and their values (e.g. - `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for - training the model for the first time, these statistics will overwrite the default buffers. If - not provided, as expected for finetuning or evaluation, the default buffers should to be - overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the - dataset is not needed to get the stats, since they are already in the policy state_dict. - """ - super().__init__() - self.features = features - self.norm_map = norm_map - self.stats = stats - stats_buffers = create_stats_buffers(features, norm_map, stats) - for key, buffer in stats_buffers.items(): - setattr(self, "buffer_" + key.replace(".", "_"), buffer) - - # TODO(rcadene): should we remove torch.no_grad? - @torch.no_grad() - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - # TODO: Remove this shallow copy - batch = dict(batch) # shallow copy avoids mutating the input batch - for key, ft in self.features.items(): - if key not in batch: - # FIXME(aliberts, rcadene): This might lead to silent fail! - continue - - norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - buffer = getattr(self, "buffer_" + key.replace(".", "_")) - - if norm_mode is NormalizationMode.MEAN_STD: - mean = buffer["mean"] - std = buffer["std"] - assert not torch.isinf(mean).any(), _no_stats_error_str("mean") - assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[key] = (batch[key] - mean) / (std + 1e-8) - elif norm_mode is NormalizationMode.MIN_MAX: - min = buffer["min"] - max = buffer["max"] - assert not torch.isinf(min).any(), _no_stats_error_str("min") - assert not torch.isinf(max).any(), _no_stats_error_str("max") - # normalize to [0,1] - batch[key] = (batch[key] - min) / (max - min + 1e-8) - # normalize to [-1, 1] - batch[key] = batch[key] * 2 - 1 - else: - raise ValueError(norm_mode) - return batch - - -class Unnormalize(nn.Module): - """ - Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their - original range used by the environment. - """ - - def __init__( - self, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, - ): - """ - Args: - shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values - are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing - mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape - is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format. - modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values - are their normalization modes among: - - "mean_std": subtract the mean and divide by standard deviation. - - "min_max": map to [-1, 1] range. - stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") - and values are dictionaries of statistic types and their values (e.g. - `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for - training the model for the first time, these statistics will overwrite the default buffers. If - not provided, as expected for finetuning or evaluation, the default buffers should to be - overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the - dataset is not needed to get the stats, since they are already in the policy state_dict. - """ - super().__init__() - self.features = features - self.norm_map = norm_map - self.stats = stats - # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` - stats_buffers = create_stats_buffers(features, norm_map, stats) - for key, buffer in stats_buffers.items(): - setattr(self, "buffer_" + key.replace(".", "_"), buffer) - - # TODO(rcadene): should we remove torch.no_grad? - @torch.no_grad() - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - batch = dict(batch) # shallow copy avoids mutating the input batch - for key, ft in self.features.items(): - if key not in batch: - continue - - norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - buffer = getattr(self, "buffer_" + key.replace(".", "_")) - - if norm_mode is NormalizationMode.MEAN_STD: - mean = buffer["mean"] - std = buffer["std"] - assert not torch.isinf(mean).any(), _no_stats_error_str("mean") - assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[key] = batch[key] * std + mean - elif norm_mode is NormalizationMode.MIN_MAX: - min = buffer["min"] - max = buffer["max"] - assert not torch.isinf(min).any(), _no_stats_error_str("min") - assert not torch.isinf(max).any(), _no_stats_error_str("max") - batch[key] = (batch[key] + 1) / 2 - batch[key] = batch[key] * (max - min) + min - else: - raise ValueError(norm_mode) - return batch - - -# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization -# and remove the `Normalize` and `Unnormalize` classes. -def _initialize_stats_buffers( - module: nn.Module, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, -) -> None: - """Register statistics buffers (mean/std or min/max) on the given *module*. - - The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`, - but is factored out so it can be reused by both classes and stay in sync. - """ - for key, ft in features.items(): - norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - shape: tuple[int, ...] = tuple(ft.shape) - if ft.type is FeatureType.VISUAL: - # reduce spatial dimensions, keep channel dimension only - c, *_ = shape - shape = (c, 1, 1) - - prefix = key.replace(".", "_") - - if norm_mode is NormalizationMode.MEAN_STD: - mean = torch.full(shape, torch.inf, dtype=torch.float32) - std = torch.full(shape, torch.inf, dtype=torch.float32) - - if stats and key in stats and "mean" in stats[key] and "std" in stats[key]: - mean_data = stats[key]["mean"] - std_data = stats[key]["std"] - if isinstance(mean_data, torch.Tensor): - # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated - # tensors anywhere (for example, when we use the same stats for normalization and - # unnormalization). See the logic here - # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. - mean = mean_data.clone().to(dtype=torch.float32) - std = std_data.clone().to(dtype=torch.float32) - else: - raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") - - module.register_buffer(f"{prefix}_mean", mean) - module.register_buffer(f"{prefix}_std", std) - continue - - if norm_mode is NormalizationMode.MIN_MAX: - min_val = torch.full(shape, torch.inf, dtype=torch.float32) - max_val = torch.full(shape, torch.inf, dtype=torch.float32) - - if stats and key in stats and "min" in stats[key] and "max" in stats[key]: - min_data = stats[key]["min"] - max_data = stats[key]["max"] - if isinstance(min_data, torch.Tensor): - min_val = min_data.clone().to(dtype=torch.float32) - max_val = max_data.clone().to(dtype=torch.float32) - else: - raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") - - module.register_buffer(f"{prefix}_min", min_val) - module.register_buffer(f"{prefix}_max", max_val) - continue - - raise ValueError(norm_mode) - - -class NormalizeBuffer(nn.Module): - """Same as `Normalize` but statistics are stored as registered buffers rather than parameters.""" - - def __init__( - self, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, - ): - super().__init__() - self.features = features - self.norm_map = norm_map - - _initialize_stats_buffers(self, features, norm_map, stats) - - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - batch = dict(batch) - for key, ft in self.features.items(): - if key not in batch: - continue - - norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - prefix = key.replace(".", "_") - - if norm_mode is NormalizationMode.MEAN_STD: - mean = getattr(self, f"{prefix}_mean") - std = getattr(self, f"{prefix}_std") - assert not torch.isinf(mean).any(), _no_stats_error_str("mean") - assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[key] = (batch[key] - mean) / (std + 1e-8) - continue - - if norm_mode is NormalizationMode.MIN_MAX: - min_val = getattr(self, f"{prefix}_min") - max_val = getattr(self, f"{prefix}_max") - assert not torch.isinf(min_val).any(), _no_stats_error_str("min") - assert not torch.isinf(max_val).any(), _no_stats_error_str("max") - batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8) - batch[key] = batch[key] * 2 - 1 - continue - - raise ValueError(norm_mode) - - return batch - - -class UnnormalizeBuffer(nn.Module): - """Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics.""" - - def __init__( - self, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, - ): - super().__init__() - self.features = features - self.norm_map = norm_map - - _initialize_stats_buffers(self, features, norm_map, stats) - - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - # batch = dict(batch) - for key, ft in self.features.items(): - if key not in batch: - continue - - norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - prefix = key.replace(".", "_") - - if norm_mode is NormalizationMode.MEAN_STD: - mean = getattr(self, f"{prefix}_mean") - std = getattr(self, f"{prefix}_std") - assert not torch.isinf(mean).any(), _no_stats_error_str("mean") - assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[key] = batch[key] * std + mean - continue - - if norm_mode is NormalizationMode.MIN_MAX: - min_val = getattr(self, f"{prefix}_min") - max_val = getattr(self, f"{prefix}_max") - assert not torch.isinf(min_val).any(), _no_stats_error_str("min") - assert not torch.isinf(max_val).any(), _no_stats_error_str("max") - batch[key] = (batch[key] + 1) / 2 - batch[key] = batch[key] * (max_val - min_val) + min_val - continue - - raise ValueError(norm_mode) - - return batch diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 97c737e0c..5813cc37d 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -25,6 +25,7 @@ from lerobot.processor.normalize_processor import ( UnnormalizerProcessor, _convert_stats_to_tensors, hotswap_stats, + rename_stats, ) from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey @@ -1604,3 +1605,156 @@ def test_hotswap_stats_functional_test(): new_result["observation"]["observation.image"], observation["observation.image"] ) assert not torch.allclose(new_result["action"], action) + + +def test_zero_std_uses_eps(): + """When std == 0, (x-mean)/(std+eps) is well-defined; x==mean should map to 0.""" + features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} + stats = {"observation.state": {"mean": np.array([0.5]), "std": np.array([0.0])}} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats, eps=1e-6) + + observation = {"observation.state": torch.tensor([0.5])} # equals mean + out = normalizer(create_transition(observation=observation)) + assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([0.0])) + + +def test_min_equals_max_maps_to_minus_one(): + """When min == max, MIN_MAX path maps to -1 after [-1,1] scaling for x==min.""" + features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + norm_map = {FeatureType.STATE: NormalizationMode.MIN_MAX} + stats = {"observation.state": {"min": np.array([2.0]), "max": np.array([2.0])}} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats, eps=1e-6) + + observation = {"observation.state": torch.tensor([2.0])} + out = normalizer(create_transition(observation=observation)) + assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([-1.0])) + + +def test_action_normalized_despite_normalize_keys(): + """Action normalization is independent of normalize_keys filter for observations.""" + features = { + "observation.state": PolicyFeature(FeatureType.STATE, (1,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD} + stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} + normalizer = NormalizerProcessor( + features=features, norm_map=norm_map, stats=stats, normalize_keys={"observation.state"} + ) + + transition = create_transition( + observation={"observation.state": torch.tensor([3.0])}, action=torch.tensor([3.0, 3.0]) + ) + out = normalizer(transition) + # (3-1)/2 = 1.0 ; (3-(-1))/4 = 1.0 + assert torch.allclose(out[TransitionKey.ACTION], torch.tensor([1.0, 1.0])) + + +def test_unnormalize_observations_mean_std_and_min_max(): + features = { + "observation.ms": PolicyFeature(FeatureType.STATE, (2,)), + "observation.mm": PolicyFeature(FeatureType.STATE, (2,)), + } + # Build two processors: one mean/std and one min/max + unnorm_ms = UnnormalizerProcessor( + features={"observation.ms": features["observation.ms"]}, + norm_map={FeatureType.STATE: NormalizationMode.MEAN_STD}, + stats={"observation.ms": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}}, + ) + unnorm_mm = UnnormalizerProcessor( + features={"observation.mm": features["observation.mm"]}, + norm_map={FeatureType.STATE: NormalizationMode.MIN_MAX}, + stats={"observation.mm": {"min": np.array([0.0, -2.0]), "max": np.array([2.0, 2.0])}}, + ) + + tr = create_transition( + observation={ + "observation.ms": torch.tensor([0.0, 0.0]), # → mean + "observation.mm": torch.tensor([0.0, 0.0]), # → mid-point + } + ) + out_ms = unnorm_ms(tr)[TransitionKey.OBSERVATION]["observation.ms"] + out_mm = unnorm_mm(tr)[TransitionKey.OBSERVATION]["observation.mm"] + assert torch.allclose(out_ms, torch.tensor([1.0, -1.0])) + assert torch.allclose(out_mm, torch.tensor([1.0, 0.0])) # mid of [0,2] and [-2,2] + + +def test_rename_stats_basic(): + orig = { + "observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}, + "action": {"mean": np.array([0.0])}, + } + mapping = {"observation.state": "observation.robot_state"} + renamed = rename_stats(orig, mapping) + assert "observation.robot_state" in renamed and "observation.state" not in renamed + # Ensure deep copy: mutate original and verify renamed unaffected + orig["observation.state"]["mean"][0] = 42.0 + assert renamed["observation.robot_state"]["mean"][0] != 42.0 + + +def test_unknown_observation_keys_ignored(): + features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} + stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + obs = {"observation.state": torch.tensor([1.0]), "observation.unknown": torch.tensor([5.0])} + tr = create_transition(observation=obs) + out = normalizer(tr) + + # Unknown key should pass through unchanged and not be tracked + assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.unknown"], obs["observation.unknown"]) + comp = out.get(TransitionKey.COMPLEMENTARY_DATA) or {} + assert "normalized_keys" in comp and "observation.unknown" not in comp["normalized_keys"] + + +def test_batched_action_normalization(): + features = {"action": PolicyFeature(FeatureType.ACTION, (2,))} + norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD} + stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + actions = torch.tensor([[1.0, -1.0], [3.0, 3.0]]) # first equals mean → zeros; second → [1, 1] + out = normalizer(create_transition(action=actions))[TransitionKey.ACTION] + expected = torch.tensor([[0.0, 0.0], [1.0, 1.0]]) + assert torch.allclose(out, expected) + + +def test_complementary_data_preservation(): + features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} + stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + comp = {"existing": 123} + tr = create_transition(observation={"observation.state": torch.tensor([1.0])}, complementary_data=comp) + out = normalizer(tr) + new_comp = out[TransitionKey.COMPLEMENTARY_DATA] + assert new_comp["existing"] == 123 and "normalized_keys" in new_comp + + +def test_roundtrip_normalize_unnormalize_non_identity(): + features = { + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MIN_MAX} + stats = { + "observation.state": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}, + "action": {"min": np.array([-2.0, 0.0]), "max": np.array([2.0, 4.0])}, + } + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + # Add a time dimension in action for broadcasting check (B,T,D) + obs = {"observation.state": torch.tensor([[3.0, 3.0], [1.0, -1.0]])} + act = torch.tensor([[[0.0, -1.0], [1.0, 1.0]]]) # shape (1,2,2) already in [-1,1] + + tr = create_transition(observation=obs, action=act) + out = unnormalizer(normalizer(tr)) + + assert torch.allclose( + out[TransitionKey.OBSERVATION]["observation.state"], obs["observation.state"], atol=1e-5 + ) + assert torch.allclose(out[TransitionKey.ACTION], act, atol=1e-5)