mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
refactor(sac): decouple algorithm hyperparameters from policy config
This commit is contained in:
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sac.configuration_sac import CriticNetworkConfig
|
||||
from lerobot.policies.sac.configuration_sac import CriticNetworkConfig, SACConfig
|
||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -29,49 +29,54 @@ if TYPE_CHECKING:
|
||||
@RLAlgorithmConfig.register_subclass("sac")
|
||||
@dataclass
|
||||
class SACAlgorithmConfig(RLAlgorithmConfig):
|
||||
"""SAC-specific hyper-parameters that control the update loop."""
|
||||
"""SAC algorithm hyperparameters."""
|
||||
|
||||
utd_ratio: int = 1
|
||||
policy_update_freq: int = 1
|
||||
clip_grad_norm: float = 40.0
|
||||
# Policy config
|
||||
sac_config: SACConfig
|
||||
|
||||
# Optimizer learning rates
|
||||
actor_lr: float = 3e-4
|
||||
critic_lr: float = 3e-4
|
||||
temperature_lr: float = 3e-4
|
||||
|
||||
# Bellman update
|
||||
discount: float = 0.99
|
||||
temperature_init: float = 1.0
|
||||
target_entropy: float | None = None
|
||||
use_backup_entropy: bool = True
|
||||
critic_target_update_weight: float = 0.005
|
||||
|
||||
# Critic ensemble
|
||||
num_critics: int = 2
|
||||
num_subsample_critics: int | None = None
|
||||
num_discrete_actions: int | None = None
|
||||
shared_encoder: bool = True
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
use_torch_compile: bool = False
|
||||
|
||||
# Temperature / entropy
|
||||
temperature_init: float = 1.0
|
||||
|
||||
# Update loop
|
||||
utd_ratio: int = 1
|
||||
policy_update_freq: int = 1
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig:
|
||||
"""Build from an existing ``SACConfig`` (cfg.policy) for backwards compat."""
|
||||
def from_policy_config(cls, policy_cfg: SACConfig) -> SACAlgorithmConfig:
|
||||
"""Build an algorithm config by copying hyperparameters from the policy config."""
|
||||
return cls(
|
||||
utd_ratio=policy_cfg.utd_ratio,
|
||||
policy_update_freq=policy_cfg.policy_update_freq,
|
||||
clip_grad_norm=policy_cfg.grad_clip_norm,
|
||||
actor_lr=policy_cfg.actor_lr,
|
||||
critic_lr=policy_cfg.critic_lr,
|
||||
temperature_lr=policy_cfg.temperature_lr,
|
||||
discount=policy_cfg.discount,
|
||||
temperature_init=policy_cfg.temperature_init,
|
||||
target_entropy=policy_cfg.target_entropy,
|
||||
use_backup_entropy=policy_cfg.use_backup_entropy,
|
||||
critic_target_update_weight=policy_cfg.critic_target_update_weight,
|
||||
num_critics=policy_cfg.num_critics,
|
||||
num_subsample_critics=policy_cfg.num_subsample_critics,
|
||||
num_discrete_actions=policy_cfg.num_discrete_actions,
|
||||
shared_encoder=policy_cfg.shared_encoder,
|
||||
critic_network_kwargs=policy_cfg.critic_network_kwargs,
|
||||
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
|
||||
use_torch_compile=policy_cfg.use_torch_compile,
|
||||
temperature_init=policy_cfg.temperature_init,
|
||||
utd_ratio=policy_cfg.utd_ratio,
|
||||
policy_update_freq=policy_cfg.policy_update_freq,
|
||||
grad_clip_norm=policy_cfg.grad_clip_norm,
|
||||
sac_config=policy_cfg,
|
||||
)
|
||||
|
||||
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
|
||||
|
||||
@@ -53,8 +53,9 @@ class SACAlgorithm(RLAlgorithm):
|
||||
policy: SACPolicy,
|
||||
config: SACAlgorithmConfig,
|
||||
):
|
||||
self.policy = policy
|
||||
self.config = config
|
||||
self.policy_config = config.sac_config
|
||||
self.policy = policy
|
||||
self.optimizers: dict[str, Optimizer] = {}
|
||||
self._optimization_step: int = 0
|
||||
|
||||
@@ -89,13 +90,13 @@ class SACAlgorithm(RLAlgorithm):
|
||||
|
||||
# TODO(Khalil): Investigate and fix torch.compile
|
||||
# NOTE: torch.compile is disabled, policy does not converge when enabled.
|
||||
if self.config.use_torch_compile:
|
||||
if self.policy_config.use_torch_compile:
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
self.discrete_critic = None
|
||||
self.discrete_critic_target = None
|
||||
if self.config.num_discrete_actions is not None:
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
self.discrete_critic, self.discrete_critic_target = self._init_discrete_critics(encoder)
|
||||
self.policy.discrete_critic = self.discrete_critic
|
||||
|
||||
@@ -104,13 +105,13 @@ class SACAlgorithm(RLAlgorithm):
|
||||
discrete_critic = DiscreteCritic(
|
||||
encoder=encoder,
|
||||
input_dim=encoder.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
output_dim=self.policy_config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
discrete_critic_target = DiscreteCritic(
|
||||
encoder=encoder,
|
||||
input_dim=encoder.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
output_dim=self.policy_config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
|
||||
@@ -177,7 +178,7 @@ class SACAlgorithm(RLAlgorithm):
|
||||
return q_values
|
||||
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
clip = self.config.clip_grad_norm
|
||||
clip = self.config.grad_clip_norm
|
||||
|
||||
for _ in range(self.config.utd_ratio - 1):
|
||||
batch = next(batch_iterator)
|
||||
@@ -189,7 +190,7 @@ class SACAlgorithm(RLAlgorithm):
|
||||
torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip)
|
||||
self.optimizers["critic"].step()
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
loss_dc = self._compute_loss_discrete_critic(fb)
|
||||
self.optimizers["discrete_critic"].zero_grad()
|
||||
loss_dc.backward()
|
||||
@@ -212,7 +213,7 @@ class SACAlgorithm(RLAlgorithm):
|
||||
grad_norms={"critic": critic_grad},
|
||||
)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
loss_dc = self._compute_loss_discrete_critic(fb)
|
||||
self.optimizers["discrete_critic"].zero_grad()
|
||||
loss_dc.backward()
|
||||
@@ -284,7 +285,7 @@ class SACAlgorithm(RLAlgorithm):
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
if self.config.num_discrete_actions is not None:
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
# NOTE: We only want to keep the continuous action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
@@ -405,7 +406,7 @@ class SACAlgorithm(RLAlgorithm):
|
||||
p.data * self.config.critic_target_update_weight
|
||||
+ target_p.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
if self.config.num_discrete_actions is not None:
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
for target_p, p in zip(
|
||||
self.discrete_critic_target.parameters(),
|
||||
self.discrete_critic.parameters(),
|
||||
@@ -466,7 +467,7 @@ class SACAlgorithm(RLAlgorithm):
|
||||
"critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr),
|
||||
"temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr),
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
self.optimizers["discrete_critic"] = torch.optim.Adam(
|
||||
self.discrete_critic.parameters(), lr=self.config.critic_lr
|
||||
)
|
||||
@@ -480,7 +481,7 @@ class SACAlgorithm(RLAlgorithm):
|
||||
state_dicts: dict[str, Any] = {
|
||||
"policy": move_state_dict_to_device(self.policy.actor.state_dict(), device="cpu"),
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
state_dicts["discrete_critic"] = move_state_dict_to_device(
|
||||
self.discrete_critic.state_dict(), device="cpu"
|
||||
)
|
||||
@@ -489,7 +490,7 @@ class SACAlgorithm(RLAlgorithm):
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
actor_sd = move_state_dict_to_device(weights["policy"], device=device)
|
||||
self.policy.actor.load_state_dict(actor_sd)
|
||||
if "discrete_critic" in weights and self.config.num_discrete_actions is not None:
|
||||
if "discrete_critic" in weights and self.policy_config.num_discrete_actions is not None:
|
||||
dc_sd = move_state_dict_to_device(weights["discrete_critic"], device=device)
|
||||
self.discrete_critic.load_state_dict(dc_sd)
|
||||
|
||||
|
||||
@@ -136,12 +136,13 @@ def test_sac_algorithm_config_registered():
|
||||
|
||||
|
||||
def test_sac_algorithm_config_from_policy_config():
|
||||
"""from_policy_config should copy relevant fields."""
|
||||
"""from_policy_config should copy algorithm hyperparameters from the policy config."""
|
||||
sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2)
|
||||
algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||
assert algo_cfg.sac_config is sac_cfg
|
||||
assert algo_cfg.utd_ratio == 4
|
||||
assert algo_cfg.policy_update_freq == 2
|
||||
assert algo_cfg.clip_grad_norm == sac_cfg.grad_clip_norm
|
||||
assert algo_cfg.grad_clip_norm == sac_cfg.grad_clip_norm
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
|
||||
Reference in New Issue
Block a user