refactor(profiling): shrink lerobot_train.py diff via start()/finalize()

Replace the `with profiler or nullcontext():` wrap around the entire
training loop with explicit `profiler.start()` / `profiler.finalize()`
calls, and tighten `_section(...)` regions in `update_policy` to only
wrap the hot calls (forward / backward / optimizer.step).

This avoids ~120 lines of pure re-indentation noise while keeping the
exact same artifacts on disk and the same public behavior.

lerobot_train.py diff vs main: 267 -> 29 changed lines.

Made-with: Cursor
This commit is contained in:
Pepijn
2026-04-17 10:59:43 +01:00
parent a6dd28e8b4
commit 1ac8e96575
3 changed files with 128 additions and 131 deletions
+3 -3
View File
@@ -138,8 +138,7 @@ def update_policy(
policy.parameters(), float("inf"), error_if_nonfinite=False policy.parameters(), float("inf"), error_if_nonfinite=False
) )
with _section("optimizer"): with _section("optimizer"), lock if lock is not None else nullcontext():
with lock if lock is not None else nullcontext():
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
@@ -327,6 +326,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
profiler.record_deterministic_forward( profiler.record_deterministic_forward(
policy=policy, dataset=dataset, batch_size=cfg.batch_size, preprocessor=preprocessor policy=policy, dataset=dataset, batch_size=cfg.batch_size, preprocessor=preprocessor
) )
profiler.start()
# Load precomputed SARM progress for RA-BC if enabled # Load precomputed SARM progress for RA-BC if enabled
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py # Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
@@ -441,7 +441,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info( logging.info(
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}" f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
) )
with profiler or nullcontext():
for _ in range(step, cfg.steps): for _ in range(step, cfg.steps):
start_time = time.perf_counter() start_time = time.perf_counter()
batch = next(dl_iter) batch = next(dl_iter)
+2 -5
View File
@@ -356,14 +356,10 @@ class TrainingProfiler:
if self._device.type == "cuda": if self._device.type == "cuda":
torch.cuda.empty_cache() torch.cuda.empty_cache()
def __enter__(self) -> TrainingProfiler: def start(self) -> None:
if self._device.type == "cuda": if self._device.type == "cuda":
torch.cuda.reset_peak_memory_stats(self._device) torch.cuda.reset_peak_memory_stats(self._device)
self._torch_profiler.__enter__() self._torch_profiler.__enter__()
return self
def __exit__(self, *exc: Any) -> bool:
return self._torch_profiler.__exit__(*exc)
@contextmanager @contextmanager
def section(self, name: str) -> Iterator[None]: def section(self, name: str) -> Iterator[None]:
@@ -394,6 +390,7 @@ class TrainingProfiler:
self._torch_profiler.step() self._torch_profiler.step()
def finalize(self) -> None: def finalize(self) -> None:
self._torch_profiler.__exit__(None, None, None)
extra: dict[str, Any] = {"profile_mode": self._mode} extra: dict[str, Any] = {"profile_mode": self._mode}
if self._device.type == "cuda": if self._device.type == "cuda":
extra["peak_memory_allocated_bytes"] = torch.cuda.max_memory_allocated(self._device) extra["peak_memory_allocated_bytes"] = torch.cuda.max_memory_allocated(self._device)
+1 -1
View File
@@ -400,7 +400,7 @@ def test_training_profiler_section_records_duration(tmp_path):
output_dir=tmp_path, output_dir=tmp_path,
device=torch.device("cpu"), device=torch.device("cpu"),
) )
with profiler: profiler.start()
with profiler.section("forward"): with profiler.section("forward"):
pass pass
with profiler.section("backward"): with profiler.section("backward"):