mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
fix precommit and fix tests
This commit is contained in:
@@ -98,8 +98,8 @@ def update_policy(
|
||||
if grad_clip_norm > 0:
|
||||
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(), float('inf'), error_if_nonfinite=False
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(), float("inf"), error_if_nonfinite=False
|
||||
)
|
||||
|
||||
# Optimizer step
|
||||
|
||||
@@ -20,7 +20,6 @@ import select
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from copy import copy, deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -143,8 +142,7 @@ def init_logging(
|
||||
logger.handlers.clear()
|
||||
|
||||
# Determine if this is a non-main process in distributed training
|
||||
if accelerator is not None:
|
||||
is_main_process = accelerator.is_main_process
|
||||
is_main_process = accelerator.is_main_process if accelerator is not None else True
|
||||
|
||||
# Console logging (main process only)
|
||||
if is_main_process:
|
||||
|
||||
Reference in New Issue
Block a user