mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 17:50:09 +00:00
Add torch compilation for eval_dataset
This commit is contained in:
@@ -11,11 +11,38 @@ It compares action predictions with and without RTC on dataset samples,
|
|||||||
measuring consistency and ground truth alignment.
|
measuring consistency and ground truth alignment.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python eval_dataset.py \
|
# Basic usage
|
||||||
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||||
--dataset.repo_id=helper2424/check_rtc \
|
--dataset.repo_id=helper2424/check_rtc \
|
||||||
--rtc.execution_horizon=8 \
|
--rtc.execution_horizon=8 \
|
||||||
--device=mps
|
--device=mps
|
||||||
|
|
||||||
|
# With torch.compile for faster inference (PyTorch 2.0+)
|
||||||
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||||
|
--dataset.repo_id=helper2424/check_rtc \
|
||||||
|
--rtc.execution_horizon=8 \
|
||||||
|
--device=mps \
|
||||||
|
--use_torch_compile=true \
|
||||||
|
--torch_compile_mode=max-autotune
|
||||||
|
|
||||||
|
# With torch.compile for faster inference (PyTorch 2.0+)
|
||||||
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||||
|
--dataset.repo_id=helper2424/check_rtc \
|
||||||
|
--rtc.execution_horizon=8 \
|
||||||
|
--device=cuda \
|
||||||
|
--use_torch_compile=true \
|
||||||
|
--torch_compile_mode=reduce-overhead
|
||||||
|
|
||||||
|
# With custom compile settings
|
||||||
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||||
|
--dataset.repo_id=helper2424/check_rtc \
|
||||||
|
--use_torch_compile=true \
|
||||||
|
--torch_compile_backend=inductor \
|
||||||
|
--torch_compile_mode=max-autotune
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -98,6 +125,22 @@ class RTCEvalConfig(HubMixin):
|
|||||||
metadata={"help": "Inference delay for RTC"},
|
metadata={"help": "Inference delay for RTC"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Torch compile configuration
|
||||||
|
use_torch_compile: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_compile_backend: str = field(
|
||||||
|
default="inductor",
|
||||||
|
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_compile_mode: str = field(
|
||||||
|
default="default",
|
||||||
|
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Parse policy path
|
# Parse policy path
|
||||||
policy_path = parser.get_path_arg("policy")
|
policy_path = parser.get_path_arg("policy")
|
||||||
@@ -144,6 +187,10 @@ class RTCEvaluator:
|
|||||||
self.policy.config.rtc_config = cfg.rtc
|
self.policy.config.rtc_config = cfg.rtc
|
||||||
self.policy.init_rtc_processor()
|
self.policy.init_rtc_processor()
|
||||||
|
|
||||||
|
# Apply torch.compile if enabled
|
||||||
|
if cfg.use_torch_compile:
|
||||||
|
self._apply_torch_compile()
|
||||||
|
|
||||||
logging.info(f"Policy loaded: {self.policy.name}")
|
logging.info(f"Policy loaded: {self.policy.name}")
|
||||||
logging.info(f"RTC enabled: {cfg.rtc.enabled}")
|
logging.info(f"RTC enabled: {cfg.rtc.enabled}")
|
||||||
logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
|
logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
|
||||||
@@ -162,6 +209,47 @@ class RTCEvaluator:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _apply_torch_compile(self):
|
||||||
|
"""Apply torch.compile to the policy model for faster inference."""
|
||||||
|
try:
|
||||||
|
# Check if torch.compile is available (PyTorch 2.0+)
|
||||||
|
if not hasattr(torch, "compile"):
|
||||||
|
logging.warning(
|
||||||
|
"torch.compile is not available. Requires PyTorch 2.0+. "
|
||||||
|
f"Current version: {torch.__version__}. Skipping compilation."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info("Applying torch.compile to policy model...")
|
||||||
|
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
|
||||||
|
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
|
||||||
|
|
||||||
|
# Compile the policy's model (not the policy itself to preserve methods)
|
||||||
|
if hasattr(self.policy, "model"):
|
||||||
|
original_model = self.policy.model
|
||||||
|
compiled_model = torch.compile(
|
||||||
|
original_model,
|
||||||
|
backend=self.cfg.torch_compile_backend,
|
||||||
|
mode=self.cfg.torch_compile_mode,
|
||||||
|
)
|
||||||
|
self.policy.model = compiled_model
|
||||||
|
logging.info("✓ Successfully compiled policy.model")
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
"Policy does not have a 'model' attribute. "
|
||||||
|
"Attempting to compile entire policy (may not work for all policy types)."
|
||||||
|
)
|
||||||
|
self.policy = torch.compile(
|
||||||
|
self.policy,
|
||||||
|
backend=self.cfg.torch_compile_backend,
|
||||||
|
mode=self.cfg.torch_compile_mode,
|
||||||
|
)
|
||||||
|
logging.info("✓ Successfully compiled policy")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to apply torch.compile: {e}")
|
||||||
|
logging.warning("Continuing without torch.compile")
|
||||||
|
|
||||||
def run_evaluation(self):
|
def run_evaluation(self):
|
||||||
"""Run evaluation on two random dataset samples."""
|
"""Run evaluation on two random dataset samples."""
|
||||||
# Create output directory
|
# Create output directory
|
||||||
|
|||||||
Reference in New Issue
Block a user