mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
perf(observation_processor): add CUDA support for image processing
This commit is contained in:
@@ -52,6 +52,34 @@ def test_process_single_image():
|
||||
assert processed_img.max() <= 1.0
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_process_single_image_cuda():
|
||||
"""Test processing a single image with CUDA device."""
|
||||
processor = VanillaObservationProcessorStep(device="cuda")
|
||||
|
||||
# Create a mock image (H, W, C) format, uint8
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
|
||||
observation = {"pixels": image}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that the image was processed correctly
|
||||
assert OBS_IMAGE in processed_obs
|
||||
processed_img = processed_obs[OBS_IMAGE]
|
||||
|
||||
# Check shape: should be (1, 3, 64, 64) - batch, channels, height, width
|
||||
assert processed_img.shape == (1, 3, 64, 64)
|
||||
|
||||
# Check dtype and range
|
||||
assert processed_img.dtype == torch.float32
|
||||
assert processed_img.device.type == "cuda"
|
||||
assert processed_img.min() >= 0.0
|
||||
assert processed_img.max() <= 1.0
|
||||
|
||||
|
||||
def test_process_image_dict():
|
||||
"""Test processing multiple images in a dictionary."""
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
Reference in New Issue
Block a user