Fix compilation

This commit is contained in:
Eugene Mironov
2025-11-07 02:28:49 +07:00
parent d9e72662c1
commit 4739ef9da3
3 changed files with 28 additions and 12 deletions
+26 -10
View File
@@ -19,6 +19,7 @@ Usage:
--device=mps --device=mps
# With torch.compile for faster inference (PyTorch 2.0+) # With torch.compile for faster inference (PyTorch 2.0+)
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
uv run python examples/rtc/eval_dataset.py \ uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \ --policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \ --dataset.repo_id=helper2424/check_rtc \
@@ -27,7 +28,7 @@ Usage:
--use_torch_compile=true \ --use_torch_compile=true \
--torch_compile_mode=max-autotune --torch_compile_mode=max-autotune
# With torch.compile on CUDA # With torch.compile on CUDA (CUDA graphs disabled by default)
uv run python examples/rtc/eval_dataset.py \ uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \ --policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \ --dataset.repo_id=helper2424/check_rtc \
@@ -36,13 +37,14 @@ Usage:
--use_torch_compile=true \ --use_torch_compile=true \
--torch_compile_mode=reduce-overhead --torch_compile_mode=reduce-overhead
# With custom compile settings # Enable CUDA graphs (advanced - may cause tensor aliasing errors)
uv run python examples/rtc/eval_dataset.py \ uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \ --policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \ --dataset.repo_id=helper2424/check_rtc \
--use_torch_compile=true \ --use_torch_compile=true \
--torch_compile_backend=inductor \ --torch_compile_backend=inductor \
--torch_compile_mode=max-autotune --torch_compile_mode=max-autotune \
--torch_compile_disable_cudagraphs=false
""" """
import gc import gc
@@ -142,6 +144,14 @@ class RTCEvalConfig(HubMixin):
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"}, metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
) )
torch_compile_disable_cudagraphs: bool = field(
default=True,
metadata={
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)
def __post_init__(self): def __post_init__(self):
# Parse policy path # Parse policy path
policy_path = parser.get_path_arg("policy") policy_path = parser.get_path_arg("policy")
@@ -265,17 +275,23 @@ class RTCEvaluator:
logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...") logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
logging.info(f" Backend: {self.cfg.torch_compile_backend}") logging.info(f" Backend: {self.cfg.torch_compile_backend}")
logging.info(f" Mode: {self.cfg.torch_compile_mode}") logging.info(f" Mode: {self.cfg.torch_compile_mode}")
logging.info(f" Disable CUDA graphs: {self.cfg.torch_compile_disable_cudagraphs}")
logging.info(" Note: Debug tracker excluded from compilation via @torch._dynamo.disable") logging.info(" Note: Debug tracker excluded from compilation via @torch._dynamo.disable")
# Compile the predict_action_chunk method # Compile the predict_action_chunk method
# The debug tracker is excluded from compilation via @torch._dynamo.disable decorator # - Debug tracker is excluded from compilation via @torch._dynamo.disable
# on the Tracker.track() method, so it won't cause graph breaks # - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
compile_kwargs = {
"backend": self.cfg.torch_compile_backend,
"mode": self.cfg.torch_compile_mode,
}
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if self.cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk original_method = policy.predict_action_chunk
compiled_method = torch.compile( compiled_method = torch.compile(original_method, **compile_kwargs)
original_method,
backend=self.cfg.torch_compile_backend,
mode=self.cfg.torch_compile_mode,
)
policy.predict_action_chunk = compiled_method policy.predict_action_chunk = compiled_method
logging.info(f" ✓ [{policy_name}] Successfully compiled predict_action_chunk") logging.info(f" ✓ [{policy_name}] Successfully compiled predict_action_chunk")
+1 -1
View File
@@ -270,7 +270,7 @@ class RTCProcessor:
execution_horizon=execution_horizon, execution_horizon=execution_horizon,
) )
return result, x_t return result
def get_prefix_weights(self, start, end, total): def get_prefix_weights(self, start, end, total):
start = min(start, end) start = min(start, end)
@@ -806,7 +806,7 @@ class VLAFlowMatching(nn.Module):
prev_chunk_left_over = kwargs.get("prev_chunk_left_over") prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon") execution_horizon = kwargs.get("execution_horizon")
v_t, x_t = self.rtc_processor.denoise_step( v_t = self.rtc_processor.denoise_step(
x_t=x_t, x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over, prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay, inference_delay=inference_delay,