mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
committed by
Francesco Capuano
parent
54c6b8ae52
commit
f6cd24be17
@@ -1,6 +1,6 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
"""
|
"""
|
||||||
Minimal SmolVLA inference + benchmarking.
|
Minimal Policy inference + benchmarking.
|
||||||
|
|
||||||
Features:
|
Features:
|
||||||
- End-to-end pipeline: dataset -> pre/post-processors -> policy.select_action
|
- End-to-end pipeline: dataset -> pre/post-processors -> policy.select_action
|
||||||
@@ -26,8 +26,8 @@ import psutil
|
|||||||
|
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.policies.factory import make_policy, make_policy_config
|
from lerobot.policies.factory import make_policy, make_policy_config
|
||||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
from lerobot.policies.factory import make_pre_post_processors
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_human(n: int) -> str:
|
def bytes_to_human(n: int) -> str:
|
||||||
@@ -64,64 +64,65 @@ def main():
|
|||||||
parser.add_argument("--forwards_per_trial", type=int, default=1, help="Number of forwards per trial")
|
parser.add_argument("--forwards_per_trial", type=int, default=1, help="Number of forwards per trial")
|
||||||
parser.add_argument("--warmup", type=int, default=20, help="Warmup forwards (not timed)")
|
parser.add_argument("--warmup", type=int, default=20, help="Warmup forwards (not timed)")
|
||||||
parser.add_argument("--print_each_trial", action="store_true", help="Print each trial's aggregate time")
|
parser.add_argument("--print_each_trial", action="store_true", help="Print each trial's aggregate time")
|
||||||
|
parser.add_argument("--policy_type", type=str, default="smolvla", help="Type of policy to benchmark")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# seed & deterministic-ish setup
|
# Seed & deterministic-ish setup
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
torch.backends.cudnn.deterministic = False # leave False to avoid perf cliffs
|
torch.backends.cudnn.deterministic = False # leave False to avoid perf cliffs
|
||||||
|
|
||||||
# device
|
# Device
|
||||||
use_cuda = args.device == "cuda" and torch.cuda.is_available()
|
use_cuda = args.device == "cuda" and torch.cuda.is_available()
|
||||||
device = "cuda" if use_cuda else "cpu"
|
device = "cuda" if use_cuda else "cpu"
|
||||||
if args.device == "cuda" and not use_cuda:
|
if args.device == "cuda" and not use_cuda:
|
||||||
print("[!] CUDA requested but unavailable. Falling back to CPU.")
|
print("[!] CUDA requested but unavailable. Falling back to CPU.")
|
||||||
|
|
||||||
# load dataset metadata
|
# Load dataset metadata
|
||||||
ds_meta = LeRobotDatasetMetadata(args.repo_id)
|
ds_meta = LeRobotDatasetMetadata(args.repo_id)
|
||||||
|
|
||||||
# policy config & creation
|
# Policy config & creation
|
||||||
cfg = make_policy_config(
|
cfg = make_policy_config(
|
||||||
"smolvla",
|
args.policy_type,
|
||||||
n_obs_steps=args.n_obs_steps,
|
n_obs_steps=args.n_obs_steps,
|
||||||
chunk_size=args.chunk_size,
|
chunk_size=args.chunk_size, # comment this if policy_type = "diffusion"
|
||||||
n_action_steps=args.n_action_steps,
|
n_action_steps=args.n_action_steps,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
policy: SmolVLAPolicy = make_policy(cfg, ds_meta=ds_meta)
|
policy: PreTrainedPolicy = make_policy(cfg, ds_meta=ds_meta)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
policy.to(device)
|
policy.to(device)
|
||||||
|
|
||||||
# Pre/post processors
|
# Pre/post processors
|
||||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(cfg, dataset_stats=ds_meta.stats)
|
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=ds_meta.stats)
|
||||||
|
|
||||||
# dataset sample
|
# Dataset sample
|
||||||
dataset = LeRobotDataset(args.repo_id, episodes=[args.episode])
|
dataset = LeRobotDataset(args.repo_id, episodes=[args.episode])
|
||||||
sample = dataset[args.sample_index]
|
sample = dataset[args.sample_index]
|
||||||
|
|
||||||
# preprocess once; we will reuse the same batch for all forwards (typical for latency bench)
|
# Preprocess once; we will reuse the same batch for all forwards (typical for latency bench)
|
||||||
preprocessed_batch = preprocessor(sample)
|
preprocessed_batch = preprocessor(sample)
|
||||||
|
|
||||||
# helper to sync for fair timings
|
# Helper to sync for fair timings
|
||||||
def _sync():
|
def _sync():
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
# warmup (to stabilize kernels/caches)
|
# Warmup (to stabilize kernels/caches)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _ in range(args.warmup):
|
for _ in range(args.warmup):
|
||||||
_ = policy.select_action(preprocessed_batch)
|
_ = policy.select_action(preprocessed_batch)
|
||||||
_sync()
|
_sync()
|
||||||
|
|
||||||
# memory footprint before timing
|
# Memory footprint before timing
|
||||||
process = psutil.Process(os.getpid())
|
process = psutil.Process(os.getpid())
|
||||||
rss_before = process.memory_info().rss
|
rss_before = process.memory_info().rss
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
# timing
|
# Timing
|
||||||
trial_times_sec: List[float] = []
|
trial_times_sec: List[float] = []
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -138,17 +139,17 @@ def main():
|
|||||||
print(f"[trial {t+1:03d}] total {trial_dur*1000:.3f} ms "
|
print(f"[trial {t+1:03d}] total {trial_dur*1000:.3f} ms "
|
||||||
f"({(trial_dur/args.forwards_per_trial)*1000:.3f} ms/forward)")
|
f"({(trial_dur/args.forwards_per_trial)*1000:.3f} ms/forward)")
|
||||||
|
|
||||||
# memory footprint after timing
|
# Memory footprint after timing
|
||||||
rss_after = process.memory_info().rss
|
rss_after = process.memory_info().rss
|
||||||
rss_delta = rss_after - rss_before
|
rss_delta = rss_after - rss_before
|
||||||
cuda_peak = torch.cuda.max_memory_allocated() if use_cuda else 0
|
cuda_peak = torch.cuda.max_memory_allocated() if use_cuda else 0
|
||||||
|
|
||||||
# do a single real inference and postprocess to verify everything still works
|
# Do a single real inference and postprocess to verify everything still works
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
action = policy.select_action(preprocessed_batch)
|
action = policy.select_action(preprocessed_batch)
|
||||||
postprocessed_action = postprocessor(action)
|
postprocessed_action = postprocessor(action)
|
||||||
|
|
||||||
# summaries
|
# Summaries
|
||||||
# Per-forward latencies in ms
|
# Per-forward latencies in ms
|
||||||
per_forward_ms = [(d / args.forwards_per_trial) * 1000.0 for d in trial_times_sec]
|
per_forward_ms = [(d / args.forwards_per_trial) * 1000.0 for d in trial_times_sec]
|
||||||
per_forward_ms_sorted = sorted(per_forward_ms)
|
per_forward_ms_sorted = sorted(per_forward_ms)
|
||||||
@@ -160,10 +161,10 @@ def main():
|
|||||||
p50_ms = percentile(per_forward_ms_sorted, 50)
|
p50_ms = percentile(per_forward_ms_sorted, 50)
|
||||||
p95_ms = percentile(per_forward_ms_sorted, 95)
|
p95_ms = percentile(per_forward_ms_sorted, 95)
|
||||||
|
|
||||||
# model size
|
# Model size
|
||||||
num_params = sum(p.numel() for p in policy.parameters())
|
num_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
||||||
print("\n=== SmolVLA Inference Benchmark ===")
|
print("\n=== Inference Benchmark for ===", args.policy_type)
|
||||||
print(f"Device: {device}")
|
print(f"Device: {device}")
|
||||||
print(f"Trials: {args.num_trials} | Forwards/Trial: {args.forwards_per_trial} | Warmup: {args.warmup}")
|
print(f"Trials: {args.num_trials} | Forwards/Trial: {args.forwards_per_trial} | Warmup: {args.warmup}")
|
||||||
print(f"Model params: {num_params:,}")
|
print(f"Model params: {num_params:,}")
|
||||||
|
|||||||
Reference in New Issue
Block a user