From 25e5062b2c8ba5b36eeb9dbfcc8aa2e31b137637 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 16 Apr 2026 10:29:01 +0200 Subject: [PATCH] fix(profiling): read generic device timings from profiler --- src/lerobot/utils/profiling_utils.py | 8 +++++++- tests/scripts/test_model_profiling.py | 9 +++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/lerobot/utils/profiling_utils.py b/src/lerobot/utils/profiling_utils.py index 9a7289ad2..fabff6156 100644 --- a/src/lerobot/utils/profiling_utils.py +++ b/src/lerobot/utils/profiling_utils.py @@ -172,6 +172,12 @@ def _hash_payload(payload: Any) -> str: return hashlib.sha256(json.dumps(payload, sort_keys=True).encode()).hexdigest() +def _get_profiler_device_time_us(event: Any) -> float | None: + return _stable_float( + getattr(event, "self_device_time_total", getattr(event, "self_cuda_time_total", None)) + ) + + def _build_reference_batch(dataset: Any, batch_size: int) -> Any: if len(dataset) == 0: raise ValueError("Cannot build a reference batch from an empty dataset.") @@ -212,7 +218,7 @@ def write_deterministic_forward_artifacts( "cpu_time_total_us": _stable_float(getattr(event, "cpu_time_total", None)), } if device_type == "cuda": - entry["self_cuda_time_total_us"] = _stable_float(getattr(event, "self_cuda_time_total", None)) + entry["self_cuda_time_total_us"] = _get_profiler_device_time_us(event) operator_entries.append(entry) operator_entries = sorted(operator_entries, key=lambda item: item["key"]) diff --git a/tests/scripts/test_model_profiling.py b/tests/scripts/test_model_profiling.py index 9b87e0811..b2fde3f88 100644 --- a/tests/scripts/test_model_profiling.py +++ b/tests/scripts/test_model_profiling.py @@ -242,3 +242,12 @@ def test_step_timing_collector_accepts_metric_like_values(tmp_path): payload = json.loads((tmp_path / "step_timing_summary.json").read_text()) assert payload["total_update_s"]["mean"] == 0.6 assert payload["dataloading_s"]["mean"] == 0.05 + + +def test_profiler_device_time_uses_generic_attr_first(): + from lerobot.utils.profiling_utils import _get_profiler_device_time_us + + class _Event: + self_device_time_total = 12.3456 + + assert _get_profiler_device_time_us(_Event()) == 12.3456