fix(sac): make temperature a property to fix checkpoint resume bug (#2877)

* fix(sac): make temperature a property to fix checkpoint resume bug

Temperature was stored as a plain float and not restored after loading
a checkpoint, causing incorrect loss computations until update_temperature()
was called. Changed to a property that always computes from log_alpha,
ensuring correct behavior after checkpoint loading.

* simplify docstrings
This commit is contained in:
Michel Aractingi
2026-01-30 12:23:22 +01:00
committed by GitHub
parent 3409ef0dc2
commit 04cbf669cf
3 changed files with 8 additions and 9 deletions
+6 -5
View File
@@ -239,8 +239,10 @@ class SACPolicy(
+ target_param.data * (1.0 - self.config.critic_target_update_weight) + target_param.data * (1.0 - self.config.critic_target_update_weight)
) )
def update_temperature(self): @property
self.temperature = self.log_alpha.exp().item() def temperature(self) -> float:
"""Return the current temperature value, always in sync with log_alpha."""
return self.log_alpha.exp().item()
def compute_loss_critic( def compute_loss_critic(
self, self,
@@ -457,11 +459,10 @@ class SACPolicy(
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0) dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.target_entropy = -np.prod(dim) / 2 self.target_entropy = -np.prod(dim) / 2
def _init_temperature(self): def _init_temperature(self) -> None:
"""Set up temperature parameter and initial log_alpha.""" """Set up temperature parameter (log_alpha)."""
temp_init = self.config.temperature_init temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
self.temperature = self.log_alpha.exp().item()
class SACObservationEncoder(nn.Module): class SACObservationEncoder(nn.Module):
-3
View File
@@ -545,9 +545,6 @@ def add_actor_information_and_train(
training_infos["temperature_grad_norm"] = temp_grad_norm training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature training_infos["temperature"] = policy.temperature
# Update temperature
policy.update_temperature()
# Push policy to actors if needed # Push policy to actors if needed
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
+2 -1
View File
@@ -441,12 +441,13 @@ def test_sac_policy_with_predefined_entropy():
def test_sac_policy_update_temperature(): def test_sac_policy_update_temperature():
"""Test that temperature property is always in sync with log_alpha."""
config = create_default_config(continuous_action_dim=10, state_dim=10) config = create_default_config(continuous_action_dim=10, state_dim=10)
policy = SACPolicy(config=config) policy = SACPolicy(config=config)
assert policy.temperature == pytest.approx(1.0) assert policy.temperature == pytest.approx(1.0)
policy.log_alpha.data = torch.tensor([math.log(0.1)]) policy.log_alpha.data = torch.tensor([math.log(0.1)])
policy.update_temperature() # Temperature property automatically reflects log_alpha changes
assert policy.temperature == pytest.approx(0.1) assert policy.temperature == pytest.approx(0.1)