refactor(profiling): remove cProfile, keep torch profiler only

Remove cProfile wrapping from the training loop and profiling utilities.
The torch profiler already captures fine-grained timing and operator
breakdowns; cProfile added redundant overhead without actionable
insight for GPU-bound models.

- Remove render_cprofile_summary, run_with_cprofile from profiling_utils
- Replace cProfile-wrapped calls in lerobot_train with direct calls
- Remove cprofile_summaries from artifact index in run_model_profiling
- Update tests to match

Made-with: Cursor
This commit is contained in:
Pepijn
2026-04-16 15:32:59 +02:00
parent 4137b5785d
commit e16a95a78e
4 changed files with 15 additions and 95 deletions
-5
View File
@@ -184,14 +184,12 @@ def build_artifact_index(
artifact_paths: dict[str, Any] = { artifact_paths: dict[str, Any] = {
"row": row_path_in_repo, "row": row_path_in_repo,
"profiling_files": {}, "profiling_files": {},
"cprofile_summaries": {},
"torch_tables": {}, "torch_tables": {},
"trace_files": {}, "trace_files": {},
} }
artifact_urls: dict[str, Any] = { artifact_urls: dict[str, Any] = {
"row": make_hub_file_url(repo_id, row_path_in_repo), "row": make_hub_file_url(repo_id, row_path_in_repo),
"profiling_files": {}, "profiling_files": {},
"cprofile_summaries": {},
"torch_tables": {}, "torch_tables": {},
"trace_files": {}, "trace_files": {},
} }
@@ -219,9 +217,6 @@ def build_artifact_index(
if path.name == "step_timing_summary.json": if path.name == "step_timing_summary.json":
artifact_paths["step_timing_summary"] = repo_path artifact_paths["step_timing_summary"] = repo_path
artifact_urls["step_timing_summary"] = make_hub_file_url(repo_id, repo_path) artifact_urls["step_timing_summary"] = make_hub_file_url(repo_id, repo_path)
elif "cprofile" in path.parts:
artifact_paths["cprofile_summaries"][path.stem] = repo_path
artifact_urls["cprofile_summaries"][path.stem] = make_hub_file_url(repo_id, repo_path)
elif "torch_tables" in path.parts: elif "torch_tables" in path.parts:
artifact_paths["torch_tables"][path.name] = repo_path artifact_paths["torch_tables"][path.name] = repo_path
artifact_urls["torch_tables"][path.name] = make_hub_file_url(repo_id, repo_path) artifact_urls["torch_tables"][path.name] = make_hub_file_url(repo_id, repo_path)
-36
View File
@@ -54,7 +54,6 @@ from lerobot.utils.profiling_utils import (
StepTimingCollector, StepTimingCollector,
ensure_dir, ensure_dir,
make_torch_profiler, make_torch_profiler,
run_with_cprofile,
write_deterministic_forward_artifacts, write_deterministic_forward_artifacts,
write_torch_profiler_outputs, write_torch_profiler_outputs,
) )
@@ -231,10 +230,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
profiling_enabled = cfg.profile_mode != "off" profiling_enabled = cfg.profile_mode != "off"
profile_output_dir = None profile_output_dir = None
cprofile_dir = None
if profiling_enabled and is_main_process and cfg.profile_output_dir is not None: if profiling_enabled and is_main_process and cfg.profile_output_dir is not None:
profile_output_dir = ensure_dir(Path(cfg.profile_output_dir)) profile_output_dir = ensure_dir(Path(cfg.profile_output_dir))
cprofile_dir = ensure_dir(profile_output_dir / "cprofile")
logging.info("Profiling enabled. Artifacts will be written to %s", profile_output_dir) logging.info("Profiling enabled. Artifacts will be written to %s", profile_output_dir)
# Initialize wandb only on main process # Initialize wandb only on main process
@@ -260,9 +257,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# Dataset loading synchronization: main process downloads first to avoid race conditions # Dataset loading synchronization: main process downloads first to avoid race conditions
if is_main_process: if is_main_process:
logging.info("Creating dataset") logging.info("Creating dataset")
if cprofile_dir is not None:
dataset = run_with_cprofile("dataset_setup", cprofile_dir, make_dataset, cfg)
else:
dataset = make_dataset(cfg) dataset = make_dataset(cfg)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@@ -281,16 +275,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if is_main_process: if is_main_process:
logging.info("Creating policy") logging.info("Creating policy")
if is_main_process and cprofile_dir is not None:
policy = run_with_cprofile(
"policy_setup",
cprofile_dir,
make_policy,
cfg=cfg.policy,
ds_meta=dataset.meta,
rename_map=cfg.rename_map,
)
else:
policy = make_policy( policy = make_policy(
cfg=cfg.policy, cfg=cfg.policy,
ds_meta=dataset.meta, ds_meta=dataset.meta,
@@ -349,17 +333,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
}, },
} }
if is_main_process and cprofile_dir is not None:
preprocessor, postprocessor = run_with_cprofile(
"processor_setup",
cprofile_dir,
make_pre_post_processors,
policy_cfg=cfg.policy,
pretrained_path=processor_pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)
else:
preprocessor, postprocessor = make_pre_post_processors( preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, policy_cfg=cfg.policy,
pretrained_path=processor_pretrained_path, pretrained_path=processor_pretrained_path,
@@ -369,15 +342,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if is_main_process: if is_main_process:
logging.info("Creating optimizer and scheduler") logging.info("Creating optimizer and scheduler")
if is_main_process and cprofile_dir is not None:
optimizer, lr_scheduler = run_with_cprofile(
"optimizer_setup",
cprofile_dir,
make_optimizer_and_scheduler,
cfg,
policy,
)
else:
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
if profiling_enabled and is_main_process and profile_output_dir is not None: if profiling_enabled and is_main_process and profile_output_dir is not None:
-33
View File
@@ -16,13 +16,9 @@
from __future__ import annotations from __future__ import annotations
import cProfile
import hashlib import hashlib
import io
import json import json
import pstats
import statistics import statistics
from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from numbers import Real from numbers import Real
from pathlib import Path from pathlib import Path
@@ -37,15 +33,6 @@ def ensure_dir(path: Path) -> Path:
return path return path
def render_cprofile_summary(
profile: cProfile.Profile, *, sort_by: str = "cumulative", limit: int = 40
) -> str:
output = io.StringIO()
stats = pstats.Stats(profile, stream=output).strip_dirs().sort_stats(sort_by)
stats.print_stats(limit)
return output.getvalue()
def write_profiler_table( def write_profiler_table(
profiler: Any, profiler: Any,
output_path: Path, output_path: Path,
@@ -103,26 +90,6 @@ def write_torch_profiler_outputs(
write_profiler_table(profiler, tables_dir / "flops.txt", sort_by="flops") write_profiler_table(profiler, tables_dir / "flops.txt", sort_by="flops")
def run_with_cprofile[T](
label: str,
output_dir: Path,
fn: Callable[..., T],
*args: Any,
sort_by: str = "cumulative",
limit: int = 40,
**kwargs: Any,
) -> T:
ensure_dir(output_dir)
profile = cProfile.Profile()
profile.enable()
try:
return fn(*args, **kwargs)
finally:
profile.disable()
summary = render_cprofile_summary(profile, sort_by=sort_by, limit=limit)
(output_dir / f"{label}.txt").write_text(summary)
def _stable_float(value: float | int | None) -> float | None: def _stable_float(value: float | int | None) -> float | None:
if value is None: if value is None:
return None return None
+2 -8
View File
@@ -106,18 +106,16 @@ def test_build_train_command_includes_profiling_outputs(tmp_path):
assert "--cudnn_deterministic=true" in cmd assert "--cudnn_deterministic=true" in cmd
def test_build_artifact_index_collects_cprofile_tables_and_traces(tmp_path): def test_build_artifact_index_collects_tables_and_traces(tmp_path):
module = _import_model_profiling_script() module = _import_model_profiling_script()
run_dir = tmp_path / "act" / "20260415T000000Z__act" run_dir = tmp_path / "act" / "20260415T000000Z__act"
profiling_dir = run_dir / "profiling" profiling_dir = run_dir / "profiling"
(profiling_dir / "cprofile").mkdir(parents=True, exist_ok=True)
(profiling_dir / "torch_tables").mkdir(parents=True, exist_ok=True) (profiling_dir / "torch_tables").mkdir(parents=True, exist_ok=True)
(profiling_dir / "torch_traces").mkdir(parents=True, exist_ok=True) (profiling_dir / "torch_traces").mkdir(parents=True, exist_ok=True)
(profiling_dir / "step_timing_summary.json").write_text("{}") (profiling_dir / "step_timing_summary.json").write_text("{}")
(profiling_dir / "deterministic_forward.json").write_text( (profiling_dir / "deterministic_forward.json").write_text(
json.dumps({"operator_fingerprint": "ops123", "output_fingerprint": "out123"}) json.dumps({"operator_fingerprint": "ops123", "output_fingerprint": "out123"})
) )
(profiling_dir / "cprofile" / "policy_setup.txt").write_text("policy setup")
(profiling_dir / "torch_tables" / "cpu_time_total.txt").write_text("cpu table") (profiling_dir / "torch_tables" / "cpu_time_total.txt").write_text("cpu table")
(profiling_dir / "torch_traces" / "trace_step_9.json").write_text("{}") (profiling_dir / "torch_traces" / "trace_step_9.json").write_text("{}")
(run_dir / "stdout.txt").write_text("stdout") (run_dir / "stdout.txt").write_text("stdout")
@@ -133,14 +131,13 @@ def test_build_artifact_index_collects_cprofile_tables_and_traces(tmp_path):
assert row_path_in_repo == "rows/act/20260415T000000Z__act.json" assert row_path_in_repo == "rows/act/20260415T000000Z__act.json"
assert artifact_paths["stdout"].endswith("/stdout.txt") assert artifact_paths["stdout"].endswith("/stdout.txt")
assert artifact_paths["step_timing_summary"].endswith("/profiling/step_timing_summary.json") assert artifact_paths["step_timing_summary"].endswith("/profiling/step_timing_summary.json")
assert "policy_setup" in artifact_paths["cprofile_summaries"]
assert "cpu_time_total.txt" in artifact_paths["torch_tables"] assert "cpu_time_total.txt" in artifact_paths["torch_tables"]
assert "trace_step_9.json" in artifact_paths["trace_files"] assert "trace_step_9.json" in artifact_paths["trace_files"]
assert artifact_paths["profiling_files"]["profiling/deterministic_forward.json"].endswith( assert artifact_paths["profiling_files"]["profiling/deterministic_forward.json"].endswith(
"/profiling/deterministic_forward.json" "/profiling/deterministic_forward.json"
) )
assert artifact_urls["row"].startswith("https://huggingface.co/datasets/lerobot/model-profiling-history/") assert artifact_urls["row"].startswith("https://huggingface.co/datasets/lerobot/model-profiling-history/")
assert len(targets) == 7 assert len(targets) == 6
def test_upload_targets_batches_preview_publish_into_single_hf_pr(monkeypatch, tmp_path): def test_upload_targets_batches_preview_publish_into_single_hf_pr(monkeypatch, tmp_path):
@@ -222,7 +219,6 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
profile_dir = Path( profile_dir = Path(
next(arg.split("=", 1)[1] for arg in cmd if arg.startswith("--profile_output_dir=")) next(arg.split("=", 1)[1] for arg in cmd if arg.startswith("--profile_output_dir="))
) )
(profile_dir / "cprofile").mkdir(parents=True, exist_ok=True)
(profile_dir / "torch_tables").mkdir(parents=True, exist_ok=True) (profile_dir / "torch_tables").mkdir(parents=True, exist_ok=True)
(profile_dir / "step_timing_summary.json").write_text( (profile_dir / "step_timing_summary.json").write_text(
json.dumps( json.dumps(
@@ -241,7 +237,6 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
} }
) )
) )
(profile_dir / "cprofile" / "policy_setup.txt").write_text("policy setup profile")
(profile_dir / "torch_tables" / "cpu_time_total.txt").write_text("cpu time table") (profile_dir / "torch_tables" / "cpu_time_total.txt").write_text("cpu time table")
return subprocess.CompletedProcess(cmd, 0, "stdout ok", "") return subprocess.CompletedProcess(cmd, 0, "stdout ok", "")
@@ -259,7 +254,6 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
assert row["pr_number"] == 3389 assert row["pr_number"] == 3389
assert row["step_timing_summary"]["forward_s"]["mean"] == 0.1 assert row["step_timing_summary"]["forward_s"]["mean"] == 0.1
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint" assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"]
def test_model_profiling_publish_failure_is_recorded_without_failing(monkeypatch, tmp_path): def test_model_profiling_publish_failure_is_recorded_without_failing(monkeypatch, tmp_path):