From 036b310a9771af81e230e33a2835646f2178fc94 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Mon, 13 Apr 2026 11:49:27 +0200 Subject: [PATCH] chore: clarify torch.compile disabled note in SACAlgorithm --- src/lerobot/policies/sac/configuration_sac.py | 2 +- src/lerobot/rl/algorithms/sac/configuration_sac.py | 2 +- src/lerobot/rl/algorithms/sac/sac_algorithm.py | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/sac/configuration_sac.py b/src/lerobot/policies/sac/configuration_sac.py index ada12330c..ad846b11c 100644 --- a/src/lerobot/policies/sac/configuration_sac.py +++ b/src/lerobot/policies/sac/configuration_sac.py @@ -195,7 +195,7 @@ class SACConfig(PreTrainedConfig): concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) # Optimizations - use_torch_compile: bool = True + use_torch_compile: bool = False def __post_init__(self): super().__post_init__() diff --git a/src/lerobot/rl/algorithms/sac/configuration_sac.py b/src/lerobot/rl/algorithms/sac/configuration_sac.py index c2aac050b..dfe58e086 100644 --- a/src/lerobot/rl/algorithms/sac/configuration_sac.py +++ b/src/lerobot/rl/algorithms/sac/configuration_sac.py @@ -48,7 +48,7 @@ class SACAlgorithmConfig(RLAlgorithmConfig): shared_encoder: bool = True critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) - use_torch_compile: bool = True + use_torch_compile: bool = False @classmethod def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig: diff --git a/src/lerobot/rl/algorithms/sac/sac_algorithm.py b/src/lerobot/rl/algorithms/sac/sac_algorithm.py index 0ce7a6875..10c104bb6 100644 --- a/src/lerobot/rl/algorithms/sac/sac_algorithm.py +++ b/src/lerobot/rl/algorithms/sac/sac_algorithm.py @@ -87,6 +87,8 @@ class SACAlgorithm(RLAlgorithm): self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads) self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) + # TODO(Khalil): Investigate and fix torch.compile + # NOTE: torch.compile is disabled, policy does not converge when enabled. if self.config.use_torch_compile: self.critic_ensemble = torch.compile(self.critic_ensemble) self.critic_target = torch.compile(self.critic_target)