mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(profiling): address review feedback
This commit is contained in:
@@ -20,7 +20,6 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
- feat/libero-benchmark
|
|
||||||
paths:
|
paths:
|
||||||
- .github/workflows/model_profiling.yml
|
- .github/workflows/model_profiling.yml
|
||||||
- src/lerobot/configs/train.py
|
- src/lerobot/configs/train.py
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import datetime as dt
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
@@ -58,7 +58,7 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
batch_size: int = 8
|
batch_size: int = 8
|
||||||
prefetch_factor: int = 4
|
prefetch_factor: int = 4
|
||||||
persistent_workers: bool = True
|
persistent_workers: bool = True
|
||||||
profile_mode: str = "off"
|
profile_mode: Literal["off", "summary", "trace"] = "off"
|
||||||
profile_output_dir: Path | None = None
|
profile_output_dir: Path | None = None
|
||||||
steps: int = 100_000
|
steps: int = 100_000
|
||||||
eval_freq: int = 20_000
|
eval_freq: int = 20_000
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ Usage (CI):
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import contextlib
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -53,7 +52,7 @@ from typing import Any
|
|||||||
import torch
|
import torch
|
||||||
from huggingface_hub import CommitOperationAdd, HfApi
|
from huggingface_hub import CommitOperationAdd, HfApi
|
||||||
from huggingface_hub.errors import HfHubHTTPError
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
from torch.utils.data._utils.collate import default_collate
|
from torch.utils.data import default_collate
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -305,11 +304,10 @@ def _get_profiler_device_time_us(event: Any) -> float | None:
|
|||||||
|
|
||||||
|
|
||||||
def _write_profiler_table(profiler: Any, path: Path, *, sort_by: str, row_limit: int = 40) -> None:
|
def _write_profiler_table(profiler: Any, path: Path, *, sort_by: str, row_limit: int = 40) -> None:
|
||||||
# The profiler may not have recorded any events for this sort key when the
|
try:
|
||||||
# schedule window lands outside the active steps — skip silently rather
|
|
||||||
# than crashing the whole artifact-writer pass.
|
|
||||||
with contextlib.suppress(Exception):
|
|
||||||
path.write_text(profiler.key_averages().table(sort_by=sort_by, row_limit=row_limit))
|
path.write_text(profiler.key_averages().table(sort_by=sort_by, row_limit=row_limit))
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not write profiler table for sort_by=%s", sort_by, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
def write_deterministic_forward_artifacts(
|
def write_deterministic_forward_artifacts(
|
||||||
@@ -324,7 +322,9 @@ def write_deterministic_forward_artifacts(
|
|||||||
"""Run a seed-controlled single forward pass and dump a stable fingerprint
|
"""Run a seed-controlled single forward pass and dump a stable fingerprint
|
||||||
(loss/output tensor hashes + op counts) for regression detection. Keeps
|
(loss/output tensor hashes + op counts) for regression detection. Keeps
|
||||||
the caller-selected module mode so ACT-with-VAE-style policies that only
|
the caller-selected module mode so ACT-with-VAE-style policies that only
|
||||||
materialize their full forward outputs in `train()` still match."""
|
materialize their full forward outputs in `train()` still match. Models
|
||||||
|
with stochastic train-mode layers still rely on the seeded RNG for stable
|
||||||
|
fingerprints."""
|
||||||
if len(dataset) == 0:
|
if len(dataset) == 0:
|
||||||
raise ValueError("Cannot build a reference batch from an empty dataset.")
|
raise ValueError("Cannot build a reference batch from an empty dataset.")
|
||||||
indices = [i % len(dataset) for i in range(batch_size)]
|
indices = [i % len(dataset) for i in range(batch_size)]
|
||||||
@@ -701,9 +701,12 @@ def main() -> int:
|
|||||||
|
|
||||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
repo_id = args.results_repo if "/" in args.results_repo else f"{args.hub_org}/{args.results_repo}"
|
repo_id = args.results_repo if "/" in args.results_repo else f"{args.hub_org}/{args.results_repo}"
|
||||||
git_exe = shutil.which("git") or (_ for _ in ()).throw(RuntimeError("git not found in PATH"))
|
git_exe = shutil.which("git")
|
||||||
|
if not git_exe:
|
||||||
|
raise RuntimeError("git not found in PATH")
|
||||||
git_commit = args.git_commit or subprocess.check_output([git_exe, "rev-parse", "HEAD"], text=True).strip()
|
git_commit = args.git_commit or subprocess.check_output([git_exe, "rev-parse", "HEAD"], text=True).strip()
|
||||||
pr_number = int(args.pr_number) if str(args.pr_number).strip() else None
|
pr_number = int(args.pr_number) if str(args.pr_number).strip() else None
|
||||||
|
exit_code = 0
|
||||||
|
|
||||||
for policy in selected:
|
for policy in selected:
|
||||||
run_id = f"{_utc_timestamp_slug()}__{policy}"
|
run_id = f"{_utc_timestamp_slug()}__{policy}"
|
||||||
@@ -717,6 +720,8 @@ def main() -> int:
|
|||||||
|
|
||||||
(run_dir / "stdout.txt").write_text(result.stdout)
|
(run_dir / "stdout.txt").write_text(result.stdout)
|
||||||
(run_dir / "stderr.txt").write_text(result.stderr)
|
(run_dir / "stderr.txt").write_text(result.stderr)
|
||||||
|
if result.returncode != 0:
|
||||||
|
exit_code = 1
|
||||||
|
|
||||||
paths, urls, upload_list, row_in_repo = build_artifact_index(
|
paths, urls, upload_list, row_in_repo = build_artifact_index(
|
||||||
repo_id=repo_id, run_dir=run_dir, policy_name=policy, run_id=run_id
|
repo_id=repo_id, run_dir=run_dir, policy_name=policy, run_id=run_id
|
||||||
@@ -771,7 +776,7 @@ def main() -> int:
|
|||||||
|
|
||||||
print(json.dumps(row, indent=2, sort_keys=True))
|
print(json.dumps(row, indent=2, sort_keys=True))
|
||||||
|
|
||||||
return 0
|
return exit_code
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ def test_parse_discussion_num_handles_hf_discussion_urls():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def _fake_args(tmp_path):
|
def fake_args(tmp_path):
|
||||||
"""Shared argparse namespace for main() smoke tests — overridden per-test."""
|
"""Shared argparse namespace for main() smoke tests — overridden per-test."""
|
||||||
return argparse.Namespace(
|
return argparse.Namespace(
|
||||||
policies=["act"],
|
policies=["act"],
|
||||||
@@ -195,14 +195,14 @@ def _stub_train_subprocess(mp_module, *, returncode: int = 0, write_artifacts: b
|
|||||||
return _fake_run
|
return _fake_run
|
||||||
|
|
||||||
|
|
||||||
def test_main_smoke_writes_row(monkeypatch, _fake_args):
|
def test_main_smoke_writes_row(monkeypatch, fake_args):
|
||||||
monkeypatch.setattr(mp, "parse_args", lambda: _fake_args)
|
monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
|
||||||
monkeypatch.setattr(mp.subprocess, "check_output", lambda *a, **k: "deadbeef\n")
|
monkeypatch.setattr(mp.subprocess, "check_output", lambda *a, **k: "deadbeef\n")
|
||||||
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp))
|
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp))
|
||||||
|
|
||||||
assert mp.main() == 0
|
assert mp.main() == 0
|
||||||
|
|
||||||
row_paths = list(_fake_args.output_dir.rglob("profiling_row.json"))
|
row_paths = list(fake_args.output_dir.rglob("profiling_row.json"))
|
||||||
assert len(row_paths) == 1
|
assert len(row_paths) == 1
|
||||||
row = json.loads(row_paths[0].read_text())
|
row = json.loads(row_paths[0].read_text())
|
||||||
assert row["policy"] == "act"
|
assert row["policy"] == "act"
|
||||||
@@ -214,10 +214,10 @@ def test_main_smoke_writes_row(monkeypatch, _fake_args):
|
|||||||
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
|
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
|
||||||
|
|
||||||
|
|
||||||
def test_main_records_publish_failure_without_failing(monkeypatch, _fake_args):
|
def test_main_records_publish_failure_without_failing(monkeypatch, fake_args):
|
||||||
_fake_args.publish = True
|
fake_args.publish = True
|
||||||
_fake_args.git_commit = "deadbeef"
|
fake_args.git_commit = "deadbeef"
|
||||||
monkeypatch.setattr(mp, "parse_args", lambda: _fake_args)
|
monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
|
||||||
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp, write_artifacts=False))
|
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp, write_artifacts=False))
|
||||||
|
|
||||||
def _fail_upload(**kwargs):
|
def _fail_upload(**kwargs):
|
||||||
@@ -227,12 +227,24 @@ def test_main_records_publish_failure_without_failing(monkeypatch, _fake_args):
|
|||||||
monkeypatch.setattr(mp, "upload_profile_run", _fail_upload)
|
monkeypatch.setattr(mp, "upload_profile_run", _fail_upload)
|
||||||
|
|
||||||
assert mp.main() == 0
|
assert mp.main() == 0
|
||||||
row = json.loads(next(_fake_args.output_dir.rglob("profiling_row.json")).read_text())
|
row = json.loads(next(fake_args.output_dir.rglob("profiling_row.json")).read_text())
|
||||||
assert row["status"] == "success"
|
assert row["status"] == "success"
|
||||||
assert row["publish_status"] == "failed"
|
assert row["publish_status"] == "failed"
|
||||||
assert "Authorization error" in row["publish_error"]
|
assert "Authorization error" in row["publish_error"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_returns_nonzero_when_training_subprocess_fails(monkeypatch, fake_args):
|
||||||
|
monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
|
||||||
|
monkeypatch.setattr(mp.subprocess, "check_output", lambda *a, **k: "deadbeef\n")
|
||||||
|
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp, returncode=3))
|
||||||
|
|
||||||
|
assert mp.main() == 1
|
||||||
|
|
||||||
|
row = json.loads(next(fake_args.output_dir.rglob("profiling_row.json")).read_text())
|
||||||
|
assert row["status"] == "failed"
|
||||||
|
assert row["return_code"] == 3
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# TrainingProfiler behavior
|
# TrainingProfiler behavior
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user