diff --git a/src/lerobot/optim/optimizers.py b/src/lerobot/optim/optimizers.py index 3d9ec6b98..6ec38491b 100644 --- a/src/lerobot/optim/optimizers.py +++ b/src/lerobot/optim/optimizers.py @@ -105,6 +105,7 @@ class AdamWConfig(OptimizerConfig): weight_decay: float = 1e-2 grad_clip_norm: float = 10.0 foreach: bool | None = None + fused: bool | None = None def build(self, params: OptimizerParams) -> torch.optim.Optimizer: kwargs = asdict(self) diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py index d15cc0ee8..192fce448 100644 --- a/src/lerobot/policies/pi05/configuration_pi05.py +++ b/src/lerobot/policies/pi05/configuration_pi05.py @@ -94,6 +94,7 @@ class PI05Config(PreTrainedConfig): optimizer_weight_decay: float = 0.01 optimizer_grad_clip_norm: float = 1.0 optimizer_foreach: bool | None = False + optimizer_fused: bool | None = True # Scheduler settings: see openpi `CosineDecaySchedule` # Note: These will auto-scale if --steps < scheduler_decay_steps @@ -154,6 +155,7 @@ class PI05Config(PreTrainedConfig): weight_decay=self.optimizer_weight_decay, grad_clip_norm=self.optimizer_grad_clip_norm, foreach=self.optimizer_foreach, + fused=self.optimizer_fused, ) def get_scheduler_preset(self):