mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
feat: Enable torch.compile for DiffusionPolicy inference (#2486)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -139,6 +139,10 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
# Inference
|
# Inference
|
||||||
num_inference_steps: int | None = None
|
num_inference_steps: int | None = None
|
||||||
|
|
||||||
|
# Optimization
|
||||||
|
compile_model: bool = False
|
||||||
|
compile_mode: str = "reduce-overhead"
|
||||||
|
|
||||||
# Loss computation
|
# Loss computation
|
||||||
do_mask_loss_for_padding: bool = False
|
do_mask_loss_for_padding: bool = False
|
||||||
|
|
||||||
|
|||||||
@@ -182,6 +182,11 @@ class DiffusionModel(nn.Module):
|
|||||||
|
|
||||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||||
|
|
||||||
|
if config.compile_model:
|
||||||
|
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
|
||||||
|
# common in diffusion inference.
|
||||||
|
self.unet = torch.compile(self.unet, mode=config.compile_mode)
|
||||||
|
|
||||||
self.noise_scheduler = _make_noise_scheduler(
|
self.noise_scheduler = _make_noise_scheduler(
|
||||||
config.noise_scheduler_type,
|
config.noise_scheduler_type,
|
||||||
num_train_timesteps=config.num_train_timesteps,
|
num_train_timesteps=config.num_train_timesteps,
|
||||||
|
|||||||
Reference in New Issue
Block a user