fix(profiling): address review feedback

This commit is contained in:
Pepijn
2026-04-23 13:23:09 +02:00
parent bfff81fd4b
commit a23ebf9d35
4 changed files with 37 additions and 21 deletions
-1
View File
@@ -20,7 +20,6 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
- feat/libero-benchmark
paths: paths:
- .github/workflows/model_profiling.yml - .github/workflows/model_profiling.yml
- src/lerobot/configs/train.py - src/lerobot/configs/train.py
+2 -2
View File
@@ -16,7 +16,7 @@ import datetime as dt
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Literal
import draccus import draccus
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@@ -58,7 +58,7 @@ class TrainPipelineConfig(HubMixin):
batch_size: int = 8 batch_size: int = 8
prefetch_factor: int = 4 prefetch_factor: int = 4
persistent_workers: bool = True persistent_workers: bool = True
profile_mode: str = "off" profile_mode: Literal["off", "summary", "trace"] = "off"
profile_output_dir: Path | None = None profile_output_dir: Path | None = None
steps: int = 100_000 steps: int = 100_000
eval_freq: int = 20_000 eval_freq: int = 20_000
+14 -9
View File
@@ -33,7 +33,6 @@ Usage (CI):
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import contextlib
import hashlib import hashlib
import json import json
import logging import logging
@@ -53,7 +52,7 @@ from typing import Any
import torch import torch
from huggingface_hub import CommitOperationAdd, HfApi from huggingface_hub import CommitOperationAdd, HfApi
from huggingface_hub.errors import HfHubHTTPError from huggingface_hub.errors import HfHubHTTPError
from torch.utils.data._utils.collate import default_collate from torch.utils.data import default_collate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -305,11 +304,10 @@ def _get_profiler_device_time_us(event: Any) -> float | None:
def _write_profiler_table(profiler: Any, path: Path, *, sort_by: str, row_limit: int = 40) -> None: def _write_profiler_table(profiler: Any, path: Path, *, sort_by: str, row_limit: int = 40) -> None:
# The profiler may not have recorded any events for this sort key when the try:
# schedule window lands outside the active steps — skip silently rather
# than crashing the whole artifact-writer pass.
with contextlib.suppress(Exception):
path.write_text(profiler.key_averages().table(sort_by=sort_by, row_limit=row_limit)) path.write_text(profiler.key_averages().table(sort_by=sort_by, row_limit=row_limit))
except Exception:
logger.debug("Could not write profiler table for sort_by=%s", sort_by, exc_info=True)
def write_deterministic_forward_artifacts( def write_deterministic_forward_artifacts(
@@ -324,7 +322,9 @@ def write_deterministic_forward_artifacts(
"""Run a seed-controlled single forward pass and dump a stable fingerprint """Run a seed-controlled single forward pass and dump a stable fingerprint
(loss/output tensor hashes + op counts) for regression detection. Keeps (loss/output tensor hashes + op counts) for regression detection. Keeps
the caller-selected module mode so ACT-with-VAE-style policies that only the caller-selected module mode so ACT-with-VAE-style policies that only
materialize their full forward outputs in `train()` still match.""" materialize their full forward outputs in `train()` still match. Models
with stochastic train-mode layers still rely on the seeded RNG for stable
fingerprints."""
if len(dataset) == 0: if len(dataset) == 0:
raise ValueError("Cannot build a reference batch from an empty dataset.") raise ValueError("Cannot build a reference batch from an empty dataset.")
indices = [i % len(dataset) for i in range(batch_size)] indices = [i % len(dataset) for i in range(batch_size)]
@@ -701,9 +701,12 @@ def main() -> int:
args.output_dir.mkdir(parents=True, exist_ok=True) args.output_dir.mkdir(parents=True, exist_ok=True)
repo_id = args.results_repo if "/" in args.results_repo else f"{args.hub_org}/{args.results_repo}" repo_id = args.results_repo if "/" in args.results_repo else f"{args.hub_org}/{args.results_repo}"
git_exe = shutil.which("git") or (_ for _ in ()).throw(RuntimeError("git not found in PATH")) git_exe = shutil.which("git")
if not git_exe:
raise RuntimeError("git not found in PATH")
git_commit = args.git_commit or subprocess.check_output([git_exe, "rev-parse", "HEAD"], text=True).strip() git_commit = args.git_commit or subprocess.check_output([git_exe, "rev-parse", "HEAD"], text=True).strip()
pr_number = int(args.pr_number) if str(args.pr_number).strip() else None pr_number = int(args.pr_number) if str(args.pr_number).strip() else None
exit_code = 0
for policy in selected: for policy in selected:
run_id = f"{_utc_timestamp_slug()}__{policy}" run_id = f"{_utc_timestamp_slug()}__{policy}"
@@ -717,6 +720,8 @@ def main() -> int:
(run_dir / "stdout.txt").write_text(result.stdout) (run_dir / "stdout.txt").write_text(result.stdout)
(run_dir / "stderr.txt").write_text(result.stderr) (run_dir / "stderr.txt").write_text(result.stderr)
if result.returncode != 0:
exit_code = 1
paths, urls, upload_list, row_in_repo = build_artifact_index( paths, urls, upload_list, row_in_repo = build_artifact_index(
repo_id=repo_id, run_dir=run_dir, policy_name=policy, run_id=run_id repo_id=repo_id, run_dir=run_dir, policy_name=policy, run_id=run_id
@@ -771,7 +776,7 @@ def main() -> int:
print(json.dumps(row, indent=2, sort_keys=True)) print(json.dumps(row, indent=2, sort_keys=True))
return 0 return exit_code
if __name__ == "__main__": if __name__ == "__main__":
+21 -9
View File
@@ -156,7 +156,7 @@ def test_parse_discussion_num_handles_hf_discussion_urls():
@pytest.fixture @pytest.fixture
def _fake_args(tmp_path): def fake_args(tmp_path):
"""Shared argparse namespace for main() smoke tests — overridden per-test.""" """Shared argparse namespace for main() smoke tests — overridden per-test."""
return argparse.Namespace( return argparse.Namespace(
policies=["act"], policies=["act"],
@@ -195,14 +195,14 @@ def _stub_train_subprocess(mp_module, *, returncode: int = 0, write_artifacts: b
return _fake_run return _fake_run
def test_main_smoke_writes_row(monkeypatch, _fake_args): def test_main_smoke_writes_row(monkeypatch, fake_args):
monkeypatch.setattr(mp, "parse_args", lambda: _fake_args) monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
monkeypatch.setattr(mp.subprocess, "check_output", lambda *a, **k: "deadbeef\n") monkeypatch.setattr(mp.subprocess, "check_output", lambda *a, **k: "deadbeef\n")
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp)) monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp))
assert mp.main() == 0 assert mp.main() == 0
row_paths = list(_fake_args.output_dir.rglob("profiling_row.json")) row_paths = list(fake_args.output_dir.rglob("profiling_row.json"))
assert len(row_paths) == 1 assert len(row_paths) == 1
row = json.loads(row_paths[0].read_text()) row = json.loads(row_paths[0].read_text())
assert row["policy"] == "act" assert row["policy"] == "act"
@@ -214,10 +214,10 @@ def test_main_smoke_writes_row(monkeypatch, _fake_args):
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint" assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
def test_main_records_publish_failure_without_failing(monkeypatch, _fake_args): def test_main_records_publish_failure_without_failing(monkeypatch, fake_args):
_fake_args.publish = True fake_args.publish = True
_fake_args.git_commit = "deadbeef" fake_args.git_commit = "deadbeef"
monkeypatch.setattr(mp, "parse_args", lambda: _fake_args) monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp, write_artifacts=False)) monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp, write_artifacts=False))
def _fail_upload(**kwargs): def _fail_upload(**kwargs):
@@ -227,12 +227,24 @@ def test_main_records_publish_failure_without_failing(monkeypatch, _fake_args):
monkeypatch.setattr(mp, "upload_profile_run", _fail_upload) monkeypatch.setattr(mp, "upload_profile_run", _fail_upload)
assert mp.main() == 0 assert mp.main() == 0
row = json.loads(next(_fake_args.output_dir.rglob("profiling_row.json")).read_text()) row = json.loads(next(fake_args.output_dir.rglob("profiling_row.json")).read_text())
assert row["status"] == "success" assert row["status"] == "success"
assert row["publish_status"] == "failed" assert row["publish_status"] == "failed"
assert "Authorization error" in row["publish_error"] assert "Authorization error" in row["publish_error"]
def test_main_returns_nonzero_when_training_subprocess_fails(monkeypatch, fake_args):
monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
monkeypatch.setattr(mp.subprocess, "check_output", lambda *a, **k: "deadbeef\n")
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp, returncode=3))
assert mp.main() == 1
row = json.loads(next(fake_args.output_dir.rglob("profiling_row.json")).read_text())
assert row["status"] == "failed"
assert row["return_code"] == 3
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# TrainingProfiler behavior # TrainingProfiler behavior
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------