chore: clarify torch.compile disabled note in SACAlgorithm

This commit is contained in:
Khalil Meftah
2026-04-13 11:49:27 +02:00
parent e022207c75
commit 036b310a97
3 changed files with 4 additions and 2 deletions
@@ -195,7 +195,7 @@ class SACConfig(PreTrainedConfig):
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
# Optimizations
use_torch_compile: bool = True
use_torch_compile: bool = False
def __post_init__(self):
super().__post_init__()
@@ -48,7 +48,7 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
shared_encoder: bool = True
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
use_torch_compile: bool = True
use_torch_compile: bool = False
@classmethod
def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig:
@@ -87,6 +87,8 @@ class SACAlgorithm(RLAlgorithm):
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
# TODO(Khalil): Investigate and fix torch.compile
# NOTE: torch.compile is disabled, policy does not converge when enabled.
if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)