mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
fix pi05 forward compile (#2551)
This commit is contained in:
@@ -538,6 +538,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
if config.compile_model:
|
if config.compile_model:
|
||||||
torch.set_float32_matmul_precision("high")
|
torch.set_float32_matmul_precision("high")
|
||||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
||||||
|
# Also compile the main forward pass used during training
|
||||||
|
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||||
|
|
||||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user