diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index b017bbc57..6500ada20 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -538,6 +538,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` if config.compile_model: torch.set_float32_matmul_precision("high") self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) + # Also compile the main forward pass used during training + self.forward = torch.compile(self.forward, mode=config.compile_mode) msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""