mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
refactor: extract profiling into self-contained TrainingProfiler class
Move all profiling orchestration out of lerobot_train.py and TrainPipelineConfig into a TrainingProfiler class in profiling_utils.py. - lerobot_train.py: ~74 lines of profiling code reduced to ~7 call sites - TrainPipelineConfig: 10 profile_* fields reduced to 2 (mode + output_dir) - update_policy: reverted to clean main-branch signature (no timing_collector) - TrainingProfiler encapsulates torch profiler, timing collection, deterministic forward artifacts, and all output writing - CI script (run_model_profiling.py) unchanged—it only passes the 2 kept fields Made-with: Cursor
This commit is contained in:
@@ -65,29 +65,24 @@ def test_pretrained_libero_specs_match_expected_camera_keys_and_normalization():
|
||||
specs = module.load_specs(spec_path)
|
||||
|
||||
assert (
|
||||
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
|
||||
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
|
||||
in specs["pi0"]["train_args"]
|
||||
'--rename_map={"observation.images.front": "observation.images.base_0_rgb", '
|
||||
'"observation.images.wrist": "observation.images.left_wrist_0_rgb"}' in specs["pi0"].train_args
|
||||
)
|
||||
assert (
|
||||
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
|
||||
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
|
||||
in specs["pi0_fast"]["train_args"]
|
||||
'--rename_map={"observation.images.front": "observation.images.base_0_rgb", '
|
||||
'"observation.images.wrist": "observation.images.left_wrist_0_rgb"}' in specs["pi0_fast"].train_args
|
||||
)
|
||||
assert (
|
||||
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
|
||||
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
|
||||
in specs["pi05"]["train_args"]
|
||||
'--rename_map={"observation.images.front": "observation.images.base_0_rgb", '
|
||||
'"observation.images.wrist": "observation.images.left_wrist_0_rgb"}' in specs["pi05"].train_args
|
||||
)
|
||||
assert (
|
||||
"--policy.normalization_mapping={\"ACTION\": \"MEAN_STD\", "
|
||||
"\"STATE\": \"MEAN_STD\", \"VISUAL\": \"IDENTITY\"}"
|
||||
in specs["pi05"]["train_args"]
|
||||
'--policy.normalization_mapping={"ACTION": "MEAN_STD", '
|
||||
'"STATE": "MEAN_STD", "VISUAL": "IDENTITY"}' in specs["pi05"].train_args
|
||||
)
|
||||
assert (
|
||||
"--rename_map={\"observation.images.front\": \"observation.images.camera1\", "
|
||||
"\"observation.images.wrist\": \"observation.images.camera2\"}"
|
||||
in specs["smolvla"]["train_args"]
|
||||
'--rename_map={"observation.images.front": "observation.images.camera1", '
|
||||
'"observation.images.wrist": "observation.images.camera2"}' in specs["smolvla"].train_args
|
||||
)
|
||||
|
||||
|
||||
@@ -222,7 +217,6 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
|
||||
(profile_dir / "step_timing_summary.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"forward_s": {"count": 1, "mean": 0.1, "median": 0.1, "min": 0.1, "max": 0.1},
|
||||
"total_update_s": {"count": 1, "mean": 0.3, "median": 0.3, "min": 0.3, "max": 0.3},
|
||||
"peak_memory_allocated_bytes": 1024,
|
||||
}
|
||||
@@ -251,7 +245,7 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
|
||||
assert row["git_commit"] == "deadbeef"
|
||||
assert row["git_ref"] == "codex/model-profiling"
|
||||
assert row["pr_number"] == 3389
|
||||
assert row["step_timing_summary"]["forward_s"]["mean"] == 0.1
|
||||
assert row["step_timing_summary"]["total_update_s"]["mean"] == 0.3
|
||||
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
|
||||
|
||||
|
||||
@@ -364,19 +358,14 @@ def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
|
||||
|
||||
|
||||
def test_step_timing_collector_accepts_metric_like_values(tmp_path):
|
||||
from lerobot.utils.profiling_utils import StepTimingCollector
|
||||
from lerobot.utils.profiling_utils import _StepTimingCollector
|
||||
|
||||
class _MetricLike:
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
collector = StepTimingCollector()
|
||||
collector.record(
|
||||
forward_s=0.1,
|
||||
backward_s=0.2,
|
||||
optimizer_s=0.3,
|
||||
total_update_s=_MetricLike(0.6),
|
||||
)
|
||||
collector = _StepTimingCollector()
|
||||
collector.record_step(_MetricLike(0.6))
|
||||
collector.record_dataloading(_MetricLike(0.05))
|
||||
collector.write_json(tmp_path / "step_timing_summary.json")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user