refactor(sac): decouple algorithm hyperparameters from policy config

This commit is contained in:
Khalil Meftah
2026-04-18 16:40:56 +02:00
parent 2487a6ee6d
commit a84b0e8132
3 changed files with 42 additions and 35 deletions
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING
import torch import torch
from lerobot.policies.sac.configuration_sac import CriticNetworkConfig from lerobot.policies.sac.configuration_sac import CriticNetworkConfig, SACConfig
from lerobot.rl.algorithms.configs import RLAlgorithmConfig from lerobot.rl.algorithms.configs import RLAlgorithmConfig
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -29,49 +29,54 @@ if TYPE_CHECKING:
@RLAlgorithmConfig.register_subclass("sac") @RLAlgorithmConfig.register_subclass("sac")
@dataclass @dataclass
class SACAlgorithmConfig(RLAlgorithmConfig): class SACAlgorithmConfig(RLAlgorithmConfig):
"""SAC-specific hyper-parameters that control the update loop.""" """SAC algorithm hyperparameters."""
utd_ratio: int = 1 # Policy config
policy_update_freq: int = 1 sac_config: SACConfig
clip_grad_norm: float = 40.0
# Optimizer learning rates
actor_lr: float = 3e-4 actor_lr: float = 3e-4
critic_lr: float = 3e-4 critic_lr: float = 3e-4
temperature_lr: float = 3e-4 temperature_lr: float = 3e-4
# Bellman update
discount: float = 0.99 discount: float = 0.99
temperature_init: float = 1.0
target_entropy: float | None = None
use_backup_entropy: bool = True use_backup_entropy: bool = True
critic_target_update_weight: float = 0.005 critic_target_update_weight: float = 0.005
# Critic ensemble
num_critics: int = 2 num_critics: int = 2
num_subsample_critics: int | None = None num_subsample_critics: int | None = None
num_discrete_actions: int | None = None
shared_encoder: bool = True
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
use_torch_compile: bool = False
# Temperature / entropy
temperature_init: float = 1.0
# Update loop
utd_ratio: int = 1
policy_update_freq: int = 1
grad_clip_norm: float = 40.0
@classmethod @classmethod
def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig: def from_policy_config(cls, policy_cfg: SACConfig) -> SACAlgorithmConfig:
"""Build from an existing ``SACConfig`` (cfg.policy) for backwards compat.""" """Build an algorithm config by copying hyperparameters from the policy config."""
return cls( return cls(
utd_ratio=policy_cfg.utd_ratio,
policy_update_freq=policy_cfg.policy_update_freq,
clip_grad_norm=policy_cfg.grad_clip_norm,
actor_lr=policy_cfg.actor_lr, actor_lr=policy_cfg.actor_lr,
critic_lr=policy_cfg.critic_lr, critic_lr=policy_cfg.critic_lr,
temperature_lr=policy_cfg.temperature_lr, temperature_lr=policy_cfg.temperature_lr,
discount=policy_cfg.discount, discount=policy_cfg.discount,
temperature_init=policy_cfg.temperature_init,
target_entropy=policy_cfg.target_entropy,
use_backup_entropy=policy_cfg.use_backup_entropy, use_backup_entropy=policy_cfg.use_backup_entropy,
critic_target_update_weight=policy_cfg.critic_target_update_weight, critic_target_update_weight=policy_cfg.critic_target_update_weight,
num_critics=policy_cfg.num_critics, num_critics=policy_cfg.num_critics,
num_subsample_critics=policy_cfg.num_subsample_critics, num_subsample_critics=policy_cfg.num_subsample_critics,
num_discrete_actions=policy_cfg.num_discrete_actions,
shared_encoder=policy_cfg.shared_encoder,
critic_network_kwargs=policy_cfg.critic_network_kwargs, critic_network_kwargs=policy_cfg.critic_network_kwargs,
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs, discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
use_torch_compile=policy_cfg.use_torch_compile, temperature_init=policy_cfg.temperature_init,
utd_ratio=policy_cfg.utd_ratio,
policy_update_freq=policy_cfg.policy_update_freq,
grad_clip_norm=policy_cfg.grad_clip_norm,
sac_config=policy_cfg,
) )
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm: def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
+14 -13
View File
@@ -53,8 +53,9 @@ class SACAlgorithm(RLAlgorithm):
policy: SACPolicy, policy: SACPolicy,
config: SACAlgorithmConfig, config: SACAlgorithmConfig,
): ):
self.policy = policy
self.config = config self.config = config
self.policy_config = config.sac_config
self.policy = policy
self.optimizers: dict[str, Optimizer] = {} self.optimizers: dict[str, Optimizer] = {}
self._optimization_step: int = 0 self._optimization_step: int = 0
@@ -89,13 +90,13 @@ class SACAlgorithm(RLAlgorithm):
# TODO(Khalil): Investigate and fix torch.compile # TODO(Khalil): Investigate and fix torch.compile
# NOTE: torch.compile is disabled, policy does not converge when enabled. # NOTE: torch.compile is disabled, policy does not converge when enabled.
if self.config.use_torch_compile: if self.policy_config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble) self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target) self.critic_target = torch.compile(self.critic_target)
self.discrete_critic = None self.discrete_critic = None
self.discrete_critic_target = None self.discrete_critic_target = None
if self.config.num_discrete_actions is not None: if self.policy_config.num_discrete_actions is not None:
self.discrete_critic, self.discrete_critic_target = self._init_discrete_critics(encoder) self.discrete_critic, self.discrete_critic_target = self._init_discrete_critics(encoder)
self.policy.discrete_critic = self.discrete_critic self.policy.discrete_critic = self.discrete_critic
@@ -104,13 +105,13 @@ class SACAlgorithm(RLAlgorithm):
discrete_critic = DiscreteCritic( discrete_critic = DiscreteCritic(
encoder=encoder, encoder=encoder,
input_dim=encoder.output_dim, input_dim=encoder.output_dim,
output_dim=self.config.num_discrete_actions, output_dim=self.policy_config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs), **asdict(self.config.discrete_critic_network_kwargs),
) )
discrete_critic_target = DiscreteCritic( discrete_critic_target = DiscreteCritic(
encoder=encoder, encoder=encoder,
input_dim=encoder.output_dim, input_dim=encoder.output_dim,
output_dim=self.config.num_discrete_actions, output_dim=self.policy_config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs), **asdict(self.config.discrete_critic_network_kwargs),
) )
@@ -177,7 +178,7 @@ class SACAlgorithm(RLAlgorithm):
return q_values return q_values
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
clip = self.config.clip_grad_norm clip = self.config.grad_clip_norm
for _ in range(self.config.utd_ratio - 1): for _ in range(self.config.utd_ratio - 1):
batch = next(batch_iterator) batch = next(batch_iterator)
@@ -189,7 +190,7 @@ class SACAlgorithm(RLAlgorithm):
torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip) torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip)
self.optimizers["critic"].step() self.optimizers["critic"].step()
if self.config.num_discrete_actions is not None: if self.policy_config.num_discrete_actions is not None:
loss_dc = self._compute_loss_discrete_critic(fb) loss_dc = self._compute_loss_discrete_critic(fb)
self.optimizers["discrete_critic"].zero_grad() self.optimizers["discrete_critic"].zero_grad()
loss_dc.backward() loss_dc.backward()
@@ -212,7 +213,7 @@ class SACAlgorithm(RLAlgorithm):
grad_norms={"critic": critic_grad}, grad_norms={"critic": critic_grad},
) )
if self.config.num_discrete_actions is not None: if self.policy_config.num_discrete_actions is not None:
loss_dc = self._compute_loss_discrete_critic(fb) loss_dc = self._compute_loss_discrete_critic(fb)
self.optimizers["discrete_critic"].zero_grad() self.optimizers["discrete_critic"].zero_grad()
loss_dc.backward() loss_dc.backward()
@@ -284,7 +285,7 @@ class SACAlgorithm(RLAlgorithm):
td_target = rewards + (1 - done) * self.config.discount * min_q td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs # 3- compute predicted qs
if self.config.num_discrete_actions is not None: if self.policy_config.num_discrete_actions is not None:
# NOTE: We only want to keep the continuous action part # NOTE: We only want to keep the continuous action part
# In the buffer we have the full action space (continuous + discrete) # In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward # We need to split them before concatenating them in the critic forward
@@ -405,7 +406,7 @@ class SACAlgorithm(RLAlgorithm):
p.data * self.config.critic_target_update_weight p.data * self.config.critic_target_update_weight
+ target_p.data * (1.0 - self.config.critic_target_update_weight) + target_p.data * (1.0 - self.config.critic_target_update_weight)
) )
if self.config.num_discrete_actions is not None: if self.policy_config.num_discrete_actions is not None:
for target_p, p in zip( for target_p, p in zip(
self.discrete_critic_target.parameters(), self.discrete_critic_target.parameters(),
self.discrete_critic.parameters(), self.discrete_critic.parameters(),
@@ -466,7 +467,7 @@ class SACAlgorithm(RLAlgorithm):
"critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr), "critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr),
"temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr), "temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr),
} }
if self.config.num_discrete_actions is not None: if self.policy_config.num_discrete_actions is not None:
self.optimizers["discrete_critic"] = torch.optim.Adam( self.optimizers["discrete_critic"] = torch.optim.Adam(
self.discrete_critic.parameters(), lr=self.config.critic_lr self.discrete_critic.parameters(), lr=self.config.critic_lr
) )
@@ -480,7 +481,7 @@ class SACAlgorithm(RLAlgorithm):
state_dicts: dict[str, Any] = { state_dicts: dict[str, Any] = {
"policy": move_state_dict_to_device(self.policy.actor.state_dict(), device="cpu"), "policy": move_state_dict_to_device(self.policy.actor.state_dict(), device="cpu"),
} }
if self.config.num_discrete_actions is not None: if self.policy_config.num_discrete_actions is not None:
state_dicts["discrete_critic"] = move_state_dict_to_device( state_dicts["discrete_critic"] = move_state_dict_to_device(
self.discrete_critic.state_dict(), device="cpu" self.discrete_critic.state_dict(), device="cpu"
) )
@@ -489,7 +490,7 @@ class SACAlgorithm(RLAlgorithm):
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
actor_sd = move_state_dict_to_device(weights["policy"], device=device) actor_sd = move_state_dict_to_device(weights["policy"], device=device)
self.policy.actor.load_state_dict(actor_sd) self.policy.actor.load_state_dict(actor_sd)
if "discrete_critic" in weights and self.config.num_discrete_actions is not None: if "discrete_critic" in weights and self.policy_config.num_discrete_actions is not None:
dc_sd = move_state_dict_to_device(weights["discrete_critic"], device=device) dc_sd = move_state_dict_to_device(weights["discrete_critic"], device=device)
self.discrete_critic.load_state_dict(dc_sd) self.discrete_critic.load_state_dict(dc_sd)
+3 -2
View File
@@ -136,12 +136,13 @@ def test_sac_algorithm_config_registered():
def test_sac_algorithm_config_from_policy_config(): def test_sac_algorithm_config_from_policy_config():
"""from_policy_config should copy relevant fields.""" """from_policy_config should copy algorithm hyperparameters from the policy config."""
sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2) sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2)
algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg) algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg)
assert algo_cfg.sac_config is sac_cfg
assert algo_cfg.utd_ratio == 4 assert algo_cfg.utd_ratio == 4
assert algo_cfg.policy_update_freq == 2 assert algo_cfg.policy_update_freq == 2
assert algo_cfg.clip_grad_norm == sac_cfg.grad_clip_norm assert algo_cfg.grad_clip_norm == sac_cfg.grad_clip_norm
# =========================================================================== # ===========================================================================