mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
feat(scripts): Integrate tqdm for training progress visualization (#3010)
This commit is contained in:
@@ -24,6 +24,7 @@ import torch
|
|||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
@@ -51,6 +52,7 @@ from lerobot.utils.utils import (
|
|||||||
format_big_number,
|
format_big_number,
|
||||||
has_method,
|
has_method,
|
||||||
init_logging,
|
init_logging,
|
||||||
|
inside_slurm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -390,6 +392,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
|
progbar = tqdm(
|
||||||
|
total=cfg.steps - step,
|
||||||
|
desc="Training",
|
||||||
|
unit="step",
|
||||||
|
disable=inside_slurm(),
|
||||||
|
position=0,
|
||||||
|
leave=True,
|
||||||
|
)
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||||
)
|
)
|
||||||
@@ -414,6 +424,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||||
# increment `step` here.
|
# increment `step` here.
|
||||||
step += 1
|
step += 1
|
||||||
|
if is_main_process:
|
||||||
|
progbar.update(1)
|
||||||
train_tracker.step()
|
train_tracker.step()
|
||||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||||
@@ -507,6 +519,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
if is_main_process:
|
||||||
|
progbar.close()
|
||||||
|
|
||||||
if eval_env:
|
if eval_env:
|
||||||
close_envs(eval_env)
|
close_envs(eval_env)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user