mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
refactor(inference): improve timeout handling and enhance dummy observation generation
- Renamed TimeoutException to TimeoutExceptionError for clarity. - Updated dummy observation generation to include a task string. - Integrated pre-processing and post-processing steps in the main function. - Added deep copy of dummy observations to prevent mutation during processing. - Enhanced timeout handling to provide percentage of timeouts during inference.
This commit is contained in:
@@ -11,6 +11,7 @@ import os
|
|||||||
import signal
|
import signal
|
||||||
import statistics
|
import statistics
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from copy import deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -19,18 +20,18 @@ import torch
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType
|
from lerobot.configs.types import FeatureType
|
||||||
from lerobot.policies.factory import get_policy_class
|
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
|
||||||
|
|
||||||
class TimeoutException:
|
class TimeoutExceptionError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def timeout(seconds):
|
def timeout(seconds):
|
||||||
def signal_handler(signum, frame):
|
def signal_handler(signum, frame):
|
||||||
raise TimeoutException(f"Timed out after {seconds} seconds")
|
raise TimeoutExceptionError(f"Timed out after {seconds} seconds")
|
||||||
|
|
||||||
# On Windows, signal is not available, so we can't use this timeout mechanism
|
# On Windows, signal is not available, so we can't use this timeout mechanism
|
||||||
if not hasattr(signal, "SIGALRM"):
|
if not hasattr(signal, "SIGALRM"):
|
||||||
@@ -84,12 +85,12 @@ def generate_dummy_observation(input_features: dict, device: str = "cpu") -> dic
|
|||||||
# Default: random normal for unknown types
|
# Default: random normal for unknown types
|
||||||
dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device)
|
dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
# Add batch dimension
|
# # Add batch dimension
|
||||||
for key in dummy_obs:
|
# for key in dummy_obs:
|
||||||
dummy_obs[key] = dummy_obs[key].unsqueeze(0)
|
# dummy_obs[key] = dummy_obs[key].unsqueeze(0)
|
||||||
|
|
||||||
# Add task string for language-conditioned policies
|
# Add task string for language-conditioned policies
|
||||||
dummy_obs["task"] = ""
|
dummy_obs["task"] = " this is a dummy task"
|
||||||
|
|
||||||
return dummy_obs
|
return dummy_obs
|
||||||
|
|
||||||
@@ -151,7 +152,9 @@ def main():
|
|||||||
policy_class = get_policy_class(args.policy_type)
|
policy_class = get_policy_class(args.policy_type)
|
||||||
policy: PreTrainedPolicy = policy_class.from_pretrained(args.policy_id)
|
policy: PreTrainedPolicy = policy_class.from_pretrained(args.policy_id)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
policy.to(device)
|
policy.to(device, torch.float32)
|
||||||
|
policy.config.device = device
|
||||||
|
preprocessor, postprocessor = make_pre_post_processors(policy.config)
|
||||||
|
|
||||||
print(f"Policy loaded on {device}")
|
print(f"Policy loaded on {device}")
|
||||||
print(f"Input features: {list(policy.config.input_features.keys())}")
|
print(f"Input features: {list(policy.config.input_features.keys())}")
|
||||||
@@ -159,7 +162,7 @@ def main():
|
|||||||
|
|
||||||
# Generate dummy observation based on policy input features
|
# Generate dummy observation based on policy input features
|
||||||
dummy_observation = generate_dummy_observation(policy.config.input_features, device)
|
dummy_observation = generate_dummy_observation(policy.config.input_features, device)
|
||||||
dummy_observation["task"] = ""
|
dummy_observation["task"] = "this is a dummy task"
|
||||||
|
|
||||||
# Helper to sync for fair timings
|
# Helper to sync for fair timings
|
||||||
def _sync(dev_=device):
|
def _sync(dev_=device):
|
||||||
@@ -175,8 +178,15 @@ def main():
|
|||||||
print("Warming up...")
|
print("Warming up...")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
policy.reset()
|
policy.reset()
|
||||||
|
preprocessor.reset()
|
||||||
|
postprocessor.reset()
|
||||||
|
orginal_dummy_observation = deepcopy(dummy_observation)
|
||||||
for _ in range(args.warmup):
|
for _ in range(args.warmup):
|
||||||
_ = policy.select_action(dummy_observation)
|
dummy_observation_model = deepcopy(orginal_dummy_observation)
|
||||||
|
dummy_observation_model = preprocessor(dummy_observation_model)
|
||||||
|
action_model = policy.select_action(dummy_observation_model)
|
||||||
|
_ = postprocessor(action_model)
|
||||||
|
policy.reset()
|
||||||
_sync()
|
_sync()
|
||||||
|
|
||||||
# Memory footprint before timing
|
# Memory footprint before timing
|
||||||
@@ -193,20 +203,25 @@ def main():
|
|||||||
start_events = []
|
start_events = []
|
||||||
end_events = []
|
end_events = []
|
||||||
timeout_count = 0
|
timeout_count = 0
|
||||||
|
orginal_dummy_observation = deepcopy(dummy_observation)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for forward in tqdm(range(args.num_samples), desc="Trials"):
|
for forward in tqdm(range(args.num_samples), desc="Trials"):
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
try:
|
try:
|
||||||
|
dummy_observation_model = deepcopy(orginal_dummy_observation)
|
||||||
|
dummy_observation_model = preprocessor(dummy_observation)
|
||||||
with timeout(args.timeout):
|
with timeout(args.timeout):
|
||||||
start_event.record()
|
start_event.record()
|
||||||
_ = policy.select_action(dummy_observation)
|
action_model = policy.select_action(dummy_observation_model)
|
||||||
end_event.record()
|
end_event.record()
|
||||||
|
_ = postprocessor(action_model)
|
||||||
|
policy.reset()
|
||||||
|
|
||||||
start_events.append(start_event)
|
start_events.append(start_event)
|
||||||
end_events.append(end_event)
|
end_events.append(end_event)
|
||||||
except TimeoutException:
|
except TimeoutExceptionError:
|
||||||
timeout_count += 1
|
timeout_count += 1
|
||||||
# Add placeholder for timeout
|
# Add placeholder for timeout
|
||||||
start_events.append(None)
|
start_events.append(None)
|
||||||
@@ -236,13 +251,16 @@ def main():
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for sample in tqdm(range(args.num_samples), desc="Samples"):
|
for sample in tqdm(range(args.num_samples), desc="Samples"):
|
||||||
try:
|
try:
|
||||||
|
dummy_observation_model = deepcopy(orginal_dummy_observation)
|
||||||
|
dummy_observation_model = preprocessor(dummy_observation_model)
|
||||||
with timeout(args.timeout):
|
with timeout(args.timeout):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
_ = policy.select_action(dummy_observation)
|
action_model = policy.select_action(dummy_observation_model)
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
|
policy.reset()
|
||||||
|
|
||||||
per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms
|
per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms
|
||||||
except TimeoutException:
|
except TimeoutExceptionError:
|
||||||
timeout_count += 1
|
timeout_count += 1
|
||||||
per_forward_ms.append(args.timeout * 1000)
|
per_forward_ms.append(args.timeout * 1000)
|
||||||
print(f"\n[!] Timeout on sample {sample + 1}")
|
print(f"\n[!] Timeout on sample {sample + 1}")
|
||||||
@@ -250,6 +268,7 @@ def main():
|
|||||||
|
|
||||||
if timeout_count > 0:
|
if timeout_count > 0:
|
||||||
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
|
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
|
||||||
|
print(f"Timeout percentage: {timeout_count / args.num_samples * 100:.1f}%")
|
||||||
|
|
||||||
# Memory footprint after timing
|
# Memory footprint after timing
|
||||||
rss_after = process.memory_info().rss
|
rss_after = process.memory_info().rss
|
||||||
|
|||||||
Reference in New Issue
Block a user