From 8765b57c0a3601f180a9da0073ff6320689fc03a Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 16 Oct 2025 14:52:32 +0200 Subject: [PATCH] fix precommit and fix tests --- src/lerobot/scripts/lerobot_train.py | 6 +++--- src/lerobot/utils/utils.py | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 524e68bfd..84eb81ad4 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -98,8 +98,8 @@ def update_policy( if grad_clip_norm > 0: grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) else: - grad_norm = torch.nn.utils.clip_grad_norm_( - policy.parameters(), float('inf'), error_if_nonfinite=False + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.parameters(), float("inf"), error_if_nonfinite=False ) # Optimizer step @@ -151,7 +151,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs]) - + init_logging(accelerator=accelerator) # Determine if this is the main process (for logging and checkpointing) diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index dd685654f..4447a1fcf 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -20,7 +20,6 @@ import select import subprocess import sys import time -from collections.abc import Callable from copy import copy, deepcopy from datetime import datetime from pathlib import Path @@ -143,8 +142,7 @@ def init_logging( logger.handlers.clear() # Determine if this is a non-main process in distributed training - if accelerator is not None: - is_main_process = accelerator.is_main_process + is_main_process = accelerator.is_main_process if accelerator is not None else True # Console logging (main process only) if is_main_process: