mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 15:09:51 +00:00
fix(profiling): normalize timing metrics before export
This commit is contained in:
@@ -173,7 +173,7 @@ def update_policy(
|
|||||||
forward_s=forward_s,
|
forward_s=forward_s,
|
||||||
backward_s=backward_s,
|
backward_s=backward_s,
|
||||||
optimizer_s=optimizer_s,
|
optimizer_s=optimizer_s,
|
||||||
total_update_s=train_metrics.update_s,
|
total_update_s=train_metrics.update_s.val,
|
||||||
)
|
)
|
||||||
return train_metrics, output_dict
|
return train_metrics, output_dict
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import io
|
|||||||
import json
|
import json
|
||||||
import pstats
|
import pstats
|
||||||
import statistics
|
import statistics
|
||||||
|
from numbers import Real
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -244,6 +245,14 @@ def _summary(values: list[float]) -> dict[str, float] | dict[str, None]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _as_float(value: Any) -> float:
|
||||||
|
if isinstance(value, Real):
|
||||||
|
return float(value)
|
||||||
|
if hasattr(value, "val"):
|
||||||
|
return float(value.val)
|
||||||
|
raise TypeError(f"Expected a real-valued metric, got {type(value).__name__}")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StepTimingCollector:
|
class StepTimingCollector:
|
||||||
forward_s: list[float] = field(default_factory=list)
|
forward_s: list[float] = field(default_factory=list)
|
||||||
@@ -261,13 +270,13 @@ class StepTimingCollector:
|
|||||||
optimizer_s: float,
|
optimizer_s: float,
|
||||||
total_update_s: float,
|
total_update_s: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.forward_s.append(forward_s)
|
self.forward_s.append(_as_float(forward_s))
|
||||||
self.backward_s.append(backward_s)
|
self.backward_s.append(_as_float(backward_s))
|
||||||
self.optimizer_s.append(optimizer_s)
|
self.optimizer_s.append(_as_float(optimizer_s))
|
||||||
self.total_update_s.append(total_update_s)
|
self.total_update_s.append(_as_float(total_update_s))
|
||||||
|
|
||||||
def record_dataloading(self, dataloading_s: float) -> None:
|
def record_dataloading(self, dataloading_s: float) -> None:
|
||||||
self.dataloading_s.append(dataloading_s)
|
self.dataloading_s.append(_as_float(dataloading_s))
|
||||||
|
|
||||||
def record_memory(self, *, step: int, allocated_bytes: int, reserved_bytes: int) -> None:
|
def record_memory(self, *, step: int, allocated_bytes: int, reserved_bytes: int) -> None:
|
||||||
self.memory_timeline.append(
|
self.memory_timeline.append(
|
||||||
|
|||||||
@@ -220,3 +220,25 @@ def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
|
|||||||
assert payload["reference_batch_size"] == 2
|
assert payload["reference_batch_size"] == 2
|
||||||
assert "operator_fingerprint" in payload
|
assert "operator_fingerprint" in payload
|
||||||
assert payload["outputs"]["loss"]["numel"] == 1
|
assert payload["outputs"]["loss"]["numel"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_step_timing_collector_accepts_metric_like_values(tmp_path):
|
||||||
|
from lerobot.utils.profiling_utils import StepTimingCollector
|
||||||
|
|
||||||
|
class _MetricLike:
|
||||||
|
def __init__(self, val):
|
||||||
|
self.val = val
|
||||||
|
|
||||||
|
collector = StepTimingCollector()
|
||||||
|
collector.record(
|
||||||
|
forward_s=0.1,
|
||||||
|
backward_s=0.2,
|
||||||
|
optimizer_s=0.3,
|
||||||
|
total_update_s=_MetricLike(0.6),
|
||||||
|
)
|
||||||
|
collector.record_dataloading(_MetricLike(0.05))
|
||||||
|
collector.write_json(tmp_path / "step_timing_summary.json")
|
||||||
|
|
||||||
|
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
|
||||||
|
assert payload["total_update_s"]["mean"] == 0.6
|
||||||
|
assert payload["dataloading_s"]["mean"] == 0.05
|
||||||
|
|||||||
Reference in New Issue
Block a user