mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
feat(profiling): record forward/backward/optimizer timings
The dashboard expects per-phase timings (forward_s, backward_s, optimizer_s) in step_timing_summary.json, but only total_update_s and dataloading_s were collected — leaving every chart except dataloading empty. Add a lightweight TrainingProfiler.section(name) context manager that times a region with torch.cuda.synchronize before and after (so GPU work is captured, not just the kernel-launch latency) and accumulates per-section samples into step_timing_summary.json. Wrap forward, backward (incl. grad clip), and optimizer (incl. zero_grad and scheduler.step) in update_policy with these sections. When profiling is off (profiler=None) the wrappers become no-ops, so training performance is unchanged outside CI. Made-with: Cursor
This commit is contained in:
@@ -23,6 +23,7 @@ import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
@@ -374,6 +375,45 @@ def test_step_timing_collector_accepts_metric_like_values(tmp_path):
|
||||
assert payload["dataloading_s"]["mean"] == 0.05
|
||||
|
||||
|
||||
def test_step_timing_collector_records_forward_backward_optimizer(tmp_path):
|
||||
from lerobot.utils.profiling_utils import _StepTimingCollector
|
||||
|
||||
collector = _StepTimingCollector()
|
||||
for _ in range(3):
|
||||
collector.record_section("forward", 0.10)
|
||||
collector.record_section("backward", 0.20)
|
||||
collector.record_section("optimizer", 0.05)
|
||||
collector.write_json(tmp_path / "step_timing_summary.json")
|
||||
|
||||
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
|
||||
assert payload["forward_s"]["mean"] == pytest.approx(0.10)
|
||||
assert payload["backward_s"]["mean"] == pytest.approx(0.20)
|
||||
assert payload["optimizer_s"]["mean"] == pytest.approx(0.05)
|
||||
assert payload["forward_s"]["count"] == 3
|
||||
|
||||
|
||||
def test_training_profiler_section_records_duration(tmp_path):
|
||||
from lerobot.utils.profiling_utils import TrainingProfiler
|
||||
|
||||
profiler = TrainingProfiler(
|
||||
mode="summary",
|
||||
output_dir=tmp_path,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
with profiler:
|
||||
with profiler.section("forward"):
|
||||
pass
|
||||
with profiler.section("backward"):
|
||||
pass
|
||||
profiler.step(1, argparse.Namespace(update_s=0.5, dataloading_s=0.01))
|
||||
profiler.finalize()
|
||||
|
||||
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
|
||||
assert payload["forward_s"]["count"] == 1
|
||||
assert payload["backward_s"]["count"] == 1
|
||||
assert payload["forward_s"]["mean"] >= 0.0
|
||||
|
||||
|
||||
def test_profiler_device_time_uses_generic_attr_first():
|
||||
from lerobot.utils.profiling_utils import _get_profiler_device_time_us
|
||||
|
||||
|
||||
Reference in New Issue
Block a user