fix(pi05): use fused AdamW by default

Route full PI05/PI052 fine-tuning through PyTorch's fused AdamW path to avoid the single-tensor Adam denominator allocation near GPU memory limits.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-18 19:23:17 +00:00
parent 2b4c5f49e3
commit 2629175d2d
2 changed files with 3 additions and 0 deletions
+1
View File
@@ -105,6 +105,7 @@ class AdamWConfig(OptimizerConfig):
weight_decay: float = 1e-2
grad_clip_norm: float = 10.0
foreach: bool | None = None
fused: bool | None = None
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
kwargs = asdict(self)
@@ -94,6 +94,7 @@ class PI05Config(PreTrainedConfig):
optimizer_weight_decay: float = 0.01
optimizer_grad_clip_norm: float = 1.0
optimizer_foreach: bool | None = False
optimizer_fused: bool | None = True
# Scheduler settings: see openpi `CosineDecaySchedule`
# Note: These will auto-scale if --steps < scheduler_decay_steps
@@ -154,6 +155,7 @@ class PI05Config(PreTrainedConfig):
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
foreach=self.optimizer_foreach,
fused=self.optimizer_fused,
)
def get_scheduler_preset(self):