feat(profiling): record forward/backward/optimizer timings

The dashboard expects per-phase timings (forward_s, backward_s,
optimizer_s) in step_timing_summary.json, but only total_update_s
and dataloading_s were collected — leaving every chart except
dataloading empty.

Add a lightweight TrainingProfiler.section(name) context manager
that times a region with torch.cuda.synchronize before and after
(so GPU work is captured, not just the kernel-launch latency) and
accumulates per-section samples into step_timing_summary.json.

Wrap forward, backward (incl. grad clip), and optimizer (incl.
zero_grad and scheduler.step) in update_policy with these sections.
When profiling is off (profiler=None) the wrappers become no-ops,
so training performance is unchanged outside CI.

Made-with: Cursor
This commit is contained in:
Pepijn
2026-04-16 20:26:27 +02:00
parent 00e9defb80
commit 1842100402
3 changed files with 90 additions and 19 deletions
+22 -18
View File
@@ -72,6 +72,7 @@ def update_policy(
lr_scheduler=None, lr_scheduler=None,
lock=None, lock=None,
rabc_weights_provider=None, rabc_weights_provider=None,
profiler: "TrainingProfiler | None" = None,
) -> tuple[MetricsTracker, dict]: ) -> tuple[MetricsTracker, dict]:
""" """
Performs a single training step to update the policy's weights. Performs a single training step to update the policy's weights.
@@ -104,8 +105,10 @@ def update_policy(
if rabc_weights_provider is not None: if rabc_weights_provider is not None:
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch) rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
# Let accelerator handle mixed precision def _section(name: str) -> Any:
with accelerator.autocast(): return profiler.section(name) if profiler is not None else nullcontext()
with _section("forward"), accelerator.autocast():
# Use per-sample loss when RA-BC is enabled for proper weighting # Use per-sample loss when RA-BC is enabled for proper weighting
if rabc_batch_weights is not None: if rabc_batch_weights is not None:
# Get per-sample losses # Get per-sample losses
@@ -124,26 +127,26 @@ def update_policy(
# TODO(rcadene): policy.unnormalize_outputs(out_dict) # TODO(rcadene): policy.unnormalize_outputs(out_dict)
# Use accelerator's backward method with _section("backward"):
accelerator.backward(loss) accelerator.backward(loss)
# Clip gradients if specified # Clip gradients if specified
if grad_clip_norm > 0: if grad_clip_norm > 0:
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
else: else:
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), float("inf"), error_if_nonfinite=False policy.parameters(), float("inf"), error_if_nonfinite=False
) )
# Optimizer step with _section("optimizer"):
with lock if lock is not None else nullcontext(): with lock if lock is not None else nullcontext():
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
# Step through pytorch scheduler at every batch instead of epoch # Step through pytorch scheduler at every batch instead of epoch
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step() lr_scheduler.step()
# Update internal buffers if policy has update method # Update internal buffers if policy has update method
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
@@ -454,6 +457,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
accelerator=accelerator, accelerator=accelerator,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
rabc_weights_provider=rabc_weights, rabc_weights_provider=rabc_weights,
profiler=profiler,
) )
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
+28 -1
View File
@@ -20,6 +20,9 @@ import hashlib
import json import json
import logging import logging
import statistics import statistics
import time
from collections.abc import Iterator
from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from numbers import Real from numbers import Real
from pathlib import Path from pathlib import Path
@@ -245,6 +248,7 @@ def _as_float(value: Any) -> float:
class _StepTimingCollector: class _StepTimingCollector:
total_update_s: list[float] = field(default_factory=list) total_update_s: list[float] = field(default_factory=list)
dataloading_s: list[float] = field(default_factory=list) dataloading_s: list[float] = field(default_factory=list)
section_s: dict[str, list[float]] = field(default_factory=dict)
memory_timeline: list[dict[str, float | int]] = field(default_factory=list) memory_timeline: list[dict[str, float | int]] = field(default_factory=list)
def record_step(self, total_update_s: float) -> None: def record_step(self, total_update_s: float) -> None:
@@ -253,6 +257,9 @@ class _StepTimingCollector:
def record_dataloading(self, dataloading_s: float) -> None: def record_dataloading(self, dataloading_s: float) -> None:
self.dataloading_s.append(_as_float(dataloading_s)) self.dataloading_s.append(_as_float(dataloading_s))
def record_section(self, name: str, duration_s: float) -> None:
self.section_s.setdefault(name, []).append(_as_float(duration_s))
def record_memory(self, *, step: int, allocated_bytes: int, reserved_bytes: int) -> None: def record_memory(self, *, step: int, allocated_bytes: int, reserved_bytes: int) -> None:
self.memory_timeline.append( self.memory_timeline.append(
{ {
@@ -263,11 +270,14 @@ class _StepTimingCollector:
) )
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return { payload: dict[str, Any] = {
"total_update_s": _summary(self.total_update_s), "total_update_s": _summary(self.total_update_s),
"dataloading_s": _summary(self.dataloading_s), "dataloading_s": _summary(self.dataloading_s),
"memory_timeline": self.memory_timeline, "memory_timeline": self.memory_timeline,
} }
for name, values in self.section_s.items():
payload[f"{name}_s"] = _summary(values)
return payload
def write_json(self, output_path: Path, extra: dict[str, Any] | None = None) -> None: def write_json(self, output_path: Path, extra: dict[str, Any] | None = None) -> None:
payload = self.to_dict() payload = self.to_dict()
@@ -355,6 +365,23 @@ class TrainingProfiler:
def __exit__(self, *exc: Any) -> bool: def __exit__(self, *exc: Any) -> bool:
return self._torch_profiler.__exit__(*exc) return self._torch_profiler.__exit__(*exc)
@contextmanager
def section(self, name: str) -> Iterator[None]:
"""Time a region of the training step (e.g. forward/backward/optimizer).
On CUDA we synchronize before and after so the reported duration
reflects GPU work, not just the CPU-side kernel-launch latency.
"""
if self._device.type == "cuda":
torch.cuda.synchronize(self._device)
start = time.perf_counter()
try:
yield
finally:
if self._device.type == "cuda":
torch.cuda.synchronize(self._device)
self._timing.record_section(name, time.perf_counter() - start)
def step(self, step_num: int, train_tracker: Any) -> None: def step(self, step_num: int, train_tracker: Any) -> None:
self._timing.record_step(_as_float(train_tracker.update_s)) self._timing.record_step(_as_float(train_tracker.update_s))
self._timing.record_dataloading(_as_float(train_tracker.dataloading_s)) self._timing.record_dataloading(_as_float(train_tracker.dataloading_s))
+40
View File
@@ -23,6 +23,7 @@ import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
import pytest
import torch import torch
from huggingface_hub.errors import HfHubHTTPError from huggingface_hub.errors import HfHubHTTPError
@@ -374,6 +375,45 @@ def test_step_timing_collector_accepts_metric_like_values(tmp_path):
assert payload["dataloading_s"]["mean"] == 0.05 assert payload["dataloading_s"]["mean"] == 0.05
def test_step_timing_collector_records_forward_backward_optimizer(tmp_path):
from lerobot.utils.profiling_utils import _StepTimingCollector
collector = _StepTimingCollector()
for _ in range(3):
collector.record_section("forward", 0.10)
collector.record_section("backward", 0.20)
collector.record_section("optimizer", 0.05)
collector.write_json(tmp_path / "step_timing_summary.json")
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
assert payload["forward_s"]["mean"] == pytest.approx(0.10)
assert payload["backward_s"]["mean"] == pytest.approx(0.20)
assert payload["optimizer_s"]["mean"] == pytest.approx(0.05)
assert payload["forward_s"]["count"] == 3
def test_training_profiler_section_records_duration(tmp_path):
from lerobot.utils.profiling_utils import TrainingProfiler
profiler = TrainingProfiler(
mode="summary",
output_dir=tmp_path,
device=torch.device("cpu"),
)
with profiler:
with profiler.section("forward"):
pass
with profiler.section("backward"):
pass
profiler.step(1, argparse.Namespace(update_s=0.5, dataloading_s=0.01))
profiler.finalize()
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
assert payload["forward_s"]["count"] == 1
assert payload["backward_s"]["count"] == 1
assert payload["forward_s"]["mean"] >= 0.0
def test_profiler_device_time_uses_generic_attr_first(): def test_profiler_device_time_uses_generic_attr_first():
from lerobot.utils.profiling_utils import _get_profiler_device_time_us from lerobot.utils.profiling_utils import _get_profiler_device_time_us