fix(profiling): normalize timing metrics before export

This commit is contained in:
Pepijn
2026-04-16 10:11:14 +02:00
parent ed8a98dda6
commit 35e3b28da1
3 changed files with 37 additions and 6 deletions
+1 -1
View File
@@ -173,7 +173,7 @@ def update_policy(
forward_s=forward_s,
backward_s=backward_s,
optimizer_s=optimizer_s,
total_update_s=train_metrics.update_s,
total_update_s=train_metrics.update_s.val,
)
return train_metrics, output_dict
+14 -5
View File
@@ -22,6 +22,7 @@ import io
import json
import pstats
import statistics
from numbers import Real
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
@@ -244,6 +245,14 @@ def _summary(values: list[float]) -> dict[str, float] | dict[str, None]:
}
def _as_float(value: Any) -> float:
if isinstance(value, Real):
return float(value)
if hasattr(value, "val"):
return float(value.val)
raise TypeError(f"Expected a real-valued metric, got {type(value).__name__}")
@dataclass
class StepTimingCollector:
forward_s: list[float] = field(default_factory=list)
@@ -261,13 +270,13 @@ class StepTimingCollector:
optimizer_s: float,
total_update_s: float,
) -> None:
self.forward_s.append(forward_s)
self.backward_s.append(backward_s)
self.optimizer_s.append(optimizer_s)
self.total_update_s.append(total_update_s)
self.forward_s.append(_as_float(forward_s))
self.backward_s.append(_as_float(backward_s))
self.optimizer_s.append(_as_float(optimizer_s))
self.total_update_s.append(_as_float(total_update_s))
def record_dataloading(self, dataloading_s: float) -> None:
self.dataloading_s.append(dataloading_s)
self.dataloading_s.append(_as_float(dataloading_s))
def record_memory(self, *, step: int, allocated_bytes: int, reserved_bytes: int) -> None:
self.memory_timeline.append(
+22
View File
@@ -220,3 +220,25 @@ def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
assert payload["reference_batch_size"] == 2
assert "operator_fingerprint" in payload
assert payload["outputs"]["loss"]["numel"] == 1
def test_step_timing_collector_accepts_metric_like_values(tmp_path):
from lerobot.utils.profiling_utils import StepTimingCollector
class _MetricLike:
def __init__(self, val):
self.val = val
collector = StepTimingCollector()
collector.record(
forward_s=0.1,
backward_s=0.2,
optimizer_s=0.3,
total_update_s=_MetricLike(0.6),
)
collector.record_dataloading(_MetricLike(0.05))
collector.write_json(tmp_path / "step_timing_summary.json")
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
assert payload["total_update_s"]["mean"] == 0.6
assert payload["dataloading_s"]["mean"] == 0.05