From d3f1ece680dcfc28968e65393052616e47c9bf7f Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 14 Oct 2025 14:33:58 +0200 Subject: [PATCH] cleanup update method --- src/lerobot/scripts/lerobot_train.py | 66 +++++++++------------------- src/lerobot/utils/utils.py | 15 ------- 2 files changed, 21 insertions(+), 60 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 19d2a98ad..baf52c400 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -16,14 +16,13 @@ import logging import os import time -from collections.abc import Callable from contextlib import nullcontext from pprint import pformat from typing import Any import torch +from accelerate import Accelerator from termcolor import colored -from torch.amp import GradScaler from torch.optim import Optimizer from lerobot.configs import parser @@ -36,7 +35,6 @@ from lerobot.envs.utils import close_envs from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.utils import get_device_from_parameters from lerobot.rl.wandb_utils import WandBLogger from lerobot.scripts.lerobot_eval import eval_policy_all from lerobot.utils.logging_utils import AverageMeter, MetricsTracker @@ -50,10 +48,8 @@ from lerobot.utils.train_utils import ( ) from lerobot.utils.utils import ( format_big_number, - get_safe_torch_device, has_method, init_logging, - is_launched_with_accelerate, ) @@ -63,17 +59,15 @@ def update_policy( batch: Any, optimizer: Optimizer, grad_clip_norm: float, - grad_scaler: GradScaler, + accelerator: Accelerator, lr_scheduler=None, - use_amp: bool = False, lock=None, - accelerator: Callable | None = None, ) -> tuple[MetricsTracker, dict]: """ Performs a single training step to update the policy's weights. This function executes the forward and backward passes, clips gradients, and steps the optimizer and - learning rate scheduler. It also handles mixed-precision training via a GradScaler. + learning rate scheduler. Accelerator handles mixed-precision training automatically. Args: train_metrics: A MetricsTracker instance to record training statistics. @@ -81,11 +75,9 @@ def update_policy( batch: A batch of training data. optimizer: The optimizer used to update the policy's parameters. grad_clip_norm: The maximum norm for gradient clipping. - grad_scaler: The GradScaler for automatic mixed-precision training. + accelerator: The Accelerator instance for distributed training and mixed precision. lr_scheduler: An optional learning rate scheduler. - use_amp: A boolean indicating whether to use automatic mixed precision. lock: An optional lock for thread-safe optimizer updates. - accelerator: An optional accelerator, for multi-gpu training. Returns: A tuple containing: @@ -93,51 +85,35 @@ def update_policy( - A dictionary of outputs from the policy's forward pass, for logging purposes. """ start_time = time.perf_counter() - device = get_device_from_parameters(policy) policy.train() + # Let accelerator handle mixed precision - with accelerator.autocast() if accelerator else (torch.autocast(device_type=device.type) if use_amp else nullcontext()): + with accelerator.autocast(): loss, output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) - if accelerator: - accelerator.backward(loss) - if grad_clip_norm > 0: - grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) - else: - grad_norm = torch.tensor(0.0, device=policy.device) - optimizer.step() + # Use accelerator's backward method + accelerator.backward(loss) + + # Clip gradients if specified + if grad_clip_norm > 0: + grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) else: - grad_scaler.scale(loss).backward() - # Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**. - grad_scaler.unscale_(optimizer) - - grad_norm = torch.nn.utils.clip_grad_norm_( - policy.parameters(), - grad_clip_norm, - error_if_nonfinite=False, - ) - - # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, - # although it still skips optimizer.step() if the gradients contain infs or NaNs. - with lock if lock is not None else nullcontext(): - grad_scaler.step(optimizer) - # Updates the scale for next iteration. - grad_scaler.update() - + grad_norm = torch.tensor(0.0, device=accelerator.device) + + # Optimizer step + with lock if lock is not None else nullcontext(): + optimizer.step() + optimizer.zero_grad() # Step through pytorch scheduler at every batch instead of epoch if lr_scheduler is not None: lr_scheduler.step() - if accelerator: - if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): - accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update() - else: - if has_method(policy, "update"): - # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). - policy.update() + # Update internal buffers if policy has update method + if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): + accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update() train_metrics.loss = loss.item() train_metrics.grad_norm = grad_norm.item() diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 3c9cf5145..79a116a45 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -177,21 +177,6 @@ def format_big_number(num, precision=0): return num - -def is_launched_with_accelerate() -> bool: - """Check if the script was launched in a distributed training context. - - This checks for standard distributed training environment variables that are set - by accelerate launch, torchrun, or torch.distributed.launch. - - Returns: - True if running in a distributed context, False otherwise. - """ - # Check for LOCAL_RANK which is the standard way to detect distributed training - # This is set by accelerate, torchrun, and torch.distributed.launch - return "LOCAL_RANK" in os.environ - - def say(text: str, blocking: bool = False): system = platform.system()