mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 10:40:04 +00:00
refactor(inference): improve timeout handling and enhance dummy observation generation
- Renamed TimeoutException to TimeoutExceptionError for clarity. - Updated dummy observation generation to include a task string. - Integrated pre-processing and post-processing steps in the main function. - Added deep copy of dummy observations to prevent mutation during processing. - Enhanced timeout handling to provide percentage of timeouts during inference.
This commit is contained in:
@@ -11,6 +11,7 @@ import os
|
||||
import signal
|
||||
import statistics
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@@ -19,18 +20,18 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
class TimeoutException:
|
||||
class TimeoutExceptionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timeout(seconds):
|
||||
def signal_handler(signum, frame):
|
||||
raise TimeoutException(f"Timed out after {seconds} seconds")
|
||||
raise TimeoutExceptionError(f"Timed out after {seconds} seconds")
|
||||
|
||||
# On Windows, signal is not available, so we can't use this timeout mechanism
|
||||
if not hasattr(signal, "SIGALRM"):
|
||||
@@ -84,12 +85,12 @@ def generate_dummy_observation(input_features: dict, device: str = "cpu") -> dic
|
||||
# Default: random normal for unknown types
|
||||
dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device)
|
||||
|
||||
# Add batch dimension
|
||||
for key in dummy_obs:
|
||||
dummy_obs[key] = dummy_obs[key].unsqueeze(0)
|
||||
# # Add batch dimension
|
||||
# for key in dummy_obs:
|
||||
# dummy_obs[key] = dummy_obs[key].unsqueeze(0)
|
||||
|
||||
# Add task string for language-conditioned policies
|
||||
dummy_obs["task"] = ""
|
||||
dummy_obs["task"] = " this is a dummy task"
|
||||
|
||||
return dummy_obs
|
||||
|
||||
@@ -151,7 +152,9 @@ def main():
|
||||
policy_class = get_policy_class(args.policy_type)
|
||||
policy: PreTrainedPolicy = policy_class.from_pretrained(args.policy_id)
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
policy.to(device, torch.float32)
|
||||
policy.config.device = device
|
||||
preprocessor, postprocessor = make_pre_post_processors(policy.config)
|
||||
|
||||
print(f"Policy loaded on {device}")
|
||||
print(f"Input features: {list(policy.config.input_features.keys())}")
|
||||
@@ -159,7 +162,7 @@ def main():
|
||||
|
||||
# Generate dummy observation based on policy input features
|
||||
dummy_observation = generate_dummy_observation(policy.config.input_features, device)
|
||||
dummy_observation["task"] = ""
|
||||
dummy_observation["task"] = "this is a dummy task"
|
||||
|
||||
# Helper to sync for fair timings
|
||||
def _sync(dev_=device):
|
||||
@@ -175,8 +178,15 @@ def main():
|
||||
print("Warming up...")
|
||||
with torch.no_grad():
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
orginal_dummy_observation = deepcopy(dummy_observation)
|
||||
for _ in range(args.warmup):
|
||||
_ = policy.select_action(dummy_observation)
|
||||
dummy_observation_model = deepcopy(orginal_dummy_observation)
|
||||
dummy_observation_model = preprocessor(dummy_observation_model)
|
||||
action_model = policy.select_action(dummy_observation_model)
|
||||
_ = postprocessor(action_model)
|
||||
policy.reset()
|
||||
_sync()
|
||||
|
||||
# Memory footprint before timing
|
||||
@@ -193,20 +203,25 @@ def main():
|
||||
start_events = []
|
||||
end_events = []
|
||||
timeout_count = 0
|
||||
orginal_dummy_observation = deepcopy(dummy_observation)
|
||||
|
||||
with torch.no_grad():
|
||||
for forward in tqdm(range(args.num_samples), desc="Trials"):
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
try:
|
||||
dummy_observation_model = deepcopy(orginal_dummy_observation)
|
||||
dummy_observation_model = preprocessor(dummy_observation)
|
||||
with timeout(args.timeout):
|
||||
start_event.record()
|
||||
_ = policy.select_action(dummy_observation)
|
||||
action_model = policy.select_action(dummy_observation_model)
|
||||
end_event.record()
|
||||
_ = postprocessor(action_model)
|
||||
policy.reset()
|
||||
|
||||
start_events.append(start_event)
|
||||
end_events.append(end_event)
|
||||
except TimeoutException:
|
||||
except TimeoutExceptionError:
|
||||
timeout_count += 1
|
||||
# Add placeholder for timeout
|
||||
start_events.append(None)
|
||||
@@ -236,13 +251,16 @@ def main():
|
||||
with torch.no_grad():
|
||||
for sample in tqdm(range(args.num_samples), desc="Samples"):
|
||||
try:
|
||||
dummy_observation_model = deepcopy(orginal_dummy_observation)
|
||||
dummy_observation_model = preprocessor(dummy_observation_model)
|
||||
with timeout(args.timeout):
|
||||
start_time = time.perf_counter()
|
||||
_ = policy.select_action(dummy_observation)
|
||||
action_model = policy.select_action(dummy_observation_model)
|
||||
end_time = time.perf_counter()
|
||||
policy.reset()
|
||||
|
||||
per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms
|
||||
except TimeoutException:
|
||||
except TimeoutExceptionError:
|
||||
timeout_count += 1
|
||||
per_forward_ms.append(args.timeout * 1000)
|
||||
print(f"\n[!] Timeout on sample {sample + 1}")
|
||||
@@ -250,6 +268,7 @@ def main():
|
||||
|
||||
if timeout_count > 0:
|
||||
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
|
||||
print(f"Timeout percentage: {timeout_count / args.num_samples * 100:.1f}%")
|
||||
|
||||
# Memory footprint after timing
|
||||
rss_after = process.memory_info().rss
|
||||
|
||||
Reference in New Issue
Block a user