cleanup update method

This commit is contained in:
Pepijn
2025-10-14 14:33:58 +02:00
parent 4061b3f5b3
commit d3f1ece680
2 changed files with 21 additions and 60 deletions
+19 -43
View File
@@ -16,14 +16,13 @@
import logging import logging
import os import os
import time import time
from collections.abc import Callable
from contextlib import nullcontext from contextlib import nullcontext
from pprint import pformat from pprint import pformat
from typing import Any from typing import Any
import torch import torch
from accelerate import Accelerator
from termcolor import colored from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer from torch.optim import Optimizer
from lerobot.configs import parser from lerobot.configs import parser
@@ -36,7 +35,6 @@ from lerobot.envs.utils import close_envs
from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.rl.wandb_utils import WandBLogger from lerobot.rl.wandb_utils import WandBLogger
from lerobot.scripts.lerobot_eval import eval_policy_all from lerobot.scripts.lerobot_eval import eval_policy_all
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
@@ -50,10 +48,8 @@ from lerobot.utils.train_utils import (
) )
from lerobot.utils.utils import ( from lerobot.utils.utils import (
format_big_number, format_big_number,
get_safe_torch_device,
has_method, has_method,
init_logging, init_logging,
is_launched_with_accelerate,
) )
@@ -63,17 +59,15 @@ def update_policy(
batch: Any, batch: Any,
optimizer: Optimizer, optimizer: Optimizer,
grad_clip_norm: float, grad_clip_norm: float,
grad_scaler: GradScaler, accelerator: Accelerator,
lr_scheduler=None, lr_scheduler=None,
use_amp: bool = False,
lock=None, lock=None,
accelerator: Callable | None = None,
) -> tuple[MetricsTracker, dict]: ) -> tuple[MetricsTracker, dict]:
""" """
Performs a single training step to update the policy's weights. Performs a single training step to update the policy's weights.
This function executes the forward and backward passes, clips gradients, and steps the optimizer and This function executes the forward and backward passes, clips gradients, and steps the optimizer and
learning rate scheduler. It also handles mixed-precision training via a GradScaler. learning rate scheduler. Accelerator handles mixed-precision training automatically.
Args: Args:
train_metrics: A MetricsTracker instance to record training statistics. train_metrics: A MetricsTracker instance to record training statistics.
@@ -81,11 +75,9 @@ def update_policy(
batch: A batch of training data. batch: A batch of training data.
optimizer: The optimizer used to update the policy's parameters. optimizer: The optimizer used to update the policy's parameters.
grad_clip_norm: The maximum norm for gradient clipping. grad_clip_norm: The maximum norm for gradient clipping.
grad_scaler: The GradScaler for automatic mixed-precision training. accelerator: The Accelerator instance for distributed training and mixed precision.
lr_scheduler: An optional learning rate scheduler. lr_scheduler: An optional learning rate scheduler.
use_amp: A boolean indicating whether to use automatic mixed precision.
lock: An optional lock for thread-safe optimizer updates. lock: An optional lock for thread-safe optimizer updates.
accelerator: An optional accelerator, for multi-gpu training.
Returns: Returns:
A tuple containing: A tuple containing:
@@ -93,37 +85,25 @@ def update_policy(
- A dictionary of outputs from the policy's forward pass, for logging purposes. - A dictionary of outputs from the policy's forward pass, for logging purposes.
""" """
start_time = time.perf_counter() start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train() policy.train()
# Let accelerator handle mixed precision # Let accelerator handle mixed precision
with accelerator.autocast() if accelerator else (torch.autocast(device_type=device.type) if use_amp else nullcontext()): with accelerator.autocast():
loss, output_dict = policy.forward(batch) loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict) # TODO(rcadene): policy.unnormalize_outputs(out_dict)
if accelerator: # Use accelerator's backward method
accelerator.backward(loss) accelerator.backward(loss)
if grad_clip_norm > 0:
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) # Clip gradients if specified
else: if grad_clip_norm > 0:
grad_norm = torch.tensor(0.0, device=policy.device) grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
optimizer.step()
else: else:
grad_scaler.scale(loss).backward() grad_norm = torch.tensor(0.0, device=accelerator.device)
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_( # Optimizer step
policy.parameters(), with lock if lock is not None else nullcontext():
grad_clip_norm, optimizer.step()
error_if_nonfinite=False,
)
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
with lock if lock is not None else nullcontext():
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
@@ -131,13 +111,9 @@ def update_policy(
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step() lr_scheduler.step()
if accelerator: # Update internal buffers if policy has update method
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update() accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
else:
if has_method(policy, "update"):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
train_metrics.loss = loss.item() train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item() train_metrics.grad_norm = grad_norm.item()
-15
View File
@@ -177,21 +177,6 @@ def format_big_number(num, precision=0):
return num return num
def is_launched_with_accelerate() -> bool:
"""Check if the script was launched in a distributed training context.
This checks for standard distributed training environment variables that are set
by accelerate launch, torchrun, or torch.distributed.launch.
Returns:
True if running in a distributed context, False otherwise.
"""
# Check for LOCAL_RANK which is the standard way to detect distributed training
# This is set by accelerate, torchrun, and torch.distributed.launch
return "LOCAL_RANK" in os.environ
def say(text: str, blocking: bool = False): def say(text: str, blocking: bool = False):
system = platform.system() system = platform.system()