mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
always use accelerate
This commit is contained in:
@@ -10,13 +10,7 @@ First, ensure you have accelerate installed:
|
|||||||
pip install accelerate
|
pip install accelerate
|
||||||
```
|
```
|
||||||
|
|
||||||
Or install it with the LeRobot accelerate extra:
|
## Training with Multiple GPUss
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -e ".[accelerate]"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Training with Multiple GPUs
|
|
||||||
|
|
||||||
You can launch training in two ways:
|
You can launch training in two ways:
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -62,6 +62,7 @@ dependencies = [
|
|||||||
"datasets>=4.0.0,<4.2.0",
|
"datasets>=4.0.0,<4.2.0",
|
||||||
"diffusers>=0.27.2,<0.36.0",
|
"diffusers>=0.27.2,<0.36.0",
|
||||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||||
|
"accelerate>=1.10.0,<2.0.0",
|
||||||
|
|
||||||
# Core dependencies
|
# Core dependencies
|
||||||
"cmake>=3.29.0.1,<4.2.0",
|
"cmake>=3.29.0.1,<4.2.0",
|
||||||
@@ -124,7 +125,6 @@ smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>
|
|||||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
|
|
||||||
# Features
|
# Features
|
||||||
accelerate = ["accelerate>=1.10.0"]
|
|
||||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||||
|
|
||||||
# Development
|
# Development
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ def update_policy(
|
|||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||||
"""
|
"""
|
||||||
Main function to train a policy.
|
Main function to train a policy.
|
||||||
|
|
||||||
@@ -161,12 +161,20 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||||
|
accelerator: Optional Accelerator instance. If None, one will be created automatically.
|
||||||
"""
|
"""
|
||||||
cfg.validate()
|
cfg.validate()
|
||||||
|
|
||||||
|
# Create Accelerator if not provided
|
||||||
|
# It will automatically detect if running in distributed mode or single-process mode
|
||||||
|
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting
|
||||||
|
# the lr_scheduler steps based on the num_processes
|
||||||
|
if accelerator is None:
|
||||||
|
accelerator = Accelerator(step_scheduler_with_optimizer=False)
|
||||||
|
|
||||||
# Determine if this is the main process (for logging and checkpointing)
|
# Determine if this is the main process (for logging and checkpointing)
|
||||||
# When using accelerate, only the main process should log to avoid duplicate outputs
|
# When using accelerate, only the main process should log to avoid duplicate outputs
|
||||||
is_main_process = accelerator.is_main_process if accelerator else True
|
is_main_process = accelerator.is_main_process
|
||||||
|
|
||||||
# Only log on main process
|
# Only log on main process
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
@@ -183,8 +191,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
|||||||
if cfg.seed is not None:
|
if cfg.seed is not None:
|
||||||
set_seed(cfg.seed, accelerator=accelerator)
|
set_seed(cfg.seed, accelerator=accelerator)
|
||||||
|
|
||||||
# Check device is available
|
# Use accelerator's device
|
||||||
device = get_safe_torch_device(cfg.policy.device, log=True, accelerator=accelerator)
|
device = accelerator.device
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
@@ -194,8 +202,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
|||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
|
|
||||||
# Wait for main process to finish downloading/caching dataset
|
# Wait for main process to finish downloading/caching dataset
|
||||||
if accelerator:
|
accelerator.wait_for_everyone()
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
|
|
||||||
# Now all other processes can safely load the dataset
|
# Now all other processes can safely load the dataset
|
||||||
if not is_main_process:
|
if not is_main_process:
|
||||||
@@ -217,13 +224,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
|||||||
ds_meta=dataset.meta,
|
ds_meta=dataset.meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only move to device if not using accelerator (accelerator.prepare will handle device placement)
|
|
||||||
if not accelerator:
|
|
||||||
policy.to(device)
|
|
||||||
|
|
||||||
# Wait for all processes to finish policy creation before continuing
|
# Wait for all processes to finish policy creation before continuing
|
||||||
if accelerator:
|
accelerator.wait_for_everyone()
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
|
|
||||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||||
processor_kwargs = {}
|
processor_kwargs = {}
|
||||||
@@ -259,7 +261,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
|||||||
if is_main_process:
|
if is_main_process:
|
||||||
logging.info("Creating optimizer and scheduler")
|
logging.info("Creating optimizer and scheduler")
|
||||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
|
|
||||||
|
|
||||||
step = 0 # number of policy updates (forward + backward + optim)
|
step = 0 # number of policy updates (forward + backward + optim)
|
||||||
|
|
||||||
@@ -276,10 +277,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
|||||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||||
logging.info(f"{dataset.num_episodes=}")
|
logging.info(f"{dataset.num_episodes=}")
|
||||||
if accelerator:
|
num_processes = accelerator.num_processes
|
||||||
num_processes = accelerator.num_processes
|
effective_bs = cfg.batch_size * num_processes
|
||||||
effective_bs = cfg.batch_size * num_processes
|
logging.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}")
|
||||||
logging.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}")
|
|
||||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
@@ -306,11 +306,12 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
|||||||
drop_last=False,
|
drop_last=False,
|
||||||
prefetch_factor=2 if cfg.num_workers > 0 else None,
|
prefetch_factor=2 if cfg.num_workers > 0 else None,
|
||||||
)
|
)
|
||||||
if accelerator:
|
|
||||||
accelerator.wait_for_everyone()
|
# Prepare everything with accelerator
|
||||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
accelerator.wait_for_everyone()
|
||||||
policy, optimizer, dataloader, lr_scheduler
|
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||||
)
|
policy, optimizer, dataloader, lr_scheduler
|
||||||
|
)
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
@@ -324,7 +325,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Use effective batch size for proper epoch calculation in distributed training
|
# Use effective batch size for proper epoch calculation in distributed training
|
||||||
effective_batch_size = cfg.batch_size * (accelerator.num_processes if accelerator else 1)
|
effective_batch_size = cfg.batch_size * accelerator.num_processes
|
||||||
train_tracker = MetricsTracker(
|
train_tracker = MetricsTracker(
|
||||||
effective_batch_size,
|
effective_batch_size,
|
||||||
dataset.num_frames,
|
dataset.num_frames,
|
||||||
@@ -349,10 +350,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
|||||||
batch,
|
batch,
|
||||||
optimizer,
|
optimizer,
|
||||||
cfg.optimizer.grad_clip_norm,
|
cfg.optimizer.grad_clip_norm,
|
||||||
grad_scaler=grad_scaler,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
use_amp=cfg.policy.use_amp,
|
|
||||||
accelerator=accelerator,
|
accelerator=accelerator,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||||
@@ -379,7 +378,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
|||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
step=step,
|
step=step,
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
policy=policy if not accelerator else accelerator.unwrap_model(policy),
|
policy=accelerator.unwrap_model(policy),
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=lr_scheduler,
|
scheduler=lr_scheduler,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
@@ -389,21 +388,15 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
|||||||
if wandb_logger:
|
if wandb_logger:
|
||||||
wandb_logger.log_policy(checkpoint_dir)
|
wandb_logger.log_policy(checkpoint_dir)
|
||||||
|
|
||||||
if accelerator:
|
accelerator.wait_for_everyone()
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
|
|
||||||
if cfg.env and is_eval_step:
|
if cfg.env and is_eval_step:
|
||||||
step_id = get_step_identifier(step, cfg.steps)
|
step_id = get_step_identifier(step, cfg.steps)
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
with (
|
with torch.no_grad(), accelerator.autocast():
|
||||||
torch.no_grad(),
|
|
||||||
torch.autocast(device_type=device.type)
|
|
||||||
if cfg.policy.use_amp and not accelerator
|
|
||||||
else nullcontext(),
|
|
||||||
):
|
|
||||||
eval_info = eval_policy_all(
|
eval_info = eval_policy_all(
|
||||||
envs=eval_env, # dict[suite][task_id] -> vec_env
|
envs=eval_env, # dict[suite][task_id] -> vec_env
|
||||||
policy=policy if not accelerator else accelerator.unwrap_model(policy),
|
policy=accelerator.unwrap_model(policy),
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
n_episodes=cfg.eval.n_episodes,
|
n_episodes=cfg.eval.n_episodes,
|
||||||
@@ -441,8 +434,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
|||||||
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
||||||
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
|
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
|
||||||
|
|
||||||
if accelerator:
|
accelerator.wait_for_everyone()
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
|
|
||||||
if eval_env:
|
if eval_env:
|
||||||
close_envs(eval_env)
|
close_envs(eval_env)
|
||||||
@@ -451,7 +443,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
|||||||
logging.info("End of training")
|
logging.info("End of training")
|
||||||
|
|
||||||
if cfg.policy.push_to_hub:
|
if cfg.policy.push_to_hub:
|
||||||
unwrapped_policy = policy if not accelerator else accelerator.unwrap_model(policy)
|
unwrapped_policy = accelerator.unwrap_model(policy)
|
||||||
unwrapped_policy.push_model_to_hub(cfg)
|
unwrapped_policy.push_model_to_hub(cfg)
|
||||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||||
postprocessor.push_to_hub(cfg.policy.repo_id)
|
postprocessor.push_to_hub(cfg.policy.repo_id)
|
||||||
@@ -463,25 +455,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import os
|
main()
|
||||||
distributed_env_vars = {
|
|
||||||
"LOCAL_RANK": os.environ.get("LOCAL_RANK", "NOT SET"),
|
|
||||||
"WORLD_SIZE": os.environ.get("WORLD_SIZE", "NOT SET"),
|
|
||||||
"RANK": os.environ.get("RANK", "NOT SET"),
|
|
||||||
"ACCELERATE_MIXED_PRECISION": os.environ.get("ACCELERATE_MIXED_PRECISION", "NOT SET"),
|
|
||||||
}
|
|
||||||
print(f"[PID {os.getpid()}] Distributed env vars: {distributed_env_vars}")
|
|
||||||
|
|
||||||
if is_launched_with_accelerate():
|
|
||||||
print(f"[PID {os.getpid()}] Detected distributed training mode")
|
|
||||||
import accelerate
|
|
||||||
|
|
||||||
# We set step_scheduler_with_optimizer False to prevent accelerate from
|
|
||||||
# adjusting the lr_scheduler steps based on the num_processes
|
|
||||||
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False)
|
|
||||||
|
|
||||||
init_logging(accelerator=accelerator)
|
|
||||||
train(accelerator=accelerator)
|
|
||||||
else:
|
|
||||||
init_logging()
|
|
||||||
train()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user