mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 22:49:48 +00:00
fix: support cuda:0, cuda:1 in string selection (#2256)
* fix * update func 2 * update nightly * fix quality * ignore test_dynamixel
This commit is contained in:
+18
-20
@@ -57,25 +57,23 @@ def auto_select_torch_device() -> torch.device:
|
||||
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||
try_device = str(try_device)
|
||||
match try_device:
|
||||
case "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
device = torch.device("cuda")
|
||||
case "mps":
|
||||
assert torch.backends.mps.is_available()
|
||||
device = torch.device("mps")
|
||||
case "xpu":
|
||||
assert torch.xpu.is_available()
|
||||
device = torch.device("xpu")
|
||||
case "cpu":
|
||||
device = torch.device("cpu")
|
||||
if log:
|
||||
logging.warning("Using CPU, this will be slow.")
|
||||
case _:
|
||||
device = torch.device(try_device)
|
||||
if log:
|
||||
logging.warning(f"Using custom {try_device} device.")
|
||||
|
||||
if try_device.startswith("cuda"):
|
||||
assert torch.cuda.is_available()
|
||||
device = torch.device(try_device)
|
||||
elif try_device == "mps":
|
||||
assert torch.backends.mps.is_available()
|
||||
device = torch.device("mps")
|
||||
elif try_device == "xpu":
|
||||
assert torch.xpu.is_available()
|
||||
device = torch.device("xpu")
|
||||
elif try_device == "cpu":
|
||||
device = torch.device("cpu")
|
||||
if log:
|
||||
logging.warning("Using CPU, this will be slow.")
|
||||
else:
|
||||
device = torch.device(try_device)
|
||||
if log:
|
||||
logging.warning(f"Using custom {try_device} device.")
|
||||
return device
|
||||
|
||||
|
||||
@@ -108,7 +106,7 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
|
||||
|
||||
def is_torch_device_available(try_device: str) -> bool:
|
||||
try_device = str(try_device) # Ensure try_device is a string
|
||||
if try_device == "cuda":
|
||||
if try_device.startswith("cuda"):
|
||||
return torch.cuda.is_available()
|
||||
elif try_device == "mps":
|
||||
return torch.backends.mps.is_available()
|
||||
|
||||
Reference in New Issue
Block a user