mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
feat(policies): add Smolvla torch compile support (#3043)
* Change LIBERO init_state_id when reset. Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com> * Change LIBERO init_state_id when reset. Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com> * pre-commit run * Add torch.compile for smolvla Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com> * Add torch.compile for smolvla Add model compilation option for improved performance. Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com> * first --------- Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com> Co-authored-by: Aoqun Jin <aojiaojiao@foxmail.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -106,6 +106,9 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@@ -593,6 +593,12 @@ class VLAFlowMatching(nn.Module):
|
||||
self.prefix_length = self.config.prefix_length
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
# Compile model if requested
|
||||
if config.compile_model:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
|
||||
Reference in New Issue
Block a user