diff --git a/src/lerobot/scripts/train_accelerate.py b/src/lerobot/scripts/train_accelerate.py new file mode 100644 index 000000000..f226a3dba --- /dev/null +++ b/src/lerobot/scripts/train_accelerate.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import time +from contextlib import nullcontext +from pprint import pformat +from typing import Any, Callable + +import accelerate +import torch +from termcolor import colored +from torch.amp import GradScaler +from torch.optim import Optimizer + +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.sampler import EpisodeAwareSampler +from lerobot.common.datasets.utils import cycle +from lerobot.common.envs.factory import make_env +from lerobot.common.optim.factory import make_optimizer_and_scheduler +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.utils import get_device_from_parameters +from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker +from lerobot.common.utils.random_utils import set_seed +from lerobot.common.utils.train_utils import ( + get_step_checkpoint_dir, + get_step_identifier, + load_training_state, + save_checkpoint, + update_last_checkpoint, +) +from lerobot.common.utils.utils import ( + format_big_number, + get_safe_torch_device, + has_method, + init_logging, + is_launched_with_accelerate, +) +from lerobot.common.utils.wandb_utils import WandBLogger +from lerobot.configs import parser +from lerobot.configs.train import TrainPipelineConfig +from lerobot.scripts.eval import eval_policy + + +def update_policy( + train_metrics: MetricsTracker, + policy: PreTrainedPolicy, + batch: Any, + optimizer: Optimizer, + grad_clip_norm: float, + grad_scaler: GradScaler, + lr_scheduler=None, + use_amp: bool = False, + lock=None, + accelerator: Callable = None, +) -> tuple[MetricsTracker, dict]: + start_time = time.perf_counter() + + policy.train() + + loss, output_dict = policy.forward(batch) + + accelerator.backward(loss) + accelerator.unscale_gradients(optimizer=optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.parameters(), + grad_clip_norm, + error_if_nonfinite=False, + ) + optimizer.step() + + optimizer.zero_grad() + + # Step through pytorch scheduler at every batch instead of epoch + if lr_scheduler is not None: + lr_scheduler.step() + + if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): + accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update() + + train_metrics.loss = loss.item() + train_metrics.grad_norm = grad_norm.item() + train_metrics.lr = optimizer.param_groups[0]["lr"] + train_metrics.update_s = time.perf_counter() - start_time + return train_metrics, output_dict + + +@parser.wrap() +def train(cfg: TrainPipelineConfig, accelerator: Callable): + cfg.validate() + logging.info(pformat(cfg.to_dict())) + + if accelerator.is_main_process: + # Disable logging on non-main processes. + cfg.wandb.enable = False + + if cfg.wandb.enable and cfg.wandb.project: + wandb_logger = WandBLogger(cfg) + else: + wandb_logger = None + logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + + if cfg.seed is not None: + set_seed(cfg.seed, accelerator=accelerator) + + # Check device is available + device = get_safe_torch_device(cfg.device, log=True, accelerator=accelerator) + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + + logging.info("Creating dataset") + dataset = make_dataset(cfg) + + # Create environment used for evaluating checkpoints during training on simulation data. + # On real-world data, no need to create an environment as evaluations are done outside train.py, + # using the eval.py instead, with gym_dora environment and dora-rs. + eval_env = None + if cfg.eval_freq > 0 and cfg.env is not None: + logging.info("Creating env") + eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size) + + logging.info("Creating policy") + policy = make_policy( + cfg=cfg.policy, + device=device, + ds_meta=dataset.meta, + ) + policy.to(device) + logging.info("Creating optimizer and scheduler") + optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) + grad_scaler = GradScaler(device, enabled=cfg.use_amp) + + step = 0 # number of policy updates (forward + backward + optim) + + if cfg.resume: + step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) + + num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) + num_total_params = sum(p.numel() for p in policy.parameters()) + if accelerator.is_main_process: + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") + if cfg.env is not None: + logging.info(f"{cfg.env.task=}") + logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") + logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") + logging.info(f"{dataset.num_episodes=}") + logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") + logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") + + # create dataloader for offline training + if hasattr(cfg.policy, "drop_n_last_frames"): + shuffle = False + sampler = EpisodeAwareSampler( + dataset.episode_data_index, + drop_n_last_frames=cfg.policy.drop_n_last_frames, + shuffle=True, + ) + else: + shuffle = True + sampler = None + + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=cfg.num_workers, + batch_size=cfg.batch_size, + shuffle=shuffle, + sampler=sampler, + pin_memory=device.type != "cpu", + drop_last=False, + ) + + policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( + policy, optimizer, dataloader, lr_scheduler + ) + + dl_iter = cycle(dataloader) + + policy.train() + + train_metrics = { + "loss": AverageMeter("loss", ":.3f"), + "grad_norm": AverageMeter("grdn", ":.3f"), + "lr": AverageMeter("lr", ":0.1e"), + "update_s": AverageMeter("updt_s", ":.3f"), + "dataloading_s": AverageMeter("data_s", ":.3f"), + } + + train_tracker = MetricsTracker( + cfg.batch_size, + dataset.num_frames, + dataset.num_episodes, + train_metrics, + initial_step=step, + accelerator=accelerator, + ) + if accelerator.is_main_process: + logging.info("Start offline training on a fixed dataset") + + for _ in range(step, cfg.steps): + start_time = time.perf_counter() + batch = next(dl_iter) + train_tracker.dataloading_s = time.perf_counter() - start_time + + train_tracker, output_dict = update_policy( + train_tracker, + policy, + batch, + optimizer, + cfg.optimizer.grad_clip_norm, + grad_scaler=grad_scaler, + lr_scheduler=lr_scheduler, + use_amp=cfg.use_amp, + accelerator=accelerator, + ) + + # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we + # increment `step` here. + step += 1 + train_tracker.step() + is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and accelerator.is_main_process + is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps and accelerator.is_main_process + is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 and accelerator.is_main_process + + if is_log_step: + logging.info(train_tracker) + if wandb_logger: + wandb_log_dict = train_tracker.to_dict() + if output_dict: + wandb_log_dict.update(output_dict) + wandb_logger.log_dict(wandb_log_dict, step) + train_tracker.reset_averages() + + if cfg.save_checkpoint and is_saving_step: + logging.info(f"Checkpoint policy after step {step}") + checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) + save_checkpoint( + checkpoint_dir, + step, + cfg, + accelerator.unwrap_model(policy), + optimizer, + lr_scheduler, + ) + update_last_checkpoint(checkpoint_dir) + if wandb_logger: + wandb_logger.log_policy(checkpoint_dir) + + accelerator.wait_for_everyone() + + if cfg.env and is_eval_step: + step_id = get_step_identifier(step, cfg.steps) + logging.info(f"Eval policy at step {step}") + + with torch.no_grad(): + eval_info = eval_policy( + env=eval_env, + policy=accelerator.unwrap_model(policy), + n_episodes=cfg.eval.n_episodes, + videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", + max_episodes_rendered=4, + start_seed=cfg.seed, + ) + + eval_metrics = { + "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), + "pc_success": AverageMeter("success", ":.1f"), + "eval_s": AverageMeter("eval_s", ":.3f"), + } + eval_tracker = MetricsTracker( + cfg.batch_size, + dataset.num_frames, + dataset.num_episodes, + eval_metrics, + initial_step=step, + accelerator=None, + ) + eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s") + eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward") + eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success") + logging.info(eval_tracker) + if wandb_logger: + wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} + wandb_logger.log_dict(wandb_log_dict, step, mode="eval") + wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval") + + if eval_env: + eval_env.close() + if not accelerator or accelerator.is_main_process: + logging.info("End of training") + + +if __name__ == "__main__": + init_logging() + + # We set step_scheduler_with_optimizer False to prevent accelerate from + # adjusting the lr_scheduler steps based on the num_processes + accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False) + train(accelerator=accelerator) diff --git a/src/lerobot/utils/logging_utils.py b/src/lerobot/utils/logging_utils.py index b6404e66d..7bfeb349d 100644 --- a/src/lerobot/utils/logging_utils.py +++ b/src/lerobot/utils/logging_utils.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Callable from lerobot.utils.utils import format_big_number @@ -84,6 +84,7 @@ class MetricsTracker: "samples", "episodes", "epochs", + "accelerator", ] def __init__( @@ -93,12 +94,14 @@ class MetricsTracker: num_episodes: int, metrics: dict[str, AverageMeter], initial_step: int = 0, + accelerator: Callable | None = None, ): self.__dict__.update(dict.fromkeys(self.__keys__)) self._batch_size = batch_size self._num_frames = num_frames self._avg_samples_per_ep = num_frames / num_episodes self.metrics = metrics + self.accelerator = accelerator self.steps = initial_step # A sample is an (observation,action) pair, where observation and action @@ -128,7 +131,7 @@ class MetricsTracker: Updates metrics that depend on 'step' for one step. """ self.steps += 1 - self.samples += self._batch_size + self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1) self.episodes = self.samples / self._avg_samples_per_ep self.epochs = self.samples / self._num_frames diff --git a/src/lerobot/utils/random_utils.py b/src/lerobot/utils/random_utils.py index da3ecf37f..1004b15a5 100644 --- a/src/lerobot/utils/random_utils.py +++ b/src/lerobot/utils/random_utils.py @@ -17,7 +17,7 @@ import random from collections.abc import Generator from contextlib import contextmanager from pathlib import Path -from typing import Any +from typing import Any, Callable, Generator import numpy as np import torch @@ -164,7 +164,7 @@ def set_rng_state(random_state_dict: dict[str, Any]): torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) -def set_seed(seed) -> None: +def set_seed(seed: int, accelerator: Callable | None = None) -> None: """Set seed for reproducibility.""" random.seed(seed) np.random.seed(seed) @@ -172,6 +172,11 @@ def set_seed(seed) -> None: if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + if accelerator: + from accelerate.utils import set_seed as accelerate_set_seed + + accelerate_set_seed(seed) + @contextmanager def seeded_context(seed: int) -> Generator[None, None, None]: diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 6e13646b0..3ae276958 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -24,6 +24,7 @@ import time from copy import copy, deepcopy from datetime import datetime, timezone from pathlib import Path +from typing import Callable from statistics import mean import numpy as np @@ -56,13 +57,15 @@ def auto_select_torch_device() -> torch.device: # TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level -def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: +def get_safe_torch_device( + try_device: str, log: bool = False, accelerator: Callable | None = None +) -> torch.device: """Given a string, return a torch.device with checks on whether the device is available.""" try_device = str(try_device) match try_device: case "cuda": assert torch.cuda.is_available() - device = torch.device("cuda") + device = accelerator.device if accelerator else torch.device("cuda") case "mps": assert torch.backends.mps.is_available() device = torch.device("mps") @@ -116,6 +119,7 @@ def init_logging( display_pid: bool = False, console_level: str = "INFO", file_level: str = "DEBUG", + accelerator: Callable | None = None, ): def custom_format(record: logging.LogRecord) -> str: dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -152,6 +156,11 @@ def init_logging( file_handler.setLevel(file_level.upper()) logger.addHandler(file_handler) + if accelerator is not None and not accelerator.is_main_process: + # Disable duplicate logging on non-main processes + logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.") + logging.getLogger().setLevel(logging.WARNING) + def format_big_number(num, precision=0): suffixes = ["", "K", "M", "B", "T", "Q"] @@ -165,6 +174,10 @@ def format_big_number(num, precision=0): return num +def is_launched_with_accelerate() -> bool: + return "ACCELERATE_MIXED_PRECISION" in os.environ + + def _relative_path_between(path1: Path, path2: Path) -> Path: """Returns path1 relative to path2.""" path1 = path1.absolute()