fix: single level loop

This commit is contained in:
Francesco Capuano
2025-09-24 01:06:13 +02:00
parent cdd6cb606c
commit 6eaf6a861a
+93 -30
View File
@@ -8,18 +8,46 @@ accurate benchmarking without requiring datasets.
import argparse
import os
import signal
import statistics
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
import psutil
import torch
from tqdm import tqdm
from lerobot.configs.types import FeatureType
from lerobot.policies.factory import get_policy_class
from lerobot.policies.pretrained import PreTrainedPolicy
class TimeoutException:
pass
@contextmanager
def timeout(seconds):
def signal_handler(signum, frame):
raise TimeoutException(f"Timed out after {seconds} seconds")
# On Windows, signal is not available, so we can't use this timeout mechanism
if not hasattr(signal, "SIGALRM"):
yield
return
old_handler = signal.signal(signal.SIGALRM, signal_handler)
try:
# signal.alarm expects integer seconds
# for float seconds, we can use setitimer
signal.setitimer(signal.ITIMER_REAL, seconds)
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
signal.signal(signal.SIGALRM, old_handler)
def bytes_to_human(n: int) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if n < 1024:
@@ -78,12 +106,19 @@ def main():
"--device", type=str, default="mps", choices=["cuda", "cpu", "mps"], help="Device to run on"
)
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--num-trials", type=int, default=10, help="Number of timing trials")
parser.add_argument("--forwards-per-trial", type=int, default=10, help="Number of forwards per trial")
parser.add_argument("--warmup", type=int, default=2, help="Warmup forwards (not timed)")
parser.add_argument(
"--num-samples", type=int, default=100, help="Number of inference samples to benchmark"
)
parser.add_argument("--warmup", type=int, default=10, help="Number of warmup samples (not timed)")
parser.add_argument(
"--output-dir", type=str, default="outputs/benchmarks", help="Directory to save benchmark results"
)
parser.add_argument(
"--timeout",
type=float,
default=0.3,
help="Timeout for each inference pass in seconds (default: 0.3s = 300ms)",
)
args = parser.parse_args()
# Seed & deterministic-ish setup
@@ -124,6 +159,7 @@ def main():
# Generate dummy observation based on policy input features
dummy_observation = generate_dummy_observation(policy.config.input_features, device)
dummy_observation["task"] = ""
# Helper to sync for fair timings
def _sync(dev_=device):
@@ -138,8 +174,9 @@ def main():
# Warmup (to stabilize kernels/caches)
print("Warming up...")
with torch.no_grad():
policy.reset()
for _ in range(args.warmup):
_ = policy.predict_action_chunk(dummy_observation)
_ = policy.select_action(dummy_observation)
_sync()
# Memory footprint before timing
@@ -149,47 +186,70 @@ def main():
torch.cuda.reset_peak_memory_stats()
# PyTorch timing with Event objects for more accurate GPU timing
print(f"Running benchmark: {args.num_trials} trials x {args.forwards_per_trial} forwards...")
print(f"Running benchmark: {args.num_samples} samples...")
if use_cuda:
# Use CUDA Events for precise GPU timing
start_events = []
end_events = []
timeout_count = 0
with torch.no_grad():
for _ in range(args.num_trials):
for _ in range(args.forwards_per_trial):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
_ = policy.predict_action_chunk(dummy_observation)
end_event.record()
for forward in tqdm(range(args.num_samples), desc="Trials"):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
try:
with timeout(args.timeout):
start_event.record()
_ = policy.select_action(dummy_observation)
end_event.record()
start_events.append(start_event)
end_events.append(end_event)
except TimeoutException:
timeout_count += 1
# Add placeholder for timeout
start_events.append(None)
end_events.append(None)
print(f"\n[!] Timeout on forward {forward + 1}")
continue
# Synchronize and collect timing results
torch.cuda.synchronize()
per_forward_ms = []
for start_event, end_event in zip(start_events, end_events, strict=True):
per_forward_ms.append(start_event.elapsed_time(end_event))
if start_event is None:
per_forward_ms.append(args.timeout * 1000)
else:
per_forward_ms.append(start_event.elapsed_time(end_event))
if timeout_count > 0:
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
else:
# Use torch.utils.benchmark for CPU/MPS timing
from torch.utils.benchmark import Timer
# Use simple time.perf_counter for CPU/MPS timing with timeout
import time
def run_inference():
return policy.predict_action_chunk(dummy_observation)
# Collect individual timing measurements
per_forward_ms = []
timeout_count = 0
with torch.no_grad():
for _ in range(args.num_trials):
for _ in range(args.forwards_per_trial):
timer = Timer(stmt="run_inference()", globals={"run_inference": run_inference})
measurement = timer.timeit(1) # Single measurement
per_forward_ms.append(measurement.mean * 1000) # Convert to ms
for sample in tqdm(range(args.num_samples), desc="Samples"):
try:
with timeout(args.timeout):
start_time = time.perf_counter()
_ = policy.select_action(dummy_observation)
end_time = time.perf_counter()
per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms
except TimeoutException:
timeout_count += 1
per_forward_ms.append(args.timeout * 1000)
print(f"\n[!] Timeout on sample {sample + 1}")
continue
if timeout_count > 0:
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
# Memory footprint after timing
rss_after = process.memory_info().rss
@@ -215,11 +275,13 @@ def main():
"policy_type": args.policy_type,
"policy_id": args.policy_id,
"device": device,
"num_trials": args.num_trials,
"forwards_per_trial": args.forwards_per_trial,
"num_trials": args.num_samples,
"forwards_per_trial": 1,
"warmup": args.warmup,
"timeout_ms": args.timeout * 1000,
"seed": args.seed,
"num_params": num_params,
"timeout_count": timeout_count,
"latency_mean_ms": mean_ms,
"latency_std_ms": std_ms,
"latency_min_ms": min_ms,
@@ -248,10 +310,11 @@ Input Features: {", ".join(results["input_features"])}
Output Features: {", ".join(results["output_features"])}
=== Benchmark Configuration ===
Trials: {results["num_trials"]}
Forwards per Trial: {results["forwards_per_trial"]}
Samples: {results["num_trials"]}
Warmup: {results["warmup"]}
Total Measurements: {len(per_forward_ms)}
Timeout: {results["timeout_ms"]:.1f}ms
Timeouts: {results["timeout_count"]} / {results["num_trials"]}
=== Latency Results (ms) ===
Mean: {results["latency_mean_ms"]:.3f}
@@ -290,7 +353,7 @@ Benchmark completed successfully at {datetime.now().strftime("%Y-%m-%d %H:%M:%S"
print("\n=== Inference Benchmark Results ===")
print(f"Policy: {args.policy_type} ({args.policy_id})")
print(f"Device: {device}")
print(f"Trials: {args.num_trials} | Forwards/Trial: {args.forwards_per_trial} | Warmup: {args.warmup}")
print(f"Samples: {args.num_samples} | Warmup: {args.warmup}")
print(f"Model params: {num_params:,}")
print("\nLatency per forward (ms):")