mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
@@ -92,18 +92,13 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
|||||||
def lr_lambda(current_step):
|
def lr_lambda(current_step):
|
||||||
def linear_warmup_schedule(current_step):
|
def linear_warmup_schedule(current_step):
|
||||||
if current_step <= 0:
|
if current_step <= 0:
|
||||||
return 0.1 # Start at 10% instead of 0.1% of peak LR
|
return 1 / (self.num_warmup_steps + 1)
|
||||||
if current_step >= self.num_warmup_steps:
|
frac = 1 - current_step / self.num_warmup_steps
|
||||||
return 1.0 # Reach peak at end of warmup
|
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
|
||||||
# Linear interpolation from 10% to 100% of peak LR
|
|
||||||
return 0.1 + 0.9 * (current_step / self.num_warmup_steps)
|
|
||||||
|
|
||||||
def cosine_decay_schedule(current_step):
|
def cosine_decay_schedule(current_step):
|
||||||
# CRITICAL FIX: Decay should count from END of warmup, not from step 0!
|
step = min(current_step, self.num_decay_steps)
|
||||||
decay_step = current_step - self.num_warmup_steps
|
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
||||||
decay_step = max(0, min(decay_step, self.num_decay_steps))
|
|
||||||
|
|
||||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * decay_step / self.num_decay_steps))
|
|
||||||
alpha = self.decay_lr / self.peak_lr
|
alpha = self.decay_lr / self.peak_lr
|
||||||
decayed = (1 - alpha) * cosine_decay + alpha
|
decayed = (1 - alpha) * cosine_decay + alpha
|
||||||
return decayed
|
return decayed
|
||||||
|
|||||||
+18
-100
@@ -20,8 +20,6 @@ from pprint import pformat
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate import Accelerator
|
|
||||||
from accelerate.utils import DistributedDataParallelKwargs
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch.amp import GradScaler
|
from torch.amp import GradScaler
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
@@ -149,40 +147,17 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
cfg.validate()
|
cfg.validate()
|
||||||
logging.info(pformat(cfg.to_dict()))
|
logging.info(pformat(cfg.to_dict()))
|
||||||
|
|
||||||
# Initialize Accelerate if requested
|
|
||||||
accelerator = None
|
|
||||||
if cfg.use_accelerate:
|
|
||||||
# Configure DDP to handle unused parameters
|
|
||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
|
||||||
accelerator = Accelerator(
|
|
||||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
|
||||||
mixed_precision=cfg.mixed_precision,
|
|
||||||
kwargs_handlers=[ddp_kwargs],
|
|
||||||
)
|
|
||||||
device = accelerator.device
|
|
||||||
if accelerator.is_main_process:
|
|
||||||
logging.info(
|
|
||||||
f"Accelerate initialized with device: {device}, mixed_precision: {cfg.mixed_precision}"
|
|
||||||
)
|
|
||||||
logging.info(f"Training on {accelerator.num_processes} processes")
|
|
||||||
else:
|
|
||||||
# Check device is available (original behavior)
|
|
||||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
|
||||||
|
|
||||||
# Only create wandb logger on main process when using accelerate
|
|
||||||
if cfg.wandb.enable and cfg.wandb.project:
|
if cfg.wandb.enable and cfg.wandb.project:
|
||||||
if accelerator is None or accelerator.is_main_process:
|
wandb_logger = WandBLogger(cfg)
|
||||||
wandb_logger = WandBLogger(cfg)
|
|
||||||
else:
|
|
||||||
wandb_logger = None
|
|
||||||
else:
|
else:
|
||||||
wandb_logger = None
|
wandb_logger = None
|
||||||
if accelerator is None or accelerator.is_main_process:
|
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
|
||||||
|
|
||||||
if cfg.seed is not None:
|
if cfg.seed is not None:
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
|
|
||||||
|
# Check device is available
|
||||||
|
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
@@ -225,12 +200,6 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
if cfg.resume:
|
if cfg.resume:
|
||||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||||
|
|
||||||
# Prepare objects with Accelerate if enabled
|
|
||||||
if accelerator is not None:
|
|
||||||
policy, optimizer, lr_scheduler = accelerator.prepare(policy, optimizer, lr_scheduler)
|
|
||||||
if accelerator.is_main_process:
|
|
||||||
logging.info("Policy, optimizer, and scheduler prepared with Accelerate")
|
|
||||||
|
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
||||||
@@ -266,11 +235,6 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
drop_last=False,
|
drop_last=False,
|
||||||
prefetch_factor=2,
|
prefetch_factor=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare dataloader with Accelerate if enabled
|
|
||||||
if accelerator is not None:
|
|
||||||
dataloader = accelerator.prepare(dataloader)
|
|
||||||
|
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
@@ -288,68 +252,22 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Start offline training on a fixed dataset")
|
logging.info("Start offline training on a fixed dataset")
|
||||||
|
|
||||||
# Calculate gradient accumulation steps for multi-GPU training
|
|
||||||
# This ensures effective batch size matches single-GPU training
|
|
||||||
gradient_accumulation_steps = accelerator.num_processes if accelerator is not None else 1
|
|
||||||
if accelerator and accelerator.is_main_process:
|
|
||||||
logging.info(f"Using gradient accumulation: {gradient_accumulation_steps} steps")
|
|
||||||
logging.info(f"Effective batch size: {cfg.batch_size} (same as single-GPU)")
|
|
||||||
|
|
||||||
for _ in range(step, cfg.steps):
|
for _ in range(step, cfg.steps):
|
||||||
policy.train()
|
start_time = time.perf_counter()
|
||||||
optimizer.zero_grad()
|
batch = next(dl_iter)
|
||||||
|
batch = preprocessor(batch)
|
||||||
|
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||||
|
|
||||||
# Accumulate gradients over multiple mini-batches to match single-GPU effective batch size
|
train_tracker, output_dict = update_policy(
|
||||||
accumulated_loss = 0
|
train_tracker,
|
||||||
accumulated_output_dict = {}
|
policy,
|
||||||
|
batch,
|
||||||
for accum_step in range(gradient_accumulation_steps):
|
optimizer,
|
||||||
start_time = time.perf_counter()
|
cfg.optimizer.grad_clip_norm,
|
||||||
batch = next(dl_iter)
|
grad_scaler=grad_scaler,
|
||||||
batch = preprocessor(batch)
|
lr_scheduler=lr_scheduler,
|
||||||
if accum_step == 0: # Only track data loading time once per step
|
use_amp=cfg.policy.use_amp,
|
||||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
)
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
with accelerator.autocast() if accelerator else nullcontext():
|
|
||||||
loss, output_dict = policy.forward(batch)
|
|
||||||
# Scale loss by accumulation steps to get proper average
|
|
||||||
loss = loss / gradient_accumulation_steps
|
|
||||||
|
|
||||||
# Backward pass
|
|
||||||
if accelerator:
|
|
||||||
accelerator.backward(loss)
|
|
||||||
else:
|
|
||||||
grad_scaler.scale(loss).backward()
|
|
||||||
|
|
||||||
# Accumulate metrics
|
|
||||||
accumulated_loss += loss.item()
|
|
||||||
if accum_step == 0:
|
|
||||||
accumulated_output_dict = output_dict
|
|
||||||
|
|
||||||
# Gradient clipping and optimizer step
|
|
||||||
if accelerator:
|
|
||||||
accelerator.clip_grad_norm_(policy.parameters(), cfg.optimizer.grad_clip_norm)
|
|
||||||
optimizer.step()
|
|
||||||
else:
|
|
||||||
grad_scaler.unscale_(optimizer)
|
|
||||||
_ = torch.nn.utils.clip_grad_norm_(
|
|
||||||
policy.parameters(), cfg.optimizer.grad_clip_norm, error_if_nonfinite=False
|
|
||||||
)
|
|
||||||
grad_scaler.step(optimizer)
|
|
||||||
grad_scaler.update()
|
|
||||||
|
|
||||||
# Update learning rate scheduler
|
|
||||||
if lr_scheduler is not None:
|
|
||||||
lr_scheduler.step()
|
|
||||||
|
|
||||||
# Update metrics with accumulated values
|
|
||||||
train_tracker.loss = accumulated_loss
|
|
||||||
train_tracker.lr = optimizer.param_groups[0]["lr"]
|
|
||||||
train_tracker.update_s = time.perf_counter() - start_time
|
|
||||||
output_dict = accumulated_output_dict
|
|
||||||
|
|
||||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||||
# increment `step` here.
|
# increment `step` here.
|
||||||
|
|||||||
Reference in New Issue
Block a user