diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 44c94a1eb..6e8458523 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -211,8 +211,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): # Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting. # Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training). force_cpu = cfg.trainable_config.device == "cpu" + # Drive Accelerate's autocast from policy.dtype (bf16/fp16 activate it; float32/absent -> launcher default). + policy_dtype = getattr(cfg.trainable_config, "dtype", None) + mixed_precision = {"bfloat16": "bf16", "float16": "fp16", "float32": "no"}.get(policy_dtype) accelerator = Accelerator( step_scheduler_with_optimizer=False, + mixed_precision=mixed_precision, kwargs_handlers=[ddp_kwargs], cpu=force_cpu, )