Files
lerobot/benchmarks/policies/inference.py
T
2025-09-23 22:34:52 +02:00

316 lines
12 KiB
Python

"""
Benchmark memory footprint and inference latency of a policy on arbitrary devices.
This script loads a pretrained policy directly (similar to the async inference server)
and generates dummy input data based on the policy's input_features to perform
accurate benchmarking without requiring datasets.
"""
import argparse
import os
import statistics
from datetime import datetime
from pathlib import Path
import psutil
import torch
from lerobot.configs.types import FeatureType
from lerobot.policies.factory import get_policy_class
from lerobot.policies.pretrained import PreTrainedPolicy
def bytes_to_human(n: int) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if n < 1024:
return f"{n:.2f} {unit}"
n /= 1024
return f"{n:.2f} PB"
def percentile(values: list[float], p: float) -> float:
if not values:
return float("nan")
k = (len(values) - 1) * (p / 100.0)
f = int(k)
c = min(f + 1, len(values) - 1)
if f == c:
return values[f]
return values[f] + (values[c] - values[f]) * (k - f)
def generate_dummy_observation(input_features: dict, device: str = "cpu") -> dict:
"""Generate dummy observation data based on policy input features."""
dummy_obs = {}
for key, feature in input_features.items():
shape = feature.shape
if feature.type == FeatureType.VISUAL:
# Images: random values in [0, 1] range (already normalized)
dummy_obs[key] = torch.rand(shape, dtype=torch.float32, device=device)
elif feature.type in [FeatureType.STATE, FeatureType.ACTION, FeatureType.ENV]:
# State/action/env: random normal distribution
dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device)
else:
# Default: random normal for unknown types
dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device)
# Add batch dimension
for key in dummy_obs:
dummy_obs[key] = dummy_obs[key].unsqueeze(0)
# Add task string for language-conditioned policies
dummy_obs["task"] = ""
return dummy_obs
def main():
parser = argparse.ArgumentParser(description="Policy inference benchmark")
parser.add_argument(
"--policy-id", type=str, required=True, help="Model ID or local path to pretrained policy"
)
parser.add_argument(
"--policy-type", type=str, required=True, help="Type of policy (smolvla, act, diffusion, etc.)"
)
parser.add_argument(
"--device", type=str, default="mps", choices=["cuda", "cpu", "mps"], help="Device to run on"
)
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--num-trials", type=int, default=10, help="Number of timing trials")
parser.add_argument("--forwards-per-trial", type=int, default=10, help="Number of forwards per trial")
parser.add_argument("--warmup", type=int, default=2, help="Warmup forwards (not timed)")
parser.add_argument(
"--output-dir", type=str, default="outputs/benchmarks", help="Directory to save benchmark results"
)
args = parser.parse_args()
# Seed & deterministic-ish setup
torch.manual_seed(args.seed)
if args.device == "cuda":
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False # leave False to avoid perf cliffs
# Resolve device availability
device = args.device.lower()
if device == "cuda" and not torch.cuda.is_available():
print("[!] CUDA requested but unavailable. Falling back to CPU.")
device = "cpu"
elif device == "mps" and not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()):
print("[!] MPS requested but unavailable. Falling back to CPU.")
device = "cpu"
use_cuda = device == "cuda"
# Create output directory and log file
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
policy_name = args.policy_id.replace("/", "_").replace("\\", "_")
log_file = output_dir / f"benchmark_{args.policy_type}_{policy_name}_{device}_{timestamp}.txt"
# Load policy directly from pretrained (similar to async inference server)
print(f"Loading policy {args.policy_type} from {args.policy_id}...")
policy_class = get_policy_class(args.policy_type)
policy: PreTrainedPolicy = policy_class.from_pretrained(args.policy_id)
policy.eval()
policy.to(device)
print(f"Policy loaded on {device}")
print(f"Input features: {list(policy.config.input_features.keys())}")
print(f"Output features: {list(policy.config.output_features.keys())}")
# Generate dummy observation based on policy input features
dummy_observation = generate_dummy_observation(policy.config.input_features, device)
# Helper to sync for fair timings
def _sync(dev_=device):
if dev_ == "cuda" and torch.cuda.is_available():
torch.cuda.synchronize()
elif dev_ == "mps" and hasattr(torch, "mps"):
try:
torch.mps.synchronize()
except AttributeError:
pass # MPS sync not available in this PyTorch version
# Warmup (to stabilize kernels/caches)
print("Warming up...")
with torch.no_grad():
for _ in range(args.warmup):
_ = policy.predict_action_chunk(dummy_observation)
_sync()
# Memory footprint before timing
process = psutil.Process(os.getpid())
rss_before = process.memory_info().rss
if use_cuda:
torch.cuda.reset_peak_memory_stats()
# PyTorch timing with Event objects for more accurate GPU timing
print(f"Running benchmark: {args.num_trials} trials x {args.forwards_per_trial} forwards...")
if use_cuda:
# Use CUDA Events for precise GPU timing
start_events = []
end_events = []
with torch.no_grad():
for _ in range(args.num_trials):
for _ in range(args.forwards_per_trial):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
_ = policy.predict_action_chunk(dummy_observation)
end_event.record()
start_events.append(start_event)
end_events.append(end_event)
# Synchronize and collect timing results
torch.cuda.synchronize()
per_forward_ms = []
for start_event, end_event in zip(start_events, end_events, strict=True):
per_forward_ms.append(start_event.elapsed_time(end_event))
else:
# Use torch.utils.benchmark for CPU/MPS timing
from torch.utils.benchmark import Timer
def run_inference():
return policy.predict_action_chunk(dummy_observation)
# Collect individual timing measurements
per_forward_ms = []
with torch.no_grad():
for _ in range(args.num_trials):
for _ in range(args.forwards_per_trial):
timer = Timer(stmt="run_inference()", globals={"run_inference": run_inference})
measurement = timer.timeit(1) # Single measurement
per_forward_ms.append(measurement.mean * 1000) # Convert to ms
# Memory footprint after timing
rss_after = process.memory_info().rss
rss_delta = rss_after - rss_before
cuda_peak = torch.cuda.max_memory_allocated() if use_cuda else 0
# Sort timing results for percentile calculations
per_forward_ms_sorted = sorted(per_forward_ms)
mean_ms = statistics.fmean(per_forward_ms) if per_forward_ms else float("nan")
std_ms = statistics.pstdev(per_forward_ms) if len(per_forward_ms) > 1 else 0.0
min_ms = per_forward_ms_sorted[0] if per_forward_ms_sorted else float("nan")
max_ms = per_forward_ms_sorted[-1] if per_forward_ms_sorted else float("nan")
p50_ms = percentile(per_forward_ms_sorted, 50)
p95_ms = percentile(per_forward_ms_sorted, 95)
# Model size
num_params = sum(p.numel() for p in policy.parameters())
# Prepare results for logging
results = {
"timestamp": datetime.now().isoformat(),
"policy_type": args.policy_type,
"policy_id": args.policy_id,
"device": device,
"num_trials": args.num_trials,
"forwards_per_trial": args.forwards_per_trial,
"warmup": args.warmup,
"seed": args.seed,
"num_params": num_params,
"latency_mean_ms": mean_ms,
"latency_std_ms": std_ms,
"latency_min_ms": min_ms,
"latency_max_ms": max_ms,
"latency_p50_ms": p50_ms,
"latency_p95_ms": p95_ms,
"cpu_rss_before": rss_before,
"cpu_rss_after": rss_after,
"cpu_rss_delta": rss_delta,
"cuda_peak_alloc": cuda_peak,
"input_features": list(policy.config.input_features.keys()),
"output_features": list(policy.config.output_features.keys()),
}
# Format and write results to log file
log_content = f"""
=== LeRobot Policy Inference Benchmark ===
Timestamp: {results["timestamp"]}
Policy: {results["policy_type"]} ({results["policy_id"]})
Device: {results["device"]}
Seed: {results["seed"]}
=== Model Information ===
Parameters: {results["num_params"]:,}
Input Features: {", ".join(results["input_features"])}
Output Features: {", ".join(results["output_features"])}
=== Benchmark Configuration ===
Trials: {results["num_trials"]}
Forwards per Trial: {results["forwards_per_trial"]}
Warmup: {results["warmup"]}
Total Measurements: {len(per_forward_ms)}
=== Latency Results (ms) ===
Mean: {results["latency_mean_ms"]:.3f}
Std Dev: {results["latency_std_ms"]:.3f}
Min: {results["latency_min_ms"]:.3f}
Max: {results["latency_max_ms"]:.3f}
P50: {results["latency_p50_ms"]:.3f}
P95: {results["latency_p95_ms"]:.3f}
=== Memory Footprint ===
CPU RSS Before: {bytes_to_human(results["cpu_rss_before"])}
CPU RSS After: {bytes_to_human(results["cpu_rss_after"])}{bytes_to_human(results["cpu_rss_delta"])})
"""
if use_cuda:
log_content += f"CUDA Peak: {bytes_to_human(results['cuda_peak_alloc'])} (reset before timing)\n"
log_content += f"""
=== Raw Timing Data (first 20 measurements, ms) ===
{", ".join(f"{t:.3f}" for t in per_forward_ms[:20])}
{"..." if len(per_forward_ms) > 20 else ""}
=== Summary Statistics ===
Timing Method: {"CUDA Events" if use_cuda else "torch.utils.benchmark.Timer"}
Device Available: {torch.cuda.is_available() if device == "cuda" else torch.backends.mps.is_available() if device == "mps" else True}
PyTorch Version: {torch.__version__}
Benchmark completed successfully at {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
"""
# Write to log file
with open(log_file, "w") as f:
f.write(log_content)
# Print to console (shorter version)
print("\n=== Inference Benchmark Results ===")
print(f"Policy: {args.policy_type} ({args.policy_id})")
print(f"Device: {device}")
print(f"Trials: {args.num_trials} | Forwards/Trial: {args.forwards_per_trial} | Warmup: {args.warmup}")
print(f"Model params: {num_params:,}")
print("\nLatency per forward (ms):")
print(f" mean: {mean_ms:.3f} std: {std_ms:.3f}")
print(f" min: {min_ms:.3f} max: {max_ms:.3f}")
print(f" p50: {p50_ms:.3f} p95: {p95_ms:.3f}")
print("\nMemory footprint:")
print(f" CPU RSS before: {bytes_to_human(rss_before)}")
print(f" CPU RSS after : {bytes_to_human(rss_after)}{bytes_to_human(rss_delta)})")
if use_cuda:
print(
f" CUDA peak allocated: {bytes_to_human(cuda_peak)} "
f"(reset by reset_peak_memory_stats before timing)"
)
print(f"\nResults saved to: {log_file}")
print("Benchmark completed successfully!")
if __name__ == "__main__":
main()