fix: single level loop

This commit is contained in:
Francesco Capuano
2025-09-24 01:06:13 +02:00
parent cdd6cb606c
commit 6eaf6a861a
+93 -30
View File
@@ -8,18 +8,46 @@ accurate benchmarking without requiring datasets.
import argparse import argparse
import os import os
import signal
import statistics import statistics
from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
import psutil import psutil
import torch import torch
from tqdm import tqdm
from lerobot.configs.types import FeatureType from lerobot.configs.types import FeatureType
from lerobot.policies.factory import get_policy_class from lerobot.policies.factory import get_policy_class
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
class TimeoutException:
pass
@contextmanager
def timeout(seconds):
def signal_handler(signum, frame):
raise TimeoutException(f"Timed out after {seconds} seconds")
# On Windows, signal is not available, so we can't use this timeout mechanism
if not hasattr(signal, "SIGALRM"):
yield
return
old_handler = signal.signal(signal.SIGALRM, signal_handler)
try:
# signal.alarm expects integer seconds
# for float seconds, we can use setitimer
signal.setitimer(signal.ITIMER_REAL, seconds)
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
signal.signal(signal.SIGALRM, old_handler)
def bytes_to_human(n: int) -> str: def bytes_to_human(n: int) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]: for unit in ["B", "KB", "MB", "GB", "TB"]:
if n < 1024: if n < 1024:
@@ -78,12 +106,19 @@ def main():
"--device", type=str, default="mps", choices=["cuda", "cpu", "mps"], help="Device to run on" "--device", type=str, default="mps", choices=["cuda", "cpu", "mps"], help="Device to run on"
) )
parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--num-trials", type=int, default=10, help="Number of timing trials") parser.add_argument(
parser.add_argument("--forwards-per-trial", type=int, default=10, help="Number of forwards per trial") "--num-samples", type=int, default=100, help="Number of inference samples to benchmark"
parser.add_argument("--warmup", type=int, default=2, help="Warmup forwards (not timed)") )
parser.add_argument("--warmup", type=int, default=10, help="Number of warmup samples (not timed)")
parser.add_argument( parser.add_argument(
"--output-dir", type=str, default="outputs/benchmarks", help="Directory to save benchmark results" "--output-dir", type=str, default="outputs/benchmarks", help="Directory to save benchmark results"
) )
parser.add_argument(
"--timeout",
type=float,
default=0.3,
help="Timeout for each inference pass in seconds (default: 0.3s = 300ms)",
)
args = parser.parse_args() args = parser.parse_args()
# Seed & deterministic-ish setup # Seed & deterministic-ish setup
@@ -124,6 +159,7 @@ def main():
# Generate dummy observation based on policy input features # Generate dummy observation based on policy input features
dummy_observation = generate_dummy_observation(policy.config.input_features, device) dummy_observation = generate_dummy_observation(policy.config.input_features, device)
dummy_observation["task"] = ""
# Helper to sync for fair timings # Helper to sync for fair timings
def _sync(dev_=device): def _sync(dev_=device):
@@ -138,8 +174,9 @@ def main():
# Warmup (to stabilize kernels/caches) # Warmup (to stabilize kernels/caches)
print("Warming up...") print("Warming up...")
with torch.no_grad(): with torch.no_grad():
policy.reset()
for _ in range(args.warmup): for _ in range(args.warmup):
_ = policy.predict_action_chunk(dummy_observation) _ = policy.select_action(dummy_observation)
_sync() _sync()
# Memory footprint before timing # Memory footprint before timing
@@ -149,47 +186,70 @@ def main():
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
# PyTorch timing with Event objects for more accurate GPU timing # PyTorch timing with Event objects for more accurate GPU timing
print(f"Running benchmark: {args.num_trials} trials x {args.forwards_per_trial} forwards...") print(f"Running benchmark: {args.num_samples} samples...")
if use_cuda: if use_cuda:
# Use CUDA Events for precise GPU timing # Use CUDA Events for precise GPU timing
start_events = [] start_events = []
end_events = [] end_events = []
timeout_count = 0
with torch.no_grad(): with torch.no_grad():
for _ in range(args.num_trials): for forward in tqdm(range(args.num_samples), desc="Trials"):
for _ in range(args.forwards_per_trial): start_event = torch.cuda.Event(enable_timing=True)
start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) try:
with timeout(args.timeout):
start_event.record() start_event.record()
_ = policy.predict_action_chunk(dummy_observation) _ = policy.select_action(dummy_observation)
end_event.record() end_event.record()
start_events.append(start_event) start_events.append(start_event)
end_events.append(end_event) end_events.append(end_event)
except TimeoutException:
timeout_count += 1
# Add placeholder for timeout
start_events.append(None)
end_events.append(None)
print(f"\n[!] Timeout on forward {forward + 1}")
continue
# Synchronize and collect timing results # Synchronize and collect timing results
torch.cuda.synchronize() torch.cuda.synchronize()
per_forward_ms = [] per_forward_ms = []
for start_event, end_event in zip(start_events, end_events, strict=True): for start_event, end_event in zip(start_events, end_events, strict=True):
per_forward_ms.append(start_event.elapsed_time(end_event)) if start_event is None:
per_forward_ms.append(args.timeout * 1000)
else:
per_forward_ms.append(start_event.elapsed_time(end_event))
if timeout_count > 0:
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
else: else:
# Use torch.utils.benchmark for CPU/MPS timing # Use simple time.perf_counter for CPU/MPS timing with timeout
from torch.utils.benchmark import Timer import time
def run_inference():
return policy.predict_action_chunk(dummy_observation)
# Collect individual timing measurements
per_forward_ms = [] per_forward_ms = []
timeout_count = 0
with torch.no_grad(): with torch.no_grad():
for _ in range(args.num_trials): for sample in tqdm(range(args.num_samples), desc="Samples"):
for _ in range(args.forwards_per_trial): try:
timer = Timer(stmt="run_inference()", globals={"run_inference": run_inference}) with timeout(args.timeout):
measurement = timer.timeit(1) # Single measurement start_time = time.perf_counter()
per_forward_ms.append(measurement.mean * 1000) # Convert to ms _ = policy.select_action(dummy_observation)
end_time = time.perf_counter()
per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms
except TimeoutException:
timeout_count += 1
per_forward_ms.append(args.timeout * 1000)
print(f"\n[!] Timeout on sample {sample + 1}")
continue
if timeout_count > 0:
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
# Memory footprint after timing # Memory footprint after timing
rss_after = process.memory_info().rss rss_after = process.memory_info().rss
@@ -215,11 +275,13 @@ def main():
"policy_type": args.policy_type, "policy_type": args.policy_type,
"policy_id": args.policy_id, "policy_id": args.policy_id,
"device": device, "device": device,
"num_trials": args.num_trials, "num_trials": args.num_samples,
"forwards_per_trial": args.forwards_per_trial, "forwards_per_trial": 1,
"warmup": args.warmup, "warmup": args.warmup,
"timeout_ms": args.timeout * 1000,
"seed": args.seed, "seed": args.seed,
"num_params": num_params, "num_params": num_params,
"timeout_count": timeout_count,
"latency_mean_ms": mean_ms, "latency_mean_ms": mean_ms,
"latency_std_ms": std_ms, "latency_std_ms": std_ms,
"latency_min_ms": min_ms, "latency_min_ms": min_ms,
@@ -248,10 +310,11 @@ Input Features: {", ".join(results["input_features"])}
Output Features: {", ".join(results["output_features"])} Output Features: {", ".join(results["output_features"])}
=== Benchmark Configuration === === Benchmark Configuration ===
Trials: {results["num_trials"]} Samples: {results["num_trials"]}
Forwards per Trial: {results["forwards_per_trial"]}
Warmup: {results["warmup"]} Warmup: {results["warmup"]}
Total Measurements: {len(per_forward_ms)} Total Measurements: {len(per_forward_ms)}
Timeout: {results["timeout_ms"]:.1f}ms
Timeouts: {results["timeout_count"]} / {results["num_trials"]}
=== Latency Results (ms) === === Latency Results (ms) ===
Mean: {results["latency_mean_ms"]:.3f} Mean: {results["latency_mean_ms"]:.3f}
@@ -290,7 +353,7 @@ Benchmark completed successfully at {datetime.now().strftime("%Y-%m-%d %H:%M:%S"
print("\n=== Inference Benchmark Results ===") print("\n=== Inference Benchmark Results ===")
print(f"Policy: {args.policy_type} ({args.policy_id})") print(f"Policy: {args.policy_type} ({args.policy_id})")
print(f"Device: {device}") print(f"Device: {device}")
print(f"Trials: {args.num_trials} | Forwards/Trial: {args.forwards_per_trial} | Warmup: {args.warmup}") print(f"Samples: {args.num_samples} | Warmup: {args.warmup}")
print(f"Model params: {num_params:,}") print(f"Model params: {num_params:,}")
print("\nLatency per forward (ms):") print("\nLatency per forward (ms):")