refactor(inference): improve timeout handling and enhance dummy observation generation

- Renamed TimeoutException to TimeoutExceptionError for clarity.
- Updated dummy observation generation to include a task string.
- Integrated pre-processing and post-processing steps in the main function.
- Added deep copy of dummy observations to prevent mutation during processing.
- Enhanced timeout handling to provide percentage of timeouts during inference.
This commit is contained in:
AdilZouitine
2025-09-24 14:32:47 +02:00
parent 6eaf6a861a
commit cffd545527
+33 -14
View File
@@ -11,6 +11,7 @@ import os
import signal import signal
import statistics import statistics
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -19,18 +20,18 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from lerobot.configs.types import FeatureType from lerobot.configs.types import FeatureType
from lerobot.policies.factory import get_policy_class from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
class TimeoutException: class TimeoutExceptionError(Exception):
pass pass
@contextmanager @contextmanager
def timeout(seconds): def timeout(seconds):
def signal_handler(signum, frame): def signal_handler(signum, frame):
raise TimeoutException(f"Timed out after {seconds} seconds") raise TimeoutExceptionError(f"Timed out after {seconds} seconds")
# On Windows, signal is not available, so we can't use this timeout mechanism # On Windows, signal is not available, so we can't use this timeout mechanism
if not hasattr(signal, "SIGALRM"): if not hasattr(signal, "SIGALRM"):
@@ -84,12 +85,12 @@ def generate_dummy_observation(input_features: dict, device: str = "cpu") -> dic
# Default: random normal for unknown types # Default: random normal for unknown types
dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device) dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device)
# Add batch dimension # # Add batch dimension
for key in dummy_obs: # for key in dummy_obs:
dummy_obs[key] = dummy_obs[key].unsqueeze(0) # dummy_obs[key] = dummy_obs[key].unsqueeze(0)
# Add task string for language-conditioned policies # Add task string for language-conditioned policies
dummy_obs["task"] = "" dummy_obs["task"] = " this is a dummy task"
return dummy_obs return dummy_obs
@@ -151,7 +152,9 @@ def main():
policy_class = get_policy_class(args.policy_type) policy_class = get_policy_class(args.policy_type)
policy: PreTrainedPolicy = policy_class.from_pretrained(args.policy_id) policy: PreTrainedPolicy = policy_class.from_pretrained(args.policy_id)
policy.eval() policy.eval()
policy.to(device) policy.to(device, torch.float32)
policy.config.device = device
preprocessor, postprocessor = make_pre_post_processors(policy.config)
print(f"Policy loaded on {device}") print(f"Policy loaded on {device}")
print(f"Input features: {list(policy.config.input_features.keys())}") print(f"Input features: {list(policy.config.input_features.keys())}")
@@ -159,7 +162,7 @@ def main():
# Generate dummy observation based on policy input features # Generate dummy observation based on policy input features
dummy_observation = generate_dummy_observation(policy.config.input_features, device) dummy_observation = generate_dummy_observation(policy.config.input_features, device)
dummy_observation["task"] = "" dummy_observation["task"] = "this is a dummy task"
# Helper to sync for fair timings # Helper to sync for fair timings
def _sync(dev_=device): def _sync(dev_=device):
@@ -175,8 +178,15 @@ def main():
print("Warming up...") print("Warming up...")
with torch.no_grad(): with torch.no_grad():
policy.reset() policy.reset()
preprocessor.reset()
postprocessor.reset()
orginal_dummy_observation = deepcopy(dummy_observation)
for _ in range(args.warmup): for _ in range(args.warmup):
_ = policy.select_action(dummy_observation) dummy_observation_model = deepcopy(orginal_dummy_observation)
dummy_observation_model = preprocessor(dummy_observation_model)
action_model = policy.select_action(dummy_observation_model)
_ = postprocessor(action_model)
policy.reset()
_sync() _sync()
# Memory footprint before timing # Memory footprint before timing
@@ -193,20 +203,25 @@ def main():
start_events = [] start_events = []
end_events = [] end_events = []
timeout_count = 0 timeout_count = 0
orginal_dummy_observation = deepcopy(dummy_observation)
with torch.no_grad(): with torch.no_grad():
for forward in tqdm(range(args.num_samples), desc="Trials"): for forward in tqdm(range(args.num_samples), desc="Trials"):
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
try: try:
dummy_observation_model = deepcopy(orginal_dummy_observation)
dummy_observation_model = preprocessor(dummy_observation)
with timeout(args.timeout): with timeout(args.timeout):
start_event.record() start_event.record()
_ = policy.select_action(dummy_observation) action_model = policy.select_action(dummy_observation_model)
end_event.record() end_event.record()
_ = postprocessor(action_model)
policy.reset()
start_events.append(start_event) start_events.append(start_event)
end_events.append(end_event) end_events.append(end_event)
except TimeoutException: except TimeoutExceptionError:
timeout_count += 1 timeout_count += 1
# Add placeholder for timeout # Add placeholder for timeout
start_events.append(None) start_events.append(None)
@@ -236,13 +251,16 @@ def main():
with torch.no_grad(): with torch.no_grad():
for sample in tqdm(range(args.num_samples), desc="Samples"): for sample in tqdm(range(args.num_samples), desc="Samples"):
try: try:
dummy_observation_model = deepcopy(orginal_dummy_observation)
dummy_observation_model = preprocessor(dummy_observation_model)
with timeout(args.timeout): with timeout(args.timeout):
start_time = time.perf_counter() start_time = time.perf_counter()
_ = policy.select_action(dummy_observation) action_model = policy.select_action(dummy_observation_model)
end_time = time.perf_counter() end_time = time.perf_counter()
policy.reset()
per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms
except TimeoutException: except TimeoutExceptionError:
timeout_count += 1 timeout_count += 1
per_forward_ms.append(args.timeout * 1000) per_forward_ms.append(args.timeout * 1000)
print(f"\n[!] Timeout on sample {sample + 1}") print(f"\n[!] Timeout on sample {sample + 1}")
@@ -250,6 +268,7 @@ def main():
if timeout_count > 0: if timeout_count > 0:
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)") print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
print(f"Timeout percentage: {timeout_count / args.num_samples * 100:.1f}%")
# Memory footprint after timing # Memory footprint after timing
rss_after = process.memory_info().rss rss_after = process.memory_info().rss