mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
feat(batch_processor): Add task field processing to ToBatchProcessor
- Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data.
This commit is contained in:
committed by
Steven Palma
parent
c4763f61a1
commit
c0013b130b
@@ -640,3 +640,262 @@ def test_empty_action_tensor():
|
||||
|
||||
# Should remain unchanged
|
||||
assert result[TransitionKey.ACTION].shape == (1, 0)
|
||||
|
||||
|
||||
# Task-specific tests
|
||||
def test_task_string_to_list():
|
||||
"""Test that string tasks get wrapped in lists to add batch dimension."""
|
||||
processor = ToBatchProcessor()
|
||||
|
||||
# Create complementary data with string task
|
||||
complementary_data = {"task": "pick_cube"}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# String task should be wrapped in list
|
||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert processed_comp_data["task"] == ["pick_cube"]
|
||||
assert isinstance(processed_comp_data["task"], list)
|
||||
assert len(processed_comp_data["task"]) == 1
|
||||
|
||||
|
||||
def test_task_string_validation():
|
||||
"""Test that only string and list of strings are valid task values."""
|
||||
processor = ToBatchProcessor()
|
||||
|
||||
# Valid string task - should be converted to list
|
||||
complementary_data = {"task": "valid_task"}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
result = processor(transition)
|
||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert processed_comp_data["task"] == ["valid_task"]
|
||||
|
||||
# Valid list of strings - should remain unchanged
|
||||
complementary_data = {"task": ["task1", "task2"]}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
result = processor(transition)
|
||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert processed_comp_data["task"] == ["task1", "task2"]
|
||||
|
||||
|
||||
def test_task_list_of_strings():
|
||||
"""Test that lists of strings remain unchanged (already batched)."""
|
||||
processor = ToBatchProcessor()
|
||||
|
||||
# Test various list of strings
|
||||
test_lists = [
|
||||
["pick_cube"], # Single string in list
|
||||
["pick_cube", "place_cube"], # Multiple strings
|
||||
["task1", "task2", "task3"], # Three strings
|
||||
[], # Empty list
|
||||
[""], # List with empty string
|
||||
["task with spaces", "task_with_underscores"], # Mixed formats
|
||||
]
|
||||
|
||||
for task_list in test_lists:
|
||||
complementary_data = {"task": task_list}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Should remain unchanged since it's already a list
|
||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert processed_comp_data["task"] == task_list
|
||||
assert isinstance(processed_comp_data["task"], list)
|
||||
|
||||
|
||||
def test_complementary_data_none():
|
||||
"""Test processor handles None complementary_data gracefully."""
|
||||
processor = ToBatchProcessor()
|
||||
|
||||
transition = create_transition(complementary_data=None)
|
||||
result = processor(transition)
|
||||
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA] is None
|
||||
|
||||
|
||||
def test_complementary_data_empty():
|
||||
"""Test processor handles empty complementary_data dict."""
|
||||
processor = ToBatchProcessor()
|
||||
|
||||
complementary_data = {}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
|
||||
def test_complementary_data_no_task():
|
||||
"""Test processor handles complementary_data without task field."""
|
||||
processor = ToBatchProcessor()
|
||||
|
||||
complementary_data = {
|
||||
"episode_id": 123,
|
||||
"timestamp": 1234567890.0,
|
||||
"extra_info": "some data",
|
||||
}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Should remain unchanged
|
||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert processed_comp_data == complementary_data
|
||||
|
||||
|
||||
def test_complementary_data_mixed():
|
||||
"""Test processor with mixed complementary_data containing task and other fields."""
|
||||
processor = ToBatchProcessor()
|
||||
|
||||
complementary_data = {
|
||||
"task": "stack_blocks",
|
||||
"episode_id": 456,
|
||||
"difficulty": "hard",
|
||||
"metadata": {"scene": "kitchen"},
|
||||
}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
# Task should be batched
|
||||
assert processed_comp_data["task"] == ["stack_blocks"]
|
||||
|
||||
# Other fields should remain unchanged
|
||||
assert processed_comp_data["episode_id"] == 456
|
||||
assert processed_comp_data["difficulty"] == "hard"
|
||||
assert processed_comp_data["metadata"] == {"scene": "kitchen"}
|
||||
|
||||
|
||||
def test_task_with_observation_and_action():
|
||||
"""Test task processing together with observation and action processing."""
|
||||
processor = ToBatchProcessor()
|
||||
|
||||
# All components need batching
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(5),
|
||||
OBS_IMAGE: torch.randn(32, 32, 3),
|
||||
}
|
||||
action = torch.randn(4)
|
||||
complementary_data = {"task": "navigate_to_goal"}
|
||||
|
||||
transition = create_transition(
|
||||
observation=observation, action=action, complementary_data=complementary_data
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# All should be batched
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5)
|
||||
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 32, 32, 3)
|
||||
assert result[TransitionKey.ACTION].shape == (1, 4)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["navigate_to_goal"]
|
||||
|
||||
|
||||
def test_task_comprehensive_string_cases():
|
||||
"""Test task processing with comprehensive string cases and edge cases."""
|
||||
processor = ToBatchProcessor()
|
||||
|
||||
# Test various string formats
|
||||
string_tasks = [
|
||||
"pick_and_place",
|
||||
"navigate",
|
||||
"open_drawer",
|
||||
"", # Empty string (valid but edge case)
|
||||
"task with spaces",
|
||||
"task_with_underscores",
|
||||
"task-with-dashes",
|
||||
"UPPERCASE_TASK",
|
||||
"MixedCaseTask",
|
||||
"task123",
|
||||
"数字任务", # Unicode task
|
||||
"🤖 robot task", # Emoji in task
|
||||
"task\nwith\nnewlines", # Special characters
|
||||
"task\twith\ttabs",
|
||||
"task with 'quotes'",
|
||||
'task with "double quotes"',
|
||||
]
|
||||
|
||||
# Test that all string tasks get properly batched
|
||||
for task in string_tasks:
|
||||
complementary_data = {"task": task}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert processed_comp_data["task"] == [task]
|
||||
assert isinstance(processed_comp_data["task"], list)
|
||||
assert len(processed_comp_data["task"]) == 1
|
||||
|
||||
# Test various list of strings (should remain unchanged)
|
||||
list_tasks = [
|
||||
["single_task"],
|
||||
["task1", "task2"],
|
||||
["pick", "place", "navigate"],
|
||||
[], # Empty list
|
||||
[""], # List with empty string
|
||||
["task with spaces", "task_with_underscores", "UPPERCASE"],
|
||||
["🤖 task", "数字任务", "normal_task"], # Mixed formats
|
||||
]
|
||||
|
||||
for task_list in list_tasks:
|
||||
complementary_data = {"task": task_list}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert processed_comp_data["task"] == task_list
|
||||
assert isinstance(processed_comp_data["task"], list)
|
||||
assert processed_comp_data["task"] is task_list # Should be same object (in-place)
|
||||
|
||||
|
||||
def test_task_in_place_mutation():
|
||||
"""Test that the processor mutates complementary_data in place for tasks."""
|
||||
processor = ToBatchProcessor()
|
||||
|
||||
complementary_data = {"task": "sort_objects"}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
|
||||
# Store reference to original transition and complementary_data
|
||||
original_transition = transition
|
||||
original_comp_data = complementary_data
|
||||
|
||||
# Process
|
||||
result = processor(transition)
|
||||
|
||||
# Should be the same objects (in-place mutation)
|
||||
assert result is original_transition
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA] is original_comp_data
|
||||
assert original_comp_data["task"] == ["sort_objects"]
|
||||
|
||||
|
||||
def test_task_preserves_other_keys():
|
||||
"""Test that task processing preserves other keys in complementary_data."""
|
||||
processor = ToBatchProcessor()
|
||||
|
||||
complementary_data = {
|
||||
"task": "clean_table",
|
||||
"robot_id": "robot_123",
|
||||
"motor_id": "motor_456",
|
||||
"config": {"speed": "slow", "precision": "high"},
|
||||
"metrics": [1.0, 2.0, 3.0],
|
||||
}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
# Task should be processed
|
||||
assert processed_comp_data["task"] == ["clean_table"]
|
||||
|
||||
# All other keys should be preserved exactly
|
||||
assert processed_comp_data["robot_id"] == "robot_123"
|
||||
assert processed_comp_data["motor_id"] == "motor_456"
|
||||
assert processed_comp_data["config"] == {"speed": "slow", "precision": "high"}
|
||||
assert processed_comp_data["metrics"] == [1.0, 2.0, 3.0]
|
||||
|
||||
Reference in New Issue
Block a user