From cffd545527099c0dc97f3390bd7f47148d46a55f Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Wed, 24 Sep 2025 14:32:47 +0200 Subject: [PATCH] refactor(inference): improve timeout handling and enhance dummy observation generation - Renamed TimeoutException to TimeoutExceptionError for clarity. - Updated dummy observation generation to include a task string. - Integrated pre-processing and post-processing steps in the main function. - Added deep copy of dummy observations to prevent mutation during processing. - Enhanced timeout handling to provide percentage of timeouts during inference. --- benchmarks/policies/inference.py | 47 ++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/benchmarks/policies/inference.py b/benchmarks/policies/inference.py index 65a7974b3..6bd4922f5 100644 --- a/benchmarks/policies/inference.py +++ b/benchmarks/policies/inference.py @@ -11,6 +11,7 @@ import os import signal import statistics from contextlib import contextmanager +from copy import deepcopy from datetime import datetime from pathlib import Path @@ -19,18 +20,18 @@ import torch from tqdm import tqdm from lerobot.configs.types import FeatureType -from lerobot.policies.factory import get_policy_class +from lerobot.policies.factory import get_policy_class, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy -class TimeoutException: +class TimeoutExceptionError(Exception): pass @contextmanager def timeout(seconds): def signal_handler(signum, frame): - raise TimeoutException(f"Timed out after {seconds} seconds") + raise TimeoutExceptionError(f"Timed out after {seconds} seconds") # On Windows, signal is not available, so we can't use this timeout mechanism if not hasattr(signal, "SIGALRM"): @@ -84,12 +85,12 @@ def generate_dummy_observation(input_features: dict, device: str = "cpu") -> dic # Default: random normal for unknown types dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device) - # Add batch dimension - for key in dummy_obs: - dummy_obs[key] = dummy_obs[key].unsqueeze(0) + # # Add batch dimension + # for key in dummy_obs: + # dummy_obs[key] = dummy_obs[key].unsqueeze(0) # Add task string for language-conditioned policies - dummy_obs["task"] = "" + dummy_obs["task"] = " this is a dummy task" return dummy_obs @@ -151,7 +152,9 @@ def main(): policy_class = get_policy_class(args.policy_type) policy: PreTrainedPolicy = policy_class.from_pretrained(args.policy_id) policy.eval() - policy.to(device) + policy.to(device, torch.float32) + policy.config.device = device + preprocessor, postprocessor = make_pre_post_processors(policy.config) print(f"Policy loaded on {device}") print(f"Input features: {list(policy.config.input_features.keys())}") @@ -159,7 +162,7 @@ def main(): # Generate dummy observation based on policy input features dummy_observation = generate_dummy_observation(policy.config.input_features, device) - dummy_observation["task"] = "" + dummy_observation["task"] = "this is a dummy task" # Helper to sync for fair timings def _sync(dev_=device): @@ -175,8 +178,15 @@ def main(): print("Warming up...") with torch.no_grad(): policy.reset() + preprocessor.reset() + postprocessor.reset() + orginal_dummy_observation = deepcopy(dummy_observation) for _ in range(args.warmup): - _ = policy.select_action(dummy_observation) + dummy_observation_model = deepcopy(orginal_dummy_observation) + dummy_observation_model = preprocessor(dummy_observation_model) + action_model = policy.select_action(dummy_observation_model) + _ = postprocessor(action_model) + policy.reset() _sync() # Memory footprint before timing @@ -193,20 +203,25 @@ def main(): start_events = [] end_events = [] timeout_count = 0 + orginal_dummy_observation = deepcopy(dummy_observation) with torch.no_grad(): for forward in tqdm(range(args.num_samples), desc="Trials"): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) try: + dummy_observation_model = deepcopy(orginal_dummy_observation) + dummy_observation_model = preprocessor(dummy_observation) with timeout(args.timeout): start_event.record() - _ = policy.select_action(dummy_observation) + action_model = policy.select_action(dummy_observation_model) end_event.record() + _ = postprocessor(action_model) + policy.reset() start_events.append(start_event) end_events.append(end_event) - except TimeoutException: + except TimeoutExceptionError: timeout_count += 1 # Add placeholder for timeout start_events.append(None) @@ -236,13 +251,16 @@ def main(): with torch.no_grad(): for sample in tqdm(range(args.num_samples), desc="Samples"): try: + dummy_observation_model = deepcopy(orginal_dummy_observation) + dummy_observation_model = preprocessor(dummy_observation_model) with timeout(args.timeout): start_time = time.perf_counter() - _ = policy.select_action(dummy_observation) + action_model = policy.select_action(dummy_observation_model) end_time = time.perf_counter() + policy.reset() per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms - except TimeoutException: + except TimeoutExceptionError: timeout_count += 1 per_forward_ms.append(args.timeout * 1000) print(f"\n[!] Timeout on sample {sample + 1}") @@ -250,6 +268,7 @@ def main(): if timeout_count > 0: print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)") + print(f"Timeout percentage: {timeout_count / args.num_samples * 100:.1f}%") # Memory footprint after timing rss_after = process.memory_info().rss