mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
chore: clarify torch.compile disabled note in SACAlgorithm
This commit is contained in:
@@ -195,7 +195,7 @@ class SACConfig(PreTrainedConfig):
|
|||||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||||
|
|
||||||
# Optimizations
|
# Optimizations
|
||||||
use_torch_compile: bool = True
|
use_torch_compile: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
|
|||||||
shared_encoder: bool = True
|
shared_encoder: bool = True
|
||||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||||
use_torch_compile: bool = True
|
use_torch_compile: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig:
|
def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig:
|
||||||
|
|||||||
@@ -87,6 +87,8 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
||||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||||
|
|
||||||
|
# TODO(Khalil): Investigate and fix torch.compile
|
||||||
|
# NOTE: torch.compile is disabled, policy does not converge when enabled.
|
||||||
if self.config.use_torch_compile:
|
if self.config.use_torch_compile:
|
||||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||||
self.critic_target = torch.compile(self.critic_target)
|
self.critic_target = torch.compile(self.critic_target)
|
||||||
|
|||||||
Reference in New Issue
Block a user