cleanup update method

This commit is contained in:
Pepijn
2025-10-14 14:33:58 +02:00
parent 4061b3f5b3
commit d3f1ece680
2 changed files with 21 additions and 60 deletions
+21 -45
View File
@@ -16,14 +16,13 @@
import logging
import os
import time
from collections.abc import Callable
from contextlib import nullcontext
from pprint import pformat
from typing import Any
import torch
from accelerate import Accelerator
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
from lerobot.configs import parser
@@ -36,7 +35,6 @@ from lerobot.envs.utils import close_envs
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.rl.wandb_utils import WandBLogger
from lerobot.scripts.lerobot_eval import eval_policy_all
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
@@ -50,10 +48,8 @@ from lerobot.utils.train_utils import (
)
from lerobot.utils.utils import (
format_big_number,
get_safe_torch_device,
has_method,
init_logging,
is_launched_with_accelerate,
)
@@ -63,17 +59,15 @@ def update_policy(
batch: Any,
optimizer: Optimizer,
grad_clip_norm: float,
grad_scaler: GradScaler,
accelerator: Accelerator,
lr_scheduler=None,
use_amp: bool = False,
lock=None,
accelerator: Callable | None = None,
) -> tuple[MetricsTracker, dict]:
"""
Performs a single training step to update the policy's weights.
This function executes the forward and backward passes, clips gradients, and steps the optimizer and
learning rate scheduler. It also handles mixed-precision training via a GradScaler.
learning rate scheduler. Accelerator handles mixed-precision training automatically.
Args:
train_metrics: A MetricsTracker instance to record training statistics.
@@ -81,11 +75,9 @@ def update_policy(
batch: A batch of training data.
optimizer: The optimizer used to update the policy's parameters.
grad_clip_norm: The maximum norm for gradient clipping.
grad_scaler: The GradScaler for automatic mixed-precision training.
accelerator: The Accelerator instance for distributed training and mixed precision.
lr_scheduler: An optional learning rate scheduler.
use_amp: A boolean indicating whether to use automatic mixed precision.
lock: An optional lock for thread-safe optimizer updates.
accelerator: An optional accelerator, for multi-gpu training.
Returns:
A tuple containing:
@@ -93,51 +85,35 @@ def update_policy(
- A dictionary of outputs from the policy's forward pass, for logging purposes.
"""
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
# Let accelerator handle mixed precision
with accelerator.autocast() if accelerator else (torch.autocast(device_type=device.type) if use_amp else nullcontext()):
with accelerator.autocast():
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
if accelerator:
accelerator.backward(loss)
if grad_clip_norm > 0:
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
else:
grad_norm = torch.tensor(0.0, device=policy.device)
optimizer.step()
# Use accelerator's backward method
accelerator.backward(loss)
# Clip gradients if specified
if grad_clip_norm > 0:
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
else:
grad_scaler.scale(loss).backward()
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
with lock if lock is not None else nullcontext():
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
grad_norm = torch.tensor(0.0, device=accelerator.device)
# Optimizer step
with lock if lock is not None else nullcontext():
optimizer.step()
optimizer.zero_grad()
# Step through pytorch scheduler at every batch instead of epoch
if lr_scheduler is not None:
lr_scheduler.step()
if accelerator:
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
else:
if has_method(policy, "update"):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
# Update internal buffers if policy has update method
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
-15
View File
@@ -177,21 +177,6 @@ def format_big_number(num, precision=0):
return num
def is_launched_with_accelerate() -> bool:
"""Check if the script was launched in a distributed training context.
This checks for standard distributed training environment variables that are set
by accelerate launch, torchrun, or torch.distributed.launch.
Returns:
True if running in a distributed context, False otherwise.
"""
# Check for LOCAL_RANK which is the standard way to detect distributed training
# This is set by accelerate, torchrun, and torch.distributed.launch
return "LOCAL_RANK" in os.environ
def say(text: str, blocking: bool = False):
system = platform.system()