mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
refactor(sac): decouple algorithm hyperparameters from policy config
This commit is contained in:
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.policies.sac.configuration_sac import CriticNetworkConfig
|
from lerobot.policies.sac.configuration_sac import CriticNetworkConfig, SACConfig
|
||||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -29,49 +29,54 @@ if TYPE_CHECKING:
|
|||||||
@RLAlgorithmConfig.register_subclass("sac")
|
@RLAlgorithmConfig.register_subclass("sac")
|
||||||
@dataclass
|
@dataclass
|
||||||
class SACAlgorithmConfig(RLAlgorithmConfig):
|
class SACAlgorithmConfig(RLAlgorithmConfig):
|
||||||
"""SAC-specific hyper-parameters that control the update loop."""
|
"""SAC algorithm hyperparameters."""
|
||||||
|
|
||||||
utd_ratio: int = 1
|
# Policy config
|
||||||
policy_update_freq: int = 1
|
sac_config: SACConfig
|
||||||
clip_grad_norm: float = 40.0
|
|
||||||
|
# Optimizer learning rates
|
||||||
actor_lr: float = 3e-4
|
actor_lr: float = 3e-4
|
||||||
critic_lr: float = 3e-4
|
critic_lr: float = 3e-4
|
||||||
temperature_lr: float = 3e-4
|
temperature_lr: float = 3e-4
|
||||||
|
|
||||||
|
# Bellman update
|
||||||
discount: float = 0.99
|
discount: float = 0.99
|
||||||
temperature_init: float = 1.0
|
|
||||||
target_entropy: float | None = None
|
|
||||||
use_backup_entropy: bool = True
|
use_backup_entropy: bool = True
|
||||||
critic_target_update_weight: float = 0.005
|
critic_target_update_weight: float = 0.005
|
||||||
|
|
||||||
|
# Critic ensemble
|
||||||
num_critics: int = 2
|
num_critics: int = 2
|
||||||
num_subsample_critics: int | None = None
|
num_subsample_critics: int | None = None
|
||||||
num_discrete_actions: int | None = None
|
|
||||||
shared_encoder: bool = True
|
|
||||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||||
use_torch_compile: bool = False
|
|
||||||
|
# Temperature / entropy
|
||||||
|
temperature_init: float = 1.0
|
||||||
|
|
||||||
|
# Update loop
|
||||||
|
utd_ratio: int = 1
|
||||||
|
policy_update_freq: int = 1
|
||||||
|
grad_clip_norm: float = 40.0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig:
|
def from_policy_config(cls, policy_cfg: SACConfig) -> SACAlgorithmConfig:
|
||||||
"""Build from an existing ``SACConfig`` (cfg.policy) for backwards compat."""
|
"""Build an algorithm config by copying hyperparameters from the policy config."""
|
||||||
return cls(
|
return cls(
|
||||||
utd_ratio=policy_cfg.utd_ratio,
|
|
||||||
policy_update_freq=policy_cfg.policy_update_freq,
|
|
||||||
clip_grad_norm=policy_cfg.grad_clip_norm,
|
|
||||||
actor_lr=policy_cfg.actor_lr,
|
actor_lr=policy_cfg.actor_lr,
|
||||||
critic_lr=policy_cfg.critic_lr,
|
critic_lr=policy_cfg.critic_lr,
|
||||||
temperature_lr=policy_cfg.temperature_lr,
|
temperature_lr=policy_cfg.temperature_lr,
|
||||||
discount=policy_cfg.discount,
|
discount=policy_cfg.discount,
|
||||||
temperature_init=policy_cfg.temperature_init,
|
|
||||||
target_entropy=policy_cfg.target_entropy,
|
|
||||||
use_backup_entropy=policy_cfg.use_backup_entropy,
|
use_backup_entropy=policy_cfg.use_backup_entropy,
|
||||||
critic_target_update_weight=policy_cfg.critic_target_update_weight,
|
critic_target_update_weight=policy_cfg.critic_target_update_weight,
|
||||||
num_critics=policy_cfg.num_critics,
|
num_critics=policy_cfg.num_critics,
|
||||||
num_subsample_critics=policy_cfg.num_subsample_critics,
|
num_subsample_critics=policy_cfg.num_subsample_critics,
|
||||||
num_discrete_actions=policy_cfg.num_discrete_actions,
|
|
||||||
shared_encoder=policy_cfg.shared_encoder,
|
|
||||||
critic_network_kwargs=policy_cfg.critic_network_kwargs,
|
critic_network_kwargs=policy_cfg.critic_network_kwargs,
|
||||||
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
|
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
|
||||||
use_torch_compile=policy_cfg.use_torch_compile,
|
temperature_init=policy_cfg.temperature_init,
|
||||||
|
utd_ratio=policy_cfg.utd_ratio,
|
||||||
|
policy_update_freq=policy_cfg.policy_update_freq,
|
||||||
|
grad_clip_norm=policy_cfg.grad_clip_norm,
|
||||||
|
sac_config=policy_cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
|
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
|
||||||
|
|||||||
@@ -53,8 +53,9 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
policy: SACPolicy,
|
policy: SACPolicy,
|
||||||
config: SACAlgorithmConfig,
|
config: SACAlgorithmConfig,
|
||||||
):
|
):
|
||||||
self.policy = policy
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.policy_config = config.sac_config
|
||||||
|
self.policy = policy
|
||||||
self.optimizers: dict[str, Optimizer] = {}
|
self.optimizers: dict[str, Optimizer] = {}
|
||||||
self._optimization_step: int = 0
|
self._optimization_step: int = 0
|
||||||
|
|
||||||
@@ -89,13 +90,13 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
|
|
||||||
# TODO(Khalil): Investigate and fix torch.compile
|
# TODO(Khalil): Investigate and fix torch.compile
|
||||||
# NOTE: torch.compile is disabled, policy does not converge when enabled.
|
# NOTE: torch.compile is disabled, policy does not converge when enabled.
|
||||||
if self.config.use_torch_compile:
|
if self.policy_config.use_torch_compile:
|
||||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||||
self.critic_target = torch.compile(self.critic_target)
|
self.critic_target = torch.compile(self.critic_target)
|
||||||
|
|
||||||
self.discrete_critic = None
|
self.discrete_critic = None
|
||||||
self.discrete_critic_target = None
|
self.discrete_critic_target = None
|
||||||
if self.config.num_discrete_actions is not None:
|
if self.policy_config.num_discrete_actions is not None:
|
||||||
self.discrete_critic, self.discrete_critic_target = self._init_discrete_critics(encoder)
|
self.discrete_critic, self.discrete_critic_target = self._init_discrete_critics(encoder)
|
||||||
self.policy.discrete_critic = self.discrete_critic
|
self.policy.discrete_critic = self.discrete_critic
|
||||||
|
|
||||||
@@ -104,13 +105,13 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
discrete_critic = DiscreteCritic(
|
discrete_critic = DiscreteCritic(
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
input_dim=encoder.output_dim,
|
input_dim=encoder.output_dim,
|
||||||
output_dim=self.config.num_discrete_actions,
|
output_dim=self.policy_config.num_discrete_actions,
|
||||||
**asdict(self.config.discrete_critic_network_kwargs),
|
**asdict(self.config.discrete_critic_network_kwargs),
|
||||||
)
|
)
|
||||||
discrete_critic_target = DiscreteCritic(
|
discrete_critic_target = DiscreteCritic(
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
input_dim=encoder.output_dim,
|
input_dim=encoder.output_dim,
|
||||||
output_dim=self.config.num_discrete_actions,
|
output_dim=self.policy_config.num_discrete_actions,
|
||||||
**asdict(self.config.discrete_critic_network_kwargs),
|
**asdict(self.config.discrete_critic_network_kwargs),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -177,7 +178,7 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
return q_values
|
return q_values
|
||||||
|
|
||||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||||
clip = self.config.clip_grad_norm
|
clip = self.config.grad_clip_norm
|
||||||
|
|
||||||
for _ in range(self.config.utd_ratio - 1):
|
for _ in range(self.config.utd_ratio - 1):
|
||||||
batch = next(batch_iterator)
|
batch = next(batch_iterator)
|
||||||
@@ -189,7 +190,7 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip)
|
torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip)
|
||||||
self.optimizers["critic"].step()
|
self.optimizers["critic"].step()
|
||||||
|
|
||||||
if self.config.num_discrete_actions is not None:
|
if self.policy_config.num_discrete_actions is not None:
|
||||||
loss_dc = self._compute_loss_discrete_critic(fb)
|
loss_dc = self._compute_loss_discrete_critic(fb)
|
||||||
self.optimizers["discrete_critic"].zero_grad()
|
self.optimizers["discrete_critic"].zero_grad()
|
||||||
loss_dc.backward()
|
loss_dc.backward()
|
||||||
@@ -212,7 +213,7 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
grad_norms={"critic": critic_grad},
|
grad_norms={"critic": critic_grad},
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.num_discrete_actions is not None:
|
if self.policy_config.num_discrete_actions is not None:
|
||||||
loss_dc = self._compute_loss_discrete_critic(fb)
|
loss_dc = self._compute_loss_discrete_critic(fb)
|
||||||
self.optimizers["discrete_critic"].zero_grad()
|
self.optimizers["discrete_critic"].zero_grad()
|
||||||
loss_dc.backward()
|
loss_dc.backward()
|
||||||
@@ -284,7 +285,7 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||||
|
|
||||||
# 3- compute predicted qs
|
# 3- compute predicted qs
|
||||||
if self.config.num_discrete_actions is not None:
|
if self.policy_config.num_discrete_actions is not None:
|
||||||
# NOTE: We only want to keep the continuous action part
|
# NOTE: We only want to keep the continuous action part
|
||||||
# In the buffer we have the full action space (continuous + discrete)
|
# In the buffer we have the full action space (continuous + discrete)
|
||||||
# We need to split them before concatenating them in the critic forward
|
# We need to split them before concatenating them in the critic forward
|
||||||
@@ -405,7 +406,7 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
p.data * self.config.critic_target_update_weight
|
p.data * self.config.critic_target_update_weight
|
||||||
+ target_p.data * (1.0 - self.config.critic_target_update_weight)
|
+ target_p.data * (1.0 - self.config.critic_target_update_weight)
|
||||||
)
|
)
|
||||||
if self.config.num_discrete_actions is not None:
|
if self.policy_config.num_discrete_actions is not None:
|
||||||
for target_p, p in zip(
|
for target_p, p in zip(
|
||||||
self.discrete_critic_target.parameters(),
|
self.discrete_critic_target.parameters(),
|
||||||
self.discrete_critic.parameters(),
|
self.discrete_critic.parameters(),
|
||||||
@@ -466,7 +467,7 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
"critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr),
|
"critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr),
|
||||||
"temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr),
|
"temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr),
|
||||||
}
|
}
|
||||||
if self.config.num_discrete_actions is not None:
|
if self.policy_config.num_discrete_actions is not None:
|
||||||
self.optimizers["discrete_critic"] = torch.optim.Adam(
|
self.optimizers["discrete_critic"] = torch.optim.Adam(
|
||||||
self.discrete_critic.parameters(), lr=self.config.critic_lr
|
self.discrete_critic.parameters(), lr=self.config.critic_lr
|
||||||
)
|
)
|
||||||
@@ -480,7 +481,7 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
state_dicts: dict[str, Any] = {
|
state_dicts: dict[str, Any] = {
|
||||||
"policy": move_state_dict_to_device(self.policy.actor.state_dict(), device="cpu"),
|
"policy": move_state_dict_to_device(self.policy.actor.state_dict(), device="cpu"),
|
||||||
}
|
}
|
||||||
if self.config.num_discrete_actions is not None:
|
if self.policy_config.num_discrete_actions is not None:
|
||||||
state_dicts["discrete_critic"] = move_state_dict_to_device(
|
state_dicts["discrete_critic"] = move_state_dict_to_device(
|
||||||
self.discrete_critic.state_dict(), device="cpu"
|
self.discrete_critic.state_dict(), device="cpu"
|
||||||
)
|
)
|
||||||
@@ -489,7 +490,7 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||||
actor_sd = move_state_dict_to_device(weights["policy"], device=device)
|
actor_sd = move_state_dict_to_device(weights["policy"], device=device)
|
||||||
self.policy.actor.load_state_dict(actor_sd)
|
self.policy.actor.load_state_dict(actor_sd)
|
||||||
if "discrete_critic" in weights and self.config.num_discrete_actions is not None:
|
if "discrete_critic" in weights and self.policy_config.num_discrete_actions is not None:
|
||||||
dc_sd = move_state_dict_to_device(weights["discrete_critic"], device=device)
|
dc_sd = move_state_dict_to_device(weights["discrete_critic"], device=device)
|
||||||
self.discrete_critic.load_state_dict(dc_sd)
|
self.discrete_critic.load_state_dict(dc_sd)
|
||||||
|
|
||||||
|
|||||||
@@ -136,12 +136,13 @@ def test_sac_algorithm_config_registered():
|
|||||||
|
|
||||||
|
|
||||||
def test_sac_algorithm_config_from_policy_config():
|
def test_sac_algorithm_config_from_policy_config():
|
||||||
"""from_policy_config should copy relevant fields."""
|
"""from_policy_config should copy algorithm hyperparameters from the policy config."""
|
||||||
sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2)
|
sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2)
|
||||||
algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||||
|
assert algo_cfg.sac_config is sac_cfg
|
||||||
assert algo_cfg.utd_ratio == 4
|
assert algo_cfg.utd_ratio == 4
|
||||||
assert algo_cfg.policy_update_freq == 2
|
assert algo_cfg.policy_update_freq == 2
|
||||||
assert algo_cfg.clip_grad_norm == sac_cfg.grad_clip_norm
|
assert algo_cfg.grad_clip_norm == sac_cfg.grad_clip_norm
|
||||||
|
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
|
|||||||
Reference in New Issue
Block a user