fix(profiling): normalize timing metrics before export

This commit is contained in:
Pepijn
2026-04-16 10:11:14 +02:00
parent ed8a98dda6
commit 35e3b28da1
3 changed files with 37 additions and 6 deletions
+1 -1
View File
@@ -173,7 +173,7 @@ def update_policy(
forward_s=forward_s, forward_s=forward_s,
backward_s=backward_s, backward_s=backward_s,
optimizer_s=optimizer_s, optimizer_s=optimizer_s,
total_update_s=train_metrics.update_s, total_update_s=train_metrics.update_s.val,
) )
return train_metrics, output_dict return train_metrics, output_dict
+14 -5
View File
@@ -22,6 +22,7 @@ import io
import json import json
import pstats import pstats
import statistics import statistics
from numbers import Real
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
@@ -244,6 +245,14 @@ def _summary(values: list[float]) -> dict[str, float] | dict[str, None]:
} }
def _as_float(value: Any) -> float:
if isinstance(value, Real):
return float(value)
if hasattr(value, "val"):
return float(value.val)
raise TypeError(f"Expected a real-valued metric, got {type(value).__name__}")
@dataclass @dataclass
class StepTimingCollector: class StepTimingCollector:
forward_s: list[float] = field(default_factory=list) forward_s: list[float] = field(default_factory=list)
@@ -261,13 +270,13 @@ class StepTimingCollector:
optimizer_s: float, optimizer_s: float,
total_update_s: float, total_update_s: float,
) -> None: ) -> None:
self.forward_s.append(forward_s) self.forward_s.append(_as_float(forward_s))
self.backward_s.append(backward_s) self.backward_s.append(_as_float(backward_s))
self.optimizer_s.append(optimizer_s) self.optimizer_s.append(_as_float(optimizer_s))
self.total_update_s.append(total_update_s) self.total_update_s.append(_as_float(total_update_s))
def record_dataloading(self, dataloading_s: float) -> None: def record_dataloading(self, dataloading_s: float) -> None:
self.dataloading_s.append(dataloading_s) self.dataloading_s.append(_as_float(dataloading_s))
def record_memory(self, *, step: int, allocated_bytes: int, reserved_bytes: int) -> None: def record_memory(self, *, step: int, allocated_bytes: int, reserved_bytes: int) -> None:
self.memory_timeline.append( self.memory_timeline.append(
+22
View File
@@ -220,3 +220,25 @@ def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
assert payload["reference_batch_size"] == 2 assert payload["reference_batch_size"] == 2
assert "operator_fingerprint" in payload assert "operator_fingerprint" in payload
assert payload["outputs"]["loss"]["numel"] == 1 assert payload["outputs"]["loss"]["numel"] == 1
def test_step_timing_collector_accepts_metric_like_values(tmp_path):
from lerobot.utils.profiling_utils import StepTimingCollector
class _MetricLike:
def __init__(self, val):
self.val = val
collector = StepTimingCollector()
collector.record(
forward_s=0.1,
backward_s=0.2,
optimizer_s=0.3,
total_update_s=_MetricLike(0.6),
)
collector.record_dataloading(_MetricLike(0.05))
collector.write_json(tmp_path / "step_timing_summary.json")
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
assert payload["total_update_s"]["mean"] == 0.6
assert payload["dataloading_s"]["mean"] == 0.05