diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index 71f037d5c..2c496fe8b 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -37,6 +37,10 @@ class ToBatchProcessor: - For actions: Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional + - For task field in complementary data: + Wraps string task in a list to add batch dimension + (task must be a string or list of strings) + This is useful when processing single transitions that need to be batched for model inference or when converting from unbatched environment outputs to batched model inputs. @@ -49,6 +53,7 @@ class ToBatchProcessor: # State: (7,) -> (1, 7) # Image: (224, 224, 3) -> (1, 224, 224, 3) # Action: (4,) -> (1, 4) + # Task: "pick_cube" -> ["pick_cube"] # Already batched: (1, 7) -> (1, 7) [unchanged] ``` """ @@ -56,6 +61,7 @@ class ToBatchProcessor: def __call__(self, transition: EnvTransition) -> EnvTransition: self._process_observation(transition) self._process_action(transition) + self._process_complementary_data(transition) return transition def _process_observation(self, transition: EnvTransition) -> None: @@ -88,6 +94,18 @@ class ToBatchProcessor: if action is not None and isinstance(action, Tensor) and action.dim() == 1: transition[TransitionKey.ACTION] = action.unsqueeze(0) + def _process_complementary_data(self, transition: EnvTransition) -> None: + """Process complementary data in-place, handling task field batching.""" + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None: + return + + # Process task field - wrap string in list to add batch dimension + if "task" in complementary_data: + task_value = complementary_data["task"] + if isinstance(task_value, str): + complementary_data["task"] = [task_value] + def get_config(self) -> dict[str, Any]: """Return configuration for serialization.""" return {} diff --git a/tests/processor/test_batch_processor.py b/tests/processor/test_batch_processor.py index 2f5bdb962..3d8cb8d49 100644 --- a/tests/processor/test_batch_processor.py +++ b/tests/processor/test_batch_processor.py @@ -640,3 +640,262 @@ def test_empty_action_tensor(): # Should remain unchanged assert result[TransitionKey.ACTION].shape == (1, 0) + + +# Task-specific tests +def test_task_string_to_list(): + """Test that string tasks get wrapped in lists to add batch dimension.""" + processor = ToBatchProcessor() + + # Create complementary data with string task + complementary_data = {"task": "pick_cube"} + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + # String task should be wrapped in list + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == ["pick_cube"] + assert isinstance(processed_comp_data["task"], list) + assert len(processed_comp_data["task"]) == 1 + + +def test_task_string_validation(): + """Test that only string and list of strings are valid task values.""" + processor = ToBatchProcessor() + + # Valid string task - should be converted to list + complementary_data = {"task": "valid_task"} + transition = create_transition(complementary_data=complementary_data) + result = processor(transition) + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == ["valid_task"] + + # Valid list of strings - should remain unchanged + complementary_data = {"task": ["task1", "task2"]} + transition = create_transition(complementary_data=complementary_data) + result = processor(transition) + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == ["task1", "task2"] + + +def test_task_list_of_strings(): + """Test that lists of strings remain unchanged (already batched).""" + processor = ToBatchProcessor() + + # Test various list of strings + test_lists = [ + ["pick_cube"], # Single string in list + ["pick_cube", "place_cube"], # Multiple strings + ["task1", "task2", "task3"], # Three strings + [], # Empty list + [""], # List with empty string + ["task with spaces", "task_with_underscores"], # Mixed formats + ] + + for task_list in test_lists: + complementary_data = {"task": task_list} + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + # Should remain unchanged since it's already a list + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == task_list + assert isinstance(processed_comp_data["task"], list) + + +def test_complementary_data_none(): + """Test processor handles None complementary_data gracefully.""" + processor = ToBatchProcessor() + + transition = create_transition(complementary_data=None) + result = processor(transition) + + assert result[TransitionKey.COMPLEMENTARY_DATA] is None + + +def test_complementary_data_empty(): + """Test processor handles empty complementary_data dict.""" + processor = ToBatchProcessor() + + complementary_data = {} + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + assert result[TransitionKey.COMPLEMENTARY_DATA] == {} + + +def test_complementary_data_no_task(): + """Test processor handles complementary_data without task field.""" + processor = ToBatchProcessor() + + complementary_data = { + "episode_id": 123, + "timestamp": 1234567890.0, + "extra_info": "some data", + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + # Should remain unchanged + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data == complementary_data + + +def test_complementary_data_mixed(): + """Test processor with mixed complementary_data containing task and other fields.""" + processor = ToBatchProcessor() + + complementary_data = { + "task": "stack_blocks", + "episode_id": 456, + "difficulty": "hard", + "metadata": {"scene": "kitchen"}, + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Task should be batched + assert processed_comp_data["task"] == ["stack_blocks"] + + # Other fields should remain unchanged + assert processed_comp_data["episode_id"] == 456 + assert processed_comp_data["difficulty"] == "hard" + assert processed_comp_data["metadata"] == {"scene": "kitchen"} + + +def test_task_with_observation_and_action(): + """Test task processing together with observation and action processing.""" + processor = ToBatchProcessor() + + # All components need batching + observation = { + OBS_STATE: torch.randn(5), + OBS_IMAGE: torch.randn(32, 32, 3), + } + action = torch.randn(4) + complementary_data = {"task": "navigate_to_goal"} + + transition = create_transition( + observation=observation, action=action, complementary_data=complementary_data + ) + + result = processor(transition) + + # All should be batched + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5) + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 32, 32, 3) + assert result[TransitionKey.ACTION].shape == (1, 4) + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["navigate_to_goal"] + + +def test_task_comprehensive_string_cases(): + """Test task processing with comprehensive string cases and edge cases.""" + processor = ToBatchProcessor() + + # Test various string formats + string_tasks = [ + "pick_and_place", + "navigate", + "open_drawer", + "", # Empty string (valid but edge case) + "task with spaces", + "task_with_underscores", + "task-with-dashes", + "UPPERCASE_TASK", + "MixedCaseTask", + "task123", + "数字任务", # Unicode task + "🤖 robot task", # Emoji in task + "task\nwith\nnewlines", # Special characters + "task\twith\ttabs", + "task with 'quotes'", + 'task with "double quotes"', + ] + + # Test that all string tasks get properly batched + for task in string_tasks: + complementary_data = {"task": task} + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == [task] + assert isinstance(processed_comp_data["task"], list) + assert len(processed_comp_data["task"]) == 1 + + # Test various list of strings (should remain unchanged) + list_tasks = [ + ["single_task"], + ["task1", "task2"], + ["pick", "place", "navigate"], + [], # Empty list + [""], # List with empty string + ["task with spaces", "task_with_underscores", "UPPERCASE"], + ["🤖 task", "数字任务", "normal_task"], # Mixed formats + ] + + for task_list in list_tasks: + complementary_data = {"task": task_list} + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == task_list + assert isinstance(processed_comp_data["task"], list) + assert processed_comp_data["task"] is task_list # Should be same object (in-place) + + +def test_task_in_place_mutation(): + """Test that the processor mutates complementary_data in place for tasks.""" + processor = ToBatchProcessor() + + complementary_data = {"task": "sort_objects"} + transition = create_transition(complementary_data=complementary_data) + + # Store reference to original transition and complementary_data + original_transition = transition + original_comp_data = complementary_data + + # Process + result = processor(transition) + + # Should be the same objects (in-place mutation) + assert result is original_transition + assert result[TransitionKey.COMPLEMENTARY_DATA] is original_comp_data + assert original_comp_data["task"] == ["sort_objects"] + + +def test_task_preserves_other_keys(): + """Test that task processing preserves other keys in complementary_data.""" + processor = ToBatchProcessor() + + complementary_data = { + "task": "clean_table", + "robot_id": "robot_123", + "motor_id": "motor_456", + "config": {"speed": "slow", "precision": "high"}, + "metrics": [1.0, 2.0, 3.0], + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Task should be processed + assert processed_comp_data["task"] == ["clean_table"] + + # All other keys should be preserved exactly + assert processed_comp_data["robot_id"] == "robot_123" + assert processed_comp_data["motor_id"] == "motor_456" + assert processed_comp_data["config"] == {"speed": "slow", "precision": "high"} + assert processed_comp_data["metrics"] == [1.0, 2.0, 3.0]