feat: Enable torch.compile for DiffusionPolicy inference (#2486)

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Jash Shah
2026-02-24 08:29:08 -08:00
committed by GitHub
parent 7fd71c83a3
commit dac1efd13d
2 changed files with 9 additions and 0 deletions
@@ -139,6 +139,10 @@ class DiffusionConfig(PreTrainedConfig):
# Inference
num_inference_steps: int | None = None
# Optimization
compile_model: bool = False
compile_mode: str = "reduce-overhead"
# Loss computation
do_mask_loss_for_padding: bool = False
@@ -182,6 +182,11 @@ class DiffusionModel(nn.Module):
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
if config.compile_model:
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
# common in diffusion inference.
self.unet = torch.compile(self.unet, mode=config.compile_mode)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps,