From 6830ca7645a3e2a52fb44952a128f4735506e23d Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Fri, 4 Jul 2025 12:09:40 +0200 Subject: [PATCH] Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. --- src/lerobot/processor/__init__.py | 5 +- src/lerobot/processor/normalize_processor.py | 366 ++++++------------- tests/processor/test_normalize_processor.py | 165 +++------ 3 files changed, 161 insertions(+), 375 deletions(-) diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index f6acdee9e..1b104199b 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .normalize_processor import NormalizationProcessor +from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor from .observation_processor import ( ImageProcessor, StateProcessor, @@ -38,7 +38,8 @@ __all__ = [ "EnvTransition", "ImageProcessor", "InfoProcessor", - "NormalizationProcessor", + "NormalizerProcessor", + "UnnormalizerProcessor", "ObservationProcessor", "ProcessorStep", "RenameProcessor", diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 08a334695..808384ccb 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -30,31 +30,22 @@ def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dic @dataclass -@ProcessorStepRegistry.register(name="observation_normalizer") -class ObservationNormalizer: - """Normalize observations using dataset statistics. +@ProcessorStepRegistry.register(name="normalizer_processor") +class NormalizerProcessor: + """Normalize observations *and* actions in one go. - This processor normalizes selected observation keys using either: - - Standard normalization: ``(x - mean) / (std + eps)`` - - Min-Max normalization to [-1, 1]: ``2 * (x - min) / (max - min + eps) - 1`` + This is a thin convenience wrapper equivalent to:: - Parameters - ---------- - stats : Dict[str, Dict[str, np.ndarray | Tensor]] - Dataset statistics. Each entry must provide either - ``{"mean", "std"}`` or ``{"min", "max"}``. - normalize_keys : set[str] | None, default=None - Observation keys to normalize. ``None`` means all keys - present in both the observation and stats. - eps : float, default=1e-8 - Small constant to avoid division by zero. + proc = RobotProcessor([ObservationNormalizer(stats, ...), ActionNormalizer(action_stats, ...)]) + + Keeping it as a single step is handy for profiling and simplifies + configuration files. """ stats: dict[str, dict[str, Any]] normalize_keys: set[str] | None = None eps: float = 1e-8 - # Cached tensors for performance _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) @classmethod @@ -64,70 +55,66 @@ class ObservationNormalizer: *, normalize_keys: set[str] | None = None, eps: float = 1e-8, - ) -> ObservationNormalizer: - """Create from a LeRobotDataset.""" - # Filter stats to only include observation keys - obs_stats = {k: v for k, v in dataset.meta.stats.items() if k != "action"} - return cls(stats=obs_stats, normalize_keys=normalize_keys, eps=eps) + ) -> NormalizerProcessor: + return cls(stats=dataset.meta.stats, normalize_keys=normalize_keys, eps=eps) def __post_init__(self): self._tensor_stats = _convert_stats_to_tensors(self.stats) - def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = transition[TransitionIndex.OBSERVATION] - + def _normalize_obs(self, observation): if observation is None: - return transition + return None - # Determine which keys to normalize keys_to_norm = ( - self.normalize_keys if self.normalize_keys is not None else set(self._tensor_stats.keys()) + self.normalize_keys + if self.normalize_keys is not None + else {k for k in self._tensor_stats if k != "action"} ) - - # Create a copy to avoid mutating input - processed_obs = dict(observation) - + processed = dict(observation) for key in keys_to_norm: - if key not in processed_obs: + if key not in processed or key not in self._tensor_stats: continue - if key not in self._tensor_stats: - if self.normalize_keys is not None: - # User explicitly requested this key but stats are missing - raise KeyError(f"Stats not found for requested key '{key}'") - continue + orig_val = processed[key] + tensor = ( + orig_val.to(dtype=torch.float32) + if isinstance(orig_val, torch.Tensor) + else torch.as_tensor(orig_val, dtype=torch.float32) + ) + stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} - # Convert to tensor if needed - orig_val = processed_obs[key] - if isinstance(orig_val, torch.Tensor): - tensor = orig_val.to(dtype=torch.float32) - elif isinstance(orig_val, np.ndarray): - tensor = torch.from_numpy(orig_val.astype(np.float32)) - else: - # For lists, tuples, scalars, etc. - tensor = torch.as_tensor(orig_val, dtype=torch.float32) - - stats = self._tensor_stats[key] - # Move stats to same device as data - stats = {k: v.to(device=tensor.device) for k, v in stats.items()} - - # Apply normalization if "mean" in stats and "std" in stats: mean, std = stats["mean"], stats["std"] - processed_obs[key] = (tensor - mean) / (std + self.eps) + processed[key] = (tensor - mean) / (std + self.eps) elif "min" in stats and "max" in stats: min_val, max_val = stats["min"], stats["max"] - # Normalize to [0, 1] then to [-1, 1] - processed_obs[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 - else: - raise ValueError( - f"Stats for key '{key}' must contain either ('mean', 'std') or ('min', 'max')" - ) + processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 + return processed - # Return new transition with normalized observation + def _normalize_action(self, action): + if action is None or "action" not in self._tensor_stats: + return action + + tensor = ( + action.to(dtype=torch.float32) + if isinstance(action, torch.Tensor) + else torch.as_tensor(action, dtype=torch.float32) + ) + stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + return (tensor - mean) / (std + self.eps) + if "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 + raise ValueError("Action stats must contain either ('mean','std') or ('min','max')") + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = self._normalize_obs(transition[TransitionIndex.OBSERVATION]) + action = self._normalize_action(transition[TransitionIndex.ACTION]) return ( - processed_obs, - transition[TransitionIndex.ACTION], + observation, + action, transition[TransitionIndex.REWARD], transition[TransitionIndex.DONE], transition[TransitionIndex.TRUNCATED], @@ -136,149 +123,38 @@ class ObservationNormalizer: ) def get_config(self) -> dict[str, Any]: - return { - "normalize_keys": list(self.normalize_keys) if self.normalize_keys is not None else None, - "eps": self.eps, - } + return {"normalize_keys": list(self.normalize_keys) if self.normalize_keys else None, "eps": self.eps} def state_dict(self) -> dict[str, Tensor]: - flat_state: dict[str, Tensor] = {} + flat = {} for key, sub in self._tensor_stats.items(): for stat_name, tensor in sub.items(): - flat_state[f"{key}.{stat_name}"] = tensor - return flat_state + flat[f"{key}.{stat_name}"] = tensor + return flat def load_state_dict(self, state: Mapping[str, Tensor]) -> None: self._tensor_stats.clear() for flat_key, tensor in state.items(): key, stat_name = flat_key.rsplit(".", 1) - if key not in self._tensor_stats: - self._tensor_stats[key] = {} - self._tensor_stats[key][stat_name] = tensor + self._tensor_stats.setdefault(key, {})[stat_name] = tensor - def reset(self) -> None: - """Nothing to reset for this stateless processor.""" + def reset(self): pass @dataclass -@ProcessorStepRegistry.register(name="action_unnormalizer") -class ActionUnnormalizer: - """Un-normalize actions using dataset statistics. +@ProcessorStepRegistry.register(name="unnormalizer_processor") +class UnnormalizerProcessor: + """Inverse normalisation for observations and actions. - This processor un-normalizes actions using the inverse of normalization: - - Standard: ``action * std + mean`` - - Min-Max from [-1, 1]: ``(action + 1) / 2 * (max - min) + min`` - - Parameters - ---------- - action_stats : Dict[str, np.ndarray | Tensor] - Action statistics containing either ``{"mean", "std"}`` or ``{"min", "max"}``. - eps : float, default=1e-8 - Small constant used during normalization (not used in unnormalization). - """ - - action_stats: dict[str, Any] - eps: float = 1e-8 # Kept for consistency, not used in unnormalization - - # Cached tensors for performance - _tensor_stats: dict[str, Tensor] = field(default_factory=dict, init=False, repr=False) - - @classmethod - def from_lerobot_dataset( - cls, - dataset: LeRobotDataset, - *, - eps: float = 1e-8, - ) -> ActionUnnormalizer: - """Create from a LeRobotDataset.""" - if "action" not in dataset.meta.stats: - raise ValueError("Dataset does not contain action statistics") - return cls(action_stats=dataset.meta.stats["action"], eps=eps) - - def __post_init__(self): - # Convert action stats to tensors - tensor_stats = _convert_stats_to_tensors({"action": self.action_stats}) - self._tensor_stats = tensor_stats["action"] - - def __call__(self, transition: EnvTransition) -> EnvTransition: - action = transition[TransitionIndex.ACTION] - - if action is None: - return transition - - # Convert to tensor if needed - if isinstance(action, torch.Tensor): - action = action.to(dtype=torch.float32) - else: - action = torch.as_tensor(action, dtype=torch.float32) - - # Move stats to same device as action - stats = {k: v.to(device=action.device) for k, v in self._tensor_stats.items()} - - # Apply unnormalization - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - unnormalized_action = action * std + mean - elif "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - # Map from [-1, 1] to [0, 1] then to [min, max] - unnormalized_action = (action + 1) / 2 * (max_val - min_val) + min_val - else: - raise ValueError("Action stats must contain either ('mean', 'std') or ('min', 'max')") - - # Return new transition with unnormalized action - return ( - transition[TransitionIndex.OBSERVATION], - unnormalized_action, - transition[TransitionIndex.REWARD], - transition[TransitionIndex.DONE], - transition[TransitionIndex.TRUNCATED], - transition[TransitionIndex.INFO], - transition[TransitionIndex.COMPLEMENTARY_DATA], - ) - - def get_config(self) -> dict[str, Any]: - return {"eps": self.eps} - - def state_dict(self) -> dict[str, Tensor]: - return dict(self._tensor_stats.items()) - - def load_state_dict(self, state: Mapping[str, Tensor]) -> None: - self._tensor_stats = dict(state) - - def reset(self) -> None: - """Nothing to reset for this stateless processor.""" - pass - - -@dataclass -@ProcessorStepRegistry.register(name="normalization_processor") -class NormalizationProcessor: - """Combined processor that normalizes observations and/or un-normalizes actions. - - This processor combines the functionality of ObservationNormalizer and - ActionUnnormalizer for convenience when both operations are needed. - - Parameters - ---------- - stats : Dict[str, Dict[str, np.ndarray | Tensor]] - Dataset statistics as returned by ``LeRobotDataset.meta.stats``. - normalize_keys : set[str] | None, default=None - Observation keys to normalize. ``None`` means all keys - present in both the observation and stats. - unnormalize_action : bool, default=True - Whether to un-normalize actions. - eps : float, default=1e-8 - Small constant to avoid division by zero. + Exactly mirrors :class:`NormalizerProcessor` but applies the inverse + transform. """ stats: dict[str, dict[str, Any]] - normalize_keys: set[str] | None = None - unnormalize_action: bool = True + unnormalize_keys: set[str] | None = None eps: float = 1e-8 - # Cached tensors for performance _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) @classmethod @@ -286,75 +162,61 @@ class NormalizationProcessor: cls, dataset: LeRobotDataset, *, - normalize_keys: set[str] | None = None, - unnormalize_action: bool = True, + unnormalize_keys: set[str] | None = None, eps: float = 1e-8, - ) -> NormalizationProcessor: - """Create from a LeRobotDataset.""" - return cls( - stats=dataset.meta.stats, - normalize_keys=normalize_keys, - unnormalize_action=unnormalize_action, - eps=eps, - ) + ) -> UnnormalizerProcessor: + return cls(stats=dataset.meta.stats, unnormalize_keys=unnormalize_keys, eps=eps) def __post_init__(self): self._tensor_stats = _convert_stats_to_tensors(self.stats) - def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = transition[TransitionIndex.OBSERVATION] - action = transition[TransitionIndex.ACTION] - - # Normalize observations - if observation is not None: - processed_obs = dict(observation) - keys_to_norm = ( - self.normalize_keys - if self.normalize_keys is not None - else {k for k in self._tensor_stats if k != "action"} + def _unnormalize_obs(self, observation): + if observation is None: + return None + keys = ( + self.unnormalize_keys + if self.unnormalize_keys is not None + else {k for k in self._tensor_stats if k != "action"} + ) + processed = dict(observation) + for key in keys: + if key not in processed or key not in self._tensor_stats: + continue + orig_val = processed[key] + tensor = ( + orig_val.to(dtype=torch.float32) + if isinstance(orig_val, torch.Tensor) + else torch.as_tensor(orig_val, dtype=torch.float32) ) - - for key in keys_to_norm: - if key not in processed_obs or key not in self._tensor_stats: - continue - - orig_val = processed_obs[key] - if isinstance(orig_val, torch.Tensor): - tensor = orig_val.to(dtype=torch.float32) - elif isinstance(orig_val, np.ndarray): - tensor = torch.from_numpy(orig_val.astype(np.float32)) - else: - tensor = torch.as_tensor(orig_val, dtype=torch.float32) - - stats = self._tensor_stats[key] - stats = {k: v.to(device=tensor.device) for k, v in stats.items()} - - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - processed_obs[key] = (tensor - mean) / (std + self.eps) - elif "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - processed_obs[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 - - observation = processed_obs - - # Un-normalize action - if self.unnormalize_action and action is not None and "action" in self._tensor_stats: - if isinstance(action, torch.Tensor): - action = action.to(dtype=torch.float32) - else: - action = torch.as_tensor(action, dtype=torch.float32) - - stats = {k: v.to(device=action.device) for k, v in self._tensor_stats["action"].items()} - + stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} if "mean" in stats and "std" in stats: mean, std = stats["mean"], stats["std"] - action = action * std + mean + processed[key] = tensor * std + mean elif "min" in stats and "max" in stats: min_val, max_val = stats["min"], stats["max"] - action = (action + 1) / 2 * (max_val - min_val) + min_val + processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val + return processed - # Return new transition + def _unnormalize_action(self, action): + if action is None or "action" not in self._tensor_stats: + return action + tensor = ( + action.to(dtype=torch.float32) + if isinstance(action, torch.Tensor) + else torch.as_tensor(action, dtype=torch.float32) + ) + stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + return tensor * std + mean + if "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + return (tensor + 1) / 2 * (max_val - min_val) + min_val + raise ValueError("Action stats must contain either ('mean','std') or ('min','max')") + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = self._unnormalize_obs(transition[TransitionIndex.OBSERVATION]) + action = self._unnormalize_action(transition[TransitionIndex.ACTION]) return ( observation, action, @@ -367,26 +229,22 @@ class NormalizationProcessor: def get_config(self) -> dict[str, Any]: return { - "normalize_keys": list(self.normalize_keys) if self.normalize_keys is not None else None, - "unnormalize_action": self.unnormalize_action, + "unnormalize_keys": list(self.unnormalize_keys) if self.unnormalize_keys else None, "eps": self.eps, } def state_dict(self) -> dict[str, Tensor]: - flat_state: dict[str, Tensor] = {} + flat = {} for key, sub in self._tensor_stats.items(): for stat_name, tensor in sub.items(): - flat_state[f"{key}.{stat_name}"] = tensor - return flat_state + flat[f"{key}.{stat_name}"] = tensor + return flat def load_state_dict(self, state: Mapping[str, Tensor]) -> None: self._tensor_stats.clear() for flat_key, tensor in state.items(): key, stat_name = flat_key.rsplit(".", 1) - if key not in self._tensor_stats: - self._tensor_stats[key] = {} - self._tensor_stats[key][stat_name] = tensor + self._tensor_stats.setdefault(key, {})[stat_name] = tensor - def reset(self) -> None: - """Nothing to reset for this stateless processor.""" + def reset(self): pass diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 8125a3520..e476ec27f 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -5,9 +5,8 @@ import pytest import torch from lerobot.processor.normalize_processor import ( - ActionUnnormalizer, - NormalizationProcessor, - ObservationNormalizer, + NormalizerProcessor, + UnnormalizerProcessor, _convert_stats_to_tensors, ) from lerobot.processor.pipeline import RobotProcessor, TransitionIndex @@ -77,7 +76,7 @@ def test_unsupported_type(): _convert_stats_to_tensors(stats) -# Fixtures for ObservationNormalizer tests +# Fixtures for observation normalisation tests using NormalizerProcessor @pytest.fixture def observation_stats(): return { @@ -94,7 +93,8 @@ def observation_stats(): @pytest.fixture def observation_normalizer(observation_stats): - return ObservationNormalizer(stats=observation_stats) + """Return a NormalizerProcessor that only has observation stats (no action).""" + return NormalizerProcessor(stats=observation_stats) def test_mean_std_normalization(observation_normalizer): @@ -129,7 +129,7 @@ def test_min_max_normalization(observation_normalizer): def test_selective_normalization(observation_stats): - normalizer = ObservationNormalizer(stats=observation_stats, normalize_keys={"observation.image"}) + normalizer = NormalizerProcessor(stats=observation_stats, normalize_keys={"observation.image"}) observation = { "observation.image": torch.tensor([0.7, 0.5, 0.3]), @@ -146,46 +146,9 @@ def test_selective_normalization(observation_stats): assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"]) -def test_missing_stats_error(observation_stats): - normalizer = ObservationNormalizer( - stats={"observation.image": observation_stats["observation.image"]}, - normalize_keys={"observation.image", "observation.missing"}, - ) - - observation = { - "observation.image": torch.tensor([0.5, 0.5, 0.5]), - "observation.missing": torch.tensor([1.0, 2.0]), - } - transition = (observation, None, None, None, None, None, None) - - with pytest.raises(KeyError, match="Stats not found for requested key 'observation.missing'"): - normalizer(transition) - - -@pytest.mark.parametrize( - "input_type,input_value,expected_type", - [ - ("numpy", np.array([0.7, 0.5, 0.3], dtype=np.float32), torch.Tensor), - ("torch", torch.tensor([0.7, 0.5, 0.3]), torch.Tensor), - ], -) -def test_input_types(observation_normalizer, input_type, input_value, expected_type): - observation = { - "observation.image": input_value, - } - transition = (observation, None, None, None, None, None, None) - - normalized_transition = observation_normalizer(transition) - normalized_obs = normalized_transition[TransitionIndex.OBSERVATION] - - expected = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 - assert isinstance(normalized_obs["observation.image"], expected_type) - assert torch.allclose(normalized_obs["observation.image"], expected) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_device_compatibility(observation_stats): - normalizer = ObservationNormalizer(stats=observation_stats) + normalizer = NormalizerProcessor(stats=observation_stats) observation = { "observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(), } @@ -205,11 +168,11 @@ def test_from_lerobot_dataset(): "action": {"mean": [0.0], "std": [1.0]}, # Should be filtered out } - normalizer = ObservationNormalizer.from_lerobot_dataset(mock_dataset) + normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset) - # Check that action stats are filtered out + # Both observation and action statistics should be present in tensor stats assert "observation.image" in normalizer._tensor_stats - assert "action" not in normalizer._tensor_stats + assert "action" in normalizer._tensor_stats def test_state_dict_save_load(observation_normalizer): @@ -217,7 +180,7 @@ def test_state_dict_save_load(observation_normalizer): state_dict = observation_normalizer.state_dict() # Create new normalizer and load state - new_normalizer = ObservationNormalizer(stats={}) + new_normalizer = NormalizerProcessor(stats={}) new_normalizer.load_state_dict(state_dict) # Test that it works the same @@ -248,7 +211,7 @@ def action_stats_min_max(): def test_mean_std_unnormalization(action_stats_mean_std): - unnormalizer = ActionUnnormalizer(action_stats=action_stats_mean_std) + unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std}) normalized_action = torch.tensor([1.0, -0.5, 2.0]) transition = (None, normalized_action, None, None, None, None, None) @@ -262,7 +225,7 @@ def test_mean_std_unnormalization(action_stats_mean_std): def test_min_max_unnormalization(action_stats_min_max): - unnormalizer = ActionUnnormalizer(action_stats=action_stats_min_max) + unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_min_max}) # Actions in [-1, 1] normalized_action = torch.tensor([0.0, -1.0, 1.0]) @@ -284,7 +247,7 @@ def test_min_max_unnormalization(action_stats_min_max): def test_numpy_action_input(action_stats_mean_std): - unnormalizer = ActionUnnormalizer(action_stats=action_stats_mean_std) + unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std}) normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32) transition = (None, normalized_action, None, None, None, None, None) @@ -298,7 +261,7 @@ def test_numpy_action_input(action_stats_mean_std): def test_none_action(action_stats_mean_std): - unnormalizer = ActionUnnormalizer(action_stats=action_stats_mean_std) + unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std}) transition = (None, None, None, None, None, None, None) result = unnormalizer(transition) @@ -308,40 +271,13 @@ def test_none_action(action_stats_mean_std): def test_action_from_lerobot_dataset(): - # Mock dataset mock_dataset = Mock() - mock_dataset.meta.stats = { - "action": {"mean": [0.0], "std": [1.0]}, - "observation.image": {"mean": [0.5], "std": [0.2]}, - } - - unnormalizer = ActionUnnormalizer.from_lerobot_dataset(mock_dataset) - - assert "mean" in unnormalizer._tensor_stats - assert "std" in unnormalizer._tensor_stats + mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}} + unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset) + assert "mean" in unnormalizer._tensor_stats["action"] -def test_missing_action_stats_error(): - mock_dataset = Mock() - mock_dataset.meta.stats = { - "observation.image": {"mean": [0.5], "std": [0.2]}, - } - - with pytest.raises(ValueError, match="Dataset does not contain action statistics"): - ActionUnnormalizer.from_lerobot_dataset(mock_dataset) - - -def test_invalid_stats_error(): - unnormalizer = ActionUnnormalizer(action_stats={"invalid": [1.0]}) - - action = torch.tensor([1.0]) - transition = (None, action, None, None, None, None, None) - - with pytest.raises(ValueError, match="Action stats must contain"): - unnormalizer(transition) - - -# Fixtures for NormalizationProcessor tests +# Fixtures for NormalizerProcessor tests @pytest.fixture def full_stats(): return { @@ -361,11 +297,11 @@ def full_stats(): @pytest.fixture -def normalization_processor(full_stats): - return NormalizationProcessor(stats=full_stats) +def normalizer_processor(full_stats): + return NormalizerProcessor(stats=full_stats) -def test_combined_normalization_unnormalization(normalization_processor): +def test_combined_normalization(normalizer_processor): observation = { "observation.image": torch.tensor([0.7, 0.5, 0.3]), "observation.state": torch.tensor([0.5, 0.0]), @@ -373,16 +309,16 @@ def test_combined_normalization_unnormalization(normalization_processor): action = torch.tensor([1.0, -0.5]) transition = (observation, action, 1.0, False, False, {}, {}) - processed_transition = normalization_processor(transition) + processed_transition = normalizer_processor(transition) # Check normalized observations processed_obs = processed_transition[TransitionIndex.OBSERVATION] expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 assert torch.allclose(processed_obs["observation.image"], expected_image) - # Check unnormalized action + # Check normalized action processed_action = processed_transition[TransitionIndex.ACTION] - expected_action = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0]) + expected_action = torch.tensor([(1.0 - 0.0) / 1.0, (-0.5 - 0.0) / 2.0]) assert torch.allclose(processed_action, expected_action) # Check other fields remain unchanged @@ -390,45 +326,28 @@ def test_combined_normalization_unnormalization(normalization_processor): assert not processed_transition[TransitionIndex.DONE] -def test_disable_action_unnormalization(full_stats): - processor = NormalizationProcessor(stats=full_stats, unnormalize_action=False) - - action = torch.tensor([1.0, -0.5]) - transition = (None, action, None, None, None, None, None) - - processed_transition = processor(transition) - - # Action should remain unchanged - assert torch.allclose(processed_transition[TransitionIndex.ACTION], action) - - def test_processor_from_lerobot_dataset(full_stats): # Mock dataset mock_dataset = Mock() mock_dataset.meta.stats = full_stats - processor = NormalizationProcessor.from_lerobot_dataset( - mock_dataset, normalize_keys={"observation.image"}, unnormalize_action=True - ) + processor = NormalizerProcessor.from_lerobot_dataset(mock_dataset, normalize_keys={"observation.image"}) assert processor.normalize_keys == {"observation.image"} - assert processor.unnormalize_action assert "observation.image" in processor._tensor_stats assert "action" in processor._tensor_stats def test_get_config(full_stats): - processor = NormalizationProcessor( - stats=full_stats, normalize_keys={"observation.image"}, unnormalize_action=False, eps=1e-6 - ) + processor = NormalizerProcessor(stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6) config = processor.get_config() - assert config == {"normalize_keys": ["observation.image"], "unnormalize_action": False, "eps": 1e-6} + assert config == {"normalize_keys": ["observation.image"], "eps": 1e-6} -def test_integration_with_robot_processor(normalization_processor): +def test_integration_with_robot_processor(normalizer_processor): """Test integration with RobotProcessor pipeline""" - robot_processor = RobotProcessor([normalization_processor]) + robot_processor = RobotProcessor([normalizer_processor]) observation = { "observation.image": torch.tensor([0.7, 0.5, 0.3]), @@ -447,7 +366,7 @@ def test_integration_with_robot_processor(normalization_processor): # Edge case tests def test_empty_observation(): stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} - normalizer = ObservationNormalizer(stats=stats) + normalizer = NormalizerProcessor(stats=stats) transition = (None, None, None, None, None, None, None) result = normalizer(transition) @@ -456,7 +375,7 @@ def test_empty_observation(): def test_empty_stats(): - normalizer = ObservationNormalizer(stats={}) + normalizer = NormalizerProcessor(stats={}) observation = {"observation.image": torch.tensor([0.5])} transition = (observation, None, None, None, None, None, None) @@ -466,12 +385,20 @@ def test_empty_stats(): def test_partial_stats(): - stats = { - "observation.image": {"mean": [0.5]}, # Missing std - } - normalizer = ObservationNormalizer(stats=stats) + """If statistics are incomplete, the value should pass through unchanged.""" + stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max) + normalizer = NormalizerProcessor(stats=stats) observation = {"observation.image": torch.tensor([0.7])} transition = (observation, None, None, None, None, None, None) - with pytest.raises(ValueError, match="must contain either"): - normalizer(transition) + processed = normalizer(transition)[TransitionIndex.OBSERVATION] + assert torch.allclose(processed["observation.image"], observation["observation.image"]) + + +def test_missing_action_stats_no_error(): + mock_dataset = Mock() + mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + + processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset) + # The tensor stats should not contain the 'action' key + assert "action" not in processor._tensor_stats