refactor: extract profiling into self-contained TrainingProfiler class

Move all profiling orchestration out of lerobot_train.py and
TrainPipelineConfig into a TrainingProfiler class in profiling_utils.py.

- lerobot_train.py: ~74 lines of profiling code reduced to ~7 call sites
- TrainPipelineConfig: 10 profile_* fields reduced to 2 (mode + output_dir)
- update_policy: reverted to clean main-branch signature (no timing_collector)
- TrainingProfiler encapsulates torch profiler, timing collection,
  deterministic forward artifacts, and all output writing
- CI script (run_model_profiling.py) unchanged—it only passes the 2 kept fields

Made-with: Cursor
This commit is contained in:
Pepijn
2026-04-16 16:00:49 +02:00
parent a4544ffea7
commit b1e16783de
4 changed files with 148 additions and 138 deletions
+14 -25
View File
@@ -65,29 +65,24 @@ def test_pretrained_libero_specs_match_expected_camera_keys_and_normalization():
specs = module.load_specs(spec_path)
assert (
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
in specs["pi0"]["train_args"]
'--rename_map={"observation.images.front": "observation.images.base_0_rgb", '
'"observation.images.wrist": "observation.images.left_wrist_0_rgb"}' in specs["pi0"].train_args
)
assert (
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
in specs["pi0_fast"]["train_args"]
'--rename_map={"observation.images.front": "observation.images.base_0_rgb", '
'"observation.images.wrist": "observation.images.left_wrist_0_rgb"}' in specs["pi0_fast"].train_args
)
assert (
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
in specs["pi05"]["train_args"]
'--rename_map={"observation.images.front": "observation.images.base_0_rgb", '
'"observation.images.wrist": "observation.images.left_wrist_0_rgb"}' in specs["pi05"].train_args
)
assert (
"--policy.normalization_mapping={\"ACTION\": \"MEAN_STD\", "
"\"STATE\": \"MEAN_STD\", \"VISUAL\": \"IDENTITY\"}"
in specs["pi05"]["train_args"]
'--policy.normalization_mapping={"ACTION": "MEAN_STD", '
'"STATE": "MEAN_STD", "VISUAL": "IDENTITY"}' in specs["pi05"].train_args
)
assert (
"--rename_map={\"observation.images.front\": \"observation.images.camera1\", "
"\"observation.images.wrist\": \"observation.images.camera2\"}"
in specs["smolvla"]["train_args"]
'--rename_map={"observation.images.front": "observation.images.camera1", '
'"observation.images.wrist": "observation.images.camera2"}' in specs["smolvla"].train_args
)
@@ -222,7 +217,6 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
(profile_dir / "step_timing_summary.json").write_text(
json.dumps(
{
"forward_s": {"count": 1, "mean": 0.1, "median": 0.1, "min": 0.1, "max": 0.1},
"total_update_s": {"count": 1, "mean": 0.3, "median": 0.3, "min": 0.3, "max": 0.3},
"peak_memory_allocated_bytes": 1024,
}
@@ -251,7 +245,7 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
assert row["git_commit"] == "deadbeef"
assert row["git_ref"] == "codex/model-profiling"
assert row["pr_number"] == 3389
assert row["step_timing_summary"]["forward_s"]["mean"] == 0.1
assert row["step_timing_summary"]["total_update_s"]["mean"] == 0.3
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
@@ -364,19 +358,14 @@ def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
def test_step_timing_collector_accepts_metric_like_values(tmp_path):
from lerobot.utils.profiling_utils import StepTimingCollector
from lerobot.utils.profiling_utils import _StepTimingCollector
class _MetricLike:
def __init__(self, val):
self.val = val
collector = StepTimingCollector()
collector.record(
forward_s=0.1,
backward_s=0.2,
optimizer_s=0.3,
total_update_s=_MetricLike(0.6),
)
collector = _StepTimingCollector()
collector.record_step(_MetricLike(0.6))
collector.record_dataloading(_MetricLike(0.05))
collector.write_json(tmp_path / "step_timing_summary.json")