diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index 8ac0920dd..3d30e0941 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -139,6 +139,10 @@ class DiffusionConfig(PreTrainedConfig): # Inference num_inference_steps: int | None = None + # Optimization + compile_model: bool = False + compile_mode: str = "reduce-overhead" + # Loss computation do_mask_loss_for_padding: bool = False diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 1fdc76f10..7525c9252 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -182,6 +182,11 @@ class DiffusionModel(nn.Module): self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps) + if config.compile_model: + # Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops + # common in diffusion inference. + self.unet = torch.compile(self.unet, mode=config.compile_mode) + self.noise_scheduler = _make_noise_scheduler( config.noise_scheduler_type, num_train_timesteps=config.num_train_timesteps,