Fix compilation

This commit is contained in:
Eugene Mironov
2025-11-07 02:28:49 +07:00
parent d9e72662c1
commit 4739ef9da3
3 changed files with 28 additions and 12 deletions
+26 -10
View File
@@ -19,6 +19,7 @@ Usage:
--device=mps
# With torch.compile for faster inference (PyTorch 2.0+)
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
@@ -27,7 +28,7 @@ Usage:
--use_torch_compile=true \
--torch_compile_mode=max-autotune
# With torch.compile on CUDA
# With torch.compile on CUDA (CUDA graphs disabled by default)
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
@@ -36,13 +37,14 @@ Usage:
--use_torch_compile=true \
--torch_compile_mode=reduce-overhead
# With custom compile settings
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--use_torch_compile=true \
--torch_compile_backend=inductor \
--torch_compile_mode=max-autotune
--torch_compile_mode=max-autotune \
--torch_compile_disable_cudagraphs=false
"""
import gc
@@ -142,6 +144,14 @@ class RTCEvalConfig(HubMixin):
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
)
torch_compile_disable_cudagraphs: bool = field(
default=True,
metadata={
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)
def __post_init__(self):
# Parse policy path
policy_path = parser.get_path_arg("policy")
@@ -265,17 +275,23 @@ class RTCEvaluator:
logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
logging.info(f" Disable CUDA graphs: {self.cfg.torch_compile_disable_cudagraphs}")
logging.info(" Note: Debug tracker excluded from compilation via @torch._dynamo.disable")
# Compile the predict_action_chunk method
# The debug tracker is excluded from compilation via @torch._dynamo.disable decorator
# on the Tracker.track() method, so it won't cause graph breaks
# - Debug tracker is excluded from compilation via @torch._dynamo.disable
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
compile_kwargs = {
"backend": self.cfg.torch_compile_backend,
"mode": self.cfg.torch_compile_mode,
}
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if self.cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk
compiled_method = torch.compile(
original_method,
backend=self.cfg.torch_compile_backend,
mode=self.cfg.torch_compile_mode,
)
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logging.info(f" ✓ [{policy_name}] Successfully compiled predict_action_chunk")
+1 -1
View File
@@ -270,7 +270,7 @@ class RTCProcessor:
execution_horizon=execution_horizon,
)
return result, x_t
return result
def get_prefix_weights(self, start, end, total):
start = min(start, end)
@@ -806,7 +806,7 @@ class VLAFlowMatching(nn.Module):
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t, x_t = self.rtc_processor.denoise_step(
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,