fix pi05 forward compile (#2551)

This commit is contained in:
Michel Aractingi
2025-12-02 11:01:43 +01:00
committed by GitHub
parent af4766b602
commit 797cd2725a
@@ -538,6 +538,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
if config.compile_model:
torch.set_float32_matmul_precision("high")
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
# Also compile the main forward pass used during training
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""