refactor(profiling): shrink lerobot_train.py diff via start()/finalize()

Replace the `with profiler or nullcontext():` wrap around the entire
training loop with explicit `profiler.start()` / `profiler.finalize()`
calls, and tighten `_section(...)` regions in `update_policy` to only
wrap the hot calls (forward / backward / optimizer.step).

This avoids ~120 lines of pure re-indentation noise while keeping the
exact same artifacts on disk and the same public behavior.

lerobot_train.py diff vs main: 267 -> 29 changed lines.

Made-with: Cursor
This commit is contained in:
Pepijn
2026-04-17 10:59:43 +01:00
parent a6dd28e8b4
commit 1ac8e96575
3 changed files with 128 additions and 131 deletions
+120 -120
View File
@@ -130,23 +130,22 @@ def update_policy(
with _section("backward"): with _section("backward"):
accelerator.backward(loss) accelerator.backward(loss)
# Clip gradients if specified # Clip gradients if specified
if grad_clip_norm > 0: if grad_clip_norm > 0:
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
else: else:
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), float("inf"), error_if_nonfinite=False policy.parameters(), float("inf"), error_if_nonfinite=False
) )
with _section("optimizer"): with _section("optimizer"), lock if lock is not None else nullcontext():
with lock if lock is not None else nullcontext(): optimizer.step()
optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
# Step through pytorch scheduler at every batch instead of epoch # Step through pytorch scheduler at every batch instead of epoch
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step() lr_scheduler.step()
# Update internal buffers if policy has update method # Update internal buffers if policy has update method
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
@@ -327,6 +326,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
profiler.record_deterministic_forward( profiler.record_deterministic_forward(
policy=policy, dataset=dataset, batch_size=cfg.batch_size, preprocessor=preprocessor policy=policy, dataset=dataset, batch_size=cfg.batch_size, preprocessor=preprocessor
) )
profiler.start()
# Load precomputed SARM progress for RA-BC if enabled # Load precomputed SARM progress for RA-BC if enabled
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py # Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
@@ -441,124 +441,124 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info( logging.info(
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}" f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
) )
with profiler or nullcontext():
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
batch = preprocessor(batch)
train_tracker.dataloading_s = time.perf_counter() - start_time
train_tracker, output_dict = update_policy( for _ in range(step, cfg.steps):
train_tracker, start_time = time.perf_counter()
policy, batch = next(dl_iter)
batch, batch = preprocessor(batch)
optimizer, train_tracker.dataloading_s = time.perf_counter() - start_time
cfg.optimizer.grad_clip_norm,
accelerator=accelerator,
lr_scheduler=lr_scheduler,
rabc_weights_provider=rabc_weights,
profiler=profiler,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we train_tracker, output_dict = update_policy(
# increment `step` here. train_tracker,
step += 1 policy,
batch,
optimizer,
cfg.optimizer.grad_clip_norm,
accelerator=accelerator,
lr_scheduler=lr_scheduler,
rabc_weights_provider=rabc_weights,
profiler=profiler,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
step += 1
if is_main_process:
progbar.update(1)
if profiler:
profiler.step(step, train_tracker)
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
if is_log_step:
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log RA-BC statistics if enabled
if rabc_weights is not None:
rabc_stats = rabc_weights.get_stats()
wandb_log_dict.update(
{
"rabc_delta_mean": rabc_stats["delta_mean"],
"rabc_delta_std": rabc_stats["delta_std"],
"rabc_num_frames": rabc_stats["num_frames"],
}
)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step:
if is_main_process: if is_main_process:
progbar.update(1) logging.info(f"Checkpoint policy after step {step}")
if profiler: checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
profiler.step(step, train_tracker) save_checkpoint(
train_tracker.step() checkpoint_dir=checkpoint_dir,
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process step=step,
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps cfg=cfg,
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 policy=accelerator.unwrap_model(policy),
optimizer=optimizer,
if is_log_step: scheduler=lr_scheduler,
logging.info(train_tracker) preprocessor=preprocessor,
postprocessor=postprocessor,
)
update_last_checkpoint(checkpoint_dir)
if wandb_logger: if wandb_logger:
wandb_log_dict = train_tracker.to_dict() wandb_logger.log_policy(checkpoint_dir)
if output_dict:
wandb_log_dict.update(output_dict)
# Log RA-BC statistics if enabled
if rabc_weights is not None:
rabc_stats = rabc_weights.get_stats()
wandb_log_dict.update(
{
"rabc_delta_mean": rabc_stats["delta_mean"],
"rabc_delta_std": rabc_stats["delta_std"],
"rabc_num_frames": rabc_stats["num_frames"],
}
)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step: accelerator.wait_for_everyone()
if is_main_process:
logging.info(f"Checkpoint policy after step {step}") if cfg.env and is_eval_step:
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) if is_main_process:
save_checkpoint( step_id = get_step_identifier(step, cfg.steps)
checkpoint_dir=checkpoint_dir, logging.info(f"Eval policy at step {step}")
step=step, with torch.no_grad(), accelerator.autocast():
cfg=cfg, eval_info = eval_policy_all(
envs=eval_env, # dict[suite][task_id] -> vec_env
policy=accelerator.unwrap_model(policy), policy=accelerator.unwrap_model(policy),
optimizer=optimizer, env_preprocessor=env_preprocessor,
scheduler=lr_scheduler, env_postprocessor=env_postprocessor,
preprocessor=preprocessor, preprocessor=preprocessor,
postprocessor=postprocessor, postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
max_parallel_tasks=cfg.env.max_parallel_tasks,
) )
update_last_checkpoint(checkpoint_dir) # overall metrics (suite-agnostic)
if wandb_logger: aggregated = eval_info["overall"]
wandb_logger.log_policy(checkpoint_dir)
accelerator.wait_for_everyone() # optional: per-suite logging
for suite, suite_info in eval_info.items():
logging.info("Suite %s aggregated: %s", suite, suite_info)
if cfg.env and is_eval_step: # meters/tracker
if is_main_process: eval_metrics = {
step_id = get_step_identifier(step, cfg.steps) "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
logging.info(f"Eval policy at step {step}") "pc_success": AverageMeter("success", ":.1f"),
with torch.no_grad(), accelerator.autocast(): "eval_s": AverageMeter("eval_s", ":.3f"),
eval_info = eval_policy_all( }
envs=eval_env, # dict[suite][task_id] -> vec_env eval_tracker = MetricsTracker(
policy=accelerator.unwrap_model(policy), cfg.batch_size,
env_preprocessor=env_preprocessor, dataset.num_frames,
env_postprocessor=env_postprocessor, dataset.num_episodes,
preprocessor=preprocessor, eval_metrics,
postprocessor=postprocessor, initial_step=step,
n_episodes=cfg.eval.n_episodes, accelerator=accelerator,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", )
max_episodes_rendered=4, eval_tracker.eval_s = aggregated.pop("eval_s")
start_seed=cfg.seed, eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
max_parallel_tasks=cfg.env.max_parallel_tasks, eval_tracker.pc_success = aggregated.pop("pc_success")
) if wandb_logger:
# overall metrics (suite-agnostic) wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
aggregated = eval_info["overall"] wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
# optional: per-suite logging accelerator.wait_for_everyone()
for suite, suite_info in eval_info.items():
logging.info("Suite %s aggregated: %s", suite, suite_info)
# meters/tracker
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
"pc_success": AverageMeter("success", ":.1f"),
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
eval_metrics,
initial_step=step,
accelerator=accelerator,
)
eval_tracker.eval_s = aggregated.pop("eval_s")
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
eval_tracker.pc_success = aggregated.pop("pc_success")
if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
accelerator.wait_for_everyone()
if is_main_process: if is_main_process:
progbar.close() progbar.close()
+2 -5
View File
@@ -356,14 +356,10 @@ class TrainingProfiler:
if self._device.type == "cuda": if self._device.type == "cuda":
torch.cuda.empty_cache() torch.cuda.empty_cache()
def __enter__(self) -> TrainingProfiler: def start(self) -> None:
if self._device.type == "cuda": if self._device.type == "cuda":
torch.cuda.reset_peak_memory_stats(self._device) torch.cuda.reset_peak_memory_stats(self._device)
self._torch_profiler.__enter__() self._torch_profiler.__enter__()
return self
def __exit__(self, *exc: Any) -> bool:
return self._torch_profiler.__exit__(*exc)
@contextmanager @contextmanager
def section(self, name: str) -> Iterator[None]: def section(self, name: str) -> Iterator[None]:
@@ -394,6 +390,7 @@ class TrainingProfiler:
self._torch_profiler.step() self._torch_profiler.step()
def finalize(self) -> None: def finalize(self) -> None:
self._torch_profiler.__exit__(None, None, None)
extra: dict[str, Any] = {"profile_mode": self._mode} extra: dict[str, Any] = {"profile_mode": self._mode}
if self._device.type == "cuda": if self._device.type == "cuda":
extra["peak_memory_allocated_bytes"] = torch.cuda.max_memory_allocated(self._device) extra["peak_memory_allocated_bytes"] = torch.cuda.max_memory_allocated(self._device)
+6 -6
View File
@@ -400,12 +400,12 @@ def test_training_profiler_section_records_duration(tmp_path):
output_dir=tmp_path, output_dir=tmp_path,
device=torch.device("cpu"), device=torch.device("cpu"),
) )
with profiler: profiler.start()
with profiler.section("forward"): with profiler.section("forward"):
pass pass
with profiler.section("backward"): with profiler.section("backward"):
pass pass
profiler.step(1, argparse.Namespace(update_s=0.5, dataloading_s=0.01)) profiler.step(1, argparse.Namespace(update_s=0.5, dataloading_s=0.01))
profiler.finalize() profiler.finalize()
payload = json.loads((tmp_path / "step_timing_summary.json").read_text()) payload = json.loads((tmp_path / "step_timing_summary.json").read_text())