mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-28 21:57:27 +00:00
try with local rank
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
@@ -164,8 +165,15 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
"""
|
||||
cfg.validate()
|
||||
|
||||
# Check if this is the main process (process_index == 0 or no accelerator)
|
||||
is_main_process = not accelerator or (hasattr(accelerator, 'process_index') and accelerator.process_index == 0) or accelerator.is_main_process
|
||||
# Check if this is the main process
|
||||
# Use LOCAL_RANK environment variable (set by accelerate) for reliable detection
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if local_rank == -1:
|
||||
# No LOCAL_RANK, check accelerator object or assume main process
|
||||
is_main_process = not accelerator or (hasattr(accelerator, 'is_main_process') and accelerator.is_main_process)
|
||||
else:
|
||||
# LOCAL_RANK is set, main process is rank 0
|
||||
is_main_process = local_rank == 0
|
||||
|
||||
if accelerator and not is_main_process:
|
||||
# Disable WandB and logging on non-main processes.
|
||||
|
||||
@@ -138,8 +138,13 @@ def init_logging(
|
||||
logger.removeHandler(handler)
|
||||
|
||||
# Check if this is a non-main process in multi-GPU training
|
||||
# Use process_index to be more explicit (main process is index 0)
|
||||
is_non_main_process = accelerator is not None and hasattr(accelerator, 'process_index') and accelerator.process_index != 0
|
||||
# Check environment variables set by accelerate (more reliable than checking accelerator object)
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
is_non_main_process = local_rank > 0
|
||||
|
||||
# Fallback to accelerator object check if LOCAL_RANK not set
|
||||
if local_rank == -1 and accelerator is not None:
|
||||
is_non_main_process = hasattr(accelerator, 'process_index') and accelerator.process_index != 0
|
||||
|
||||
# Write logs to console (only for main process in multi-GPU training)
|
||||
if not is_non_main_process:
|
||||
|
||||
Reference in New Issue
Block a user