mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
refactor: extract profiling into self-contained TrainingProfiler class
Move all profiling orchestration out of lerobot_train.py and TrainPipelineConfig into a TrainingProfiler class in profiling_utils.py. - lerobot_train.py: ~74 lines of profiling code reduced to ~7 call sites - TrainPipelineConfig: 10 profile_* fields reduced to 2 (mode + output_dir) - update_policy: reverted to clean main-branch signature (no timing_collector) - TrainingProfiler encapsulates torch profiler, timing collection, deterministic forward artifacts, and all output writing - CI script (run_model_profiling.py) unchanged—it only passes the 2 kept fields Made-with: Cursor
This commit is contained in:
@@ -57,15 +57,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
num_workers: int = 4
|
||||
batch_size: int = 8
|
||||
profile_mode: str = "off"
|
||||
profile_wait_steps: int = 1
|
||||
profile_warmup_steps: int = 2
|
||||
profile_active_steps: int = 6
|
||||
profile_repeat: int = 1
|
||||
profile_output_dir: Path | None = None
|
||||
profile_record_shapes: bool = True
|
||||
profile_with_memory: bool = True
|
||||
profile_with_flops: bool = True
|
||||
profile_with_stack: bool = False
|
||||
steps: int = 100_000
|
||||
eval_freq: int = 20_000
|
||||
log_freq: int = 200
|
||||
@@ -147,10 +139,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
raise ValueError(
|
||||
f"`profile_mode` must be one of 'off', 'summary', or 'trace', got {self.profile_mode}."
|
||||
)
|
||||
if self.profile_wait_steps < 0 or self.profile_warmup_steps < 0 or self.profile_active_steps < 0:
|
||||
raise ValueError("Profiler schedule steps must be non-negative.")
|
||||
if self.profile_repeat <= 0:
|
||||
raise ValueError("`profile_repeat` must be strictly positive.")
|
||||
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||
|
||||
@@ -22,7 +22,6 @@ import dataclasses
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -50,13 +49,7 @@ from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.profiling_utils import (
|
||||
StepTimingCollector,
|
||||
ensure_dir,
|
||||
make_torch_profiler,
|
||||
write_deterministic_forward_artifacts,
|
||||
write_torch_profiler_outputs,
|
||||
)
|
||||
from lerobot.utils.profiling_utils import TrainingProfiler
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
cycle,
|
||||
@@ -79,7 +72,6 @@ def update_policy(
|
||||
lr_scheduler=None,
|
||||
lock=None,
|
||||
rabc_weights_provider=None,
|
||||
timing_collector: StepTimingCollector | None = None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
@@ -113,7 +105,6 @@ def update_policy(
|
||||
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
|
||||
|
||||
# Let accelerator handle mixed precision
|
||||
forward_start = time.perf_counter()
|
||||
with accelerator.autocast():
|
||||
# Use per-sample loss when RA-BC is enabled for proper weighting
|
||||
if rabc_batch_weights is not None:
|
||||
@@ -132,15 +123,11 @@ def update_policy(
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
forward_s = time.perf_counter() - forward_start
|
||||
|
||||
# Use accelerator's backward method
|
||||
backward_start = time.perf_counter()
|
||||
accelerator.backward(loss)
|
||||
backward_s = time.perf_counter() - backward_start
|
||||
|
||||
# Clip gradients if specified
|
||||
optimizer_start = time.perf_counter()
|
||||
if grad_clip_norm > 0:
|
||||
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
||||
else:
|
||||
@@ -161,19 +148,11 @@ def update_policy(
|
||||
# Update internal buffers if policy has update method
|
||||
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
|
||||
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
|
||||
optimizer_s = time.perf_counter() - optimizer_start
|
||||
|
||||
train_metrics.loss = loss.item()
|
||||
train_metrics.grad_norm = grad_norm.item()
|
||||
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
||||
train_metrics.update_s = time.perf_counter() - start_time
|
||||
if timing_collector is not None:
|
||||
timing_collector.record(
|
||||
forward_s=forward_s,
|
||||
backward_s=backward_s,
|
||||
optimizer_s=optimizer_s,
|
||||
total_update_s=train_metrics.update_s.val,
|
||||
)
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
@@ -228,12 +207,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if is_main_process:
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
profiling_enabled = cfg.profile_mode != "off"
|
||||
profile_output_dir = None
|
||||
if profiling_enabled and is_main_process and cfg.profile_output_dir is not None:
|
||||
profile_output_dir = ensure_dir(Path(cfg.profile_output_dir))
|
||||
logging.info("Profiling enabled. Artifacts will be written to %s", profile_output_dir)
|
||||
|
||||
# Initialize wandb only on main process
|
||||
if cfg.wandb.enable and cfg.wandb.project and is_main_process:
|
||||
wandb_logger = WandBLogger(cfg)
|
||||
@@ -344,15 +317,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
if profiling_enabled and is_main_process and profile_output_dir is not None:
|
||||
logging.info("Recording deterministic forward-pass artifacts")
|
||||
write_deterministic_forward_artifacts(
|
||||
policy=policy,
|
||||
dataset=dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
preprocessor=preprocessor,
|
||||
output_dir=profile_output_dir,
|
||||
device_type=device.type,
|
||||
profiler = (
|
||||
TrainingProfiler.from_cfg(cfg, device) if cfg.profile_mode != "off" and is_main_process else None
|
||||
)
|
||||
if profiler:
|
||||
profiler.record_deterministic_forward(
|
||||
policy=policy, dataset=dataset, batch_size=cfg.batch_size, preprocessor=preprocessor
|
||||
)
|
||||
|
||||
# Load precomputed SARM progress for RA-BC if enabled
|
||||
@@ -468,16 +438,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info(
|
||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||
)
|
||||
timing_collector = StepTimingCollector() if profiling_enabled and is_main_process else None
|
||||
profiler = None
|
||||
profiler_context = nullcontext()
|
||||
if profiling_enabled and is_main_process and profile_output_dir is not None:
|
||||
if device.type == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
profiler = make_torch_profiler(cfg, profile_output_dir, device.type)
|
||||
profiler_context = profiler
|
||||
|
||||
with profiler_context:
|
||||
with profiler or nullcontext():
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
@@ -493,7 +454,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
accelerator=accelerator,
|
||||
lr_scheduler=lr_scheduler,
|
||||
rabc_weights_provider=rabc_weights,
|
||||
timing_collector=timing_collector,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
@@ -501,17 +461,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
step += 1
|
||||
if is_main_process:
|
||||
progbar.update(1)
|
||||
if timing_collector is not None:
|
||||
timing_collector.record_dataloading(train_tracker.dataloading_s.val)
|
||||
if device.type == "cuda":
|
||||
timing_collector.record_memory(
|
||||
step=step,
|
||||
allocated_bytes=torch.cuda.memory_allocated(device),
|
||||
reserved_bytes=torch.cuda.memory_reserved(device),
|
||||
)
|
||||
if profiler:
|
||||
profiler.step(step, train_tracker)
|
||||
train_tracker.step()
|
||||
if profiler is not None:
|
||||
profiler.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||
@@ -606,21 +558,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
if is_main_process:
|
||||
progbar.close()
|
||||
if timing_collector is not None and profile_output_dir is not None:
|
||||
extra_profile_metrics = {
|
||||
"profile_mode": cfg.profile_mode,
|
||||
"peak_memory_allocated_bytes": (
|
||||
torch.cuda.max_memory_allocated(device) if device.type == "cuda" else None
|
||||
),
|
||||
"peak_memory_reserved_bytes": (
|
||||
torch.cuda.max_memory_reserved(device) if device.type == "cuda" else None
|
||||
),
|
||||
}
|
||||
timing_collector.write_json(
|
||||
profile_output_dir / "step_timing_summary.json", extra=extra_profile_metrics
|
||||
)
|
||||
if profiler is not None and profile_output_dir is not None:
|
||||
write_torch_profiler_outputs(profiler, profile_output_dir, device_type=device.type)
|
||||
if profiler:
|
||||
profiler.finalize()
|
||||
|
||||
if eval_env:
|
||||
close_envs(eval_env)
|
||||
|
||||
@@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import statistics
|
||||
from dataclasses import dataclass, field
|
||||
from numbers import Real
|
||||
@@ -47,7 +48,20 @@ def write_profiler_table(
|
||||
output_path.write_text(table)
|
||||
|
||||
|
||||
def make_torch_profiler(cfg: Any, output_dir: Path, device_type: str) -> Any:
|
||||
def _make_torch_profiler(
|
||||
*,
|
||||
mode: str,
|
||||
output_dir: Path,
|
||||
device_type: str,
|
||||
wait_steps: int = 1,
|
||||
warmup_steps: int = 2,
|
||||
active_steps: int = 6,
|
||||
repeat: int = 1,
|
||||
record_shapes: bool = True,
|
||||
with_memory: bool = True,
|
||||
with_flops: bool = True,
|
||||
with_stack: bool = False,
|
||||
) -> Any:
|
||||
activities = [torch.profiler.ProfilerActivity.CPU]
|
||||
if device_type == "cuda":
|
||||
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
||||
@@ -55,23 +69,23 @@ def make_torch_profiler(cfg: Any, output_dir: Path, device_type: str) -> Any:
|
||||
trace_dir = ensure_dir(output_dir / "torch_traces")
|
||||
|
||||
def _trace_ready(profiler: Any) -> None:
|
||||
if cfg.profile_mode != "trace":
|
||||
if mode != "trace":
|
||||
return
|
||||
profiler.export_chrome_trace(str(trace_dir / f"trace_step_{profiler.step_num}.json"))
|
||||
|
||||
return torch.profiler.profile(
|
||||
activities=activities,
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=cfg.profile_wait_steps,
|
||||
warmup=cfg.profile_warmup_steps,
|
||||
active=cfg.profile_active_steps,
|
||||
repeat=cfg.profile_repeat,
|
||||
wait=wait_steps,
|
||||
warmup=warmup_steps,
|
||||
active=active_steps,
|
||||
repeat=repeat,
|
||||
),
|
||||
on_trace_ready=_trace_ready,
|
||||
record_shapes=cfg.profile_record_shapes,
|
||||
profile_memory=cfg.profile_with_memory,
|
||||
with_flops=cfg.profile_with_flops,
|
||||
with_stack=cfg.profile_with_stack,
|
||||
record_shapes=record_shapes,
|
||||
profile_memory=with_memory,
|
||||
with_flops=with_flops,
|
||||
with_stack=with_stack,
|
||||
)
|
||||
|
||||
|
||||
@@ -228,25 +242,12 @@ def _as_float(value: Any) -> float:
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepTimingCollector:
|
||||
forward_s: list[float] = field(default_factory=list)
|
||||
backward_s: list[float] = field(default_factory=list)
|
||||
optimizer_s: list[float] = field(default_factory=list)
|
||||
class _StepTimingCollector:
|
||||
total_update_s: list[float] = field(default_factory=list)
|
||||
dataloading_s: list[float] = field(default_factory=list)
|
||||
memory_timeline: list[dict[str, float | int]] = field(default_factory=list)
|
||||
|
||||
def record(
|
||||
self,
|
||||
*,
|
||||
forward_s: float,
|
||||
backward_s: float,
|
||||
optimizer_s: float,
|
||||
total_update_s: float,
|
||||
) -> None:
|
||||
self.forward_s.append(_as_float(forward_s))
|
||||
self.backward_s.append(_as_float(backward_s))
|
||||
self.optimizer_s.append(_as_float(optimizer_s))
|
||||
def record_step(self, total_update_s: float) -> None:
|
||||
self.total_update_s.append(_as_float(total_update_s))
|
||||
|
||||
def record_dataloading(self, dataloading_s: float) -> None:
|
||||
@@ -263,9 +264,6 @@ class StepTimingCollector:
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"forward_s": _summary(self.forward_s),
|
||||
"backward_s": _summary(self.backward_s),
|
||||
"optimizer_s": _summary(self.optimizer_s),
|
||||
"total_update_s": _summary(self.total_update_s),
|
||||
"dataloading_s": _summary(self.dataloading_s),
|
||||
"memory_timeline": self.memory_timeline,
|
||||
@@ -277,3 +275,99 @@ class StepTimingCollector:
|
||||
payload.update(extra)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(json.dumps(payload, indent=2, sort_keys=True))
|
||||
|
||||
|
||||
class TrainingProfiler:
|
||||
"""Self-contained profiling orchestrator for the training loop.
|
||||
|
||||
Encapsulates torch profiler setup, step-level timing collection, deterministic
|
||||
forward-pass artifact recording, and all output writing. The training script
|
||||
interacts with it through a thin interface (~7 lines).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str,
|
||||
output_dir: Path,
|
||||
device: torch.device,
|
||||
*,
|
||||
wait_steps: int = 1,
|
||||
warmup_steps: int = 2,
|
||||
active_steps: int = 6,
|
||||
repeat: int = 1,
|
||||
record_shapes: bool = True,
|
||||
with_memory: bool = True,
|
||||
with_flops: bool = True,
|
||||
with_stack: bool = False,
|
||||
) -> None:
|
||||
self._mode = mode
|
||||
self._output_dir = ensure_dir(output_dir)
|
||||
self._device = device
|
||||
self._timing = _StepTimingCollector()
|
||||
self._torch_profiler = _make_torch_profiler(
|
||||
mode=mode,
|
||||
output_dir=output_dir,
|
||||
device_type=device.type,
|
||||
wait_steps=wait_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
active_steps=active_steps,
|
||||
repeat=repeat,
|
||||
record_shapes=record_shapes,
|
||||
with_memory=with_memory,
|
||||
with_flops=with_flops,
|
||||
with_stack=with_stack,
|
||||
)
|
||||
logging.info("Profiling enabled. Artifacts will be written to %s", output_dir)
|
||||
|
||||
@classmethod
|
||||
def from_cfg(cls, cfg: Any, device: torch.device) -> TrainingProfiler:
|
||||
output_dir = cfg.profile_output_dir
|
||||
if output_dir is None:
|
||||
output_dir = Path(cfg.output_dir) / "profiling"
|
||||
return cls(mode=cfg.profile_mode, output_dir=Path(output_dir), device=device)
|
||||
|
||||
def record_deterministic_forward(
|
||||
self,
|
||||
*,
|
||||
policy: Any,
|
||||
dataset: Any,
|
||||
batch_size: int,
|
||||
preprocessor: Any,
|
||||
) -> None:
|
||||
logging.info("Recording deterministic forward-pass artifacts")
|
||||
write_deterministic_forward_artifacts(
|
||||
policy=policy,
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
preprocessor=preprocessor,
|
||||
output_dir=self._output_dir,
|
||||
device_type=self._device.type,
|
||||
)
|
||||
|
||||
def __enter__(self) -> TrainingProfiler:
|
||||
if self._device.type == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats(self._device)
|
||||
self._torch_profiler.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc: Any) -> bool:
|
||||
return self._torch_profiler.__exit__(*exc)
|
||||
|
||||
def step(self, step_num: int, train_tracker: Any) -> None:
|
||||
self._timing.record_step(_as_float(train_tracker.update_s))
|
||||
self._timing.record_dataloading(_as_float(train_tracker.dataloading_s))
|
||||
if self._device.type == "cuda":
|
||||
self._timing.record_memory(
|
||||
step=step_num,
|
||||
allocated_bytes=torch.cuda.memory_allocated(self._device),
|
||||
reserved_bytes=torch.cuda.memory_reserved(self._device),
|
||||
)
|
||||
self._torch_profiler.step()
|
||||
|
||||
def finalize(self) -> None:
|
||||
extra: dict[str, Any] = {"profile_mode": self._mode}
|
||||
if self._device.type == "cuda":
|
||||
extra["peak_memory_allocated_bytes"] = torch.cuda.max_memory_allocated(self._device)
|
||||
extra["peak_memory_reserved_bytes"] = torch.cuda.max_memory_reserved(self._device)
|
||||
self._timing.write_json(self._output_dir / "step_timing_summary.json", extra=extra)
|
||||
write_torch_profiler_outputs(self._torch_profiler, self._output_dir, device_type=self._device.type)
|
||||
|
||||
@@ -65,29 +65,24 @@ def test_pretrained_libero_specs_match_expected_camera_keys_and_normalization():
|
||||
specs = module.load_specs(spec_path)
|
||||
|
||||
assert (
|
||||
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
|
||||
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
|
||||
in specs["pi0"]["train_args"]
|
||||
'--rename_map={"observation.images.front": "observation.images.base_0_rgb", '
|
||||
'"observation.images.wrist": "observation.images.left_wrist_0_rgb"}' in specs["pi0"].train_args
|
||||
)
|
||||
assert (
|
||||
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
|
||||
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
|
||||
in specs["pi0_fast"]["train_args"]
|
||||
'--rename_map={"observation.images.front": "observation.images.base_0_rgb", '
|
||||
'"observation.images.wrist": "observation.images.left_wrist_0_rgb"}' in specs["pi0_fast"].train_args
|
||||
)
|
||||
assert (
|
||||
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
|
||||
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
|
||||
in specs["pi05"]["train_args"]
|
||||
'--rename_map={"observation.images.front": "observation.images.base_0_rgb", '
|
||||
'"observation.images.wrist": "observation.images.left_wrist_0_rgb"}' in specs["pi05"].train_args
|
||||
)
|
||||
assert (
|
||||
"--policy.normalization_mapping={\"ACTION\": \"MEAN_STD\", "
|
||||
"\"STATE\": \"MEAN_STD\", \"VISUAL\": \"IDENTITY\"}"
|
||||
in specs["pi05"]["train_args"]
|
||||
'--policy.normalization_mapping={"ACTION": "MEAN_STD", '
|
||||
'"STATE": "MEAN_STD", "VISUAL": "IDENTITY"}' in specs["pi05"].train_args
|
||||
)
|
||||
assert (
|
||||
"--rename_map={\"observation.images.front\": \"observation.images.camera1\", "
|
||||
"\"observation.images.wrist\": \"observation.images.camera2\"}"
|
||||
in specs["smolvla"]["train_args"]
|
||||
'--rename_map={"observation.images.front": "observation.images.camera1", '
|
||||
'"observation.images.wrist": "observation.images.camera2"}' in specs["smolvla"].train_args
|
||||
)
|
||||
|
||||
|
||||
@@ -222,7 +217,6 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
|
||||
(profile_dir / "step_timing_summary.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"forward_s": {"count": 1, "mean": 0.1, "median": 0.1, "min": 0.1, "max": 0.1},
|
||||
"total_update_s": {"count": 1, "mean": 0.3, "median": 0.3, "min": 0.3, "max": 0.3},
|
||||
"peak_memory_allocated_bytes": 1024,
|
||||
}
|
||||
@@ -251,7 +245,7 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
|
||||
assert row["git_commit"] == "deadbeef"
|
||||
assert row["git_ref"] == "codex/model-profiling"
|
||||
assert row["pr_number"] == 3389
|
||||
assert row["step_timing_summary"]["forward_s"]["mean"] == 0.1
|
||||
assert row["step_timing_summary"]["total_update_s"]["mean"] == 0.3
|
||||
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
|
||||
|
||||
|
||||
@@ -364,19 +358,14 @@ def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
|
||||
|
||||
|
||||
def test_step_timing_collector_accepts_metric_like_values(tmp_path):
|
||||
from lerobot.utils.profiling_utils import StepTimingCollector
|
||||
from lerobot.utils.profiling_utils import _StepTimingCollector
|
||||
|
||||
class _MetricLike:
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
collector = StepTimingCollector()
|
||||
collector.record(
|
||||
forward_s=0.1,
|
||||
backward_s=0.2,
|
||||
optimizer_s=0.3,
|
||||
total_update_s=_MetricLike(0.6),
|
||||
)
|
||||
collector = _StepTimingCollector()
|
||||
collector.record_step(_MetricLike(0.6))
|
||||
collector.record_dataloading(_MetricLike(0.05))
|
||||
collector.write_json(tmp_path / "step_timing_summary.json")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user