always use accelerate

This commit is contained in:
Pepijn
2025-10-14 14:24:55 +02:00
parent d2687e9486
commit 4061b3f5b3
3 changed files with 34 additions and 69 deletions
+1 -7
View File
@@ -10,13 +10,7 @@ First, ensure you have accelerate installed:
pip install accelerate
```
Or install it with the LeRobot accelerate extra:
```bash
pip install -e ".[accelerate]"
```
## Training with Multiple GPUs
## Training with Multiple GPUss
You can launch training in two ways:
+1 -1
View File
@@ -62,6 +62,7 @@ dependencies = [
"datasets>=4.0.0,<4.2.0",
"diffusers>=0.27.2,<0.36.0",
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
"accelerate>=1.10.0,<2.0.0",
# Core dependencies
"cmake>=3.29.0.1,<4.2.0",
@@ -124,7 +125,6 @@ smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
accelerate = ["accelerate>=1.10.0"]
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
# Development
+32 -61
View File
@@ -147,7 +147,7 @@ def update_policy(
@parser.wrap()
def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
"""
Main function to train a policy.
@@ -161,12 +161,20 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
Args:
cfg: A `TrainPipelineConfig` object containing all training configurations.
accelerator: Optional Accelerator instance. If None, one will be created automatically.
"""
cfg.validate()
# Create Accelerator if not provided
# It will automatically detect if running in distributed mode or single-process mode
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting
# the lr_scheduler steps based on the num_processes
if accelerator is None:
accelerator = Accelerator(step_scheduler_with_optimizer=False)
# Determine if this is the main process (for logging and checkpointing)
# When using accelerate, only the main process should log to avoid duplicate outputs
is_main_process = accelerator.is_main_process if accelerator else True
is_main_process = accelerator.is_main_process
# Only log on main process
if is_main_process:
@@ -183,8 +191,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
if cfg.seed is not None:
set_seed(cfg.seed, accelerator=accelerator)
# Check device is available
device = get_safe_torch_device(cfg.policy.device, log=True, accelerator=accelerator)
# Use accelerator's device
device = accelerator.device
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -194,8 +202,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
dataset = make_dataset(cfg)
# Wait for main process to finish downloading/caching dataset
if accelerator:
accelerator.wait_for_everyone()
accelerator.wait_for_everyone()
# Now all other processes can safely load the dataset
if not is_main_process:
@@ -217,13 +224,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
ds_meta=dataset.meta,
)
# Only move to device if not using accelerator (accelerator.prepare will handle device placement)
if not accelerator:
policy.to(device)
# Wait for all processes to finish policy creation before continuing
if accelerator:
accelerator.wait_for_everyone()
accelerator.wait_for_everyone()
# Create processors - only provide dataset_stats if not resuming from saved processors
processor_kwargs = {}
@@ -259,7 +261,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
if is_main_process:
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
step = 0 # number of policy updates (forward + backward + optim)
@@ -276,10 +277,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
if accelerator:
num_processes = accelerator.num_processes
effective_bs = cfg.batch_size * num_processes
logging.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}")
num_processes = accelerator.num_processes
effective_bs = cfg.batch_size * num_processes
logging.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
@@ -306,11 +306,12 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
drop_last=False,
prefetch_factor=2 if cfg.num_workers > 0 else None,
)
if accelerator:
accelerator.wait_for_everyone()
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
# Prepare everything with accelerator
accelerator.wait_for_everyone()
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
dl_iter = cycle(dataloader)
policy.train()
@@ -324,7 +325,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
}
# Use effective batch size for proper epoch calculation in distributed training
effective_batch_size = cfg.batch_size * (accelerator.num_processes if accelerator else 1)
effective_batch_size = cfg.batch_size * accelerator.num_processes
train_tracker = MetricsTracker(
effective_batch_size,
dataset.num_frames,
@@ -349,10 +350,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
batch,
optimizer,
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.policy.use_amp,
accelerator=accelerator,
lr_scheduler=lr_scheduler,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -379,7 +378,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
checkpoint_dir=checkpoint_dir,
step=step,
cfg=cfg,
policy=policy if not accelerator else accelerator.unwrap_model(policy),
policy=accelerator.unwrap_model(policy),
optimizer=optimizer,
scheduler=lr_scheduler,
preprocessor=preprocessor,
@@ -389,21 +388,15 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
if accelerator:
accelerator.wait_for_everyone()
accelerator.wait_for_everyone()
if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
with (
torch.no_grad(),
torch.autocast(device_type=device.type)
if cfg.policy.use_amp and not accelerator
else nullcontext(),
):
with torch.no_grad(), accelerator.autocast():
eval_info = eval_policy_all(
envs=eval_env, # dict[suite][task_id] -> vec_env
policy=policy if not accelerator else accelerator.unwrap_model(policy),
policy=accelerator.unwrap_model(policy),
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
@@ -441,8 +434,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
if accelerator:
accelerator.wait_for_everyone()
accelerator.wait_for_everyone()
if eval_env:
close_envs(eval_env)
@@ -451,7 +443,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
logging.info("End of training")
if cfg.policy.push_to_hub:
unwrapped_policy = policy if not accelerator else accelerator.unwrap_model(policy)
unwrapped_policy = accelerator.unwrap_model(policy)
unwrapped_policy.push_model_to_hub(cfg)
preprocessor.push_to_hub(cfg.policy.repo_id)
postprocessor.push_to_hub(cfg.policy.repo_id)
@@ -463,25 +455,4 @@ def main():
if __name__ == "__main__":
import os
distributed_env_vars = {
"LOCAL_RANK": os.environ.get("LOCAL_RANK", "NOT SET"),
"WORLD_SIZE": os.environ.get("WORLD_SIZE", "NOT SET"),
"RANK": os.environ.get("RANK", "NOT SET"),
"ACCELERATE_MIXED_PRECISION": os.environ.get("ACCELERATE_MIXED_PRECISION", "NOT SET"),
}
print(f"[PID {os.getpid()}] Distributed env vars: {distributed_env_vars}")
if is_launched_with_accelerate():
print(f"[PID {os.getpid()}] Detected distributed training mode")
import accelerate
# We set step_scheduler_with_optimizer False to prevent accelerate from
# adjusting the lr_scheduler steps based on the num_processes
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False)
init_logging(accelerator=accelerator)
train(accelerator=accelerator)
else:
init_logging()
train()
main()