mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
Add torch compilation for eval_dataset
This commit is contained in:
@@ -11,11 +11,38 @@ It compares action predictions with and without RTC on dataset samples,
|
||||
measuring consistency and ground truth alignment.
|
||||
|
||||
Usage:
|
||||
python eval_dataset.py \
|
||||
# Basic usage
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps
|
||||
|
||||
# With torch.compile for faster inference (PyTorch 2.0+)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_mode=max-autotune
|
||||
|
||||
# With torch.compile for faster inference (PyTorch 2.0+)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_mode=reduce-overhead
|
||||
|
||||
# With custom compile settings
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_backend=inductor \
|
||||
--torch_compile_mode=max-autotune
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -98,6 +125,22 @@ class RTCEvalConfig(HubMixin):
|
||||
metadata={"help": "Inference delay for RTC"},
|
||||
)
|
||||
|
||||
# Torch compile configuration
|
||||
use_torch_compile: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
||||
)
|
||||
|
||||
torch_compile_backend: str = field(
|
||||
default="inductor",
|
||||
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
||||
)
|
||||
|
||||
torch_compile_mode: str = field(
|
||||
default="default",
|
||||
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Parse policy path
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
@@ -144,6 +187,10 @@ class RTCEvaluator:
|
||||
self.policy.config.rtc_config = cfg.rtc
|
||||
self.policy.init_rtc_processor()
|
||||
|
||||
# Apply torch.compile if enabled
|
||||
if cfg.use_torch_compile:
|
||||
self._apply_torch_compile()
|
||||
|
||||
logging.info(f"Policy loaded: {self.policy.name}")
|
||||
logging.info(f"RTC enabled: {cfg.rtc.enabled}")
|
||||
logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
|
||||
@@ -162,6 +209,47 @@ class RTCEvaluator:
|
||||
},
|
||||
)
|
||||
|
||||
def _apply_torch_compile(self):
|
||||
"""Apply torch.compile to the policy model for faster inference."""
|
||||
try:
|
||||
# Check if torch.compile is available (PyTorch 2.0+)
|
||||
if not hasattr(torch, "compile"):
|
||||
logging.warning(
|
||||
"torch.compile is not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return
|
||||
|
||||
logging.info("Applying torch.compile to policy model...")
|
||||
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
|
||||
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
|
||||
|
||||
# Compile the policy's model (not the policy itself to preserve methods)
|
||||
if hasattr(self.policy, "model"):
|
||||
original_model = self.policy.model
|
||||
compiled_model = torch.compile(
|
||||
original_model,
|
||||
backend=self.cfg.torch_compile_backend,
|
||||
mode=self.cfg.torch_compile_mode,
|
||||
)
|
||||
self.policy.model = compiled_model
|
||||
logging.info("✓ Successfully compiled policy.model")
|
||||
else:
|
||||
logging.warning(
|
||||
"Policy does not have a 'model' attribute. "
|
||||
"Attempting to compile entire policy (may not work for all policy types)."
|
||||
)
|
||||
self.policy = torch.compile(
|
||||
self.policy,
|
||||
backend=self.cfg.torch_compile_backend,
|
||||
mode=self.cfg.torch_compile_mode,
|
||||
)
|
||||
logging.info("✓ Successfully compiled policy")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to apply torch.compile: {e}")
|
||||
logging.warning("Continuing without torch.compile")
|
||||
|
||||
def run_evaluation(self):
|
||||
"""Run evaluation on two random dataset samples."""
|
||||
# Create output directory
|
||||
|
||||
Reference in New Issue
Block a user