mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
Fix: Respect policy.device=cpu config in training (#2778)
* fix cpu training in lerobot_train * Update src/lerobot/scripts/lerobot_train.py Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
@@ -259,7 +259,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
from accelerate.utils import DistributedDataParallelKwargs
|
from accelerate.utils import DistributedDataParallelKwargs
|
||||||
|
|
||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||||
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
|
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||||
|
# Force the device to be CPU when policy.device is set to CPU.
|
||||||
|
force_cpu = cfg.policy.device == "cpu"
|
||||||
|
accelerator = Accelerator(
|
||||||
|
step_scheduler_with_optimizer=False,
|
||||||
|
kwargs_handlers=[ddp_kwargs],
|
||||||
|
cpu=force_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
init_logging(accelerator=accelerator)
|
init_logging(accelerator=accelerator)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user