From f30da2dec1ba2f3e81dee36ab9bec59f357be357 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Thu, 2 Oct 2025 18:11:27 +0200 Subject: [PATCH] Enhance training and logging functionality with accelerator support - Added support for multi-GPU training by introducing an `accelerator` parameter in training functions. - Updated `update_policy` to handle gradient updates based on the presence of an accelerator. - Modified logging to prevent duplicate messages in non-main processes. - Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage. - Updated `MetricsTracker` to account for the number of processes when calculating metrics. - Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency. --- pyproject.toml | 1 + src/lerobot/scripts/lerobot_train.py | 149 ++++++++++++++++++++------- src/lerobot/utils/logging_utils.py | 6 +- src/lerobot/utils/random_utils.py | 10 +- src/lerobot/utils/utils.py | 16 ++- 5 files changed, 137 insertions(+), 45 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f350fac0a..f639fa0a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,6 +125,7 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11", "lerobot[grpcio-dep]" # Features async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"] +accelerate = ["accelerate>=1.10.0"] # Development dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"] diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index bc66618ca..4493b1167 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -18,6 +18,7 @@ import time from contextlib import nullcontext from pprint import pformat from typing import Any +from collections.abc import Callable import torch from termcolor import colored @@ -51,6 +52,7 @@ from lerobot.utils.utils import ( get_safe_torch_device, has_method, init_logging, + is_launched_with_accelerate, ) @@ -64,6 +66,7 @@ def update_policy( lr_scheduler=None, use_amp: bool = False, lock=None, + accelerator: Callable | None = None, ) -> tuple[MetricsTracker, dict]: """ Performs a single training step to update the policy's weights. @@ -81,6 +84,7 @@ def update_policy( lr_scheduler: An optional learning rate scheduler. use_amp: A boolean indicating whether to use automatic mixed precision. lock: An optional lock for thread-safe optimizer updates. + accelerator: An optional accelerator, for multi-gpu training. Returns: A tuple containing: @@ -90,26 +94,36 @@ def update_policy( start_time = time.perf_counter() device = get_device_from_parameters(policy) policy.train() - with torch.autocast(device_type=device.type) if use_amp else nullcontext(): + with torch.autocast(device_type=device.type) if use_amp and accelerator is None else nullcontext(): loss, output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) - grad_scaler.scale(loss).backward() - # Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**. - grad_scaler.unscale_(optimizer) + if accelerator: + accelerator.backward(loss) + accelerator.unscale_gradients(optimizer=optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.parameters(), + grad_clip_norm, + error_if_nonfinite=False, + ) + optimizer.step() + else: + grad_scaler.scale(loss).backward() + # Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**. + grad_scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_( - policy.parameters(), - grad_clip_norm, - error_if_nonfinite=False, - ) + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.parameters(), + grad_clip_norm, + error_if_nonfinite=False, + ) - # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, - # although it still skips optimizer.step() if the gradients contain infs or NaNs. - with lock if lock is not None else nullcontext(): - grad_scaler.step(optimizer) - # Updates the scale for next iteration. - grad_scaler.update() + # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, + # although it still skips optimizer.step() if the gradients contain infs or NaNs. + with lock if lock is not None else nullcontext(): + grad_scaler.step(optimizer) + # Updates the scale for next iteration. + grad_scaler.update() optimizer.zero_grad() @@ -117,9 +131,13 @@ def update_policy( if lr_scheduler is not None: lr_scheduler.step() - if has_method(policy, "update"): - # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). - policy.update() + if accelerator: + if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): + accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update() + else: + if has_method(policy, "update"): + # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). + policy.update() train_metrics.loss = loss.item() train_metrics.grad_norm = grad_norm.item() @@ -129,7 +147,7 @@ def update_policy( @parser.wrap() -def train(cfg: TrainPipelineConfig): +def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): """ Main function to train a policy. @@ -147,6 +165,10 @@ def train(cfg: TrainPipelineConfig): cfg.validate() logging.info(pformat(cfg.to_dict())) + if accelerator and not accelerator.is_main_process: + # Disable logging on non-main processes. + cfg.wandb.enable = False + if cfg.wandb.enable and cfg.wandb.project: wandb_logger = WandBLogger(cfg) else: @@ -154,10 +176,10 @@ def train(cfg: TrainPipelineConfig): logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) if cfg.seed is not None: - set_seed(cfg.seed) + set_seed(cfg.seed, accelerator=accelerator) # Check device is available - device = get_safe_torch_device(cfg.policy.device, log=True) + device = get_safe_torch_device(cfg.policy.device, log=True, accelerator=accelerator) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -177,6 +199,7 @@ def train(cfg: TrainPipelineConfig): cfg=cfg.policy, ds_meta=dataset.meta, ) + policy.to(device) # Create processors - only provide dataset_stats if not resuming from saved processors processor_kwargs = {} @@ -221,14 +244,15 @@ def train(cfg: TrainPipelineConfig): num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) - logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") - if cfg.env is not None: - logging.info(f"{cfg.env.task=}") - logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") - logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") - logging.info(f"{dataset.num_episodes=}") - logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") - logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") + if not accelerator or accelerator.is_main_process: + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") + if cfg.env is not None: + logging.info(f"{cfg.env.task=}") + logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") + logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") + logging.info(f"{dataset.num_episodes=}") + logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") + logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") # create dataloader for offline training if hasattr(cfg.policy, "drop_n_last_frames"): @@ -253,6 +277,10 @@ def train(cfg: TrainPipelineConfig): drop_last=False, prefetch_factor=2, ) + if accelerator: + policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( + policy, optimizer, dataloader, lr_scheduler + ) dl_iter = cycle(dataloader) policy.train() @@ -266,10 +294,16 @@ def train(cfg: TrainPipelineConfig): } train_tracker = MetricsTracker( - cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step + cfg.batch_size, + dataset.num_frames, + dataset.num_episodes, + train_metrics, + initial_step=step, + accelerator=accelerator, ) - logging.info("Start offline training on a fixed dataset") + if not accelerator or accelerator.is_main_process: + logging.info("Start offline training on a fixed dataset") for _ in range(step, cfg.steps): start_time = time.perf_counter() batch = next(dl_iter) @@ -285,15 +319,26 @@ def train(cfg: TrainPipelineConfig): grad_scaler=grad_scaler, lr_scheduler=lr_scheduler, use_amp=cfg.policy.use_amp, + accelerator=accelerator, ) # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # increment `step` here. step += 1 train_tracker.step() - is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 - is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps - is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 + is_log_step = ( + cfg.log_freq > 0 and step % cfg.log_freq == 0 and (not accelerator or accelerator.is_main_process) + ) + is_saving_step = ( + step % cfg.save_freq == 0 + or step == cfg.steps + and (not accelerator or accelerator.is_main_process) + ) + is_eval_step = ( + cfg.eval_freq > 0 + and step % cfg.eval_freq == 0 + and (not accelerator or accelerator.is_main_process) + ) if is_log_step: logging.info(train_tracker) @@ -308,22 +353,31 @@ def train(cfg: TrainPipelineConfig): logging.info(f"Checkpoint policy after step {step}") checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) save_checkpoint( - checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor + checkpoint_dir=checkpoint_dir, + step=step, + cfg=cfg, + policy=policy if not accelerator else accelerator.unwrap_model(policy), + optimizer=optimizer, + scheduler=lr_scheduler, ) update_last_checkpoint(checkpoint_dir) if wandb_logger: wandb_logger.log_policy(checkpoint_dir) + if accelerator: + accelerator.wait_for_everyone() if cfg.env and is_eval_step: step_id = get_step_identifier(step, cfg.steps) logging.info(f"Eval policy at step {step}") with ( torch.no_grad(), - torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), + torch.autocast(device_type=device.type) + if cfg.policy.use_amp and not accelerator + else nullcontext(), ): eval_info = eval_policy_all( envs=eval_env, # dict[suite][task_id] -> vec_env - policy=policy, + policy=policy if not accelerator else accelerator.unwrap_model(policy), preprocessor=preprocessor, postprocessor=postprocessor, n_episodes=cfg.eval.n_episodes, @@ -346,7 +400,12 @@ def train(cfg: TrainPipelineConfig): "eval_s": AverageMeter("eval_s", ":.3f"), } eval_tracker = MetricsTracker( - cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step + cfg.batch_size, + dataset.num_frames, + dataset.num_episodes, + eval_metrics, + initial_step=step, + accelerator=accelerator, ) eval_tracker.eval_s = aggregated.pop("eval_s") eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward") @@ -358,7 +417,9 @@ def train(cfg: TrainPipelineConfig): if eval_env: close_envs(eval_env) - logging.info("End of training") + + if not accelerator or accelerator.is_main_process: + logging.info("End of training") if cfg.policy.push_to_hub: policy.push_model_to_hub(cfg) @@ -372,4 +433,12 @@ def main(): if __name__ == "__main__": - main() + if is_launched_with_accelerate(): + import accelerate + + # We set step_scheduler_with_optimizer False to prevent accelerate from + # adjusting the lr_scheduler steps based on the num_processes + accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False) + train(accelerator=accelerator) + else: + train() diff --git a/src/lerobot/utils/logging_utils.py b/src/lerobot/utils/logging_utils.py index b6404e66d..c4c1f42e0 100644 --- a/src/lerobot/utils/logging_utils.py +++ b/src/lerobot/utils/logging_utils.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable from typing import Any from lerobot.utils.utils import format_big_number @@ -84,6 +85,7 @@ class MetricsTracker: "samples", "episodes", "epochs", + "accelerator", ] def __init__( @@ -93,6 +95,7 @@ class MetricsTracker: num_episodes: int, metrics: dict[str, AverageMeter], initial_step: int = 0, + accelerator: Callable | None = None, ): self.__dict__.update(dict.fromkeys(self.__keys__)) self._batch_size = batch_size @@ -106,6 +109,7 @@ class MetricsTracker: self.samples = self.steps * self._batch_size self.episodes = self.samples / self._avg_samples_per_ep self.epochs = self.samples / self._num_frames + self.accelerator = accelerator def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any: if name in self.__dict__: @@ -128,7 +132,7 @@ class MetricsTracker: Updates metrics that depend on 'step' for one step. """ self.steps += 1 - self.samples += self._batch_size + self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1) self.episodes = self.samples / self._avg_samples_per_ep self.epochs = self.samples / self._num_frames diff --git a/src/lerobot/utils/random_utils.py b/src/lerobot/utils/random_utils.py index 1bb1f0631..4cf1d6bd1 100644 --- a/src/lerobot/utils/random_utils.py +++ b/src/lerobot/utils/random_utils.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import random -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager from pathlib import Path from typing import Any @@ -164,14 +164,20 @@ def set_rng_state(random_state_dict: dict[str, Any]): torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) -def set_seed(seed) -> None: +def set_seed(seed, accelerator: Callable | None = None) -> None: """Set seed for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) + if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + if accelerator: + from accelerate.utils import set_seed + + set_seed(seed) + @contextmanager def seeded_context(seed: int) -> Generator[None, None, None]: diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 8777d5a9d..c01906bc9 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -20,6 +20,7 @@ import select import subprocess import sys import time +from collections.abc import Callable from copy import copy, deepcopy from datetime import datetime from pathlib import Path @@ -49,13 +50,15 @@ def auto_select_torch_device() -> torch.device: # TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level -def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: +def get_safe_torch_device( + try_device: str, log: bool = False, accelerator: Callable | None = None +) -> torch.device: """Given a string, return a torch.device with checks on whether the device is available.""" try_device = str(try_device) match try_device: case "cuda": assert torch.cuda.is_available() - device = torch.device("cuda") + device = accelerator.device if accelerator else torch.device("cuda") case "mps": assert torch.backends.mps.is_available() device = torch.device("mps") @@ -109,6 +112,7 @@ def init_logging( display_pid: bool = False, console_level: str = "INFO", file_level: str = "DEBUG", + accelerator: Callable | None = None, ): def custom_format(record: logging.LogRecord) -> str: dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -137,6 +141,10 @@ def init_logging( console_handler.setFormatter(formatter) console_handler.setLevel(console_level.upper()) logger.addHandler(console_handler) + if accelerator is not None and not accelerator.is_main_process: + # Disable duplicate logging on non-main processes + logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.") + logging.getLogger().setLevel(logging.WARNING) # Additionally write logs to file if log_file is not None: @@ -158,6 +166,10 @@ def format_big_number(num, precision=0): return num +def is_launched_with_accelerate() -> bool: + return "ACCELERATE_MIXED_PRECISION" in os.environ + + def say(text: str, blocking: bool = False): system = platform.system()