mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
fix(sac): make temperature a property to fix checkpoint resume bug (#2877)
* fix(sac): make temperature a property to fix checkpoint resume bug Temperature was stored as a plain float and not restored after loading a checkpoint, causing incorrect loss computations until update_temperature() was called. Changed to a property that always computes from log_alpha, ensuring correct behavior after checkpoint loading. * simplify docstrings
This commit is contained in:
@@ -239,8 +239,10 @@ class SACPolicy(
|
|||||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_temperature(self):
|
@property
|
||||||
self.temperature = self.log_alpha.exp().item()
|
def temperature(self) -> float:
|
||||||
|
"""Return the current temperature value, always in sync with log_alpha."""
|
||||||
|
return self.log_alpha.exp().item()
|
||||||
|
|
||||||
def compute_loss_critic(
|
def compute_loss_critic(
|
||||||
self,
|
self,
|
||||||
@@ -457,11 +459,10 @@ class SACPolicy(
|
|||||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||||
self.target_entropy = -np.prod(dim) / 2
|
self.target_entropy = -np.prod(dim) / 2
|
||||||
|
|
||||||
def _init_temperature(self):
|
def _init_temperature(self) -> None:
|
||||||
"""Set up temperature parameter and initial log_alpha."""
|
"""Set up temperature parameter (log_alpha)."""
|
||||||
temp_init = self.config.temperature_init
|
temp_init = self.config.temperature_init
|
||||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||||
self.temperature = self.log_alpha.exp().item()
|
|
||||||
|
|
||||||
|
|
||||||
class SACObservationEncoder(nn.Module):
|
class SACObservationEncoder(nn.Module):
|
||||||
|
|||||||
@@ -545,9 +545,6 @@ def add_actor_information_and_train(
|
|||||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||||
training_infos["temperature"] = policy.temperature
|
training_infos["temperature"] = policy.temperature
|
||||||
|
|
||||||
# Update temperature
|
|
||||||
policy.update_temperature()
|
|
||||||
|
|
||||||
# Push policy to actors if needed
|
# Push policy to actors if needed
|
||||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||||
|
|||||||
@@ -441,12 +441,13 @@ def test_sac_policy_with_predefined_entropy():
|
|||||||
|
|
||||||
|
|
||||||
def test_sac_policy_update_temperature():
|
def test_sac_policy_update_temperature():
|
||||||
|
"""Test that temperature property is always in sync with log_alpha."""
|
||||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||||
policy = SACPolicy(config=config)
|
policy = SACPolicy(config=config)
|
||||||
|
|
||||||
assert policy.temperature == pytest.approx(1.0)
|
assert policy.temperature == pytest.approx(1.0)
|
||||||
policy.log_alpha.data = torch.tensor([math.log(0.1)])
|
policy.log_alpha.data = torch.tensor([math.log(0.1)])
|
||||||
policy.update_temperature()
|
# Temperature property automatically reflects log_alpha changes
|
||||||
assert policy.temperature == pytest.approx(0.1)
|
assert policy.temperature == pytest.approx(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user