mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
fix: single level loop
This commit is contained in:
@@ -8,18 +8,46 @@ accurate benchmarking without requiring datasets.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import signal
|
||||||
import statistics
|
import statistics
|
||||||
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType
|
from lerobot.configs.types import FeatureType
|
||||||
from lerobot.policies.factory import get_policy_class
|
from lerobot.policies.factory import get_policy_class
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class TimeoutException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def timeout(seconds):
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
raise TimeoutException(f"Timed out after {seconds} seconds")
|
||||||
|
|
||||||
|
# On Windows, signal is not available, so we can't use this timeout mechanism
|
||||||
|
if not hasattr(signal, "SIGALRM"):
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
old_handler = signal.signal(signal.SIGALRM, signal_handler)
|
||||||
|
try:
|
||||||
|
# signal.alarm expects integer seconds
|
||||||
|
# for float seconds, we can use setitimer
|
||||||
|
signal.setitimer(signal.ITIMER_REAL, seconds)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
signal.setitimer(signal.ITIMER_REAL, 0)
|
||||||
|
signal.signal(signal.SIGALRM, old_handler)
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_human(n: int) -> str:
|
def bytes_to_human(n: int) -> str:
|
||||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||||
if n < 1024:
|
if n < 1024:
|
||||||
@@ -78,12 +106,19 @@ def main():
|
|||||||
"--device", type=str, default="mps", choices=["cuda", "cpu", "mps"], help="Device to run on"
|
"--device", type=str, default="mps", choices=["cuda", "cpu", "mps"], help="Device to run on"
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||||
parser.add_argument("--num-trials", type=int, default=10, help="Number of timing trials")
|
parser.add_argument(
|
||||||
parser.add_argument("--forwards-per-trial", type=int, default=10, help="Number of forwards per trial")
|
"--num-samples", type=int, default=100, help="Number of inference samples to benchmark"
|
||||||
parser.add_argument("--warmup", type=int, default=2, help="Warmup forwards (not timed)")
|
)
|
||||||
|
parser.add_argument("--warmup", type=int, default=10, help="Number of warmup samples (not timed)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-dir", type=str, default="outputs/benchmarks", help="Directory to save benchmark results"
|
"--output-dir", type=str, default="outputs/benchmarks", help="Directory to save benchmark results"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout",
|
||||||
|
type=float,
|
||||||
|
default=0.3,
|
||||||
|
help="Timeout for each inference pass in seconds (default: 0.3s = 300ms)",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Seed & deterministic-ish setup
|
# Seed & deterministic-ish setup
|
||||||
@@ -124,6 +159,7 @@ def main():
|
|||||||
|
|
||||||
# Generate dummy observation based on policy input features
|
# Generate dummy observation based on policy input features
|
||||||
dummy_observation = generate_dummy_observation(policy.config.input_features, device)
|
dummy_observation = generate_dummy_observation(policy.config.input_features, device)
|
||||||
|
dummy_observation["task"] = ""
|
||||||
|
|
||||||
# Helper to sync for fair timings
|
# Helper to sync for fair timings
|
||||||
def _sync(dev_=device):
|
def _sync(dev_=device):
|
||||||
@@ -138,8 +174,9 @@ def main():
|
|||||||
# Warmup (to stabilize kernels/caches)
|
# Warmup (to stabilize kernels/caches)
|
||||||
print("Warming up...")
|
print("Warming up...")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
policy.reset()
|
||||||
for _ in range(args.warmup):
|
for _ in range(args.warmup):
|
||||||
_ = policy.predict_action_chunk(dummy_observation)
|
_ = policy.select_action(dummy_observation)
|
||||||
_sync()
|
_sync()
|
||||||
|
|
||||||
# Memory footprint before timing
|
# Memory footprint before timing
|
||||||
@@ -149,47 +186,70 @@ def main():
|
|||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
# PyTorch timing with Event objects for more accurate GPU timing
|
# PyTorch timing with Event objects for more accurate GPU timing
|
||||||
print(f"Running benchmark: {args.num_trials} trials x {args.forwards_per_trial} forwards...")
|
print(f"Running benchmark: {args.num_samples} samples...")
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
# Use CUDA Events for precise GPU timing
|
# Use CUDA Events for precise GPU timing
|
||||||
start_events = []
|
start_events = []
|
||||||
end_events = []
|
end_events = []
|
||||||
|
timeout_count = 0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _ in range(args.num_trials):
|
for forward in tqdm(range(args.num_samples), desc="Trials"):
|
||||||
for _ in range(args.forwards_per_trial):
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
try:
|
||||||
|
with timeout(args.timeout):
|
||||||
start_event.record()
|
start_event.record()
|
||||||
_ = policy.predict_action_chunk(dummy_observation)
|
_ = policy.select_action(dummy_observation)
|
||||||
end_event.record()
|
end_event.record()
|
||||||
|
|
||||||
start_events.append(start_event)
|
start_events.append(start_event)
|
||||||
end_events.append(end_event)
|
end_events.append(end_event)
|
||||||
|
except TimeoutException:
|
||||||
|
timeout_count += 1
|
||||||
|
# Add placeholder for timeout
|
||||||
|
start_events.append(None)
|
||||||
|
end_events.append(None)
|
||||||
|
print(f"\n[!] Timeout on forward {forward + 1}")
|
||||||
|
continue
|
||||||
|
|
||||||
# Synchronize and collect timing results
|
# Synchronize and collect timing results
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
per_forward_ms = []
|
per_forward_ms = []
|
||||||
for start_event, end_event in zip(start_events, end_events, strict=True):
|
for start_event, end_event in zip(start_events, end_events, strict=True):
|
||||||
per_forward_ms.append(start_event.elapsed_time(end_event))
|
if start_event is None:
|
||||||
|
per_forward_ms.append(args.timeout * 1000)
|
||||||
|
else:
|
||||||
|
per_forward_ms.append(start_event.elapsed_time(end_event))
|
||||||
|
|
||||||
|
if timeout_count > 0:
|
||||||
|
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Use torch.utils.benchmark for CPU/MPS timing
|
# Use simple time.perf_counter for CPU/MPS timing with timeout
|
||||||
from torch.utils.benchmark import Timer
|
import time
|
||||||
|
|
||||||
def run_inference():
|
|
||||||
return policy.predict_action_chunk(dummy_observation)
|
|
||||||
|
|
||||||
# Collect individual timing measurements
|
|
||||||
per_forward_ms = []
|
per_forward_ms = []
|
||||||
|
timeout_count = 0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _ in range(args.num_trials):
|
for sample in tqdm(range(args.num_samples), desc="Samples"):
|
||||||
for _ in range(args.forwards_per_trial):
|
try:
|
||||||
timer = Timer(stmt="run_inference()", globals={"run_inference": run_inference})
|
with timeout(args.timeout):
|
||||||
measurement = timer.timeit(1) # Single measurement
|
start_time = time.perf_counter()
|
||||||
per_forward_ms.append(measurement.mean * 1000) # Convert to ms
|
_ = policy.select_action(dummy_observation)
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
|
||||||
|
per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms
|
||||||
|
except TimeoutException:
|
||||||
|
timeout_count += 1
|
||||||
|
per_forward_ms.append(args.timeout * 1000)
|
||||||
|
print(f"\n[!] Timeout on sample {sample + 1}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if timeout_count > 0:
|
||||||
|
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
|
||||||
|
|
||||||
# Memory footprint after timing
|
# Memory footprint after timing
|
||||||
rss_after = process.memory_info().rss
|
rss_after = process.memory_info().rss
|
||||||
@@ -215,11 +275,13 @@ def main():
|
|||||||
"policy_type": args.policy_type,
|
"policy_type": args.policy_type,
|
||||||
"policy_id": args.policy_id,
|
"policy_id": args.policy_id,
|
||||||
"device": device,
|
"device": device,
|
||||||
"num_trials": args.num_trials,
|
"num_trials": args.num_samples,
|
||||||
"forwards_per_trial": args.forwards_per_trial,
|
"forwards_per_trial": 1,
|
||||||
"warmup": args.warmup,
|
"warmup": args.warmup,
|
||||||
|
"timeout_ms": args.timeout * 1000,
|
||||||
"seed": args.seed,
|
"seed": args.seed,
|
||||||
"num_params": num_params,
|
"num_params": num_params,
|
||||||
|
"timeout_count": timeout_count,
|
||||||
"latency_mean_ms": mean_ms,
|
"latency_mean_ms": mean_ms,
|
||||||
"latency_std_ms": std_ms,
|
"latency_std_ms": std_ms,
|
||||||
"latency_min_ms": min_ms,
|
"latency_min_ms": min_ms,
|
||||||
@@ -248,10 +310,11 @@ Input Features: {", ".join(results["input_features"])}
|
|||||||
Output Features: {", ".join(results["output_features"])}
|
Output Features: {", ".join(results["output_features"])}
|
||||||
|
|
||||||
=== Benchmark Configuration ===
|
=== Benchmark Configuration ===
|
||||||
Trials: {results["num_trials"]}
|
Samples: {results["num_trials"]}
|
||||||
Forwards per Trial: {results["forwards_per_trial"]}
|
|
||||||
Warmup: {results["warmup"]}
|
Warmup: {results["warmup"]}
|
||||||
Total Measurements: {len(per_forward_ms)}
|
Total Measurements: {len(per_forward_ms)}
|
||||||
|
Timeout: {results["timeout_ms"]:.1f}ms
|
||||||
|
Timeouts: {results["timeout_count"]} / {results["num_trials"]}
|
||||||
|
|
||||||
=== Latency Results (ms) ===
|
=== Latency Results (ms) ===
|
||||||
Mean: {results["latency_mean_ms"]:.3f}
|
Mean: {results["latency_mean_ms"]:.3f}
|
||||||
@@ -290,7 +353,7 @@ Benchmark completed successfully at {datetime.now().strftime("%Y-%m-%d %H:%M:%S"
|
|||||||
print("\n=== Inference Benchmark Results ===")
|
print("\n=== Inference Benchmark Results ===")
|
||||||
print(f"Policy: {args.policy_type} ({args.policy_id})")
|
print(f"Policy: {args.policy_type} ({args.policy_id})")
|
||||||
print(f"Device: {device}")
|
print(f"Device: {device}")
|
||||||
print(f"Trials: {args.num_trials} | Forwards/Trial: {args.forwards_per_trial} | Warmup: {args.warmup}")
|
print(f"Samples: {args.num_samples} | Warmup: {args.warmup}")
|
||||||
print(f"Model params: {num_params:,}")
|
print(f"Model params: {num_params:,}")
|
||||||
|
|
||||||
print("\nLatency per forward (ms):")
|
print("\nLatency per forward (ms):")
|
||||||
|
|||||||
Reference in New Issue
Block a user