mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 22:49:48 +00:00
refactor(profiling): shrink lerobot_train.py diff via start()/finalize()
Replace the `with profiler or nullcontext():` wrap around the entire training loop with explicit `profiler.start()` / `profiler.finalize()` calls, and tighten `_section(...)` regions in `update_policy` to only wrap the hot calls (forward / backward / optimizer.step). This avoids ~120 lines of pure re-indentation noise while keeping the exact same artifacts on disk and the same public behavior. lerobot_train.py diff vs main: 267 -> 29 changed lines. Made-with: Cursor
This commit is contained in:
@@ -138,8 +138,7 @@ def update_policy(
|
||||
policy.parameters(), float("inf"), error_if_nonfinite=False
|
||||
)
|
||||
|
||||
with _section("optimizer"):
|
||||
with lock if lock is not None else nullcontext():
|
||||
with _section("optimizer"), lock if lock is not None else nullcontext():
|
||||
optimizer.step()
|
||||
|
||||
optimizer.zero_grad()
|
||||
@@ -327,6 +326,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
profiler.record_deterministic_forward(
|
||||
policy=policy, dataset=dataset, batch_size=cfg.batch_size, preprocessor=preprocessor
|
||||
)
|
||||
profiler.start()
|
||||
|
||||
# Load precomputed SARM progress for RA-BC if enabled
|
||||
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
|
||||
@@ -441,7 +441,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info(
|
||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||
)
|
||||
with profiler or nullcontext():
|
||||
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
|
||||
@@ -356,14 +356,10 @@ class TrainingProfiler:
|
||||
if self._device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def __enter__(self) -> TrainingProfiler:
|
||||
def start(self) -> None:
|
||||
if self._device.type == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats(self._device)
|
||||
self._torch_profiler.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc: Any) -> bool:
|
||||
return self._torch_profiler.__exit__(*exc)
|
||||
|
||||
@contextmanager
|
||||
def section(self, name: str) -> Iterator[None]:
|
||||
@@ -394,6 +390,7 @@ class TrainingProfiler:
|
||||
self._torch_profiler.step()
|
||||
|
||||
def finalize(self) -> None:
|
||||
self._torch_profiler.__exit__(None, None, None)
|
||||
extra: dict[str, Any] = {"profile_mode": self._mode}
|
||||
if self._device.type == "cuda":
|
||||
extra["peak_memory_allocated_bytes"] = torch.cuda.max_memory_allocated(self._device)
|
||||
|
||||
@@ -400,7 +400,7 @@ def test_training_profiler_section_records_duration(tmp_path):
|
||||
output_dir=tmp_path,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
with profiler:
|
||||
profiler.start()
|
||||
with profiler.section("forward"):
|
||||
pass
|
||||
with profiler.section("backward"):
|
||||
|
||||
Reference in New Issue
Block a user