mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
359 lines
13 KiB
Python
359 lines
13 KiB
Python
#!/usr/bin/env python
|
|
|
|
"""
|
|
Comprehensive profiling script for Pi0 with RTC.
|
|
|
|
This script demonstrates how to use all the profiling tools to identify
|
|
bottlenecks in Pi0 policy inference with RTC enabled.
|
|
|
|
It profiles:
|
|
1. Overall inference time
|
|
2. RTC-specific operations (guidance, weights, etc.)
|
|
3. Preprocessing/postprocessing
|
|
4. Individual method timings
|
|
|
|
Usage:
|
|
uv run examples/rtc/profile_pi0_rtc_detailed.py \
|
|
--policy_path=helper2424/pi05_check_rtc \
|
|
--device=mps \
|
|
--num_iterations=20 \
|
|
--execution_horizon=20 \
|
|
--enable_rtc_profiling
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import sys
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
from lerobot.configs.types import RTCAttentionSchedule
|
|
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
|
from lerobot.utils.profiling import (
|
|
ProfileContext,
|
|
clear_profiling_stats,
|
|
enable_profiling,
|
|
get_profiling_stats,
|
|
print_profiling_summary,
|
|
)
|
|
|
|
# Import monkey patching for RTC profiling
|
|
try:
|
|
from examples.rtc.add_rtc_profiling import monkey_patch_rtc_profiling
|
|
except ImportError:
|
|
logging.warning("Could not import add_rtc_profiling, detailed RTC profiling disabled")
|
|
monkey_patch_rtc_profiling = None
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def create_mock_observation(policy_config, device: str) -> dict:
|
|
"""Create a mock observation matching policy requirements.
|
|
|
|
Args:
|
|
policy_config: Policy configuration
|
|
device: Device to create tensors on
|
|
|
|
Returns:
|
|
Mock observation dictionary
|
|
"""
|
|
obs = {}
|
|
|
|
# Create mock state observation
|
|
state_dim = 10 # Typical robot state dimension
|
|
obs["observation.state"] = torch.randn(1, state_dim, device=device)
|
|
|
|
# Create mock images if needed
|
|
# For Pi0, we typically need at least one image
|
|
image_height = 224
|
|
image_width = 224
|
|
|
|
# Common image keys for Pi0
|
|
image_keys = ["observation.images.gripper", "observation.images.front"]
|
|
|
|
for key in image_keys:
|
|
# Images should be [B, C, H, W] and normalized to [0, 1]
|
|
obs[key] = torch.rand(1, 3, image_height, image_width, device=device)
|
|
|
|
# Add task
|
|
obs["task"] = ["Pick up the object"]
|
|
|
|
# Add language tokens and attention mask (required for Pi0)
|
|
# These are mock values - in real usage they come from tokenizer
|
|
max_seq_len = 32
|
|
obs["observation.language_tokens"] = torch.randint(0, 1000, (1, max_seq_len), device=device)
|
|
obs["observation.language_attention_mask"] = torch.ones(1, max_seq_len, device=device)
|
|
|
|
return obs
|
|
|
|
|
|
def profile_single_iteration(
|
|
policy,
|
|
preprocessor,
|
|
postprocessor,
|
|
observation: dict,
|
|
prev_actions: torch.Tensor | None,
|
|
use_rtc: bool,
|
|
inference_delay: int = 0,
|
|
) -> tuple[torch.Tensor, torch.Tensor | None, dict]:
|
|
"""Profile a single inference iteration.
|
|
|
|
Args:
|
|
policy: Policy instance
|
|
preprocessor: Observation preprocessor
|
|
postprocessor: Action postprocessor
|
|
observation: Input observation
|
|
prev_actions: Previous action chunk (for RTC)
|
|
use_rtc: Whether RTC is enabled
|
|
inference_delay: Inference delay in timesteps
|
|
|
|
Returns:
|
|
Tuple of (actions, new_prev_actions, timings)
|
|
"""
|
|
timings = {}
|
|
|
|
with ProfileContext("iteration.total"):
|
|
# Preprocessing
|
|
with ProfileContext("iteration.preprocessing"):
|
|
preprocessed_obs = preprocessor(observation)
|
|
|
|
# Policy inference
|
|
with ProfileContext("iteration.policy_inference"):
|
|
if use_rtc:
|
|
actions = policy.predict_action_chunk(
|
|
preprocessed_obs,
|
|
inference_delay=inference_delay,
|
|
prev_chunk_left_over=prev_actions,
|
|
)
|
|
else:
|
|
actions = policy.predict_action_chunk(preprocessed_obs)
|
|
|
|
# Clone for next iteration (if RTC)
|
|
new_prev_actions = None
|
|
if use_rtc:
|
|
with ProfileContext("iteration.prepare_prev_actions"):
|
|
execution_horizon = policy.config.rtc_config.execution_horizon
|
|
if actions.shape[1] > execution_horizon:
|
|
new_prev_actions = actions[:, execution_horizon:].clone()
|
|
|
|
# Postprocessing
|
|
with ProfileContext("iteration.postprocessing"):
|
|
processed_actions = postprocessor(actions)
|
|
|
|
return processed_actions, new_prev_actions, timings
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Detailed profiling for Pi0 with RTC")
|
|
parser.add_argument("--policy_path", type=str, required=True, help="Path to pretrained policy")
|
|
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu/mps)")
|
|
parser.add_argument("--num_iterations", type=int, default=20, help="Number of iterations")
|
|
parser.add_argument("--execution_horizon", type=int, default=10, help="RTC execution horizon")
|
|
parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations")
|
|
parser.add_argument("--enable_rtc_profiling", action="store_true", help="Enable detailed RTC profiling")
|
|
parser.add_argument("--use_torch_compile", action="store_true", help="Use torch.compile")
|
|
|
|
args = parser.parse_args()
|
|
|
|
logger.info("="*80)
|
|
logger.info("DETAILED PI0 RTC PROFILING")
|
|
logger.info("="*80)
|
|
logger.info(f"Policy: {args.policy_path}")
|
|
logger.info(f"Device: {args.device}")
|
|
logger.info(f"Iterations: {args.num_iterations}")
|
|
logger.info(f"Execution Horizon: {args.execution_horizon}")
|
|
logger.info(f"RTC Profiling: {args.enable_rtc_profiling}")
|
|
logger.info("="*80 + "\n")
|
|
|
|
# Enable profiling
|
|
enable_profiling()
|
|
|
|
# Apply RTC profiling if requested
|
|
if args.enable_rtc_profiling:
|
|
if monkey_patch_rtc_profiling is not None:
|
|
monkey_patch_rtc_profiling()
|
|
logger.info("✓ Detailed RTC profiling enabled\n")
|
|
else:
|
|
logger.warning("⚠ Could not enable detailed RTC profiling\n")
|
|
|
|
# Load policy
|
|
logger.info("Loading policy...")
|
|
config = PreTrainedConfig.from_pretrained(args.policy_path)
|
|
|
|
if hasattr(config, "compile_model"):
|
|
config.compile_model = args.use_torch_compile
|
|
|
|
policy_class = get_policy_class(config.type)
|
|
policy = policy_class.from_pretrained(args.policy_path, config=config)
|
|
|
|
# Configure RTC
|
|
policy.config.rtc_config = RTCConfig(
|
|
enabled=True,
|
|
execution_horizon=args.execution_horizon,
|
|
max_guidance_weight=1.0,
|
|
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
|
)
|
|
policy.init_rtc_processor()
|
|
|
|
policy = policy.to(args.device)
|
|
policy.eval()
|
|
|
|
logger.info(f"✓ Policy loaded: {config.type}\n")
|
|
|
|
# Create preprocessor and postprocessor
|
|
logger.info("Loading preprocessor/postprocessor...")
|
|
preprocessor, postprocessor = make_pre_post_processors(
|
|
policy_cfg=config,
|
|
pretrained_path=args.policy_path,
|
|
dataset_stats=None,
|
|
preprocessor_overrides={
|
|
"device_processor": {"device": args.device},
|
|
},
|
|
)
|
|
logger.info("✓ Preprocessor/postprocessor loaded\n")
|
|
|
|
# Create mock observation
|
|
logger.info("Creating mock observation...")
|
|
observation = create_mock_observation(config, args.device)
|
|
logger.info("✓ Mock observation created\n")
|
|
|
|
# Warmup
|
|
logger.info(f"Warming up ({args.warmup_iterations} iterations)...")
|
|
prev_actions = None
|
|
for i in range(args.warmup_iterations):
|
|
with torch.no_grad():
|
|
_, prev_actions, _ = profile_single_iteration(
|
|
policy=policy,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
observation=observation,
|
|
prev_actions=prev_actions,
|
|
use_rtc=True,
|
|
inference_delay=0,
|
|
)
|
|
|
|
# Clear warmup stats
|
|
clear_profiling_stats()
|
|
logger.info("✓ Warmup complete\n")
|
|
|
|
# Profiled run WITH RTC
|
|
logger.info(f"Running profiled iterations WITH RTC ({args.num_iterations} iterations)...")
|
|
prev_actions = None
|
|
iteration_times = []
|
|
|
|
for i in range(args.num_iterations):
|
|
start = time.perf_counter()
|
|
|
|
with torch.no_grad():
|
|
_, prev_actions, _ = profile_single_iteration(
|
|
policy=policy,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
observation=observation,
|
|
prev_actions=prev_actions,
|
|
use_rtc=True,
|
|
inference_delay=0,
|
|
)
|
|
|
|
# Sync CUDA if needed
|
|
if args.device.startswith("cuda"):
|
|
torch.cuda.synchronize()
|
|
|
|
elapsed = time.perf_counter() - start
|
|
iteration_times.append(elapsed)
|
|
|
|
if (i + 1) % 5 == 0:
|
|
logger.info(f" Completed {i+1}/{args.num_iterations}")
|
|
|
|
logger.info("✓ Profiling complete\n")
|
|
|
|
# Print summary statistics
|
|
logger.info("\n" + "="*80)
|
|
logger.info("ITERATION TIMING SUMMARY")
|
|
logger.info("="*80)
|
|
|
|
times_arr = np.array(iteration_times)
|
|
logger.info(f"Mean time: {np.mean(times_arr)*1000:.2f} ms")
|
|
logger.info(f"Median time: {np.median(times_arr)*1000:.2f} ms")
|
|
logger.info(f"Std dev: {np.std(times_arr)*1000:.2f} ms")
|
|
logger.info(f"Min time: {np.min(times_arr)*1000:.2f} ms")
|
|
logger.info(f"Max time: {np.max(times_arr)*1000:.2f} ms")
|
|
logger.info(f"Total time: {np.sum(times_arr):.2f} s")
|
|
logger.info(f"Throughput: {len(times_arr)/np.sum(times_arr):.2f} iter/s")
|
|
logger.info("="*80 + "\n")
|
|
|
|
# Print detailed profiling breakdown
|
|
print_profiling_summary(sort_by="total")
|
|
|
|
# Print key insights
|
|
stats = get_profiling_stats()
|
|
|
|
logger.info("\n" + "="*80)
|
|
logger.info("KEY INSIGHTS")
|
|
logger.info("="*80)
|
|
|
|
# Find bottlenecks
|
|
if stats:
|
|
policy_inference_time = stats.get("iteration.policy_inference", {}).get("mean", 0)
|
|
preprocessing_time = stats.get("iteration.preprocessing", {}).get("mean", 0)
|
|
postprocessing_time = stats.get("iteration.postprocessing", {}).get("mean", 0)
|
|
|
|
total_time = policy_inference_time + preprocessing_time + postprocessing_time
|
|
|
|
if total_time > 0:
|
|
logger.info(f"\nTime breakdown:")
|
|
logger.info(f" Policy inference: {policy_inference_time*1000:.2f} ms ({policy_inference_time/total_time*100:.1f}%)")
|
|
logger.info(f" Preprocessing: {preprocessing_time*1000:.2f} ms ({preprocessing_time/total_time*100:.1f}%)")
|
|
logger.info(f" Postprocessing: {postprocessing_time*1000:.2f} ms ({postprocessing_time/total_time*100:.1f}%)")
|
|
|
|
# RTC-specific insights
|
|
if args.enable_rtc_profiling:
|
|
rtc_guidance = stats.get("rtc.denoise_step.guidance_computation", {}).get("mean", 0)
|
|
rtc_autograd = stats.get("rtc.denoise_step.autograd_correction", {}).get("mean", 0)
|
|
rtc_base = stats.get("rtc.denoise_step.base_denoising", {}).get("mean", 0)
|
|
|
|
if rtc_guidance > 0:
|
|
logger.info(f"\nRTC breakdown:")
|
|
logger.info(f" Base denoising: {rtc_base*1000:.2f} ms")
|
|
logger.info(f" Guidance compute: {rtc_guidance*1000:.2f} ms")
|
|
logger.info(f" Autograd correct: {rtc_autograd*1000:.2f} ms")
|
|
logger.info(f" RTC overhead: {(rtc_guidance - rtc_base)*1000:.2f} ms")
|
|
|
|
# Recommendations
|
|
logger.info("\nRecommendations:")
|
|
|
|
if preprocessing_time > policy_inference_time * 0.3:
|
|
logger.info(" ⚠ Preprocessing is taking >30% of time")
|
|
logger.info(" → Consider reducing image resolution")
|
|
logger.info(" → Consider using fewer cameras")
|
|
|
|
if args.enable_rtc_profiling and rtc_autograd > rtc_base * 0.5:
|
|
logger.info(" ⚠ RTC autograd overhead is significant")
|
|
logger.info(" → This is expected, but consider increasing execution_horizon")
|
|
logger.info(" → Try torch.compile if not already enabled")
|
|
|
|
if not args.use_torch_compile:
|
|
logger.info(" 💡 torch.compile not enabled")
|
|
logger.info(" → Try --use_torch_compile for potential speedup")
|
|
|
|
logger.info("="*80 + "\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except KeyboardInterrupt:
|
|
logger.info("\n\nProfiling interrupted by user")
|
|
sys.exit(0)
|
|
except Exception as e:
|
|
logger.error(f"\n\nError during profiling: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
|