mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
encorperate feedback pr
This commit is contained in:
@@ -98,7 +98,9 @@ def update_policy(
|
|||||||
if grad_clip_norm > 0:
|
if grad_clip_norm > 0:
|
||||||
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
||||||
else:
|
else:
|
||||||
grad_norm = torch.tensor(0.0, device=accelerator.device)
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
policy.parameters(), float('inf'), error_if_nonfinite=False
|
||||||
|
)
|
||||||
|
|
||||||
# Optimizer step
|
# Optimizer step
|
||||||
with lock if lock is not None else nullcontext():
|
with lock if lock is not None else nullcontext():
|
||||||
|
|||||||
@@ -52,15 +52,13 @@ def auto_select_torch_device() -> torch.device:
|
|||||||
|
|
||||||
|
|
||||||
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
|
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
|
||||||
def get_safe_torch_device(
|
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||||
try_device: str, log: bool = False, accelerator: Callable | None = None
|
|
||||||
) -> torch.device:
|
|
||||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||||
try_device = str(try_device)
|
try_device = str(try_device)
|
||||||
match try_device:
|
match try_device:
|
||||||
case "cuda":
|
case "cuda":
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
device = accelerator.device if accelerator else torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
case "mps":
|
case "mps":
|
||||||
assert torch.backends.mps.is_available()
|
assert torch.backends.mps.is_available()
|
||||||
device = torch.device("mps")
|
device = torch.device("mps")
|
||||||
|
|||||||
Reference in New Issue
Block a user