fix precommit and fix tests

This commit is contained in:
Pepijn
2025-10-16 14:52:32 +02:00
parent ccdf06f0f1
commit 8765b57c0a
2 changed files with 4 additions and 6 deletions
+3 -3
View File
@@ -98,8 +98,8 @@ def update_policy(
if grad_clip_norm > 0:
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), float('inf'), error_if_nonfinite=False
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), float("inf"), error_if_nonfinite=False
)
# Optimizer step
@@ -151,7 +151,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
init_logging(accelerator=accelerator)
# Determine if this is the main process (for logging and checkpointing)
+1 -3
View File
@@ -20,7 +20,6 @@ import select
import subprocess
import sys
import time
from collections.abc import Callable
from copy import copy, deepcopy
from datetime import datetime
from pathlib import Path
@@ -143,8 +142,7 @@ def init_logging(
logger.handlers.clear()
# Determine if this is a non-main process in distributed training
if accelerator is not None:
is_main_process = accelerator.is_main_process
is_main_process = accelerator.is_main_process if accelerator is not None else True
# Console logging (main process only)
if is_main_process: