mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-03 08:07:03 +00:00
fix(train): drive Accelerate mixed precision from policy.dtype (#3912)
* fix(train): drive Accelerate mixed precision from policy.dtype `accelerator.autocast()` was always a no-op because `mixed_precision` was never set, so `--policy.dtype=bfloat16` only cast the model params (via the policy) while autocast-eligible ops still ran in fp32/tf32. Map the active policy's `dtype` onto Accelerate's `mixed_precision` (bfloat16 -> bf16, float16 -> fp16, float32 -> no) so autocast is active for bf16/fp16 and stays full precision for float32. Policies without a string `dtype` field fall back to Accelerate's launcher default, so existing behavior is preserved. * style(train): condense mixed-precision comment to one line
This commit is contained in:
@@ -211,8 +211,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
|
||||
force_cpu = cfg.trainable_config.device == "cpu"
|
||||
# Drive Accelerate's autocast from policy.dtype (bf16/fp16 activate it; float32/absent -> launcher default).
|
||||
policy_dtype = getattr(cfg.trainable_config, "dtype", None)
|
||||
mixed_precision = {"bfloat16": "bf16", "float16": "fp16", "float32": "no"}.get(policy_dtype)
|
||||
accelerator = Accelerator(
|
||||
step_scheduler_with_optimizer=False,
|
||||
mixed_precision=mixed_precision,
|
||||
kwargs_handlers=[ddp_kwargs],
|
||||
cpu=force_cpu,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user