encorperate feedback pr

This commit is contained in:
Pepijn
2025-10-15 12:57:41 +02:00
parent 1d86482101
commit 300d614ae5
2 changed files with 5 additions and 5 deletions
+3 -1
View File
@@ -98,7 +98,9 @@ def update_policy(
if grad_clip_norm > 0:
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
else:
grad_norm = torch.tensor(0.0, device=accelerator.device)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), float('inf'), error_if_nonfinite=False
)
# Optimizer step
with lock if lock is not None else nullcontext():
+2 -4
View File
@@ -52,15 +52,13 @@ def auto_select_torch_device() -> torch.device:
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
def get_safe_torch_device(
try_device: str, log: bool = False, accelerator: Callable | None = None
) -> torch.device:
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available."""
try_device = str(try_device)
match try_device:
case "cuda":
assert torch.cuda.is_available()
device = accelerator.device if accelerator else torch.device("cuda")
device = torch.device("cuda")
case "mps":
assert torch.backends.mps.is_available()
device = torch.device("mps")