mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
Revert "perf(observation_processor): add CUDA support for image processing"
This reverts commit 38b88c414c.
This commit is contained in:
@@ -52,8 +52,6 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
|||||||
- Adds a batch dimension if one is not already present.
|
- Adds a batch dimension if one is not already present.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
device: str | None = None
|
|
||||||
|
|
||||||
def _process_single_image(self, img: np.ndarray) -> Tensor:
|
def _process_single_image(self, img: np.ndarray) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Processes a single NumPy image array into a channel-first, normalized tensor.
|
Processes a single NumPy image array into a channel-first, normalized tensor.
|
||||||
@@ -73,9 +71,6 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
|||||||
# Convert to tensor
|
# Convert to tensor
|
||||||
img_tensor = torch.from_numpy(img)
|
img_tensor = torch.from_numpy(img)
|
||||||
|
|
||||||
if self.device is not None:
|
|
||||||
img_tensor = img_tensor.to(device=self.device)
|
|
||||||
|
|
||||||
# Add batch dimension if needed
|
# Add batch dimension if needed
|
||||||
if img_tensor.ndim == 3:
|
if img_tensor.ndim == 3:
|
||||||
img_tensor = img_tensor.unsqueeze(0)
|
img_tensor = img_tensor.unsqueeze(0)
|
||||||
|
|||||||
@@ -52,34 +52,6 @@ def test_process_single_image():
|
|||||||
assert processed_img.max() <= 1.0
|
assert processed_img.max() <= 1.0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
||||||
def test_process_single_image_cuda():
|
|
||||||
"""Test processing a single image with CUDA device."""
|
|
||||||
processor = VanillaObservationProcessorStep(device="cuda")
|
|
||||||
|
|
||||||
# Create a mock image (H, W, C) format, uint8
|
|
||||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
|
||||||
|
|
||||||
observation = {"pixels": image}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check that the image was processed correctly
|
|
||||||
assert OBS_IMAGE in processed_obs
|
|
||||||
processed_img = processed_obs[OBS_IMAGE]
|
|
||||||
|
|
||||||
# Check shape: should be (1, 3, 64, 64) - batch, channels, height, width
|
|
||||||
assert processed_img.shape == (1, 3, 64, 64)
|
|
||||||
|
|
||||||
# Check dtype and range
|
|
||||||
assert processed_img.dtype == torch.float32
|
|
||||||
assert processed_img.device.type == "cuda"
|
|
||||||
assert processed_img.min() >= 0.0
|
|
||||||
assert processed_img.max() <= 1.0
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_image_dict():
|
def test_process_image_dict():
|
||||||
"""Test processing multiple images in a dictionary."""
|
"""Test processing multiple images in a dictionary."""
|
||||||
processor = VanillaObservationProcessorStep()
|
processor = VanillaObservationProcessorStep()
|
||||||
|
|||||||
Reference in New Issue
Block a user