diff --git a/benchmarks/policies/inference.py b/benchmarks/policies/inference.py index 6bd4922f5..f89a820d4 100644 --- a/benchmarks/policies/inference.py +++ b/benchmarks/policies/inference.py @@ -234,7 +234,8 @@ def main(): per_forward_ms = [] for start_event, end_event in zip(start_events, end_events, strict=True): if start_event is None: - per_forward_ms.append(args.timeout * 1000) + # per_forward_ms.append(args.timeout * 1000) + continue else: per_forward_ms.append(start_event.elapsed_time(end_event)) @@ -262,7 +263,8 @@ def main(): per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms except TimeoutExceptionError: timeout_count += 1 - per_forward_ms.append(args.timeout * 1000) + # per_forward_ms.append(args.timeout * 1000) + print(f"\n[!] Timeout on sample {sample + 1}") continue @@ -374,6 +376,7 @@ Benchmark completed successfully at {datetime.now().strftime("%Y-%m-%d %H:%M:%S" print(f"Device: {device}") print(f"Samples: {args.num_samples} | Warmup: {args.warmup}") print(f"Model params: {num_params:,}") + print(f"Timeout percentage: {timeout_count / args.num_samples * 100:.1f}%") print("\nLatency per forward (ms):") print(f" mean: {mean_ms:.3f} std: {std_ms:.3f}")