mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
fix: single level loop
This commit is contained in:
@@ -8,18 +8,46 @@ accurate benchmarking without requiring datasets.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import signal
|
||||
import statistics
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
class TimeoutException:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timeout(seconds):
|
||||
def signal_handler(signum, frame):
|
||||
raise TimeoutException(f"Timed out after {seconds} seconds")
|
||||
|
||||
# On Windows, signal is not available, so we can't use this timeout mechanism
|
||||
if not hasattr(signal, "SIGALRM"):
|
||||
yield
|
||||
return
|
||||
|
||||
old_handler = signal.signal(signal.SIGALRM, signal_handler)
|
||||
try:
|
||||
# signal.alarm expects integer seconds
|
||||
# for float seconds, we can use setitimer
|
||||
signal.setitimer(signal.ITIMER_REAL, seconds)
|
||||
yield
|
||||
finally:
|
||||
signal.setitimer(signal.ITIMER_REAL, 0)
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
|
||||
|
||||
def bytes_to_human(n: int) -> str:
|
||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||
if n < 1024:
|
||||
@@ -78,12 +106,19 @@ def main():
|
||||
"--device", type=str, default="mps", choices=["cuda", "cpu", "mps"], help="Device to run on"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
parser.add_argument("--num-trials", type=int, default=10, help="Number of timing trials")
|
||||
parser.add_argument("--forwards-per-trial", type=int, default=10, help="Number of forwards per trial")
|
||||
parser.add_argument("--warmup", type=int, default=2, help="Warmup forwards (not timed)")
|
||||
parser.add_argument(
|
||||
"--num-samples", type=int, default=100, help="Number of inference samples to benchmark"
|
||||
)
|
||||
parser.add_argument("--warmup", type=int, default=10, help="Number of warmup samples (not timed)")
|
||||
parser.add_argument(
|
||||
"--output-dir", type=str, default="outputs/benchmarks", help="Directory to save benchmark results"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="Timeout for each inference pass in seconds (default: 0.3s = 300ms)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Seed & deterministic-ish setup
|
||||
@@ -124,6 +159,7 @@ def main():
|
||||
|
||||
# Generate dummy observation based on policy input features
|
||||
dummy_observation = generate_dummy_observation(policy.config.input_features, device)
|
||||
dummy_observation["task"] = ""
|
||||
|
||||
# Helper to sync for fair timings
|
||||
def _sync(dev_=device):
|
||||
@@ -138,8 +174,9 @@ def main():
|
||||
# Warmup (to stabilize kernels/caches)
|
||||
print("Warming up...")
|
||||
with torch.no_grad():
|
||||
policy.reset()
|
||||
for _ in range(args.warmup):
|
||||
_ = policy.predict_action_chunk(dummy_observation)
|
||||
_ = policy.select_action(dummy_observation)
|
||||
_sync()
|
||||
|
||||
# Memory footprint before timing
|
||||
@@ -149,47 +186,70 @@ def main():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# PyTorch timing with Event objects for more accurate GPU timing
|
||||
print(f"Running benchmark: {args.num_trials} trials x {args.forwards_per_trial} forwards...")
|
||||
print(f"Running benchmark: {args.num_samples} samples...")
|
||||
|
||||
if use_cuda:
|
||||
# Use CUDA Events for precise GPU timing
|
||||
start_events = []
|
||||
end_events = []
|
||||
timeout_count = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(args.num_trials):
|
||||
for _ in range(args.forwards_per_trial):
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
_ = policy.predict_action_chunk(dummy_observation)
|
||||
end_event.record()
|
||||
for forward in tqdm(range(args.num_samples), desc="Trials"):
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
try:
|
||||
with timeout(args.timeout):
|
||||
start_event.record()
|
||||
_ = policy.select_action(dummy_observation)
|
||||
end_event.record()
|
||||
|
||||
start_events.append(start_event)
|
||||
end_events.append(end_event)
|
||||
except TimeoutException:
|
||||
timeout_count += 1
|
||||
# Add placeholder for timeout
|
||||
start_events.append(None)
|
||||
end_events.append(None)
|
||||
print(f"\n[!] Timeout on forward {forward + 1}")
|
||||
continue
|
||||
|
||||
# Synchronize and collect timing results
|
||||
torch.cuda.synchronize()
|
||||
per_forward_ms = []
|
||||
for start_event, end_event in zip(start_events, end_events, strict=True):
|
||||
per_forward_ms.append(start_event.elapsed_time(end_event))
|
||||
if start_event is None:
|
||||
per_forward_ms.append(args.timeout * 1000)
|
||||
else:
|
||||
per_forward_ms.append(start_event.elapsed_time(end_event))
|
||||
|
||||
if timeout_count > 0:
|
||||
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
|
||||
|
||||
else:
|
||||
# Use torch.utils.benchmark for CPU/MPS timing
|
||||
from torch.utils.benchmark import Timer
|
||||
# Use simple time.perf_counter for CPU/MPS timing with timeout
|
||||
import time
|
||||
|
||||
def run_inference():
|
||||
return policy.predict_action_chunk(dummy_observation)
|
||||
|
||||
# Collect individual timing measurements
|
||||
per_forward_ms = []
|
||||
timeout_count = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(args.num_trials):
|
||||
for _ in range(args.forwards_per_trial):
|
||||
timer = Timer(stmt="run_inference()", globals={"run_inference": run_inference})
|
||||
measurement = timer.timeit(1) # Single measurement
|
||||
per_forward_ms.append(measurement.mean * 1000) # Convert to ms
|
||||
for sample in tqdm(range(args.num_samples), desc="Samples"):
|
||||
try:
|
||||
with timeout(args.timeout):
|
||||
start_time = time.perf_counter()
|
||||
_ = policy.select_action(dummy_observation)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms
|
||||
except TimeoutException:
|
||||
timeout_count += 1
|
||||
per_forward_ms.append(args.timeout * 1000)
|
||||
print(f"\n[!] Timeout on sample {sample + 1}")
|
||||
continue
|
||||
|
||||
if timeout_count > 0:
|
||||
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
|
||||
|
||||
# Memory footprint after timing
|
||||
rss_after = process.memory_info().rss
|
||||
@@ -215,11 +275,13 @@ def main():
|
||||
"policy_type": args.policy_type,
|
||||
"policy_id": args.policy_id,
|
||||
"device": device,
|
||||
"num_trials": args.num_trials,
|
||||
"forwards_per_trial": args.forwards_per_trial,
|
||||
"num_trials": args.num_samples,
|
||||
"forwards_per_trial": 1,
|
||||
"warmup": args.warmup,
|
||||
"timeout_ms": args.timeout * 1000,
|
||||
"seed": args.seed,
|
||||
"num_params": num_params,
|
||||
"timeout_count": timeout_count,
|
||||
"latency_mean_ms": mean_ms,
|
||||
"latency_std_ms": std_ms,
|
||||
"latency_min_ms": min_ms,
|
||||
@@ -248,10 +310,11 @@ Input Features: {", ".join(results["input_features"])}
|
||||
Output Features: {", ".join(results["output_features"])}
|
||||
|
||||
=== Benchmark Configuration ===
|
||||
Trials: {results["num_trials"]}
|
||||
Forwards per Trial: {results["forwards_per_trial"]}
|
||||
Samples: {results["num_trials"]}
|
||||
Warmup: {results["warmup"]}
|
||||
Total Measurements: {len(per_forward_ms)}
|
||||
Timeout: {results["timeout_ms"]:.1f}ms
|
||||
Timeouts: {results["timeout_count"]} / {results["num_trials"]}
|
||||
|
||||
=== Latency Results (ms) ===
|
||||
Mean: {results["latency_mean_ms"]:.3f}
|
||||
@@ -290,7 +353,7 @@ Benchmark completed successfully at {datetime.now().strftime("%Y-%m-%d %H:%M:%S"
|
||||
print("\n=== Inference Benchmark Results ===")
|
||||
print(f"Policy: {args.policy_type} ({args.policy_id})")
|
||||
print(f"Device: {device}")
|
||||
print(f"Trials: {args.num_trials} | Forwards/Trial: {args.forwards_per_trial} | Warmup: {args.warmup}")
|
||||
print(f"Samples: {args.num_samples} | Warmup: {args.warmup}")
|
||||
print(f"Model params: {num_params:,}")
|
||||
|
||||
print("\nLatency per forward (ms):")
|
||||
|
||||
Reference in New Issue
Block a user