mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
refactor(profiling): shrink lerobot_train.py diff via start()/finalize()
Replace the `with profiler or nullcontext():` wrap around the entire training loop with explicit `profiler.start()` / `profiler.finalize()` calls, and tighten `_section(...)` regions in `update_policy` to only wrap the hot calls (forward / backward / optimizer.step). This avoids ~120 lines of pure re-indentation noise while keeping the exact same artifacts on disk and the same public behavior. lerobot_train.py diff vs main: 267 -> 29 changed lines. Made-with: Cursor
This commit is contained in:
@@ -130,23 +130,22 @@ def update_policy(
|
|||||||
with _section("backward"):
|
with _section("backward"):
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
|
|
||||||
# Clip gradients if specified
|
# Clip gradients if specified
|
||||||
if grad_clip_norm > 0:
|
if grad_clip_norm > 0:
|
||||||
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
||||||
else:
|
else:
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
policy.parameters(), float("inf"), error_if_nonfinite=False
|
policy.parameters(), float("inf"), error_if_nonfinite=False
|
||||||
)
|
)
|
||||||
|
|
||||||
with _section("optimizer"):
|
with _section("optimizer"), lock if lock is not None else nullcontext():
|
||||||
with lock if lock is not None else nullcontext():
|
optimizer.step()
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Step through pytorch scheduler at every batch instead of epoch
|
# Step through pytorch scheduler at every batch instead of epoch
|
||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
# Update internal buffers if policy has update method
|
# Update internal buffers if policy has update method
|
||||||
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
|
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
|
||||||
@@ -327,6 +326,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
profiler.record_deterministic_forward(
|
profiler.record_deterministic_forward(
|
||||||
policy=policy, dataset=dataset, batch_size=cfg.batch_size, preprocessor=preprocessor
|
policy=policy, dataset=dataset, batch_size=cfg.batch_size, preprocessor=preprocessor
|
||||||
)
|
)
|
||||||
|
profiler.start()
|
||||||
|
|
||||||
# Load precomputed SARM progress for RA-BC if enabled
|
# Load precomputed SARM progress for RA-BC if enabled
|
||||||
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
|
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
|
||||||
@@ -441,124 +441,124 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
logging.info(
|
logging.info(
|
||||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||||
)
|
)
|
||||||
with profiler or nullcontext():
|
|
||||||
for _ in range(step, cfg.steps):
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
batch = next(dl_iter)
|
|
||||||
batch = preprocessor(batch)
|
|
||||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
|
||||||
|
|
||||||
train_tracker, output_dict = update_policy(
|
for _ in range(step, cfg.steps):
|
||||||
train_tracker,
|
start_time = time.perf_counter()
|
||||||
policy,
|
batch = next(dl_iter)
|
||||||
batch,
|
batch = preprocessor(batch)
|
||||||
optimizer,
|
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||||
cfg.optimizer.grad_clip_norm,
|
|
||||||
accelerator=accelerator,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
rabc_weights_provider=rabc_weights,
|
|
||||||
profiler=profiler,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
train_tracker, output_dict = update_policy(
|
||||||
# increment `step` here.
|
train_tracker,
|
||||||
step += 1
|
policy,
|
||||||
|
batch,
|
||||||
|
optimizer,
|
||||||
|
cfg.optimizer.grad_clip_norm,
|
||||||
|
accelerator=accelerator,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
rabc_weights_provider=rabc_weights,
|
||||||
|
profiler=profiler,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||||
|
# increment `step` here.
|
||||||
|
step += 1
|
||||||
|
if is_main_process:
|
||||||
|
progbar.update(1)
|
||||||
|
if profiler:
|
||||||
|
profiler.step(step, train_tracker)
|
||||||
|
train_tracker.step()
|
||||||
|
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||||
|
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||||
|
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||||
|
|
||||||
|
if is_log_step:
|
||||||
|
logging.info(train_tracker)
|
||||||
|
if wandb_logger:
|
||||||
|
wandb_log_dict = train_tracker.to_dict()
|
||||||
|
if output_dict:
|
||||||
|
wandb_log_dict.update(output_dict)
|
||||||
|
# Log RA-BC statistics if enabled
|
||||||
|
if rabc_weights is not None:
|
||||||
|
rabc_stats = rabc_weights.get_stats()
|
||||||
|
wandb_log_dict.update(
|
||||||
|
{
|
||||||
|
"rabc_delta_mean": rabc_stats["delta_mean"],
|
||||||
|
"rabc_delta_std": rabc_stats["delta_std"],
|
||||||
|
"rabc_num_frames": rabc_stats["num_frames"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
wandb_logger.log_dict(wandb_log_dict, step)
|
||||||
|
train_tracker.reset_averages()
|
||||||
|
|
||||||
|
if cfg.save_checkpoint and is_saving_step:
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
progbar.update(1)
|
logging.info(f"Checkpoint policy after step {step}")
|
||||||
if profiler:
|
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||||
profiler.step(step, train_tracker)
|
save_checkpoint(
|
||||||
train_tracker.step()
|
checkpoint_dir=checkpoint_dir,
|
||||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
step=step,
|
||||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
cfg=cfg,
|
||||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
policy=accelerator.unwrap_model(policy),
|
||||||
|
optimizer=optimizer,
|
||||||
if is_log_step:
|
scheduler=lr_scheduler,
|
||||||
logging.info(train_tracker)
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
)
|
||||||
|
update_last_checkpoint(checkpoint_dir)
|
||||||
if wandb_logger:
|
if wandb_logger:
|
||||||
wandb_log_dict = train_tracker.to_dict()
|
wandb_logger.log_policy(checkpoint_dir)
|
||||||
if output_dict:
|
|
||||||
wandb_log_dict.update(output_dict)
|
|
||||||
# Log RA-BC statistics if enabled
|
|
||||||
if rabc_weights is not None:
|
|
||||||
rabc_stats = rabc_weights.get_stats()
|
|
||||||
wandb_log_dict.update(
|
|
||||||
{
|
|
||||||
"rabc_delta_mean": rabc_stats["delta_mean"],
|
|
||||||
"rabc_delta_std": rabc_stats["delta_std"],
|
|
||||||
"rabc_num_frames": rabc_stats["num_frames"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
wandb_logger.log_dict(wandb_log_dict, step)
|
|
||||||
train_tracker.reset_averages()
|
|
||||||
|
|
||||||
if cfg.save_checkpoint and is_saving_step:
|
accelerator.wait_for_everyone()
|
||||||
if is_main_process:
|
|
||||||
logging.info(f"Checkpoint policy after step {step}")
|
if cfg.env and is_eval_step:
|
||||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
if is_main_process:
|
||||||
save_checkpoint(
|
step_id = get_step_identifier(step, cfg.steps)
|
||||||
checkpoint_dir=checkpoint_dir,
|
logging.info(f"Eval policy at step {step}")
|
||||||
step=step,
|
with torch.no_grad(), accelerator.autocast():
|
||||||
cfg=cfg,
|
eval_info = eval_policy_all(
|
||||||
|
envs=eval_env, # dict[suite][task_id] -> vec_env
|
||||||
policy=accelerator.unwrap_model(policy),
|
policy=accelerator.unwrap_model(policy),
|
||||||
optimizer=optimizer,
|
env_preprocessor=env_preprocessor,
|
||||||
scheduler=lr_scheduler,
|
env_postprocessor=env_postprocessor,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
|
n_episodes=cfg.eval.n_episodes,
|
||||||
|
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||||
|
max_episodes_rendered=4,
|
||||||
|
start_seed=cfg.seed,
|
||||||
|
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||||
)
|
)
|
||||||
update_last_checkpoint(checkpoint_dir)
|
# overall metrics (suite-agnostic)
|
||||||
if wandb_logger:
|
aggregated = eval_info["overall"]
|
||||||
wandb_logger.log_policy(checkpoint_dir)
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
# optional: per-suite logging
|
||||||
|
for suite, suite_info in eval_info.items():
|
||||||
|
logging.info("Suite %s aggregated: %s", suite, suite_info)
|
||||||
|
|
||||||
if cfg.env and is_eval_step:
|
# meters/tracker
|
||||||
if is_main_process:
|
eval_metrics = {
|
||||||
step_id = get_step_identifier(step, cfg.steps)
|
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||||
logging.info(f"Eval policy at step {step}")
|
"pc_success": AverageMeter("success", ":.1f"),
|
||||||
with torch.no_grad(), accelerator.autocast():
|
"eval_s": AverageMeter("eval_s", ":.3f"),
|
||||||
eval_info = eval_policy_all(
|
}
|
||||||
envs=eval_env, # dict[suite][task_id] -> vec_env
|
eval_tracker = MetricsTracker(
|
||||||
policy=accelerator.unwrap_model(policy),
|
cfg.batch_size,
|
||||||
env_preprocessor=env_preprocessor,
|
dataset.num_frames,
|
||||||
env_postprocessor=env_postprocessor,
|
dataset.num_episodes,
|
||||||
preprocessor=preprocessor,
|
eval_metrics,
|
||||||
postprocessor=postprocessor,
|
initial_step=step,
|
||||||
n_episodes=cfg.eval.n_episodes,
|
accelerator=accelerator,
|
||||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
)
|
||||||
max_episodes_rendered=4,
|
eval_tracker.eval_s = aggregated.pop("eval_s")
|
||||||
start_seed=cfg.seed,
|
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
|
||||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
eval_tracker.pc_success = aggregated.pop("pc_success")
|
||||||
)
|
if wandb_logger:
|
||||||
# overall metrics (suite-agnostic)
|
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
||||||
aggregated = eval_info["overall"]
|
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
||||||
|
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
|
||||||
|
|
||||||
# optional: per-suite logging
|
accelerator.wait_for_everyone()
|
||||||
for suite, suite_info in eval_info.items():
|
|
||||||
logging.info("Suite %s aggregated: %s", suite, suite_info)
|
|
||||||
|
|
||||||
# meters/tracker
|
|
||||||
eval_metrics = {
|
|
||||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
|
||||||
"pc_success": AverageMeter("success", ":.1f"),
|
|
||||||
"eval_s": AverageMeter("eval_s", ":.3f"),
|
|
||||||
}
|
|
||||||
eval_tracker = MetricsTracker(
|
|
||||||
cfg.batch_size,
|
|
||||||
dataset.num_frames,
|
|
||||||
dataset.num_episodes,
|
|
||||||
eval_metrics,
|
|
||||||
initial_step=step,
|
|
||||||
accelerator=accelerator,
|
|
||||||
)
|
|
||||||
eval_tracker.eval_s = aggregated.pop("eval_s")
|
|
||||||
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
|
|
||||||
eval_tracker.pc_success = aggregated.pop("pc_success")
|
|
||||||
if wandb_logger:
|
|
||||||
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
|
||||||
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
|
||||||
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
progbar.close()
|
progbar.close()
|
||||||
|
|||||||
@@ -356,14 +356,10 @@ class TrainingProfiler:
|
|||||||
if self._device.type == "cuda":
|
if self._device.type == "cuda":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def __enter__(self) -> TrainingProfiler:
|
def start(self) -> None:
|
||||||
if self._device.type == "cuda":
|
if self._device.type == "cuda":
|
||||||
torch.cuda.reset_peak_memory_stats(self._device)
|
torch.cuda.reset_peak_memory_stats(self._device)
|
||||||
self._torch_profiler.__enter__()
|
self._torch_profiler.__enter__()
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *exc: Any) -> bool:
|
|
||||||
return self._torch_profiler.__exit__(*exc)
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def section(self, name: str) -> Iterator[None]:
|
def section(self, name: str) -> Iterator[None]:
|
||||||
@@ -394,6 +390,7 @@ class TrainingProfiler:
|
|||||||
self._torch_profiler.step()
|
self._torch_profiler.step()
|
||||||
|
|
||||||
def finalize(self) -> None:
|
def finalize(self) -> None:
|
||||||
|
self._torch_profiler.__exit__(None, None, None)
|
||||||
extra: dict[str, Any] = {"profile_mode": self._mode}
|
extra: dict[str, Any] = {"profile_mode": self._mode}
|
||||||
if self._device.type == "cuda":
|
if self._device.type == "cuda":
|
||||||
extra["peak_memory_allocated_bytes"] = torch.cuda.max_memory_allocated(self._device)
|
extra["peak_memory_allocated_bytes"] = torch.cuda.max_memory_allocated(self._device)
|
||||||
|
|||||||
@@ -400,12 +400,12 @@ def test_training_profiler_section_records_duration(tmp_path):
|
|||||||
output_dir=tmp_path,
|
output_dir=tmp_path,
|
||||||
device=torch.device("cpu"),
|
device=torch.device("cpu"),
|
||||||
)
|
)
|
||||||
with profiler:
|
profiler.start()
|
||||||
with profiler.section("forward"):
|
with profiler.section("forward"):
|
||||||
pass
|
pass
|
||||||
with profiler.section("backward"):
|
with profiler.section("backward"):
|
||||||
pass
|
pass
|
||||||
profiler.step(1, argparse.Namespace(update_s=0.5, dataloading_s=0.01))
|
profiler.step(1, argparse.Namespace(update_s=0.5, dataloading_s=0.01))
|
||||||
profiler.finalize()
|
profiler.finalize()
|
||||||
|
|
||||||
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
|
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
|
||||||
|
|||||||
Reference in New Issue
Block a user