Enhance training and logging functionality with accelerator support

- Added support for multi-GPU training by introducing an `accelerator` parameter in training functions.
- Updated `update_policy` to handle gradient updates based on the presence of an accelerator.
- Modified logging to prevent duplicate messages in non-main processes.
- Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage.
- Updated `MetricsTracker` to account for the number of processes when calculating metrics.
- Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency.
This commit is contained in:
AdilZouitine
2025-10-02 18:11:27 +02:00
parent abde7be3b3
commit f30da2dec1
5 changed files with 137 additions and 45 deletions
+1
View File
@@ -125,6 +125,7 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11", "lerobot[grpcio-dep]"
# Features # Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"] async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"]
accelerate = ["accelerate>=1.10.0"]
# Development # Development
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"] dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
+109 -40
View File
@@ -18,6 +18,7 @@ import time
from contextlib import nullcontext from contextlib import nullcontext
from pprint import pformat from pprint import pformat
from typing import Any from typing import Any
from collections.abc import Callable
import torch import torch
from termcolor import colored from termcolor import colored
@@ -51,6 +52,7 @@ from lerobot.utils.utils import (
get_safe_torch_device, get_safe_torch_device,
has_method, has_method,
init_logging, init_logging,
is_launched_with_accelerate,
) )
@@ -64,6 +66,7 @@ def update_policy(
lr_scheduler=None, lr_scheduler=None,
use_amp: bool = False, use_amp: bool = False,
lock=None, lock=None,
accelerator: Callable | None = None,
) -> tuple[MetricsTracker, dict]: ) -> tuple[MetricsTracker, dict]:
""" """
Performs a single training step to update the policy's weights. Performs a single training step to update the policy's weights.
@@ -81,6 +84,7 @@ def update_policy(
lr_scheduler: An optional learning rate scheduler. lr_scheduler: An optional learning rate scheduler.
use_amp: A boolean indicating whether to use automatic mixed precision. use_amp: A boolean indicating whether to use automatic mixed precision.
lock: An optional lock for thread-safe optimizer updates. lock: An optional lock for thread-safe optimizer updates.
accelerator: An optional accelerator, for multi-gpu training.
Returns: Returns:
A tuple containing: A tuple containing:
@@ -90,26 +94,36 @@ def update_policy(
start_time = time.perf_counter() start_time = time.perf_counter()
device = get_device_from_parameters(policy) device = get_device_from_parameters(policy)
policy.train() policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext(): with torch.autocast(device_type=device.type) if use_amp and accelerator is None else nullcontext():
loss, output_dict = policy.forward(batch) loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict) # TODO(rcadene): policy.unnormalize_outputs(out_dict)
grad_scaler.scale(loss).backward()
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**. if accelerator:
grad_scaler.unscale_(optimizer) accelerator.backward(loss)
accelerator.unscale_gradients(optimizer=optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
optimizer.step()
else:
grad_scaler.scale(loss).backward()
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), policy.parameters(),
grad_clip_norm, grad_clip_norm,
error_if_nonfinite=False, error_if_nonfinite=False,
) )
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them, # Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs. # although it still skips optimizer.step() if the gradients contain infs or NaNs.
with lock if lock is not None else nullcontext(): with lock if lock is not None else nullcontext():
grad_scaler.step(optimizer) grad_scaler.step(optimizer)
# Updates the scale for next iteration. # Updates the scale for next iteration.
grad_scaler.update() grad_scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
@@ -117,9 +131,13 @@ def update_policy(
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step() lr_scheduler.step()
if has_method(policy, "update"): if accelerator:
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
policy.update() accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
else:
if has_method(policy, "update"):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
train_metrics.loss = loss.item() train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item() train_metrics.grad_norm = grad_norm.item()
@@ -129,7 +147,7 @@ def update_policy(
@parser.wrap() @parser.wrap()
def train(cfg: TrainPipelineConfig): def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
""" """
Main function to train a policy. Main function to train a policy.
@@ -147,6 +165,10 @@ def train(cfg: TrainPipelineConfig):
cfg.validate() cfg.validate()
logging.info(pformat(cfg.to_dict())) logging.info(pformat(cfg.to_dict()))
if accelerator and not accelerator.is_main_process:
# Disable logging on non-main processes.
cfg.wandb.enable = False
if cfg.wandb.enable and cfg.wandb.project: if cfg.wandb.enable and cfg.wandb.project:
wandb_logger = WandBLogger(cfg) wandb_logger = WandBLogger(cfg)
else: else:
@@ -154,10 +176,10 @@ def train(cfg: TrainPipelineConfig):
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if cfg.seed is not None: if cfg.seed is not None:
set_seed(cfg.seed) set_seed(cfg.seed, accelerator=accelerator)
# Check device is available # Check device is available
device = get_safe_torch_device(cfg.policy.device, log=True) device = get_safe_torch_device(cfg.policy.device, log=True, accelerator=accelerator)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@@ -177,6 +199,7 @@ def train(cfg: TrainPipelineConfig):
cfg=cfg.policy, cfg=cfg.policy,
ds_meta=dataset.meta, ds_meta=dataset.meta,
) )
policy.to(device)
# Create processors - only provide dataset_stats if not resuming from saved processors # Create processors - only provide dataset_stats if not resuming from saved processors
processor_kwargs = {} processor_kwargs = {}
@@ -221,14 +244,15 @@ def train(cfg: TrainPipelineConfig):
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") if not accelerator or accelerator.is_main_process:
if cfg.env is not None: logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info(f"{cfg.env.task=}") if cfg.env is not None:
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") logging.info(f"{cfg.env.task=}")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_episodes=}") logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{dataset.num_episodes=}")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training # create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"): if hasattr(cfg.policy, "drop_n_last_frames"):
@@ -253,6 +277,10 @@ def train(cfg: TrainPipelineConfig):
drop_last=False, drop_last=False,
prefetch_factor=2, prefetch_factor=2,
) )
if accelerator:
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
policy.train() policy.train()
@@ -266,10 +294,16 @@ def train(cfg: TrainPipelineConfig):
} }
train_tracker = MetricsTracker( train_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
train_metrics,
initial_step=step,
accelerator=accelerator,
) )
logging.info("Start offline training on a fixed dataset") if not accelerator or accelerator.is_main_process:
logging.info("Start offline training on a fixed dataset")
for _ in range(step, cfg.steps): for _ in range(step, cfg.steps):
start_time = time.perf_counter() start_time = time.perf_counter()
batch = next(dl_iter) batch = next(dl_iter)
@@ -285,15 +319,26 @@ def train(cfg: TrainPipelineConfig):
grad_scaler=grad_scaler, grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
use_amp=cfg.policy.use_amp, use_amp=cfg.policy.use_amp,
accelerator=accelerator,
) )
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here. # increment `step` here.
step += 1 step += 1
train_tracker.step() train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 is_log_step = (
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps cfg.log_freq > 0 and step % cfg.log_freq == 0 and (not accelerator or accelerator.is_main_process)
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 )
is_saving_step = (
step % cfg.save_freq == 0
or step == cfg.steps
and (not accelerator or accelerator.is_main_process)
)
is_eval_step = (
cfg.eval_freq > 0
and step % cfg.eval_freq == 0
and (not accelerator or accelerator.is_main_process)
)
if is_log_step: if is_log_step:
logging.info(train_tracker) logging.info(train_tracker)
@@ -308,22 +353,31 @@ def train(cfg: TrainPipelineConfig):
logging.info(f"Checkpoint policy after step {step}") logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint( save_checkpoint(
checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor checkpoint_dir=checkpoint_dir,
step=step,
cfg=cfg,
policy=policy if not accelerator else accelerator.unwrap_model(policy),
optimizer=optimizer,
scheduler=lr_scheduler,
) )
update_last_checkpoint(checkpoint_dir) update_last_checkpoint(checkpoint_dir)
if wandb_logger: if wandb_logger:
wandb_logger.log_policy(checkpoint_dir) wandb_logger.log_policy(checkpoint_dir)
if accelerator:
accelerator.wait_for_everyone()
if cfg.env and is_eval_step: if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps) step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
with ( with (
torch.no_grad(), torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), torch.autocast(device_type=device.type)
if cfg.policy.use_amp and not accelerator
else nullcontext(),
): ):
eval_info = eval_policy_all( eval_info = eval_policy_all(
envs=eval_env, # dict[suite][task_id] -> vec_env envs=eval_env, # dict[suite][task_id] -> vec_env
policy=policy, policy=policy if not accelerator else accelerator.unwrap_model(policy),
preprocessor=preprocessor, preprocessor=preprocessor,
postprocessor=postprocessor, postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes, n_episodes=cfg.eval.n_episodes,
@@ -346,7 +400,12 @@ def train(cfg: TrainPipelineConfig):
"eval_s": AverageMeter("eval_s", ":.3f"), "eval_s": AverageMeter("eval_s", ":.3f"),
} }
eval_tracker = MetricsTracker( eval_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
eval_metrics,
initial_step=step,
accelerator=accelerator,
) )
eval_tracker.eval_s = aggregated.pop("eval_s") eval_tracker.eval_s = aggregated.pop("eval_s")
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward") eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
@@ -358,7 +417,9 @@ def train(cfg: TrainPipelineConfig):
if eval_env: if eval_env:
close_envs(eval_env) close_envs(eval_env)
logging.info("End of training")
if not accelerator or accelerator.is_main_process:
logging.info("End of training")
if cfg.policy.push_to_hub: if cfg.policy.push_to_hub:
policy.push_model_to_hub(cfg) policy.push_model_to_hub(cfg)
@@ -372,4 +433,12 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() if is_launched_with_accelerate():
import accelerate
# We set step_scheduler_with_optimizer False to prevent accelerate from
# adjusting the lr_scheduler steps based on the num_processes
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False)
train(accelerator=accelerator)
else:
train()
+5 -1
View File
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Callable
from typing import Any from typing import Any
from lerobot.utils.utils import format_big_number from lerobot.utils.utils import format_big_number
@@ -84,6 +85,7 @@ class MetricsTracker:
"samples", "samples",
"episodes", "episodes",
"epochs", "epochs",
"accelerator",
] ]
def __init__( def __init__(
@@ -93,6 +95,7 @@ class MetricsTracker:
num_episodes: int, num_episodes: int,
metrics: dict[str, AverageMeter], metrics: dict[str, AverageMeter],
initial_step: int = 0, initial_step: int = 0,
accelerator: Callable | None = None,
): ):
self.__dict__.update(dict.fromkeys(self.__keys__)) self.__dict__.update(dict.fromkeys(self.__keys__))
self._batch_size = batch_size self._batch_size = batch_size
@@ -106,6 +109,7 @@ class MetricsTracker:
self.samples = self.steps * self._batch_size self.samples = self.steps * self._batch_size
self.episodes = self.samples / self._avg_samples_per_ep self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames self.epochs = self.samples / self._num_frames
self.accelerator = accelerator
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any: def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
if name in self.__dict__: if name in self.__dict__:
@@ -128,7 +132,7 @@ class MetricsTracker:
Updates metrics that depend on 'step' for one step. Updates metrics that depend on 'step' for one step.
""" """
self.steps += 1 self.steps += 1
self.samples += self._batch_size self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
self.episodes = self.samples / self._avg_samples_per_ep self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames self.epochs = self.samples / self._num_frames
+8 -2
View File
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random import random
from collections.abc import Generator from collections.abc import Callable, Generator
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -164,14 +164,20 @@ def set_rng_state(random_state_dict: dict[str, Any]):
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
def set_seed(seed) -> None: def set_seed(seed, accelerator: Callable | None = None) -> None:
"""Set seed for reproducibility.""" """Set seed for reproducibility."""
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
if accelerator:
from accelerate.utils import set_seed
set_seed(seed)
@contextmanager @contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]: def seeded_context(seed: int) -> Generator[None, None, None]:
+14 -2
View File
@@ -20,6 +20,7 @@ import select
import subprocess import subprocess
import sys import sys
import time import time
from collections.abc import Callable
from copy import copy, deepcopy from copy import copy, deepcopy
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -49,13 +50,15 @@ def auto_select_torch_device() -> torch.device:
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level # TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: def get_safe_torch_device(
try_device: str, log: bool = False, accelerator: Callable | None = None
) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available.""" """Given a string, return a torch.device with checks on whether the device is available."""
try_device = str(try_device) try_device = str(try_device)
match try_device: match try_device:
case "cuda": case "cuda":
assert torch.cuda.is_available() assert torch.cuda.is_available()
device = torch.device("cuda") device = accelerator.device if accelerator else torch.device("cuda")
case "mps": case "mps":
assert torch.backends.mps.is_available() assert torch.backends.mps.is_available()
device = torch.device("mps") device = torch.device("mps")
@@ -109,6 +112,7 @@ def init_logging(
display_pid: bool = False, display_pid: bool = False,
console_level: str = "INFO", console_level: str = "INFO",
file_level: str = "DEBUG", file_level: str = "DEBUG",
accelerator: Callable | None = None,
): ):
def custom_format(record: logging.LogRecord) -> str: def custom_format(record: logging.LogRecord) -> str:
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@@ -137,6 +141,10 @@ def init_logging(
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
console_handler.setLevel(console_level.upper()) console_handler.setLevel(console_level.upper())
logger.addHandler(console_handler) logger.addHandler(console_handler)
if accelerator is not None and not accelerator.is_main_process:
# Disable duplicate logging on non-main processes
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
logging.getLogger().setLevel(logging.WARNING)
# Additionally write logs to file # Additionally write logs to file
if log_file is not None: if log_file is not None:
@@ -158,6 +166,10 @@ def format_big_number(num, precision=0):
return num return num
def is_launched_with_accelerate() -> bool:
return "ACCELERATE_MIXED_PRECISION" in os.environ
def say(text: str, blocking: bool = False): def say(text: str, blocking: bool = False):
system = platform.system() system = platform.system()