mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
feat(policies): add Smolvla torch compile support (#3043)
* Change LIBERO init_state_id when reset. Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com> * Change LIBERO init_state_id when reset. Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com> * pre-commit run * Add torch.compile for smolvla Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com> * Add torch.compile for smolvla Add model compilation option for improved performance. Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com> * first --------- Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com> Co-authored-by: Aoqun Jin <aojiaojiao@foxmail.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -106,6 +106,9 @@ class SmolVLAConfig(PreTrainedConfig):
|
|||||||
# Real-Time Chunking (RTC) configuration
|
# Real-Time Chunking (RTC) configuration
|
||||||
rtc_config: RTCConfig | None = None
|
rtc_config: RTCConfig | None = None
|
||||||
|
|
||||||
|
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||||
|
compile_mode: str = "max-autotune" # Torch compile mode
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
|
|||||||
@@ -593,6 +593,12 @@ class VLAFlowMatching(nn.Module):
|
|||||||
self.prefix_length = self.config.prefix_length
|
self.prefix_length = self.config.prefix_length
|
||||||
self.rtc_processor = rtc_processor
|
self.rtc_processor = rtc_processor
|
||||||
|
|
||||||
|
# Compile model if requested
|
||||||
|
if config.compile_model:
|
||||||
|
torch.set_float32_matmul_precision("high")
|
||||||
|
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
||||||
|
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||||
|
|
||||||
def _rtc_enabled(self):
|
def _rtc_enabled(self):
|
||||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user