mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 10:40:04 +00:00
fix(pi05): use fused AdamW by default
Route full PI05/PI052 fine-tuning through PyTorch's fused AdamW path to avoid the single-tensor Adam denominator allocation near GPU memory limits. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -105,6 +105,7 @@ class AdamWConfig(OptimizerConfig):
|
||||
weight_decay: float = 1e-2
|
||||
grad_clip_norm: float = 10.0
|
||||
foreach: bool | None = None
|
||||
fused: bool | None = None
|
||||
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
|
||||
@@ -94,6 +94,7 @@ class PI05Config(PreTrainedConfig):
|
||||
optimizer_weight_decay: float = 0.01
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
optimizer_foreach: bool | None = False
|
||||
optimizer_fused: bool | None = True
|
||||
|
||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
@@ -154,6 +155,7 @@ class PI05Config(PreTrainedConfig):
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
foreach=self.optimizer_foreach,
|
||||
fused=self.optimizer_fused,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
|
||||
Reference in New Issue
Block a user