From a23ebf9d350d3ea12ea6b8db8e51208c04daffd2 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 23 Apr 2026 13:23:09 +0200 Subject: [PATCH] fix(profiling): address review feedback --- .github/workflows/model_profiling.yml | 1 - src/lerobot/configs/train.py | 4 ++-- src/lerobot/utils/model_profiling.py | 23 ++++++++++++-------- tests/test_model_profiling.py | 30 +++++++++++++++++++-------- 4 files changed, 37 insertions(+), 21 deletions(-) diff --git a/.github/workflows/model_profiling.yml b/.github/workflows/model_profiling.yml index f148e02d7..500a70763 100644 --- a/.github/workflows/model_profiling.yml +++ b/.github/workflows/model_profiling.yml @@ -20,7 +20,6 @@ on: pull_request: branches: - main - - feat/libero-benchmark paths: - .github/workflows/model_profiling.yml - src/lerobot/configs/train.py diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 480fb3536..c375c85be 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -16,7 +16,7 @@ import datetime as dt import os from dataclasses import dataclass, field from pathlib import Path -from typing import Any +from typing import Any, Literal import draccus from huggingface_hub import hf_hub_download @@ -58,7 +58,7 @@ class TrainPipelineConfig(HubMixin): batch_size: int = 8 prefetch_factor: int = 4 persistent_workers: bool = True - profile_mode: str = "off" + profile_mode: Literal["off", "summary", "trace"] = "off" profile_output_dir: Path | None = None steps: int = 100_000 eval_freq: int = 20_000 diff --git a/src/lerobot/utils/model_profiling.py b/src/lerobot/utils/model_profiling.py index 7fe2a9cd1..ab2575528 100644 --- a/src/lerobot/utils/model_profiling.py +++ b/src/lerobot/utils/model_profiling.py @@ -33,7 +33,6 @@ Usage (CI): from __future__ import annotations import argparse -import contextlib import hashlib import json import logging @@ -53,7 +52,7 @@ from typing import Any import torch from huggingface_hub import CommitOperationAdd, HfApi from huggingface_hub.errors import HfHubHTTPError -from torch.utils.data._utils.collate import default_collate +from torch.utils.data import default_collate logger = logging.getLogger(__name__) @@ -305,11 +304,10 @@ def _get_profiler_device_time_us(event: Any) -> float | None: def _write_profiler_table(profiler: Any, path: Path, *, sort_by: str, row_limit: int = 40) -> None: - # The profiler may not have recorded any events for this sort key when the - # schedule window lands outside the active steps — skip silently rather - # than crashing the whole artifact-writer pass. - with contextlib.suppress(Exception): + try: path.write_text(profiler.key_averages().table(sort_by=sort_by, row_limit=row_limit)) + except Exception: + logger.debug("Could not write profiler table for sort_by=%s", sort_by, exc_info=True) def write_deterministic_forward_artifacts( @@ -324,7 +322,9 @@ def write_deterministic_forward_artifacts( """Run a seed-controlled single forward pass and dump a stable fingerprint (loss/output tensor hashes + op counts) for regression detection. Keeps the caller-selected module mode so ACT-with-VAE-style policies that only - materialize their full forward outputs in `train()` still match.""" + materialize their full forward outputs in `train()` still match. Models + with stochastic train-mode layers still rely on the seeded RNG for stable + fingerprints.""" if len(dataset) == 0: raise ValueError("Cannot build a reference batch from an empty dataset.") indices = [i % len(dataset) for i in range(batch_size)] @@ -701,9 +701,12 @@ def main() -> int: args.output_dir.mkdir(parents=True, exist_ok=True) repo_id = args.results_repo if "/" in args.results_repo else f"{args.hub_org}/{args.results_repo}" - git_exe = shutil.which("git") or (_ for _ in ()).throw(RuntimeError("git not found in PATH")) + git_exe = shutil.which("git") + if not git_exe: + raise RuntimeError("git not found in PATH") git_commit = args.git_commit or subprocess.check_output([git_exe, "rev-parse", "HEAD"], text=True).strip() pr_number = int(args.pr_number) if str(args.pr_number).strip() else None + exit_code = 0 for policy in selected: run_id = f"{_utc_timestamp_slug()}__{policy}" @@ -717,6 +720,8 @@ def main() -> int: (run_dir / "stdout.txt").write_text(result.stdout) (run_dir / "stderr.txt").write_text(result.stderr) + if result.returncode != 0: + exit_code = 1 paths, urls, upload_list, row_in_repo = build_artifact_index( repo_id=repo_id, run_dir=run_dir, policy_name=policy, run_id=run_id @@ -771,7 +776,7 @@ def main() -> int: print(json.dumps(row, indent=2, sort_keys=True)) - return 0 + return exit_code if __name__ == "__main__": diff --git a/tests/test_model_profiling.py b/tests/test_model_profiling.py index 6797b7453..83f373da0 100644 --- a/tests/test_model_profiling.py +++ b/tests/test_model_profiling.py @@ -156,7 +156,7 @@ def test_parse_discussion_num_handles_hf_discussion_urls(): @pytest.fixture -def _fake_args(tmp_path): +def fake_args(tmp_path): """Shared argparse namespace for main() smoke tests — overridden per-test.""" return argparse.Namespace( policies=["act"], @@ -195,14 +195,14 @@ def _stub_train_subprocess(mp_module, *, returncode: int = 0, write_artifacts: b return _fake_run -def test_main_smoke_writes_row(monkeypatch, _fake_args): - monkeypatch.setattr(mp, "parse_args", lambda: _fake_args) +def test_main_smoke_writes_row(monkeypatch, fake_args): + monkeypatch.setattr(mp, "parse_args", lambda: fake_args) monkeypatch.setattr(mp.subprocess, "check_output", lambda *a, **k: "deadbeef\n") monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp)) assert mp.main() == 0 - row_paths = list(_fake_args.output_dir.rglob("profiling_row.json")) + row_paths = list(fake_args.output_dir.rglob("profiling_row.json")) assert len(row_paths) == 1 row = json.loads(row_paths[0].read_text()) assert row["policy"] == "act" @@ -214,10 +214,10 @@ def test_main_smoke_writes_row(monkeypatch, _fake_args): assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint" -def test_main_records_publish_failure_without_failing(monkeypatch, _fake_args): - _fake_args.publish = True - _fake_args.git_commit = "deadbeef" - monkeypatch.setattr(mp, "parse_args", lambda: _fake_args) +def test_main_records_publish_failure_without_failing(monkeypatch, fake_args): + fake_args.publish = True + fake_args.git_commit = "deadbeef" + monkeypatch.setattr(mp, "parse_args", lambda: fake_args) monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp, write_artifacts=False)) def _fail_upload(**kwargs): @@ -227,12 +227,24 @@ def test_main_records_publish_failure_without_failing(monkeypatch, _fake_args): monkeypatch.setattr(mp, "upload_profile_run", _fail_upload) assert mp.main() == 0 - row = json.loads(next(_fake_args.output_dir.rglob("profiling_row.json")).read_text()) + row = json.loads(next(fake_args.output_dir.rglob("profiling_row.json")).read_text()) assert row["status"] == "success" assert row["publish_status"] == "failed" assert "Authorization error" in row["publish_error"] +def test_main_returns_nonzero_when_training_subprocess_fails(monkeypatch, fake_args): + monkeypatch.setattr(mp, "parse_args", lambda: fake_args) + monkeypatch.setattr(mp.subprocess, "check_output", lambda *a, **k: "deadbeef\n") + monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp, returncode=3)) + + assert mp.main() == 1 + + row = json.loads(next(fake_args.output_dir.rglob("profiling_row.json")).read_text()) + assert row["status"] == "failed" + assert row["return_code"] == 3 + + # --------------------------------------------------------------------------- # TrainingProfiler behavior # ---------------------------------------------------------------------------