also compile forward method

This commit is contained in:
Pepijn
2025-09-13 11:12:54 +02:00
parent c8163662ad
commit c5a029a28a
@@ -517,6 +517,8 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
if config.compile_model:
torch.set_float32_matmul_precision("high")
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
# Also compile the main forward pass used during training
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """transformers_replace is not installed correctly.
Please install it with `pip install transformers==4.53.2`