mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 18:20:08 +00:00
fix(profiling): normalize timing metrics before export
This commit is contained in:
@@ -220,3 +220,25 @@ def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
|
||||
assert payload["reference_batch_size"] == 2
|
||||
assert "operator_fingerprint" in payload
|
||||
assert payload["outputs"]["loss"]["numel"] == 1
|
||||
|
||||
|
||||
def test_step_timing_collector_accepts_metric_like_values(tmp_path):
|
||||
from lerobot.utils.profiling_utils import StepTimingCollector
|
||||
|
||||
class _MetricLike:
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
collector = StepTimingCollector()
|
||||
collector.record(
|
||||
forward_s=0.1,
|
||||
backward_s=0.2,
|
||||
optimizer_s=0.3,
|
||||
total_update_s=_MetricLike(0.6),
|
||||
)
|
||||
collector.record_dataloading(_MetricLike(0.05))
|
||||
collector.write_json(tmp_path / "step_timing_summary.json")
|
||||
|
||||
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
|
||||
assert payload["total_update_s"]["mean"] == 0.6
|
||||
assert payload["dataloading_s"]["mean"] == 0.05
|
||||
|
||||
Reference in New Issue
Block a user