mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
refactor(profiling): shrink lerobot_train.py diff via start()/finalize()
Replace the `with profiler or nullcontext():` wrap around the entire training loop with explicit `profiler.start()` / `profiler.finalize()` calls, and tighten `_section(...)` regions in `update_policy` to only wrap the hot calls (forward / backward / optimizer.step). This avoids ~120 lines of pure re-indentation noise while keeping the exact same artifacts on disk and the same public behavior. lerobot_train.py diff vs main: 267 -> 29 changed lines. Made-with: Cursor
This commit is contained in:
@@ -138,8 +138,7 @@ def update_policy(
|
|||||||
policy.parameters(), float("inf"), error_if_nonfinite=False
|
policy.parameters(), float("inf"), error_if_nonfinite=False
|
||||||
)
|
)
|
||||||
|
|
||||||
with _section("optimizer"):
|
with _section("optimizer"), lock if lock is not None else nullcontext():
|
||||||
with lock if lock is not None else nullcontext():
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@@ -327,6 +326,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
profiler.record_deterministic_forward(
|
profiler.record_deterministic_forward(
|
||||||
policy=policy, dataset=dataset, batch_size=cfg.batch_size, preprocessor=preprocessor
|
policy=policy, dataset=dataset, batch_size=cfg.batch_size, preprocessor=preprocessor
|
||||||
)
|
)
|
||||||
|
profiler.start()
|
||||||
|
|
||||||
# Load precomputed SARM progress for RA-BC if enabled
|
# Load precomputed SARM progress for RA-BC if enabled
|
||||||
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
|
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
|
||||||
@@ -441,7 +441,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
logging.info(
|
logging.info(
|
||||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||||
)
|
)
|
||||||
with profiler or nullcontext():
|
|
||||||
for _ in range(step, cfg.steps):
|
for _ in range(step, cfg.steps):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
|
|||||||
@@ -356,14 +356,10 @@ class TrainingProfiler:
|
|||||||
if self._device.type == "cuda":
|
if self._device.type == "cuda":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def __enter__(self) -> TrainingProfiler:
|
def start(self) -> None:
|
||||||
if self._device.type == "cuda":
|
if self._device.type == "cuda":
|
||||||
torch.cuda.reset_peak_memory_stats(self._device)
|
torch.cuda.reset_peak_memory_stats(self._device)
|
||||||
self._torch_profiler.__enter__()
|
self._torch_profiler.__enter__()
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *exc: Any) -> bool:
|
|
||||||
return self._torch_profiler.__exit__(*exc)
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def section(self, name: str) -> Iterator[None]:
|
def section(self, name: str) -> Iterator[None]:
|
||||||
@@ -394,6 +390,7 @@ class TrainingProfiler:
|
|||||||
self._torch_profiler.step()
|
self._torch_profiler.step()
|
||||||
|
|
||||||
def finalize(self) -> None:
|
def finalize(self) -> None:
|
||||||
|
self._torch_profiler.__exit__(None, None, None)
|
||||||
extra: dict[str, Any] = {"profile_mode": self._mode}
|
extra: dict[str, Any] = {"profile_mode": self._mode}
|
||||||
if self._device.type == "cuda":
|
if self._device.type == "cuda":
|
||||||
extra["peak_memory_allocated_bytes"] = torch.cuda.max_memory_allocated(self._device)
|
extra["peak_memory_allocated_bytes"] = torch.cuda.max_memory_allocated(self._device)
|
||||||
|
|||||||
@@ -400,7 +400,7 @@ def test_training_profiler_section_records_duration(tmp_path):
|
|||||||
output_dir=tmp_path,
|
output_dir=tmp_path,
|
||||||
device=torch.device("cpu"),
|
device=torch.device("cpu"),
|
||||||
)
|
)
|
||||||
with profiler:
|
profiler.start()
|
||||||
with profiler.section("forward"):
|
with profiler.section("forward"):
|
||||||
pass
|
pass
|
||||||
with profiler.section("backward"):
|
with profiler.section("backward"):
|
||||||
|
|||||||
Reference in New Issue
Block a user