refactor(inference): improve timeout handling and enhance dummy observation generation

- Renamed TimeoutException to TimeoutExceptionError for clarity.
- Updated dummy observation generation to include a task string.
- Integrated pre-processing and post-processing steps in the main function.
- Added deep copy of dummy observations to prevent mutation during processing.
- Enhanced timeout handling to provide percentage of timeouts during inference.
This commit is contained in:
AdilZouitine
2025-09-24 14:32:47 +02:00
parent 6eaf6a861a
commit cffd545527
+33 -14
View File
@@ -11,6 +11,7 @@ import os
import signal
import statistics
from contextlib import contextmanager
from copy import deepcopy
from datetime import datetime
from pathlib import Path
@@ -19,18 +20,18 @@ import torch
from tqdm import tqdm
from lerobot.configs.types import FeatureType
from lerobot.policies.factory import get_policy_class
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
class TimeoutException:
class TimeoutExceptionError(Exception):
pass
@contextmanager
def timeout(seconds):
def signal_handler(signum, frame):
raise TimeoutException(f"Timed out after {seconds} seconds")
raise TimeoutExceptionError(f"Timed out after {seconds} seconds")
# On Windows, signal is not available, so we can't use this timeout mechanism
if not hasattr(signal, "SIGALRM"):
@@ -84,12 +85,12 @@ def generate_dummy_observation(input_features: dict, device: str = "cpu") -> dic
# Default: random normal for unknown types
dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device)
# Add batch dimension
for key in dummy_obs:
dummy_obs[key] = dummy_obs[key].unsqueeze(0)
# # Add batch dimension
# for key in dummy_obs:
# dummy_obs[key] = dummy_obs[key].unsqueeze(0)
# Add task string for language-conditioned policies
dummy_obs["task"] = ""
dummy_obs["task"] = " this is a dummy task"
return dummy_obs
@@ -151,7 +152,9 @@ def main():
policy_class = get_policy_class(args.policy_type)
policy: PreTrainedPolicy = policy_class.from_pretrained(args.policy_id)
policy.eval()
policy.to(device)
policy.to(device, torch.float32)
policy.config.device = device
preprocessor, postprocessor = make_pre_post_processors(policy.config)
print(f"Policy loaded on {device}")
print(f"Input features: {list(policy.config.input_features.keys())}")
@@ -159,7 +162,7 @@ def main():
# Generate dummy observation based on policy input features
dummy_observation = generate_dummy_observation(policy.config.input_features, device)
dummy_observation["task"] = ""
dummy_observation["task"] = "this is a dummy task"
# Helper to sync for fair timings
def _sync(dev_=device):
@@ -175,8 +178,15 @@ def main():
print("Warming up...")
with torch.no_grad():
policy.reset()
preprocessor.reset()
postprocessor.reset()
orginal_dummy_observation = deepcopy(dummy_observation)
for _ in range(args.warmup):
_ = policy.select_action(dummy_observation)
dummy_observation_model = deepcopy(orginal_dummy_observation)
dummy_observation_model = preprocessor(dummy_observation_model)
action_model = policy.select_action(dummy_observation_model)
_ = postprocessor(action_model)
policy.reset()
_sync()
# Memory footprint before timing
@@ -193,20 +203,25 @@ def main():
start_events = []
end_events = []
timeout_count = 0
orginal_dummy_observation = deepcopy(dummy_observation)
with torch.no_grad():
for forward in tqdm(range(args.num_samples), desc="Trials"):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
try:
dummy_observation_model = deepcopy(orginal_dummy_observation)
dummy_observation_model = preprocessor(dummy_observation)
with timeout(args.timeout):
start_event.record()
_ = policy.select_action(dummy_observation)
action_model = policy.select_action(dummy_observation_model)
end_event.record()
_ = postprocessor(action_model)
policy.reset()
start_events.append(start_event)
end_events.append(end_event)
except TimeoutException:
except TimeoutExceptionError:
timeout_count += 1
# Add placeholder for timeout
start_events.append(None)
@@ -236,13 +251,16 @@ def main():
with torch.no_grad():
for sample in tqdm(range(args.num_samples), desc="Samples"):
try:
dummy_observation_model = deepcopy(orginal_dummy_observation)
dummy_observation_model = preprocessor(dummy_observation_model)
with timeout(args.timeout):
start_time = time.perf_counter()
_ = policy.select_action(dummy_observation)
action_model = policy.select_action(dummy_observation_model)
end_time = time.perf_counter()
policy.reset()
per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms
except TimeoutException:
except TimeoutExceptionError:
timeout_count += 1
per_forward_ms.append(args.timeout * 1000)
print(f"\n[!] Timeout on sample {sample + 1}")
@@ -250,6 +268,7 @@ def main():
if timeout_count > 0:
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
print(f"Timeout percentage: {timeout_count / args.num_samples * 100:.1f}%")
# Memory footprint after timing
rss_after = process.memory_info().rss