fix(profiling): address review feedback

This commit is contained in:
Pepijn
2026-04-23 13:23:09 +02:00
parent bfff81fd4b
commit a23ebf9d35
4 changed files with 37 additions and 21 deletions
-1
View File
@@ -20,7 +20,6 @@ on:
pull_request:
branches:
- main
- feat/libero-benchmark
paths:
- .github/workflows/model_profiling.yml
- src/lerobot/configs/train.py
+2 -2
View File
@@ -16,7 +16,7 @@ import datetime as dt
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from typing import Any, Literal
import draccus
from huggingface_hub import hf_hub_download
@@ -58,7 +58,7 @@ class TrainPipelineConfig(HubMixin):
batch_size: int = 8
prefetch_factor: int = 4
persistent_workers: bool = True
profile_mode: str = "off"
profile_mode: Literal["off", "summary", "trace"] = "off"
profile_output_dir: Path | None = None
steps: int = 100_000
eval_freq: int = 20_000
+14 -9
View File
@@ -33,7 +33,6 @@ Usage (CI):
from __future__ import annotations
import argparse
import contextlib
import hashlib
import json
import logging
@@ -53,7 +52,7 @@ from typing import Any
import torch
from huggingface_hub import CommitOperationAdd, HfApi
from huggingface_hub.errors import HfHubHTTPError
from torch.utils.data._utils.collate import default_collate
from torch.utils.data import default_collate
logger = logging.getLogger(__name__)
@@ -305,11 +304,10 @@ def _get_profiler_device_time_us(event: Any) -> float | None:
def _write_profiler_table(profiler: Any, path: Path, *, sort_by: str, row_limit: int = 40) -> None:
# The profiler may not have recorded any events for this sort key when the
# schedule window lands outside the active steps — skip silently rather
# than crashing the whole artifact-writer pass.
with contextlib.suppress(Exception):
try:
path.write_text(profiler.key_averages().table(sort_by=sort_by, row_limit=row_limit))
except Exception:
logger.debug("Could not write profiler table for sort_by=%s", sort_by, exc_info=True)
def write_deterministic_forward_artifacts(
@@ -324,7 +322,9 @@ def write_deterministic_forward_artifacts(
"""Run a seed-controlled single forward pass and dump a stable fingerprint
(loss/output tensor hashes + op counts) for regression detection. Keeps
the caller-selected module mode so ACT-with-VAE-style policies that only
materialize their full forward outputs in `train()` still match."""
materialize their full forward outputs in `train()` still match. Models
with stochastic train-mode layers still rely on the seeded RNG for stable
fingerprints."""
if len(dataset) == 0:
raise ValueError("Cannot build a reference batch from an empty dataset.")
indices = [i % len(dataset) for i in range(batch_size)]
@@ -701,9 +701,12 @@ def main() -> int:
args.output_dir.mkdir(parents=True, exist_ok=True)
repo_id = args.results_repo if "/" in args.results_repo else f"{args.hub_org}/{args.results_repo}"
git_exe = shutil.which("git") or (_ for _ in ()).throw(RuntimeError("git not found in PATH"))
git_exe = shutil.which("git")
if not git_exe:
raise RuntimeError("git not found in PATH")
git_commit = args.git_commit or subprocess.check_output([git_exe, "rev-parse", "HEAD"], text=True).strip()
pr_number = int(args.pr_number) if str(args.pr_number).strip() else None
exit_code = 0
for policy in selected:
run_id = f"{_utc_timestamp_slug()}__{policy}"
@@ -717,6 +720,8 @@ def main() -> int:
(run_dir / "stdout.txt").write_text(result.stdout)
(run_dir / "stderr.txt").write_text(result.stderr)
if result.returncode != 0:
exit_code = 1
paths, urls, upload_list, row_in_repo = build_artifact_index(
repo_id=repo_id, run_dir=run_dir, policy_name=policy, run_id=run_id
@@ -771,7 +776,7 @@ def main() -> int:
print(json.dumps(row, indent=2, sort_keys=True))
return 0
return exit_code
if __name__ == "__main__":
+21 -9
View File
@@ -156,7 +156,7 @@ def test_parse_discussion_num_handles_hf_discussion_urls():
@pytest.fixture
def _fake_args(tmp_path):
def fake_args(tmp_path):
"""Shared argparse namespace for main() smoke tests — overridden per-test."""
return argparse.Namespace(
policies=["act"],
@@ -195,14 +195,14 @@ def _stub_train_subprocess(mp_module, *, returncode: int = 0, write_artifacts: b
return _fake_run
def test_main_smoke_writes_row(monkeypatch, _fake_args):
monkeypatch.setattr(mp, "parse_args", lambda: _fake_args)
def test_main_smoke_writes_row(monkeypatch, fake_args):
monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
monkeypatch.setattr(mp.subprocess, "check_output", lambda *a, **k: "deadbeef\n")
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp))
assert mp.main() == 0
row_paths = list(_fake_args.output_dir.rglob("profiling_row.json"))
row_paths = list(fake_args.output_dir.rglob("profiling_row.json"))
assert len(row_paths) == 1
row = json.loads(row_paths[0].read_text())
assert row["policy"] == "act"
@@ -214,10 +214,10 @@ def test_main_smoke_writes_row(monkeypatch, _fake_args):
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
def test_main_records_publish_failure_without_failing(monkeypatch, _fake_args):
_fake_args.publish = True
_fake_args.git_commit = "deadbeef"
monkeypatch.setattr(mp, "parse_args", lambda: _fake_args)
def test_main_records_publish_failure_without_failing(monkeypatch, fake_args):
fake_args.publish = True
fake_args.git_commit = "deadbeef"
monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp, write_artifacts=False))
def _fail_upload(**kwargs):
@@ -227,12 +227,24 @@ def test_main_records_publish_failure_without_failing(monkeypatch, _fake_args):
monkeypatch.setattr(mp, "upload_profile_run", _fail_upload)
assert mp.main() == 0
row = json.loads(next(_fake_args.output_dir.rglob("profiling_row.json")).read_text())
row = json.loads(next(fake_args.output_dir.rglob("profiling_row.json")).read_text())
assert row["status"] == "success"
assert row["publish_status"] == "failed"
assert "Authorization error" in row["publish_error"]
def test_main_returns_nonzero_when_training_subprocess_fails(monkeypatch, fake_args):
monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
monkeypatch.setattr(mp.subprocess, "check_output", lambda *a, **k: "deadbeef\n")
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp, returncode=3))
assert mp.main() == 1
row = json.loads(next(fake_args.output_dir.rglob("profiling_row.json")).read_text())
assert row["status"] == "failed"
assert row["return_code"] == 3
# ---------------------------------------------------------------------------
# TrainingProfiler behavior
# ---------------------------------------------------------------------------