mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
fix precommit and fix tests
This commit is contained in:
@@ -98,8 +98,8 @@ def update_policy(
|
|||||||
if grad_clip_norm > 0:
|
if grad_clip_norm > 0:
|
||||||
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
||||||
else:
|
else:
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
policy.parameters(), float('inf'), error_if_nonfinite=False
|
policy.parameters(), float("inf"), error_if_nonfinite=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optimizer step
|
# Optimizer step
|
||||||
@@ -151,7 +151,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
|
|
||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||||
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
|
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
|
||||||
|
|
||||||
init_logging(accelerator=accelerator)
|
init_logging(accelerator=accelerator)
|
||||||
|
|
||||||
# Determine if this is the main process (for logging and checkpointing)
|
# Determine if this is the main process (for logging and checkpointing)
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import select
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
|
||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -143,8 +142,7 @@ def init_logging(
|
|||||||
logger.handlers.clear()
|
logger.handlers.clear()
|
||||||
|
|
||||||
# Determine if this is a non-main process in distributed training
|
# Determine if this is a non-main process in distributed training
|
||||||
if accelerator is not None:
|
is_main_process = accelerator.is_main_process if accelerator is not None else True
|
||||||
is_main_process = accelerator.is_main_process
|
|
||||||
|
|
||||||
# Console logging (main process only)
|
# Console logging (main process only)
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
|
|||||||
Reference in New Issue
Block a user