feat(batch_processor): Add task field processing to ToBatchProcessor

- Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference.
- Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings.
- Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data.
This commit is contained in:
Adil Zouitine
2025-07-25 19:05:44 +02:00
committed by Steven Palma
parent c4763f61a1
commit c0013b130b
2 changed files with 277 additions and 0 deletions
+18
View File
@@ -37,6 +37,10 @@ class ToBatchProcessor:
- For actions: - For actions:
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
- For task field in complementary data:
Wraps string task in a list to add batch dimension
(task must be a string or list of strings)
This is useful when processing single transitions that need to be batched for This is useful when processing single transitions that need to be batched for
model inference or when converting from unbatched environment outputs to model inference or when converting from unbatched environment outputs to
batched model inputs. batched model inputs.
@@ -49,6 +53,7 @@ class ToBatchProcessor:
# State: (7,) -> (1, 7) # State: (7,) -> (1, 7)
# Image: (224, 224, 3) -> (1, 224, 224, 3) # Image: (224, 224, 3) -> (1, 224, 224, 3)
# Action: (4,) -> (1, 4) # Action: (4,) -> (1, 4)
# Task: "pick_cube" -> ["pick_cube"]
# Already batched: (1, 7) -> (1, 7) [unchanged] # Already batched: (1, 7) -> (1, 7) [unchanged]
``` ```
""" """
@@ -56,6 +61,7 @@ class ToBatchProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
self._process_observation(transition) self._process_observation(transition)
self._process_action(transition) self._process_action(transition)
self._process_complementary_data(transition)
return transition return transition
def _process_observation(self, transition: EnvTransition) -> None: def _process_observation(self, transition: EnvTransition) -> None:
@@ -88,6 +94,18 @@ class ToBatchProcessor:
if action is not None and isinstance(action, Tensor) and action.dim() == 1: if action is not None and isinstance(action, Tensor) and action.dim() == 1:
transition[TransitionKey.ACTION] = action.unsqueeze(0) transition[TransitionKey.ACTION] = action.unsqueeze(0)
def _process_complementary_data(self, transition: EnvTransition) -> None:
"""Process complementary data in-place, handling task field batching."""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return
# Process task field - wrap string in list to add batch dimension
if "task" in complementary_data:
task_value = complementary_data["task"]
if isinstance(task_value, str):
complementary_data["task"] = [task_value]
def get_config(self) -> dict[str, Any]: def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization.""" """Return configuration for serialization."""
return {} return {}
+259
View File
@@ -640,3 +640,262 @@ def test_empty_action_tensor():
# Should remain unchanged # Should remain unchanged
assert result[TransitionKey.ACTION].shape == (1, 0) assert result[TransitionKey.ACTION].shape == (1, 0)
# Task-specific tests
def test_task_string_to_list():
"""Test that string tasks get wrapped in lists to add batch dimension."""
processor = ToBatchProcessor()
# Create complementary data with string task
complementary_data = {"task": "pick_cube"}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
# String task should be wrapped in list
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == ["pick_cube"]
assert isinstance(processed_comp_data["task"], list)
assert len(processed_comp_data["task"]) == 1
def test_task_string_validation():
"""Test that only string and list of strings are valid task values."""
processor = ToBatchProcessor()
# Valid string task - should be converted to list
complementary_data = {"task": "valid_task"}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == ["valid_task"]
# Valid list of strings - should remain unchanged
complementary_data = {"task": ["task1", "task2"]}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == ["task1", "task2"]
def test_task_list_of_strings():
"""Test that lists of strings remain unchanged (already batched)."""
processor = ToBatchProcessor()
# Test various list of strings
test_lists = [
["pick_cube"], # Single string in list
["pick_cube", "place_cube"], # Multiple strings
["task1", "task2", "task3"], # Three strings
[], # Empty list
[""], # List with empty string
["task with spaces", "task_with_underscores"], # Mixed formats
]
for task_list in test_lists:
complementary_data = {"task": task_list}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
# Should remain unchanged since it's already a list
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == task_list
assert isinstance(processed_comp_data["task"], list)
def test_complementary_data_none():
"""Test processor handles None complementary_data gracefully."""
processor = ToBatchProcessor()
transition = create_transition(complementary_data=None)
result = processor(transition)
assert result[TransitionKey.COMPLEMENTARY_DATA] is None
def test_complementary_data_empty():
"""Test processor handles empty complementary_data dict."""
processor = ToBatchProcessor()
complementary_data = {}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
def test_complementary_data_no_task():
"""Test processor handles complementary_data without task field."""
processor = ToBatchProcessor()
complementary_data = {
"episode_id": 123,
"timestamp": 1234567890.0,
"extra_info": "some data",
}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
# Should remain unchanged
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data == complementary_data
def test_complementary_data_mixed():
"""Test processor with mixed complementary_data containing task and other fields."""
processor = ToBatchProcessor()
complementary_data = {
"task": "stack_blocks",
"episode_id": 456,
"difficulty": "hard",
"metadata": {"scene": "kitchen"},
}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
# Task should be batched
assert processed_comp_data["task"] == ["stack_blocks"]
# Other fields should remain unchanged
assert processed_comp_data["episode_id"] == 456
assert processed_comp_data["difficulty"] == "hard"
assert processed_comp_data["metadata"] == {"scene": "kitchen"}
def test_task_with_observation_and_action():
"""Test task processing together with observation and action processing."""
processor = ToBatchProcessor()
# All components need batching
observation = {
OBS_STATE: torch.randn(5),
OBS_IMAGE: torch.randn(32, 32, 3),
}
action = torch.randn(4)
complementary_data = {"task": "navigate_to_goal"}
transition = create_transition(
observation=observation, action=action, complementary_data=complementary_data
)
result = processor(transition)
# All should be batched
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5)
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 32, 32, 3)
assert result[TransitionKey.ACTION].shape == (1, 4)
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["navigate_to_goal"]
def test_task_comprehensive_string_cases():
"""Test task processing with comprehensive string cases and edge cases."""
processor = ToBatchProcessor()
# Test various string formats
string_tasks = [
"pick_and_place",
"navigate",
"open_drawer",
"", # Empty string (valid but edge case)
"task with spaces",
"task_with_underscores",
"task-with-dashes",
"UPPERCASE_TASK",
"MixedCaseTask",
"task123",
"数字任务", # Unicode task
"🤖 robot task", # Emoji in task
"task\nwith\nnewlines", # Special characters
"task\twith\ttabs",
"task with 'quotes'",
'task with "double quotes"',
]
# Test that all string tasks get properly batched
for task in string_tasks:
complementary_data = {"task": task}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == [task]
assert isinstance(processed_comp_data["task"], list)
assert len(processed_comp_data["task"]) == 1
# Test various list of strings (should remain unchanged)
list_tasks = [
["single_task"],
["task1", "task2"],
["pick", "place", "navigate"],
[], # Empty list
[""], # List with empty string
["task with spaces", "task_with_underscores", "UPPERCASE"],
["🤖 task", "数字任务", "normal_task"], # Mixed formats
]
for task_list in list_tasks:
complementary_data = {"task": task_list}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == task_list
assert isinstance(processed_comp_data["task"], list)
assert processed_comp_data["task"] is task_list # Should be same object (in-place)
def test_task_in_place_mutation():
"""Test that the processor mutates complementary_data in place for tasks."""
processor = ToBatchProcessor()
complementary_data = {"task": "sort_objects"}
transition = create_transition(complementary_data=complementary_data)
# Store reference to original transition and complementary_data
original_transition = transition
original_comp_data = complementary_data
# Process
result = processor(transition)
# Should be the same objects (in-place mutation)
assert result is original_transition
assert result[TransitionKey.COMPLEMENTARY_DATA] is original_comp_data
assert original_comp_data["task"] == ["sort_objects"]
def test_task_preserves_other_keys():
"""Test that task processing preserves other keys in complementary_data."""
processor = ToBatchProcessor()
complementary_data = {
"task": "clean_table",
"robot_id": "robot_123",
"motor_id": "motor_456",
"config": {"speed": "slow", "precision": "high"},
"metrics": [1.0, 2.0, 3.0],
}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
# Task should be processed
assert processed_comp_data["task"] == ["clean_table"]
# All other keys should be preserved exactly
assert processed_comp_data["robot_id"] == "robot_123"
assert processed_comp_data["motor_id"] == "motor_456"
assert processed_comp_data["config"] == {"speed": "slow", "precision": "high"}
assert processed_comp_data["metrics"] == [1.0, 2.0, 3.0]