From 300d614ae554ece73584d599f8d476f4c6205ae6 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 15 Oct 2025 12:57:41 +0200 Subject: [PATCH] encorperate feedback pr --- src/lerobot/scripts/lerobot_train.py | 4 +++- src/lerobot/utils/utils.py | 6 ++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 00054702e..c738b89f8 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -98,7 +98,9 @@ def update_policy( if grad_clip_norm > 0: grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) else: - grad_norm = torch.tensor(0.0, device=accelerator.device) + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.parameters(), float('inf'), error_if_nonfinite=False + ) # Optimizer step with lock if lock is not None else nullcontext(): diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index fa7124ae6..1614a40cb 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -52,15 +52,13 @@ def auto_select_torch_device() -> torch.device: # TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level -def get_safe_torch_device( - try_device: str, log: bool = False, accelerator: Callable | None = None -) -> torch.device: +def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: """Given a string, return a torch.device with checks on whether the device is available.""" try_device = str(try_device) match try_device: case "cuda": assert torch.cuda.is_available() - device = accelerator.device if accelerator else torch.device("cuda") + device = torch.device("cuda") case "mps": assert torch.backends.mps.is_available() device = torch.device("mps")