diff --git a/src/lerobot/rl/algorithms/sac/configuration_sac.py b/src/lerobot/rl/algorithms/sac/configuration_sac.py index dfe58e086..9ae99424a 100644 --- a/src/lerobot/rl/algorithms/sac/configuration_sac.py +++ b/src/lerobot/rl/algorithms/sac/configuration_sac.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING import torch -from lerobot.policies.sac.configuration_sac import CriticNetworkConfig +from lerobot.policies.sac.configuration_sac import CriticNetworkConfig, SACConfig from lerobot.rl.algorithms.configs import RLAlgorithmConfig if TYPE_CHECKING: @@ -29,49 +29,54 @@ if TYPE_CHECKING: @RLAlgorithmConfig.register_subclass("sac") @dataclass class SACAlgorithmConfig(RLAlgorithmConfig): - """SAC-specific hyper-parameters that control the update loop.""" + """SAC algorithm hyperparameters.""" - utd_ratio: int = 1 - policy_update_freq: int = 1 - clip_grad_norm: float = 40.0 + # Policy config + sac_config: SACConfig + + # Optimizer learning rates actor_lr: float = 3e-4 critic_lr: float = 3e-4 temperature_lr: float = 3e-4 + + # Bellman update discount: float = 0.99 - temperature_init: float = 1.0 - target_entropy: float | None = None use_backup_entropy: bool = True critic_target_update_weight: float = 0.005 + + # Critic ensemble num_critics: int = 2 num_subsample_critics: int | None = None - num_discrete_actions: int | None = None - shared_encoder: bool = True critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) - use_torch_compile: bool = False + + # Temperature / entropy + temperature_init: float = 1.0 + + # Update loop + utd_ratio: int = 1 + policy_update_freq: int = 1 + grad_clip_norm: float = 40.0 @classmethod - def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig: - """Build from an existing ``SACConfig`` (cfg.policy) for backwards compat.""" + def from_policy_config(cls, policy_cfg: SACConfig) -> SACAlgorithmConfig: + """Build an algorithm config by copying hyperparameters from the policy config.""" return cls( - utd_ratio=policy_cfg.utd_ratio, - policy_update_freq=policy_cfg.policy_update_freq, - clip_grad_norm=policy_cfg.grad_clip_norm, actor_lr=policy_cfg.actor_lr, critic_lr=policy_cfg.critic_lr, temperature_lr=policy_cfg.temperature_lr, discount=policy_cfg.discount, - temperature_init=policy_cfg.temperature_init, - target_entropy=policy_cfg.target_entropy, use_backup_entropy=policy_cfg.use_backup_entropy, critic_target_update_weight=policy_cfg.critic_target_update_weight, num_critics=policy_cfg.num_critics, num_subsample_critics=policy_cfg.num_subsample_critics, - num_discrete_actions=policy_cfg.num_discrete_actions, - shared_encoder=policy_cfg.shared_encoder, critic_network_kwargs=policy_cfg.critic_network_kwargs, discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs, - use_torch_compile=policy_cfg.use_torch_compile, + temperature_init=policy_cfg.temperature_init, + utd_ratio=policy_cfg.utd_ratio, + policy_update_freq=policy_cfg.policy_update_freq, + grad_clip_norm=policy_cfg.grad_clip_norm, + sac_config=policy_cfg, ) def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm: diff --git a/src/lerobot/rl/algorithms/sac/sac_algorithm.py b/src/lerobot/rl/algorithms/sac/sac_algorithm.py index d732b4b5a..e10ce005e 100644 --- a/src/lerobot/rl/algorithms/sac/sac_algorithm.py +++ b/src/lerobot/rl/algorithms/sac/sac_algorithm.py @@ -53,8 +53,9 @@ class SACAlgorithm(RLAlgorithm): policy: SACPolicy, config: SACAlgorithmConfig, ): - self.policy = policy self.config = config + self.policy_config = config.sac_config + self.policy = policy self.optimizers: dict[str, Optimizer] = {} self._optimization_step: int = 0 @@ -89,13 +90,13 @@ class SACAlgorithm(RLAlgorithm): # TODO(Khalil): Investigate and fix torch.compile # NOTE: torch.compile is disabled, policy does not converge when enabled. - if self.config.use_torch_compile: + if self.policy_config.use_torch_compile: self.critic_ensemble = torch.compile(self.critic_ensemble) self.critic_target = torch.compile(self.critic_target) self.discrete_critic = None self.discrete_critic_target = None - if self.config.num_discrete_actions is not None: + if self.policy_config.num_discrete_actions is not None: self.discrete_critic, self.discrete_critic_target = self._init_discrete_critics(encoder) self.policy.discrete_critic = self.discrete_critic @@ -104,13 +105,13 @@ class SACAlgorithm(RLAlgorithm): discrete_critic = DiscreteCritic( encoder=encoder, input_dim=encoder.output_dim, - output_dim=self.config.num_discrete_actions, + output_dim=self.policy_config.num_discrete_actions, **asdict(self.config.discrete_critic_network_kwargs), ) discrete_critic_target = DiscreteCritic( encoder=encoder, input_dim=encoder.output_dim, - output_dim=self.config.num_discrete_actions, + output_dim=self.policy_config.num_discrete_actions, **asdict(self.config.discrete_critic_network_kwargs), ) @@ -177,7 +178,7 @@ class SACAlgorithm(RLAlgorithm): return q_values def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: - clip = self.config.clip_grad_norm + clip = self.config.grad_clip_norm for _ in range(self.config.utd_ratio - 1): batch = next(batch_iterator) @@ -189,7 +190,7 @@ class SACAlgorithm(RLAlgorithm): torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip) self.optimizers["critic"].step() - if self.config.num_discrete_actions is not None: + if self.policy_config.num_discrete_actions is not None: loss_dc = self._compute_loss_discrete_critic(fb) self.optimizers["discrete_critic"].zero_grad() loss_dc.backward() @@ -212,7 +213,7 @@ class SACAlgorithm(RLAlgorithm): grad_norms={"critic": critic_grad}, ) - if self.config.num_discrete_actions is not None: + if self.policy_config.num_discrete_actions is not None: loss_dc = self._compute_loss_discrete_critic(fb) self.optimizers["discrete_critic"].zero_grad() loss_dc.backward() @@ -284,7 +285,7 @@ class SACAlgorithm(RLAlgorithm): td_target = rewards + (1 - done) * self.config.discount * min_q # 3- compute predicted qs - if self.config.num_discrete_actions is not None: + if self.policy_config.num_discrete_actions is not None: # NOTE: We only want to keep the continuous action part # In the buffer we have the full action space (continuous + discrete) # We need to split them before concatenating them in the critic forward @@ -405,7 +406,7 @@ class SACAlgorithm(RLAlgorithm): p.data * self.config.critic_target_update_weight + target_p.data * (1.0 - self.config.critic_target_update_weight) ) - if self.config.num_discrete_actions is not None: + if self.policy_config.num_discrete_actions is not None: for target_p, p in zip( self.discrete_critic_target.parameters(), self.discrete_critic.parameters(), @@ -466,7 +467,7 @@ class SACAlgorithm(RLAlgorithm): "critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr), "temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr), } - if self.config.num_discrete_actions is not None: + if self.policy_config.num_discrete_actions is not None: self.optimizers["discrete_critic"] = torch.optim.Adam( self.discrete_critic.parameters(), lr=self.config.critic_lr ) @@ -480,7 +481,7 @@ class SACAlgorithm(RLAlgorithm): state_dicts: dict[str, Any] = { "policy": move_state_dict_to_device(self.policy.actor.state_dict(), device="cpu"), } - if self.config.num_discrete_actions is not None: + if self.policy_config.num_discrete_actions is not None: state_dicts["discrete_critic"] = move_state_dict_to_device( self.discrete_critic.state_dict(), device="cpu" ) @@ -489,7 +490,7 @@ class SACAlgorithm(RLAlgorithm): def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: actor_sd = move_state_dict_to_device(weights["policy"], device=device) self.policy.actor.load_state_dict(actor_sd) - if "discrete_critic" in weights and self.config.num_discrete_actions is not None: + if "discrete_critic" in weights and self.policy_config.num_discrete_actions is not None: dc_sd = move_state_dict_to_device(weights["discrete_critic"], device=device) self.discrete_critic.load_state_dict(dc_sd) diff --git a/tests/rl/test_sac_algorithm.py b/tests/rl/test_sac_algorithm.py index 6ef3e3ddf..79019777f 100644 --- a/tests/rl/test_sac_algorithm.py +++ b/tests/rl/test_sac_algorithm.py @@ -136,12 +136,13 @@ def test_sac_algorithm_config_registered(): def test_sac_algorithm_config_from_policy_config(): - """from_policy_config should copy relevant fields.""" + """from_policy_config should copy algorithm hyperparameters from the policy config.""" sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2) algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg) + assert algo_cfg.sac_config is sac_cfg assert algo_cfg.utd_ratio == 4 assert algo_cfg.policy_update_freq == 2 - assert algo_cfg.clip_grad_norm == sac_cfg.grad_clip_norm + assert algo_cfg.grad_clip_norm == sac_cfg.grad_clip_norm # ===========================================================================