mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
fix formatting
This commit is contained in:
@@ -14,7 +14,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
@@ -143,16 +142,13 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
|
|
||||||
# Create Accelerator if not provided
|
# Create Accelerator if not provided
|
||||||
# It will automatically detect if running in distributed mode or single-process mode
|
# It will automatically detect if running in distributed mode or single-process mode
|
||||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting
|
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
||||||
# the lr_scheduler steps based on the num_processes
|
# We set find_unused_parameters=True to handle models with conditional computation
|
||||||
# We set find_unused_parameters=True to handle models with conditional computation paths
|
|
||||||
if accelerator is None:
|
if accelerator is None:
|
||||||
from accelerate.utils import DistributedDataParallelKwargs
|
from accelerate.utils import DistributedDataParallelKwargs
|
||||||
|
|
||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
|
||||||
step_scheduler_with_optimizer=False,
|
|
||||||
kwargs_handlers=[ddp_kwargs]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine if this is the main process (for logging and checkpointing)
|
# Determine if this is the main process (for logging and checkpointing)
|
||||||
# When using accelerate, only the main process should log to avoid duplicate outputs
|
# When using accelerate, only the main process should log to avoid duplicate outputs
|
||||||
@@ -183,7 +179,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
logging.info("Creating dataset")
|
logging.info("Creating dataset")
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
|
|
||||||
# Wait for main process to finish downloading/caching dataset
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# Now all other processes can safely load the dataset
|
# Now all other processes can safely load the dataset
|
||||||
@@ -341,7 +336,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
step += 1
|
step += 1
|
||||||
train_tracker.step()
|
train_tracker.step()
|
||||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||||
is_saving_step = (step % cfg.save_freq == 0 or step == cfg.steps)
|
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||||
|
|
||||||
if is_log_step:
|
if is_log_step:
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from accelerate import Accelerator
|
|
||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -29,6 +28,7 @@ from statistics import mean
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
||||||
|
|
||||||
|
|
||||||
@@ -128,6 +128,7 @@ def init_logging(
|
|||||||
file_level: Logging level for file output
|
file_level: Logging level for file output
|
||||||
accelerator: Optional Accelerator instance (for multi-GPU detection)
|
accelerator: Optional Accelerator instance (for multi-GPU detection)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def custom_format(record: logging.LogRecord) -> str:
|
def custom_format(record: logging.LogRecord) -> str:
|
||||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
fnameline = f"{record.pathname}:{record.lineno}"
|
fnameline = f"{record.pathname}:{record.lineno}"
|
||||||
@@ -159,7 +160,6 @@ def init_logging(
|
|||||||
logger.addHandler(logging.NullHandler())
|
logger.addHandler(logging.NullHandler())
|
||||||
logger.setLevel(logging.ERROR)
|
logger.setLevel(logging.ERROR)
|
||||||
|
|
||||||
# File logging (optional, all processes)
|
|
||||||
if log_file is not None:
|
if log_file is not None:
|
||||||
file_handler = logging.FileHandler(log_file)
|
file_handler = logging.FileHandler(log_file)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
@@ -178,6 +178,7 @@ def format_big_number(num, precision=0):
|
|||||||
|
|
||||||
return num
|
return num
|
||||||
|
|
||||||
|
|
||||||
def say(text: str, blocking: bool = False):
|
def say(text: str, blocking: bool = False):
|
||||||
system = platform.system()
|
system = platform.system()
|
||||||
|
|
||||||
|
|||||||
@@ -25,9 +25,7 @@ The tests automatically generate accelerate configs and launch training
|
|||||||
with subprocess to properly test the distributed training environment.
|
with subprocess to properly test the distributed training environment.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -70,15 +68,6 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
|||||||
Returns:
|
Returns:
|
||||||
subprocess.CompletedProcess result
|
subprocess.CompletedProcess result
|
||||||
"""
|
"""
|
||||||
# Create accelerate config
|
|
||||||
accelerate_config = {
|
|
||||||
"compute_environment": "LOCAL_MACHINE",
|
|
||||||
"distributed_type": "MULTI_GPU",
|
|
||||||
"mixed_precision": "no",
|
|
||||||
"num_processes": num_processes,
|
|
||||||
"use_cpu": False,
|
|
||||||
"gpu_ids": "all",
|
|
||||||
}
|
|
||||||
|
|
||||||
config_path = Path(temp_dir) / "accelerate_config.yaml"
|
config_path = Path(temp_dir) / "accelerate_config.yaml"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user