diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index a20761226..82c265c93 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -493,9 +493,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): # Prepare everything with accelerator accelerator.wait_for_everyone() - policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( - policy, optimizer, dataloader, lr_scheduler - ) + if eval_dataloader is not None: + policy, optimizer, dataloader, lr_scheduler, eval_dataloader = accelerator.prepare( + policy, optimizer, dataloader, lr_scheduler, eval_dataloader + ) + else: + policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( + policy, optimizer, dataloader, lr_scheduler + ) # FSDP optimizer state is sharded across ranks, so it can only be loaded once the optimizer and # model are FSDP-wrapped (i.e. after `prepare`). Collective: every rank must participate. @@ -614,6 +619,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): eval_loss_sum += loss.item() n_eval_batches += 1 eval_loss = eval_loss_sum / max(n_eval_batches, 1) + eval_loss = torch.tensor(eval_loss, device=device) + eval_loss = accelerator.reduce(eval_loss, reduction="mean").item() policy.train() if is_main_process: