From 35e3b28da151a884e219ac27aa733497c18e0d6e Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 16 Apr 2026 10:11:14 +0200 Subject: [PATCH] fix(profiling): normalize timing metrics before export --- src/lerobot/scripts/lerobot_train.py | 2 +- src/lerobot/utils/profiling_utils.py | 19 ++++++++++++++----- tests/scripts/test_model_profiling.py | 22 ++++++++++++++++++++++ 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 9093d7516..765f5d9e2 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -173,7 +173,7 @@ def update_policy( forward_s=forward_s, backward_s=backward_s, optimizer_s=optimizer_s, - total_update_s=train_metrics.update_s, + total_update_s=train_metrics.update_s.val, ) return train_metrics, output_dict diff --git a/src/lerobot/utils/profiling_utils.py b/src/lerobot/utils/profiling_utils.py index c3625363b..9a7289ad2 100644 --- a/src/lerobot/utils/profiling_utils.py +++ b/src/lerobot/utils/profiling_utils.py @@ -22,6 +22,7 @@ import io import json import pstats import statistics +from numbers import Real from collections.abc import Callable from dataclasses import dataclass, field from pathlib import Path @@ -244,6 +245,14 @@ def _summary(values: list[float]) -> dict[str, float] | dict[str, None]: } +def _as_float(value: Any) -> float: + if isinstance(value, Real): + return float(value) + if hasattr(value, "val"): + return float(value.val) + raise TypeError(f"Expected a real-valued metric, got {type(value).__name__}") + + @dataclass class StepTimingCollector: forward_s: list[float] = field(default_factory=list) @@ -261,13 +270,13 @@ class StepTimingCollector: optimizer_s: float, total_update_s: float, ) -> None: - self.forward_s.append(forward_s) - self.backward_s.append(backward_s) - self.optimizer_s.append(optimizer_s) - self.total_update_s.append(total_update_s) + self.forward_s.append(_as_float(forward_s)) + self.backward_s.append(_as_float(backward_s)) + self.optimizer_s.append(_as_float(optimizer_s)) + self.total_update_s.append(_as_float(total_update_s)) def record_dataloading(self, dataloading_s: float) -> None: - self.dataloading_s.append(dataloading_s) + self.dataloading_s.append(_as_float(dataloading_s)) def record_memory(self, *, step: int, allocated_bytes: int, reserved_bytes: int) -> None: self.memory_timeline.append( diff --git a/tests/scripts/test_model_profiling.py b/tests/scripts/test_model_profiling.py index 61dae1232..9b87e0811 100644 --- a/tests/scripts/test_model_profiling.py +++ b/tests/scripts/test_model_profiling.py @@ -220,3 +220,25 @@ def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path): assert payload["reference_batch_size"] == 2 assert "operator_fingerprint" in payload assert payload["outputs"]["loss"]["numel"] == 1 + + +def test_step_timing_collector_accepts_metric_like_values(tmp_path): + from lerobot.utils.profiling_utils import StepTimingCollector + + class _MetricLike: + def __init__(self, val): + self.val = val + + collector = StepTimingCollector() + collector.record( + forward_s=0.1, + backward_s=0.2, + optimizer_s=0.3, + total_update_s=_MetricLike(0.6), + ) + collector.record_dataloading(_MetricLike(0.05)) + collector.write_json(tmp_path / "step_timing_summary.json") + + payload = json.loads((tmp_path / "step_timing_summary.json").read_text()) + assert payload["total_update_s"]["mean"] == 0.6 + assert payload["dataloading_s"]["mean"] == 0.05