refactor(rl/sac): consolidate hyperparameter ownership and clean up discrete critic

This commit is contained in:
Khalil Meftah
2026-04-24 13:18:33 +02:00
parent 06255996ea
commit 1ed32210c7
9 changed files with 162 additions and 190 deletions
+1 -1
View File
@@ -926,7 +926,7 @@ The ideal behaviour is that your intervention rate should drop gradually during
Some configuration values have a disproportionate impact on training stability and speed:
- **`temperature_init`** (`policy.temperature_init`) initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
- **`temperature_init`** (`algorithm.temperature_init`) initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
- **`storage_device`** (`policy.storage_device`) device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second.
@@ -136,80 +136,41 @@ class GaussianActorConfig(PreTrainedConfig):
# Dimension of the image embedding pooling
image_embedding_pooling_dim: int = 8
# Training parameter
# Number of steps for online training
online_steps: int = 1000000
# Capacity of the online replay buffer
online_buffer_capacity: int = 100000
# Capacity of the offline replay buffer
offline_buffer_capacity: int = 100000
# Whether to use asynchronous prefetching for the buffers
async_prefetch: bool = False
# Number of steps before learning starts
online_step_before_learning: int = 100
# Frequency of policy updates
policy_update_freq: int = 1
# SAC algorithm parameters
# Discount factor for the SAC algorithm
discount: float = 0.99
# Initial temperature value
temperature_init: float = 1.0
# Number of critics in the ensemble
num_critics: int = 2
# Number of subsampled critics for training
num_subsample_critics: int | None = None
# Learning rate for the critic network
critic_lr: float = 3e-4
# Learning rate for the actor network
actor_lr: float = 3e-4
# Learning rate for the temperature parameter
temperature_lr: float = 3e-4
# Weight for the critic target update
critic_target_update_weight: float = 0.005
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
utd_ratio: int = 1
# Encoder architecture
# Hidden dimension size for the state encoder
state_encoder_hidden_dim: int = 256
# Dimension of the latent space
latent_dim: int = 256
# Target entropy for the SAC algorithm
target_entropy: float | None = None
# Whether to use backup entropy for the SAC algorithm
use_backup_entropy: bool = True
# Gradient clipping norm for the SAC algorithm
grad_clip_norm: float = 40.0
# Network configuration
# Configuration for the critic network architecture
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for the actor network architecture
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
# Configuration for the policy parameters
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
# Configuration for the discrete critic network
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for actor-learner architecture
# Online training (TODO(Khalil): relocate to TrainRLServerPipelineConfig)
online_steps: int = 1000000
online_buffer_capacity: int = 100000
offline_buffer_capacity: int = 100000
async_prefetch: bool = False
online_step_before_learning: int = 100
# Actor-learner transport (TODO(Khalil): relocate to TrainRLServerPipelineConfig).
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
# Optimizations
# torch.compile is currently disabled by default due to known issues with the SAC
# critic ensemble and shared encoder.
use_torch_compile: bool = False
# Network architecture
# Actor network
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
# Gaussian head parameters
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
# Discrete critic
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
def __post_init__(self):
super().__post_init__()
# Any validation specific to SAC configuration
def get_optimizer_preset(self) -> MultiAdamConfig:
return MultiAdamConfig(
weight_decay=0.0,
optimizer_groups={
"actor": {"lr": self.actor_lr},
"critic": {"lr": self.critic_lr},
"temperature": {"lr": self.temperature_lr},
"actor": {"lr": 3e-4},
"critic": {"lr": 3e-4},
"temperature": {"lr": 3e-4},
},
)
@@ -19,7 +19,6 @@ from collections.abc import Callable
from dataclasses import asdict
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
@@ -61,7 +60,7 @@ class GaussianActorPolicy(
continuous_action_dim = config.output_features[ACTION].shape[0]
self._init_encoders()
self._init_actor(continuous_action_dim)
self.discrete_critic = None
self._init_discrete_critic()
def get_optim_params(self) -> dict:
optim_params = {
@@ -125,19 +124,14 @@ class GaussianActorPolicy(
def load_actor_weights(self, state_dicts: dict[str, Any], device: str | torch.device = "cpu") -> None:
from lerobot.utils.transition import move_state_dict_to_device
actor_sd = move_state_dict_to_device(state_dicts["policy"], device=device)
self.actor.load_state_dict(actor_sd)
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
self.actor.load_state_dict(actor_state_dict)
if "discrete_critic" in state_dicts:
dc_sd = move_state_dict_to_device(state_dicts["discrete_critic"], device=device)
if self.discrete_critic is None:
self.discrete_critic = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
).to(device)
self.discrete_critic.load_state_dict(dc_sd)
if "discrete_critic" in state_dicts and self.discrete_critic is not None:
discrete_critic_state_dict = move_state_dict_to_device(
state_dicts["discrete_critic"], device=device
)
self.discrete_critic.load_state_dict(discrete_critic_state_dict)
def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic."""
@@ -148,7 +142,7 @@ class GaussianActorPolicy(
)
def _init_actor(self, continuous_action_dim):
"""Initialize policy actor network and default target entropy."""
"""Initialize policy actor network."""
# NOTE: The actor select only the continuous action part
self.actor = Policy(
encoder=self.encoder_actor,
@@ -158,10 +152,19 @@ class GaussianActorPolicy(
**asdict(self.config.policy_kwargs),
)
self.target_entropy = self.config.target_entropy
if self.target_entropy is None:
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.target_entropy = -np.prod(dim) / 2
def _init_discrete_critic(self) -> None:
"""Initialize discrete critic network."""
if self.config.num_discrete_actions is None:
self.discrete_critic = None
return
# TODO(Khalil): Compile the discrete critic
self.discrete_critic = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
class GaussianActorObservationEncoder(nn.Module):
@@ -35,7 +35,7 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
"""SAC algorithm hyperparameters."""
# Policy config
sac_config: GaussianActorConfig
policy_config: GaussianActorConfig
# Optimizer learning rates
actor_lr: float = 3e-4
@@ -55,31 +55,26 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
# Temperature / entropy
temperature_init: float = 1.0
# Target entropy for automatic temperature tuning. If ``None``, defaults to
# ``-|A|/2`` where ``|A|`` is the total action dimension (continuous + 1 if
# there is a discrete action head).
target_entropy: float | None = None
# Update loop
utd_ratio: int = 1
policy_update_freq: int = 1
grad_clip_norm: float = 40.0
# Optimizations
# torch.compile is currently disabled by default
use_torch_compile: bool = False
@classmethod
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
"""Build an algorithm config by copying hyperparameters from the policy config."""
"""Build an algorithm config with default hyperparameters for a given policy."""
return cls(
actor_lr=policy_cfg.actor_lr,
critic_lr=policy_cfg.critic_lr,
temperature_lr=policy_cfg.temperature_lr,
discount=policy_cfg.discount,
use_backup_entropy=policy_cfg.use_backup_entropy,
critic_target_update_weight=policy_cfg.critic_target_update_weight,
num_critics=policy_cfg.num_critics,
num_subsample_critics=policy_cfg.num_subsample_critics,
critic_network_kwargs=policy_cfg.critic_network_kwargs,
policy_config=policy_cfg,
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
temperature_init=policy_cfg.temperature_init,
utd_ratio=policy_cfg.utd_ratio,
policy_update_freq=policy_cfg.policy_update_freq,
grad_clip_norm=policy_cfg.grad_clip_norm,
sac_config=policy_cfg,
)
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
+29 -35
View File
@@ -54,14 +54,14 @@ class SACAlgorithm(RLAlgorithm):
config: SACAlgorithmConfig,
):
self.config = config
self.policy_config = config.sac_config
self.policy_config = config.policy_config
self.policy = policy
self.optimizers: dict[str, Optimizer] = {}
self._optimization_step: int = 0
action_dim = self.policy.config.output_features[ACTION].shape[0]
self._init_critics(action_dim)
self._init_temperature()
self._init_temperature(action_dim)
self._device = torch.device(self.policy.config.device)
self._move_to_device()
@@ -90,49 +90,44 @@ class SACAlgorithm(RLAlgorithm):
# TODO(Khalil): Investigate and fix torch.compile
# NOTE: torch.compile is disabled, policy does not converge when enabled.
if self.policy_config.use_torch_compile:
if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
self.discrete_critic = None
self.discrete_critic_target = None
if self.policy_config.num_discrete_actions is not None:
self.discrete_critic, self.discrete_critic_target = self._init_discrete_critics(encoder)
self.policy.discrete_critic = self.discrete_critic
self.discrete_critic_target = self._init_discrete_critic_target(encoder)
def _init_discrete_critics(
self, encoder: GaussianActorObservationEncoder
) -> tuple[DiscreteCritic, DiscreteCritic]:
"""Build discrete critic ensemble and target networks."""
discrete_critic = DiscreteCritic(
encoder=encoder,
input_dim=encoder.output_dim,
output_dim=self.policy_config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
def _init_discrete_critic_target(self, encoder: GaussianActorObservationEncoder) -> DiscreteCritic:
"""Build target discrete critic (main network is owned by the policy)."""
discrete_critic_target = DiscreteCritic(
encoder=encoder,
input_dim=encoder.output_dim,
output_dim=self.policy_config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
# TODO(Khalil): Compile the discrete critic
discrete_critic_target.load_state_dict(discrete_critic.state_dict())
return discrete_critic, discrete_critic_target
discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict())
return discrete_critic_target
def _init_temperature(self) -> None:
"""Set up temperature parameter (log_alpha)."""
def _init_temperature(self, continuous_action_dim: int) -> None:
"""Set up temperature parameter (log_alpha) and target entropy."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
self.target_entropy = self.config.target_entropy
if self.target_entropy is None:
total_action_dim = continuous_action_dim + (
1 if self.policy_config.num_discrete_actions is not None else 0
)
self.target_entropy = -total_action_dim / 2
def _move_to_device(self) -> None:
self.policy.to(self._device)
self.critic_ensemble.to(self._device)
self.critic_target.to(self._device)
self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device))
if self.discrete_critic is not None:
self.discrete_critic.to(self._device)
if self.discrete_critic_target is not None:
self.discrete_critic_target.to(self._device)
@property
@@ -175,7 +170,7 @@ class SACAlgorithm(RLAlgorithm):
Returns:
Tensor of Q-values from the discrete critic network
"""
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
discrete_critic = self.discrete_critic_target if use_target else self.policy.discrete_critic
q_values = discrete_critic(observations, observation_features)
return q_values
@@ -196,7 +191,7 @@ class SACAlgorithm(RLAlgorithm):
loss_dc = self._compute_loss_discrete_critic(fb)
self.optimizers["discrete_critic"].zero_grad()
loss_dc.backward()
torch.nn.utils.clip_grad_norm_(self.discrete_critic.parameters(), max_norm=clip)
torch.nn.utils.clip_grad_norm_(self.policy.discrete_critic.parameters(), max_norm=clip)
self.optimizers["discrete_critic"].step()
self._update_target_networks()
@@ -219,7 +214,9 @@ class SACAlgorithm(RLAlgorithm):
loss_dc = self._compute_loss_discrete_critic(fb)
self.optimizers["discrete_critic"].zero_grad()
loss_dc.backward()
dc_grad = torch.nn.utils.clip_grad_norm_(self.discrete_critic.parameters(), max_norm=clip).item()
dc_grad = torch.nn.utils.clip_grad_norm_(
self.policy.discrete_critic.parameters(), max_norm=clip
).item()
self.optimizers["discrete_critic"].step()
stats.losses["loss_discrete_critic"] = loss_dc.item()
stats.grad_norms["discrete_critic"] = dc_grad
@@ -396,7 +393,7 @@ class SACAlgorithm(RLAlgorithm):
with torch.no_grad():
_, log_probs, _ = self.policy.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.policy.target_entropy)).mean()
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
return temperature_loss
def _update_target_networks(self) -> None:
@@ -411,7 +408,7 @@ class SACAlgorithm(RLAlgorithm):
if self.policy_config.num_discrete_actions is not None:
for target_p, p in zip(
self.discrete_critic_target.parameters(),
self.discrete_critic.parameters(),
self.policy.discrete_critic.parameters(),
strict=True,
):
target_p.data.copy_(
@@ -471,7 +468,7 @@ class SACAlgorithm(RLAlgorithm):
}
if self.policy_config.num_discrete_actions is not None:
self.optimizers["discrete_critic"] = torch.optim.Adam(
self.discrete_critic.parameters(), lr=self.config.critic_lr
self.policy.discrete_critic.parameters(), lr=self.config.critic_lr
)
return self.optimizers
@@ -485,16 +482,13 @@ class SACAlgorithm(RLAlgorithm):
}
if self.policy_config.num_discrete_actions is not None:
state_dicts["discrete_critic"] = move_state_dict_to_device(
self.discrete_critic.state_dict(), device="cpu"
self.policy.discrete_critic.state_dict(), device="cpu"
)
return state_dicts
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
actor_sd = move_state_dict_to_device(weights["policy"], device=device)
self.policy.actor.load_state_dict(actor_sd)
if "discrete_critic" in weights and self.policy_config.num_discrete_actions is not None:
dc_sd = move_state_dict_to_device(weights["discrete_critic"], device=device)
self.discrete_critic.load_state_dict(dc_sd)
"""Load actor + discrete-critic weights into the policy."""
self.policy.load_actor_weights(weights, device=device)
def get_observation_features(
self, observations: Tensor, next_observations: Tensor
+8 -29
View File
@@ -55,9 +55,6 @@ def test_gaussian_actor_config_default_initialization():
# Basic parameters
assert config.device == "cpu"
assert config.storage_device == "cpu"
assert config.discount == 0.99
assert config.temperature_init == 1.0
assert config.num_critics == 2
# Architecture specifics
assert config.vision_encoder_name is None
@@ -66,6 +63,8 @@ def test_gaussian_actor_config_default_initialization():
assert config.shared_encoder is True
assert config.num_discrete_actions is None
assert config.image_embedding_pooling_dim == 8
assert config.state_encoder_hidden_dim == 256
assert config.latent_dim == 256
# Training parameters
assert config.online_steps == 1000000
@@ -73,20 +72,6 @@ def test_gaussian_actor_config_default_initialization():
assert config.offline_buffer_capacity == 100000
assert config.async_prefetch is False
assert config.online_step_before_learning == 100
assert config.policy_update_freq == 1
# SAC algorithm parameters
assert config.num_subsample_critics is None
assert config.critic_lr == 3e-4
assert config.actor_lr == 3e-4
assert config.temperature_lr == 3e-4
assert config.critic_target_update_weight == 0.005
assert config.utd_ratio == 1
assert config.state_encoder_hidden_dim == 256
assert config.latent_dim == 256
assert config.target_entropy is None
assert config.use_backup_entropy is True
assert config.grad_clip_norm == 40.0
# Dataset stats defaults
expected_dataset_stats = {
@@ -105,11 +90,6 @@ def test_gaussian_actor_config_default_initialization():
}
assert config.dataset_stats == expected_dataset_stats
# Critic network configuration
assert config.critic_network_kwargs.hidden_dims == [256, 256]
assert config.critic_network_kwargs.activate_final is True
assert config.critic_network_kwargs.final_activation is None
# Actor network configuration
assert config.actor_network_kwargs.hidden_dims == [256, 256]
assert config.actor_network_kwargs.activate_final is True
@@ -135,7 +115,6 @@ def test_gaussian_actor_config_default_initialization():
assert config.concurrency.learner == "threads"
assert isinstance(config.actor_network_kwargs, ActorNetworkConfig)
assert isinstance(config.critic_network_kwargs, CriticNetworkConfig)
assert isinstance(config.policy_kwargs, PolicyConfig)
assert isinstance(config.actor_learner_config, ActorLearnerConfig)
assert isinstance(config.concurrency, ConcurrencyConfig)
@@ -178,15 +157,15 @@ def test_concurrency_config():
def test_gaussian_actor_config_custom_initialization():
config = GaussianActorConfig(
device="cpu",
discount=0.95,
temperature_init=0.5,
num_critics=3,
latent_dim=128,
state_encoder_hidden_dim=128,
num_discrete_actions=3,
)
assert config.device == "cpu"
assert config.discount == 0.95
assert config.temperature_init == 0.5
assert config.num_critics == 3
assert config.latent_dim == 128
assert config.state_encoder_hidden_dim == 128
assert config.num_discrete_actions == 3
def test_validate_features():
+10 -9
View File
@@ -404,19 +404,16 @@ def test_sac_training_with_discrete_critic():
def test_sac_algorithm_target_entropy():
"""Target entropy is an SAC hyperparameter and lives on the algorithm."""
config = create_default_config(continuous_action_dim=10, state_dim=10)
_, policy = _make_algorithm(config)
algo_config = SACAlgorithmConfig.from_policy_config(config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm, _ = _make_algorithm(config)
assert algorithm.target_entropy == -5.0
def test_sac_algorithm_target_entropy_with_discrete_action():
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
config.num_discrete_actions = 5
algo_config = SACAlgorithmConfig.from_policy_config(config)
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm, _ = _make_algorithm(config)
assert algorithm.target_entropy == -3.5
@@ -435,8 +432,8 @@ def test_sac_algorithm_temperature():
def test_sac_algorithm_update_target_network():
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.critic_target_update_weight = 1.0
algo_config = SACAlgorithmConfig.from_policy_config(config)
algo_config.critic_target_update_weight = 1.0
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
@@ -454,9 +451,13 @@ def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
action_dim = 10
state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.num_critics = num_critics
algorithm, policy = _make_algorithm(config)
policy = GaussianActorPolicy(config=config)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(config)
algo_config.num_critics = num_critics
algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm.make_optimizers_and_scheduler()
assert len(algorithm.critic_ensemble.critics) == num_critics
-3
View File
@@ -327,7 +327,6 @@ def test_learner_algorithm_wiring():
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
use_torch_compile=False,
)
sac_cfg.validate_features()
@@ -412,7 +411,6 @@ def test_initial_and_periodic_weight_push_consistency():
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
use_torch_compile=False,
)
sac_cfg.validate_features()
@@ -450,7 +448,6 @@ def test_actor_side_algorithm_select_action_and_load_weights():
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
use_torch_compile=False,
)
sac_cfg.validate_features()
+62 -20
View File
@@ -44,8 +44,6 @@ def _make_sac_config(
state_dim: int = 10,
action_dim: int = 6,
num_discrete_actions: int | None = None,
utd_ratio: int = 1,
policy_update_freq: int = 1,
with_images: bool = False,
) -> GaussianActorConfig:
config = GaussianActorConfig(
@@ -55,10 +53,7 @@ def _make_sac_config(
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
utd_ratio=utd_ratio,
policy_update_freq=policy_update_freq,
num_discrete_actions=num_discrete_actions,
use_torch_compile=False,
)
if with_images:
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
@@ -83,14 +78,14 @@ def _make_algorithm(
sac_cfg = _make_sac_config(
state_dim=state_dim,
action_dim=action_dim,
utd_ratio=utd_ratio,
policy_update_freq=policy_update_freq,
num_discrete_actions=num_discrete_actions,
with_images=with_images,
)
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
algo_config.utd_ratio = utd_ratio
algo_config.policy_update_freq = policy_update_freq
algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm.make_optimizers_and_scheduler()
return algorithm, policy
@@ -136,13 +131,16 @@ def test_sac_algorithm_config_registered():
def test_sac_algorithm_config_from_policy_config():
"""from_policy_config should copy algorithm hyperparameters from the policy config."""
sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2)
"""from_policy_config embeds the policy config and uses SAC defaults."""
sac_cfg = _make_sac_config()
algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg)
assert algo_cfg.sac_config is sac_cfg
assert algo_cfg.utd_ratio == 4
assert algo_cfg.policy_update_freq == 2
assert algo_cfg.grad_clip_norm == sac_cfg.grad_clip_norm
assert algo_cfg.policy_config is sac_cfg
assert algo_cfg.discrete_critic_network_kwargs is sac_cfg.discrete_critic_network_kwargs
# Defaults come from SACAlgorithmConfig, not from the policy config.
assert algo_cfg.utd_ratio == 1
assert algo_cfg.policy_update_freq == 1
assert algo_cfg.grad_clip_norm == 40.0
assert algo_cfg.actor_lr == 3e-4
# ===========================================================================
@@ -377,12 +375,14 @@ def test_actor_side_no_optimizers():
assert algorithm.optimizers == {}
def test_make_algorithm_copies_config_fields():
sac_cfg = _make_sac_config(utd_ratio=5, policy_update_freq=3)
def test_make_algorithm_uses_sac_algorithm_defaults():
"""make_algorithm populates SACAlgorithmConfig with its own defaults."""
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert algorithm.config.utd_ratio == 5
assert algorithm.config.policy_update_freq == 3
assert algorithm.config.utd_ratio == 1
assert algorithm.config.policy_update_freq == 1
assert algorithm.config.grad_clip_norm == 40.0
def test_make_algorithm_raises_for_unknown_type():
@@ -431,10 +431,10 @@ def test_load_weights_round_trip_with_discrete_critic():
assert "discrete_critic" in weights
assert len(weights["discrete_critic"]) > 0
dst_dc_state_dict = algo_dst.discrete_critic.state_dict()
dst_discrete_critic_state_dict = algo_dst.policy.discrete_critic.state_dict()
for key, tensor in weights["discrete_critic"].items():
assert torch.equal(
dst_dc_state_dict[key].cpu(),
dst_discrete_critic_state_dict[key].cpu(),
tensor.cpu(),
), f"Discrete critic param '{key}' mismatch after load_weights"
@@ -446,6 +446,47 @@ def test_load_weights_ignores_missing_discrete_critic():
algorithm.load_weights(weights, device="cpu")
def test_actor_side_weight_sync_with_discrete_critic():
"""End-to-end: learner ``algorithm.get_weights()`` -> actor ``policy.load_actor_weights()``."""
# Learner side: train the algorithm so its weights diverge from init.
algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
algo_src.update(_batch_iterator(action_dim=7))
weights = algo_src.get_weights()
# Actor side: fresh policy, no algorithm/optimizer.
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
policy_actor = GaussianActorPolicy(config=sac_cfg)
# Snapshot initial actor state for the "did it change?" assertion below.
initial_discrete_critic_state_dict = {
k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items()
}
policy_actor.load_actor_weights(weights, device="cpu")
# Actor weights match the learner's exported actor state dict.
actor_state_dict = policy_actor.actor.state_dict()
for key, tensor in weights["policy"].items():
assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), (
f"Actor param '{key}' not synced by load_actor_weights"
)
# Discrete critic weights match the learner's exported discrete critic.
discrete_critic_state_dict = policy_actor.discrete_critic.state_dict()
for key, tensor in weights["discrete_critic"].items():
assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), (
f"Discrete critic param '{key}' not synced by load_actor_weights"
)
# Sanity: the discrete critic actually changed (otherwise the sync is trivial).
changed = any(
not torch.equal(initial_discrete_critic_state_dict[key], discrete_critic_state_dict[key])
for key in initial_discrete_critic_state_dict
if key in discrete_critic_state_dict
)
assert changed, "Discrete critic weights did not change between init and after sync"
# ===========================================================================
# TrainingStats generic losses dict
# ===========================================================================
@@ -468,8 +509,9 @@ def test_training_stats_generic_losses():
def test_build_algorithm_via_config():
"""SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm."""
sac_cfg = _make_sac_config(utd_ratio=2)
sac_cfg = _make_sac_config()
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
algo_config.utd_ratio = 2
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = algo_config.build_algorithm(policy)