From 4061b3f5b3d11d0147186e68c7ad21a5ad6f487a Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 14 Oct 2025 14:24:55 +0200 Subject: [PATCH] always use accelerate --- docs/source/multi_gpu_training.mdx | 8 +-- pyproject.toml | 2 +- src/lerobot/scripts/lerobot_train.py | 93 ++++++++++------------------ 3 files changed, 34 insertions(+), 69 deletions(-) diff --git a/docs/source/multi_gpu_training.mdx b/docs/source/multi_gpu_training.mdx index 41e5b0794..5d8319acb 100644 --- a/docs/source/multi_gpu_training.mdx +++ b/docs/source/multi_gpu_training.mdx @@ -10,13 +10,7 @@ First, ensure you have accelerate installed: pip install accelerate ``` -Or install it with the LeRobot accelerate extra: - -```bash -pip install -e ".[accelerate]" -``` - -## Training with Multiple GPUs +## Training with Multiple GPUss You can launch training in two ways: diff --git a/pyproject.toml b/pyproject.toml index 6947acdb9..a8cf8dcae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ dependencies = [ "datasets>=4.0.0,<4.2.0", "diffusers>=0.27.2,<0.36.0", "huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0", + "accelerate>=1.10.0,<2.0.0", # Core dependencies "cmake>=3.29.0.1,<4.2.0", @@ -124,7 +125,6 @@ smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate> hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] # Features -accelerate = ["accelerate>=1.10.0"] async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"] # Development diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index b89ca9148..19d2a98ad 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -147,7 +147,7 @@ def update_policy( @parser.wrap() -def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): +def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): """ Main function to train a policy. @@ -161,12 +161,20 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): Args: cfg: A `TrainPipelineConfig` object containing all training configurations. + accelerator: Optional Accelerator instance. If None, one will be created automatically. """ cfg.validate() + # Create Accelerator if not provided + # It will automatically detect if running in distributed mode or single-process mode + # We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting + # the lr_scheduler steps based on the num_processes + if accelerator is None: + accelerator = Accelerator(step_scheduler_with_optimizer=False) + # Determine if this is the main process (for logging and checkpointing) # When using accelerate, only the main process should log to avoid duplicate outputs - is_main_process = accelerator.is_main_process if accelerator else True + is_main_process = accelerator.is_main_process # Only log on main process if is_main_process: @@ -183,8 +191,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): if cfg.seed is not None: set_seed(cfg.seed, accelerator=accelerator) - # Check device is available - device = get_safe_torch_device(cfg.policy.device, log=True, accelerator=accelerator) + # Use accelerator's device + device = accelerator.device torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -194,8 +202,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): dataset = make_dataset(cfg) # Wait for main process to finish downloading/caching dataset - if accelerator: - accelerator.wait_for_everyone() + accelerator.wait_for_everyone() # Now all other processes can safely load the dataset if not is_main_process: @@ -217,13 +224,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): ds_meta=dataset.meta, ) - # Only move to device if not using accelerator (accelerator.prepare will handle device placement) - if not accelerator: - policy.to(device) - # Wait for all processes to finish policy creation before continuing - if accelerator: - accelerator.wait_for_everyone() + accelerator.wait_for_everyone() # Create processors - only provide dataset_stats if not resuming from saved processors processor_kwargs = {} @@ -259,7 +261,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): if is_main_process: logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) - grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) step = 0 # number of policy updates (forward + backward + optim) @@ -276,10 +277,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") logging.info(f"{dataset.num_episodes=}") - if accelerator: - num_processes = accelerator.num_processes - effective_bs = cfg.batch_size * num_processes - logging.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}") + num_processes = accelerator.num_processes + effective_bs = cfg.batch_size * num_processes + logging.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") @@ -306,11 +306,12 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): drop_last=False, prefetch_factor=2 if cfg.num_workers > 0 else None, ) - if accelerator: - accelerator.wait_for_everyone() - policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( - policy, optimizer, dataloader, lr_scheduler - ) + + # Prepare everything with accelerator + accelerator.wait_for_everyone() + policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( + policy, optimizer, dataloader, lr_scheduler + ) dl_iter = cycle(dataloader) policy.train() @@ -324,7 +325,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): } # Use effective batch size for proper epoch calculation in distributed training - effective_batch_size = cfg.batch_size * (accelerator.num_processes if accelerator else 1) + effective_batch_size = cfg.batch_size * accelerator.num_processes train_tracker = MetricsTracker( effective_batch_size, dataset.num_frames, @@ -349,10 +350,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): batch, optimizer, cfg.optimizer.grad_clip_norm, - grad_scaler=grad_scaler, - lr_scheduler=lr_scheduler, - use_amp=cfg.policy.use_amp, accelerator=accelerator, + lr_scheduler=lr_scheduler, ) # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we @@ -379,7 +378,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): checkpoint_dir=checkpoint_dir, step=step, cfg=cfg, - policy=policy if not accelerator else accelerator.unwrap_model(policy), + policy=accelerator.unwrap_model(policy), optimizer=optimizer, scheduler=lr_scheduler, preprocessor=preprocessor, @@ -389,21 +388,15 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): if wandb_logger: wandb_logger.log_policy(checkpoint_dir) - if accelerator: - accelerator.wait_for_everyone() + accelerator.wait_for_everyone() if cfg.env and is_eval_step: step_id = get_step_identifier(step, cfg.steps) logging.info(f"Eval policy at step {step}") - with ( - torch.no_grad(), - torch.autocast(device_type=device.type) - if cfg.policy.use_amp and not accelerator - else nullcontext(), - ): + with torch.no_grad(), accelerator.autocast(): eval_info = eval_policy_all( envs=eval_env, # dict[suite][task_id] -> vec_env - policy=policy if not accelerator else accelerator.unwrap_model(policy), + policy=accelerator.unwrap_model(policy), preprocessor=preprocessor, postprocessor=postprocessor, n_episodes=cfg.eval.n_episodes, @@ -441,8 +434,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): wandb_logger.log_dict(wandb_log_dict, step, mode="eval") wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval") - if accelerator: - accelerator.wait_for_everyone() + accelerator.wait_for_everyone() if eval_env: close_envs(eval_env) @@ -451,7 +443,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): logging.info("End of training") if cfg.policy.push_to_hub: - unwrapped_policy = policy if not accelerator else accelerator.unwrap_model(policy) + unwrapped_policy = accelerator.unwrap_model(policy) unwrapped_policy.push_model_to_hub(cfg) preprocessor.push_to_hub(cfg.policy.repo_id) postprocessor.push_to_hub(cfg.policy.repo_id) @@ -463,25 +455,4 @@ def main(): if __name__ == "__main__": - import os - distributed_env_vars = { - "LOCAL_RANK": os.environ.get("LOCAL_RANK", "NOT SET"), - "WORLD_SIZE": os.environ.get("WORLD_SIZE", "NOT SET"), - "RANK": os.environ.get("RANK", "NOT SET"), - "ACCELERATE_MIXED_PRECISION": os.environ.get("ACCELERATE_MIXED_PRECISION", "NOT SET"), - } - print(f"[PID {os.getpid()}] Distributed env vars: {distributed_env_vars}") - - if is_launched_with_accelerate(): - print(f"[PID {os.getpid()}] Detected distributed training mode") - import accelerate - - # We set step_scheduler_with_optimizer False to prevent accelerate from - # adjusting the lr_scheduler steps based on the num_processes - accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False) - - init_logging(accelerator=accelerator) - train(accelerator=accelerator) - else: - init_logging() - train() + main()