Signed-off-by: Jade Choghari <chogharijade@gmail.com>
This commit is contained in:
Jade Choghari
2025-09-22 17:39:26 +02:00
committed by Francesco Capuano
parent 54c6b8ae52
commit f6cd24be17
+23 -22
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
""" """
Minimal SmolVLA inference + benchmarking. Minimal Policy inference + benchmarking.
Features: Features:
- End-to-end pipeline: dataset -> pre/post-processors -> policy.select_action - End-to-end pipeline: dataset -> pre/post-processors -> policy.select_action
@@ -26,8 +26,8 @@ import psutil
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.policies.factory import make_policy, make_policy_config from lerobot.policies.factory import make_policy, make_policy_config
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors from lerobot.policies.factory import make_pre_post_processors
def bytes_to_human(n: int) -> str: def bytes_to_human(n: int) -> str:
@@ -64,64 +64,65 @@ def main():
parser.add_argument("--forwards_per_trial", type=int, default=1, help="Number of forwards per trial") parser.add_argument("--forwards_per_trial", type=int, default=1, help="Number of forwards per trial")
parser.add_argument("--warmup", type=int, default=20, help="Warmup forwards (not timed)") parser.add_argument("--warmup", type=int, default=20, help="Warmup forwards (not timed)")
parser.add_argument("--print_each_trial", action="store_true", help="Print each trial's aggregate time") parser.add_argument("--print_each_trial", action="store_true", help="Print each trial's aggregate time")
parser.add_argument("--policy_type", type=str, default="smolvla", help="Type of policy to benchmark")
args = parser.parse_args() args = parser.parse_args()
# seed & deterministic-ish setup # Seed & deterministic-ish setup
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False # leave False to avoid perf cliffs torch.backends.cudnn.deterministic = False # leave False to avoid perf cliffs
# device # Device
use_cuda = args.device == "cuda" and torch.cuda.is_available() use_cuda = args.device == "cuda" and torch.cuda.is_available()
device = "cuda" if use_cuda else "cpu" device = "cuda" if use_cuda else "cpu"
if args.device == "cuda" and not use_cuda: if args.device == "cuda" and not use_cuda:
print("[!] CUDA requested but unavailable. Falling back to CPU.") print("[!] CUDA requested but unavailable. Falling back to CPU.")
# load dataset metadata # Load dataset metadata
ds_meta = LeRobotDatasetMetadata(args.repo_id) ds_meta = LeRobotDatasetMetadata(args.repo_id)
# policy config & creation # Policy config & creation
cfg = make_policy_config( cfg = make_policy_config(
"smolvla", args.policy_type,
n_obs_steps=args.n_obs_steps, n_obs_steps=args.n_obs_steps,
chunk_size=args.chunk_size, chunk_size=args.chunk_size, # comment this if policy_type = "diffusion"
n_action_steps=args.n_action_steps, n_action_steps=args.n_action_steps,
device=device, device=device,
) )
policy: SmolVLAPolicy = make_policy(cfg, ds_meta=ds_meta) policy: PreTrainedPolicy = make_policy(cfg, ds_meta=ds_meta)
policy.eval() policy.eval()
policy.to(device) policy.to(device)
# Pre/post processors # Pre/post processors
preprocessor, postprocessor = make_smolvla_pre_post_processors(cfg, dataset_stats=ds_meta.stats) preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=ds_meta.stats)
# dataset sample # Dataset sample
dataset = LeRobotDataset(args.repo_id, episodes=[args.episode]) dataset = LeRobotDataset(args.repo_id, episodes=[args.episode])
sample = dataset[args.sample_index] sample = dataset[args.sample_index]
# preprocess once; we will reuse the same batch for all forwards (typical for latency bench) # Preprocess once; we will reuse the same batch for all forwards (typical for latency bench)
preprocessed_batch = preprocessor(sample) preprocessed_batch = preprocessor(sample)
# helper to sync for fair timings # Helper to sync for fair timings
def _sync(): def _sync():
if use_cuda: if use_cuda:
torch.cuda.synchronize() torch.cuda.synchronize()
# warmup (to stabilize kernels/caches) # Warmup (to stabilize kernels/caches)
with torch.no_grad(): with torch.no_grad():
for _ in range(args.warmup): for _ in range(args.warmup):
_ = policy.select_action(preprocessed_batch) _ = policy.select_action(preprocessed_batch)
_sync() _sync()
# memory footprint before timing # Memory footprint before timing
process = psutil.Process(os.getpid()) process = psutil.Process(os.getpid())
rss_before = process.memory_info().rss rss_before = process.memory_info().rss
if use_cuda: if use_cuda:
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
# timing # Timing
trial_times_sec: List[float] = [] trial_times_sec: List[float] = []
with torch.no_grad(): with torch.no_grad():
@@ -138,17 +139,17 @@ def main():
print(f"[trial {t+1:03d}] total {trial_dur*1000:.3f} ms " print(f"[trial {t+1:03d}] total {trial_dur*1000:.3f} ms "
f"({(trial_dur/args.forwards_per_trial)*1000:.3f} ms/forward)") f"({(trial_dur/args.forwards_per_trial)*1000:.3f} ms/forward)")
# memory footprint after timing # Memory footprint after timing
rss_after = process.memory_info().rss rss_after = process.memory_info().rss
rss_delta = rss_after - rss_before rss_delta = rss_after - rss_before
cuda_peak = torch.cuda.max_memory_allocated() if use_cuda else 0 cuda_peak = torch.cuda.max_memory_allocated() if use_cuda else 0
# do a single real inference and postprocess to verify everything still works # Do a single real inference and postprocess to verify everything still works
with torch.no_grad(): with torch.no_grad():
action = policy.select_action(preprocessed_batch) action = policy.select_action(preprocessed_batch)
postprocessed_action = postprocessor(action) postprocessed_action = postprocessor(action)
# summaries # Summaries
# Per-forward latencies in ms # Per-forward latencies in ms
per_forward_ms = [(d / args.forwards_per_trial) * 1000.0 for d in trial_times_sec] per_forward_ms = [(d / args.forwards_per_trial) * 1000.0 for d in trial_times_sec]
per_forward_ms_sorted = sorted(per_forward_ms) per_forward_ms_sorted = sorted(per_forward_ms)
@@ -160,10 +161,10 @@ def main():
p50_ms = percentile(per_forward_ms_sorted, 50) p50_ms = percentile(per_forward_ms_sorted, 50)
p95_ms = percentile(per_forward_ms_sorted, 95) p95_ms = percentile(per_forward_ms_sorted, 95)
# model size # Model size
num_params = sum(p.numel() for p in policy.parameters()) num_params = sum(p.numel() for p in policy.parameters())
print("\n=== SmolVLA Inference Benchmark ===") print("\n=== Inference Benchmark for ===", args.policy_type)
print(f"Device: {device}") print(f"Device: {device}")
print(f"Trials: {args.num_trials} | Forwards/Trial: {args.forwards_per_trial} | Warmup: {args.warmup}") print(f"Trials: {args.num_trials} | Forwards/Trial: {args.forwards_per_trial} | Warmup: {args.warmup}")
print(f"Model params: {num_params:,}") print(f"Model params: {num_params:,}")