mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
encorperate feedback pr
This commit is contained in:
@@ -98,7 +98,9 @@ def update_policy(
|
||||
if grad_clip_norm > 0:
|
||||
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
||||
else:
|
||||
grad_norm = torch.tensor(0.0, device=accelerator.device)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(), float('inf'), error_if_nonfinite=False
|
||||
)
|
||||
|
||||
# Optimizer step
|
||||
with lock if lock is not None else nullcontext():
|
||||
|
||||
@@ -52,15 +52,13 @@ def auto_select_torch_device() -> torch.device:
|
||||
|
||||
|
||||
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
|
||||
def get_safe_torch_device(
|
||||
try_device: str, log: bool = False, accelerator: Callable | None = None
|
||||
) -> torch.device:
|
||||
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||
try_device = str(try_device)
|
||||
match try_device:
|
||||
case "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
device = accelerator.device if accelerator else torch.device("cuda")
|
||||
device = torch.device("cuda")
|
||||
case "mps":
|
||||
assert torch.backends.mps.is_available()
|
||||
device = torch.device("mps")
|
||||
|
||||
Reference in New Issue
Block a user