mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-03 08:07:03 +00:00
also compile forward method
This commit is contained in:
@@ -517,6 +517,8 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
if config.compile_model:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
||||
# Also compile the main forward pass used during training
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
msg = """transformers_replace is not installed correctly.
|
||||
Please install it with `pip install transformers==4.53.2`
|
||||
|
||||
Reference in New Issue
Block a user