diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 8dd244c27..b0390a887 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .batch_processor import ToBatchProcessor from .device_processor import DeviceProcessor from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor from .observation_processor import VanillaObservationProcessor @@ -48,6 +49,7 @@ __all__ = [ "RenameProcessor", "RewardProcessor", "RobotProcessor", + "ToBatchProcessor", "TransitionKey", "TruncatedProcessor", "VanillaObservationProcessor", diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py new file mode 100644 index 000000000..77ce0810b --- /dev/null +++ b/src/lerobot/processor/batch_processor.py @@ -0,0 +1,92 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any + +import torch +from torch import Tensor + +from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey + + +@dataclass +@ProcessorStepRegistry.register(name="to_batch_processor") +class ToBatchProcessor: + """Processor that adds batch dimensions to observations when needed. + + This processor ensures that observations have proper batch dimensions for model processing: + + - For state observations (observation.state, observation.environment_state): + Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional + + - For image observations (observation.image, observation.images.*): + Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C) + + This is useful when processing single transitions that need to be batched for + model inference or when converting from unbatched environment outputs to + batched model inputs. + + The processor only modifies tensors that need batching and leaves already + batched tensors unchanged. + + Example: + ```python + # State: (7,) -> (1, 7) + # Image: (224, 224, 3) -> (1, 224, 224, 3) + # Already batched: (1, 7) -> (1, 7) [unchanged] + ``` + """ + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = transition.get(TransitionKey.OBSERVATION) + + if observation is None: + return transition + + # Process state observations - add batch dim if 1D + for state_key in [OBS_STATE, OBS_ENV_STATE]: + if state_key in observation: + state_value = observation[state_key] + if isinstance(state_value, Tensor) and state_value.dim() == 1: + observation[state_key] = state_value.unsqueeze(0) + + # Process single image observation - add batch dim if 3D + if OBS_IMAGE in observation: + image_value = observation[OBS_IMAGE] + if isinstance(image_value, Tensor) and image_value.dim() == 3: + observation[OBS_IMAGE] = image_value.unsqueeze(0) + + # Process multiple image observations - add batch dim if 3D + for key, value in observation.items(): + if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3: + observation[key] = value.unsqueeze(0) + + return transition + + def get_config(self) -> dict[str, Any]: + """Return configuration for serialization.""" + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + """Return state dictionary (empty for this processor).""" + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Load state dictionary (no-op for this processor).""" + pass + + def reset(self) -> None: + """Reset processor state (no-op for this processor).""" + pass diff --git a/tests/processor/test_batch_processor.py b/tests/processor/test_batch_processor.py new file mode 100644 index 000000000..e6f199833 --- /dev/null +++ b/tests/processor/test_batch_processor.py @@ -0,0 +1,423 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path + +import pytest +import torch + +from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.processor import ProcessorStepRegistry, RobotProcessor, ToBatchProcessor, TransitionKey + + +def create_transition( + observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info if info is not None else {}, + TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {}, + } + + +def test_state_1d_to_2d(): + """Test that 1D state tensors get unsqueezed to 2D.""" + processor = ToBatchProcessor() + + # Test observation.state + state_1d = torch.randn(7) + observation = {OBS_STATE: state_1d} + transition = create_transition(observation=observation) + + result = processor(transition) + + processed_state = result[TransitionKey.OBSERVATION][OBS_STATE] + assert processed_state.shape == (1, 7) + assert torch.allclose(processed_state.squeeze(0), state_1d) + + +def test_env_state_1d_to_2d(): + """Test that 1D environment state tensors get unsqueezed to 2D.""" + processor = ToBatchProcessor() + + # Test observation.environment_state + env_state_1d = torch.randn(10) + observation = {OBS_ENV_STATE: env_state_1d} + transition = create_transition(observation=observation) + + result = processor(transition) + + processed_env_state = result[TransitionKey.OBSERVATION][OBS_ENV_STATE] + assert processed_env_state.shape == (1, 10) + assert torch.allclose(processed_env_state.squeeze(0), env_state_1d) + + +def test_image_3d_to_4d(): + """Test that 3D image tensors get unsqueezed to 4D.""" + processor = ToBatchProcessor() + + # Test observation.image + image_3d = torch.randn(224, 224, 3) + observation = {OBS_IMAGE: image_3d} + transition = create_transition(observation=observation) + + result = processor(transition) + + processed_image = result[TransitionKey.OBSERVATION][OBS_IMAGE] + assert processed_image.shape == (1, 224, 224, 3) + assert torch.allclose(processed_image.squeeze(0), image_3d) + + +def test_multiple_images_3d_to_4d(): + """Test that 3D image tensors in observation.images.* get unsqueezed to 4D.""" + processor = ToBatchProcessor() + + # Test observation.images.camera1 and observation.images.camera2 + image1_3d = torch.randn(64, 64, 3) + image2_3d = torch.randn(128, 128, 3) + observation = { + f"{OBS_IMAGES}.camera1": image1_3d, + f"{OBS_IMAGES}.camera2": image2_3d, + } + transition = create_transition(observation=observation) + + result = processor(transition) + + processed_obs = result[TransitionKey.OBSERVATION] + processed_image1 = processed_obs[f"{OBS_IMAGES}.camera1"] + processed_image2 = processed_obs[f"{OBS_IMAGES}.camera2"] + + assert processed_image1.shape == (1, 64, 64, 3) + assert processed_image2.shape == (1, 128, 128, 3) + assert torch.allclose(processed_image1.squeeze(0), image1_3d) + assert torch.allclose(processed_image2.squeeze(0), image2_3d) + + +def test_already_batched_tensors_unchanged(): + """Test that already batched tensors remain unchanged.""" + processor = ToBatchProcessor() + + # Create already batched tensors + state_2d = torch.randn(1, 7) + env_state_2d = torch.randn(1, 10) + image_4d = torch.randn(1, 224, 224, 3) + + observation = { + OBS_STATE: state_2d, + OBS_ENV_STATE: env_state_2d, + OBS_IMAGE: image_4d, + } + transition = create_transition(observation=observation) + + result = processor(transition) + + processed_obs = result[TransitionKey.OBSERVATION] + + # Should remain unchanged + assert torch.allclose(processed_obs[OBS_STATE], state_2d) + assert torch.allclose(processed_obs[OBS_ENV_STATE], env_state_2d) + assert torch.allclose(processed_obs[OBS_IMAGE], image_4d) + + +def test_higher_dimensional_tensors_unchanged(): + """Test that tensors with more dimensions than expected remain unchanged.""" + processor = ToBatchProcessor() + + # Create tensors with more dimensions + state_3d = torch.randn(2, 7, 5) # More than 1D + image_5d = torch.randn(2, 3, 224, 224, 1) # More than 3D + + observation = { + OBS_STATE: state_3d, + OBS_IMAGE: image_5d, + } + transition = create_transition(observation=observation) + + result = processor(transition) + + processed_obs = result[TransitionKey.OBSERVATION] + + # Should remain unchanged + assert torch.allclose(processed_obs[OBS_STATE], state_3d) + assert torch.allclose(processed_obs[OBS_IMAGE], image_5d) + + +def test_non_tensor_values_unchanged(): + """Test that non-tensor values in observations remain unchanged.""" + processor = ToBatchProcessor() + + observation = { + OBS_STATE: [1, 2, 3], # List, not tensor + OBS_IMAGE: "not_a_tensor", # String + "custom_key": 42, # Integer + "another_key": {"nested": "dict"}, # Dict + } + transition = create_transition(observation=observation) + + result = processor(transition) + + processed_obs = result[TransitionKey.OBSERVATION] + + # Should remain unchanged + assert processed_obs[OBS_STATE] == [1, 2, 3] + assert processed_obs[OBS_IMAGE] == "not_a_tensor" + assert processed_obs["custom_key"] == 42 + assert processed_obs["another_key"] == {"nested": "dict"} + + +def test_none_observation(): + """Test processor handles None observation gracefully.""" + processor = ToBatchProcessor() + + transition = create_transition(observation=None) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION] is None + + +def test_empty_observation(): + """Test processor handles empty observation dict.""" + processor = ToBatchProcessor() + + observation = {} + transition = create_transition(observation=observation) + + result = processor(transition) + + assert result[TransitionKey.OBSERVATION] == {} + + +def test_mixed_observation(): + """Test processor with mixed observation containing various types and dimensions.""" + processor = ToBatchProcessor() + + state_1d = torch.randn(5) + env_state_2d = torch.randn(1, 8) # Already batched + image_3d = torch.randn(32, 32, 3) + other_tensor = torch.randn(3, 3, 3, 3) # 4D, should be unchanged + + observation = { + OBS_STATE: state_1d, + OBS_ENV_STATE: env_state_2d, + OBS_IMAGE: image_3d, + f"{OBS_IMAGES}.front": torch.randn(64, 64, 3), # 3D, should be batched + f"{OBS_IMAGES}.back": torch.randn(1, 64, 64, 3), # 4D, should be unchanged + "other_tensor": other_tensor, + "non_tensor": "string_value", + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check transformations + assert processed_obs[OBS_STATE].shape == (1, 5) + assert processed_obs[OBS_ENV_STATE].shape == (1, 8) # Unchanged + assert processed_obs[OBS_IMAGE].shape == (1, 32, 32, 3) + assert processed_obs[f"{OBS_IMAGES}.front"].shape == (1, 64, 64, 3) + assert processed_obs[f"{OBS_IMAGES}.back"].shape == (1, 64, 64, 3) # Unchanged + assert processed_obs["other_tensor"].shape == (3, 3, 3, 3) # Unchanged + assert processed_obs["non_tensor"] == "string_value" # Unchanged + + +def test_integration_with_robot_processor(): + """Test ToBatchProcessor integration with RobotProcessor.""" + to_batch_processor = ToBatchProcessor() + pipeline = RobotProcessor([to_batch_processor]) + + # Create unbatched observation + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(224, 224, 3), + } + transition = create_transition(observation=observation) + + result = pipeline(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert processed_obs[OBS_STATE].shape == (1, 7) + assert processed_obs[OBS_IMAGE].shape == (1, 224, 224, 3) + + +def test_serialization_methods(): + """Test get_config, state_dict, load_state_dict, and reset methods.""" + processor = ToBatchProcessor() + + # Test get_config + config = processor.get_config() + assert isinstance(config, dict) + assert config == {} + + # Test state_dict + state = processor.state_dict() + assert isinstance(state, dict) + assert state == {} + + # Test load_state_dict (should not raise an error) + processor.load_state_dict({}) + + # Test reset (should not raise an error) + processor.reset() + + +def test_save_and_load_pretrained(): + """Test saving and loading ToBatchProcessor with RobotProcessor.""" + processor = ToBatchProcessor() + pipeline = RobotProcessor([processor], name="BatchPipeline") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save pipeline + pipeline.save_pretrained(tmp_dir) + + # Check config file exists + config_path = Path(tmp_dir) / "batchpipeline.json" + assert config_path.exists() + + # Load pipeline + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + + assert loaded_pipeline.name == "BatchPipeline" + assert len(loaded_pipeline) == 1 + assert isinstance(loaded_pipeline.steps[0], ToBatchProcessor) + + # Test functionality of loaded processor + observation = {OBS_STATE: torch.randn(5)} + transition = create_transition(observation=observation) + + result = loaded_pipeline(transition) + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5) + + +def test_registry_functionality(): + """Test that ToBatchProcessor is properly registered.""" + # Check that the processor is registered + registered_class = ProcessorStepRegistry.get("to_batch_processor") + assert registered_class is ToBatchProcessor + + # Check that it's in the list of registered processors + assert "to_batch_processor" in ProcessorStepRegistry.list() + + +def test_registry_based_save_load(): + """Test saving and loading using registry name.""" + processor = ToBatchProcessor() + pipeline = RobotProcessor([processor]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + + # Verify the loaded processor works + observation = { + OBS_STATE: torch.randn(3), + OBS_IMAGE: torch.randn(100, 100, 3), + } + transition = create_transition(observation=observation) + + result = loaded_pipeline(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert processed_obs[OBS_STATE].shape == (1, 3) + assert processed_obs[OBS_IMAGE].shape == (1, 100, 100, 3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_device_compatibility(): + """Test processor works with tensors on different devices.""" + processor = ToBatchProcessor() + + # Create tensors on GPU + state_1d = torch.randn(7, device="cuda") + image_3d = torch.randn(64, 64, 3, device="cuda") + + observation = { + OBS_STATE: state_1d, + OBS_IMAGE: image_3d, + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check shapes and that tensors stayed on GPU + assert processed_obs[OBS_STATE].shape == (1, 7) + assert processed_obs[OBS_IMAGE].shape == (1, 64, 64, 3) + assert processed_obs[OBS_STATE].device.type == "cuda" + assert processed_obs[OBS_IMAGE].device.type == "cuda" + + +def test_processor_preserves_other_transition_keys(): + """Test that processor only modifies observation and preserves other transition keys.""" + processor = ToBatchProcessor() + + action = torch.randn(5) + reward = 1.5 + done = True + truncated = False + info = {"step": 10} + comp_data = {"extra": "data"} + + observation = {OBS_STATE: torch.randn(7)} + + transition = create_transition( + observation=observation, + action=action, + reward=reward, + done=done, + truncated=truncated, + info=info, + complementary_data=comp_data, + ) + + result = processor(transition) + + # Check that non-observation keys are preserved + assert torch.allclose(result[TransitionKey.ACTION], action) + assert result[TransitionKey.REWARD] == reward + assert result[TransitionKey.DONE] == done + assert result[TransitionKey.TRUNCATED] == truncated + assert result[TransitionKey.INFO] == info + assert result[TransitionKey.COMPLEMENTARY_DATA] == comp_data + + # Check that observation was processed + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7) + + +def test_edge_case_zero_dimensional_tensors(): + """Test processor handles 0D tensors (scalars) correctly.""" + processor = ToBatchProcessor() + + # 0D tensors should not be modified + scalar_tensor = torch.tensor(42.0) + + observation = { + OBS_STATE: scalar_tensor, + "scalar_value": scalar_tensor, + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # 0D tensors should remain unchanged + assert torch.allclose(processed_obs[OBS_STATE], scalar_tensor) + assert torch.allclose(processed_obs["scalar_value"], scalar_tensor)