mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
feat(scripts): Integrate tqdm for training progress visualization (#3010)
This commit is contained in:
@@ -24,6 +24,7 @@ import torch
|
||||
from accelerate import Accelerator
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
@@ -51,6 +52,7 @@ from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
has_method,
|
||||
init_logging,
|
||||
inside_slurm,
|
||||
)
|
||||
|
||||
|
||||
@@ -390,6 +392,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
progbar = tqdm(
|
||||
total=cfg.steps - step,
|
||||
desc="Training",
|
||||
unit="step",
|
||||
disable=inside_slurm(),
|
||||
position=0,
|
||||
leave=True,
|
||||
)
|
||||
logging.info(
|
||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||
)
|
||||
@@ -414,6 +424,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
if is_main_process:
|
||||
progbar.update(1)
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
@@ -507,6 +519,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_main_process:
|
||||
progbar.close()
|
||||
|
||||
if eval_env:
|
||||
close_envs(eval_env)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user