mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
fix(profiling): normalize timing metrics before export
This commit is contained in:
@@ -173,7 +173,7 @@ def update_policy(
|
||||
forward_s=forward_s,
|
||||
backward_s=backward_s,
|
||||
optimizer_s=optimizer_s,
|
||||
total_update_s=train_metrics.update_s,
|
||||
total_update_s=train_metrics.update_s.val,
|
||||
)
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import io
|
||||
import json
|
||||
import pstats
|
||||
import statistics
|
||||
from numbers import Real
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
@@ -244,6 +245,14 @@ def _summary(values: list[float]) -> dict[str, float] | dict[str, None]:
|
||||
}
|
||||
|
||||
|
||||
def _as_float(value: Any) -> float:
|
||||
if isinstance(value, Real):
|
||||
return float(value)
|
||||
if hasattr(value, "val"):
|
||||
return float(value.val)
|
||||
raise TypeError(f"Expected a real-valued metric, got {type(value).__name__}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepTimingCollector:
|
||||
forward_s: list[float] = field(default_factory=list)
|
||||
@@ -261,13 +270,13 @@ class StepTimingCollector:
|
||||
optimizer_s: float,
|
||||
total_update_s: float,
|
||||
) -> None:
|
||||
self.forward_s.append(forward_s)
|
||||
self.backward_s.append(backward_s)
|
||||
self.optimizer_s.append(optimizer_s)
|
||||
self.total_update_s.append(total_update_s)
|
||||
self.forward_s.append(_as_float(forward_s))
|
||||
self.backward_s.append(_as_float(backward_s))
|
||||
self.optimizer_s.append(_as_float(optimizer_s))
|
||||
self.total_update_s.append(_as_float(total_update_s))
|
||||
|
||||
def record_dataloading(self, dataloading_s: float) -> None:
|
||||
self.dataloading_s.append(dataloading_s)
|
||||
self.dataloading_s.append(_as_float(dataloading_s))
|
||||
|
||||
def record_memory(self, *, step: int, allocated_bytes: int, reserved_bytes: int) -> None:
|
||||
self.memory_timeline.append(
|
||||
|
||||
@@ -220,3 +220,25 @@ def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
|
||||
assert payload["reference_batch_size"] == 2
|
||||
assert "operator_fingerprint" in payload
|
||||
assert payload["outputs"]["loss"]["numel"] == 1
|
||||
|
||||
|
||||
def test_step_timing_collector_accepts_metric_like_values(tmp_path):
|
||||
from lerobot.utils.profiling_utils import StepTimingCollector
|
||||
|
||||
class _MetricLike:
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
collector = StepTimingCollector()
|
||||
collector.record(
|
||||
forward_s=0.1,
|
||||
backward_s=0.2,
|
||||
optimizer_s=0.3,
|
||||
total_update_s=_MetricLike(0.6),
|
||||
)
|
||||
collector.record_dataloading(_MetricLike(0.05))
|
||||
collector.write_json(tmp_path / "step_timing_summary.json")
|
||||
|
||||
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
|
||||
assert payload["total_update_s"]["mean"] == 0.6
|
||||
assert payload["dataloading_s"]["mean"] == 0.05
|
||||
|
||||
Reference in New Issue
Block a user