mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
cleanup update method
This commit is contained in:
@@ -16,14 +16,13 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from termcolor import colored
|
||||
from torch.amp import GradScaler
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.configs import parser
|
||||
@@ -36,7 +35,6 @@ from lerobot.envs.utils import close_envs
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.rl.wandb_utils import WandBLogger
|
||||
from lerobot.scripts.lerobot_eval import eval_policy_all
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
@@ -50,10 +48,8 @@ from lerobot.utils.train_utils import (
|
||||
)
|
||||
from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
has_method,
|
||||
init_logging,
|
||||
is_launched_with_accelerate,
|
||||
)
|
||||
|
||||
|
||||
@@ -63,17 +59,15 @@ def update_policy(
|
||||
batch: Any,
|
||||
optimizer: Optimizer,
|
||||
grad_clip_norm: float,
|
||||
grad_scaler: GradScaler,
|
||||
accelerator: Accelerator,
|
||||
lr_scheduler=None,
|
||||
use_amp: bool = False,
|
||||
lock=None,
|
||||
accelerator: Callable | None = None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
|
||||
This function executes the forward and backward passes, clips gradients, and steps the optimizer and
|
||||
learning rate scheduler. It also handles mixed-precision training via a GradScaler.
|
||||
learning rate scheduler. Accelerator handles mixed-precision training automatically.
|
||||
|
||||
Args:
|
||||
train_metrics: A MetricsTracker instance to record training statistics.
|
||||
@@ -81,11 +75,9 @@ def update_policy(
|
||||
batch: A batch of training data.
|
||||
optimizer: The optimizer used to update the policy's parameters.
|
||||
grad_clip_norm: The maximum norm for gradient clipping.
|
||||
grad_scaler: The GradScaler for automatic mixed-precision training.
|
||||
accelerator: The Accelerator instance for distributed training and mixed precision.
|
||||
lr_scheduler: An optional learning rate scheduler.
|
||||
use_amp: A boolean indicating whether to use automatic mixed precision.
|
||||
lock: An optional lock for thread-safe optimizer updates.
|
||||
accelerator: An optional accelerator, for multi-gpu training.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@@ -93,51 +85,35 @@ def update_policy(
|
||||
- A dictionary of outputs from the policy's forward pass, for logging purposes.
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
device = get_device_from_parameters(policy)
|
||||
policy.train()
|
||||
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast() if accelerator else (torch.autocast(device_type=device.type) if use_amp else nullcontext()):
|
||||
with accelerator.autocast():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
|
||||
if accelerator:
|
||||
accelerator.backward(loss)
|
||||
if grad_clip_norm > 0:
|
||||
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
||||
else:
|
||||
grad_norm = torch.tensor(0.0, device=policy.device)
|
||||
optimizer.step()
|
||||
# Use accelerator's backward method
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Clip gradients if specified
|
||||
if grad_clip_norm > 0:
|
||||
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
||||
else:
|
||||
grad_scaler.scale(loss).backward()
|
||||
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
||||
grad_scaler.unscale_(optimizer)
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(),
|
||||
grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
||||
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
||||
with lock if lock is not None else nullcontext():
|
||||
grad_scaler.step(optimizer)
|
||||
# Updates the scale for next iteration.
|
||||
grad_scaler.update()
|
||||
|
||||
grad_norm = torch.tensor(0.0, device=accelerator.device)
|
||||
|
||||
# Optimizer step
|
||||
with lock if lock is not None else nullcontext():
|
||||
optimizer.step()
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Step through pytorch scheduler at every batch instead of epoch
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
if accelerator:
|
||||
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
|
||||
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
|
||||
else:
|
||||
if has_method(policy, "update"):
|
||||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
||||
policy.update()
|
||||
# Update internal buffers if policy has update method
|
||||
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
|
||||
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
|
||||
|
||||
train_metrics.loss = loss.item()
|
||||
train_metrics.grad_norm = grad_norm.item()
|
||||
|
||||
@@ -177,21 +177,6 @@ def format_big_number(num, precision=0):
|
||||
|
||||
return num
|
||||
|
||||
|
||||
def is_launched_with_accelerate() -> bool:
|
||||
"""Check if the script was launched in a distributed training context.
|
||||
|
||||
This checks for standard distributed training environment variables that are set
|
||||
by accelerate launch, torchrun, or torch.distributed.launch.
|
||||
|
||||
Returns:
|
||||
True if running in a distributed context, False otherwise.
|
||||
"""
|
||||
# Check for LOCAL_RANK which is the standard way to detect distributed training
|
||||
# This is set by accelerate, torchrun, and torch.distributed.launch
|
||||
return "LOCAL_RANK" in os.environ
|
||||
|
||||
|
||||
def say(text: str, blocking: bool = False):
|
||||
system = platform.system()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user