mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
fix formatting
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pprint import pformat
|
||||
@@ -143,16 +142,13 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
# Create Accelerator if not provided
|
||||
# It will automatically detect if running in distributed mode or single-process mode
|
||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting
|
||||
# the lr_scheduler steps based on the num_processes
|
||||
# We set find_unused_parameters=True to handle models with conditional computation paths
|
||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
||||
# We set find_unused_parameters=True to handle models with conditional computation
|
||||
if accelerator is None:
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
accelerator = Accelerator(
|
||||
step_scheduler_with_optimizer=False,
|
||||
kwargs_handlers=[ddp_kwargs]
|
||||
)
|
||||
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
|
||||
|
||||
# Determine if this is the main process (for logging and checkpointing)
|
||||
# When using accelerate, only the main process should log to avoid duplicate outputs
|
||||
@@ -183,7 +179,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Wait for main process to finish downloading/caching dataset
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Now all other processes can safely load the dataset
|
||||
@@ -341,7 +336,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
step += 1
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_saving_step = (step % cfg.save_freq == 0 or step == cfg.steps)
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||
|
||||
if is_log_step:
|
||||
|
||||
@@ -21,7 +21,6 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from accelerate import Accelerator
|
||||
from copy import copy, deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -29,6 +28,7 @@ from statistics import mean
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
||||
|
||||
|
||||
@@ -128,6 +128,7 @@ def init_logging(
|
||||
file_level: Logging level for file output
|
||||
accelerator: Optional Accelerator instance (for multi-GPU detection)
|
||||
"""
|
||||
|
||||
def custom_format(record: logging.LogRecord) -> str:
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
fnameline = f"{record.pathname}:{record.lineno}"
|
||||
@@ -159,7 +160,6 @@ def init_logging(
|
||||
logger.addHandler(logging.NullHandler())
|
||||
logger.setLevel(logging.ERROR)
|
||||
|
||||
# File logging (optional, all processes)
|
||||
if log_file is not None:
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(formatter)
|
||||
@@ -178,6 +178,7 @@ def format_big_number(num, precision=0):
|
||||
|
||||
return num
|
||||
|
||||
|
||||
def say(text: str, blocking: bool = False):
|
||||
system = platform.system()
|
||||
|
||||
|
||||
@@ -25,9 +25,7 @@ The tests automatically generate accelerate configs and launch training
|
||||
with subprocess to properly test the distributed training environment.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
@@ -70,15 +68,6 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
||||
Returns:
|
||||
subprocess.CompletedProcess result
|
||||
"""
|
||||
# Create accelerate config
|
||||
accelerate_config = {
|
||||
"compute_environment": "LOCAL_MACHINE",
|
||||
"distributed_type": "MULTI_GPU",
|
||||
"mixed_precision": "no",
|
||||
"num_processes": num_processes,
|
||||
"use_cpu": False,
|
||||
"gpu_ids": "all",
|
||||
}
|
||||
|
||||
config_path = Path(temp_dir) / "accelerate_config.yaml"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user