refactor: extract profiling into self-contained TrainingProfiler class

Move all profiling orchestration out of lerobot_train.py and
TrainPipelineConfig into a TrainingProfiler class in profiling_utils.py.

- lerobot_train.py: ~74 lines of profiling code reduced to ~7 call sites
- TrainPipelineConfig: 10 profile_* fields reduced to 2 (mode + output_dir)
- update_policy: reverted to clean main-branch signature (no timing_collector)
- TrainingProfiler encapsulates torch profiler, timing collection,
  deterministic forward artifacts, and all output writing
- CI script (run_model_profiling.py) unchanged—it only passes the 2 kept fields

Made-with: Cursor
This commit is contained in:
Pepijn
2026-04-16 16:00:49 +02:00
parent a4544ffea7
commit b1e16783de
4 changed files with 148 additions and 138 deletions
-12
View File
@@ -57,15 +57,7 @@ class TrainPipelineConfig(HubMixin):
num_workers: int = 4
batch_size: int = 8
profile_mode: str = "off"
profile_wait_steps: int = 1
profile_warmup_steps: int = 2
profile_active_steps: int = 6
profile_repeat: int = 1
profile_output_dir: Path | None = None
profile_record_shapes: bool = True
profile_with_memory: bool = True
profile_with_flops: bool = True
profile_with_stack: bool = False
steps: int = 100_000
eval_freq: int = 20_000
log_freq: int = 200
@@ -147,10 +139,6 @@ class TrainPipelineConfig(HubMixin):
raise ValueError(
f"`profile_mode` must be one of 'off', 'summary', or 'trace', got {self.profile_mode}."
)
if self.profile_wait_steps < 0 or self.profile_warmup_steps < 0 or self.profile_active_steps < 0:
raise ValueError("Profiler schedule steps must be non-negative.")
if self.profile_repeat <= 0:
raise ValueError("`profile_repeat` must be strictly positive.")
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
+12 -73
View File
@@ -22,7 +22,6 @@ import dataclasses
import logging
import time
from contextlib import nullcontext
from pathlib import Path
from pprint import pformat
from typing import TYPE_CHECKING, Any
@@ -50,13 +49,7 @@ from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.profiling_utils import (
StepTimingCollector,
ensure_dir,
make_torch_profiler,
write_deterministic_forward_artifacts,
write_torch_profiler_outputs,
)
from lerobot.utils.profiling_utils import TrainingProfiler
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
cycle,
@@ -79,7 +72,6 @@ def update_policy(
lr_scheduler=None,
lock=None,
rabc_weights_provider=None,
timing_collector: StepTimingCollector | None = None,
) -> tuple[MetricsTracker, dict]:
"""
Performs a single training step to update the policy's weights.
@@ -113,7 +105,6 @@ def update_policy(
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
# Let accelerator handle mixed precision
forward_start = time.perf_counter()
with accelerator.autocast():
# Use per-sample loss when RA-BC is enabled for proper weighting
if rabc_batch_weights is not None:
@@ -132,15 +123,11 @@ def update_policy(
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
forward_s = time.perf_counter() - forward_start
# Use accelerator's backward method
backward_start = time.perf_counter()
accelerator.backward(loss)
backward_s = time.perf_counter() - backward_start
# Clip gradients if specified
optimizer_start = time.perf_counter()
if grad_clip_norm > 0:
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
else:
@@ -161,19 +148,11 @@ def update_policy(
# Update internal buffers if policy has update method
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
optimizer_s = time.perf_counter() - optimizer_start
train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
train_metrics.lr = optimizer.param_groups[0]["lr"]
train_metrics.update_s = time.perf_counter() - start_time
if timing_collector is not None:
timing_collector.record(
forward_s=forward_s,
backward_s=backward_s,
optimizer_s=optimizer_s,
total_update_s=train_metrics.update_s.val,
)
return train_metrics, output_dict
@@ -228,12 +207,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if is_main_process:
logging.info(pformat(cfg.to_dict()))
profiling_enabled = cfg.profile_mode != "off"
profile_output_dir = None
if profiling_enabled and is_main_process and cfg.profile_output_dir is not None:
profile_output_dir = ensure_dir(Path(cfg.profile_output_dir))
logging.info("Profiling enabled. Artifacts will be written to %s", profile_output_dir)
# Initialize wandb only on main process
if cfg.wandb.enable and cfg.wandb.project and is_main_process:
wandb_logger = WandBLogger(cfg)
@@ -344,15 +317,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
if profiling_enabled and is_main_process and profile_output_dir is not None:
logging.info("Recording deterministic forward-pass artifacts")
write_deterministic_forward_artifacts(
policy=policy,
dataset=dataset,
batch_size=cfg.batch_size,
preprocessor=preprocessor,
output_dir=profile_output_dir,
device_type=device.type,
profiler = (
TrainingProfiler.from_cfg(cfg, device) if cfg.profile_mode != "off" and is_main_process else None
)
if profiler:
profiler.record_deterministic_forward(
policy=policy, dataset=dataset, batch_size=cfg.batch_size, preprocessor=preprocessor
)
# Load precomputed SARM progress for RA-BC if enabled
@@ -468,16 +438,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info(
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
)
timing_collector = StepTimingCollector() if profiling_enabled and is_main_process else None
profiler = None
profiler_context = nullcontext()
if profiling_enabled and is_main_process and profile_output_dir is not None:
if device.type == "cuda":
torch.cuda.reset_peak_memory_stats(device)
profiler = make_torch_profiler(cfg, profile_output_dir, device.type)
profiler_context = profiler
with profiler_context:
with profiler or nullcontext():
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
@@ -493,7 +454,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
accelerator=accelerator,
lr_scheduler=lr_scheduler,
rabc_weights_provider=rabc_weights,
timing_collector=timing_collector,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -501,17 +461,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
step += 1
if is_main_process:
progbar.update(1)
if timing_collector is not None:
timing_collector.record_dataloading(train_tracker.dataloading_s.val)
if device.type == "cuda":
timing_collector.record_memory(
step=step,
allocated_bytes=torch.cuda.memory_allocated(device),
reserved_bytes=torch.cuda.memory_reserved(device),
)
if profiler:
profiler.step(step, train_tracker)
train_tracker.step()
if profiler is not None:
profiler.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
@@ -606,21 +558,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if is_main_process:
progbar.close()
if timing_collector is not None and profile_output_dir is not None:
extra_profile_metrics = {
"profile_mode": cfg.profile_mode,
"peak_memory_allocated_bytes": (
torch.cuda.max_memory_allocated(device) if device.type == "cuda" else None
),
"peak_memory_reserved_bytes": (
torch.cuda.max_memory_reserved(device) if device.type == "cuda" else None
),
}
timing_collector.write_json(
profile_output_dir / "step_timing_summary.json", extra=extra_profile_metrics
)
if profiler is not None and profile_output_dir is not None:
write_torch_profiler_outputs(profiler, profile_output_dir, device_type=device.type)
if profiler:
profiler.finalize()
if eval_env:
close_envs(eval_env)
+122 -28
View File
@@ -18,6 +18,7 @@ from __future__ import annotations
import hashlib
import json
import logging
import statistics
from dataclasses import dataclass, field
from numbers import Real
@@ -47,7 +48,20 @@ def write_profiler_table(
output_path.write_text(table)
def make_torch_profiler(cfg: Any, output_dir: Path, device_type: str) -> Any:
def _make_torch_profiler(
*,
mode: str,
output_dir: Path,
device_type: str,
wait_steps: int = 1,
warmup_steps: int = 2,
active_steps: int = 6,
repeat: int = 1,
record_shapes: bool = True,
with_memory: bool = True,
with_flops: bool = True,
with_stack: bool = False,
) -> Any:
activities = [torch.profiler.ProfilerActivity.CPU]
if device_type == "cuda":
activities.append(torch.profiler.ProfilerActivity.CUDA)
@@ -55,23 +69,23 @@ def make_torch_profiler(cfg: Any, output_dir: Path, device_type: str) -> Any:
trace_dir = ensure_dir(output_dir / "torch_traces")
def _trace_ready(profiler: Any) -> None:
if cfg.profile_mode != "trace":
if mode != "trace":
return
profiler.export_chrome_trace(str(trace_dir / f"trace_step_{profiler.step_num}.json"))
return torch.profiler.profile(
activities=activities,
schedule=torch.profiler.schedule(
wait=cfg.profile_wait_steps,
warmup=cfg.profile_warmup_steps,
active=cfg.profile_active_steps,
repeat=cfg.profile_repeat,
wait=wait_steps,
warmup=warmup_steps,
active=active_steps,
repeat=repeat,
),
on_trace_ready=_trace_ready,
record_shapes=cfg.profile_record_shapes,
profile_memory=cfg.profile_with_memory,
with_flops=cfg.profile_with_flops,
with_stack=cfg.profile_with_stack,
record_shapes=record_shapes,
profile_memory=with_memory,
with_flops=with_flops,
with_stack=with_stack,
)
@@ -228,25 +242,12 @@ def _as_float(value: Any) -> float:
@dataclass
class StepTimingCollector:
forward_s: list[float] = field(default_factory=list)
backward_s: list[float] = field(default_factory=list)
optimizer_s: list[float] = field(default_factory=list)
class _StepTimingCollector:
total_update_s: list[float] = field(default_factory=list)
dataloading_s: list[float] = field(default_factory=list)
memory_timeline: list[dict[str, float | int]] = field(default_factory=list)
def record(
self,
*,
forward_s: float,
backward_s: float,
optimizer_s: float,
total_update_s: float,
) -> None:
self.forward_s.append(_as_float(forward_s))
self.backward_s.append(_as_float(backward_s))
self.optimizer_s.append(_as_float(optimizer_s))
def record_step(self, total_update_s: float) -> None:
self.total_update_s.append(_as_float(total_update_s))
def record_dataloading(self, dataloading_s: float) -> None:
@@ -263,9 +264,6 @@ class StepTimingCollector:
def to_dict(self) -> dict[str, Any]:
return {
"forward_s": _summary(self.forward_s),
"backward_s": _summary(self.backward_s),
"optimizer_s": _summary(self.optimizer_s),
"total_update_s": _summary(self.total_update_s),
"dataloading_s": _summary(self.dataloading_s),
"memory_timeline": self.memory_timeline,
@@ -277,3 +275,99 @@ class StepTimingCollector:
payload.update(extra)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(payload, indent=2, sort_keys=True))
class TrainingProfiler:
"""Self-contained profiling orchestrator for the training loop.
Encapsulates torch profiler setup, step-level timing collection, deterministic
forward-pass artifact recording, and all output writing. The training script
interacts with it through a thin interface (~7 lines).
"""
def __init__(
self,
mode: str,
output_dir: Path,
device: torch.device,
*,
wait_steps: int = 1,
warmup_steps: int = 2,
active_steps: int = 6,
repeat: int = 1,
record_shapes: bool = True,
with_memory: bool = True,
with_flops: bool = True,
with_stack: bool = False,
) -> None:
self._mode = mode
self._output_dir = ensure_dir(output_dir)
self._device = device
self._timing = _StepTimingCollector()
self._torch_profiler = _make_torch_profiler(
mode=mode,
output_dir=output_dir,
device_type=device.type,
wait_steps=wait_steps,
warmup_steps=warmup_steps,
active_steps=active_steps,
repeat=repeat,
record_shapes=record_shapes,
with_memory=with_memory,
with_flops=with_flops,
with_stack=with_stack,
)
logging.info("Profiling enabled. Artifacts will be written to %s", output_dir)
@classmethod
def from_cfg(cls, cfg: Any, device: torch.device) -> TrainingProfiler:
output_dir = cfg.profile_output_dir
if output_dir is None:
output_dir = Path(cfg.output_dir) / "profiling"
return cls(mode=cfg.profile_mode, output_dir=Path(output_dir), device=device)
def record_deterministic_forward(
self,
*,
policy: Any,
dataset: Any,
batch_size: int,
preprocessor: Any,
) -> None:
logging.info("Recording deterministic forward-pass artifacts")
write_deterministic_forward_artifacts(
policy=policy,
dataset=dataset,
batch_size=batch_size,
preprocessor=preprocessor,
output_dir=self._output_dir,
device_type=self._device.type,
)
def __enter__(self) -> TrainingProfiler:
if self._device.type == "cuda":
torch.cuda.reset_peak_memory_stats(self._device)
self._torch_profiler.__enter__()
return self
def __exit__(self, *exc: Any) -> bool:
return self._torch_profiler.__exit__(*exc)
def step(self, step_num: int, train_tracker: Any) -> None:
self._timing.record_step(_as_float(train_tracker.update_s))
self._timing.record_dataloading(_as_float(train_tracker.dataloading_s))
if self._device.type == "cuda":
self._timing.record_memory(
step=step_num,
allocated_bytes=torch.cuda.memory_allocated(self._device),
reserved_bytes=torch.cuda.memory_reserved(self._device),
)
self._torch_profiler.step()
def finalize(self) -> None:
extra: dict[str, Any] = {"profile_mode": self._mode}
if self._device.type == "cuda":
extra["peak_memory_allocated_bytes"] = torch.cuda.max_memory_allocated(self._device)
extra["peak_memory_reserved_bytes"] = torch.cuda.max_memory_reserved(self._device)
self._timing.write_json(self._output_dir / "step_timing_summary.json", extra=extra)
write_torch_profiler_outputs(self._torch_profiler, self._output_dir, device_type=self._device.type)
+14 -25
View File
@@ -65,29 +65,24 @@ def test_pretrained_libero_specs_match_expected_camera_keys_and_normalization():
specs = module.load_specs(spec_path)
assert (
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
in specs["pi0"]["train_args"]
'--rename_map={"observation.images.front": "observation.images.base_0_rgb", '
'"observation.images.wrist": "observation.images.left_wrist_0_rgb"}' in specs["pi0"].train_args
)
assert (
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
in specs["pi0_fast"]["train_args"]
'--rename_map={"observation.images.front": "observation.images.base_0_rgb", '
'"observation.images.wrist": "observation.images.left_wrist_0_rgb"}' in specs["pi0_fast"].train_args
)
assert (
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
in specs["pi05"]["train_args"]
'--rename_map={"observation.images.front": "observation.images.base_0_rgb", '
'"observation.images.wrist": "observation.images.left_wrist_0_rgb"}' in specs["pi05"].train_args
)
assert (
"--policy.normalization_mapping={\"ACTION\": \"MEAN_STD\", "
"\"STATE\": \"MEAN_STD\", \"VISUAL\": \"IDENTITY\"}"
in specs["pi05"]["train_args"]
'--policy.normalization_mapping={"ACTION": "MEAN_STD", '
'"STATE": "MEAN_STD", "VISUAL": "IDENTITY"}' in specs["pi05"].train_args
)
assert (
"--rename_map={\"observation.images.front\": \"observation.images.camera1\", "
"\"observation.images.wrist\": \"observation.images.camera2\"}"
in specs["smolvla"]["train_args"]
'--rename_map={"observation.images.front": "observation.images.camera1", '
'"observation.images.wrist": "observation.images.camera2"}' in specs["smolvla"].train_args
)
@@ -222,7 +217,6 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
(profile_dir / "step_timing_summary.json").write_text(
json.dumps(
{
"forward_s": {"count": 1, "mean": 0.1, "median": 0.1, "min": 0.1, "max": 0.1},
"total_update_s": {"count": 1, "mean": 0.3, "median": 0.3, "min": 0.3, "max": 0.3},
"peak_memory_allocated_bytes": 1024,
}
@@ -251,7 +245,7 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
assert row["git_commit"] == "deadbeef"
assert row["git_ref"] == "codex/model-profiling"
assert row["pr_number"] == 3389
assert row["step_timing_summary"]["forward_s"]["mean"] == 0.1
assert row["step_timing_summary"]["total_update_s"]["mean"] == 0.3
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
@@ -364,19 +358,14 @@ def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
def test_step_timing_collector_accepts_metric_like_values(tmp_path):
from lerobot.utils.profiling_utils import StepTimingCollector
from lerobot.utils.profiling_utils import _StepTimingCollector
class _MetricLike:
def __init__(self, val):
self.val = val
collector = StepTimingCollector()
collector.record(
forward_s=0.1,
backward_s=0.2,
optimizer_s=0.3,
total_update_s=_MetricLike(0.6),
)
collector = _StepTimingCollector()
collector.record_step(_MetricLike(0.6))
collector.record_dataloading(_MetricLike(0.05))
collector.write_json(tmp_path / "step_timing_summary.json")