From 6fdee95923f9076d235f6f0f63d5a03a75588d24 Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Thu, 6 Nov 2025 16:02:47 +0700 Subject: [PATCH] Add torch compilation for eval_dataset --- examples/rtc/eval_dataset.py | 90 +++++++++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 1 deletion(-) diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index ec7bf1c99..f5ad118c2 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -11,11 +11,38 @@ It compares action predictions with and without RTC on dataset samples, measuring consistency and ground truth alignment. Usage: - python eval_dataset.py \ + # Basic usage + uv run python examples/rtc/eval_dataset.py \ --policy.path=helper2424/smolvla_check_rtc_last3 \ --dataset.repo_id=helper2424/check_rtc \ --rtc.execution_horizon=8 \ --device=mps + + # With torch.compile for faster inference (PyTorch 2.0+) + uv run python examples/rtc/eval_dataset.py \ + --policy.path=helper2424/smolvla_check_rtc_last3 \ + --dataset.repo_id=helper2424/check_rtc \ + --rtc.execution_horizon=8 \ + --device=mps \ + --use_torch_compile=true \ + --torch_compile_mode=max-autotune + + # With torch.compile for faster inference (PyTorch 2.0+) + uv run python examples/rtc/eval_dataset.py \ + --policy.path=helper2424/smolvla_check_rtc_last3 \ + --dataset.repo_id=helper2424/check_rtc \ + --rtc.execution_horizon=8 \ + --device=cuda \ + --use_torch_compile=true \ + --torch_compile_mode=reduce-overhead + + # With custom compile settings + uv run python examples/rtc/eval_dataset.py \ + --policy.path=helper2424/smolvla_check_rtc_last3 \ + --dataset.repo_id=helper2424/check_rtc \ + --use_torch_compile=true \ + --torch_compile_backend=inductor \ + --torch_compile_mode=max-autotune """ import logging @@ -98,6 +125,22 @@ class RTCEvalConfig(HubMixin): metadata={"help": "Inference delay for RTC"}, ) + # Torch compile configuration + use_torch_compile: bool = field( + default=False, + metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"}, + ) + + torch_compile_backend: str = field( + default="inductor", + metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"}, + ) + + torch_compile_mode: str = field( + default="default", + metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"}, + ) + def __post_init__(self): # Parse policy path policy_path = parser.get_path_arg("policy") @@ -144,6 +187,10 @@ class RTCEvaluator: self.policy.config.rtc_config = cfg.rtc self.policy.init_rtc_processor() + # Apply torch.compile if enabled + if cfg.use_torch_compile: + self._apply_torch_compile() + logging.info(f"Policy loaded: {self.policy.name}") logging.info(f"RTC enabled: {cfg.rtc.enabled}") logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}") @@ -162,6 +209,47 @@ class RTCEvaluator: }, ) + def _apply_torch_compile(self): + """Apply torch.compile to the policy model for faster inference.""" + try: + # Check if torch.compile is available (PyTorch 2.0+) + if not hasattr(torch, "compile"): + logging.warning( + "torch.compile is not available. Requires PyTorch 2.0+. " + f"Current version: {torch.__version__}. Skipping compilation." + ) + return + + logging.info("Applying torch.compile to policy model...") + logging.info(f" Backend: {self.cfg.torch_compile_backend}") + logging.info(f" Mode: {self.cfg.torch_compile_mode}") + + # Compile the policy's model (not the policy itself to preserve methods) + if hasattr(self.policy, "model"): + original_model = self.policy.model + compiled_model = torch.compile( + original_model, + backend=self.cfg.torch_compile_backend, + mode=self.cfg.torch_compile_mode, + ) + self.policy.model = compiled_model + logging.info("✓ Successfully compiled policy.model") + else: + logging.warning( + "Policy does not have a 'model' attribute. " + "Attempting to compile entire policy (may not work for all policy types)." + ) + self.policy = torch.compile( + self.policy, + backend=self.cfg.torch_compile_backend, + mode=self.cfg.torch_compile_mode, + ) + logging.info("✓ Successfully compiled policy") + + except Exception as e: + logging.error(f"Failed to apply torch.compile: {e}") + logging.warning("Continuing without torch.compile") + def run_evaluation(self): """Run evaluation on two random dataset samples.""" # Create output directory