mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 10:40:04 +00:00
chore: clarify torch.compile disabled note in SACAlgorithm
This commit is contained in:
@@ -195,7 +195,7 @@ class SACConfig(PreTrainedConfig):
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
# Optimizations
|
||||
use_torch_compile: bool = True
|
||||
use_torch_compile: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
@@ -48,7 +48,7 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
|
||||
shared_encoder: bool = True
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
use_torch_compile: bool = True
|
||||
use_torch_compile: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig:
|
||||
|
||||
@@ -87,6 +87,8 @@ class SACAlgorithm(RLAlgorithm):
|
||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
# TODO(Khalil): Investigate and fix torch.compile
|
||||
# NOTE: torch.compile is disabled, policy does not converge when enabled.
|
||||
if self.config.use_torch_compile:
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
Reference in New Issue
Block a user