From 1ac8e96575de3b54af75bc09fcf54da3a2ab3f08 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 17 Apr 2026 10:59:43 +0100 Subject: [PATCH] refactor(profiling): shrink lerobot_train.py diff via start()/finalize() Replace the `with profiler or nullcontext():` wrap around the entire training loop with explicit `profiler.start()` / `profiler.finalize()` calls, and tighten `_section(...)` regions in `update_policy` to only wrap the hot calls (forward / backward / optimizer.step). This avoids ~120 lines of pure re-indentation noise while keeping the exact same artifacts on disk and the same public behavior. lerobot_train.py diff vs main: 267 -> 29 changed lines. Made-with: Cursor --- src/lerobot/scripts/lerobot_train.py | 240 +++++++++++++------------- src/lerobot/utils/profiling_utils.py | 7 +- tests/scripts/test_model_profiling.py | 12 +- 3 files changed, 128 insertions(+), 131 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index c2bca99cf..348d25356 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -130,23 +130,22 @@ def update_policy( with _section("backward"): accelerator.backward(loss) - # Clip gradients if specified - if grad_clip_norm > 0: - grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) - else: - grad_norm = torch.nn.utils.clip_grad_norm_( - policy.parameters(), float("inf"), error_if_nonfinite=False - ) + # Clip gradients if specified + if grad_clip_norm > 0: + grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.parameters(), float("inf"), error_if_nonfinite=False + ) - with _section("optimizer"): - with lock if lock is not None else nullcontext(): - optimizer.step() + with _section("optimizer"), lock if lock is not None else nullcontext(): + optimizer.step() - optimizer.zero_grad() + optimizer.zero_grad() - # Step through pytorch scheduler at every batch instead of epoch - if lr_scheduler is not None: - lr_scheduler.step() + # Step through pytorch scheduler at every batch instead of epoch + if lr_scheduler is not None: + lr_scheduler.step() # Update internal buffers if policy has update method if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): @@ -327,6 +326,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): profiler.record_deterministic_forward( policy=policy, dataset=dataset, batch_size=cfg.batch_size, preprocessor=preprocessor ) + profiler.start() # Load precomputed SARM progress for RA-BC if enabled # Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py @@ -441,124 +441,124 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): logging.info( f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}" ) - with profiler or nullcontext(): - for _ in range(step, cfg.steps): - start_time = time.perf_counter() - batch = next(dl_iter) - batch = preprocessor(batch) - train_tracker.dataloading_s = time.perf_counter() - start_time - train_tracker, output_dict = update_policy( - train_tracker, - policy, - batch, - optimizer, - cfg.optimizer.grad_clip_norm, - accelerator=accelerator, - lr_scheduler=lr_scheduler, - rabc_weights_provider=rabc_weights, - profiler=profiler, - ) + for _ in range(step, cfg.steps): + start_time = time.perf_counter() + batch = next(dl_iter) + batch = preprocessor(batch) + train_tracker.dataloading_s = time.perf_counter() - start_time - # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we - # increment `step` here. - step += 1 + train_tracker, output_dict = update_policy( + train_tracker, + policy, + batch, + optimizer, + cfg.optimizer.grad_clip_norm, + accelerator=accelerator, + lr_scheduler=lr_scheduler, + rabc_weights_provider=rabc_weights, + profiler=profiler, + ) + + # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we + # increment `step` here. + step += 1 + if is_main_process: + progbar.update(1) + if profiler: + profiler.step(step, train_tracker) + train_tracker.step() + is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process + is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps + is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 + + if is_log_step: + logging.info(train_tracker) + if wandb_logger: + wandb_log_dict = train_tracker.to_dict() + if output_dict: + wandb_log_dict.update(output_dict) + # Log RA-BC statistics if enabled + if rabc_weights is not None: + rabc_stats = rabc_weights.get_stats() + wandb_log_dict.update( + { + "rabc_delta_mean": rabc_stats["delta_mean"], + "rabc_delta_std": rabc_stats["delta_std"], + "rabc_num_frames": rabc_stats["num_frames"], + } + ) + wandb_logger.log_dict(wandb_log_dict, step) + train_tracker.reset_averages() + + if cfg.save_checkpoint and is_saving_step: if is_main_process: - progbar.update(1) - if profiler: - profiler.step(step, train_tracker) - train_tracker.step() - is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process - is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps - is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 - - if is_log_step: - logging.info(train_tracker) + logging.info(f"Checkpoint policy after step {step}") + checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) + save_checkpoint( + checkpoint_dir=checkpoint_dir, + step=step, + cfg=cfg, + policy=accelerator.unwrap_model(policy), + optimizer=optimizer, + scheduler=lr_scheduler, + preprocessor=preprocessor, + postprocessor=postprocessor, + ) + update_last_checkpoint(checkpoint_dir) if wandb_logger: - wandb_log_dict = train_tracker.to_dict() - if output_dict: - wandb_log_dict.update(output_dict) - # Log RA-BC statistics if enabled - if rabc_weights is not None: - rabc_stats = rabc_weights.get_stats() - wandb_log_dict.update( - { - "rabc_delta_mean": rabc_stats["delta_mean"], - "rabc_delta_std": rabc_stats["delta_std"], - "rabc_num_frames": rabc_stats["num_frames"], - } - ) - wandb_logger.log_dict(wandb_log_dict, step) - train_tracker.reset_averages() + wandb_logger.log_policy(checkpoint_dir) - if cfg.save_checkpoint and is_saving_step: - if is_main_process: - logging.info(f"Checkpoint policy after step {step}") - checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) - save_checkpoint( - checkpoint_dir=checkpoint_dir, - step=step, - cfg=cfg, + accelerator.wait_for_everyone() + + if cfg.env and is_eval_step: + if is_main_process: + step_id = get_step_identifier(step, cfg.steps) + logging.info(f"Eval policy at step {step}") + with torch.no_grad(), accelerator.autocast(): + eval_info = eval_policy_all( + envs=eval_env, # dict[suite][task_id] -> vec_env policy=accelerator.unwrap_model(policy), - optimizer=optimizer, - scheduler=lr_scheduler, + env_preprocessor=env_preprocessor, + env_postprocessor=env_postprocessor, preprocessor=preprocessor, postprocessor=postprocessor, + n_episodes=cfg.eval.n_episodes, + videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", + max_episodes_rendered=4, + start_seed=cfg.seed, + max_parallel_tasks=cfg.env.max_parallel_tasks, ) - update_last_checkpoint(checkpoint_dir) - if wandb_logger: - wandb_logger.log_policy(checkpoint_dir) + # overall metrics (suite-agnostic) + aggregated = eval_info["overall"] - accelerator.wait_for_everyone() + # optional: per-suite logging + for suite, suite_info in eval_info.items(): + logging.info("Suite %s aggregated: %s", suite, suite_info) - if cfg.env and is_eval_step: - if is_main_process: - step_id = get_step_identifier(step, cfg.steps) - logging.info(f"Eval policy at step {step}") - with torch.no_grad(), accelerator.autocast(): - eval_info = eval_policy_all( - envs=eval_env, # dict[suite][task_id] -> vec_env - policy=accelerator.unwrap_model(policy), - env_preprocessor=env_preprocessor, - env_postprocessor=env_postprocessor, - preprocessor=preprocessor, - postprocessor=postprocessor, - n_episodes=cfg.eval.n_episodes, - videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", - max_episodes_rendered=4, - start_seed=cfg.seed, - max_parallel_tasks=cfg.env.max_parallel_tasks, - ) - # overall metrics (suite-agnostic) - aggregated = eval_info["overall"] + # meters/tracker + eval_metrics = { + "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), + "pc_success": AverageMeter("success", ":.1f"), + "eval_s": AverageMeter("eval_s", ":.3f"), + } + eval_tracker = MetricsTracker( + cfg.batch_size, + dataset.num_frames, + dataset.num_episodes, + eval_metrics, + initial_step=step, + accelerator=accelerator, + ) + eval_tracker.eval_s = aggregated.pop("eval_s") + eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward") + eval_tracker.pc_success = aggregated.pop("pc_success") + if wandb_logger: + wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} + wandb_logger.log_dict(wandb_log_dict, step, mode="eval") + wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval") - # optional: per-suite logging - for suite, suite_info in eval_info.items(): - logging.info("Suite %s aggregated: %s", suite, suite_info) - - # meters/tracker - eval_metrics = { - "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), - "pc_success": AverageMeter("success", ":.1f"), - "eval_s": AverageMeter("eval_s", ":.3f"), - } - eval_tracker = MetricsTracker( - cfg.batch_size, - dataset.num_frames, - dataset.num_episodes, - eval_metrics, - initial_step=step, - accelerator=accelerator, - ) - eval_tracker.eval_s = aggregated.pop("eval_s") - eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward") - eval_tracker.pc_success = aggregated.pop("pc_success") - if wandb_logger: - wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} - wandb_logger.log_dict(wandb_log_dict, step, mode="eval") - wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval") - - accelerator.wait_for_everyone() + accelerator.wait_for_everyone() if is_main_process: progbar.close() diff --git a/src/lerobot/utils/profiling_utils.py b/src/lerobot/utils/profiling_utils.py index 874b1ffb9..67e02b712 100644 --- a/src/lerobot/utils/profiling_utils.py +++ b/src/lerobot/utils/profiling_utils.py @@ -356,14 +356,10 @@ class TrainingProfiler: if self._device.type == "cuda": torch.cuda.empty_cache() - def __enter__(self) -> TrainingProfiler: + def start(self) -> None: if self._device.type == "cuda": torch.cuda.reset_peak_memory_stats(self._device) self._torch_profiler.__enter__() - return self - - def __exit__(self, *exc: Any) -> bool: - return self._torch_profiler.__exit__(*exc) @contextmanager def section(self, name: str) -> Iterator[None]: @@ -394,6 +390,7 @@ class TrainingProfiler: self._torch_profiler.step() def finalize(self) -> None: + self._torch_profiler.__exit__(None, None, None) extra: dict[str, Any] = {"profile_mode": self._mode} if self._device.type == "cuda": extra["peak_memory_allocated_bytes"] = torch.cuda.max_memory_allocated(self._device) diff --git a/tests/scripts/test_model_profiling.py b/tests/scripts/test_model_profiling.py index 10d4e4b1b..d3e2de58b 100644 --- a/tests/scripts/test_model_profiling.py +++ b/tests/scripts/test_model_profiling.py @@ -400,12 +400,12 @@ def test_training_profiler_section_records_duration(tmp_path): output_dir=tmp_path, device=torch.device("cpu"), ) - with profiler: - with profiler.section("forward"): - pass - with profiler.section("backward"): - pass - profiler.step(1, argparse.Namespace(update_s=0.5, dataloading_s=0.01)) + profiler.start() + with profiler.section("forward"): + pass + with profiler.section("backward"): + pass + profiler.step(1, argparse.Namespace(update_s=0.5, dataloading_s=0.01)) profiler.finalize() payload = json.loads((tmp_path / "step_timing_summary.json").read_text())