fix(processor): specialized processors respect contract by raising if none (#1909)

* fix(processor): specialized processor now raise

* test(processor): fix tests for now raise specialized processors

* test(processor): use identity in newly introduced pipeline
This commit is contained in:
Steven Palma
2025-09-10 18:45:47 +02:00
committed by GitHub
parent 51588f741b
commit 6745958362
7 changed files with 78 additions and 74 deletions
+58 -58
View File
@@ -38,7 +38,7 @@ def test_state_1d_to_2d():
# Test observation.state
state_1d = torch.randn(7)
observation = {OBS_STATE: state_1d}
transition = create_transition(observation=observation)
transition = create_transition(observation=observation, action={})
result = processor(transition)
@@ -54,7 +54,7 @@ def test_env_state_1d_to_2d():
# Test observation.environment_state
env_state_1d = torch.randn(10)
observation = {OBS_ENV_STATE: env_state_1d}
transition = create_transition(observation=observation)
transition = create_transition(observation=observation, action={})
result = processor(transition)
@@ -70,7 +70,7 @@ def test_image_3d_to_4d():
# Test observation.image
image_3d = torch.randn(224, 224, 3)
observation = {OBS_IMAGE: image_3d}
transition = create_transition(observation=observation)
transition = create_transition(observation=observation, action={})
result = processor(transition)
@@ -90,7 +90,7 @@ def test_multiple_images_3d_to_4d():
f"{OBS_IMAGES}.camera1": image1_3d,
f"{OBS_IMAGES}.camera2": image2_3d,
}
transition = create_transition(observation=observation)
transition = create_transition(observation=observation, action={})
result = processor(transition)
@@ -118,7 +118,7 @@ def test_already_batched_tensors_unchanged():
OBS_ENV_STATE: env_state_2d,
OBS_IMAGE: image_4d,
}
transition = create_transition(observation=observation)
transition = create_transition(observation=observation, action={})
result = processor(transition)
@@ -142,7 +142,7 @@ def test_higher_dimensional_tensors_unchanged():
OBS_STATE: state_3d,
OBS_IMAGE: image_5d,
}
transition = create_transition(observation=observation)
transition = create_transition(observation=observation, action={})
result = processor(transition)
@@ -163,7 +163,7 @@ def test_non_tensor_values_unchanged():
"custom_key": 42, # Integer
"another_key": {"nested": "dict"}, # Dict
}
transition = create_transition(observation=observation)
transition = create_transition(observation=observation, action={})
result = processor(transition)
@@ -180,10 +180,10 @@ def test_none_observation():
"""Test processor handles None observation gracefully."""
processor = AddBatchDimensionProcessorStep()
transition = create_transition(observation=None)
transition = create_transition(observation={}, action={})
result = processor(transition)
assert result[TransitionKey.OBSERVATION] is None
assert result[TransitionKey.OBSERVATION] == {}
def test_empty_observation():
@@ -191,7 +191,7 @@ def test_empty_observation():
processor = AddBatchDimensionProcessorStep()
observation = {}
transition = create_transition(observation=observation)
transition = create_transition(observation=observation, action={})
result = processor(transition)
@@ -216,7 +216,7 @@ def test_mixed_observation():
"other_tensor": other_tensor,
"non_tensor": "string_value",
}
transition = create_transition(observation=observation)
transition = create_transition(observation=observation, action={})
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
@@ -243,7 +243,7 @@ def test_integration_with_robot_processor():
OBS_STATE: torch.randn(7),
OBS_IMAGE: torch.randn(224, 224, 3),
}
transition = create_transition(observation=observation)
transition = create_transition(observation=observation, action={})
result = pipeline(transition)
processed_obs = result[TransitionKey.OBSERVATION]
@@ -299,7 +299,7 @@ def test_save_and_load_pretrained():
# Test functionality of loaded processor
observation = {OBS_STATE: torch.randn(5)}
transition = create_transition(observation=observation)
transition = create_transition(observation=observation, action={})
result = loaded_pipeline(transition)
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5)
@@ -333,7 +333,7 @@ def test_registry_based_save_load():
OBS_STATE: torch.randn(3),
OBS_IMAGE: torch.randn(100, 100, 3),
}
transition = create_transition(observation=observation)
transition = create_transition(observation=observation, action={})
result = loaded_pipeline(transition)
processed_obs = result[TransitionKey.OBSERVATION]
@@ -355,7 +355,7 @@ def test_device_compatibility():
OBS_STATE: state_1d,
OBS_IMAGE: image_3d,
}
transition = create_transition(observation=observation)
transition = create_transition(observation=observation, action={})
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
@@ -415,7 +415,7 @@ def test_edge_case_zero_dimensional_tensors():
OBS_STATE: scalar_tensor,
"scalar_value": scalar_tensor,
}
transition = create_transition(observation=observation)
transition = create_transition(observation=observation, action={})
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
@@ -432,7 +432,7 @@ def test_action_1d_to_2d():
# Create 1D action tensor
action_1d = torch.randn(4)
transition = create_transition(action=action_1d)
transition = create_transition(observation={}, action=action_1d)
result = processor(transition)
@@ -450,12 +450,12 @@ def test_action_already_batched():
action_batched_5 = torch.randn(5, 4)
# Single batch
transition = create_transition(action=action_batched_1)
transition = create_transition(action=action_batched_1, observation={})
result = processor(transition)
assert torch.equal(result[TransitionKey.ACTION], action_batched_1)
# Multiple batch
transition = create_transition(action=action_batched_5)
transition = create_transition(action=action_batched_5, observation={})
result = processor(transition)
assert torch.equal(result[TransitionKey.ACTION], action_batched_5)
@@ -466,13 +466,13 @@ def test_action_higher_dimensional():
# 3D action tensor (e.g., sequence of actions)
action_3d = torch.randn(2, 4, 3)
transition = create_transition(action=action_3d)
transition = create_transition(action=action_3d, observation={})
result = processor(transition)
assert torch.equal(result[TransitionKey.ACTION], action_3d)
# 4D action tensor
action_4d = torch.randn(2, 10, 4, 3)
transition = create_transition(action=action_4d)
transition = create_transition(action=action_4d, observation={})
result = processor(transition)
assert torch.equal(result[TransitionKey.ACTION], action_4d)
@@ -482,7 +482,7 @@ def test_action_scalar_tensor():
processor = AddBatchDimensionProcessorStep()
action_scalar = torch.tensor(1.5)
transition = create_transition(action=action_scalar)
transition = create_transition(action=action_scalar, observation={})
result = processor(transition)
# Should remain scalar
@@ -496,25 +496,25 @@ def test_action_non_tensor():
# List action
action_list = [0.1, 0.2, 0.3, 0.4]
transition = create_transition(action=action_list)
transition = create_transition(action=action_list, observation={})
result = processor(transition)
assert result[TransitionKey.ACTION] == action_list
# Numpy array action (as Python object, not converted)
action_numpy = np.array([1, 2, 3, 4])
transition = create_transition(action=action_numpy)
transition = create_transition(action=action_numpy, observation={})
result = processor(transition)
assert np.array_equal(result[TransitionKey.ACTION], action_numpy)
# String action (edge case)
action_string = "forward"
transition = create_transition(action=action_string)
transition = create_transition(action=action_string, observation={})
result = processor(transition)
assert result[TransitionKey.ACTION] == action_string
# Dict action (structured action)
action_dict = {"linear": [0.5, 0.0], "angular": 0.2}
transition = create_transition(action=action_dict)
transition = create_transition(action=action_dict, observation={})
result = processor(transition)
assert result[TransitionKey.ACTION] == action_dict
@@ -523,9 +523,9 @@ def test_action_none():
"""Test that None action is handled correctly."""
processor = AddBatchDimensionProcessorStep()
transition = create_transition(action=None)
transition = create_transition(action={}, observation={})
result = processor(transition)
assert result[TransitionKey.ACTION] is None
assert result[TransitionKey.ACTION] == {}
def test_action_with_observation():
@@ -557,7 +557,7 @@ def test_action_different_sizes():
for size in action_sizes:
action = torch.randn(size)
transition = create_transition(action=action)
transition = create_transition(action=action, observation={})
result = processor(transition)
assert result[TransitionKey.ACTION].shape == (1, size)
@@ -571,7 +571,7 @@ def test_action_device_compatibility():
# CUDA action
action_cuda = torch.randn(4, device="cuda")
transition = create_transition(action=action_cuda)
transition = create_transition(action=action_cuda, observation={})
result = processor(transition)
assert result[TransitionKey.ACTION].shape == (1, 4)
@@ -579,7 +579,7 @@ def test_action_device_compatibility():
# CPU action
action_cpu = torch.randn(4, device="cpu")
transition = create_transition(action=action_cpu)
transition = create_transition(action=action_cpu, observation={})
result = processor(transition)
assert result[TransitionKey.ACTION].shape == (1, 4)
@@ -595,7 +595,7 @@ def test_action_dtype_preservation():
for dtype in dtypes:
action = torch.randn(4).to(dtype)
transition = create_transition(action=action)
transition = create_transition(action=action, observation={})
result = processor(transition)
assert result[TransitionKey.ACTION].dtype == dtype
@@ -608,7 +608,7 @@ def test_empty_action_tensor():
# Empty 1D tensor
action_empty = torch.tensor([])
transition = create_transition(action=action_empty)
transition = create_transition(action=action_empty, observation={})
result = processor(transition)
# Should add batch dimension even to empty tensor
@@ -616,7 +616,7 @@ def test_empty_action_tensor():
# Empty 2D tensor (already batched)
action_empty_2d = torch.randn(1, 0)
transition = create_transition(action=action_empty_2d)
transition = create_transition(action=action_empty_2d, observation={})
result = processor(transition)
# Should remain unchanged
@@ -630,7 +630,7 @@ def test_task_string_to_list():
# Create complementary data with string task
complementary_data = {"task": "pick_cube"}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(action={}, observation={}, complementary_data=complementary_data)
result = processor(transition)
@@ -647,14 +647,14 @@ def test_task_string_validation():
# Valid string task - should be converted to list
complementary_data = {"task": "valid_task"}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == ["valid_task"]
# Valid list of strings - should remain unchanged
complementary_data = {"task": ["task1", "task2"]}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == ["task1", "task2"]
@@ -676,7 +676,7 @@ def test_task_list_of_strings():
for task_list in test_lists:
complementary_data = {"task": task_list}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
@@ -690,7 +690,7 @@ def test_complementary_data_none():
"""Test processor handles None complementary_data gracefully."""
processor = AddBatchDimensionProcessorStep()
transition = create_transition(complementary_data=None)
transition = create_transition(complementary_data=None, action={}, observation={})
result = processor(transition)
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
@@ -701,7 +701,7 @@ def test_complementary_data_empty():
processor = AddBatchDimensionProcessorStep()
complementary_data = {}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
@@ -717,7 +717,7 @@ def test_complementary_data_no_task():
"timestamp": 1234567890.0,
"extra_info": "some data",
}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
@@ -736,7 +736,7 @@ def test_complementary_data_mixed():
"difficulty": "hard",
"metadata": {"scene": "kitchen"},
}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
@@ -803,7 +803,7 @@ def test_task_comprehensive_string_cases():
# Test that all string tasks get properly batched
for task in string_tasks:
complementary_data = {"task": task}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
@@ -825,7 +825,7 @@ def test_task_comprehensive_string_cases():
for task_list in list_tasks:
complementary_data = {"task": task_list}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
@@ -845,7 +845,7 @@ def test_task_preserves_other_keys():
"config": {"speed": "slow", "precision": "high"},
"metrics": [1.0, 2.0, 3.0],
}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
@@ -869,7 +869,7 @@ def test_index_scalar_to_1d():
# Create 0D index tensor (scalar)
index_0d = torch.tensor(42, dtype=torch.int64)
complementary_data = {"index": index_0d}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
@@ -886,7 +886,7 @@ def test_task_index_scalar_to_1d():
# Create 0D task_index tensor (scalar)
task_index_0d = torch.tensor(7, dtype=torch.int64)
complementary_data = {"task_index": task_index_0d}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
@@ -908,7 +908,7 @@ def test_index_and_task_index_together():
"task_index": task_index_0d,
"task": "pick_object",
}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
@@ -936,13 +936,13 @@ def test_index_already_batched():
# Test 1D (already batched)
complementary_data = {"index": index_1d}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_1d)
# Test 2D
complementary_data = {"index": index_2d}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_2d)
@@ -957,13 +957,13 @@ def test_task_index_already_batched():
# Test 1D (already batched)
complementary_data = {"task_index": task_index_1d}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_1d)
# Test 2D
complementary_data = {"task_index": task_index_2d}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_2d)
@@ -976,7 +976,7 @@ def test_index_non_tensor_unchanged():
"index": 42, # Plain int, not tensor
"task_index": [1, 2, 3], # List, not tensor
}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
@@ -999,7 +999,7 @@ def test_index_dtype_preservation():
"index": index_0d,
"task_index": task_index_0d,
}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
@@ -1062,7 +1062,7 @@ def test_index_device_compatibility():
"index": index_0d,
"task_index": task_index_0d,
}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
@@ -1081,7 +1081,7 @@ def test_empty_index_tensor():
# Empty 0D tensor doesn't make sense, but test empty 1D
index_empty = torch.tensor([], dtype=torch.int64)
complementary_data = {"index": index_empty}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
result = processor(transition)
@@ -1094,7 +1094,7 @@ def test_action_processing_creates_new_transition():
processor = AddBatchDimensionProcessorStep()
action = torch.randn(4)
transition = create_transition(action=action)
transition = create_transition(action=action, observation={})
# Store reference to original transition
original_transition = transition
@@ -1116,7 +1116,7 @@ def test_task_processing_creates_new_transition():
processor = AddBatchDimensionProcessorStep()
complementary_data = {"task": "sort_objects"}
transition = create_transition(complementary_data=complementary_data)
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
# Store reference to original transition and complementary_data
original_transition = transition
+5 -1
View File
@@ -218,7 +218,11 @@ def test_diffusion_processor_without_stats():
"""Test Diffusion processor creation without dataset statistics."""
config = create_default_config()
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, dataset_stats=None)
preprocessor, postprocessor = make_diffusion_pre_post_processors(
config,
dataset_stats=None,
preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition},
)
# Should still create processors
assert preprocessor is not None
@@ -136,7 +136,7 @@ def test_none_observation():
"""Test processor with None observation."""
processor = VanillaObservationProcessorStep()
transition = create_transition()
transition = create_transition(observation={})
result = processor(transition)
assert result == transition
+1 -1
View File
@@ -86,7 +86,7 @@ def test_none_observation():
"""Test processor with None observation."""
processor = RenameObservationsProcessorStep(rename_map={"old": "new"})
transition = create_transition()
transition = create_transition(observation={})
result = processor(transition)
# Should return transition unchanged
+1 -1
View File
@@ -372,7 +372,7 @@ def test_sac_processor_edge_cases():
assert processed[TransitionKey.ACTION].shape == (1, 5)
# Test with None action
transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=None)
transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action={})
processed = preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
# When action is None, it may still be present with None value
+5 -5
View File
@@ -203,7 +203,7 @@ def test_none_complementary_data(mock_auto_tokenizer):
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
transition = create_transition(complementary_data=None)
transition = create_transition(observation={}, complementary_data=None)
result = processor(transition)
assert result == transition # Should return unchanged
@@ -218,7 +218,7 @@ def test_missing_task_key(mock_auto_tokenizer):
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
transition = create_transition(complementary_data={"other_field": "some value"})
transition = create_transition(observation={}, complementary_data={"other_field": "some value"})
result = processor(transition)
assert result == transition # Should return unchanged
@@ -233,7 +233,7 @@ def test_none_task_value(mock_auto_tokenizer):
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
transition = create_transition(complementary_data={"task": None})
transition = create_transition(observation={}, complementary_data={"task": None})
result = processor(transition)
assert result == transition # Should return unchanged
@@ -249,13 +249,13 @@ def test_unsupported_task_type(mock_auto_tokenizer):
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
# Test with integer task
transition = create_transition(complementary_data={"task": 123})
transition = create_transition(observation={}, complementary_data={"task": 123})
result = processor(transition)
assert result == transition # Should return unchanged
# Test with mixed list
transition = create_transition(complementary_data={"task": ["text", 123, "more text"]})
transition = create_transition(observation={}, complementary_data={"task": ["text", 123, "more text"]})
result = processor(transition)
assert result == transition # Should return unchanged