diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index c7c6798ed..d5dd71a48 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -239,8 +239,10 @@ class SACPolicy( + target_param.data * (1.0 - self.config.critic_target_update_weight) ) - def update_temperature(self): - self.temperature = self.log_alpha.exp().item() + @property + def temperature(self) -> float: + """Return the current temperature value, always in sync with log_alpha.""" + return self.log_alpha.exp().item() def compute_loss_critic( self, @@ -457,11 +459,10 @@ class SACPolicy( dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0) self.target_entropy = -np.prod(dim) / 2 - def _init_temperature(self): - """Set up temperature parameter and initial log_alpha.""" + def _init_temperature(self) -> None: + """Set up temperature parameter (log_alpha).""" temp_init = self.config.temperature_init self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) - self.temperature = self.log_alpha.exp().item() class SACObservationEncoder(nn.Module): diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index abc5c9504..ee09ac9ac 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -545,9 +545,6 @@ def add_actor_information_and_train( training_infos["temperature_grad_norm"] = temp_grad_norm training_infos["temperature"] = policy.temperature - # Update temperature - policy.update_temperature() - # Push policy to actors if needed if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py index 8576883bd..6fad2979e 100644 --- a/tests/policies/test_sac_policy.py +++ b/tests/policies/test_sac_policy.py @@ -441,12 +441,13 @@ def test_sac_policy_with_predefined_entropy(): def test_sac_policy_update_temperature(): + """Test that temperature property is always in sync with log_alpha.""" config = create_default_config(continuous_action_dim=10, state_dim=10) policy = SACPolicy(config=config) assert policy.temperature == pytest.approx(1.0) policy.log_alpha.data = torch.tensor([math.log(0.1)]) - policy.update_temperature() + # Temperature property automatically reflects log_alpha changes assert policy.temperature == pytest.approx(0.1)