feat(profiling): record forward/backward/optimizer timings

The dashboard expects per-phase timings (forward_s, backward_s,
optimizer_s) in step_timing_summary.json, but only total_update_s
and dataloading_s were collected — leaving every chart except
dataloading empty.

Add a lightweight TrainingProfiler.section(name) context manager
that times a region with torch.cuda.synchronize before and after
(so GPU work is captured, not just the kernel-launch latency) and
accumulates per-section samples into step_timing_summary.json.

Wrap forward, backward (incl. grad clip), and optimizer (incl.
zero_grad and scheduler.step) in update_policy with these sections.
When profiling is off (profiler=None) the wrappers become no-ops,
so training performance is unchanged outside CI.

Made-with: Cursor
This commit is contained in:
Pepijn
2026-04-16 20:26:27 +02:00
parent 00e9defb80
commit 1842100402
3 changed files with 90 additions and 19 deletions
+22 -18
View File
@@ -72,6 +72,7 @@ def update_policy(
lr_scheduler=None,
lock=None,
rabc_weights_provider=None,
profiler: "TrainingProfiler | None" = None,
) -> tuple[MetricsTracker, dict]:
"""
Performs a single training step to update the policy's weights.
@@ -104,8 +105,10 @@ def update_policy(
if rabc_weights_provider is not None:
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
# Let accelerator handle mixed precision
with accelerator.autocast():
def _section(name: str) -> Any:
return profiler.section(name) if profiler is not None else nullcontext()
with _section("forward"), accelerator.autocast():
# Use per-sample loss when RA-BC is enabled for proper weighting
if rabc_batch_weights is not None:
# Get per-sample losses
@@ -124,26 +127,26 @@ def update_policy(
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
# Use accelerator's backward method
accelerator.backward(loss)
with _section("backward"):
accelerator.backward(loss)
# Clip gradients if specified
if grad_clip_norm > 0:
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), float("inf"), error_if_nonfinite=False
)
# Clip gradients if specified
if grad_clip_norm > 0:
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), float("inf"), error_if_nonfinite=False
)
# Optimizer step
with lock if lock is not None else nullcontext():
optimizer.step()
with _section("optimizer"):
with lock if lock is not None else nullcontext():
optimizer.step()
optimizer.zero_grad()
optimizer.zero_grad()
# Step through pytorch scheduler at every batch instead of epoch
if lr_scheduler is not None:
lr_scheduler.step()
# Step through pytorch scheduler at every batch instead of epoch
if lr_scheduler is not None:
lr_scheduler.step()
# Update internal buffers if policy has update method
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
@@ -454,6 +457,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
accelerator=accelerator,
lr_scheduler=lr_scheduler,
rabc_weights_provider=rabc_weights,
profiler=profiler,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
+28 -1
View File
@@ -20,6 +20,9 @@ import hashlib
import json
import logging
import statistics
import time
from collections.abc import Iterator
from contextlib import contextmanager
from dataclasses import dataclass, field
from numbers import Real
from pathlib import Path
@@ -245,6 +248,7 @@ def _as_float(value: Any) -> float:
class _StepTimingCollector:
total_update_s: list[float] = field(default_factory=list)
dataloading_s: list[float] = field(default_factory=list)
section_s: dict[str, list[float]] = field(default_factory=dict)
memory_timeline: list[dict[str, float | int]] = field(default_factory=list)
def record_step(self, total_update_s: float) -> None:
@@ -253,6 +257,9 @@ class _StepTimingCollector:
def record_dataloading(self, dataloading_s: float) -> None:
self.dataloading_s.append(_as_float(dataloading_s))
def record_section(self, name: str, duration_s: float) -> None:
self.section_s.setdefault(name, []).append(_as_float(duration_s))
def record_memory(self, *, step: int, allocated_bytes: int, reserved_bytes: int) -> None:
self.memory_timeline.append(
{
@@ -263,11 +270,14 @@ class _StepTimingCollector:
)
def to_dict(self) -> dict[str, Any]:
return {
payload: dict[str, Any] = {
"total_update_s": _summary(self.total_update_s),
"dataloading_s": _summary(self.dataloading_s),
"memory_timeline": self.memory_timeline,
}
for name, values in self.section_s.items():
payload[f"{name}_s"] = _summary(values)
return payload
def write_json(self, output_path: Path, extra: dict[str, Any] | None = None) -> None:
payload = self.to_dict()
@@ -355,6 +365,23 @@ class TrainingProfiler:
def __exit__(self, *exc: Any) -> bool:
return self._torch_profiler.__exit__(*exc)
@contextmanager
def section(self, name: str) -> Iterator[None]:
"""Time a region of the training step (e.g. forward/backward/optimizer).
On CUDA we synchronize before and after so the reported duration
reflects GPU work, not just the CPU-side kernel-launch latency.
"""
if self._device.type == "cuda":
torch.cuda.synchronize(self._device)
start = time.perf_counter()
try:
yield
finally:
if self._device.type == "cuda":
torch.cuda.synchronize(self._device)
self._timing.record_section(name, time.perf_counter() - start)
def step(self, step_num: int, train_tracker: Any) -> None:
self._timing.record_step(_as_float(train_tracker.update_s))
self._timing.record_dataloading(_as_float(train_tracker.dataloading_s))
+40
View File
@@ -23,6 +23,7 @@ import subprocess
import sys
from pathlib import Path
import pytest
import torch
from huggingface_hub.errors import HfHubHTTPError
@@ -374,6 +375,45 @@ def test_step_timing_collector_accepts_metric_like_values(tmp_path):
assert payload["dataloading_s"]["mean"] == 0.05
def test_step_timing_collector_records_forward_backward_optimizer(tmp_path):
from lerobot.utils.profiling_utils import _StepTimingCollector
collector = _StepTimingCollector()
for _ in range(3):
collector.record_section("forward", 0.10)
collector.record_section("backward", 0.20)
collector.record_section("optimizer", 0.05)
collector.write_json(tmp_path / "step_timing_summary.json")
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
assert payload["forward_s"]["mean"] == pytest.approx(0.10)
assert payload["backward_s"]["mean"] == pytest.approx(0.20)
assert payload["optimizer_s"]["mean"] == pytest.approx(0.05)
assert payload["forward_s"]["count"] == 3
def test_training_profiler_section_records_duration(tmp_path):
from lerobot.utils.profiling_utils import TrainingProfiler
profiler = TrainingProfiler(
mode="summary",
output_dir=tmp_path,
device=torch.device("cpu"),
)
with profiler:
with profiler.section("forward"):
pass
with profiler.section("backward"):
pass
profiler.step(1, argparse.Namespace(update_s=0.5, dataloading_s=0.01))
profiler.finalize()
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
assert payload["forward_s"]["count"] == 1
assert payload["backward_s"]["count"] == 1
assert payload["forward_s"]["mean"] >= 0.0
def test_profiler_device_time_uses_generic_attr_first():
from lerobot.utils.profiling_utils import _get_profiler_device_time_us