mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
refactor(profiling): remove cProfile, keep torch profiler only
Remove cProfile wrapping from the training loop and profiling utilities. The torch profiler already captures fine-grained timing and operator breakdowns; cProfile added redundant overhead without actionable insight for GPU-bound models. - Remove render_cprofile_summary, run_with_cprofile from profiling_utils - Replace cProfile-wrapped calls in lerobot_train with direct calls - Remove cprofile_summaries from artifact index in run_model_profiling - Update tests to match Made-with: Cursor
This commit is contained in:
@@ -184,14 +184,12 @@ def build_artifact_index(
|
||||
artifact_paths: dict[str, Any] = {
|
||||
"row": row_path_in_repo,
|
||||
"profiling_files": {},
|
||||
"cprofile_summaries": {},
|
||||
"torch_tables": {},
|
||||
"trace_files": {},
|
||||
}
|
||||
artifact_urls: dict[str, Any] = {
|
||||
"row": make_hub_file_url(repo_id, row_path_in_repo),
|
||||
"profiling_files": {},
|
||||
"cprofile_summaries": {},
|
||||
"torch_tables": {},
|
||||
"trace_files": {},
|
||||
}
|
||||
@@ -219,9 +217,6 @@ def build_artifact_index(
|
||||
if path.name == "step_timing_summary.json":
|
||||
artifact_paths["step_timing_summary"] = repo_path
|
||||
artifact_urls["step_timing_summary"] = make_hub_file_url(repo_id, repo_path)
|
||||
elif "cprofile" in path.parts:
|
||||
artifact_paths["cprofile_summaries"][path.stem] = repo_path
|
||||
artifact_urls["cprofile_summaries"][path.stem] = make_hub_file_url(repo_id, repo_path)
|
||||
elif "torch_tables" in path.parts:
|
||||
artifact_paths["torch_tables"][path.name] = repo_path
|
||||
artifact_urls["torch_tables"][path.name] = make_hub_file_url(repo_id, repo_path)
|
||||
|
||||
@@ -54,7 +54,6 @@ from lerobot.utils.profiling_utils import (
|
||||
StepTimingCollector,
|
||||
ensure_dir,
|
||||
make_torch_profiler,
|
||||
run_with_cprofile,
|
||||
write_deterministic_forward_artifacts,
|
||||
write_torch_profiler_outputs,
|
||||
)
|
||||
@@ -231,10 +230,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
profiling_enabled = cfg.profile_mode != "off"
|
||||
profile_output_dir = None
|
||||
cprofile_dir = None
|
||||
if profiling_enabled and is_main_process and cfg.profile_output_dir is not None:
|
||||
profile_output_dir = ensure_dir(Path(cfg.profile_output_dir))
|
||||
cprofile_dir = ensure_dir(profile_output_dir / "cprofile")
|
||||
logging.info("Profiling enabled. Artifacts will be written to %s", profile_output_dir)
|
||||
|
||||
# Initialize wandb only on main process
|
||||
@@ -260,10 +257,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
if cprofile_dir is not None:
|
||||
dataset = run_with_cprofile("dataset_setup", cprofile_dir, make_dataset, cfg)
|
||||
else:
|
||||
dataset = make_dataset(cfg)
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -281,21 +275,11 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
if is_main_process:
|
||||
logging.info("Creating policy")
|
||||
if is_main_process and cprofile_dir is not None:
|
||||
policy = run_with_cprofile(
|
||||
"policy_setup",
|
||||
cprofile_dir,
|
||||
make_policy,
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
rename_map=cfg.rename_map,
|
||||
)
|
||||
else:
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
rename_map=cfg.rename_map,
|
||||
)
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
rename_map=cfg.rename_map,
|
||||
)
|
||||
|
||||
if cfg.peft is not None:
|
||||
logging.info("Using PEFT! Wrapping model.")
|
||||
@@ -349,36 +333,16 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
},
|
||||
}
|
||||
|
||||
if is_main_process and cprofile_dir is not None:
|
||||
preprocessor, postprocessor = run_with_cprofile(
|
||||
"processor_setup",
|
||||
cprofile_dir,
|
||||
make_pre_post_processors,
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
else:
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
if is_main_process and cprofile_dir is not None:
|
||||
optimizer, lr_scheduler = run_with_cprofile(
|
||||
"optimizer_setup",
|
||||
cprofile_dir,
|
||||
make_optimizer_and_scheduler,
|
||||
cfg,
|
||||
policy,
|
||||
)
|
||||
else:
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
if profiling_enabled and is_main_process and profile_output_dir is not None:
|
||||
logging.info("Recording deterministic forward-pass artifacts")
|
||||
|
||||
@@ -16,13 +16,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import cProfile
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import pstats
|
||||
import statistics
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from numbers import Real
|
||||
from pathlib import Path
|
||||
@@ -37,15 +33,6 @@ def ensure_dir(path: Path) -> Path:
|
||||
return path
|
||||
|
||||
|
||||
def render_cprofile_summary(
|
||||
profile: cProfile.Profile, *, sort_by: str = "cumulative", limit: int = 40
|
||||
) -> str:
|
||||
output = io.StringIO()
|
||||
stats = pstats.Stats(profile, stream=output).strip_dirs().sort_stats(sort_by)
|
||||
stats.print_stats(limit)
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
def write_profiler_table(
|
||||
profiler: Any,
|
||||
output_path: Path,
|
||||
@@ -103,26 +90,6 @@ def write_torch_profiler_outputs(
|
||||
write_profiler_table(profiler, tables_dir / "flops.txt", sort_by="flops")
|
||||
|
||||
|
||||
def run_with_cprofile[T](
|
||||
label: str,
|
||||
output_dir: Path,
|
||||
fn: Callable[..., T],
|
||||
*args: Any,
|
||||
sort_by: str = "cumulative",
|
||||
limit: int = 40,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
ensure_dir(output_dir)
|
||||
profile = cProfile.Profile()
|
||||
profile.enable()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
profile.disable()
|
||||
summary = render_cprofile_summary(profile, sort_by=sort_by, limit=limit)
|
||||
(output_dir / f"{label}.txt").write_text(summary)
|
||||
|
||||
|
||||
def _stable_float(value: float | int | None) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
@@ -106,18 +106,16 @@ def test_build_train_command_includes_profiling_outputs(tmp_path):
|
||||
assert "--cudnn_deterministic=true" in cmd
|
||||
|
||||
|
||||
def test_build_artifact_index_collects_cprofile_tables_and_traces(tmp_path):
|
||||
def test_build_artifact_index_collects_tables_and_traces(tmp_path):
|
||||
module = _import_model_profiling_script()
|
||||
run_dir = tmp_path / "act" / "20260415T000000Z__act"
|
||||
profiling_dir = run_dir / "profiling"
|
||||
(profiling_dir / "cprofile").mkdir(parents=True, exist_ok=True)
|
||||
(profiling_dir / "torch_tables").mkdir(parents=True, exist_ok=True)
|
||||
(profiling_dir / "torch_traces").mkdir(parents=True, exist_ok=True)
|
||||
(profiling_dir / "step_timing_summary.json").write_text("{}")
|
||||
(profiling_dir / "deterministic_forward.json").write_text(
|
||||
json.dumps({"operator_fingerprint": "ops123", "output_fingerprint": "out123"})
|
||||
)
|
||||
(profiling_dir / "cprofile" / "policy_setup.txt").write_text("policy setup")
|
||||
(profiling_dir / "torch_tables" / "cpu_time_total.txt").write_text("cpu table")
|
||||
(profiling_dir / "torch_traces" / "trace_step_9.json").write_text("{}")
|
||||
(run_dir / "stdout.txt").write_text("stdout")
|
||||
@@ -133,14 +131,13 @@ def test_build_artifact_index_collects_cprofile_tables_and_traces(tmp_path):
|
||||
assert row_path_in_repo == "rows/act/20260415T000000Z__act.json"
|
||||
assert artifact_paths["stdout"].endswith("/stdout.txt")
|
||||
assert artifact_paths["step_timing_summary"].endswith("/profiling/step_timing_summary.json")
|
||||
assert "policy_setup" in artifact_paths["cprofile_summaries"]
|
||||
assert "cpu_time_total.txt" in artifact_paths["torch_tables"]
|
||||
assert "trace_step_9.json" in artifact_paths["trace_files"]
|
||||
assert artifact_paths["profiling_files"]["profiling/deterministic_forward.json"].endswith(
|
||||
"/profiling/deterministic_forward.json"
|
||||
)
|
||||
assert artifact_urls["row"].startswith("https://huggingface.co/datasets/lerobot/model-profiling-history/")
|
||||
assert len(targets) == 7
|
||||
assert len(targets) == 6
|
||||
|
||||
|
||||
def test_upload_targets_batches_preview_publish_into_single_hf_pr(monkeypatch, tmp_path):
|
||||
@@ -222,7 +219,6 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
|
||||
profile_dir = Path(
|
||||
next(arg.split("=", 1)[1] for arg in cmd if arg.startswith("--profile_output_dir="))
|
||||
)
|
||||
(profile_dir / "cprofile").mkdir(parents=True, exist_ok=True)
|
||||
(profile_dir / "torch_tables").mkdir(parents=True, exist_ok=True)
|
||||
(profile_dir / "step_timing_summary.json").write_text(
|
||||
json.dumps(
|
||||
@@ -241,7 +237,6 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
|
||||
}
|
||||
)
|
||||
)
|
||||
(profile_dir / "cprofile" / "policy_setup.txt").write_text("policy setup profile")
|
||||
(profile_dir / "torch_tables" / "cpu_time_total.txt").write_text("cpu time table")
|
||||
return subprocess.CompletedProcess(cmd, 0, "stdout ok", "")
|
||||
|
||||
@@ -259,7 +254,6 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
|
||||
assert row["pr_number"] == 3389
|
||||
assert row["step_timing_summary"]["forward_s"]["mean"] == 0.1
|
||||
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
|
||||
assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"]
|
||||
|
||||
|
||||
def test_model_profiling_publish_failure_is_recorded_without_failing(monkeypatch, tmp_path):
|
||||
|
||||
Reference in New Issue
Block a user