diff --git a/benchmarks/policies/inference.py b/benchmarks/policies/inference.py index 7fccd5362..65a7974b3 100644 --- a/benchmarks/policies/inference.py +++ b/benchmarks/policies/inference.py @@ -8,18 +8,46 @@ accurate benchmarking without requiring datasets. import argparse import os +import signal import statistics +from contextlib import contextmanager from datetime import datetime from pathlib import Path import psutil import torch +from tqdm import tqdm from lerobot.configs.types import FeatureType from lerobot.policies.factory import get_policy_class from lerobot.policies.pretrained import PreTrainedPolicy +class TimeoutException: + pass + + +@contextmanager +def timeout(seconds): + def signal_handler(signum, frame): + raise TimeoutException(f"Timed out after {seconds} seconds") + + # On Windows, signal is not available, so we can't use this timeout mechanism + if not hasattr(signal, "SIGALRM"): + yield + return + + old_handler = signal.signal(signal.SIGALRM, signal_handler) + try: + # signal.alarm expects integer seconds + # for float seconds, we can use setitimer + signal.setitimer(signal.ITIMER_REAL, seconds) + yield + finally: + signal.setitimer(signal.ITIMER_REAL, 0) + signal.signal(signal.SIGALRM, old_handler) + + def bytes_to_human(n: int) -> str: for unit in ["B", "KB", "MB", "GB", "TB"]: if n < 1024: @@ -78,12 +106,19 @@ def main(): "--device", type=str, default="mps", choices=["cuda", "cpu", "mps"], help="Device to run on" ) parser.add_argument("--seed", type=int, default=42, help="Random seed") - parser.add_argument("--num-trials", type=int, default=10, help="Number of timing trials") - parser.add_argument("--forwards-per-trial", type=int, default=10, help="Number of forwards per trial") - parser.add_argument("--warmup", type=int, default=2, help="Warmup forwards (not timed)") + parser.add_argument( + "--num-samples", type=int, default=100, help="Number of inference samples to benchmark" + ) + parser.add_argument("--warmup", type=int, default=10, help="Number of warmup samples (not timed)") parser.add_argument( "--output-dir", type=str, default="outputs/benchmarks", help="Directory to save benchmark results" ) + parser.add_argument( + "--timeout", + type=float, + default=0.3, + help="Timeout for each inference pass in seconds (default: 0.3s = 300ms)", + ) args = parser.parse_args() # Seed & deterministic-ish setup @@ -124,6 +159,7 @@ def main(): # Generate dummy observation based on policy input features dummy_observation = generate_dummy_observation(policy.config.input_features, device) + dummy_observation["task"] = "" # Helper to sync for fair timings def _sync(dev_=device): @@ -138,8 +174,9 @@ def main(): # Warmup (to stabilize kernels/caches) print("Warming up...") with torch.no_grad(): + policy.reset() for _ in range(args.warmup): - _ = policy.predict_action_chunk(dummy_observation) + _ = policy.select_action(dummy_observation) _sync() # Memory footprint before timing @@ -149,47 +186,70 @@ def main(): torch.cuda.reset_peak_memory_stats() # PyTorch timing with Event objects for more accurate GPU timing - print(f"Running benchmark: {args.num_trials} trials x {args.forwards_per_trial} forwards...") + print(f"Running benchmark: {args.num_samples} samples...") if use_cuda: # Use CUDA Events for precise GPU timing start_events = [] end_events = [] + timeout_count = 0 with torch.no_grad(): - for _ in range(args.num_trials): - for _ in range(args.forwards_per_trial): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - _ = policy.predict_action_chunk(dummy_observation) - end_event.record() + for forward in tqdm(range(args.num_samples), desc="Trials"): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + try: + with timeout(args.timeout): + start_event.record() + _ = policy.select_action(dummy_observation) + end_event.record() start_events.append(start_event) end_events.append(end_event) + except TimeoutException: + timeout_count += 1 + # Add placeholder for timeout + start_events.append(None) + end_events.append(None) + print(f"\n[!] Timeout on forward {forward + 1}") + continue # Synchronize and collect timing results torch.cuda.synchronize() per_forward_ms = [] for start_event, end_event in zip(start_events, end_events, strict=True): - per_forward_ms.append(start_event.elapsed_time(end_event)) + if start_event is None: + per_forward_ms.append(args.timeout * 1000) + else: + per_forward_ms.append(start_event.elapsed_time(end_event)) + + if timeout_count > 0: + print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)") else: - # Use torch.utils.benchmark for CPU/MPS timing - from torch.utils.benchmark import Timer + # Use simple time.perf_counter for CPU/MPS timing with timeout + import time - def run_inference(): - return policy.predict_action_chunk(dummy_observation) - - # Collect individual timing measurements per_forward_ms = [] + timeout_count = 0 + with torch.no_grad(): - for _ in range(args.num_trials): - for _ in range(args.forwards_per_trial): - timer = Timer(stmt="run_inference()", globals={"run_inference": run_inference}) - measurement = timer.timeit(1) # Single measurement - per_forward_ms.append(measurement.mean * 1000) # Convert to ms + for sample in tqdm(range(args.num_samples), desc="Samples"): + try: + with timeout(args.timeout): + start_time = time.perf_counter() + _ = policy.select_action(dummy_observation) + end_time = time.perf_counter() + + per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms + except TimeoutException: + timeout_count += 1 + per_forward_ms.append(args.timeout * 1000) + print(f"\n[!] Timeout on sample {sample + 1}") + continue + + if timeout_count > 0: + print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)") # Memory footprint after timing rss_after = process.memory_info().rss @@ -215,11 +275,13 @@ def main(): "policy_type": args.policy_type, "policy_id": args.policy_id, "device": device, - "num_trials": args.num_trials, - "forwards_per_trial": args.forwards_per_trial, + "num_trials": args.num_samples, + "forwards_per_trial": 1, "warmup": args.warmup, + "timeout_ms": args.timeout * 1000, "seed": args.seed, "num_params": num_params, + "timeout_count": timeout_count, "latency_mean_ms": mean_ms, "latency_std_ms": std_ms, "latency_min_ms": min_ms, @@ -248,10 +310,11 @@ Input Features: {", ".join(results["input_features"])} Output Features: {", ".join(results["output_features"])} === Benchmark Configuration === -Trials: {results["num_trials"]} -Forwards per Trial: {results["forwards_per_trial"]} +Samples: {results["num_trials"]} Warmup: {results["warmup"]} Total Measurements: {len(per_forward_ms)} +Timeout: {results["timeout_ms"]:.1f}ms +Timeouts: {results["timeout_count"]} / {results["num_trials"]} === Latency Results (ms) === Mean: {results["latency_mean_ms"]:.3f} @@ -290,7 +353,7 @@ Benchmark completed successfully at {datetime.now().strftime("%Y-%m-%d %H:%M:%S" print("\n=== Inference Benchmark Results ===") print(f"Policy: {args.policy_type} ({args.policy_id})") print(f"Device: {device}") - print(f"Trials: {args.num_trials} | Forwards/Trial: {args.forwards_per_trial} | Warmup: {args.warmup}") + print(f"Samples: {args.num_samples} | Warmup: {args.warmup}") print(f"Model params: {num_params:,}") print("\nLatency per forward (ms):")