From 38b88c414cdc1f53ebaab3211e688fe87522b732 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Fri, 24 Apr 2026 13:36:26 +0200 Subject: [PATCH] perf(observation_processor): add CUDA support for image processing --- .../processor/observation_processor.py | 5 ++++ tests/processor/test_observation_processor.py | 28 +++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index 12d1f82a2..03f7d661d 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -52,6 +52,8 @@ class VanillaObservationProcessorStep(ObservationProcessorStep): - Adds a batch dimension if one is not already present. """ + device: str | None = None + def _process_single_image(self, img: np.ndarray) -> Tensor: """ Processes a single NumPy image array into a channel-first, normalized tensor. @@ -71,6 +73,9 @@ class VanillaObservationProcessorStep(ObservationProcessorStep): # Convert to tensor img_tensor = torch.from_numpy(img) + if self.device is not None: + img_tensor = img_tensor.to(device=self.device) + # Add batch dimension if needed if img_tensor.ndim == 3: img_tensor = img_tensor.unsqueeze(0) diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index 923059210..e0208b30a 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -52,6 +52,34 @@ def test_process_single_image(): assert processed_img.max() <= 1.0 +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_process_single_image_cuda(): + """Test processing a single image with CUDA device.""" + processor = VanillaObservationProcessorStep(device="cuda") + + # Create a mock image (H, W, C) format, uint8 + image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) + + observation = {"pixels": image} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that the image was processed correctly + assert OBS_IMAGE in processed_obs + processed_img = processed_obs[OBS_IMAGE] + + # Check shape: should be (1, 3, 64, 64) - batch, channels, height, width + assert processed_img.shape == (1, 3, 64, 64) + + # Check dtype and range + assert processed_img.dtype == torch.float32 + assert processed_img.device.type == "cuda" + assert processed_img.min() >= 0.0 + assert processed_img.max() <= 1.0 + + def test_process_image_dict(): """Test processing multiple images in a dictionary.""" processor = VanillaObservationProcessorStep()