mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
fix precommit and fix tests
This commit is contained in:
@@ -99,7 +99,7 @@ def update_policy(
|
|||||||
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
||||||
else:
|
else:
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
policy.parameters(), float('inf'), error_if_nonfinite=False
|
policy.parameters(), float("inf"), error_if_nonfinite=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optimizer step
|
# Optimizer step
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import select
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
|
||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -143,8 +142,7 @@ def init_logging(
|
|||||||
logger.handlers.clear()
|
logger.handlers.clear()
|
||||||
|
|
||||||
# Determine if this is a non-main process in distributed training
|
# Determine if this is a non-main process in distributed training
|
||||||
if accelerator is not None:
|
is_main_process = accelerator.is_main_process if accelerator is not None else True
|
||||||
is_main_process = accelerator.is_main_process
|
|
||||||
|
|
||||||
# Console logging (main process only)
|
# Console logging (main process only)
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
|
|||||||
Reference in New Issue
Block a user