refactor(device_processor): Update device handling and improve type hints

- Changed device attribute type from torch.device to str for better clarity.
- Introduced a private _device attribute to store the actual torch.device instance.
- Updated tests to conditionally check for CUDA availability, ensuring compatibility across different environments.
- Refactored device-related assertions in tests to use a consistent approach for device type verification.
This commit is contained in:
Adil Zouitine
2025-08-06 18:08:15 +02:00
parent 2805ae347c
commit 0535f2a59a
2 changed files with 40 additions and 27 deletions
+4 -2
View File
@@ -34,11 +34,13 @@ class DeviceProcessor:
(int, long, bool, etc.).
"""
device: torch.device = "cpu"
device: str = "cpu"
float_dtype: str | None = None
_device: torch.device | None = None
def __post_init__(self):
self.device = get_safe_torch_device(self.device)
self._device = get_safe_torch_device(self.device)
self.device = self._device.type
self.non_blocking = "cuda" in str(self.device)
# Validate and convert float_dtype string to torch dtype