From a0fdbf037ac918d0f2cdfd540db72199e0b925d0 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Fri, 27 Feb 2026 18:58:36 +0300 Subject: [PATCH] feat(policies): add Smolvla torch compile support (#3043) * Change LIBERO init_state_id when reset. Signed-off-by: Aoqun Jin * Change LIBERO init_state_id when reset. Signed-off-by: Aoqun Jin * pre-commit run * Add torch.compile for smolvla Signed-off-by: Aoqun Jin * Add torch.compile for smolvla Add model compilation option for improved performance. Signed-off-by: Aoqun Jin * first --------- Signed-off-by: Aoqun Jin Co-authored-by: Aoqun Jin Co-authored-by: Steven Palma --- src/lerobot/policies/smolvla/configuration_smolvla.py | 3 +++ src/lerobot/policies/smolvla/modeling_smolvla.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/src/lerobot/policies/smolvla/configuration_smolvla.py b/src/lerobot/policies/smolvla/configuration_smolvla.py index c696265f2..b861b856b 100644 --- a/src/lerobot/policies/smolvla/configuration_smolvla.py +++ b/src/lerobot/policies/smolvla/configuration_smolvla.py @@ -106,6 +106,9 @@ class SmolVLAConfig(PreTrainedConfig): # Real-Time Chunking (RTC) configuration rtc_config: RTCConfig | None = None + compile_model: bool = False # Whether to use torch.compile for model optimization + compile_mode: str = "max-autotune" # Torch compile mode + def __post_init__(self): super().__post_init__() diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 10544a949..e49226d26 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -593,6 +593,12 @@ class VLAFlowMatching(nn.Module): self.prefix_length = self.config.prefix_length self.rtc_processor = rtc_processor + # Compile model if requested + if config.compile_model: + torch.set_float32_matmul_precision("high") + self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) + self.forward = torch.compile(self.forward, mode=config.compile_mode) + def _rtc_enabled(self): return self.config.rtc_config is not None and self.config.rtc_config.enabled