try with local rank

This commit is contained in:
Pepijn
2025-10-10 15:52:49 +02:00
parent 63fcebd5a7
commit a74affad7c
2 changed files with 17 additions and 4 deletions
+10 -2
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import time
from collections.abc import Callable
from contextlib import nullcontext
@@ -164,8 +165,15 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
"""
cfg.validate()
# Check if this is the main process (process_index == 0 or no accelerator)
is_main_process = not accelerator or (hasattr(accelerator, 'process_index') and accelerator.process_index == 0) or accelerator.is_main_process
# Check if this is the main process
# Use LOCAL_RANK environment variable (set by accelerate) for reliable detection
local_rank = int(os.environ.get("LOCAL_RANK", -1))
if local_rank == -1:
# No LOCAL_RANK, check accelerator object or assume main process
is_main_process = not accelerator or (hasattr(accelerator, 'is_main_process') and accelerator.is_main_process)
else:
# LOCAL_RANK is set, main process is rank 0
is_main_process = local_rank == 0
if accelerator and not is_main_process:
# Disable WandB and logging on non-main processes.
+7 -2
View File
@@ -138,8 +138,13 @@ def init_logging(
logger.removeHandler(handler)
# Check if this is a non-main process in multi-GPU training
# Use process_index to be more explicit (main process is index 0)
is_non_main_process = accelerator is not None and hasattr(accelerator, 'process_index') and accelerator.process_index != 0
# Check environment variables set by accelerate (more reliable than checking accelerator object)
local_rank = int(os.environ.get("LOCAL_RANK", -1))
is_non_main_process = local_rank > 0
# Fallback to accelerator object check if LOCAL_RANK not set
if local_rank == -1 and accelerator is not None:
is_non_main_process = hasattr(accelerator, 'process_index') and accelerator.process_index != 0
# Write logs to console (only for main process in multi-GPU training)
if not is_non_main_process: