mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
cleanup update method
This commit is contained in:
@@ -16,14 +16,13 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch.amp import GradScaler
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
@@ -36,7 +35,6 @@ from lerobot.envs.utils import close_envs
|
|||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import get_device_from_parameters
|
|
||||||
from lerobot.rl.wandb_utils import WandBLogger
|
from lerobot.rl.wandb_utils import WandBLogger
|
||||||
from lerobot.scripts.lerobot_eval import eval_policy_all
|
from lerobot.scripts.lerobot_eval import eval_policy_all
|
||||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||||
@@ -50,10 +48,8 @@ from lerobot.utils.train_utils import (
|
|||||||
)
|
)
|
||||||
from lerobot.utils.utils import (
|
from lerobot.utils.utils import (
|
||||||
format_big_number,
|
format_big_number,
|
||||||
get_safe_torch_device,
|
|
||||||
has_method,
|
has_method,
|
||||||
init_logging,
|
init_logging,
|
||||||
is_launched_with_accelerate,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -63,17 +59,15 @@ def update_policy(
|
|||||||
batch: Any,
|
batch: Any,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
grad_clip_norm: float,
|
grad_clip_norm: float,
|
||||||
grad_scaler: GradScaler,
|
accelerator: Accelerator,
|
||||||
lr_scheduler=None,
|
lr_scheduler=None,
|
||||||
use_amp: bool = False,
|
|
||||||
lock=None,
|
lock=None,
|
||||||
accelerator: Callable | None = None,
|
|
||||||
) -> tuple[MetricsTracker, dict]:
|
) -> tuple[MetricsTracker, dict]:
|
||||||
"""
|
"""
|
||||||
Performs a single training step to update the policy's weights.
|
Performs a single training step to update the policy's weights.
|
||||||
|
|
||||||
This function executes the forward and backward passes, clips gradients, and steps the optimizer and
|
This function executes the forward and backward passes, clips gradients, and steps the optimizer and
|
||||||
learning rate scheduler. It also handles mixed-precision training via a GradScaler.
|
learning rate scheduler. Accelerator handles mixed-precision training automatically.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_metrics: A MetricsTracker instance to record training statistics.
|
train_metrics: A MetricsTracker instance to record training statistics.
|
||||||
@@ -81,11 +75,9 @@ def update_policy(
|
|||||||
batch: A batch of training data.
|
batch: A batch of training data.
|
||||||
optimizer: The optimizer used to update the policy's parameters.
|
optimizer: The optimizer used to update the policy's parameters.
|
||||||
grad_clip_norm: The maximum norm for gradient clipping.
|
grad_clip_norm: The maximum norm for gradient clipping.
|
||||||
grad_scaler: The GradScaler for automatic mixed-precision training.
|
accelerator: The Accelerator instance for distributed training and mixed precision.
|
||||||
lr_scheduler: An optional learning rate scheduler.
|
lr_scheduler: An optional learning rate scheduler.
|
||||||
use_amp: A boolean indicating whether to use automatic mixed precision.
|
|
||||||
lock: An optional lock for thread-safe optimizer updates.
|
lock: An optional lock for thread-safe optimizer updates.
|
||||||
accelerator: An optional accelerator, for multi-gpu training.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing:
|
A tuple containing:
|
||||||
@@ -93,37 +85,25 @@ def update_policy(
|
|||||||
- A dictionary of outputs from the policy's forward pass, for logging purposes.
|
- A dictionary of outputs from the policy's forward pass, for logging purposes.
|
||||||
"""
|
"""
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
device = get_device_from_parameters(policy)
|
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
# Let accelerator handle mixed precision
|
# Let accelerator handle mixed precision
|
||||||
with accelerator.autocast() if accelerator else (torch.autocast(device_type=device.type) if use_amp else nullcontext()):
|
with accelerator.autocast():
|
||||||
loss, output_dict = policy.forward(batch)
|
loss, output_dict = policy.forward(batch)
|
||||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||||
|
|
||||||
if accelerator:
|
# Use accelerator's backward method
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if grad_clip_norm > 0:
|
|
||||||
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
# Clip gradients if specified
|
||||||
else:
|
if grad_clip_norm > 0:
|
||||||
grad_norm = torch.tensor(0.0, device=policy.device)
|
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
||||||
optimizer.step()
|
|
||||||
else:
|
else:
|
||||||
grad_scaler.scale(loss).backward()
|
grad_norm = torch.tensor(0.0, device=accelerator.device)
|
||||||
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
|
||||||
grad_scaler.unscale_(optimizer)
|
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
# Optimizer step
|
||||||
policy.parameters(),
|
with lock if lock is not None else nullcontext():
|
||||||
grad_clip_norm,
|
optimizer.step()
|
||||||
error_if_nonfinite=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
|
||||||
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
|
||||||
with lock if lock is not None else nullcontext():
|
|
||||||
grad_scaler.step(optimizer)
|
|
||||||
# Updates the scale for next iteration.
|
|
||||||
grad_scaler.update()
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
@@ -131,13 +111,9 @@ def update_policy(
|
|||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
if accelerator:
|
# Update internal buffers if policy has update method
|
||||||
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
|
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
|
||||||
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
|
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
|
||||||
else:
|
|
||||||
if has_method(policy, "update"):
|
|
||||||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
|
||||||
policy.update()
|
|
||||||
|
|
||||||
train_metrics.loss = loss.item()
|
train_metrics.loss = loss.item()
|
||||||
train_metrics.grad_norm = grad_norm.item()
|
train_metrics.grad_norm = grad_norm.item()
|
||||||
|
|||||||
@@ -177,21 +177,6 @@ def format_big_number(num, precision=0):
|
|||||||
|
|
||||||
return num
|
return num
|
||||||
|
|
||||||
|
|
||||||
def is_launched_with_accelerate() -> bool:
|
|
||||||
"""Check if the script was launched in a distributed training context.
|
|
||||||
|
|
||||||
This checks for standard distributed training environment variables that are set
|
|
||||||
by accelerate launch, torchrun, or torch.distributed.launch.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if running in a distributed context, False otherwise.
|
|
||||||
"""
|
|
||||||
# Check for LOCAL_RANK which is the standard way to detect distributed training
|
|
||||||
# This is set by accelerate, torchrun, and torch.distributed.launch
|
|
||||||
return "LOCAL_RANK" in os.environ
|
|
||||||
|
|
||||||
|
|
||||||
def say(text: str, blocking: bool = False):
|
def say(text: str, blocking: bool = False):
|
||||||
system = platform.system()
|
system = platform.system()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user