mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
fix(profiling): read generic device timings from profiler
This commit is contained in:
@@ -172,6 +172,12 @@ def _hash_payload(payload: Any) -> str:
|
||||
return hashlib.sha256(json.dumps(payload, sort_keys=True).encode()).hexdigest()
|
||||
|
||||
|
||||
def _get_profiler_device_time_us(event: Any) -> float | None:
|
||||
return _stable_float(
|
||||
getattr(event, "self_device_time_total", getattr(event, "self_cuda_time_total", None))
|
||||
)
|
||||
|
||||
|
||||
def _build_reference_batch(dataset: Any, batch_size: int) -> Any:
|
||||
if len(dataset) == 0:
|
||||
raise ValueError("Cannot build a reference batch from an empty dataset.")
|
||||
@@ -212,7 +218,7 @@ def write_deterministic_forward_artifacts(
|
||||
"cpu_time_total_us": _stable_float(getattr(event, "cpu_time_total", None)),
|
||||
}
|
||||
if device_type == "cuda":
|
||||
entry["self_cuda_time_total_us"] = _stable_float(getattr(event, "self_cuda_time_total", None))
|
||||
entry["self_cuda_time_total_us"] = _get_profiler_device_time_us(event)
|
||||
operator_entries.append(entry)
|
||||
operator_entries = sorted(operator_entries, key=lambda item: item["key"])
|
||||
|
||||
|
||||
@@ -242,3 +242,12 @@ def test_step_timing_collector_accepts_metric_like_values(tmp_path):
|
||||
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
|
||||
assert payload["total_update_s"]["mean"] == 0.6
|
||||
assert payload["dataloading_s"]["mean"] == 0.05
|
||||
|
||||
|
||||
def test_profiler_device_time_uses_generic_attr_first():
|
||||
from lerobot.utils.profiling_utils import _get_profiler_device_time_us
|
||||
|
||||
class _Event:
|
||||
self_device_time_total = 12.3456
|
||||
|
||||
assert _get_profiler_device_time_us(_Event()) == 12.3456
|
||||
|
||||
Reference in New Issue
Block a user