mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
feat(batch_processor): Add task field processing to ToBatchProcessor
- Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data.
This commit is contained in:
committed by
Steven Palma
parent
c4763f61a1
commit
c0013b130b
@@ -37,6 +37,10 @@ class ToBatchProcessor:
|
|||||||
- For actions:
|
- For actions:
|
||||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||||
|
|
||||||
|
- For task field in complementary data:
|
||||||
|
Wraps string task in a list to add batch dimension
|
||||||
|
(task must be a string or list of strings)
|
||||||
|
|
||||||
This is useful when processing single transitions that need to be batched for
|
This is useful when processing single transitions that need to be batched for
|
||||||
model inference or when converting from unbatched environment outputs to
|
model inference or when converting from unbatched environment outputs to
|
||||||
batched model inputs.
|
batched model inputs.
|
||||||
@@ -49,6 +53,7 @@ class ToBatchProcessor:
|
|||||||
# State: (7,) -> (1, 7)
|
# State: (7,) -> (1, 7)
|
||||||
# Image: (224, 224, 3) -> (1, 224, 224, 3)
|
# Image: (224, 224, 3) -> (1, 224, 224, 3)
|
||||||
# Action: (4,) -> (1, 4)
|
# Action: (4,) -> (1, 4)
|
||||||
|
# Task: "pick_cube" -> ["pick_cube"]
|
||||||
# Already batched: (1, 7) -> (1, 7) [unchanged]
|
# Already batched: (1, 7) -> (1, 7) [unchanged]
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
@@ -56,6 +61,7 @@ class ToBatchProcessor:
|
|||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
self._process_observation(transition)
|
self._process_observation(transition)
|
||||||
self._process_action(transition)
|
self._process_action(transition)
|
||||||
|
self._process_complementary_data(transition)
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
def _process_observation(self, transition: EnvTransition) -> None:
|
def _process_observation(self, transition: EnvTransition) -> None:
|
||||||
@@ -88,6 +94,18 @@ class ToBatchProcessor:
|
|||||||
if action is not None and isinstance(action, Tensor) and action.dim() == 1:
|
if action is not None and isinstance(action, Tensor) and action.dim() == 1:
|
||||||
transition[TransitionKey.ACTION] = action.unsqueeze(0)
|
transition[TransitionKey.ACTION] = action.unsqueeze(0)
|
||||||
|
|
||||||
|
def _process_complementary_data(self, transition: EnvTransition) -> None:
|
||||||
|
"""Process complementary data in-place, handling task field batching."""
|
||||||
|
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||||
|
if complementary_data is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Process task field - wrap string in list to add batch dimension
|
||||||
|
if "task" in complementary_data:
|
||||||
|
task_value = complementary_data["task"]
|
||||||
|
if isinstance(task_value, str):
|
||||||
|
complementary_data["task"] = [task_value]
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
def get_config(self) -> dict[str, Any]:
|
||||||
"""Return configuration for serialization."""
|
"""Return configuration for serialization."""
|
||||||
return {}
|
return {}
|
||||||
|
|||||||
@@ -640,3 +640,262 @@ def test_empty_action_tensor():
|
|||||||
|
|
||||||
# Should remain unchanged
|
# Should remain unchanged
|
||||||
assert result[TransitionKey.ACTION].shape == (1, 0)
|
assert result[TransitionKey.ACTION].shape == (1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
# Task-specific tests
|
||||||
|
def test_task_string_to_list():
|
||||||
|
"""Test that string tasks get wrapped in lists to add batch dimension."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
# Create complementary data with string task
|
||||||
|
complementary_data = {"task": "pick_cube"}
|
||||||
|
transition = create_transition(complementary_data=complementary_data)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
# String task should be wrapped in list
|
||||||
|
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
|
assert processed_comp_data["task"] == ["pick_cube"]
|
||||||
|
assert isinstance(processed_comp_data["task"], list)
|
||||||
|
assert len(processed_comp_data["task"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_string_validation():
|
||||||
|
"""Test that only string and list of strings are valid task values."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
# Valid string task - should be converted to list
|
||||||
|
complementary_data = {"task": "valid_task"}
|
||||||
|
transition = create_transition(complementary_data=complementary_data)
|
||||||
|
result = processor(transition)
|
||||||
|
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
|
assert processed_comp_data["task"] == ["valid_task"]
|
||||||
|
|
||||||
|
# Valid list of strings - should remain unchanged
|
||||||
|
complementary_data = {"task": ["task1", "task2"]}
|
||||||
|
transition = create_transition(complementary_data=complementary_data)
|
||||||
|
result = processor(transition)
|
||||||
|
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
|
assert processed_comp_data["task"] == ["task1", "task2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_list_of_strings():
|
||||||
|
"""Test that lists of strings remain unchanged (already batched)."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
# Test various list of strings
|
||||||
|
test_lists = [
|
||||||
|
["pick_cube"], # Single string in list
|
||||||
|
["pick_cube", "place_cube"], # Multiple strings
|
||||||
|
["task1", "task2", "task3"], # Three strings
|
||||||
|
[], # Empty list
|
||||||
|
[""], # List with empty string
|
||||||
|
["task with spaces", "task_with_underscores"], # Mixed formats
|
||||||
|
]
|
||||||
|
|
||||||
|
for task_list in test_lists:
|
||||||
|
complementary_data = {"task": task_list}
|
||||||
|
transition = create_transition(complementary_data=complementary_data)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
# Should remain unchanged since it's already a list
|
||||||
|
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
|
assert processed_comp_data["task"] == task_list
|
||||||
|
assert isinstance(processed_comp_data["task"], list)
|
||||||
|
|
||||||
|
|
||||||
|
def test_complementary_data_none():
|
||||||
|
"""Test processor handles None complementary_data gracefully."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
transition = create_transition(complementary_data=None)
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
assert result[TransitionKey.COMPLEMENTARY_DATA] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_complementary_data_empty():
|
||||||
|
"""Test processor handles empty complementary_data dict."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
complementary_data = {}
|
||||||
|
transition = create_transition(complementary_data=complementary_data)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_complementary_data_no_task():
|
||||||
|
"""Test processor handles complementary_data without task field."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
complementary_data = {
|
||||||
|
"episode_id": 123,
|
||||||
|
"timestamp": 1234567890.0,
|
||||||
|
"extra_info": "some data",
|
||||||
|
}
|
||||||
|
transition = create_transition(complementary_data=complementary_data)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
# Should remain unchanged
|
||||||
|
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
|
assert processed_comp_data == complementary_data
|
||||||
|
|
||||||
|
|
||||||
|
def test_complementary_data_mixed():
|
||||||
|
"""Test processor with mixed complementary_data containing task and other fields."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
complementary_data = {
|
||||||
|
"task": "stack_blocks",
|
||||||
|
"episode_id": 456,
|
||||||
|
"difficulty": "hard",
|
||||||
|
"metadata": {"scene": "kitchen"},
|
||||||
|
}
|
||||||
|
transition = create_transition(complementary_data=complementary_data)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
|
|
||||||
|
# Task should be batched
|
||||||
|
assert processed_comp_data["task"] == ["stack_blocks"]
|
||||||
|
|
||||||
|
# Other fields should remain unchanged
|
||||||
|
assert processed_comp_data["episode_id"] == 456
|
||||||
|
assert processed_comp_data["difficulty"] == "hard"
|
||||||
|
assert processed_comp_data["metadata"] == {"scene": "kitchen"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_with_observation_and_action():
|
||||||
|
"""Test task processing together with observation and action processing."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
# All components need batching
|
||||||
|
observation = {
|
||||||
|
OBS_STATE: torch.randn(5),
|
||||||
|
OBS_IMAGE: torch.randn(32, 32, 3),
|
||||||
|
}
|
||||||
|
action = torch.randn(4)
|
||||||
|
complementary_data = {"task": "navigate_to_goal"}
|
||||||
|
|
||||||
|
transition = create_transition(
|
||||||
|
observation=observation, action=action, complementary_data=complementary_data
|
||||||
|
)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
# All should be batched
|
||||||
|
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5)
|
||||||
|
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 32, 32, 3)
|
||||||
|
assert result[TransitionKey.ACTION].shape == (1, 4)
|
||||||
|
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["navigate_to_goal"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_comprehensive_string_cases():
|
||||||
|
"""Test task processing with comprehensive string cases and edge cases."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
# Test various string formats
|
||||||
|
string_tasks = [
|
||||||
|
"pick_and_place",
|
||||||
|
"navigate",
|
||||||
|
"open_drawer",
|
||||||
|
"", # Empty string (valid but edge case)
|
||||||
|
"task with spaces",
|
||||||
|
"task_with_underscores",
|
||||||
|
"task-with-dashes",
|
||||||
|
"UPPERCASE_TASK",
|
||||||
|
"MixedCaseTask",
|
||||||
|
"task123",
|
||||||
|
"数字任务", # Unicode task
|
||||||
|
"🤖 robot task", # Emoji in task
|
||||||
|
"task\nwith\nnewlines", # Special characters
|
||||||
|
"task\twith\ttabs",
|
||||||
|
"task with 'quotes'",
|
||||||
|
'task with "double quotes"',
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test that all string tasks get properly batched
|
||||||
|
for task in string_tasks:
|
||||||
|
complementary_data = {"task": task}
|
||||||
|
transition = create_transition(complementary_data=complementary_data)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
|
assert processed_comp_data["task"] == [task]
|
||||||
|
assert isinstance(processed_comp_data["task"], list)
|
||||||
|
assert len(processed_comp_data["task"]) == 1
|
||||||
|
|
||||||
|
# Test various list of strings (should remain unchanged)
|
||||||
|
list_tasks = [
|
||||||
|
["single_task"],
|
||||||
|
["task1", "task2"],
|
||||||
|
["pick", "place", "navigate"],
|
||||||
|
[], # Empty list
|
||||||
|
[""], # List with empty string
|
||||||
|
["task with spaces", "task_with_underscores", "UPPERCASE"],
|
||||||
|
["🤖 task", "数字任务", "normal_task"], # Mixed formats
|
||||||
|
]
|
||||||
|
|
||||||
|
for task_list in list_tasks:
|
||||||
|
complementary_data = {"task": task_list}
|
||||||
|
transition = create_transition(complementary_data=complementary_data)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
|
assert processed_comp_data["task"] == task_list
|
||||||
|
assert isinstance(processed_comp_data["task"], list)
|
||||||
|
assert processed_comp_data["task"] is task_list # Should be same object (in-place)
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_in_place_mutation():
|
||||||
|
"""Test that the processor mutates complementary_data in place for tasks."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
complementary_data = {"task": "sort_objects"}
|
||||||
|
transition = create_transition(complementary_data=complementary_data)
|
||||||
|
|
||||||
|
# Store reference to original transition and complementary_data
|
||||||
|
original_transition = transition
|
||||||
|
original_comp_data = complementary_data
|
||||||
|
|
||||||
|
# Process
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
# Should be the same objects (in-place mutation)
|
||||||
|
assert result is original_transition
|
||||||
|
assert result[TransitionKey.COMPLEMENTARY_DATA] is original_comp_data
|
||||||
|
assert original_comp_data["task"] == ["sort_objects"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_preserves_other_keys():
|
||||||
|
"""Test that task processing preserves other keys in complementary_data."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
complementary_data = {
|
||||||
|
"task": "clean_table",
|
||||||
|
"robot_id": "robot_123",
|
||||||
|
"motor_id": "motor_456",
|
||||||
|
"config": {"speed": "slow", "precision": "high"},
|
||||||
|
"metrics": [1.0, 2.0, 3.0],
|
||||||
|
}
|
||||||
|
transition = create_transition(complementary_data=complementary_data)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
|
|
||||||
|
# Task should be processed
|
||||||
|
assert processed_comp_data["task"] == ["clean_table"]
|
||||||
|
|
||||||
|
# All other keys should be preserved exactly
|
||||||
|
assert processed_comp_data["robot_id"] == "robot_123"
|
||||||
|
assert processed_comp_data["motor_id"] == "motor_456"
|
||||||
|
assert processed_comp_data["config"] == {"speed": "slow", "precision": "high"}
|
||||||
|
assert processed_comp_data["metrics"] == [1.0, 2.0, 3.0]
|
||||||
|
|||||||
Reference in New Issue
Block a user