mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
fix(profiling): read generic device timings from profiler
This commit is contained in:
@@ -172,6 +172,12 @@ def _hash_payload(payload: Any) -> str:
|
|||||||
return hashlib.sha256(json.dumps(payload, sort_keys=True).encode()).hexdigest()
|
return hashlib.sha256(json.dumps(payload, sort_keys=True).encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_profiler_device_time_us(event: Any) -> float | None:
|
||||||
|
return _stable_float(
|
||||||
|
getattr(event, "self_device_time_total", getattr(event, "self_cuda_time_total", None))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_reference_batch(dataset: Any, batch_size: int) -> Any:
|
def _build_reference_batch(dataset: Any, batch_size: int) -> Any:
|
||||||
if len(dataset) == 0:
|
if len(dataset) == 0:
|
||||||
raise ValueError("Cannot build a reference batch from an empty dataset.")
|
raise ValueError("Cannot build a reference batch from an empty dataset.")
|
||||||
@@ -212,7 +218,7 @@ def write_deterministic_forward_artifacts(
|
|||||||
"cpu_time_total_us": _stable_float(getattr(event, "cpu_time_total", None)),
|
"cpu_time_total_us": _stable_float(getattr(event, "cpu_time_total", None)),
|
||||||
}
|
}
|
||||||
if device_type == "cuda":
|
if device_type == "cuda":
|
||||||
entry["self_cuda_time_total_us"] = _stable_float(getattr(event, "self_cuda_time_total", None))
|
entry["self_cuda_time_total_us"] = _get_profiler_device_time_us(event)
|
||||||
operator_entries.append(entry)
|
operator_entries.append(entry)
|
||||||
operator_entries = sorted(operator_entries, key=lambda item: item["key"])
|
operator_entries = sorted(operator_entries, key=lambda item: item["key"])
|
||||||
|
|
||||||
|
|||||||
@@ -242,3 +242,12 @@ def test_step_timing_collector_accepts_metric_like_values(tmp_path):
|
|||||||
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
|
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
|
||||||
assert payload["total_update_s"]["mean"] == 0.6
|
assert payload["total_update_s"]["mean"] == 0.6
|
||||||
assert payload["dataloading_s"]["mean"] == 0.05
|
assert payload["dataloading_s"]["mean"] == 0.05
|
||||||
|
|
||||||
|
|
||||||
|
def test_profiler_device_time_uses_generic_attr_first():
|
||||||
|
from lerobot.utils.profiling_utils import _get_profiler_device_time_us
|
||||||
|
|
||||||
|
class _Event:
|
||||||
|
self_device_time_total = 12.3456
|
||||||
|
|
||||||
|
assert _get_profiler_device_time_us(_Event()) == 12.3456
|
||||||
|
|||||||
Reference in New Issue
Block a user