mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
always use accelerate
This commit is contained in:
@@ -10,13 +10,7 @@ First, ensure you have accelerate installed:
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
Or install it with the LeRobot accelerate extra:
|
||||
|
||||
```bash
|
||||
pip install -e ".[accelerate]"
|
||||
```
|
||||
|
||||
## Training with Multiple GPUs
|
||||
## Training with Multiple GPUss
|
||||
|
||||
You can launch training in two ways:
|
||||
|
||||
|
||||
+1
-1
@@ -62,6 +62,7 @@ dependencies = [
|
||||
"datasets>=4.0.0,<4.2.0",
|
||||
"diffusers>=0.27.2,<0.36.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
|
||||
# Core dependencies
|
||||
"cmake>=3.29.0.1,<4.2.0",
|
||||
@@ -124,7 +125,6 @@ smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
accelerate = ["accelerate>=1.10.0"]
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
|
||||
# Development
|
||||
|
||||
@@ -147,7 +147,7 @@ def update_policy(
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
"""
|
||||
Main function to train a policy.
|
||||
|
||||
@@ -161,12 +161,20 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
|
||||
Args:
|
||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||
accelerator: Optional Accelerator instance. If None, one will be created automatically.
|
||||
"""
|
||||
cfg.validate()
|
||||
|
||||
# Create Accelerator if not provided
|
||||
# It will automatically detect if running in distributed mode or single-process mode
|
||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting
|
||||
# the lr_scheduler steps based on the num_processes
|
||||
if accelerator is None:
|
||||
accelerator = Accelerator(step_scheduler_with_optimizer=False)
|
||||
|
||||
# Determine if this is the main process (for logging and checkpointing)
|
||||
# When using accelerate, only the main process should log to avoid duplicate outputs
|
||||
is_main_process = accelerator.is_main_process if accelerator else True
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
# Only log on main process
|
||||
if is_main_process:
|
||||
@@ -183,8 +191,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
if cfg.seed is not None:
|
||||
set_seed(cfg.seed, accelerator=accelerator)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True, accelerator=accelerator)
|
||||
# Use accelerator's device
|
||||
device = accelerator.device
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
@@ -194,8 +202,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Wait for main process to finish downloading/caching dataset
|
||||
if accelerator:
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Now all other processes can safely load the dataset
|
||||
if not is_main_process:
|
||||
@@ -217,13 +224,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
|
||||
# Only move to device if not using accelerator (accelerator.prepare will handle device placement)
|
||||
if not accelerator:
|
||||
policy.to(device)
|
||||
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
if accelerator:
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||
processor_kwargs = {}
|
||||
@@ -259,7 +261,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
if is_main_process:
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
@@ -276,10 +277,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||
logging.info(f"{dataset.num_episodes=}")
|
||||
if accelerator:
|
||||
num_processes = accelerator.num_processes
|
||||
effective_bs = cfg.batch_size * num_processes
|
||||
logging.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}")
|
||||
num_processes = accelerator.num_processes
|
||||
effective_bs = cfg.batch_size * num_processes
|
||||
logging.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
@@ -306,11 +306,12 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
drop_last=False,
|
||||
prefetch_factor=2 if cfg.num_workers > 0 else None,
|
||||
)
|
||||
if accelerator:
|
||||
accelerator.wait_for_everyone()
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# Prepare everything with accelerator
|
||||
accelerator.wait_for_everyone()
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
@@ -324,7 +325,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
}
|
||||
|
||||
# Use effective batch size for proper epoch calculation in distributed training
|
||||
effective_batch_size = cfg.batch_size * (accelerator.num_processes if accelerator else 1)
|
||||
effective_batch_size = cfg.batch_size * accelerator.num_processes
|
||||
train_tracker = MetricsTracker(
|
||||
effective_batch_size,
|
||||
dataset.num_frames,
|
||||
@@ -349,10 +350,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
batch,
|
||||
optimizer,
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.policy.use_amp,
|
||||
accelerator=accelerator,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
@@ -379,7 +378,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
step=step,
|
||||
cfg=cfg,
|
||||
policy=policy if not accelerator else accelerator.unwrap_model(policy),
|
||||
policy=accelerator.unwrap_model(policy),
|
||||
optimizer=optimizer,
|
||||
scheduler=lr_scheduler,
|
||||
preprocessor=preprocessor,
|
||||
@@ -389,21 +388,15 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
if accelerator:
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if cfg.env and is_eval_step:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if cfg.policy.use_amp and not accelerator
|
||||
else nullcontext(),
|
||||
):
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
eval_info = eval_policy_all(
|
||||
envs=eval_env, # dict[suite][task_id] -> vec_env
|
||||
policy=policy if not accelerator else accelerator.unwrap_model(policy),
|
||||
policy=accelerator.unwrap_model(policy),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
@@ -441,8 +434,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
||||
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
|
||||
|
||||
if accelerator:
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if eval_env:
|
||||
close_envs(eval_env)
|
||||
@@ -451,7 +443,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
logging.info("End of training")
|
||||
|
||||
if cfg.policy.push_to_hub:
|
||||
unwrapped_policy = policy if not accelerator else accelerator.unwrap_model(policy)
|
||||
unwrapped_policy = accelerator.unwrap_model(policy)
|
||||
unwrapped_policy.push_model_to_hub(cfg)
|
||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
postprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
@@ -463,25 +455,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
distributed_env_vars = {
|
||||
"LOCAL_RANK": os.environ.get("LOCAL_RANK", "NOT SET"),
|
||||
"WORLD_SIZE": os.environ.get("WORLD_SIZE", "NOT SET"),
|
||||
"RANK": os.environ.get("RANK", "NOT SET"),
|
||||
"ACCELERATE_MIXED_PRECISION": os.environ.get("ACCELERATE_MIXED_PRECISION", "NOT SET"),
|
||||
}
|
||||
print(f"[PID {os.getpid()}] Distributed env vars: {distributed_env_vars}")
|
||||
|
||||
if is_launched_with_accelerate():
|
||||
print(f"[PID {os.getpid()}] Detected distributed training mode")
|
||||
import accelerate
|
||||
|
||||
# We set step_scheduler_with_optimizer False to prevent accelerate from
|
||||
# adjusting the lr_scheduler steps based on the num_processes
|
||||
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False)
|
||||
|
||||
init_logging(accelerator=accelerator)
|
||||
train(accelerator=accelerator)
|
||||
else:
|
||||
init_logging()
|
||||
train()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user