feat(scripts): Integrate tqdm for training progress visualization (#3010)

This commit is contained in:
Steven Palma
2026-02-24 19:10:43 +01:00
committed by GitHub
parent 5095ab0845
commit 18d9cb5ac4
+15
View File
@@ -24,6 +24,7 @@ import torch
from accelerate import Accelerator from accelerate import Accelerator
from termcolor import colored from termcolor import colored
from torch.optim import Optimizer from torch.optim import Optimizer
from tqdm import tqdm
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
@@ -51,6 +52,7 @@ from lerobot.utils.utils import (
format_big_number, format_big_number,
has_method, has_method,
init_logging, init_logging,
inside_slurm,
) )
@@ -390,6 +392,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
) )
if is_main_process: if is_main_process:
progbar = tqdm(
total=cfg.steps - step,
desc="Training",
unit="step",
disable=inside_slurm(),
position=0,
leave=True,
)
logging.info( logging.info(
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}" f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
) )
@@ -414,6 +424,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here. # increment `step` here.
step += 1 step += 1
if is_main_process:
progbar.update(1)
train_tracker.step() train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
@@ -507,6 +519,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if is_main_process:
progbar.close()
if eval_env: if eval_env:
close_envs(eval_env) close_envs(eval_env)