mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
Fix compilation
This commit is contained in:
@@ -19,6 +19,7 @@ Usage:
|
|||||||
--device=mps
|
--device=mps
|
||||||
|
|
||||||
# With torch.compile for faster inference (PyTorch 2.0+)
|
# With torch.compile for faster inference (PyTorch 2.0+)
|
||||||
|
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
|
||||||
uv run python examples/rtc/eval_dataset.py \
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||||
--dataset.repo_id=helper2424/check_rtc \
|
--dataset.repo_id=helper2424/check_rtc \
|
||||||
@@ -27,7 +28,7 @@ Usage:
|
|||||||
--use_torch_compile=true \
|
--use_torch_compile=true \
|
||||||
--torch_compile_mode=max-autotune
|
--torch_compile_mode=max-autotune
|
||||||
|
|
||||||
# With torch.compile on CUDA
|
# With torch.compile on CUDA (CUDA graphs disabled by default)
|
||||||
uv run python examples/rtc/eval_dataset.py \
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||||
--dataset.repo_id=helper2424/check_rtc \
|
--dataset.repo_id=helper2424/check_rtc \
|
||||||
@@ -36,13 +37,14 @@ Usage:
|
|||||||
--use_torch_compile=true \
|
--use_torch_compile=true \
|
||||||
--torch_compile_mode=reduce-overhead
|
--torch_compile_mode=reduce-overhead
|
||||||
|
|
||||||
# With custom compile settings
|
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
|
||||||
uv run python examples/rtc/eval_dataset.py \
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||||
--dataset.repo_id=helper2424/check_rtc \
|
--dataset.repo_id=helper2424/check_rtc \
|
||||||
--use_torch_compile=true \
|
--use_torch_compile=true \
|
||||||
--torch_compile_backend=inductor \
|
--torch_compile_backend=inductor \
|
||||||
--torch_compile_mode=max-autotune
|
--torch_compile_mode=max-autotune \
|
||||||
|
--torch_compile_disable_cudagraphs=false
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
@@ -142,6 +144,14 @@ class RTCEvalConfig(HubMixin):
|
|||||||
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
torch_compile_disable_cudagraphs: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={
|
||||||
|
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
||||||
|
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Parse policy path
|
# Parse policy path
|
||||||
policy_path = parser.get_path_arg("policy")
|
policy_path = parser.get_path_arg("policy")
|
||||||
@@ -265,17 +275,23 @@ class RTCEvaluator:
|
|||||||
logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
|
logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
|
||||||
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
|
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
|
||||||
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
|
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
|
||||||
|
logging.info(f" Disable CUDA graphs: {self.cfg.torch_compile_disable_cudagraphs}")
|
||||||
logging.info(" Note: Debug tracker excluded from compilation via @torch._dynamo.disable")
|
logging.info(" Note: Debug tracker excluded from compilation via @torch._dynamo.disable")
|
||||||
|
|
||||||
# Compile the predict_action_chunk method
|
# Compile the predict_action_chunk method
|
||||||
# The debug tracker is excluded from compilation via @torch._dynamo.disable decorator
|
# - Debug tracker is excluded from compilation via @torch._dynamo.disable
|
||||||
# on the Tracker.track() method, so it won't cause graph breaks
|
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
|
||||||
|
compile_kwargs = {
|
||||||
|
"backend": self.cfg.torch_compile_backend,
|
||||||
|
"mode": self.cfg.torch_compile_mode,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
||||||
|
if self.cfg.torch_compile_disable_cudagraphs:
|
||||||
|
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||||
|
|
||||||
original_method = policy.predict_action_chunk
|
original_method = policy.predict_action_chunk
|
||||||
compiled_method = torch.compile(
|
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||||
original_method,
|
|
||||||
backend=self.cfg.torch_compile_backend,
|
|
||||||
mode=self.cfg.torch_compile_mode,
|
|
||||||
)
|
|
||||||
policy.predict_action_chunk = compiled_method
|
policy.predict_action_chunk = compiled_method
|
||||||
logging.info(f" ✓ [{policy_name}] Successfully compiled predict_action_chunk")
|
logging.info(f" ✓ [{policy_name}] Successfully compiled predict_action_chunk")
|
||||||
|
|
||||||
|
|||||||
@@ -270,7 +270,7 @@ class RTCProcessor:
|
|||||||
execution_horizon=execution_horizon,
|
execution_horizon=execution_horizon,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result, x_t
|
return result
|
||||||
|
|
||||||
def get_prefix_weights(self, start, end, total):
|
def get_prefix_weights(self, start, end, total):
|
||||||
start = min(start, end)
|
start = min(start, end)
|
||||||
|
|||||||
@@ -806,7 +806,7 @@ class VLAFlowMatching(nn.Module):
|
|||||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||||
execution_horizon = kwargs.get("execution_horizon")
|
execution_horizon = kwargs.get("execution_horizon")
|
||||||
|
|
||||||
v_t, x_t = self.rtc_processor.denoise_step(
|
v_t = self.rtc_processor.denoise_step(
|
||||||
x_t=x_t,
|
x_t=x_t,
|
||||||
prev_chunk_left_over=prev_chunk_left_over,
|
prev_chunk_left_over=prev_chunk_left_over,
|
||||||
inference_delay=inference_delay,
|
inference_delay=inference_delay,
|
||||||
|
|||||||
Reference in New Issue
Block a user