perf(observation_processor): add CUDA support for image processing

This commit is contained in:
Khalil Meftah
2026-04-24 13:36:26 +02:00
parent 1ed32210c7
commit 38b88c414c
2 changed files with 33 additions and 0 deletions
@@ -52,6 +52,34 @@ def test_process_single_image():
assert processed_img.max() <= 1.0
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_process_single_image_cuda():
"""Test processing a single image with CUDA device."""
processor = VanillaObservationProcessorStep(device="cuda")
# Create a mock image (H, W, C) format, uint8
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that the image was processed correctly
assert OBS_IMAGE in processed_obs
processed_img = processed_obs[OBS_IMAGE]
# Check shape: should be (1, 3, 64, 64) - batch, channels, height, width
assert processed_img.shape == (1, 3, 64, 64)
# Check dtype and range
assert processed_img.dtype == torch.float32
assert processed_img.device.type == "cuda"
assert processed_img.min() >= 0.0
assert processed_img.max() <= 1.0
def test_process_image_dict():
"""Test processing multiple images in a dictionary."""
processor = VanillaObservationProcessorStep()