feat(batch_processor): Add task field processing to ToBatchProcessor

- Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference.
- Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings.
- Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data.
This commit is contained in:
Adil Zouitine
2025-07-25 19:05:44 +02:00
committed by Steven Palma
parent c4763f61a1
commit c0013b130b
2 changed files with 277 additions and 0 deletions
+18
View File
@@ -37,6 +37,10 @@ class ToBatchProcessor:
- For actions:
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
- For task field in complementary data:
Wraps string task in a list to add batch dimension
(task must be a string or list of strings)
This is useful when processing single transitions that need to be batched for
model inference or when converting from unbatched environment outputs to
batched model inputs.
@@ -49,6 +53,7 @@ class ToBatchProcessor:
# State: (7,) -> (1, 7)
# Image: (224, 224, 3) -> (1, 224, 224, 3)
# Action: (4,) -> (1, 4)
# Task: "pick_cube" -> ["pick_cube"]
# Already batched: (1, 7) -> (1, 7) [unchanged]
```
"""
@@ -56,6 +61,7 @@ class ToBatchProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition:
self._process_observation(transition)
self._process_action(transition)
self._process_complementary_data(transition)
return transition
def _process_observation(self, transition: EnvTransition) -> None:
@@ -88,6 +94,18 @@ class ToBatchProcessor:
if action is not None and isinstance(action, Tensor) and action.dim() == 1:
transition[TransitionKey.ACTION] = action.unsqueeze(0)
def _process_complementary_data(self, transition: EnvTransition) -> None:
"""Process complementary data in-place, handling task field batching."""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return
# Process task field - wrap string in list to add batch dimension
if "task" in complementary_data:
task_value = complementary_data["task"]
if isinstance(task_value, str):
complementary_data["task"] = [task_value]
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}