mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
refactor(device_processor): Update device handling and improve type hints
- Changed device attribute type from torch.device to str for better clarity. - Introduced a private _device attribute to store the actual torch.device instance. - Updated tests to conditionally check for CUDA availability, ensuring compatibility across different environments. - Refactored device-related assertions in tests to use a consistent approach for device type verification.
This commit is contained in:
@@ -34,11 +34,13 @@ class DeviceProcessor:
|
||||
(int, long, bool, etc.).
|
||||
"""
|
||||
|
||||
device: torch.device = "cpu"
|
||||
device: str = "cpu"
|
||||
float_dtype: str | None = None
|
||||
_device: torch.device | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.device = get_safe_torch_device(self.device)
|
||||
self._device = get_safe_torch_device(self.device)
|
||||
self.device = self._device.type
|
||||
self.non_blocking = "cuda" in str(self.device)
|
||||
|
||||
# Validate and convert float_dtype string to torch dtype
|
||||
|
||||
Reference in New Issue
Block a user