diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 392bfb51f..c2bca99cf 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -72,6 +72,7 @@ def update_policy( lr_scheduler=None, lock=None, rabc_weights_provider=None, + profiler: "TrainingProfiler | None" = None, ) -> tuple[MetricsTracker, dict]: """ Performs a single training step to update the policy's weights. @@ -104,8 +105,10 @@ def update_policy( if rabc_weights_provider is not None: rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch) - # Let accelerator handle mixed precision - with accelerator.autocast(): + def _section(name: str) -> Any: + return profiler.section(name) if profiler is not None else nullcontext() + + with _section("forward"), accelerator.autocast(): # Use per-sample loss when RA-BC is enabled for proper weighting if rabc_batch_weights is not None: # Get per-sample losses @@ -124,26 +127,26 @@ def update_policy( # TODO(rcadene): policy.unnormalize_outputs(out_dict) - # Use accelerator's backward method - accelerator.backward(loss) + with _section("backward"): + accelerator.backward(loss) - # Clip gradients if specified - if grad_clip_norm > 0: - grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) - else: - grad_norm = torch.nn.utils.clip_grad_norm_( - policy.parameters(), float("inf"), error_if_nonfinite=False - ) + # Clip gradients if specified + if grad_clip_norm > 0: + grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.parameters(), float("inf"), error_if_nonfinite=False + ) - # Optimizer step - with lock if lock is not None else nullcontext(): - optimizer.step() + with _section("optimizer"): + with lock if lock is not None else nullcontext(): + optimizer.step() - optimizer.zero_grad() + optimizer.zero_grad() - # Step through pytorch scheduler at every batch instead of epoch - if lr_scheduler is not None: - lr_scheduler.step() + # Step through pytorch scheduler at every batch instead of epoch + if lr_scheduler is not None: + lr_scheduler.step() # Update internal buffers if policy has update method if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): @@ -454,6 +457,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): accelerator=accelerator, lr_scheduler=lr_scheduler, rabc_weights_provider=rabc_weights, + profiler=profiler, ) # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we diff --git a/src/lerobot/utils/profiling_utils.py b/src/lerobot/utils/profiling_utils.py index 85931051d..874b1ffb9 100644 --- a/src/lerobot/utils/profiling_utils.py +++ b/src/lerobot/utils/profiling_utils.py @@ -20,6 +20,9 @@ import hashlib import json import logging import statistics +import time +from collections.abc import Iterator +from contextlib import contextmanager from dataclasses import dataclass, field from numbers import Real from pathlib import Path @@ -245,6 +248,7 @@ def _as_float(value: Any) -> float: class _StepTimingCollector: total_update_s: list[float] = field(default_factory=list) dataloading_s: list[float] = field(default_factory=list) + section_s: dict[str, list[float]] = field(default_factory=dict) memory_timeline: list[dict[str, float | int]] = field(default_factory=list) def record_step(self, total_update_s: float) -> None: @@ -253,6 +257,9 @@ class _StepTimingCollector: def record_dataloading(self, dataloading_s: float) -> None: self.dataloading_s.append(_as_float(dataloading_s)) + def record_section(self, name: str, duration_s: float) -> None: + self.section_s.setdefault(name, []).append(_as_float(duration_s)) + def record_memory(self, *, step: int, allocated_bytes: int, reserved_bytes: int) -> None: self.memory_timeline.append( { @@ -263,11 +270,14 @@ class _StepTimingCollector: ) def to_dict(self) -> dict[str, Any]: - return { + payload: dict[str, Any] = { "total_update_s": _summary(self.total_update_s), "dataloading_s": _summary(self.dataloading_s), "memory_timeline": self.memory_timeline, } + for name, values in self.section_s.items(): + payload[f"{name}_s"] = _summary(values) + return payload def write_json(self, output_path: Path, extra: dict[str, Any] | None = None) -> None: payload = self.to_dict() @@ -355,6 +365,23 @@ class TrainingProfiler: def __exit__(self, *exc: Any) -> bool: return self._torch_profiler.__exit__(*exc) + @contextmanager + def section(self, name: str) -> Iterator[None]: + """Time a region of the training step (e.g. forward/backward/optimizer). + + On CUDA we synchronize before and after so the reported duration + reflects GPU work, not just the CPU-side kernel-launch latency. + """ + if self._device.type == "cuda": + torch.cuda.synchronize(self._device) + start = time.perf_counter() + try: + yield + finally: + if self._device.type == "cuda": + torch.cuda.synchronize(self._device) + self._timing.record_section(name, time.perf_counter() - start) + def step(self, step_num: int, train_tracker: Any) -> None: self._timing.record_step(_as_float(train_tracker.update_s)) self._timing.record_dataloading(_as_float(train_tracker.dataloading_s)) diff --git a/tests/scripts/test_model_profiling.py b/tests/scripts/test_model_profiling.py index 4a6930f65..10d4e4b1b 100644 --- a/tests/scripts/test_model_profiling.py +++ b/tests/scripts/test_model_profiling.py @@ -23,6 +23,7 @@ import subprocess import sys from pathlib import Path +import pytest import torch from huggingface_hub.errors import HfHubHTTPError @@ -374,6 +375,45 @@ def test_step_timing_collector_accepts_metric_like_values(tmp_path): assert payload["dataloading_s"]["mean"] == 0.05 +def test_step_timing_collector_records_forward_backward_optimizer(tmp_path): + from lerobot.utils.profiling_utils import _StepTimingCollector + + collector = _StepTimingCollector() + for _ in range(3): + collector.record_section("forward", 0.10) + collector.record_section("backward", 0.20) + collector.record_section("optimizer", 0.05) + collector.write_json(tmp_path / "step_timing_summary.json") + + payload = json.loads((tmp_path / "step_timing_summary.json").read_text()) + assert payload["forward_s"]["mean"] == pytest.approx(0.10) + assert payload["backward_s"]["mean"] == pytest.approx(0.20) + assert payload["optimizer_s"]["mean"] == pytest.approx(0.05) + assert payload["forward_s"]["count"] == 3 + + +def test_training_profiler_section_records_duration(tmp_path): + from lerobot.utils.profiling_utils import TrainingProfiler + + profiler = TrainingProfiler( + mode="summary", + output_dir=tmp_path, + device=torch.device("cpu"), + ) + with profiler: + with profiler.section("forward"): + pass + with profiler.section("backward"): + pass + profiler.step(1, argparse.Namespace(update_s=0.5, dataloading_s=0.01)) + profiler.finalize() + + payload = json.loads((tmp_path / "step_timing_summary.json").read_text()) + assert payload["forward_s"]["count"] == 1 + assert payload["backward_s"]["count"] == 1 + assert payload["forward_s"]["mean"] >= 0.0 + + def test_profiler_device_time_uses_generic_attr_first(): from lerobot.utils.profiling_utils import _get_profiler_device_time_us