mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 21:19:53 +00:00
perf(observation_processor): add CUDA support for image processing
This commit is contained in:
@@ -52,6 +52,8 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
|||||||
- Adds a batch dimension if one is not already present.
|
- Adds a batch dimension if one is not already present.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
device: str | None = None
|
||||||
|
|
||||||
def _process_single_image(self, img: np.ndarray) -> Tensor:
|
def _process_single_image(self, img: np.ndarray) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Processes a single NumPy image array into a channel-first, normalized tensor.
|
Processes a single NumPy image array into a channel-first, normalized tensor.
|
||||||
@@ -71,6 +73,9 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
|||||||
# Convert to tensor
|
# Convert to tensor
|
||||||
img_tensor = torch.from_numpy(img)
|
img_tensor = torch.from_numpy(img)
|
||||||
|
|
||||||
|
if self.device is not None:
|
||||||
|
img_tensor = img_tensor.to(device=self.device)
|
||||||
|
|
||||||
# Add batch dimension if needed
|
# Add batch dimension if needed
|
||||||
if img_tensor.ndim == 3:
|
if img_tensor.ndim == 3:
|
||||||
img_tensor = img_tensor.unsqueeze(0)
|
img_tensor = img_tensor.unsqueeze(0)
|
||||||
|
|||||||
@@ -52,6 +52,34 @@ def test_process_single_image():
|
|||||||
assert processed_img.max() <= 1.0
|
assert processed_img.max() <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||||
|
def test_process_single_image_cuda():
|
||||||
|
"""Test processing a single image with CUDA device."""
|
||||||
|
processor = VanillaObservationProcessorStep(device="cuda")
|
||||||
|
|
||||||
|
# Create a mock image (H, W, C) format, uint8
|
||||||
|
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
observation = {"pixels": image}
|
||||||
|
transition = create_transition(observation=observation)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
processed_obs = result[TransitionKey.OBSERVATION]
|
||||||
|
|
||||||
|
# Check that the image was processed correctly
|
||||||
|
assert OBS_IMAGE in processed_obs
|
||||||
|
processed_img = processed_obs[OBS_IMAGE]
|
||||||
|
|
||||||
|
# Check shape: should be (1, 3, 64, 64) - batch, channels, height, width
|
||||||
|
assert processed_img.shape == (1, 3, 64, 64)
|
||||||
|
|
||||||
|
# Check dtype and range
|
||||||
|
assert processed_img.dtype == torch.float32
|
||||||
|
assert processed_img.device.type == "cuda"
|
||||||
|
assert processed_img.min() >= 0.0
|
||||||
|
assert processed_img.max() <= 1.0
|
||||||
|
|
||||||
|
|
||||||
def test_process_image_dict():
|
def test_process_image_dict():
|
||||||
"""Test processing multiple images in a dictionary."""
|
"""Test processing multiple images in a dictionary."""
|
||||||
processor = VanillaObservationProcessorStep()
|
processor = VanillaObservationProcessorStep()
|
||||||
|
|||||||
Reference in New Issue
Block a user