Add torch compilation for eval_dataset

This commit is contained in:
Eugene Mironov
2025-11-06 16:02:47 +07:00
parent c5b246f57c
commit 6fdee95923
+89 -1
View File
@@ -11,11 +11,38 @@ It compares action predictions with and without RTC on dataset samples,
measuring consistency and ground truth alignment.
Usage:
python eval_dataset.py \
# Basic usage
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps
# With torch.compile for faster inference (PyTorch 2.0+)
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--use_torch_compile=true \
--torch_compile_mode=max-autotune
# With torch.compile for faster inference (PyTorch 2.0+)
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=cuda \
--use_torch_compile=true \
--torch_compile_mode=reduce-overhead
# With custom compile settings
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--use_torch_compile=true \
--torch_compile_backend=inductor \
--torch_compile_mode=max-autotune
"""
import logging
@@ -98,6 +125,22 @@ class RTCEvalConfig(HubMixin):
metadata={"help": "Inference delay for RTC"},
)
# Torch compile configuration
use_torch_compile: bool = field(
default=False,
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
)
torch_compile_backend: str = field(
default="inductor",
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
)
torch_compile_mode: str = field(
default="default",
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
)
def __post_init__(self):
# Parse policy path
policy_path = parser.get_path_arg("policy")
@@ -144,6 +187,10 @@ class RTCEvaluator:
self.policy.config.rtc_config = cfg.rtc
self.policy.init_rtc_processor()
# Apply torch.compile if enabled
if cfg.use_torch_compile:
self._apply_torch_compile()
logging.info(f"Policy loaded: {self.policy.name}")
logging.info(f"RTC enabled: {cfg.rtc.enabled}")
logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
@@ -162,6 +209,47 @@ class RTCEvaluator:
},
)
def _apply_torch_compile(self):
"""Apply torch.compile to the policy model for faster inference."""
try:
# Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"):
logging.warning(
"torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return
logging.info("Applying torch.compile to policy model...")
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
# Compile the policy's model (not the policy itself to preserve methods)
if hasattr(self.policy, "model"):
original_model = self.policy.model
compiled_model = torch.compile(
original_model,
backend=self.cfg.torch_compile_backend,
mode=self.cfg.torch_compile_mode,
)
self.policy.model = compiled_model
logging.info("✓ Successfully compiled policy.model")
else:
logging.warning(
"Policy does not have a 'model' attribute. "
"Attempting to compile entire policy (may not work for all policy types)."
)
self.policy = torch.compile(
self.policy,
backend=self.cfg.torch_compile_backend,
mode=self.cfg.torch_compile_mode,
)
logging.info("✓ Successfully compiled policy")
except Exception as e:
logging.error(f"Failed to apply torch.compile: {e}")
logging.warning("Continuing without torch.compile")
def run_evaluation(self):
"""Run evaluation on two random dataset samples."""
# Create output directory