feat(batch_processor): Add task field processing to ToBatchProcessor

- Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference.
- Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings.
- Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data.
This commit is contained in:
Adil Zouitine
2025-07-25 19:05:44 +02:00
committed by Steven Palma
parent c4763f61a1
commit c0013b130b
2 changed files with 277 additions and 0 deletions
+259
View File
@@ -640,3 +640,262 @@ def test_empty_action_tensor():
# Should remain unchanged
assert result[TransitionKey.ACTION].shape == (1, 0)
# Task-specific tests
def test_task_string_to_list():
"""Test that string tasks get wrapped in lists to add batch dimension."""
processor = ToBatchProcessor()
# Create complementary data with string task
complementary_data = {"task": "pick_cube"}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
# String task should be wrapped in list
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == ["pick_cube"]
assert isinstance(processed_comp_data["task"], list)
assert len(processed_comp_data["task"]) == 1
def test_task_string_validation():
"""Test that only string and list of strings are valid task values."""
processor = ToBatchProcessor()
# Valid string task - should be converted to list
complementary_data = {"task": "valid_task"}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == ["valid_task"]
# Valid list of strings - should remain unchanged
complementary_data = {"task": ["task1", "task2"]}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == ["task1", "task2"]
def test_task_list_of_strings():
"""Test that lists of strings remain unchanged (already batched)."""
processor = ToBatchProcessor()
# Test various list of strings
test_lists = [
["pick_cube"], # Single string in list
["pick_cube", "place_cube"], # Multiple strings
["task1", "task2", "task3"], # Three strings
[], # Empty list
[""], # List with empty string
["task with spaces", "task_with_underscores"], # Mixed formats
]
for task_list in test_lists:
complementary_data = {"task": task_list}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
# Should remain unchanged since it's already a list
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == task_list
assert isinstance(processed_comp_data["task"], list)
def test_complementary_data_none():
"""Test processor handles None complementary_data gracefully."""
processor = ToBatchProcessor()
transition = create_transition(complementary_data=None)
result = processor(transition)
assert result[TransitionKey.COMPLEMENTARY_DATA] is None
def test_complementary_data_empty():
"""Test processor handles empty complementary_data dict."""
processor = ToBatchProcessor()
complementary_data = {}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
def test_complementary_data_no_task():
"""Test processor handles complementary_data without task field."""
processor = ToBatchProcessor()
complementary_data = {
"episode_id": 123,
"timestamp": 1234567890.0,
"extra_info": "some data",
}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
# Should remain unchanged
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data == complementary_data
def test_complementary_data_mixed():
"""Test processor with mixed complementary_data containing task and other fields."""
processor = ToBatchProcessor()
complementary_data = {
"task": "stack_blocks",
"episode_id": 456,
"difficulty": "hard",
"metadata": {"scene": "kitchen"},
}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
# Task should be batched
assert processed_comp_data["task"] == ["stack_blocks"]
# Other fields should remain unchanged
assert processed_comp_data["episode_id"] == 456
assert processed_comp_data["difficulty"] == "hard"
assert processed_comp_data["metadata"] == {"scene": "kitchen"}
def test_task_with_observation_and_action():
"""Test task processing together with observation and action processing."""
processor = ToBatchProcessor()
# All components need batching
observation = {
OBS_STATE: torch.randn(5),
OBS_IMAGE: torch.randn(32, 32, 3),
}
action = torch.randn(4)
complementary_data = {"task": "navigate_to_goal"}
transition = create_transition(
observation=observation, action=action, complementary_data=complementary_data
)
result = processor(transition)
# All should be batched
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5)
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 32, 32, 3)
assert result[TransitionKey.ACTION].shape == (1, 4)
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["navigate_to_goal"]
def test_task_comprehensive_string_cases():
"""Test task processing with comprehensive string cases and edge cases."""
processor = ToBatchProcessor()
# Test various string formats
string_tasks = [
"pick_and_place",
"navigate",
"open_drawer",
"", # Empty string (valid but edge case)
"task with spaces",
"task_with_underscores",
"task-with-dashes",
"UPPERCASE_TASK",
"MixedCaseTask",
"task123",
"数字任务", # Unicode task
"🤖 robot task", # Emoji in task
"task\nwith\nnewlines", # Special characters
"task\twith\ttabs",
"task with 'quotes'",
'task with "double quotes"',
]
# Test that all string tasks get properly batched
for task in string_tasks:
complementary_data = {"task": task}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == [task]
assert isinstance(processed_comp_data["task"], list)
assert len(processed_comp_data["task"]) == 1
# Test various list of strings (should remain unchanged)
list_tasks = [
["single_task"],
["task1", "task2"],
["pick", "place", "navigate"],
[], # Empty list
[""], # List with empty string
["task with spaces", "task_with_underscores", "UPPERCASE"],
["🤖 task", "数字任务", "normal_task"], # Mixed formats
]
for task_list in list_tasks:
complementary_data = {"task": task_list}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == task_list
assert isinstance(processed_comp_data["task"], list)
assert processed_comp_data["task"] is task_list # Should be same object (in-place)
def test_task_in_place_mutation():
"""Test that the processor mutates complementary_data in place for tasks."""
processor = ToBatchProcessor()
complementary_data = {"task": "sort_objects"}
transition = create_transition(complementary_data=complementary_data)
# Store reference to original transition and complementary_data
original_transition = transition
original_comp_data = complementary_data
# Process
result = processor(transition)
# Should be the same objects (in-place mutation)
assert result is original_transition
assert result[TransitionKey.COMPLEMENTARY_DATA] is original_comp_data
assert original_comp_data["task"] == ["sort_objects"]
def test_task_preserves_other_keys():
"""Test that task processing preserves other keys in complementary_data."""
processor = ToBatchProcessor()
complementary_data = {
"task": "clean_table",
"robot_id": "robot_123",
"motor_id": "motor_456",
"config": {"speed": "slow", "precision": "high"},
"metrics": [1.0, 2.0, 3.0],
}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
# Task should be processed
assert processed_comp_data["task"] == ["clean_table"]
# All other keys should be preserved exactly
assert processed_comp_data["robot_id"] == "robot_123"
assert processed_comp_data["motor_id"] == "motor_456"
assert processed_comp_data["config"] == {"speed": "slow", "precision": "high"}
assert processed_comp_data["metrics"] == [1.0, 2.0, 3.0]