From 453e0a995f982917028ffeb7b20ce4ab7b7ae6dc Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Fri, 4 Jul 2025 10:53:40 +0200 Subject: [PATCH] Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. --- src/lerobot/processor/__init__.py | 2 + src/lerobot/processor/normalize_processor.py | 4 +- src/lerobot/processor/pipeline.py | 92 ++++ tests/processor/test_normalize_processor.py | 477 +++++++++++++++++++ tests/processor/test_rename_processor.py | 393 +++++++++++++++ 5 files changed, 966 insertions(+), 2 deletions(-) create mode 100644 tests/processor/test_normalize_processor.py create mode 100644 tests/processor/test_rename_processor.py diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 1f62a81a9..bcf49c905 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -20,6 +20,7 @@ from .observation_processor import ( StateProcessor, ) from .pipeline import EnvTransition, ProcessorStep, RobotProcessor +from .rename_processor import RenameProcessor __all__ = [ "RobotProcessor", @@ -29,4 +30,5 @@ __all__ = [ "StateProcessor", "ObservationProcessor", "NormalizationProcessor", + "RenameProcessor", ] diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 77b8d4236..08a334695 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -151,7 +151,7 @@ class ObservationNormalizer: def load_state_dict(self, state: Mapping[str, Tensor]) -> None: self._tensor_stats.clear() for flat_key, tensor in state.items(): - key, stat_name = flat_key.split(".", 1) + key, stat_name = flat_key.rsplit(".", 1) if key not in self._tensor_stats: self._tensor_stats[key] = {} self._tensor_stats[key][stat_name] = tensor @@ -382,7 +382,7 @@ class NormalizationProcessor: def load_state_dict(self, state: Mapping[str, Tensor]) -> None: self._tensor_stats.clear() for flat_key, tensor in state.items(): - key, stat_name = flat_key.split(".", 1) + key, stat_name = flat_key.rsplit(".", 1) if key not in self._tensor_stats: self._tensor_stats[key] = {} self._tensor_stats[key][stat_name] = tensor diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 28542a112..adbfaba19 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -464,3 +464,95 @@ class RobotProcessor(ModelHubMixin): profile_results[step_name] = avg_time return profile_results + + +class ObservationProcessor: + def observation(self, observation): + return observation + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = transition[TransitionIndex.OBSERVATION] + observation = self.observation(observation) + transition = (observation, *transition[TransitionIndex.ACTION :]) + return transition + + +class ActionProcessor: + def action(self, action): + return action + + def __call__(self, transition: EnvTransition) -> EnvTransition: + action = transition[TransitionIndex.ACTION] + action = self.action(action) + transition = (transition[TransitionIndex.OBSERVATION], action, *transition[TransitionIndex.REWARD :]) + return transition + + +class RewardProcessor: + def reward(self, reward): + return reward + + def __call__(self, transition: EnvTransition) -> EnvTransition: + reward = transition[TransitionIndex.REWARD] + reward = self.reward(reward) + transition = ( + transition[TransitionIndex.OBSERVATION], + transition[TransitionIndex.ACTION], + reward, + *transition[TransitionIndex.DONE :], + ) + return transition + + +class DoneProcessor: + def done(self, done): + return done + + def __call__(self, transition: EnvTransition) -> EnvTransition: + done = transition[TransitionIndex.DONE] + done = self.done(done) + transition = ( + transition[TransitionIndex.OBSERVATION], + transition[TransitionIndex.ACTION], + transition[TransitionIndex.REWARD], + done, + *transition[TransitionIndex.TRUNCATED :], + ) + return transition + + +class TruncatedProcessor: + def truncated(self, truncated): + return truncated + + def __call__(self, transition: EnvTransition) -> EnvTransition: + truncated = transition[TransitionIndex.TRUNCATED] + truncated = self.truncated(truncated) + transition = ( + transition[TransitionIndex.OBSERVATION], + transition[TransitionIndex.ACTION], + transition[TransitionIndex.REWARD], + transition[TransitionIndex.DONE], + truncated, + *transition[TransitionIndex.INFO :], + ) + return transition + + +class InfoProcessor: + def info(self, info): + return info + + def __call__(self, transition: EnvTransition) -> EnvTransition: + info = transition[TransitionIndex.INFO] + info = self.info(info) + transition = ( + transition[TransitionIndex.OBSERVATION], + transition[TransitionIndex.ACTION], + transition[TransitionIndex.REWARD], + transition[TransitionIndex.DONE], + transition[TransitionIndex.TRUNCATED], + info, + *transition[TransitionIndex.COMPLEMENTARY_DATA :], + ) + return transition diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py new file mode 100644 index 000000000..8125a3520 --- /dev/null +++ b/tests/processor/test_normalize_processor.py @@ -0,0 +1,477 @@ +from unittest.mock import Mock + +import numpy as np +import pytest +import torch + +from lerobot.processor.normalize_processor import ( + ActionUnnormalizer, + NormalizationProcessor, + ObservationNormalizer, + _convert_stats_to_tensors, +) +from lerobot.processor.pipeline import RobotProcessor, TransitionIndex + + +def test_numpy_conversion(): + stats = { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + } + } + tensor_stats = _convert_stats_to_tensors(stats) + + assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor) + assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor) + assert torch.allclose(tensor_stats["observation.image"]["mean"], torch.tensor([0.5, 0.5, 0.5])) + assert torch.allclose(tensor_stats["observation.image"]["std"], torch.tensor([0.2, 0.2, 0.2])) + + +def test_tensor_conversion(): + stats = { + "action": { + "mean": torch.tensor([0.0, 0.0]), + "std": torch.tensor([1.0, 1.0]), + } + } + tensor_stats = _convert_stats_to_tensors(stats) + + assert tensor_stats["action"]["mean"].dtype == torch.float32 + assert tensor_stats["action"]["std"].dtype == torch.float32 + + +def test_scalar_conversion(): + stats = { + "reward": { + "mean": 0.5, + "std": 0.1, + } + } + tensor_stats = _convert_stats_to_tensors(stats) + + assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5)) + assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1)) + + +def test_list_conversion(): + stats = { + "observation.state": { + "min": [0.0, -1.0, -2.0], + "max": [1.0, 1.0, 2.0], + } + } + tensor_stats = _convert_stats_to_tensors(stats) + + assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0])) + assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0])) + + +def test_unsupported_type(): + stats = { + "bad_key": { + "mean": "string_value", + } + } + with pytest.raises(TypeError, match="Unsupported type"): + _convert_stats_to_tensors(stats) + + +# Fixtures for ObservationNormalizer tests +@pytest.fixture +def observation_stats(): + return { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + }, + "observation.state": { + "min": np.array([0.0, -1.0]), + "max": np.array([1.0, 1.0]), + }, + } + + +@pytest.fixture +def observation_normalizer(observation_stats): + return ObservationNormalizer(stats=observation_stats) + + +def test_mean_std_normalization(observation_normalizer): + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = (observation, None, None, None, None, None, None) + + normalized_transition = observation_normalizer(transition) + normalized_obs = normalized_transition[TransitionIndex.OBSERVATION] + + # Check mean/std normalization + expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 + assert torch.allclose(normalized_obs["observation.image"], expected_image) + + +def test_min_max_normalization(observation_normalizer): + observation = { + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = (observation, None, None, None, None, None, None) + + normalized_transition = observation_normalizer(transition) + normalized_obs = normalized_transition[TransitionIndex.OBSERVATION] + + # Check min/max normalization to [-1, 1] + # For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0 + # For state[1]: 2 * (0.0 - (-1.0)) / (1.0 - (-1.0)) - 1 = 0.0 + expected_state = torch.tensor([0.0, 0.0]) + assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6) + + +def test_selective_normalization(observation_stats): + normalizer = ObservationNormalizer(stats=observation_stats, normalize_keys={"observation.image"}) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = (observation, None, None, None, None, None, None) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionIndex.OBSERVATION] + + # Only image should be normalized + assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2) + # State should remain unchanged + assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"]) + + +def test_missing_stats_error(observation_stats): + normalizer = ObservationNormalizer( + stats={"observation.image": observation_stats["observation.image"]}, + normalize_keys={"observation.image", "observation.missing"}, + ) + + observation = { + "observation.image": torch.tensor([0.5, 0.5, 0.5]), + "observation.missing": torch.tensor([1.0, 2.0]), + } + transition = (observation, None, None, None, None, None, None) + + with pytest.raises(KeyError, match="Stats not found for requested key 'observation.missing'"): + normalizer(transition) + + +@pytest.mark.parametrize( + "input_type,input_value,expected_type", + [ + ("numpy", np.array([0.7, 0.5, 0.3], dtype=np.float32), torch.Tensor), + ("torch", torch.tensor([0.7, 0.5, 0.3]), torch.Tensor), + ], +) +def test_input_types(observation_normalizer, input_type, input_value, expected_type): + observation = { + "observation.image": input_value, + } + transition = (observation, None, None, None, None, None, None) + + normalized_transition = observation_normalizer(transition) + normalized_obs = normalized_transition[TransitionIndex.OBSERVATION] + + expected = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 + assert isinstance(normalized_obs["observation.image"], expected_type) + assert torch.allclose(normalized_obs["observation.image"], expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_device_compatibility(observation_stats): + normalizer = ObservationNormalizer(stats=observation_stats) + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(), + } + transition = (observation, None, None, None, None, None, None) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionIndex.OBSERVATION] + + assert normalized_obs["observation.image"].device.type == "cuda" + + +def test_from_lerobot_dataset(): + # Mock dataset + mock_dataset = Mock() + mock_dataset.meta.stats = { + "observation.image": {"mean": [0.5], "std": [0.2]}, + "action": {"mean": [0.0], "std": [1.0]}, # Should be filtered out + } + + normalizer = ObservationNormalizer.from_lerobot_dataset(mock_dataset) + + # Check that action stats are filtered out + assert "observation.image" in normalizer._tensor_stats + assert "action" not in normalizer._tensor_stats + + +def test_state_dict_save_load(observation_normalizer): + # Save state + state_dict = observation_normalizer.state_dict() + + # Create new normalizer and load state + new_normalizer = ObservationNormalizer(stats={}) + new_normalizer.load_state_dict(state_dict) + + # Test that it works the same + observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + transition = (observation, None, None, None, None, None, None) + + result1 = observation_normalizer(transition)[0] + result2 = new_normalizer(transition)[0] + + assert torch.allclose(result1["observation.image"], result2["observation.image"]) + + +# Fixtures for ActionUnnormalizer tests +@pytest.fixture +def action_stats_mean_std(): + return { + "mean": np.array([0.0, 0.0, 0.0]), + "std": np.array([1.0, 2.0, 0.5]), + } + + +@pytest.fixture +def action_stats_min_max(): + return { + "min": np.array([-1.0, -2.0, 0.0]), + "max": np.array([1.0, 2.0, 1.0]), + } + + +def test_mean_std_unnormalization(action_stats_mean_std): + unnormalizer = ActionUnnormalizer(action_stats=action_stats_mean_std) + + normalized_action = torch.tensor([1.0, -0.5, 2.0]) + transition = (None, normalized_action, None, None, None, None, None) + + unnormalized_transition = unnormalizer(transition) + unnormalized_action = unnormalized_transition[TransitionIndex.ACTION] + + # action * std + mean + expected = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0, 2.0 * 0.5 + 0.0]) + assert torch.allclose(unnormalized_action, expected) + + +def test_min_max_unnormalization(action_stats_min_max): + unnormalizer = ActionUnnormalizer(action_stats=action_stats_min_max) + + # Actions in [-1, 1] + normalized_action = torch.tensor([0.0, -1.0, 1.0]) + transition = (None, normalized_action, None, None, None, None, None) + + unnormalized_transition = unnormalizer(transition) + unnormalized_action = unnormalized_transition[TransitionIndex.ACTION] + + # Map from [-1, 1] to [min, max] + # (action + 1) / 2 * (max - min) + min + expected = torch.tensor( + [ + (0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0), # 0.0 + (-1.0 + 1) / 2 * (2.0 - (-2.0)) + (-2.0), # -2.0 + (1.0 + 1) / 2 * (1.0 - 0.0) + 0.0, # 1.0 + ] + ) + assert torch.allclose(unnormalized_action, expected) + + +def test_numpy_action_input(action_stats_mean_std): + unnormalizer = ActionUnnormalizer(action_stats=action_stats_mean_std) + + normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32) + transition = (None, normalized_action, None, None, None, None, None) + + unnormalized_transition = unnormalizer(transition) + unnormalized_action = unnormalized_transition[TransitionIndex.ACTION] + + assert isinstance(unnormalized_action, torch.Tensor) + expected = torch.tensor([1.0, -1.0, 1.0]) + assert torch.allclose(unnormalized_action, expected) + + +def test_none_action(action_stats_mean_std): + unnormalizer = ActionUnnormalizer(action_stats=action_stats_mean_std) + + transition = (None, None, None, None, None, None, None) + result = unnormalizer(transition) + + # Should return transition unchanged + assert result == transition + + +def test_action_from_lerobot_dataset(): + # Mock dataset + mock_dataset = Mock() + mock_dataset.meta.stats = { + "action": {"mean": [0.0], "std": [1.0]}, + "observation.image": {"mean": [0.5], "std": [0.2]}, + } + + unnormalizer = ActionUnnormalizer.from_lerobot_dataset(mock_dataset) + + assert "mean" in unnormalizer._tensor_stats + assert "std" in unnormalizer._tensor_stats + + +def test_missing_action_stats_error(): + mock_dataset = Mock() + mock_dataset.meta.stats = { + "observation.image": {"mean": [0.5], "std": [0.2]}, + } + + with pytest.raises(ValueError, match="Dataset does not contain action statistics"): + ActionUnnormalizer.from_lerobot_dataset(mock_dataset) + + +def test_invalid_stats_error(): + unnormalizer = ActionUnnormalizer(action_stats={"invalid": [1.0]}) + + action = torch.tensor([1.0]) + transition = (None, action, None, None, None, None, None) + + with pytest.raises(ValueError, match="Action stats must contain"): + unnormalizer(transition) + + +# Fixtures for NormalizationProcessor tests +@pytest.fixture +def full_stats(): + return { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + }, + "observation.state": { + "min": np.array([0.0, -1.0]), + "max": np.array([1.0, 1.0]), + }, + "action": { + "mean": np.array([0.0, 0.0]), + "std": np.array([1.0, 2.0]), + }, + } + + +@pytest.fixture +def normalization_processor(full_stats): + return NormalizationProcessor(stats=full_stats) + + +def test_combined_normalization_unnormalization(normalization_processor): + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + action = torch.tensor([1.0, -0.5]) + transition = (observation, action, 1.0, False, False, {}, {}) + + processed_transition = normalization_processor(transition) + + # Check normalized observations + processed_obs = processed_transition[TransitionIndex.OBSERVATION] + expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 + assert torch.allclose(processed_obs["observation.image"], expected_image) + + # Check unnormalized action + processed_action = processed_transition[TransitionIndex.ACTION] + expected_action = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0]) + assert torch.allclose(processed_action, expected_action) + + # Check other fields remain unchanged + assert processed_transition[TransitionIndex.REWARD] == 1.0 + assert not processed_transition[TransitionIndex.DONE] + + +def test_disable_action_unnormalization(full_stats): + processor = NormalizationProcessor(stats=full_stats, unnormalize_action=False) + + action = torch.tensor([1.0, -0.5]) + transition = (None, action, None, None, None, None, None) + + processed_transition = processor(transition) + + # Action should remain unchanged + assert torch.allclose(processed_transition[TransitionIndex.ACTION], action) + + +def test_processor_from_lerobot_dataset(full_stats): + # Mock dataset + mock_dataset = Mock() + mock_dataset.meta.stats = full_stats + + processor = NormalizationProcessor.from_lerobot_dataset( + mock_dataset, normalize_keys={"observation.image"}, unnormalize_action=True + ) + + assert processor.normalize_keys == {"observation.image"} + assert processor.unnormalize_action + assert "observation.image" in processor._tensor_stats + assert "action" in processor._tensor_stats + + +def test_get_config(full_stats): + processor = NormalizationProcessor( + stats=full_stats, normalize_keys={"observation.image"}, unnormalize_action=False, eps=1e-6 + ) + + config = processor.get_config() + assert config == {"normalize_keys": ["observation.image"], "unnormalize_action": False, "eps": 1e-6} + + +def test_integration_with_robot_processor(normalization_processor): + """Test integration with RobotProcessor pipeline""" + robot_processor = RobotProcessor([normalization_processor]) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + action = torch.tensor([1.0, -0.5]) + transition = (observation, action, 1.0, False, False, {}, {}) + + processed_transition = robot_processor(transition) + + # Verify the processing worked + assert isinstance(processed_transition[TransitionIndex.OBSERVATION], dict) + assert isinstance(processed_transition[TransitionIndex.ACTION], torch.Tensor) + + +# Edge case tests +def test_empty_observation(): + stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + normalizer = ObservationNormalizer(stats=stats) + + transition = (None, None, None, None, None, None, None) + result = normalizer(transition) + + assert result == transition + + +def test_empty_stats(): + normalizer = ObservationNormalizer(stats={}) + observation = {"observation.image": torch.tensor([0.5])} + transition = (observation, None, None, None, None, None, None) + + result = normalizer(transition) + # Should return observation unchanged + assert torch.allclose(result[0]["observation.image"], observation["observation.image"]) + + +def test_partial_stats(): + stats = { + "observation.image": {"mean": [0.5]}, # Missing std + } + normalizer = ObservationNormalizer(stats=stats) + observation = {"observation.image": torch.tensor([0.7])} + transition = (observation, None, None, None, None, None, None) + + with pytest.raises(ValueError, match="must contain either"): + normalizer(transition) diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py new file mode 100644 index 000000000..1b7b28425 --- /dev/null +++ b/tests/processor/test_rename_processor.py @@ -0,0 +1,393 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path + +import numpy as np +import torch + +from lerobot.processor.pipeline import ProcessorStepRegistry, RobotProcessor, TransitionIndex +from lerobot.processor.rename_processor import RenameProcessor + + +def test_basic_renaming(): + """Test basic key renaming functionality.""" + rename_map = { + "old_key1": "new_key1", + "old_key2": "new_key2", + } + processor = RenameProcessor(rename_map=rename_map) + + observation = { + "old_key1": torch.tensor([1.0, 2.0]), + "old_key2": np.array([3.0, 4.0]), + "unchanged_key": "keep_me", + } + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[TransitionIndex.OBSERVATION] + + # Check renamed keys + assert "new_key1" in processed_obs + assert "new_key2" in processed_obs + assert "old_key1" not in processed_obs + assert "old_key2" not in processed_obs + + # Check values are preserved + torch.testing.assert_close(processed_obs["new_key1"], torch.tensor([1.0, 2.0])) + np.testing.assert_array_equal(processed_obs["new_key2"], np.array([3.0, 4.0])) + + # Check unchanged key is preserved + assert processed_obs["unchanged_key"] == "keep_me" + + +def test_empty_rename_map(): + """Test processor with empty rename map (should pass through unchanged).""" + processor = RenameProcessor(rename_map={}) + + observation = { + "key1": torch.tensor([1.0]), + "key2": "value2", + } + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[TransitionIndex.OBSERVATION] + + # All keys should be unchanged + assert processed_obs.keys() == observation.keys() + torch.testing.assert_close(processed_obs["key1"], observation["key1"]) + assert processed_obs["key2"] == observation["key2"] + + +def test_none_observation(): + """Test processor with None observation.""" + processor = RenameProcessor(rename_map={"old": "new"}) + + transition = (None, None, None, None, None, None, None) + result = processor(transition) + + # Should return transition unchanged + assert result == transition + + +def test_overlapping_rename(): + """Test renaming when new names might conflict.""" + rename_map = { + "a": "b", + "b": "c", # This creates a potential conflict + } + processor = RenameProcessor(rename_map=rename_map) + + observation = { + "a": 1, + "b": 2, + "x": 3, + } + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[TransitionIndex.OBSERVATION] + + # Check that renaming happens correctly + assert "a" not in processed_obs + assert processed_obs["b"] == 1 # 'a' renamed to 'b' + assert processed_obs["c"] == 2 # original 'b' renamed to 'c' + assert processed_obs["x"] == 3 + + +def test_partial_rename(): + """Test renaming only some keys.""" + rename_map = { + "observation.state": "observation.proprio_state", + "pixels": "observation.image", + } + processor = RenameProcessor(rename_map=rename_map) + + observation = { + "observation.state": torch.randn(10), + "pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8), + "reward": 1.0, + "info": {"episode": 1}, + } + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[TransitionIndex.OBSERVATION] + + # Check renamed keys + assert "observation.proprio_state" in processed_obs + assert "observation.image" in processed_obs + assert "observation.state" not in processed_obs + assert "pixels" not in processed_obs + + # Check unchanged keys + assert processed_obs["reward"] == 1.0 + assert processed_obs["info"] == {"episode": 1} + + +def test_get_config(): + """Test configuration serialization.""" + rename_map = { + "old1": "new1", + "old2": "new2", + } + processor = RenameProcessor(rename_map=rename_map) + + config = processor.get_config() + assert config == {"rename_map": rename_map} + + +def test_state_dict(): + """Test state dict (should be empty for RenameProcessor).""" + processor = RenameProcessor(rename_map={"old": "new"}) + + state = processor.state_dict() + assert state == {} + + # Load state dict should work even with empty dict + processor.load_state_dict({}) + + +def test_integration_with_robot_processor(): + """Test integration with RobotProcessor pipeline.""" + rename_map = { + "agent_pos": "observation.state", + "pixels": "observation.image", + } + rename_processor = RenameProcessor(rename_map=rename_map) + + pipeline = RobotProcessor([rename_processor]) + + observation = { + "agent_pos": np.array([1.0, 2.0, 3.0]), + "pixels": np.zeros((32, 32, 3), dtype=np.uint8), + "other_data": "preserve_me", + } + transition = (observation, None, 0.5, False, False, {}, {}) + + result = pipeline(transition) + processed_obs = result[TransitionIndex.OBSERVATION] + + # Check renaming worked through pipeline + assert "observation.state" in processed_obs + assert "observation.image" in processed_obs + assert "agent_pos" not in processed_obs + assert "pixels" not in processed_obs + assert processed_obs["other_data"] == "preserve_me" + + # Check other transition elements unchanged + assert result[TransitionIndex.REWARD] == 0.5 + assert result[TransitionIndex.DONE] is False + + +def test_save_and_load_pretrained(): + """Test saving and loading processor with RobotProcessor.""" + rename_map = { + "old_state": "observation.state", + "old_image": "observation.image", + } + processor = RenameProcessor(rename_map=rename_map) + pipeline = RobotProcessor([processor], name="TestRenameProcessor") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save pipeline + pipeline.save_pretrained(tmp_dir) + + # Check files were created + config_path = Path(tmp_dir) / "processor.json" + assert config_path.exists() + + # No state files should be created for RenameProcessor + state_files = list(Path(tmp_dir).glob("*.safetensors")) + assert len(state_files) == 0 + + # Load pipeline + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + + assert loaded_pipeline.name == "TestRenameProcessor" + assert len(loaded_pipeline) == 1 + + # Check that loaded processor works correctly + loaded_processor = loaded_pipeline.steps[0] + assert isinstance(loaded_processor, RenameProcessor) + assert loaded_processor.rename_map == rename_map + + # Test functionality after loading + observation = {"old_state": [1, 2, 3], "old_image": "image_data"} + transition = (observation, None, None, None, None, None, None) + + result = loaded_pipeline(transition) + processed_obs = result[TransitionIndex.OBSERVATION] + + assert "observation.state" in processed_obs + assert "observation.image" in processed_obs + assert processed_obs["observation.state"] == [1, 2, 3] + assert processed_obs["observation.image"] == "image_data" + + +def test_registry_functionality(): + """Test that RenameProcessor is properly registered.""" + # Check that it's registered + assert "rename_processor" in ProcessorStepRegistry.list() + + # Get from registry + retrieved_class = ProcessorStepRegistry.get("rename_processor") + assert retrieved_class is RenameProcessor + + # Create instance from registry + instance = retrieved_class(rename_map={"old": "new"}) + assert isinstance(instance, RenameProcessor) + assert instance.rename_map == {"old": "new"} + + +def test_registry_based_save_load(): + """Test save/load using registry name instead of module path.""" + processor = RenameProcessor(rename_map={"key1": "renamed_key1"}) + pipeline = RobotProcessor([processor]) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save and load + pipeline.save_pretrained(tmp_dir) + + # Verify config uses registry name + import json + + with open(Path(tmp_dir) / "processor.json") as f: + config = json.load(f) + + assert "registry_name" in config["steps"][0] + assert config["steps"][0]["registry_name"] == "rename_processor" + assert "class" not in config["steps"][0] # Should use registry, not module path + + # Load should work + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_processor = loaded_pipeline.steps[0] + assert isinstance(loaded_processor, RenameProcessor) + assert loaded_processor.rename_map == {"key1": "renamed_key1"} + + +def test_chained_rename_processors(): + """Test multiple RenameProcessors in a pipeline.""" + # First processor: rename raw keys to intermediate format + processor1 = RenameProcessor( + rename_map={ + "pos": "agent_position", + "img": "camera_image", + } + ) + + # Second processor: rename to final format + processor2 = RenameProcessor( + rename_map={ + "agent_position": "observation.state", + "camera_image": "observation.image", + } + ) + + pipeline = RobotProcessor([processor1, processor2]) + + observation = { + "pos": np.array([1.0, 2.0]), + "img": "image_data", + "extra": "keep_me", + } + transition = (observation, None, None, None, None, None, None) + + # Step through to see intermediate results + results = list(pipeline.step_through(transition)) + + # After first processor + assert "agent_position" in results[1][TransitionIndex.OBSERVATION] + assert "camera_image" in results[1][TransitionIndex.OBSERVATION] + + # After second processor + final_obs = results[2][TransitionIndex.OBSERVATION] + assert "observation.state" in final_obs + assert "observation.image" in final_obs + assert final_obs["extra"] == "keep_me" + + # Original keys should be gone + assert "pos" not in final_obs + assert "img" not in final_obs + assert "agent_position" not in final_obs + assert "camera_image" not in final_obs + + +def test_nested_observation_rename(): + """Test renaming with nested observation structures.""" + rename_map = { + "observation.images.left": "observation.camera.left_view", + "observation.images.right": "observation.camera.right_view", + "observation.proprio": "observation.proprioception", + } + processor = RenameProcessor(rename_map=rename_map) + + observation = { + "observation.images.left": torch.randn(3, 64, 64), + "observation.images.right": torch.randn(3, 64, 64), + "observation.proprio": torch.randn(7), + "observation.gripper": torch.tensor([0.0]), # Not renamed + } + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[TransitionIndex.OBSERVATION] + + # Check renames + assert "observation.camera.left_view" in processed_obs + assert "observation.camera.right_view" in processed_obs + assert "observation.proprioception" in processed_obs + + # Check unchanged key + assert "observation.gripper" in processed_obs + + # Check old keys removed + assert "observation.images.left" not in processed_obs + assert "observation.images.right" not in processed_obs + assert "observation.proprio" not in processed_obs + + +def test_value_types_preserved(): + """Test that various value types are preserved during renaming.""" + rename_map = {"old_tensor": "new_tensor", "old_array": "new_array", "old_scalar": "new_scalar"} + processor = RenameProcessor(rename_map=rename_map) + + tensor_value = torch.randn(3, 3) + array_value = np.random.rand(2, 2) + + observation = { + "old_tensor": tensor_value, + "old_array": array_value, + "old_scalar": 42, + "old_string": "hello", + "old_dict": {"nested": "value"}, + "old_list": [1, 2, 3], + } + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[TransitionIndex.OBSERVATION] + + # Check that values and types are preserved + assert torch.equal(processed_obs["new_tensor"], tensor_value) + assert np.array_equal(processed_obs["new_array"], array_value) + assert processed_obs["new_scalar"] == 42 + assert processed_obs["old_string"] == "hello" + assert processed_obs["old_dict"] == {"nested": "value"} + assert processed_obs["old_list"] == [1, 2, 3]