diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index 77ce0810b..71f037d5c 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -24,9 +24,9 @@ from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, Tra @dataclass @ProcessorStepRegistry.register(name="to_batch_processor") class ToBatchProcessor: - """Processor that adds batch dimensions to observations when needed. + """Processor that adds batch dimensions to observations and actions when needed. - This processor ensures that observations have proper batch dimensions for model processing: + This processor ensures that observations and actions have proper batch dimensions for model processing: - For state observations (observation.state, observation.environment_state): Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional @@ -34,6 +34,9 @@ class ToBatchProcessor: - For image observations (observation.image, observation.images.*): Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C) + - For actions: + Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional + This is useful when processing single transitions that need to be batched for model inference or when converting from unbatched environment outputs to batched model inputs. @@ -45,15 +48,21 @@ class ToBatchProcessor: ```python # State: (7,) -> (1, 7) # Image: (224, 224, 3) -> (1, 224, 224, 3) + # Action: (4,) -> (1, 4) # Already batched: (1, 7) -> (1, 7) [unchanged] ``` """ def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = transition.get(TransitionKey.OBSERVATION) + self._process_observation(transition) + self._process_action(transition) + return transition + def _process_observation(self, transition: EnvTransition) -> None: + """Process observation component in-place, adding batch dimensions where needed.""" + observation = transition.get(TransitionKey.OBSERVATION) if observation is None: - return transition + return # Process state observations - add batch dim if 1D for state_key in [OBS_STATE, OBS_ENV_STATE]: @@ -73,7 +82,11 @@ class ToBatchProcessor: if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3: observation[key] = value.unsqueeze(0) - return transition + def _process_action(self, transition: EnvTransition) -> None: + """Process action component in-place, adding batch dimension if needed.""" + action = transition.get(TransitionKey.ACTION) + if action is not None and isinstance(action, Tensor) and action.dim() == 1: + transition[TransitionKey.ACTION] = action.unsqueeze(0) def get_config(self) -> dict[str, Any]: """Return configuration for serialization.""" diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py index 9032ba48e..0245b9fb7 100644 --- a/src/lerobot/processor/migrate_policy_normalization.py +++ b/src/lerobot/processor/migrate_policy_normalization.py @@ -46,6 +46,7 @@ from huggingface_hub import hf_hub_download from safetensors.torch import load_file as load_safetensors from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.processor.batch_processor import ToBatchProcessor from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor from lerobot.processor.pipeline import RobotProcessor @@ -403,14 +404,16 @@ def main(): preprocessor_steps = [ NormalizerProcessor(features=input_features, norm_map=norm_map, stats=stats), NormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats), + ToBatchProcessor(), ] - preprocessor = RobotProcessor(preprocessor_steps, name=f"{policy_type}_preprocessor") + preprocessor = RobotProcessor(preprocessor_steps, name="preprocessor") # Create postprocessor with unnormalizer for outputs only postprocessor_steps = [ UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats), + ToBatchProcessor(), ] - postprocessor = RobotProcessor(postprocessor_steps, name=f"{policy_type}_postprocessor") + postprocessor = RobotProcessor(postprocessor_steps, name="postprocessor") # Determine hub repo ID if pushing to hub if args.push_to_hub: diff --git a/tests/processor/test_batch_processor.py b/tests/processor/test_batch_processor.py index e6f199833..e46eee031 100644 --- a/tests/processor/test_batch_processor.py +++ b/tests/processor/test_batch_processor.py @@ -17,11 +17,14 @@ import tempfile from pathlib import Path +import numpy as np import pytest import torch from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE -from lerobot.processor import ProcessorStepRegistry, RobotProcessor, ToBatchProcessor, TransitionKey +from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor +from lerobot.processor.batch_processor import ToBatchProcessor +from lerobot.processor.pipeline import TransitionKey def create_transition( @@ -34,8 +37,8 @@ def create_transition( TransitionKey.REWARD: reward, TransitionKey.DONE: done, TransitionKey.TRUNCATED: truncated, - TransitionKey.INFO: info if info is not None else {}, - TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {}, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: complementary_data, } @@ -421,3 +424,219 @@ def test_edge_case_zero_dimensional_tensors(): # 0D tensors should remain unchanged assert torch.allclose(processed_obs[OBS_STATE], scalar_tensor) assert torch.allclose(processed_obs["scalar_value"], scalar_tensor) + + +# Action-specific tests +def test_action_1d_to_2d(): + """Test that 1D action tensors get batch dimension added.""" + processor = ToBatchProcessor() + + # Create 1D action tensor + action_1d = torch.randn(4) + transition = create_transition(action=action_1d) + + result = processor(transition) + + # Should add batch dimension + assert result[TransitionKey.ACTION].shape == (1, 4) + assert torch.equal(result[TransitionKey.ACTION][0], action_1d) + + +def test_action_already_batched(): + """Test that already batched action tensors remain unchanged.""" + processor = ToBatchProcessor() + + # Test various batch sizes + action_batched_1 = torch.randn(1, 4) + action_batched_5 = torch.randn(5, 4) + + # Single batch + transition = create_transition(action=action_batched_1) + result = processor(transition) + assert torch.equal(result[TransitionKey.ACTION], action_batched_1) + + # Multiple batch + transition = create_transition(action=action_batched_5) + result = processor(transition) + assert torch.equal(result[TransitionKey.ACTION], action_batched_5) + + +def test_action_higher_dimensional(): + """Test that higher dimensional action tensors remain unchanged.""" + processor = ToBatchProcessor() + + # 3D action tensor (e.g., sequence of actions) + action_3d = torch.randn(2, 4, 3) + transition = create_transition(action=action_3d) + result = processor(transition) + assert torch.equal(result[TransitionKey.ACTION], action_3d) + + # 4D action tensor + action_4d = torch.randn(2, 10, 4, 3) + transition = create_transition(action=action_4d) + result = processor(transition) + assert torch.equal(result[TransitionKey.ACTION], action_4d) + + +def test_action_scalar_tensor(): + """Test that scalar (0D) action tensors remain unchanged.""" + processor = ToBatchProcessor() + + action_scalar = torch.tensor(1.5) + transition = create_transition(action=action_scalar) + result = processor(transition) + + # Should remain scalar + assert result[TransitionKey.ACTION].dim() == 0 + assert torch.equal(result[TransitionKey.ACTION], action_scalar) + + +def test_action_non_tensor(): + """Test that non-tensor actions remain unchanged.""" + processor = ToBatchProcessor() + + # List action + action_list = [0.1, 0.2, 0.3, 0.4] + transition = create_transition(action=action_list) + result = processor(transition) + assert result[TransitionKey.ACTION] == action_list + + # Numpy array action (as Python object, not converted) + action_numpy = np.array([1, 2, 3, 4]) + transition = create_transition(action=action_numpy) + result = processor(transition) + assert np.array_equal(result[TransitionKey.ACTION], action_numpy) + + # String action (edge case) + action_string = "forward" + transition = create_transition(action=action_string) + result = processor(transition) + assert result[TransitionKey.ACTION] == action_string + + # Dict action (structured action) + action_dict = {"linear": [0.5, 0.0], "angular": 0.2} + transition = create_transition(action=action_dict) + result = processor(transition) + assert result[TransitionKey.ACTION] == action_dict + + +def test_action_none(): + """Test that None action is handled correctly.""" + processor = ToBatchProcessor() + + transition = create_transition(action=None) + result = processor(transition) + assert result[TransitionKey.ACTION] is None + + +def test_action_with_observation(): + """Test action processing together with observation processing.""" + processor = ToBatchProcessor() + + # Both need batching + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(64, 64, 3), + } + action = torch.randn(4) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Both should be batched + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7) + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 64, 64, 3) + assert result[TransitionKey.ACTION].shape == (1, 4) + + +def test_action_different_sizes(): + """Test action processing with various action dimensions.""" + processor = ToBatchProcessor() + + # Different action sizes (robot with different DOF) + action_sizes = [1, 2, 4, 7, 10, 20] + + for size in action_sizes: + action = torch.randn(size) + transition = create_transition(action=action) + result = processor(transition) + + assert result[TransitionKey.ACTION].shape == (1, size) + assert torch.equal(result[TransitionKey.ACTION][0], action) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_action_device_compatibility(): + """Test action processing on different devices.""" + processor = ToBatchProcessor() + + # CUDA action + action_cuda = torch.randn(4, device="cuda") + transition = create_transition(action=action_cuda) + result = processor(transition) + + assert result[TransitionKey.ACTION].shape == (1, 4) + assert result[TransitionKey.ACTION].device.type == "cuda" + + # CPU action + action_cpu = torch.randn(4, device="cpu") + transition = create_transition(action=action_cpu) + result = processor(transition) + + assert result[TransitionKey.ACTION].shape == (1, 4) + assert result[TransitionKey.ACTION].device.type == "cpu" + + +def test_action_dtype_preservation(): + """Test that action dtype is preserved during processing.""" + processor = ToBatchProcessor() + + # Different dtypes + dtypes = [torch.float32, torch.float64, torch.int32, torch.int64] + + for dtype in dtypes: + action = torch.randn(4).to(dtype) + transition = create_transition(action=action) + result = processor(transition) + + assert result[TransitionKey.ACTION].dtype == dtype + assert result[TransitionKey.ACTION].shape == (1, 4) + + +def test_action_in_place_mutation(): + """Test that the processor mutates the transition in place for actions.""" + processor = ToBatchProcessor() + + action = torch.randn(4) + transition = create_transition(action=action) + + # Store reference to original transition + original_transition = transition + + # Process + result = processor(transition) + + # Should be the same object (in-place mutation) + assert result is original_transition + assert result[TransitionKey.ACTION].shape == (1, 4) + + +def test_empty_action_tensor(): + """Test handling of empty action tensors.""" + processor = ToBatchProcessor() + + # Empty 1D tensor + action_empty = torch.tensor([]) + transition = create_transition(action=action_empty) + result = processor(transition) + + # Should add batch dimension even to empty tensor + assert result[TransitionKey.ACTION].shape == (1, 0) + + # Empty 2D tensor (already batched) + action_empty_2d = torch.randn(1, 0) + transition = create_transition(action=action_empty_2d) + result = processor(transition) + + # Should remain unchanged + assert result[TransitionKey.ACTION].shape == (1, 0)