refactor(rl/sac): consolidate hyperparameter ownership and clean up discrete critic

This commit is contained in:
Khalil Meftah
2026-04-24 13:18:33 +02:00
parent 06255996ea
commit 1ed32210c7
9 changed files with 162 additions and 190 deletions
+1 -1
View File
@@ -926,7 +926,7 @@ The ideal behaviour is that your intervention rate should drop gradually during
Some configuration values have a disproportionate impact on training stability and speed: Some configuration values have a disproportionate impact on training stability and speed:
- **`temperature_init`** (`policy.temperature_init`) initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning. - **`temperature_init`** (`algorithm.temperature_init`) initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency. - **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
- **`storage_device`** (`policy.storage_device`) device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second. - **`storage_device`** (`policy.storage_device`) device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second.
@@ -136,80 +136,41 @@ class GaussianActorConfig(PreTrainedConfig):
# Dimension of the image embedding pooling # Dimension of the image embedding pooling
image_embedding_pooling_dim: int = 8 image_embedding_pooling_dim: int = 8
# Training parameter # Encoder architecture
# Number of steps for online training
online_steps: int = 1000000
# Capacity of the online replay buffer
online_buffer_capacity: int = 100000
# Capacity of the offline replay buffer
offline_buffer_capacity: int = 100000
# Whether to use asynchronous prefetching for the buffers
async_prefetch: bool = False
# Number of steps before learning starts
online_step_before_learning: int = 100
# Frequency of policy updates
policy_update_freq: int = 1
# SAC algorithm parameters
# Discount factor for the SAC algorithm
discount: float = 0.99
# Initial temperature value
temperature_init: float = 1.0
# Number of critics in the ensemble
num_critics: int = 2
# Number of subsampled critics for training
num_subsample_critics: int | None = None
# Learning rate for the critic network
critic_lr: float = 3e-4
# Learning rate for the actor network
actor_lr: float = 3e-4
# Learning rate for the temperature parameter
temperature_lr: float = 3e-4
# Weight for the critic target update
critic_target_update_weight: float = 0.005
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
utd_ratio: int = 1
# Hidden dimension size for the state encoder # Hidden dimension size for the state encoder
state_encoder_hidden_dim: int = 256 state_encoder_hidden_dim: int = 256
# Dimension of the latent space # Dimension of the latent space
latent_dim: int = 256 latent_dim: int = 256
# Target entropy for the SAC algorithm
target_entropy: float | None = None
# Whether to use backup entropy for the SAC algorithm
use_backup_entropy: bool = True
# Gradient clipping norm for the SAC algorithm
grad_clip_norm: float = 40.0
# Network configuration # Online training (TODO(Khalil): relocate to TrainRLServerPipelineConfig)
# Configuration for the critic network architecture online_steps: int = 1000000
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) online_buffer_capacity: int = 100000
# Configuration for the actor network architecture offline_buffer_capacity: int = 100000
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) async_prefetch: bool = False
# Configuration for the policy parameters online_step_before_learning: int = 100
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
# Configuration for the discrete critic network # Actor-learner transport (TODO(Khalil): relocate to TrainRLServerPipelineConfig).
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for actor-learner architecture
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig) actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
# Optimizations # Network architecture
# torch.compile is currently disabled by default due to known issues with the SAC # Actor network
# critic ensemble and shared encoder. actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
use_torch_compile: bool = False # Gaussian head parameters
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
# Discrete critic
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
# Any validation specific to SAC configuration
def get_optimizer_preset(self) -> MultiAdamConfig: def get_optimizer_preset(self) -> MultiAdamConfig:
return MultiAdamConfig( return MultiAdamConfig(
weight_decay=0.0, weight_decay=0.0,
optimizer_groups={ optimizer_groups={
"actor": {"lr": self.actor_lr}, "actor": {"lr": 3e-4},
"critic": {"lr": self.critic_lr}, "critic": {"lr": 3e-4},
"temperature": {"lr": self.temperature_lr}, "temperature": {"lr": 3e-4},
}, },
) )
@@ -19,7 +19,6 @@ from collections.abc import Callable
from dataclasses import asdict from dataclasses import asdict
from typing import Any from typing import Any
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
@@ -61,7 +60,7 @@ class GaussianActorPolicy(
continuous_action_dim = config.output_features[ACTION].shape[0] continuous_action_dim = config.output_features[ACTION].shape[0]
self._init_encoders() self._init_encoders()
self._init_actor(continuous_action_dim) self._init_actor(continuous_action_dim)
self.discrete_critic = None self._init_discrete_critic()
def get_optim_params(self) -> dict: def get_optim_params(self) -> dict:
optim_params = { optim_params = {
@@ -125,19 +124,14 @@ class GaussianActorPolicy(
def load_actor_weights(self, state_dicts: dict[str, Any], device: str | torch.device = "cpu") -> None: def load_actor_weights(self, state_dicts: dict[str, Any], device: str | torch.device = "cpu") -> None:
from lerobot.utils.transition import move_state_dict_to_device from lerobot.utils.transition import move_state_dict_to_device
actor_sd = move_state_dict_to_device(state_dicts["policy"], device=device) actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
self.actor.load_state_dict(actor_sd) self.actor.load_state_dict(actor_state_dict)
if "discrete_critic" in state_dicts: if "discrete_critic" in state_dicts and self.discrete_critic is not None:
dc_sd = move_state_dict_to_device(state_dicts["discrete_critic"], device=device) discrete_critic_state_dict = move_state_dict_to_device(
if self.discrete_critic is None: state_dicts["discrete_critic"], device=device
self.discrete_critic = DiscreteCritic( )
encoder=self.encoder_critic, self.discrete_critic.load_state_dict(discrete_critic_state_dict)
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
).to(device)
self.discrete_critic.load_state_dict(dc_sd)
def _init_encoders(self): def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic.""" """Initialize shared or separate encoders for actor and critic."""
@@ -148,7 +142,7 @@ class GaussianActorPolicy(
) )
def _init_actor(self, continuous_action_dim): def _init_actor(self, continuous_action_dim):
"""Initialize policy actor network and default target entropy.""" """Initialize policy actor network."""
# NOTE: The actor select only the continuous action part # NOTE: The actor select only the continuous action part
self.actor = Policy( self.actor = Policy(
encoder=self.encoder_actor, encoder=self.encoder_actor,
@@ -158,10 +152,19 @@ class GaussianActorPolicy(
**asdict(self.config.policy_kwargs), **asdict(self.config.policy_kwargs),
) )
self.target_entropy = self.config.target_entropy def _init_discrete_critic(self) -> None:
if self.target_entropy is None: """Initialize discrete critic network."""
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0) if self.config.num_discrete_actions is None:
self.target_entropy = -np.prod(dim) / 2 self.discrete_critic = None
return
# TODO(Khalil): Compile the discrete critic
self.discrete_critic = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
class GaussianActorObservationEncoder(nn.Module): class GaussianActorObservationEncoder(nn.Module):
@@ -35,7 +35,7 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
"""SAC algorithm hyperparameters.""" """SAC algorithm hyperparameters."""
# Policy config # Policy config
sac_config: GaussianActorConfig policy_config: GaussianActorConfig
# Optimizer learning rates # Optimizer learning rates
actor_lr: float = 3e-4 actor_lr: float = 3e-4
@@ -55,31 +55,26 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
# Temperature / entropy # Temperature / entropy
temperature_init: float = 1.0 temperature_init: float = 1.0
# Target entropy for automatic temperature tuning. If ``None``, defaults to
# ``-|A|/2`` where ``|A|`` is the total action dimension (continuous + 1 if
# there is a discrete action head).
target_entropy: float | None = None
# Update loop # Update loop
utd_ratio: int = 1 utd_ratio: int = 1
policy_update_freq: int = 1 policy_update_freq: int = 1
grad_clip_norm: float = 40.0 grad_clip_norm: float = 40.0
# Optimizations
# torch.compile is currently disabled by default
use_torch_compile: bool = False
@classmethod @classmethod
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig: def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
"""Build an algorithm config by copying hyperparameters from the policy config.""" """Build an algorithm config with default hyperparameters for a given policy."""
return cls( return cls(
actor_lr=policy_cfg.actor_lr, policy_config=policy_cfg,
critic_lr=policy_cfg.critic_lr,
temperature_lr=policy_cfg.temperature_lr,
discount=policy_cfg.discount,
use_backup_entropy=policy_cfg.use_backup_entropy,
critic_target_update_weight=policy_cfg.critic_target_update_weight,
num_critics=policy_cfg.num_critics,
num_subsample_critics=policy_cfg.num_subsample_critics,
critic_network_kwargs=policy_cfg.critic_network_kwargs,
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs, discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
temperature_init=policy_cfg.temperature_init,
utd_ratio=policy_cfg.utd_ratio,
policy_update_freq=policy_cfg.policy_update_freq,
grad_clip_norm=policy_cfg.grad_clip_norm,
sac_config=policy_cfg,
) )
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm: def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
+29 -35
View File
@@ -54,14 +54,14 @@ class SACAlgorithm(RLAlgorithm):
config: SACAlgorithmConfig, config: SACAlgorithmConfig,
): ):
self.config = config self.config = config
self.policy_config = config.sac_config self.policy_config = config.policy_config
self.policy = policy self.policy = policy
self.optimizers: dict[str, Optimizer] = {} self.optimizers: dict[str, Optimizer] = {}
self._optimization_step: int = 0 self._optimization_step: int = 0
action_dim = self.policy.config.output_features[ACTION].shape[0] action_dim = self.policy.config.output_features[ACTION].shape[0]
self._init_critics(action_dim) self._init_critics(action_dim)
self._init_temperature() self._init_temperature(action_dim)
self._device = torch.device(self.policy.config.device) self._device = torch.device(self.policy.config.device)
self._move_to_device() self._move_to_device()
@@ -90,49 +90,44 @@ class SACAlgorithm(RLAlgorithm):
# TODO(Khalil): Investigate and fix torch.compile # TODO(Khalil): Investigate and fix torch.compile
# NOTE: torch.compile is disabled, policy does not converge when enabled. # NOTE: torch.compile is disabled, policy does not converge when enabled.
if self.policy_config.use_torch_compile: if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble) self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target) self.critic_target = torch.compile(self.critic_target)
self.discrete_critic = None
self.discrete_critic_target = None self.discrete_critic_target = None
if self.policy_config.num_discrete_actions is not None: if self.policy_config.num_discrete_actions is not None:
self.discrete_critic, self.discrete_critic_target = self._init_discrete_critics(encoder) self.discrete_critic_target = self._init_discrete_critic_target(encoder)
self.policy.discrete_critic = self.discrete_critic
def _init_discrete_critics( def _init_discrete_critic_target(self, encoder: GaussianActorObservationEncoder) -> DiscreteCritic:
self, encoder: GaussianActorObservationEncoder """Build target discrete critic (main network is owned by the policy)."""
) -> tuple[DiscreteCritic, DiscreteCritic]:
"""Build discrete critic ensemble and target networks."""
discrete_critic = DiscreteCritic(
encoder=encoder,
input_dim=encoder.output_dim,
output_dim=self.policy_config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
discrete_critic_target = DiscreteCritic( discrete_critic_target = DiscreteCritic(
encoder=encoder, encoder=encoder,
input_dim=encoder.output_dim, input_dim=encoder.output_dim,
output_dim=self.policy_config.num_discrete_actions, output_dim=self.policy_config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs), **asdict(self.config.discrete_critic_network_kwargs),
) )
# TODO(Khalil): Compile the discrete critic # TODO(Khalil): Compile the discrete critic
discrete_critic_target.load_state_dict(discrete_critic.state_dict()) discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict())
return discrete_critic, discrete_critic_target return discrete_critic_target
def _init_temperature(self) -> None: def _init_temperature(self, continuous_action_dim: int) -> None:
"""Set up temperature parameter (log_alpha).""" """Set up temperature parameter (log_alpha) and target entropy."""
temp_init = self.config.temperature_init temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
self.target_entropy = self.config.target_entropy
if self.target_entropy is None:
total_action_dim = continuous_action_dim + (
1 if self.policy_config.num_discrete_actions is not None else 0
)
self.target_entropy = -total_action_dim / 2
def _move_to_device(self) -> None: def _move_to_device(self) -> None:
self.policy.to(self._device) self.policy.to(self._device)
self.critic_ensemble.to(self._device) self.critic_ensemble.to(self._device)
self.critic_target.to(self._device) self.critic_target.to(self._device)
self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device)) self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device))
if self.discrete_critic is not None: if self.discrete_critic_target is not None:
self.discrete_critic.to(self._device)
self.discrete_critic_target.to(self._device) self.discrete_critic_target.to(self._device)
@property @property
@@ -175,7 +170,7 @@ class SACAlgorithm(RLAlgorithm):
Returns: Returns:
Tensor of Q-values from the discrete critic network Tensor of Q-values from the discrete critic network
""" """
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic discrete_critic = self.discrete_critic_target if use_target else self.policy.discrete_critic
q_values = discrete_critic(observations, observation_features) q_values = discrete_critic(observations, observation_features)
return q_values return q_values
@@ -196,7 +191,7 @@ class SACAlgorithm(RLAlgorithm):
loss_dc = self._compute_loss_discrete_critic(fb) loss_dc = self._compute_loss_discrete_critic(fb)
self.optimizers["discrete_critic"].zero_grad() self.optimizers["discrete_critic"].zero_grad()
loss_dc.backward() loss_dc.backward()
torch.nn.utils.clip_grad_norm_(self.discrete_critic.parameters(), max_norm=clip) torch.nn.utils.clip_grad_norm_(self.policy.discrete_critic.parameters(), max_norm=clip)
self.optimizers["discrete_critic"].step() self.optimizers["discrete_critic"].step()
self._update_target_networks() self._update_target_networks()
@@ -219,7 +214,9 @@ class SACAlgorithm(RLAlgorithm):
loss_dc = self._compute_loss_discrete_critic(fb) loss_dc = self._compute_loss_discrete_critic(fb)
self.optimizers["discrete_critic"].zero_grad() self.optimizers["discrete_critic"].zero_grad()
loss_dc.backward() loss_dc.backward()
dc_grad = torch.nn.utils.clip_grad_norm_(self.discrete_critic.parameters(), max_norm=clip).item() dc_grad = torch.nn.utils.clip_grad_norm_(
self.policy.discrete_critic.parameters(), max_norm=clip
).item()
self.optimizers["discrete_critic"].step() self.optimizers["discrete_critic"].step()
stats.losses["loss_discrete_critic"] = loss_dc.item() stats.losses["loss_discrete_critic"] = loss_dc.item()
stats.grad_norms["discrete_critic"] = dc_grad stats.grad_norms["discrete_critic"] = dc_grad
@@ -396,7 +393,7 @@ class SACAlgorithm(RLAlgorithm):
with torch.no_grad(): with torch.no_grad():
_, log_probs, _ = self.policy.actor(observations, observation_features) _, log_probs, _ = self.policy.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.policy.target_entropy)).mean() temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
return temperature_loss return temperature_loss
def _update_target_networks(self) -> None: def _update_target_networks(self) -> None:
@@ -411,7 +408,7 @@ class SACAlgorithm(RLAlgorithm):
if self.policy_config.num_discrete_actions is not None: if self.policy_config.num_discrete_actions is not None:
for target_p, p in zip( for target_p, p in zip(
self.discrete_critic_target.parameters(), self.discrete_critic_target.parameters(),
self.discrete_critic.parameters(), self.policy.discrete_critic.parameters(),
strict=True, strict=True,
): ):
target_p.data.copy_( target_p.data.copy_(
@@ -471,7 +468,7 @@ class SACAlgorithm(RLAlgorithm):
} }
if self.policy_config.num_discrete_actions is not None: if self.policy_config.num_discrete_actions is not None:
self.optimizers["discrete_critic"] = torch.optim.Adam( self.optimizers["discrete_critic"] = torch.optim.Adam(
self.discrete_critic.parameters(), lr=self.config.critic_lr self.policy.discrete_critic.parameters(), lr=self.config.critic_lr
) )
return self.optimizers return self.optimizers
@@ -485,16 +482,13 @@ class SACAlgorithm(RLAlgorithm):
} }
if self.policy_config.num_discrete_actions is not None: if self.policy_config.num_discrete_actions is not None:
state_dicts["discrete_critic"] = move_state_dict_to_device( state_dicts["discrete_critic"] = move_state_dict_to_device(
self.discrete_critic.state_dict(), device="cpu" self.policy.discrete_critic.state_dict(), device="cpu"
) )
return state_dicts return state_dicts
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
actor_sd = move_state_dict_to_device(weights["policy"], device=device) """Load actor + discrete-critic weights into the policy."""
self.policy.actor.load_state_dict(actor_sd) self.policy.load_actor_weights(weights, device=device)
if "discrete_critic" in weights and self.policy_config.num_discrete_actions is not None:
dc_sd = move_state_dict_to_device(weights["discrete_critic"], device=device)
self.discrete_critic.load_state_dict(dc_sd)
def get_observation_features( def get_observation_features(
self, observations: Tensor, next_observations: Tensor self, observations: Tensor, next_observations: Tensor
+8 -29
View File
@@ -55,9 +55,6 @@ def test_gaussian_actor_config_default_initialization():
# Basic parameters # Basic parameters
assert config.device == "cpu" assert config.device == "cpu"
assert config.storage_device == "cpu" assert config.storage_device == "cpu"
assert config.discount == 0.99
assert config.temperature_init == 1.0
assert config.num_critics == 2
# Architecture specifics # Architecture specifics
assert config.vision_encoder_name is None assert config.vision_encoder_name is None
@@ -66,6 +63,8 @@ def test_gaussian_actor_config_default_initialization():
assert config.shared_encoder is True assert config.shared_encoder is True
assert config.num_discrete_actions is None assert config.num_discrete_actions is None
assert config.image_embedding_pooling_dim == 8 assert config.image_embedding_pooling_dim == 8
assert config.state_encoder_hidden_dim == 256
assert config.latent_dim == 256
# Training parameters # Training parameters
assert config.online_steps == 1000000 assert config.online_steps == 1000000
@@ -73,20 +72,6 @@ def test_gaussian_actor_config_default_initialization():
assert config.offline_buffer_capacity == 100000 assert config.offline_buffer_capacity == 100000
assert config.async_prefetch is False assert config.async_prefetch is False
assert config.online_step_before_learning == 100 assert config.online_step_before_learning == 100
assert config.policy_update_freq == 1
# SAC algorithm parameters
assert config.num_subsample_critics is None
assert config.critic_lr == 3e-4
assert config.actor_lr == 3e-4
assert config.temperature_lr == 3e-4
assert config.critic_target_update_weight == 0.005
assert config.utd_ratio == 1
assert config.state_encoder_hidden_dim == 256
assert config.latent_dim == 256
assert config.target_entropy is None
assert config.use_backup_entropy is True
assert config.grad_clip_norm == 40.0
# Dataset stats defaults # Dataset stats defaults
expected_dataset_stats = { expected_dataset_stats = {
@@ -105,11 +90,6 @@ def test_gaussian_actor_config_default_initialization():
} }
assert config.dataset_stats == expected_dataset_stats assert config.dataset_stats == expected_dataset_stats
# Critic network configuration
assert config.critic_network_kwargs.hidden_dims == [256, 256]
assert config.critic_network_kwargs.activate_final is True
assert config.critic_network_kwargs.final_activation is None
# Actor network configuration # Actor network configuration
assert config.actor_network_kwargs.hidden_dims == [256, 256] assert config.actor_network_kwargs.hidden_dims == [256, 256]
assert config.actor_network_kwargs.activate_final is True assert config.actor_network_kwargs.activate_final is True
@@ -135,7 +115,6 @@ def test_gaussian_actor_config_default_initialization():
assert config.concurrency.learner == "threads" assert config.concurrency.learner == "threads"
assert isinstance(config.actor_network_kwargs, ActorNetworkConfig) assert isinstance(config.actor_network_kwargs, ActorNetworkConfig)
assert isinstance(config.critic_network_kwargs, CriticNetworkConfig)
assert isinstance(config.policy_kwargs, PolicyConfig) assert isinstance(config.policy_kwargs, PolicyConfig)
assert isinstance(config.actor_learner_config, ActorLearnerConfig) assert isinstance(config.actor_learner_config, ActorLearnerConfig)
assert isinstance(config.concurrency, ConcurrencyConfig) assert isinstance(config.concurrency, ConcurrencyConfig)
@@ -178,15 +157,15 @@ def test_concurrency_config():
def test_gaussian_actor_config_custom_initialization(): def test_gaussian_actor_config_custom_initialization():
config = GaussianActorConfig( config = GaussianActorConfig(
device="cpu", device="cpu",
discount=0.95, latent_dim=128,
temperature_init=0.5, state_encoder_hidden_dim=128,
num_critics=3, num_discrete_actions=3,
) )
assert config.device == "cpu" assert config.device == "cpu"
assert config.discount == 0.95 assert config.latent_dim == 128
assert config.temperature_init == 0.5 assert config.state_encoder_hidden_dim == 128
assert config.num_critics == 3 assert config.num_discrete_actions == 3
def test_validate_features(): def test_validate_features():
+10 -9
View File
@@ -404,19 +404,16 @@ def test_sac_training_with_discrete_critic():
def test_sac_algorithm_target_entropy(): def test_sac_algorithm_target_entropy():
"""Target entropy is an SAC hyperparameter and lives on the algorithm."""
config = create_default_config(continuous_action_dim=10, state_dim=10) config = create_default_config(continuous_action_dim=10, state_dim=10)
_, policy = _make_algorithm(config) algorithm, _ = _make_algorithm(config)
algo_config = SACAlgorithmConfig.from_policy_config(config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
assert algorithm.target_entropy == -5.0 assert algorithm.target_entropy == -5.0
def test_sac_algorithm_target_entropy_with_discrete_action(): def test_sac_algorithm_target_entropy_with_discrete_action():
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True) config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
config.num_discrete_actions = 5 config.num_discrete_actions = 5
algo_config = SACAlgorithmConfig.from_policy_config(config) algorithm, _ = _make_algorithm(config)
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
assert algorithm.target_entropy == -3.5 assert algorithm.target_entropy == -3.5
@@ -435,8 +432,8 @@ def test_sac_algorithm_temperature():
def test_sac_algorithm_update_target_network(): def test_sac_algorithm_update_target_network():
config = create_default_config(state_dim=10, continuous_action_dim=6) config = create_default_config(state_dim=10, continuous_action_dim=6)
config.critic_target_update_weight = 1.0
algo_config = SACAlgorithmConfig.from_policy_config(config) algo_config = SACAlgorithmConfig.from_policy_config(config)
algo_config.critic_target_update_weight = 1.0
policy = GaussianActorPolicy(config=config) policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config) algorithm = SACAlgorithm(policy=policy, config=algo_config)
@@ -454,9 +451,13 @@ def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
action_dim = 10 action_dim = 10
state_dim = 10 state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.num_critics = num_critics
algorithm, policy = _make_algorithm(config) policy = GaussianActorPolicy(config=config)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(config)
algo_config.num_critics = num_critics
algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm.make_optimizers_and_scheduler()
assert len(algorithm.critic_ensemble.critics) == num_critics assert len(algorithm.critic_ensemble.critics) == num_critics
-3
View File
@@ -327,7 +327,6 @@ def test_learner_algorithm_wiring():
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
}, },
use_torch_compile=False,
) )
sac_cfg.validate_features() sac_cfg.validate_features()
@@ -412,7 +411,6 @@ def test_initial_and_periodic_weight_push_consistency():
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
}, },
use_torch_compile=False,
) )
sac_cfg.validate_features() sac_cfg.validate_features()
@@ -450,7 +448,6 @@ def test_actor_side_algorithm_select_action_and_load_weights():
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
}, },
use_torch_compile=False,
) )
sac_cfg.validate_features() sac_cfg.validate_features()
+62 -20
View File
@@ -44,8 +44,6 @@ def _make_sac_config(
state_dim: int = 10, state_dim: int = 10,
action_dim: int = 6, action_dim: int = 6,
num_discrete_actions: int | None = None, num_discrete_actions: int | None = None,
utd_ratio: int = 1,
policy_update_freq: int = 1,
with_images: bool = False, with_images: bool = False,
) -> GaussianActorConfig: ) -> GaussianActorConfig:
config = GaussianActorConfig( config = GaussianActorConfig(
@@ -55,10 +53,7 @@ def _make_sac_config(
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
}, },
utd_ratio=utd_ratio,
policy_update_freq=policy_update_freq,
num_discrete_actions=num_discrete_actions, num_discrete_actions=num_discrete_actions,
use_torch_compile=False,
) )
if with_images: if with_images:
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
@@ -83,14 +78,14 @@ def _make_algorithm(
sac_cfg = _make_sac_config( sac_cfg = _make_sac_config(
state_dim=state_dim, state_dim=state_dim,
action_dim=action_dim, action_dim=action_dim,
utd_ratio=utd_ratio,
policy_update_freq=policy_update_freq,
num_discrete_actions=num_discrete_actions, num_discrete_actions=num_discrete_actions,
with_images=with_images, with_images=with_images,
) )
policy = GaussianActorPolicy(config=sac_cfg) policy = GaussianActorPolicy(config=sac_cfg)
policy.train() policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg) algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
algo_config.utd_ratio = utd_ratio
algo_config.policy_update_freq = policy_update_freq
algorithm = SACAlgorithm(policy=policy, config=algo_config) algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm.make_optimizers_and_scheduler() algorithm.make_optimizers_and_scheduler()
return algorithm, policy return algorithm, policy
@@ -136,13 +131,16 @@ def test_sac_algorithm_config_registered():
def test_sac_algorithm_config_from_policy_config(): def test_sac_algorithm_config_from_policy_config():
"""from_policy_config should copy algorithm hyperparameters from the policy config.""" """from_policy_config embeds the policy config and uses SAC defaults."""
sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2) sac_cfg = _make_sac_config()
algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg) algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg)
assert algo_cfg.sac_config is sac_cfg assert algo_cfg.policy_config is sac_cfg
assert algo_cfg.utd_ratio == 4 assert algo_cfg.discrete_critic_network_kwargs is sac_cfg.discrete_critic_network_kwargs
assert algo_cfg.policy_update_freq == 2 # Defaults come from SACAlgorithmConfig, not from the policy config.
assert algo_cfg.grad_clip_norm == sac_cfg.grad_clip_norm assert algo_cfg.utd_ratio == 1
assert algo_cfg.policy_update_freq == 1
assert algo_cfg.grad_clip_norm == 40.0
assert algo_cfg.actor_lr == 3e-4
# =========================================================================== # ===========================================================================
@@ -377,12 +375,14 @@ def test_actor_side_no_optimizers():
assert algorithm.optimizers == {} assert algorithm.optimizers == {}
def test_make_algorithm_copies_config_fields(): def test_make_algorithm_uses_sac_algorithm_defaults():
sac_cfg = _make_sac_config(utd_ratio=5, policy_update_freq=3) """make_algorithm populates SACAlgorithmConfig with its own defaults."""
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg) policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert algorithm.config.utd_ratio == 5 assert algorithm.config.utd_ratio == 1
assert algorithm.config.policy_update_freq == 3 assert algorithm.config.policy_update_freq == 1
assert algorithm.config.grad_clip_norm == 40.0
def test_make_algorithm_raises_for_unknown_type(): def test_make_algorithm_raises_for_unknown_type():
@@ -431,10 +431,10 @@ def test_load_weights_round_trip_with_discrete_critic():
assert "discrete_critic" in weights assert "discrete_critic" in weights
assert len(weights["discrete_critic"]) > 0 assert len(weights["discrete_critic"]) > 0
dst_dc_state_dict = algo_dst.discrete_critic.state_dict() dst_discrete_critic_state_dict = algo_dst.policy.discrete_critic.state_dict()
for key, tensor in weights["discrete_critic"].items(): for key, tensor in weights["discrete_critic"].items():
assert torch.equal( assert torch.equal(
dst_dc_state_dict[key].cpu(), dst_discrete_critic_state_dict[key].cpu(),
tensor.cpu(), tensor.cpu(),
), f"Discrete critic param '{key}' mismatch after load_weights" ), f"Discrete critic param '{key}' mismatch after load_weights"
@@ -446,6 +446,47 @@ def test_load_weights_ignores_missing_discrete_critic():
algorithm.load_weights(weights, device="cpu") algorithm.load_weights(weights, device="cpu")
def test_actor_side_weight_sync_with_discrete_critic():
"""End-to-end: learner ``algorithm.get_weights()`` -> actor ``policy.load_actor_weights()``."""
# Learner side: train the algorithm so its weights diverge from init.
algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
algo_src.update(_batch_iterator(action_dim=7))
weights = algo_src.get_weights()
# Actor side: fresh policy, no algorithm/optimizer.
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
policy_actor = GaussianActorPolicy(config=sac_cfg)
# Snapshot initial actor state for the "did it change?" assertion below.
initial_discrete_critic_state_dict = {
k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items()
}
policy_actor.load_actor_weights(weights, device="cpu")
# Actor weights match the learner's exported actor state dict.
actor_state_dict = policy_actor.actor.state_dict()
for key, tensor in weights["policy"].items():
assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), (
f"Actor param '{key}' not synced by load_actor_weights"
)
# Discrete critic weights match the learner's exported discrete critic.
discrete_critic_state_dict = policy_actor.discrete_critic.state_dict()
for key, tensor in weights["discrete_critic"].items():
assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), (
f"Discrete critic param '{key}' not synced by load_actor_weights"
)
# Sanity: the discrete critic actually changed (otherwise the sync is trivial).
changed = any(
not torch.equal(initial_discrete_critic_state_dict[key], discrete_critic_state_dict[key])
for key in initial_discrete_critic_state_dict
if key in discrete_critic_state_dict
)
assert changed, "Discrete critic weights did not change between init and after sync"
# =========================================================================== # ===========================================================================
# TrainingStats generic losses dict # TrainingStats generic losses dict
# =========================================================================== # ===========================================================================
@@ -468,8 +509,9 @@ def test_training_stats_generic_losses():
def test_build_algorithm_via_config(): def test_build_algorithm_via_config():
"""SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm.""" """SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm."""
sac_cfg = _make_sac_config(utd_ratio=2) sac_cfg = _make_sac_config()
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg) algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
algo_config.utd_ratio = 2
policy = GaussianActorPolicy(config=sac_cfg) policy = GaussianActorPolicy(config=sac_cfg)
algorithm = algo_config.build_algorithm(policy) algorithm = algo_config.build_algorithm(policy)