mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix precommit and fix tests
This commit is contained in:
@@ -98,8 +98,8 @@ def update_policy(
|
||||
if grad_clip_norm > 0:
|
||||
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(), float('inf'), error_if_nonfinite=False
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(), float("inf"), error_if_nonfinite=False
|
||||
)
|
||||
|
||||
# Optimizer step
|
||||
@@ -151,7 +151,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
|
||||
|
||||
|
||||
init_logging(accelerator=accelerator)
|
||||
|
||||
# Determine if this is the main process (for logging and checkpointing)
|
||||
|
||||
@@ -20,7 +20,6 @@ import select
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from copy import copy, deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -143,8 +142,7 @@ def init_logging(
|
||||
logger.handlers.clear()
|
||||
|
||||
# Determine if this is a non-main process in distributed training
|
||||
if accelerator is not None:
|
||||
is_main_process = accelerator.is_main_process
|
||||
is_main_process = accelerator.is_main_process if accelerator is not None else True
|
||||
|
||||
# Console logging (main process only)
|
||||
if is_main_process:
|
||||
|
||||
Reference in New Issue
Block a user