fix(profiling): read generic device timings from profiler

This commit is contained in:
Pepijn
2026-04-16 10:29:01 +02:00
parent 35e3b28da1
commit 25e5062b2c
2 changed files with 16 additions and 1 deletions
+7 -1
View File
@@ -172,6 +172,12 @@ def _hash_payload(payload: Any) -> str:
return hashlib.sha256(json.dumps(payload, sort_keys=True).encode()).hexdigest() return hashlib.sha256(json.dumps(payload, sort_keys=True).encode()).hexdigest()
def _get_profiler_device_time_us(event: Any) -> float | None:
return _stable_float(
getattr(event, "self_device_time_total", getattr(event, "self_cuda_time_total", None))
)
def _build_reference_batch(dataset: Any, batch_size: int) -> Any: def _build_reference_batch(dataset: Any, batch_size: int) -> Any:
if len(dataset) == 0: if len(dataset) == 0:
raise ValueError("Cannot build a reference batch from an empty dataset.") raise ValueError("Cannot build a reference batch from an empty dataset.")
@@ -212,7 +218,7 @@ def write_deterministic_forward_artifacts(
"cpu_time_total_us": _stable_float(getattr(event, "cpu_time_total", None)), "cpu_time_total_us": _stable_float(getattr(event, "cpu_time_total", None)),
} }
if device_type == "cuda": if device_type == "cuda":
entry["self_cuda_time_total_us"] = _stable_float(getattr(event, "self_cuda_time_total", None)) entry["self_cuda_time_total_us"] = _get_profiler_device_time_us(event)
operator_entries.append(entry) operator_entries.append(entry)
operator_entries = sorted(operator_entries, key=lambda item: item["key"]) operator_entries = sorted(operator_entries, key=lambda item: item["key"])
+9
View File
@@ -242,3 +242,12 @@ def test_step_timing_collector_accepts_metric_like_values(tmp_path):
payload = json.loads((tmp_path / "step_timing_summary.json").read_text()) payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
assert payload["total_update_s"]["mean"] == 0.6 assert payload["total_update_s"]["mean"] == 0.6
assert payload["dataloading_s"]["mean"] == 0.05 assert payload["dataloading_s"]["mean"] == 0.05
def test_profiler_device_time_uses_generic_attr_first():
from lerobot.utils.profiling_utils import _get_profiler_device_time_us
class _Event:
self_device_time_total = 12.3456
assert _get_profiler_device_time_us(_Event()) == 12.3456