From 6f1e49dbc4c9be76b93f041ed9441b7a154c47f6 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Wed, 10 Sep 2025 13:08:44 +0200 Subject: [PATCH] refactor(tests): streamline transition creation in processor tests - Replaced custom transition creation functions with a centralized `create_transition` function imported from converters across multiple test files. - Updated test cases to utilize keyword arguments for better readability and maintainability, ensuring consistent transition creation throughout the test suite. --- tests/processor/test_act_processor.py | 53 ++++++++++---------- tests/processor/test_batch_conversion.py | 33 ++++++------ tests/processor/test_batch_processor.py | 20 ++------ tests/processor/test_classifier_processor.py | 44 ++++++++-------- tests/processor/test_device_processor.py | 31 ++---------- tests/processor/test_diffusion_processor.py | 47 +++++++++-------- tests/processor/test_pi0_processor.py | 28 ++++------- tests/processor/test_sac_processor.py | 52 +++++++++---------- tests/processor/test_smolvla_processor.py | 28 ++++------- tests/processor/test_vqbet_processor.py | 47 +++++++++-------- 10 files changed, 165 insertions(+), 218 deletions(-) diff --git a/tests/processor/test_act_processor.py b/tests/processor/test_act_processor.py index ef3b72f54..e4c7cc734 100644 --- a/tests/processor/test_act_processor.py +++ b/tests/processor/test_act_processor.py @@ -33,19 +33,7 @@ from lerobot.processor import ( TransitionKey, UnnormalizerProcessorStep, ) - - -def create_transition(observation=None, action=None, **kwargs): - """Helper function to create a transition dictionary.""" - transition = {} - if observation is not None: - transition[TransitionKey.OBSERVATION] = observation - if action is not None: - transition[TransitionKey.ACTION] = action - for key, value in kwargs.items(): - if hasattr(TransitionKey, key.upper()): - transition[getattr(TransitionKey, key.upper())] = value - return transition +from lerobot.processor.converters import create_transition def create_default_config(): @@ -112,7 +100,8 @@ def test_act_processor_normalization(): # Create test data observation = {OBS_STATE: torch.randn(7)} action = torch.randn(4) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -122,7 +111,8 @@ def test_act_processor_normalization(): assert processed[TransitionKey.ACTION].shape == (1, 4) # Process action through postprocessor - action_transition = create_transition(action=processed[TransitionKey.ACTION]) + action_transition = create_transition() + action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION] postprocessed = postprocessor(action_transition) # Check that action is unnormalized @@ -146,7 +136,8 @@ def test_act_processor_cuda(): # Create CPU data observation = {OBS_STATE: torch.randn(7)} action = torch.randn(4) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -156,7 +147,8 @@ def test_act_processor_cuda(): assert processed[TransitionKey.ACTION].device.type == "cuda" # Process through postprocessor - action_transition = create_transition(action=processed[TransitionKey.ACTION]) + action_transition = create_transition() + action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION] postprocessed = postprocessor(action_transition) # Check that action is back on CPU @@ -181,7 +173,8 @@ def test_act_processor_accelerate_scenario(): device = torch.device("cuda:0") observation = {OBS_STATE: torch.randn(1, 7).to(device)} # Already batched and on GPU action = torch.randn(1, 4).to(device) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -204,7 +197,8 @@ def test_act_processor_multi_gpu(): device = torch.device("cuda:1") observation = {OBS_STATE: torch.randn(1, 7).to(device)} action = torch.randn(1, 4).to(device) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -227,7 +221,8 @@ def test_act_processor_without_stats(): # Process should still work (but won't normalize without stats) observation = {OBS_STATE: torch.randn(7)} action = torch.randn(4) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action processed = preprocessor(transition) assert processed is not None @@ -257,7 +252,8 @@ def test_act_processor_save_and_load(): # Test that loaded processor works observation = {OBS_STATE: torch.randn(7)} action = torch.randn(4) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action processed = loaded_preprocessor(transition) assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7) @@ -281,7 +277,8 @@ def test_act_processor_device_placement_preservation(): # Process CPU data observation = {OBS_STATE: torch.randn(7)} action = torch.randn(4) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action processed = preprocessor(transition) assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu" @@ -326,7 +323,8 @@ def test_act_processor_mixed_precision(): # Create test data observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} action = torch.randn(4, dtype=torch.float32) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -351,7 +349,8 @@ def test_act_processor_batch_consistency(): # Test single sample (unbatched) observation = {OBS_STATE: torch.randn(7)} action = torch.randn(4) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action processed = preprocessor(transition) assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 # Batched @@ -359,7 +358,8 @@ def test_act_processor_batch_consistency(): # Test already batched data observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8 action_batched = torch.randn(8, 4) - transition_batched = create_transition(observation_batched, action_batched) + transition_batched = create_transition(observation=observation_batched) + transition_batched[TransitionKey.ACTION] = action_batched processed_batched = preprocessor(transition_batched) assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8 @@ -407,7 +407,8 @@ def test_act_processor_bfloat16_device_float32_normalizer(): # Create test data observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} # Start with float32 action = torch.randn(4, dtype=torch.float32) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through full pipeline processed = preprocessor(transition) diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 8d1f5e20e..8635c39e9 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from lerobot.processor import DataProcessorPipeline, TransitionKey @@ -20,7 +22,7 @@ def _dummy_batch(): def test_observation_grouping_roundtrip(): """Test that observation.* keys are properly grouped and ungrouped.""" - proc = DataProcessorPipeline([]) + proc = DataProcessorPipeline[dict[str, Any]]([]) batch_in = _dummy_batch() batch_out = proc(batch_in) @@ -45,11 +47,12 @@ def test_observation_grouping_roundtrip(): def test_batch_to_transition_observation_grouping(): """Test that batch_to_transition correctly groups observation.* keys.""" + base_batch = _dummy_batch() batch = { "observation.image.top": torch.randn(1, 3, 128, 128), - "observation.image.left": torch.randn(1, 3, 128, 128), + "observation.image.left": base_batch["observation.image.left"], "observation.state": [1, 2, 3, 4], - "action": "action_data", + "action": torch.tensor([[0.1, 0.2]]), "next.reward": 1.5, "next.done": True, "next.truncated": False, @@ -74,7 +77,7 @@ def test_batch_to_transition_observation_grouping(): assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4] # Check other fields - assert transition[TransitionKey.ACTION] == "action_data" + assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([[0.1, 0.2]])) assert transition[TransitionKey.REWARD] == 1.5 assert transition[TransitionKey.DONE] assert not transition[TransitionKey.TRUNCATED] @@ -84,15 +87,16 @@ def test_batch_to_transition_observation_grouping(): def test_transition_to_batch_observation_flattening(): """Test that transition_to_batch correctly flattens observation dict.""" + base_batch = _dummy_batch() observation_dict = { "observation.image.top": torch.randn(1, 3, 128, 128), - "observation.image.left": torch.randn(1, 3, 128, 128), + "observation.image.left": base_batch["observation.image.left"], "observation.state": [1, 2, 3, 4], } transition = { TransitionKey.OBSERVATION: observation_dict, - TransitionKey.ACTION: "action_data", + TransitionKey.ACTION: torch.tensor([[0.3, 0.4]]), TransitionKey.REWARD: 1.5, TransitionKey.DONE: True, TransitionKey.TRUNCATED: False, @@ -113,7 +117,7 @@ def test_transition_to_batch_observation_flattening(): assert batch["observation.state"] == [1, 2, 3, 4] # Check other fields are mapped to next.* format - assert batch["action"] == "action_data" + assert torch.allclose(batch["action"], torch.tensor([[0.3, 0.4]])) assert batch["next.reward"] == 1.5 assert batch["next.done"] assert not batch["next.truncated"] @@ -123,7 +127,7 @@ def test_transition_to_batch_observation_flattening(): def test_no_observation_keys(): """Test behavior when there are no observation.* keys.""" batch = { - "action": "action_data", + "action": torch.tensor([[0.7, 0.8]]), "next.reward": 2.0, "next.done": False, "next.truncated": True, @@ -136,7 +140,7 @@ def test_no_observation_keys(): assert transition[TransitionKey.OBSERVATION] is None # Check other fields - assert transition[TransitionKey.ACTION] == "action_data" + assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([[0.7, 0.8]])) assert transition[TransitionKey.REWARD] == 2.0 assert not transition[TransitionKey.DONE] assert transition[TransitionKey.TRUNCATED] @@ -144,7 +148,7 @@ def test_no_observation_keys(): # Round trip should work reconstructed_batch = transition_to_batch(transition) - assert reconstructed_batch["action"] == "action_data" + assert torch.allclose(reconstructed_batch["action"], torch.tensor([[0.7, 0.8]])) assert reconstructed_batch["next.reward"] == 2.0 assert not reconstructed_batch["next.done"] assert reconstructed_batch["next.truncated"] @@ -153,13 +157,13 @@ def test_no_observation_keys(): def test_minimal_batch(): """Test with minimal batch containing only observation.* and action.""" - batch = {"observation.state": "minimal_state", "action": "minimal_action"} + batch = {"observation.state": "minimal_state", "action": torch.tensor([[0.9]])} transition = batch_to_transition(batch) # Check observation assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"} - assert transition[TransitionKey.ACTION] == "minimal_action" + assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([[0.9]])) # Check defaults assert transition[TransitionKey.REWARD] == 0.0 @@ -171,7 +175,7 @@ def test_minimal_batch(): # Round trip reconstructed_batch = transition_to_batch(transition) assert reconstructed_batch["observation.state"] == "minimal_state" - assert reconstructed_batch["action"] == "minimal_action" + assert torch.allclose(reconstructed_batch["action"], torch.tensor([[0.9]])) assert reconstructed_batch["next.reward"] == 0.0 assert not reconstructed_batch["next.done"] assert not reconstructed_batch["next.truncated"] @@ -204,9 +208,10 @@ def test_empty_batch(): def test_complex_nested_observation(): """Test with complex nested observation data.""" + base_batch = _dummy_batch() batch = { "observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890}, - "observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, + "observation.image.left": {"image": base_batch["observation.image.left"], "timestamp": 1234567891}, "observation.state": torch.randn(7), "action": torch.randn(8), "next.reward": 3.14, diff --git a/tests/processor/test_batch_processor.py b/tests/processor/test_batch_processor.py index 568bfa5c4..371ef4ae7 100644 --- a/tests/processor/test_batch_processor.py +++ b/tests/processor/test_batch_processor.py @@ -28,21 +28,7 @@ from lerobot.processor import ( ProcessorStepRegistry, TransitionKey, ) - - -def create_transition( - observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None -): - """Helper to create an EnvTransition dictionary.""" - return { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: action, - TransitionKey.REWARD: reward, - TransitionKey.DONE: done, - TransitionKey.TRUNCATED: truncated, - TransitionKey.INFO: info, - TransitionKey.COMPLEMENTARY_DATA: complementary_data, - } +from lerobot.processor.converters import create_transition def test_state_1d_to_2d(): @@ -517,7 +503,7 @@ def test_action_non_tensor(): assert np.array_equal(result[TransitionKey.ACTION], action_numpy) # String action (edge case) - action_string = "forward" + action_string = "eef.pos.x" transition = create_transition(action=action_string) result = processor(transition) assert result[TransitionKey.ACTION] == action_string @@ -703,7 +689,7 @@ def test_complementary_data_none(): transition = create_transition(complementary_data=None) result = processor(transition) - assert result[TransitionKey.COMPLEMENTARY_DATA] is None + assert result[TransitionKey.COMPLEMENTARY_DATA] == {} def test_complementary_data_empty(): diff --git a/tests/processor/test_classifier_processor.py b/tests/processor/test_classifier_processor.py index 1c9118bd1..11540b8d5 100644 --- a/tests/processor/test_classifier_processor.py +++ b/tests/processor/test_classifier_processor.py @@ -31,19 +31,7 @@ from lerobot.processor import ( NormalizerProcessorStep, TransitionKey, ) - - -def create_transition(observation=None, action=None, **kwargs): - """Helper function to create a transition dictionary.""" - transition = {} - if observation is not None: - transition[TransitionKey.OBSERVATION] = observation - if action is not None: - transition[TransitionKey.ACTION] = action - for key, value in kwargs.items(): - if hasattr(TransitionKey, key.upper()): - transition[getattr(TransitionKey, key.upper())] = value - return transition +from lerobot.processor.converters import create_transition def create_default_config(): @@ -115,7 +103,8 @@ def test_classifier_processor_normalization(): OBS_IMAGE: torch.randn(3, 224, 224), } action = torch.randn(1) # Dummy action/reward - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -146,7 +135,8 @@ def test_classifier_processor_cuda(): OBS_IMAGE: torch.randn(3, 224, 224), } action = torch.randn(1) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -157,7 +147,8 @@ def test_classifier_processor_cuda(): assert processed[TransitionKey.ACTION].device.type == "cuda" # Process through postprocessor - reward_transition = create_transition(action=processed[TransitionKey.ACTION]) + reward_transition = create_transition() + reward_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION] postprocessed = postprocessor(reward_transition) # Check that output is back on CPU @@ -185,7 +176,8 @@ def test_classifier_processor_accelerate_scenario(): OBS_IMAGE: torch.randn(3, 224, 224).to(device), } action = torch.randn(1).to(device) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -212,7 +204,8 @@ def test_classifier_processor_multi_gpu(): OBS_IMAGE: torch.randn(3, 224, 224).to(device), } action = torch.randn(1).to(device) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -239,7 +232,8 @@ def test_classifier_processor_without_stats(): OBS_IMAGE: torch.randn(3, 224, 224), } action = torch.randn(1) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action processed = preprocessor(transition) assert processed is not None @@ -273,7 +267,8 @@ def test_classifier_processor_save_and_load(): OBS_IMAGE: torch.randn(3, 224, 224), } action = torch.randn(1) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action processed = loaded_preprocessor(transition) assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (10,) @@ -308,7 +303,8 @@ def test_classifier_processor_mixed_precision(): OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), } action = torch.randn(1, dtype=torch.float32) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -338,7 +334,8 @@ def test_classifier_processor_batch_data(): OBS_IMAGE: torch.randn(batch_size, 3, 224, 224), } action = torch.randn(batch_size, 1) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -363,7 +360,8 @@ def test_classifier_processor_postprocessor_identity(): # Create test data for postprocessor reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions - transition = create_transition(action=reward) + transition = create_transition() + transition[TransitionKey.ACTION] = reward # Process through postprocessor processed = postprocessor(transition) diff --git a/tests/processor/test_device_processor.py b/tests/processor/test_device_processor.py index 113b1adf2..56de52774 100644 --- a/tests/processor/test_device_processor.py +++ b/tests/processor/test_device_processor.py @@ -20,28 +20,7 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey - - -def create_transition( - observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None -): - """Helper function to create a transition dictionary.""" - transition = {} - if observation is not None: - transition[TransitionKey.OBSERVATION] = observation - if action is not None: - transition[TransitionKey.ACTION] = action - if reward is not None: - transition[TransitionKey.REWARD] = reward - if done is not None: - transition[TransitionKey.DONE] = done - if truncated is not None: - transition[TransitionKey.TRUNCATED] = truncated - if info is not None: - transition[TransitionKey.INFO] = info - if complementary_data is not None: - transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data - return transition +from lerobot.processor.converters import create_transition def test_basic_functionality(): @@ -147,14 +126,14 @@ def test_none_values(): # Test with None observation transition = create_transition(observation=None, action=torch.randn(5)) result = processor(transition) - assert TransitionKey.OBSERVATION not in result + assert result[TransitionKey.OBSERVATION] is None assert result[TransitionKey.ACTION].device.type == "cpu" # Test with None action transition = create_transition(observation={"observation.state": torch.randn(10)}, action=None) result = processor(transition) assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" - assert TransitionKey.ACTION not in result + assert result[TransitionKey.ACTION] is None def test_empty_observation(): @@ -822,8 +801,8 @@ def test_complementary_data_none(): result = processor(transition) - # Complementary data should not be in the result (same as input) - assert TransitionKey.COMPLEMENTARY_DATA not in result + # Complementary data should be an empty dict (standardized behavior) + assert result[TransitionKey.COMPLEMENTARY_DATA] == {} @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") diff --git a/tests/processor/test_diffusion_processor.py b/tests/processor/test_diffusion_processor.py index 1e5e93b4d..e032f3e3d 100644 --- a/tests/processor/test_diffusion_processor.py +++ b/tests/processor/test_diffusion_processor.py @@ -33,19 +33,7 @@ from lerobot.processor import ( TransitionKey, UnnormalizerProcessorStep, ) - - -def create_transition(observation=None, action=None, **kwargs): - """Helper function to create a transition dictionary.""" - transition = {} - if observation is not None: - transition[TransitionKey.OBSERVATION] = observation - if action is not None: - transition[TransitionKey.ACTION] = action - for key, value in kwargs.items(): - if hasattr(TransitionKey, key.upper()): - transition[getattr(TransitionKey, key.upper())] = value - return transition +from lerobot.processor.converters import create_transition def create_default_config(): @@ -118,7 +106,8 @@ def test_diffusion_processor_with_images(): OBS_IMAGE: torch.randn(3, 224, 224), } action = torch.randn(6) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -149,7 +138,8 @@ def test_diffusion_processor_cuda(): OBS_IMAGE: torch.randn(3, 224, 224), } action = torch.randn(6) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -160,7 +150,8 @@ def test_diffusion_processor_cuda(): assert processed[TransitionKey.ACTION].device.type == "cuda" # Process through postprocessor - action_transition = create_transition(action=processed[TransitionKey.ACTION]) + action_transition = create_transition() + action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION] postprocessed = postprocessor(action_transition) # Check that action is back on CPU @@ -188,7 +179,8 @@ def test_diffusion_processor_accelerate_scenario(): OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), } action = torch.randn(1, 6).to(device) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -215,7 +207,8 @@ def test_diffusion_processor_multi_gpu(): OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), } action = torch.randn(1, 6).to(device) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -242,7 +235,8 @@ def test_diffusion_processor_without_stats(): OBS_IMAGE: torch.randn(3, 224, 224), } action = torch.randn(6) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action processed = preprocessor(transition) assert processed is not None @@ -276,7 +270,8 @@ def test_diffusion_processor_save_and_load(): OBS_IMAGE: torch.randn(3, 224, 224), } action = torch.randn(6) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action processed = loaded_preprocessor(transition) assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7) @@ -322,7 +317,8 @@ def test_diffusion_processor_mixed_precision(): OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), } action = torch.randn(6, dtype=torch.float32) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -352,7 +348,8 @@ def test_diffusion_processor_identity_normalization(): OBS_IMAGE: image_value.clone(), } action = torch.randn(6) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -381,7 +378,8 @@ def test_diffusion_processor_batch_consistency(): OBS_IMAGE: torch.randn(batch_size, 3, 224, 224) if batch_size > 1 else torch.randn(3, 224, 224), } action = torch.randn(batch_size, 6) if batch_size > 1 else torch.randn(6) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action processed = preprocessor(transition) @@ -435,7 +433,8 @@ def test_diffusion_processor_bfloat16_device_float32_normalizer(): OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), } action = torch.randn(6, dtype=torch.float32) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through full pipeline processed = preprocessor(transition) diff --git a/tests/processor/test_pi0_processor.py b/tests/processor/test_pi0_processor.py index e83635a48..6c87f826b 100644 --- a/tests/processor/test_pi0_processor.py +++ b/tests/processor/test_pi0_processor.py @@ -34,6 +34,7 @@ from lerobot.processor import ( TransitionKey, UnnormalizerProcessorStep, ) +from lerobot.processor.converters import create_transition class MockTokenizerProcessorStep(ProcessorStep): @@ -52,21 +53,6 @@ class MockTokenizerProcessorStep(ProcessorStep): return features -def create_transition(observation=None, action=None, **kwargs): - """Helper function to create a transition dictionary.""" - transition = {} - if observation is not None: - transition[TransitionKey.OBSERVATION] = observation - if action is not None: - transition[TransitionKey.ACTION] = action - for key, value in kwargs.items(): - if hasattr(TransitionKey, key.upper()): - transition[getattr(TransitionKey, key.upper())] = value - elif key == "complementary_data": - transition[TransitionKey.COMPLEMENTARY_DATA] = value - return transition - - def create_default_config(): """Create a default PI0 configuration for testing.""" config = PI0Config() @@ -219,7 +205,8 @@ def test_pi0_processor_cuda(): OBS_IMAGE: torch.randn(3, 224, 224), } action = torch.randn(6) - transition = create_transition(observation, action, complementary_data={"task": "test task"}) + transition = create_transition(observation=observation, complementary_data={"task": "test task"}) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -275,7 +262,8 @@ def test_pi0_processor_accelerate_scenario(): OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), } action = torch.randn(1, 6).to(device) - transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) + transition = create_transition(observation=observation, complementary_data={"task": ["test task"]}) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -331,7 +319,8 @@ def test_pi0_processor_multi_gpu(): OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), } action = torch.randn(1, 6).to(device) - transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) + transition = create_transition(observation=observation, complementary_data={"task": ["test task"]}) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -426,8 +415,9 @@ def test_pi0_processor_bfloat16_device_float32_normalizer(): } action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6 transition = create_transition( - observation, action, complementary_data={"task": "test bfloat16 adaptation"} + observation=observation, complementary_data={"task": "test bfloat16 adaptation"} ) + transition[TransitionKey.ACTION] = action # Process through full pipeline processed = preprocessor(transition) diff --git a/tests/processor/test_sac_processor.py b/tests/processor/test_sac_processor.py index 3e26172c3..9fbe7979b 100644 --- a/tests/processor/test_sac_processor.py +++ b/tests/processor/test_sac_processor.py @@ -33,19 +33,7 @@ from lerobot.processor import ( TransitionKey, UnnormalizerProcessorStep, ) - - -def create_transition(observation=None, action=None, **kwargs): - """Helper function to create a transition dictionary.""" - transition = {} - if observation is not None: - transition[TransitionKey.OBSERVATION] = observation - if action is not None: - transition[TransitionKey.ACTION] = action - for key, value in kwargs.items(): - if hasattr(TransitionKey, key.upper()): - transition[getattr(TransitionKey, key.upper())] = value - return transition +from lerobot.processor.converters import create_transition def create_default_config(): @@ -117,7 +105,8 @@ def test_sac_processor_normalization_modes(): # Create test data observation = {OBS_STATE: torch.randn(10) * 2} # Larger values to test normalization action = torch.rand(5) * 2 - 1 # Range [-1, 1] - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -129,7 +118,8 @@ def test_sac_processor_normalization_modes(): assert processed[TransitionKey.ACTION].shape == (1, 5) # Process action through postprocessor - action_transition = create_transition(action=processed[TransitionKey.ACTION]) + action_transition = create_transition() + action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION] postprocessed = postprocessor(action_transition) # Check that action is unnormalized (but still batched) @@ -153,7 +143,8 @@ def test_sac_processor_cuda(): # Create CPU data observation = {OBS_STATE: torch.randn(10)} action = torch.randn(5) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -163,7 +154,8 @@ def test_sac_processor_cuda(): assert processed[TransitionKey.ACTION].device.type == "cuda" # Process through postprocessor - action_transition = create_transition(action=processed[TransitionKey.ACTION]) + action_transition = create_transition() + action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION] postprocessed = postprocessor(action_transition) # Check that action is back on CPU @@ -188,7 +180,8 @@ def test_sac_processor_accelerate_scenario(): device = torch.device("cuda:0") observation = {OBS_STATE: torch.randn(10).to(device)} action = torch.randn(5).to(device) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -216,7 +209,8 @@ def test_sac_processor_multi_gpu(): device = torch.device("cuda:1") observation = {OBS_STATE: torch.randn(10).to(device)} action = torch.randn(5).to(device) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -254,7 +248,8 @@ def test_sac_processor_without_stats(): # Process should still work observation = {OBS_STATE: torch.randn(10)} action = torch.randn(5) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action processed = preprocessor(transition) assert processed is not None @@ -284,7 +279,8 @@ def test_sac_processor_save_and_load(): # Test that loaded processor works observation = {OBS_STATE: torch.randn(10)} action = torch.randn(5) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action processed = loaded_preprocessor(transition) assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10) @@ -329,7 +325,8 @@ def test_sac_processor_mixed_precision(): # Create test data observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} action = torch.randn(5, dtype=torch.float32) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -355,7 +352,8 @@ def test_sac_processor_batch_data(): batch_size = 32 observation = {OBS_STATE: torch.randn(batch_size, 10)} action = torch.randn(batch_size, 5) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -378,13 +376,14 @@ def test_sac_processor_edge_cases(): ) # Test with empty observation - transition = create_transition(observation={}, action=torch.randn(5)) + transition = create_transition(observation={}) + transition[TransitionKey.ACTION] = torch.randn(5) processed = preprocessor(transition) assert processed[TransitionKey.OBSERVATION] == {} assert processed[TransitionKey.ACTION].shape == (1, 5) # Test with None action - transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=None) + transition = create_transition(observation={OBS_STATE: torch.randn(10)}) processed = preprocessor(transition) assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10) # When action is None, it may still be present with None value @@ -433,7 +432,8 @@ def test_sac_processor_bfloat16_device_float32_normalizer(): # Create test data observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} # Start with float32 action = torch.randn(5, dtype=torch.float32) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through full pipeline processed = preprocessor(transition) diff --git a/tests/processor/test_smolvla_processor.py b/tests/processor/test_smolvla_processor.py index 317b0feec..944901977 100644 --- a/tests/processor/test_smolvla_processor.py +++ b/tests/processor/test_smolvla_processor.py @@ -37,6 +37,7 @@ from lerobot.processor import ( TransitionKey, UnnormalizerProcessorStep, ) +from lerobot.processor.converters import create_transition class MockTokenizerProcessorStep(ProcessorStep): @@ -55,21 +56,6 @@ class MockTokenizerProcessorStep(ProcessorStep): return features -def create_transition(observation=None, action=None, **kwargs): - """Helper function to create a transition dictionary.""" - transition = {} - if observation is not None: - transition[TransitionKey.OBSERVATION] = observation - if action is not None: - transition[TransitionKey.ACTION] = action - for key, value in kwargs.items(): - if hasattr(TransitionKey, key.upper()): - transition[getattr(TransitionKey, key.upper())] = value - elif key == "complementary_data": - transition[TransitionKey.COMPLEMENTARY_DATA] = value - return transition - - def create_default_config(): """Create a default SmolVLA configuration for testing.""" config = SmolVLAConfig() @@ -228,7 +214,8 @@ def test_smolvla_processor_cuda(): OBS_IMAGE: torch.randn(3, 224, 224), } action = torch.randn(7) - transition = create_transition(observation, action, complementary_data={"task": "test task"}) + transition = create_transition(observation=observation, complementary_data={"task": "test task"}) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -286,7 +273,8 @@ def test_smolvla_processor_accelerate_scenario(): OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), } action = torch.randn(1, 7).to(device) - transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) + transition = create_transition(observation=observation, complementary_data={"task": ["test task"]}) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -344,7 +332,8 @@ def test_smolvla_processor_multi_gpu(): OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), } action = torch.randn(1, 7).to(device) - transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) + transition = create_transition(observation=observation, complementary_data={"task": ["test task"]}) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -455,8 +444,9 @@ def test_smolvla_processor_bfloat16_device_float32_normalizer(): } action = torch.randn(7, dtype=torch.float32) transition = create_transition( - observation, action, complementary_data={"task": "test bfloat16 adaptation"} + observation=observation, complementary_data={"task": "test bfloat16 adaptation"} ) + transition[TransitionKey.ACTION] = action # Process through full pipeline processed = preprocessor(transition) diff --git a/tests/processor/test_vqbet_processor.py b/tests/processor/test_vqbet_processor.py index c05fb15fe..fdfc62a76 100644 --- a/tests/processor/test_vqbet_processor.py +++ b/tests/processor/test_vqbet_processor.py @@ -33,19 +33,7 @@ from lerobot.processor import ( TransitionKey, UnnormalizerProcessorStep, ) - - -def create_transition(observation=None, action=None, **kwargs): - """Helper function to create a transition dictionary.""" - transition = {} - if observation is not None: - transition[TransitionKey.OBSERVATION] = observation - if action is not None: - transition[TransitionKey.ACTION] = action - for key, value in kwargs.items(): - if hasattr(TransitionKey, key.upper()): - transition[getattr(TransitionKey, key.upper())] = value - return transition +from lerobot.processor.converters import create_transition def create_default_config(): @@ -123,7 +111,8 @@ def test_vqbet_processor_with_images(): OBS_IMAGE: torch.randn(3, 224, 224), } action = torch.randn(7) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -154,7 +143,8 @@ def test_vqbet_processor_cuda(): OBS_IMAGE: torch.randn(3, 224, 224), } action = torch.randn(7) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -165,7 +155,8 @@ def test_vqbet_processor_cuda(): assert processed[TransitionKey.ACTION].device.type == "cuda" # Process through postprocessor - action_transition = create_transition(action=processed[TransitionKey.ACTION]) + action_transition = create_transition() + action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION] postprocessed = postprocessor(action_transition) # Check that action is back on CPU @@ -193,7 +184,8 @@ def test_vqbet_processor_accelerate_scenario(): OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), } action = torch.randn(1, 7).to(device) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -225,7 +217,8 @@ def test_vqbet_processor_multi_gpu(): OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), } action = torch.randn(1, 7).to(device) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -267,7 +260,8 @@ def test_vqbet_processor_without_stats(): OBS_IMAGE: torch.randn(3, 224, 224), } action = torch.randn(7) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action processed = preprocessor(transition) assert processed is not None @@ -300,7 +294,8 @@ def test_vqbet_processor_save_and_load(): OBS_IMAGE: torch.randn(3, 224, 224), } action = torch.randn(7) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action processed = loaded_preprocessor(transition) assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 8) @@ -349,7 +344,8 @@ def test_vqbet_processor_mixed_precision(): OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), } action = torch.randn(7, dtype=torch.float32) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -379,7 +375,8 @@ def test_vqbet_processor_large_batch(): OBS_IMAGE: torch.randn(batch_size, 3, 224, 224), } action = torch.randn(batch_size, 7) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through preprocessor processed = preprocessor(transition) @@ -410,7 +407,8 @@ def test_vqbet_processor_sequential_processing(): OBS_IMAGE: torch.randn(3, 224, 224), } action = torch.randn(7) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action processed = preprocessor(transition) results.append(processed) @@ -467,7 +465,8 @@ def test_vqbet_processor_bfloat16_device_float32_normalizer(): OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), } action = torch.randn(7, dtype=torch.float32) - transition = create_transition(observation, action) + transition = create_transition(observation=observation) + transition[TransitionKey.ACTION] = action # Process through full pipeline processed = preprocessor(transition)