mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
feat(profiling): record forward/backward/optimizer timings
The dashboard expects per-phase timings (forward_s, backward_s, optimizer_s) in step_timing_summary.json, but only total_update_s and dataloading_s were collected — leaving every chart except dataloading empty. Add a lightweight TrainingProfiler.section(name) context manager that times a region with torch.cuda.synchronize before and after (so GPU work is captured, not just the kernel-launch latency) and accumulates per-section samples into step_timing_summary.json. Wrap forward, backward (incl. grad clip), and optimizer (incl. zero_grad and scheduler.step) in update_policy with these sections. When profiling is off (profiler=None) the wrappers become no-ops, so training performance is unchanged outside CI. Made-with: Cursor
This commit is contained in:
@@ -72,6 +72,7 @@ def update_policy(
|
|||||||
lr_scheduler=None,
|
lr_scheduler=None,
|
||||||
lock=None,
|
lock=None,
|
||||||
rabc_weights_provider=None,
|
rabc_weights_provider=None,
|
||||||
|
profiler: "TrainingProfiler | None" = None,
|
||||||
) -> tuple[MetricsTracker, dict]:
|
) -> tuple[MetricsTracker, dict]:
|
||||||
"""
|
"""
|
||||||
Performs a single training step to update the policy's weights.
|
Performs a single training step to update the policy's weights.
|
||||||
@@ -104,8 +105,10 @@ def update_policy(
|
|||||||
if rabc_weights_provider is not None:
|
if rabc_weights_provider is not None:
|
||||||
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
|
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
|
||||||
|
|
||||||
# Let accelerator handle mixed precision
|
def _section(name: str) -> Any:
|
||||||
with accelerator.autocast():
|
return profiler.section(name) if profiler is not None else nullcontext()
|
||||||
|
|
||||||
|
with _section("forward"), accelerator.autocast():
|
||||||
# Use per-sample loss when RA-BC is enabled for proper weighting
|
# Use per-sample loss when RA-BC is enabled for proper weighting
|
||||||
if rabc_batch_weights is not None:
|
if rabc_batch_weights is not None:
|
||||||
# Get per-sample losses
|
# Get per-sample losses
|
||||||
@@ -124,26 +127,26 @@ def update_policy(
|
|||||||
|
|
||||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||||
|
|
||||||
# Use accelerator's backward method
|
with _section("backward"):
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
|
|
||||||
# Clip gradients if specified
|
# Clip gradients if specified
|
||||||
if grad_clip_norm > 0:
|
if grad_clip_norm > 0:
|
||||||
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
||||||
else:
|
else:
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
policy.parameters(), float("inf"), error_if_nonfinite=False
|
policy.parameters(), float("inf"), error_if_nonfinite=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optimizer step
|
with _section("optimizer"):
|
||||||
with lock if lock is not None else nullcontext():
|
with lock if lock is not None else nullcontext():
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Step through pytorch scheduler at every batch instead of epoch
|
# Step through pytorch scheduler at every batch instead of epoch
|
||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
# Update internal buffers if policy has update method
|
# Update internal buffers if policy has update method
|
||||||
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
|
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
|
||||||
@@ -454,6 +457,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
accelerator=accelerator,
|
accelerator=accelerator,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
rabc_weights_provider=rabc_weights,
|
rabc_weights_provider=rabc_weights,
|
||||||
|
profiler=profiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import statistics
|
import statistics
|
||||||
|
import time
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from numbers import Real
|
from numbers import Real
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -245,6 +248,7 @@ def _as_float(value: Any) -> float:
|
|||||||
class _StepTimingCollector:
|
class _StepTimingCollector:
|
||||||
total_update_s: list[float] = field(default_factory=list)
|
total_update_s: list[float] = field(default_factory=list)
|
||||||
dataloading_s: list[float] = field(default_factory=list)
|
dataloading_s: list[float] = field(default_factory=list)
|
||||||
|
section_s: dict[str, list[float]] = field(default_factory=dict)
|
||||||
memory_timeline: list[dict[str, float | int]] = field(default_factory=list)
|
memory_timeline: list[dict[str, float | int]] = field(default_factory=list)
|
||||||
|
|
||||||
def record_step(self, total_update_s: float) -> None:
|
def record_step(self, total_update_s: float) -> None:
|
||||||
@@ -253,6 +257,9 @@ class _StepTimingCollector:
|
|||||||
def record_dataloading(self, dataloading_s: float) -> None:
|
def record_dataloading(self, dataloading_s: float) -> None:
|
||||||
self.dataloading_s.append(_as_float(dataloading_s))
|
self.dataloading_s.append(_as_float(dataloading_s))
|
||||||
|
|
||||||
|
def record_section(self, name: str, duration_s: float) -> None:
|
||||||
|
self.section_s.setdefault(name, []).append(_as_float(duration_s))
|
||||||
|
|
||||||
def record_memory(self, *, step: int, allocated_bytes: int, reserved_bytes: int) -> None:
|
def record_memory(self, *, step: int, allocated_bytes: int, reserved_bytes: int) -> None:
|
||||||
self.memory_timeline.append(
|
self.memory_timeline.append(
|
||||||
{
|
{
|
||||||
@@ -263,11 +270,14 @@ class _StepTimingCollector:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return {
|
payload: dict[str, Any] = {
|
||||||
"total_update_s": _summary(self.total_update_s),
|
"total_update_s": _summary(self.total_update_s),
|
||||||
"dataloading_s": _summary(self.dataloading_s),
|
"dataloading_s": _summary(self.dataloading_s),
|
||||||
"memory_timeline": self.memory_timeline,
|
"memory_timeline": self.memory_timeline,
|
||||||
}
|
}
|
||||||
|
for name, values in self.section_s.items():
|
||||||
|
payload[f"{name}_s"] = _summary(values)
|
||||||
|
return payload
|
||||||
|
|
||||||
def write_json(self, output_path: Path, extra: dict[str, Any] | None = None) -> None:
|
def write_json(self, output_path: Path, extra: dict[str, Any] | None = None) -> None:
|
||||||
payload = self.to_dict()
|
payload = self.to_dict()
|
||||||
@@ -355,6 +365,23 @@ class TrainingProfiler:
|
|||||||
def __exit__(self, *exc: Any) -> bool:
|
def __exit__(self, *exc: Any) -> bool:
|
||||||
return self._torch_profiler.__exit__(*exc)
|
return self._torch_profiler.__exit__(*exc)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def section(self, name: str) -> Iterator[None]:
|
||||||
|
"""Time a region of the training step (e.g. forward/backward/optimizer).
|
||||||
|
|
||||||
|
On CUDA we synchronize before and after so the reported duration
|
||||||
|
reflects GPU work, not just the CPU-side kernel-launch latency.
|
||||||
|
"""
|
||||||
|
if self._device.type == "cuda":
|
||||||
|
torch.cuda.synchronize(self._device)
|
||||||
|
start = time.perf_counter()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if self._device.type == "cuda":
|
||||||
|
torch.cuda.synchronize(self._device)
|
||||||
|
self._timing.record_section(name, time.perf_counter() - start)
|
||||||
|
|
||||||
def step(self, step_num: int, train_tracker: Any) -> None:
|
def step(self, step_num: int, train_tracker: Any) -> None:
|
||||||
self._timing.record_step(_as_float(train_tracker.update_s))
|
self._timing.record_step(_as_float(train_tracker.update_s))
|
||||||
self._timing.record_dataloading(_as_float(train_tracker.dataloading_s))
|
self._timing.record_dataloading(_as_float(train_tracker.dataloading_s))
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub.errors import HfHubHTTPError
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
|
|
||||||
@@ -374,6 +375,45 @@ def test_step_timing_collector_accepts_metric_like_values(tmp_path):
|
|||||||
assert payload["dataloading_s"]["mean"] == 0.05
|
assert payload["dataloading_s"]["mean"] == 0.05
|
||||||
|
|
||||||
|
|
||||||
|
def test_step_timing_collector_records_forward_backward_optimizer(tmp_path):
|
||||||
|
from lerobot.utils.profiling_utils import _StepTimingCollector
|
||||||
|
|
||||||
|
collector = _StepTimingCollector()
|
||||||
|
for _ in range(3):
|
||||||
|
collector.record_section("forward", 0.10)
|
||||||
|
collector.record_section("backward", 0.20)
|
||||||
|
collector.record_section("optimizer", 0.05)
|
||||||
|
collector.write_json(tmp_path / "step_timing_summary.json")
|
||||||
|
|
||||||
|
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
|
||||||
|
assert payload["forward_s"]["mean"] == pytest.approx(0.10)
|
||||||
|
assert payload["backward_s"]["mean"] == pytest.approx(0.20)
|
||||||
|
assert payload["optimizer_s"]["mean"] == pytest.approx(0.05)
|
||||||
|
assert payload["forward_s"]["count"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_training_profiler_section_records_duration(tmp_path):
|
||||||
|
from lerobot.utils.profiling_utils import TrainingProfiler
|
||||||
|
|
||||||
|
profiler = TrainingProfiler(
|
||||||
|
mode="summary",
|
||||||
|
output_dir=tmp_path,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
)
|
||||||
|
with profiler:
|
||||||
|
with profiler.section("forward"):
|
||||||
|
pass
|
||||||
|
with profiler.section("backward"):
|
||||||
|
pass
|
||||||
|
profiler.step(1, argparse.Namespace(update_s=0.5, dataloading_s=0.01))
|
||||||
|
profiler.finalize()
|
||||||
|
|
||||||
|
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
|
||||||
|
assert payload["forward_s"]["count"] == 1
|
||||||
|
assert payload["backward_s"]["count"] == 1
|
||||||
|
assert payload["forward_s"]["mean"] >= 0.0
|
||||||
|
|
||||||
|
|
||||||
def test_profiler_device_time_uses_generic_attr_first():
|
def test_profiler_device_time_uses_generic_attr_first():
|
||||||
from lerobot.utils.profiling_utils import _get_profiler_device_time_us
|
from lerobot.utils.profiling_utils import _get_profiler_device_time_us
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user