diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 3f0267eae..c5c4af9fa 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -25,9 +25,11 @@ from .pipeline import ( ActionProcessor, DoneProcessor, EnvTransition, + IdentityProcessor, InfoProcessor, ObservationProcessor, ProcessorStep, + ProcessorStepRegistry, RewardProcessor, RobotProcessor, TruncatedProcessor, @@ -39,12 +41,14 @@ __all__ = [ "DeviceProcessor", "DoneProcessor", "EnvTransition", + "IdentityProcessor", "ImageProcessor", "InfoProcessor", "NormalizerProcessor", "UnnormalizerProcessor", "ObservationProcessor", "ProcessorStep", + "ProcessorStepRegistry", "RenameProcessor", "RewardProcessor", "RobotProcessor", diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index 0ff6ef9da..7dcd8abda 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -33,6 +33,9 @@ class DeviceProcessor: device: str = "cpu" + def __post_init__(self): + self.non_blocking = "cuda" in self.device + def __call__(self, transition: EnvTransition) -> EnvTransition: observation: dict[str, torch.Tensor] = transition[TransitionIndex.OBSERVATION] action = transition[TransitionIndex.ACTION] @@ -43,7 +46,9 @@ class DeviceProcessor: complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA] if observation is not None: - observation = {k: v.to(self.device) for k, v in observation.items()} + observation = { + k: v.to(self.device, non_blocking=self.non_blocking) for k, v in observation.items() + } if action is not None: action = action.to(self.device) diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 6b3cae5be..8e836dd50 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -1,12 +1,13 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Mapping +from typing import Any, Mapping, Optional, Set import numpy as np import torch from torch import Tensor +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex @@ -45,8 +46,18 @@ class NormalizerProcessor: the normalize_keys parameter. """ - stats: dict[str, dict[str, Any]] - normalize_keys: set[str] | None = None + # Features and normalisation map are mandatory to match the design of normalize.py + features: dict[str, PolicyFeature] + norm_map: dict[FeatureType, NormalizationMode] + + # Pre-computed statistics coming from dataset.meta.stats for instance. + stats: Optional[dict[str, dict[str, Any]]] = None + + # Explicit subset of keys to normalise. If ``None`` every key (except + # "action") found in ``stats`` will be normalised. Using a ``set`` makes + # membership checks O(1). + normalize_keys: Optional[Set[str]] = None + eps: float = 1e-8 _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) @@ -55,24 +66,48 @@ class NormalizerProcessor: def from_lerobot_dataset( cls, dataset: LeRobotDataset, + features: dict[str, PolicyFeature], + norm_map: dict[FeatureType, NormalizationMode], *, - normalize_keys: set[str] | None = None, + normalize_keys: Optional[Set[str]] = None, eps: float = 1e-8, - ) -> NormalizerProcessor: - return cls(stats=dataset.meta.stats, normalize_keys=normalize_keys, eps=eps) + ) -> "NormalizerProcessor": + """Factory helper that pulls statistics from a :class:`LeRobotDataset`. + + The features and norm_map parameters are mandatory to match the design + pattern used in normalize.py. + """ + + return cls( + features=features, + norm_map=norm_map, + stats=dataset.meta.stats, + normalize_keys=normalize_keys, + eps=eps, + ) def __post_init__(self): + # Convert statistics once so we avoid repeated numpy→Tensor conversions + # during runtime. + self.stats = self.stats or {} self._tensor_stats = _convert_stats_to_tensors(self.stats) + # Ensure *normalize_keys* is a set for fast look-ups and compare by + # value later when returning the configuration. + if self.normalize_keys is not None and not isinstance(self.normalize_keys, set): + self.normalize_keys = set(self.normalize_keys) + def _normalize_obs(self, observation): if observation is None: return None - keys_to_norm = ( - self.normalize_keys - if self.normalize_keys is not None - else {k for k in self._tensor_stats if k != "action"} - ) + # Decide which keys should be normalised for this call. + if self.normalize_keys is not None: + keys_to_norm = self.normalize_keys + else: + # Use feature map to skip action keys. + keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION} + processed = dict(observation) for key in keys_to_norm: if key not in processed or key not in self._tensor_stats: @@ -126,7 +161,11 @@ class NormalizerProcessor: ) def get_config(self) -> dict[str, Any]: - return {"normalize_keys": list(self.normalize_keys) if self.normalize_keys else None, "eps": self.eps} + config = {"eps": self.eps} + if self.normalize_keys is not None: + # Serialise as a list for YAML / JSON friendliness + config["normalize_keys"] = sorted(self.normalize_keys) + return config def state_dict(self) -> dict[str, Tensor]: flat = {} @@ -154,8 +193,9 @@ class UnnormalizerProcessor: transform. """ - stats: dict[str, dict[str, Any]] - unnormalize_keys: set[str] | None = None + features: dict[str, PolicyFeature] + norm_map: dict[FeatureType, NormalizationMode] + stats: Optional[dict[str, dict[str, Any]]] = None eps: float = 1e-8 _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) @@ -164,23 +204,21 @@ class UnnormalizerProcessor: def from_lerobot_dataset( cls, dataset: LeRobotDataset, + features: dict[str, PolicyFeature], + norm_map: dict[FeatureType, NormalizationMode], *, - unnormalize_keys: set[str] | None = None, eps: float = 1e-8, - ) -> UnnormalizerProcessor: - return cls(stats=dataset.meta.stats, unnormalize_keys=unnormalize_keys, eps=eps) + ) -> "UnnormalizerProcessor": + return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps) def __post_init__(self): + self.stats = self.stats or {} self._tensor_stats = _convert_stats_to_tensors(self.stats) def _unnormalize_obs(self, observation): if observation is None: return None - keys = ( - self.unnormalize_keys - if self.unnormalize_keys is not None - else {k for k in self._tensor_stats if k != "action"} - ) + keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION] processed = dict(observation) for key in keys: if key not in processed or key not in self._tensor_stats: @@ -231,10 +269,7 @@ class UnnormalizerProcessor: ) def get_config(self) -> dict[str, Any]: - return { - "unnormalize_keys": list(self.unnormalize_keys) if self.unnormalize_keys else None, - "eps": self.eps, - } + return {"eps": self.eps} def state_dict(self) -> dict[str, Tensor]: flat = {} diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index ac900fcff..7c88f6d3a 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -42,7 +42,7 @@ class TransitionIndex(IntEnum): # (observation, action, reward, done, truncated, info, complementary_data) EnvTransition = Tuple[ - Any | None, # observation + dict[str, Any] | None, # observation Any | None, # action float | None, # reward bool | None, # done @@ -162,6 +162,79 @@ class ProcessorStep(Protocol): def reset(self) -> None: ... +def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401 + """Convert a *batch* dict coming from Learobot replay/dataset code into an + ``EnvTransition`` tuple. + + The function is intentionally **strictly positional** – it maps well known + keys to the fixed slot order used inside the pipeline. Missing keys are + filled with sane defaults (``None`` or ``0.0``/``False``). + + Keys recognised (case-sensitive): + + * "observation.*" (keys starting with "observation." are grouped into observation dict) + * "action" + * "next.reward" + * "next.done" + * "next.truncated" + * "info" + + Additional keys are ignored so that existing dataloaders can carry extra + metadata without breaking the processor. + """ + + # Handle observation and observation.* keys + observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} + + observation = None + if observation_keys: + observation = {} + # Add observation.* keys to the observation dict, removing the "observation." prefix + for key, value in observation_keys.items(): + observation[key] = value + + return ( + observation, + batch.get("action"), + batch.get("next.reward", 0.0), + batch.get("next.done", False), + batch.get("next.truncated", False), + batch.get("info", {}), + {}, + ) + + +def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401 + """Inverse of :pyfunc:`_default_batch_to_transition`. Returns a dict with + the canonical field names used throughout *LeRobot*. + """ + + ( + observation, + action, + reward, + done, + truncated, + info, + _, + ) = transition + + batch = { + "action": action, + "next.reward": reward, + "next.done": done, + "next.truncated": truncated, + "info": info, + } + + # Handle observation - flatten dict to observation.* keys if it's a dict + if isinstance(observation, dict): + # Check if this looks like a dict that was created from observation.* keys + for key, value in observation.items(): + batch[key] = value + return batch + + @dataclass class RobotProcessor(ModelHubMixin): """ @@ -200,6 +273,13 @@ class RobotProcessor(ModelHubMixin): name: str = "RobotProcessor" seed: int | None = None + to_transition: Callable[[dict[str, Any]], EnvTransition] = field( + default_factory=lambda: _default_batch_to_transition, repr=False + ) + to_batch: Callable[[EnvTransition], dict[str, Any]] = field( + default_factory=lambda: _default_transition_to_batch, repr=False + ) + # Processor-level hooks # A hook can optionally return a modified transition. If it returns # ``None`` the current value is left untouched. @@ -211,14 +291,29 @@ class RobotProcessor(ModelHubMixin): ) reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False) - def __call__(self, transition: EnvTransition) -> EnvTransition: - """Run *transition* through every step, firing hooks on the way.""" + def __call__(self, data: EnvTransition | dict[str, Any]): + """Process *data* through all steps. - # Basic validation with helpful error message + The method accepts **either** the classic :pydata:`EnvTransition` tuple + **or** a *batch* dictionary (like the ones returned by + :class:`lerobot.utils.buffer.ReplayBuffer` or + :class:`lerobot.datasets.lerobot_dataset.LeRobotDataset`). If a dict is + supplied it is first converted to the internal tuple format using + :pyattr:`to_transition`; after all steps are executed the tuple is + transformed back into a dict with :pyattr:`to_batch` and the result is + returned – thereby preserving the caller's original data type. + """ + + called_with_batch = isinstance(data, dict) + + transition = self.to_transition(data) if called_with_batch else data + + # Basic validation with helpful error message for tuple input if not isinstance(transition, tuple) or len(transition) != 7: raise ValueError( - f"EnvTransition must be a 7-tuple of (observation, action, reward, done, truncated, info, complementary_data), " - f"got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}" + "EnvTransition must be a 7-tuple of (observation, action, reward, done, " + "truncated, info, complementary_data). " + f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}." ) for idx, processor_step in enumerate(self.steps): @@ -234,7 +329,7 @@ class RobotProcessor(ModelHubMixin): if updated is not None: transition = updated - return transition + return self.to_batch(transition) if called_with_batch else transition def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]: """Yield the intermediate Transition instances after each processor step.""" @@ -737,3 +832,22 @@ class InfoProcessor: *transition[TransitionIndex.COMPLEMENTARY_DATA :], ) return transition + + +class IdentityProcessor: + """Identity processor that does nothing.""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py new file mode 100644 index 000000000..5b399a675 --- /dev/null +++ b/tests/processor/test_batch_conversion.py @@ -0,0 +1,288 @@ +import torch + +from lerobot.processor.pipeline import ( + RobotProcessor, + TransitionIndex, + _default_batch_to_transition, + _default_transition_to_batch, +) + + +def _dummy_batch(): + """Create a dummy batch using the new format with observation.* and next.* keys.""" + return { + "observation.image.left": torch.randn(1, 3, 128, 128), + "observation.image.right": torch.randn(1, 3, 128, 128), + "observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]), + "action": torch.tensor([[0.5]]), + "next.reward": 1.0, + "next.done": False, + "next.truncated": False, + "info": {"key": "value"}, + } + + +def test_observation_grouping_roundtrip(): + """Test that observation.* keys are properly grouped and ungrouped.""" + proc = RobotProcessor([]) + batch_in = _dummy_batch() + batch_out = proc(batch_in) + + # Check that all observation.* keys are preserved + original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")} + reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")} + + assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys()) + + # Check tensor values + assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"]) + assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"]) + assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"]) + + # Check other fields + assert torch.allclose(batch_out["action"], batch_in["action"]) + assert batch_out["next.reward"] == batch_in["next.reward"] + assert batch_out["next.done"] == batch_in["next.done"] + assert batch_out["next.truncated"] == batch_in["next.truncated"] + assert batch_out["info"] == batch_in["info"] + + +def test_batch_to_transition_observation_grouping(): + """Test that _default_batch_to_transition correctly groups observation.* keys.""" + batch = { + "observation.image.top": torch.randn(1, 3, 128, 128), + "observation.image.left": torch.randn(1, 3, 128, 128), + "observation.state": [1, 2, 3, 4], + "action": "action_data", + "next.reward": 1.5, + "next.done": True, + "next.truncated": False, + "info": {"episode": 42}, + } + + transition = _default_batch_to_transition(batch) + + # Check observation is a dict with all observation.* keys + assert isinstance(transition[TransitionIndex.OBSERVATION], dict) + assert "observation.image.top" in transition[TransitionIndex.OBSERVATION] + assert "observation.image.left" in transition[TransitionIndex.OBSERVATION] + assert "observation.state" in transition[TransitionIndex.OBSERVATION] + + # Check values are preserved + assert torch.allclose( + transition[TransitionIndex.OBSERVATION]["observation.image.top"], batch["observation.image.top"] + ) + assert torch.allclose( + transition[TransitionIndex.OBSERVATION]["observation.image.left"], batch["observation.image.left"] + ) + assert transition[TransitionIndex.OBSERVATION]["observation.state"] == [1, 2, 3, 4] + + # Check other fields + assert transition[TransitionIndex.ACTION] == "action_data" + assert transition[TransitionIndex.REWARD] == 1.5 + assert transition[TransitionIndex.DONE] + assert not transition[TransitionIndex.TRUNCATED] + assert transition[TransitionIndex.INFO] == {"episode": 42} + assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {} + + +def test_transition_to_batch_observation_flattening(): + """Test that _default_transition_to_batch correctly flattens observation dict.""" + observation_dict = { + "observation.image.top": torch.randn(1, 3, 128, 128), + "observation.image.left": torch.randn(1, 3, 128, 128), + "observation.state": [1, 2, 3, 4], + } + + transition = ( + observation_dict, # observation + "action_data", # action + 1.5, # reward + True, # done + False, # truncated + {"episode": 42}, # info + {}, # complementary_data + ) + + batch = _default_transition_to_batch(transition) + + # Check that observation.* keys are flattened back to batch + assert "observation.image.top" in batch + assert "observation.image.left" in batch + assert "observation.state" in batch + + # Check values are preserved + assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"]) + assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"]) + assert batch["observation.state"] == [1, 2, 3, 4] + + # Check other fields are mapped to next.* format + assert batch["action"] == "action_data" + assert batch["next.reward"] == 1.5 + assert batch["next.done"] + assert not batch["next.truncated"] + assert batch["info"] == {"episode": 42} + + +def test_no_observation_keys(): + """Test behavior when there are no observation.* keys.""" + batch = { + "action": "action_data", + "next.reward": 2.0, + "next.done": False, + "next.truncated": True, + "info": {"test": "no_obs"}, + } + + transition = _default_batch_to_transition(batch) + + # Observation should be None when no observation.* keys + assert transition[TransitionIndex.OBSERVATION] is None + + # Check other fields + assert transition[TransitionIndex.ACTION] == "action_data" + assert transition[TransitionIndex.REWARD] == 2.0 + assert not transition[TransitionIndex.DONE] + assert transition[TransitionIndex.TRUNCATED] + assert transition[TransitionIndex.INFO] == {"test": "no_obs"} + + # Round trip should work + reconstructed_batch = _default_transition_to_batch(transition) + assert reconstructed_batch["action"] == "action_data" + assert reconstructed_batch["next.reward"] == 2.0 + assert not reconstructed_batch["next.done"] + assert reconstructed_batch["next.truncated"] + assert reconstructed_batch["info"] == {"test": "no_obs"} + + +def test_minimal_batch(): + """Test with minimal batch containing only observation.* and action.""" + batch = {"observation.state": "minimal_state", "action": "minimal_action"} + + transition = _default_batch_to_transition(batch) + + # Check observation + assert transition[TransitionIndex.OBSERVATION] == {"observation.state": "minimal_state"} + assert transition[TransitionIndex.ACTION] == "minimal_action" + + # Check defaults + assert transition[TransitionIndex.REWARD] == 0.0 + assert not transition[TransitionIndex.DONE] + assert not transition[TransitionIndex.TRUNCATED] + assert transition[TransitionIndex.INFO] == {} + assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {} + + # Round trip + reconstructed_batch = _default_transition_to_batch(transition) + assert reconstructed_batch["observation.state"] == "minimal_state" + assert reconstructed_batch["action"] == "minimal_action" + assert reconstructed_batch["next.reward"] == 0.0 + assert not reconstructed_batch["next.done"] + assert not reconstructed_batch["next.truncated"] + assert reconstructed_batch["info"] == {} + + +def test_empty_batch(): + """Test behavior with empty batch.""" + batch = {} + + transition = _default_batch_to_transition(batch) + + # All fields should have defaults + assert transition[TransitionIndex.OBSERVATION] is None + assert transition[TransitionIndex.ACTION] is None + assert transition[TransitionIndex.REWARD] == 0.0 + assert not transition[TransitionIndex.DONE] + assert not transition[TransitionIndex.TRUNCATED] + assert transition[TransitionIndex.INFO] == {} + assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {} + + # Round trip + reconstructed_batch = _default_transition_to_batch(transition) + assert reconstructed_batch["action"] is None + assert reconstructed_batch["next.reward"] == 0.0 + assert not reconstructed_batch["next.done"] + assert not reconstructed_batch["next.truncated"] + assert reconstructed_batch["info"] == {} + + +def test_complex_nested_observation(): + """Test with complex nested observation data.""" + batch = { + "observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890}, + "observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, + "observation.state": torch.randn(7), + "action": torch.randn(8), + "next.reward": 3.14, + "next.done": False, + "next.truncated": True, + "info": {"episode_length": 200, "success": True}, + } + + transition = _default_batch_to_transition(batch) + reconstructed_batch = _default_transition_to_batch(transition) + + # Check that all observation keys are preserved + original_obs_keys = {k for k in batch.keys() if k.startswith("observation.")} + reconstructed_obs_keys = {k for k in reconstructed_batch.keys() if k.startswith("observation.")} + + assert original_obs_keys == reconstructed_obs_keys + + # Check tensor values + assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"]) + + # Check nested dict with tensors + assert torch.allclose( + batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"] + ) + assert torch.allclose( + batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"] + ) + + # Check action tensor + assert torch.allclose(batch["action"], reconstructed_batch["action"]) + + # Check other fields + assert batch["next.reward"] == reconstructed_batch["next.reward"] + assert batch["next.done"] == reconstructed_batch["next.done"] + assert batch["next.truncated"] == reconstructed_batch["next.truncated"] + assert batch["info"] == reconstructed_batch["info"] + + +def test_custom_converter(): + """Test that custom converters can still be used.""" + + def to_tr(batch): + # Custom converter that modifies the reward + tr = _default_batch_to_transition(batch) + # Double the reward + reward = tr[TransitionIndex.REWARD] * 2 if tr[TransitionIndex.REWARD] is not None else 0.0 + return ( + tr[TransitionIndex.OBSERVATION], + tr[TransitionIndex.ACTION], + reward, + tr[TransitionIndex.DONE], + tr[TransitionIndex.TRUNCATED], + tr[TransitionIndex.INFO], + tr[TransitionIndex.COMPLEMENTARY_DATA], + ) + + def to_batch(tr): + # Custom converter that adds a custom field + batch = _default_transition_to_batch(tr) + batch["custom_field"] = "custom_value" + return batch + + proc = RobotProcessor([], to_transition=to_tr, to_batch=to_batch) + batch = _dummy_batch() + out = proc(batch) + + # Check that custom modifications were applied + assert out["next.reward"] == batch["next.reward"] * 2 + assert out["custom_field"] == "custom_value" + + # Check that observation.* keys are still preserved + original_obs_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} + output_obs_keys = {k: v for k, v in out.items() if k.startswith("observation.")} + + assert set(original_obs_keys.keys()) == set(output_obs_keys.keys()) diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index e476ec27f..0c48433e8 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -4,6 +4,7 @@ import numpy as np import pytest import torch +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.processor.normalize_processor import ( NormalizerProcessor, UnnormalizerProcessor, @@ -76,6 +77,21 @@ def test_unsupported_type(): _convert_stats_to_tensors(stats) +# Helper functions to create feature maps and norm maps +def _create_observation_features(): + return { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + + +def _create_observation_norm_map(): + return { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.STATE: NormalizationMode.MIN_MAX, + } + + # Fixtures for observation normalisation tests using NormalizerProcessor @pytest.fixture def observation_stats(): @@ -94,7 +110,9 @@ def observation_stats(): @pytest.fixture def observation_normalizer(observation_stats): """Return a NormalizerProcessor that only has observation stats (no action).""" - return NormalizerProcessor(stats=observation_stats) + features = _create_observation_features() + norm_map = _create_observation_norm_map() + return NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats) def test_mean_std_normalization(observation_normalizer): @@ -129,7 +147,11 @@ def test_min_max_normalization(observation_normalizer): def test_selective_normalization(observation_stats): - normalizer = NormalizerProcessor(stats=observation_stats, normalize_keys={"observation.image"}) + features = _create_observation_features() + norm_map = _create_observation_norm_map() + normalizer = NormalizerProcessor( + features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"} + ) observation = { "observation.image": torch.tensor([0.7, 0.5, 0.3]), @@ -148,7 +170,9 @@ def test_selective_normalization(observation_stats): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_device_compatibility(observation_stats): - normalizer = NormalizerProcessor(stats=observation_stats) + features = _create_observation_features() + norm_map = _create_observation_norm_map() + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats) observation = { "observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(), } @@ -165,10 +189,19 @@ def test_from_lerobot_dataset(): mock_dataset = Mock() mock_dataset.meta.stats = { "observation.image": {"mean": [0.5], "std": [0.2]}, - "action": {"mean": [0.0], "std": [1.0]}, # Should be filtered out + "action": {"mean": [0.0], "std": [1.0]}, } - normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset) + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "action": PolicyFeature(FeatureType.ACTION, (1,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) # Both observation and action statistics should be present in tensor stats assert "observation.image" in normalizer._tensor_stats @@ -180,7 +213,9 @@ def test_state_dict_save_load(observation_normalizer): state_dict = observation_normalizer.state_dict() # Create new normalizer and load state - new_normalizer = NormalizerProcessor(stats={}) + features = _create_observation_features() + norm_map = _create_observation_norm_map() + new_normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={}) new_normalizer.load_state_dict(state_dict) # Test that it works the same @@ -210,8 +245,30 @@ def action_stats_min_max(): } +def _create_action_features(): + return { + "action": PolicyFeature(FeatureType.ACTION, (3,)), + } + + +def _create_action_norm_map_mean_std(): + return { + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + +def _create_action_norm_map_min_max(): + return { + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + + def test_mean_std_unnormalization(action_stats_mean_std): - unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std}) + features = _create_action_features() + norm_map = _create_action_norm_map_mean_std() + unnormalizer = UnnormalizerProcessor( + features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + ) normalized_action = torch.tensor([1.0, -0.5, 2.0]) transition = (None, normalized_action, None, None, None, None, None) @@ -225,7 +282,11 @@ def test_mean_std_unnormalization(action_stats_mean_std): def test_min_max_unnormalization(action_stats_min_max): - unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_min_max}) + features = _create_action_features() + norm_map = _create_action_norm_map_min_max() + unnormalizer = UnnormalizerProcessor( + features=features, norm_map=norm_map, stats={"action": action_stats_min_max} + ) # Actions in [-1, 1] normalized_action = torch.tensor([0.0, -1.0, 1.0]) @@ -247,7 +308,11 @@ def test_min_max_unnormalization(action_stats_min_max): def test_numpy_action_input(action_stats_mean_std): - unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std}) + features = _create_action_features() + norm_map = _create_action_norm_map_mean_std() + unnormalizer = UnnormalizerProcessor( + features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + ) normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32) transition = (None, normalized_action, None, None, None, None, None) @@ -261,7 +326,11 @@ def test_numpy_action_input(action_stats_mean_std): def test_none_action(action_stats_mean_std): - unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std}) + features = _create_action_features() + norm_map = _create_action_norm_map_mean_std() + unnormalizer = UnnormalizerProcessor( + features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + ) transition = (None, None, None, None, None, None, None) result = unnormalizer(transition) @@ -273,7 +342,9 @@ def test_none_action(action_stats_mean_std): def test_action_from_lerobot_dataset(): mock_dataset = Mock() mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}} - unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset) + features = {"action": PolicyFeature(FeatureType.ACTION, (1,))} + norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD} + unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) assert "mean" in unnormalizer._tensor_stats["action"] @@ -296,9 +367,27 @@ def full_stats(): } +def _create_full_features(): + return { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + + +def _create_full_norm_map(): + return { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.STATE: NormalizationMode.MIN_MAX, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + @pytest.fixture def normalizer_processor(full_stats): - return NormalizerProcessor(stats=full_stats) + features = _create_full_features() + norm_map = _create_full_norm_map() + return NormalizerProcessor(features=features, norm_map=norm_map, stats=full_stats) def test_combined_normalization(normalizer_processor): @@ -331,7 +420,12 @@ def test_processor_from_lerobot_dataset(full_stats): mock_dataset = Mock() mock_dataset.meta.stats = full_stats - processor = NormalizerProcessor.from_lerobot_dataset(mock_dataset, normalize_keys={"observation.image"}) + features = _create_full_features() + norm_map = _create_full_norm_map() + + processor = NormalizerProcessor.from_lerobot_dataset( + mock_dataset, features, norm_map, normalize_keys={"observation.image"} + ) assert processor.normalize_keys == {"observation.image"} assert "observation.image" in processor._tensor_stats @@ -339,7 +433,11 @@ def test_processor_from_lerobot_dataset(full_stats): def test_get_config(full_stats): - processor = NormalizerProcessor(stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6) + features = _create_full_features() + norm_map = _create_full_norm_map() + processor = NormalizerProcessor( + features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 + ) config = processor.get_config() assert config == {"normalize_keys": ["observation.image"], "eps": 1e-6} @@ -366,7 +464,9 @@ def test_integration_with_robot_processor(normalizer_processor): # Edge case tests def test_empty_observation(): stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} - normalizer = NormalizerProcessor(stats=stats) + features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) transition = (None, None, None, None, None, None, None) result = normalizer(transition) @@ -375,19 +475,23 @@ def test_empty_observation(): def test_empty_stats(): - normalizer = NormalizerProcessor(stats={}) + features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={}) observation = {"observation.image": torch.tensor([0.5])} transition = (observation, None, None, None, None, None, None) result = normalizer(transition) - # Should return observation unchanged + # Should return observation unchanged since no stats are available assert torch.allclose(result[0]["observation.image"], observation["observation.image"]) def test_partial_stats(): """If statistics are incomplete, the value should pass through unchanged.""" stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max) - normalizer = NormalizerProcessor(stats=stats) + features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) observation = {"observation.image": torch.tensor([0.7])} transition = (observation, None, None, None, None, None, None) @@ -399,6 +503,9 @@ def test_missing_action_stats_no_error(): mock_dataset = Mock() mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} - processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset) + features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + + processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) # The tensor stats should not contain the 'action' key assert "action" not in processor._tensor_stats