fix(train): drive Accelerate mixed precision from policy.dtype (#3912)

* fix(train): drive Accelerate mixed precision from policy.dtype

`accelerator.autocast()` was always a no-op because `mixed_precision`
was never set, so `--policy.dtype=bfloat16` only cast the model params
(via the policy) while autocast-eligible ops still ran in fp32/tf32.

Map the active policy's `dtype` onto Accelerate's `mixed_precision`
(bfloat16 -> bf16, float16 -> fp16, float32 -> no) so autocast is active
for bf16/fp16 and stays full precision for float32. Policies without a
string `dtype` field fall back to Accelerate's launcher default, so
existing behavior is preserved.

* style(train): condense mixed-precision comment to one line
This commit is contained in:
Pepijn
2026-07-02 19:15:19 +02:00
committed by GitHub
parent 7ae12124b0
commit 07285677a3
+4
View File
@@ -211,8 +211,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
force_cpu = cfg.trainable_config.device == "cpu"
# Drive Accelerate's autocast from policy.dtype (bf16/fp16 activate it; float32/absent -> launcher default).
policy_dtype = getattr(cfg.trainable_config, "dtype", None)
mixed_precision = {"bfloat16": "bf16", "float16": "fp16", "float32": "no"}.get(policy_dtype)
accelerator = Accelerator(
step_scheduler_with_optimizer=False,
mixed_precision=mixed_precision,
kwargs_handlers=[ddp_kwargs],
cpu=force_cpu,
)