mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
refactor(inference): improve timeout handling and report timeout percentage
- Commented out the timeout handling logic to prevent appending timeout values to the results. - Added a print statement to display the percentage of timeouts during inference.
This commit is contained in:
@@ -234,7 +234,8 @@ def main():
|
|||||||
per_forward_ms = []
|
per_forward_ms = []
|
||||||
for start_event, end_event in zip(start_events, end_events, strict=True):
|
for start_event, end_event in zip(start_events, end_events, strict=True):
|
||||||
if start_event is None:
|
if start_event is None:
|
||||||
per_forward_ms.append(args.timeout * 1000)
|
# per_forward_ms.append(args.timeout * 1000)
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
per_forward_ms.append(start_event.elapsed_time(end_event))
|
per_forward_ms.append(start_event.elapsed_time(end_event))
|
||||||
|
|
||||||
@@ -262,7 +263,8 @@ def main():
|
|||||||
per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms
|
per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms
|
||||||
except TimeoutExceptionError:
|
except TimeoutExceptionError:
|
||||||
timeout_count += 1
|
timeout_count += 1
|
||||||
per_forward_ms.append(args.timeout * 1000)
|
# per_forward_ms.append(args.timeout * 1000)
|
||||||
|
|
||||||
print(f"\n[!] Timeout on sample {sample + 1}")
|
print(f"\n[!] Timeout on sample {sample + 1}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -374,6 +376,7 @@ Benchmark completed successfully at {datetime.now().strftime("%Y-%m-%d %H:%M:%S"
|
|||||||
print(f"Device: {device}")
|
print(f"Device: {device}")
|
||||||
print(f"Samples: {args.num_samples} | Warmup: {args.warmup}")
|
print(f"Samples: {args.num_samples} | Warmup: {args.warmup}")
|
||||||
print(f"Model params: {num_params:,}")
|
print(f"Model params: {num_params:,}")
|
||||||
|
print(f"Timeout percentage: {timeout_count / args.num_samples * 100:.1f}%")
|
||||||
|
|
||||||
print("\nLatency per forward (ms):")
|
print("\nLatency per forward (ms):")
|
||||||
print(f" mean: {mean_ms:.3f} std: {std_ms:.3f}")
|
print(f" mean: {mean_ms:.3f} std: {std_ms:.3f}")
|
||||||
|
|||||||
Reference in New Issue
Block a user