fix formatting

This commit is contained in:
Pepijn
2025-10-14 17:37:47 +02:00
parent f8a185f753
commit ebf64bd80e
6 changed files with 50 additions and 65 deletions
+5 -10
View File
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import os
import time import time
from contextlib import nullcontext from contextlib import nullcontext
from pprint import pformat from pprint import pformat
@@ -143,16 +142,13 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# Create Accelerator if not provided # Create Accelerator if not provided
# It will automatically detect if running in distributed mode or single-process mode # It will automatically detect if running in distributed mode or single-process mode
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting # We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
# the lr_scheduler steps based on the num_processes # We set find_unused_parameters=True to handle models with conditional computation
# We set find_unused_parameters=True to handle models with conditional computation paths
if accelerator is None: if accelerator is None:
from accelerate.utils import DistributedDataParallelKwargs from accelerate.utils import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator( accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
step_scheduler_with_optimizer=False,
kwargs_handlers=[ddp_kwargs]
)
# Determine if this is the main process (for logging and checkpointing) # Determine if this is the main process (for logging and checkpointing)
# When using accelerate, only the main process should log to avoid duplicate outputs # When using accelerate, only the main process should log to avoid duplicate outputs
@@ -183,7 +179,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info("Creating dataset") logging.info("Creating dataset")
dataset = make_dataset(cfg) dataset = make_dataset(cfg)
# Wait for main process to finish downloading/caching dataset
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# Now all other processes can safely load the dataset # Now all other processes can safely load the dataset
@@ -341,7 +336,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
step += 1 step += 1
train_tracker.step() train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
is_saving_step = (step % cfg.save_freq == 0 or step == cfg.steps) is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
if is_log_step: if is_log_step:
+3 -2
View File
@@ -21,7 +21,6 @@ import subprocess
import sys import sys
import time import time
from collections.abc import Callable from collections.abc import Callable
from accelerate import Accelerator
from copy import copy, deepcopy from copy import copy, deepcopy
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -29,6 +28,7 @@ from statistics import mean
import numpy as np import numpy as np
import torch import torch
from accelerate import Accelerator
from datasets.utils.logging import disable_progress_bar, enable_progress_bar from datasets.utils.logging import disable_progress_bar, enable_progress_bar
@@ -128,6 +128,7 @@ def init_logging(
file_level: Logging level for file output file_level: Logging level for file output
accelerator: Optional Accelerator instance (for multi-GPU detection) accelerator: Optional Accelerator instance (for multi-GPU detection)
""" """
def custom_format(record: logging.LogRecord) -> str: def custom_format(record: logging.LogRecord) -> str:
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}" fnameline = f"{record.pathname}:{record.lineno}"
@@ -159,7 +160,6 @@ def init_logging(
logger.addHandler(logging.NullHandler()) logger.addHandler(logging.NullHandler())
logger.setLevel(logging.ERROR) logger.setLevel(logging.ERROR)
# File logging (optional, all processes)
if log_file is not None: if log_file is not None:
file_handler = logging.FileHandler(log_file) file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
@@ -178,6 +178,7 @@ def format_big_number(num, precision=0):
return num return num
def say(text: str, blocking: bool = False): def say(text: str, blocking: bool = False):
system = platform.system() system = platform.system()
-11
View File
@@ -25,9 +25,7 @@ The tests automatically generate accelerate configs and launch training
with subprocess to properly test the distributed training environment. with subprocess to properly test the distributed training environment.
""" """
import json
import os import os
import shutil
import subprocess import subprocess
import tempfile import tempfile
from pathlib import Path from pathlib import Path
@@ -70,15 +68,6 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
Returns: Returns:
subprocess.CompletedProcess result subprocess.CompletedProcess result
""" """
# Create accelerate config
accelerate_config = {
"compute_environment": "LOCAL_MACHINE",
"distributed_type": "MULTI_GPU",
"mixed_precision": "no",
"num_processes": num_processes,
"use_cpu": False,
"gpu_ids": "all",
}
config_path = Path(temp_dir) / "accelerate_config.yaml" config_path = Path(temp_dir) / "accelerate_config.yaml"