mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
Revert "perf(observation_processor): add CUDA support for image processing"
This reverts commit 38b88c414c.
This commit is contained in:
@@ -52,8 +52,6 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
- Adds a batch dimension if one is not already present.
|
||||
"""
|
||||
|
||||
device: str | None = None
|
||||
|
||||
def _process_single_image(self, img: np.ndarray) -> Tensor:
|
||||
"""
|
||||
Processes a single NumPy image array into a channel-first, normalized tensor.
|
||||
@@ -73,9 +71,6 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
# Convert to tensor
|
||||
img_tensor = torch.from_numpy(img)
|
||||
|
||||
if self.device is not None:
|
||||
img_tensor = img_tensor.to(device=self.device)
|
||||
|
||||
# Add batch dimension if needed
|
||||
if img_tensor.ndim == 3:
|
||||
img_tensor = img_tensor.unsqueeze(0)
|
||||
|
||||
@@ -52,34 +52,6 @@ def test_process_single_image():
|
||||
assert processed_img.max() <= 1.0
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_process_single_image_cuda():
|
||||
"""Test processing a single image with CUDA device."""
|
||||
processor = VanillaObservationProcessorStep(device="cuda")
|
||||
|
||||
# Create a mock image (H, W, C) format, uint8
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
|
||||
observation = {"pixels": image}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that the image was processed correctly
|
||||
assert OBS_IMAGE in processed_obs
|
||||
processed_img = processed_obs[OBS_IMAGE]
|
||||
|
||||
# Check shape: should be (1, 3, 64, 64) - batch, channels, height, width
|
||||
assert processed_img.shape == (1, 3, 64, 64)
|
||||
|
||||
# Check dtype and range
|
||||
assert processed_img.dtype == torch.float32
|
||||
assert processed_img.device.type == "cuda"
|
||||
assert processed_img.min() >= 0.0
|
||||
assert processed_img.max() <= 1.0
|
||||
|
||||
|
||||
def test_process_image_dict():
|
||||
"""Test processing multiple images in a dictionary."""
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
Reference in New Issue
Block a user