Compare commits

...

1 Commits

Author SHA1 Message Date
Michel Aractingi 74b7cd246e add check for cfg.policy in force_cpu line 2026-01-19 13:54:44 +01:00
+2 -1
View File
@@ -263,7 +263,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
# Force the device to be CPU when policy.device is set to CPU.
force_cpu = cfg.policy.device == "cpu"
# Note (maractin): cfg.policy may be None before validate() fully loads from pretrained_path
force_cpu = cfg.policy is not None and cfg.policy.device == "cpu"
accelerator = Accelerator(
step_scheduler_with_optimizer=False,
kwargs_handlers=[ddp_kwargs],