mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
refactor(rl/sac): consolidate hyperparameter ownership and clean up discrete critic
This commit is contained in:
@@ -926,7 +926,7 @@ The ideal behaviour is that your intervention rate should drop gradually during
|
|||||||
|
|
||||||
Some configuration values have a disproportionate impact on training stability and speed:
|
Some configuration values have a disproportionate impact on training stability and speed:
|
||||||
|
|
||||||
- **`temperature_init`** (`policy.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
|
- **`temperature_init`** (`algorithm.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
|
||||||
- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
|
- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
|
||||||
- **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second.
|
- **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second.
|
||||||
|
|
||||||
|
|||||||
@@ -136,80 +136,41 @@ class GaussianActorConfig(PreTrainedConfig):
|
|||||||
# Dimension of the image embedding pooling
|
# Dimension of the image embedding pooling
|
||||||
image_embedding_pooling_dim: int = 8
|
image_embedding_pooling_dim: int = 8
|
||||||
|
|
||||||
# Training parameter
|
# Encoder architecture
|
||||||
# Number of steps for online training
|
|
||||||
online_steps: int = 1000000
|
|
||||||
# Capacity of the online replay buffer
|
|
||||||
online_buffer_capacity: int = 100000
|
|
||||||
# Capacity of the offline replay buffer
|
|
||||||
offline_buffer_capacity: int = 100000
|
|
||||||
# Whether to use asynchronous prefetching for the buffers
|
|
||||||
async_prefetch: bool = False
|
|
||||||
# Number of steps before learning starts
|
|
||||||
online_step_before_learning: int = 100
|
|
||||||
# Frequency of policy updates
|
|
||||||
policy_update_freq: int = 1
|
|
||||||
|
|
||||||
# SAC algorithm parameters
|
|
||||||
# Discount factor for the SAC algorithm
|
|
||||||
discount: float = 0.99
|
|
||||||
# Initial temperature value
|
|
||||||
temperature_init: float = 1.0
|
|
||||||
# Number of critics in the ensemble
|
|
||||||
num_critics: int = 2
|
|
||||||
# Number of subsampled critics for training
|
|
||||||
num_subsample_critics: int | None = None
|
|
||||||
# Learning rate for the critic network
|
|
||||||
critic_lr: float = 3e-4
|
|
||||||
# Learning rate for the actor network
|
|
||||||
actor_lr: float = 3e-4
|
|
||||||
# Learning rate for the temperature parameter
|
|
||||||
temperature_lr: float = 3e-4
|
|
||||||
# Weight for the critic target update
|
|
||||||
critic_target_update_weight: float = 0.005
|
|
||||||
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
|
|
||||||
utd_ratio: int = 1
|
|
||||||
# Hidden dimension size for the state encoder
|
# Hidden dimension size for the state encoder
|
||||||
state_encoder_hidden_dim: int = 256
|
state_encoder_hidden_dim: int = 256
|
||||||
# Dimension of the latent space
|
# Dimension of the latent space
|
||||||
latent_dim: int = 256
|
latent_dim: int = 256
|
||||||
# Target entropy for the SAC algorithm
|
|
||||||
target_entropy: float | None = None
|
|
||||||
# Whether to use backup entropy for the SAC algorithm
|
|
||||||
use_backup_entropy: bool = True
|
|
||||||
# Gradient clipping norm for the SAC algorithm
|
|
||||||
grad_clip_norm: float = 40.0
|
|
||||||
|
|
||||||
# Network configuration
|
# Online training (TODO(Khalil): relocate to TrainRLServerPipelineConfig)
|
||||||
# Configuration for the critic network architecture
|
online_steps: int = 1000000
|
||||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
online_buffer_capacity: int = 100000
|
||||||
# Configuration for the actor network architecture
|
offline_buffer_capacity: int = 100000
|
||||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
async_prefetch: bool = False
|
||||||
# Configuration for the policy parameters
|
online_step_before_learning: int = 100
|
||||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
|
||||||
# Configuration for the discrete critic network
|
# Actor-learner transport (TODO(Khalil): relocate to TrainRLServerPipelineConfig).
|
||||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
|
||||||
# Configuration for actor-learner architecture
|
|
||||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||||
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
|
||||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||||
|
|
||||||
# Optimizations
|
# Network architecture
|
||||||
# torch.compile is currently disabled by default due to known issues with the SAC
|
# Actor network
|
||||||
# critic ensemble and shared encoder.
|
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||||
use_torch_compile: bool = False
|
# Gaussian head parameters
|
||||||
|
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||||
|
# Discrete critic
|
||||||
|
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
# Any validation specific to SAC configuration
|
|
||||||
|
|
||||||
def get_optimizer_preset(self) -> MultiAdamConfig:
|
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||||
return MultiAdamConfig(
|
return MultiAdamConfig(
|
||||||
weight_decay=0.0,
|
weight_decay=0.0,
|
||||||
optimizer_groups={
|
optimizer_groups={
|
||||||
"actor": {"lr": self.actor_lr},
|
"actor": {"lr": 3e-4},
|
||||||
"critic": {"lr": self.critic_lr},
|
"critic": {"lr": 3e-4},
|
||||||
"temperature": {"lr": self.temperature_lr},
|
"temperature": {"lr": 3e-4},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from collections.abc import Callable
|
|||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@@ -61,7 +60,7 @@ class GaussianActorPolicy(
|
|||||||
continuous_action_dim = config.output_features[ACTION].shape[0]
|
continuous_action_dim = config.output_features[ACTION].shape[0]
|
||||||
self._init_encoders()
|
self._init_encoders()
|
||||||
self._init_actor(continuous_action_dim)
|
self._init_actor(continuous_action_dim)
|
||||||
self.discrete_critic = None
|
self._init_discrete_critic()
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
optim_params = {
|
optim_params = {
|
||||||
@@ -125,19 +124,14 @@ class GaussianActorPolicy(
|
|||||||
def load_actor_weights(self, state_dicts: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
def load_actor_weights(self, state_dicts: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||||
from lerobot.utils.transition import move_state_dict_to_device
|
from lerobot.utils.transition import move_state_dict_to_device
|
||||||
|
|
||||||
actor_sd = move_state_dict_to_device(state_dicts["policy"], device=device)
|
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
|
||||||
self.actor.load_state_dict(actor_sd)
|
self.actor.load_state_dict(actor_state_dict)
|
||||||
|
|
||||||
if "discrete_critic" in state_dicts:
|
if "discrete_critic" in state_dicts and self.discrete_critic is not None:
|
||||||
dc_sd = move_state_dict_to_device(state_dicts["discrete_critic"], device=device)
|
discrete_critic_state_dict = move_state_dict_to_device(
|
||||||
if self.discrete_critic is None:
|
state_dicts["discrete_critic"], device=device
|
||||||
self.discrete_critic = DiscreteCritic(
|
)
|
||||||
encoder=self.encoder_critic,
|
self.discrete_critic.load_state_dict(discrete_critic_state_dict)
|
||||||
input_dim=self.encoder_critic.output_dim,
|
|
||||||
output_dim=self.config.num_discrete_actions,
|
|
||||||
**asdict(self.config.discrete_critic_network_kwargs),
|
|
||||||
).to(device)
|
|
||||||
self.discrete_critic.load_state_dict(dc_sd)
|
|
||||||
|
|
||||||
def _init_encoders(self):
|
def _init_encoders(self):
|
||||||
"""Initialize shared or separate encoders for actor and critic."""
|
"""Initialize shared or separate encoders for actor and critic."""
|
||||||
@@ -148,7 +142,7 @@ class GaussianActorPolicy(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _init_actor(self, continuous_action_dim):
|
def _init_actor(self, continuous_action_dim):
|
||||||
"""Initialize policy actor network and default target entropy."""
|
"""Initialize policy actor network."""
|
||||||
# NOTE: The actor select only the continuous action part
|
# NOTE: The actor select only the continuous action part
|
||||||
self.actor = Policy(
|
self.actor = Policy(
|
||||||
encoder=self.encoder_actor,
|
encoder=self.encoder_actor,
|
||||||
@@ -158,10 +152,19 @@ class GaussianActorPolicy(
|
|||||||
**asdict(self.config.policy_kwargs),
|
**asdict(self.config.policy_kwargs),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.target_entropy = self.config.target_entropy
|
def _init_discrete_critic(self) -> None:
|
||||||
if self.target_entropy is None:
|
"""Initialize discrete critic network."""
|
||||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
if self.config.num_discrete_actions is None:
|
||||||
self.target_entropy = -np.prod(dim) / 2
|
self.discrete_critic = None
|
||||||
|
return
|
||||||
|
|
||||||
|
# TODO(Khalil): Compile the discrete critic
|
||||||
|
self.discrete_critic = DiscreteCritic(
|
||||||
|
encoder=self.encoder_critic,
|
||||||
|
input_dim=self.encoder_critic.output_dim,
|
||||||
|
output_dim=self.config.num_discrete_actions,
|
||||||
|
**asdict(self.config.discrete_critic_network_kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GaussianActorObservationEncoder(nn.Module):
|
class GaussianActorObservationEncoder(nn.Module):
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
|
|||||||
"""SAC algorithm hyperparameters."""
|
"""SAC algorithm hyperparameters."""
|
||||||
|
|
||||||
# Policy config
|
# Policy config
|
||||||
sac_config: GaussianActorConfig
|
policy_config: GaussianActorConfig
|
||||||
|
|
||||||
# Optimizer learning rates
|
# Optimizer learning rates
|
||||||
actor_lr: float = 3e-4
|
actor_lr: float = 3e-4
|
||||||
@@ -55,31 +55,26 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
|
|||||||
|
|
||||||
# Temperature / entropy
|
# Temperature / entropy
|
||||||
temperature_init: float = 1.0
|
temperature_init: float = 1.0
|
||||||
|
# Target entropy for automatic temperature tuning. If ``None``, defaults to
|
||||||
|
# ``-|A|/2`` where ``|A|`` is the total action dimension (continuous + 1 if
|
||||||
|
# there is a discrete action head).
|
||||||
|
target_entropy: float | None = None
|
||||||
|
|
||||||
# Update loop
|
# Update loop
|
||||||
utd_ratio: int = 1
|
utd_ratio: int = 1
|
||||||
policy_update_freq: int = 1
|
policy_update_freq: int = 1
|
||||||
grad_clip_norm: float = 40.0
|
grad_clip_norm: float = 40.0
|
||||||
|
|
||||||
|
# Optimizations
|
||||||
|
# torch.compile is currently disabled by default
|
||||||
|
use_torch_compile: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
|
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
|
||||||
"""Build an algorithm config by copying hyperparameters from the policy config."""
|
"""Build an algorithm config with default hyperparameters for a given policy."""
|
||||||
return cls(
|
return cls(
|
||||||
actor_lr=policy_cfg.actor_lr,
|
policy_config=policy_cfg,
|
||||||
critic_lr=policy_cfg.critic_lr,
|
|
||||||
temperature_lr=policy_cfg.temperature_lr,
|
|
||||||
discount=policy_cfg.discount,
|
|
||||||
use_backup_entropy=policy_cfg.use_backup_entropy,
|
|
||||||
critic_target_update_weight=policy_cfg.critic_target_update_weight,
|
|
||||||
num_critics=policy_cfg.num_critics,
|
|
||||||
num_subsample_critics=policy_cfg.num_subsample_critics,
|
|
||||||
critic_network_kwargs=policy_cfg.critic_network_kwargs,
|
|
||||||
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
|
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
|
||||||
temperature_init=policy_cfg.temperature_init,
|
|
||||||
utd_ratio=policy_cfg.utd_ratio,
|
|
||||||
policy_update_freq=policy_cfg.policy_update_freq,
|
|
||||||
grad_clip_norm=policy_cfg.grad_clip_norm,
|
|
||||||
sac_config=policy_cfg,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
|
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
|
||||||
|
|||||||
@@ -54,14 +54,14 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
config: SACAlgorithmConfig,
|
config: SACAlgorithmConfig,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.policy_config = config.sac_config
|
self.policy_config = config.policy_config
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
self.optimizers: dict[str, Optimizer] = {}
|
self.optimizers: dict[str, Optimizer] = {}
|
||||||
self._optimization_step: int = 0
|
self._optimization_step: int = 0
|
||||||
|
|
||||||
action_dim = self.policy.config.output_features[ACTION].shape[0]
|
action_dim = self.policy.config.output_features[ACTION].shape[0]
|
||||||
self._init_critics(action_dim)
|
self._init_critics(action_dim)
|
||||||
self._init_temperature()
|
self._init_temperature(action_dim)
|
||||||
|
|
||||||
self._device = torch.device(self.policy.config.device)
|
self._device = torch.device(self.policy.config.device)
|
||||||
self._move_to_device()
|
self._move_to_device()
|
||||||
@@ -90,49 +90,44 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
|
|
||||||
# TODO(Khalil): Investigate and fix torch.compile
|
# TODO(Khalil): Investigate and fix torch.compile
|
||||||
# NOTE: torch.compile is disabled, policy does not converge when enabled.
|
# NOTE: torch.compile is disabled, policy does not converge when enabled.
|
||||||
if self.policy_config.use_torch_compile:
|
if self.config.use_torch_compile:
|
||||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||||
self.critic_target = torch.compile(self.critic_target)
|
self.critic_target = torch.compile(self.critic_target)
|
||||||
|
|
||||||
self.discrete_critic = None
|
|
||||||
self.discrete_critic_target = None
|
self.discrete_critic_target = None
|
||||||
if self.policy_config.num_discrete_actions is not None:
|
if self.policy_config.num_discrete_actions is not None:
|
||||||
self.discrete_critic, self.discrete_critic_target = self._init_discrete_critics(encoder)
|
self.discrete_critic_target = self._init_discrete_critic_target(encoder)
|
||||||
self.policy.discrete_critic = self.discrete_critic
|
|
||||||
|
|
||||||
def _init_discrete_critics(
|
def _init_discrete_critic_target(self, encoder: GaussianActorObservationEncoder) -> DiscreteCritic:
|
||||||
self, encoder: GaussianActorObservationEncoder
|
"""Build target discrete critic (main network is owned by the policy)."""
|
||||||
) -> tuple[DiscreteCritic, DiscreteCritic]:
|
|
||||||
"""Build discrete critic ensemble and target networks."""
|
|
||||||
discrete_critic = DiscreteCritic(
|
|
||||||
encoder=encoder,
|
|
||||||
input_dim=encoder.output_dim,
|
|
||||||
output_dim=self.policy_config.num_discrete_actions,
|
|
||||||
**asdict(self.config.discrete_critic_network_kwargs),
|
|
||||||
)
|
|
||||||
discrete_critic_target = DiscreteCritic(
|
discrete_critic_target = DiscreteCritic(
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
input_dim=encoder.output_dim,
|
input_dim=encoder.output_dim,
|
||||||
output_dim=self.policy_config.num_discrete_actions,
|
output_dim=self.policy_config.num_discrete_actions,
|
||||||
**asdict(self.config.discrete_critic_network_kwargs),
|
**asdict(self.config.discrete_critic_network_kwargs),
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(Khalil): Compile the discrete critic
|
# TODO(Khalil): Compile the discrete critic
|
||||||
discrete_critic_target.load_state_dict(discrete_critic.state_dict())
|
discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict())
|
||||||
return discrete_critic, discrete_critic_target
|
return discrete_critic_target
|
||||||
|
|
||||||
def _init_temperature(self) -> None:
|
def _init_temperature(self, continuous_action_dim: int) -> None:
|
||||||
"""Set up temperature parameter (log_alpha)."""
|
"""Set up temperature parameter (log_alpha) and target entropy."""
|
||||||
temp_init = self.config.temperature_init
|
temp_init = self.config.temperature_init
|
||||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||||
|
|
||||||
|
self.target_entropy = self.config.target_entropy
|
||||||
|
if self.target_entropy is None:
|
||||||
|
total_action_dim = continuous_action_dim + (
|
||||||
|
1 if self.policy_config.num_discrete_actions is not None else 0
|
||||||
|
)
|
||||||
|
self.target_entropy = -total_action_dim / 2
|
||||||
|
|
||||||
def _move_to_device(self) -> None:
|
def _move_to_device(self) -> None:
|
||||||
self.policy.to(self._device)
|
self.policy.to(self._device)
|
||||||
self.critic_ensemble.to(self._device)
|
self.critic_ensemble.to(self._device)
|
||||||
self.critic_target.to(self._device)
|
self.critic_target.to(self._device)
|
||||||
self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device))
|
self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device))
|
||||||
if self.discrete_critic is not None:
|
if self.discrete_critic_target is not None:
|
||||||
self.discrete_critic.to(self._device)
|
|
||||||
self.discrete_critic_target.to(self._device)
|
self.discrete_critic_target.to(self._device)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -175,7 +170,7 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
Returns:
|
Returns:
|
||||||
Tensor of Q-values from the discrete critic network
|
Tensor of Q-values from the discrete critic network
|
||||||
"""
|
"""
|
||||||
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
|
discrete_critic = self.discrete_critic_target if use_target else self.policy.discrete_critic
|
||||||
q_values = discrete_critic(observations, observation_features)
|
q_values = discrete_critic(observations, observation_features)
|
||||||
return q_values
|
return q_values
|
||||||
|
|
||||||
@@ -196,7 +191,7 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
loss_dc = self._compute_loss_discrete_critic(fb)
|
loss_dc = self._compute_loss_discrete_critic(fb)
|
||||||
self.optimizers["discrete_critic"].zero_grad()
|
self.optimizers["discrete_critic"].zero_grad()
|
||||||
loss_dc.backward()
|
loss_dc.backward()
|
||||||
torch.nn.utils.clip_grad_norm_(self.discrete_critic.parameters(), max_norm=clip)
|
torch.nn.utils.clip_grad_norm_(self.policy.discrete_critic.parameters(), max_norm=clip)
|
||||||
self.optimizers["discrete_critic"].step()
|
self.optimizers["discrete_critic"].step()
|
||||||
|
|
||||||
self._update_target_networks()
|
self._update_target_networks()
|
||||||
@@ -219,7 +214,9 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
loss_dc = self._compute_loss_discrete_critic(fb)
|
loss_dc = self._compute_loss_discrete_critic(fb)
|
||||||
self.optimizers["discrete_critic"].zero_grad()
|
self.optimizers["discrete_critic"].zero_grad()
|
||||||
loss_dc.backward()
|
loss_dc.backward()
|
||||||
dc_grad = torch.nn.utils.clip_grad_norm_(self.discrete_critic.parameters(), max_norm=clip).item()
|
dc_grad = torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.policy.discrete_critic.parameters(), max_norm=clip
|
||||||
|
).item()
|
||||||
self.optimizers["discrete_critic"].step()
|
self.optimizers["discrete_critic"].step()
|
||||||
stats.losses["loss_discrete_critic"] = loss_dc.item()
|
stats.losses["loss_discrete_critic"] = loss_dc.item()
|
||||||
stats.grad_norms["discrete_critic"] = dc_grad
|
stats.grad_norms["discrete_critic"] = dc_grad
|
||||||
@@ -396,7 +393,7 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_, log_probs, _ = self.policy.actor(observations, observation_features)
|
_, log_probs, _ = self.policy.actor(observations, observation_features)
|
||||||
|
|
||||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.policy.target_entropy)).mean()
|
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
|
||||||
return temperature_loss
|
return temperature_loss
|
||||||
|
|
||||||
def _update_target_networks(self) -> None:
|
def _update_target_networks(self) -> None:
|
||||||
@@ -411,7 +408,7 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
if self.policy_config.num_discrete_actions is not None:
|
if self.policy_config.num_discrete_actions is not None:
|
||||||
for target_p, p in zip(
|
for target_p, p in zip(
|
||||||
self.discrete_critic_target.parameters(),
|
self.discrete_critic_target.parameters(),
|
||||||
self.discrete_critic.parameters(),
|
self.policy.discrete_critic.parameters(),
|
||||||
strict=True,
|
strict=True,
|
||||||
):
|
):
|
||||||
target_p.data.copy_(
|
target_p.data.copy_(
|
||||||
@@ -471,7 +468,7 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
}
|
}
|
||||||
if self.policy_config.num_discrete_actions is not None:
|
if self.policy_config.num_discrete_actions is not None:
|
||||||
self.optimizers["discrete_critic"] = torch.optim.Adam(
|
self.optimizers["discrete_critic"] = torch.optim.Adam(
|
||||||
self.discrete_critic.parameters(), lr=self.config.critic_lr
|
self.policy.discrete_critic.parameters(), lr=self.config.critic_lr
|
||||||
)
|
)
|
||||||
return self.optimizers
|
return self.optimizers
|
||||||
|
|
||||||
@@ -485,16 +482,13 @@ class SACAlgorithm(RLAlgorithm):
|
|||||||
}
|
}
|
||||||
if self.policy_config.num_discrete_actions is not None:
|
if self.policy_config.num_discrete_actions is not None:
|
||||||
state_dicts["discrete_critic"] = move_state_dict_to_device(
|
state_dicts["discrete_critic"] = move_state_dict_to_device(
|
||||||
self.discrete_critic.state_dict(), device="cpu"
|
self.policy.discrete_critic.state_dict(), device="cpu"
|
||||||
)
|
)
|
||||||
return state_dicts
|
return state_dicts
|
||||||
|
|
||||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||||
actor_sd = move_state_dict_to_device(weights["policy"], device=device)
|
"""Load actor + discrete-critic weights into the policy."""
|
||||||
self.policy.actor.load_state_dict(actor_sd)
|
self.policy.load_actor_weights(weights, device=device)
|
||||||
if "discrete_critic" in weights and self.policy_config.num_discrete_actions is not None:
|
|
||||||
dc_sd = move_state_dict_to_device(weights["discrete_critic"], device=device)
|
|
||||||
self.discrete_critic.load_state_dict(dc_sd)
|
|
||||||
|
|
||||||
def get_observation_features(
|
def get_observation_features(
|
||||||
self, observations: Tensor, next_observations: Tensor
|
self, observations: Tensor, next_observations: Tensor
|
||||||
|
|||||||
@@ -55,9 +55,6 @@ def test_gaussian_actor_config_default_initialization():
|
|||||||
# Basic parameters
|
# Basic parameters
|
||||||
assert config.device == "cpu"
|
assert config.device == "cpu"
|
||||||
assert config.storage_device == "cpu"
|
assert config.storage_device == "cpu"
|
||||||
assert config.discount == 0.99
|
|
||||||
assert config.temperature_init == 1.0
|
|
||||||
assert config.num_critics == 2
|
|
||||||
|
|
||||||
# Architecture specifics
|
# Architecture specifics
|
||||||
assert config.vision_encoder_name is None
|
assert config.vision_encoder_name is None
|
||||||
@@ -66,6 +63,8 @@ def test_gaussian_actor_config_default_initialization():
|
|||||||
assert config.shared_encoder is True
|
assert config.shared_encoder is True
|
||||||
assert config.num_discrete_actions is None
|
assert config.num_discrete_actions is None
|
||||||
assert config.image_embedding_pooling_dim == 8
|
assert config.image_embedding_pooling_dim == 8
|
||||||
|
assert config.state_encoder_hidden_dim == 256
|
||||||
|
assert config.latent_dim == 256
|
||||||
|
|
||||||
# Training parameters
|
# Training parameters
|
||||||
assert config.online_steps == 1000000
|
assert config.online_steps == 1000000
|
||||||
@@ -73,20 +72,6 @@ def test_gaussian_actor_config_default_initialization():
|
|||||||
assert config.offline_buffer_capacity == 100000
|
assert config.offline_buffer_capacity == 100000
|
||||||
assert config.async_prefetch is False
|
assert config.async_prefetch is False
|
||||||
assert config.online_step_before_learning == 100
|
assert config.online_step_before_learning == 100
|
||||||
assert config.policy_update_freq == 1
|
|
||||||
|
|
||||||
# SAC algorithm parameters
|
|
||||||
assert config.num_subsample_critics is None
|
|
||||||
assert config.critic_lr == 3e-4
|
|
||||||
assert config.actor_lr == 3e-4
|
|
||||||
assert config.temperature_lr == 3e-4
|
|
||||||
assert config.critic_target_update_weight == 0.005
|
|
||||||
assert config.utd_ratio == 1
|
|
||||||
assert config.state_encoder_hidden_dim == 256
|
|
||||||
assert config.latent_dim == 256
|
|
||||||
assert config.target_entropy is None
|
|
||||||
assert config.use_backup_entropy is True
|
|
||||||
assert config.grad_clip_norm == 40.0
|
|
||||||
|
|
||||||
# Dataset stats defaults
|
# Dataset stats defaults
|
||||||
expected_dataset_stats = {
|
expected_dataset_stats = {
|
||||||
@@ -105,11 +90,6 @@ def test_gaussian_actor_config_default_initialization():
|
|||||||
}
|
}
|
||||||
assert config.dataset_stats == expected_dataset_stats
|
assert config.dataset_stats == expected_dataset_stats
|
||||||
|
|
||||||
# Critic network configuration
|
|
||||||
assert config.critic_network_kwargs.hidden_dims == [256, 256]
|
|
||||||
assert config.critic_network_kwargs.activate_final is True
|
|
||||||
assert config.critic_network_kwargs.final_activation is None
|
|
||||||
|
|
||||||
# Actor network configuration
|
# Actor network configuration
|
||||||
assert config.actor_network_kwargs.hidden_dims == [256, 256]
|
assert config.actor_network_kwargs.hidden_dims == [256, 256]
|
||||||
assert config.actor_network_kwargs.activate_final is True
|
assert config.actor_network_kwargs.activate_final is True
|
||||||
@@ -135,7 +115,6 @@ def test_gaussian_actor_config_default_initialization():
|
|||||||
assert config.concurrency.learner == "threads"
|
assert config.concurrency.learner == "threads"
|
||||||
|
|
||||||
assert isinstance(config.actor_network_kwargs, ActorNetworkConfig)
|
assert isinstance(config.actor_network_kwargs, ActorNetworkConfig)
|
||||||
assert isinstance(config.critic_network_kwargs, CriticNetworkConfig)
|
|
||||||
assert isinstance(config.policy_kwargs, PolicyConfig)
|
assert isinstance(config.policy_kwargs, PolicyConfig)
|
||||||
assert isinstance(config.actor_learner_config, ActorLearnerConfig)
|
assert isinstance(config.actor_learner_config, ActorLearnerConfig)
|
||||||
assert isinstance(config.concurrency, ConcurrencyConfig)
|
assert isinstance(config.concurrency, ConcurrencyConfig)
|
||||||
@@ -178,15 +157,15 @@ def test_concurrency_config():
|
|||||||
def test_gaussian_actor_config_custom_initialization():
|
def test_gaussian_actor_config_custom_initialization():
|
||||||
config = GaussianActorConfig(
|
config = GaussianActorConfig(
|
||||||
device="cpu",
|
device="cpu",
|
||||||
discount=0.95,
|
latent_dim=128,
|
||||||
temperature_init=0.5,
|
state_encoder_hidden_dim=128,
|
||||||
num_critics=3,
|
num_discrete_actions=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert config.device == "cpu"
|
assert config.device == "cpu"
|
||||||
assert config.discount == 0.95
|
assert config.latent_dim == 128
|
||||||
assert config.temperature_init == 0.5
|
assert config.state_encoder_hidden_dim == 128
|
||||||
assert config.num_critics == 3
|
assert config.num_discrete_actions == 3
|
||||||
|
|
||||||
|
|
||||||
def test_validate_features():
|
def test_validate_features():
|
||||||
|
|||||||
@@ -404,19 +404,16 @@ def test_sac_training_with_discrete_critic():
|
|||||||
|
|
||||||
|
|
||||||
def test_sac_algorithm_target_entropy():
|
def test_sac_algorithm_target_entropy():
|
||||||
|
"""Target entropy is an SAC hyperparameter and lives on the algorithm."""
|
||||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||||
_, policy = _make_algorithm(config)
|
algorithm, _ = _make_algorithm(config)
|
||||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
|
||||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
|
||||||
assert algorithm.target_entropy == -5.0
|
assert algorithm.target_entropy == -5.0
|
||||||
|
|
||||||
|
|
||||||
def test_sac_algorithm_target_entropy_with_discrete_action():
|
def test_sac_algorithm_target_entropy_with_discrete_action():
|
||||||
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
|
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
|
||||||
config.num_discrete_actions = 5
|
config.num_discrete_actions = 5
|
||||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
algorithm, _ = _make_algorithm(config)
|
||||||
policy = GaussianActorPolicy(config=config)
|
|
||||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
|
||||||
assert algorithm.target_entropy == -3.5
|
assert algorithm.target_entropy == -3.5
|
||||||
|
|
||||||
|
|
||||||
@@ -435,8 +432,8 @@ def test_sac_algorithm_temperature():
|
|||||||
|
|
||||||
def test_sac_algorithm_update_target_network():
|
def test_sac_algorithm_update_target_network():
|
||||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||||
config.critic_target_update_weight = 1.0
|
|
||||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||||
|
algo_config.critic_target_update_weight = 1.0
|
||||||
policy = GaussianActorPolicy(config=config)
|
policy = GaussianActorPolicy(config=config)
|
||||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||||
|
|
||||||
@@ -454,9 +451,13 @@ def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
|
|||||||
action_dim = 10
|
action_dim = 10
|
||||||
state_dim = 10
|
state_dim = 10
|
||||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||||
config.num_critics = num_critics
|
|
||||||
|
|
||||||
algorithm, policy = _make_algorithm(config)
|
policy = GaussianActorPolicy(config=config)
|
||||||
|
policy.train()
|
||||||
|
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||||
|
algo_config.num_critics = num_critics
|
||||||
|
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||||
|
algorithm.make_optimizers_and_scheduler()
|
||||||
|
|
||||||
assert len(algorithm.critic_ensemble.critics) == num_critics
|
assert len(algorithm.critic_ensemble.critics) == num_critics
|
||||||
|
|
||||||
|
|||||||
@@ -327,7 +327,6 @@ def test_learner_algorithm_wiring():
|
|||||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||||
},
|
},
|
||||||
use_torch_compile=False,
|
|
||||||
)
|
)
|
||||||
sac_cfg.validate_features()
|
sac_cfg.validate_features()
|
||||||
|
|
||||||
@@ -412,7 +411,6 @@ def test_initial_and_periodic_weight_push_consistency():
|
|||||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||||
},
|
},
|
||||||
use_torch_compile=False,
|
|
||||||
)
|
)
|
||||||
sac_cfg.validate_features()
|
sac_cfg.validate_features()
|
||||||
|
|
||||||
@@ -450,7 +448,6 @@ def test_actor_side_algorithm_select_action_and_load_weights():
|
|||||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||||
},
|
},
|
||||||
use_torch_compile=False,
|
|
||||||
)
|
)
|
||||||
sac_cfg.validate_features()
|
sac_cfg.validate_features()
|
||||||
|
|
||||||
|
|||||||
@@ -44,8 +44,6 @@ def _make_sac_config(
|
|||||||
state_dim: int = 10,
|
state_dim: int = 10,
|
||||||
action_dim: int = 6,
|
action_dim: int = 6,
|
||||||
num_discrete_actions: int | None = None,
|
num_discrete_actions: int | None = None,
|
||||||
utd_ratio: int = 1,
|
|
||||||
policy_update_freq: int = 1,
|
|
||||||
with_images: bool = False,
|
with_images: bool = False,
|
||||||
) -> GaussianActorConfig:
|
) -> GaussianActorConfig:
|
||||||
config = GaussianActorConfig(
|
config = GaussianActorConfig(
|
||||||
@@ -55,10 +53,7 @@ def _make_sac_config(
|
|||||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||||
},
|
},
|
||||||
utd_ratio=utd_ratio,
|
|
||||||
policy_update_freq=policy_update_freq,
|
|
||||||
num_discrete_actions=num_discrete_actions,
|
num_discrete_actions=num_discrete_actions,
|
||||||
use_torch_compile=False,
|
|
||||||
)
|
)
|
||||||
if with_images:
|
if with_images:
|
||||||
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
|
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
|
||||||
@@ -83,14 +78,14 @@ def _make_algorithm(
|
|||||||
sac_cfg = _make_sac_config(
|
sac_cfg = _make_sac_config(
|
||||||
state_dim=state_dim,
|
state_dim=state_dim,
|
||||||
action_dim=action_dim,
|
action_dim=action_dim,
|
||||||
utd_ratio=utd_ratio,
|
|
||||||
policy_update_freq=policy_update_freq,
|
|
||||||
num_discrete_actions=num_discrete_actions,
|
num_discrete_actions=num_discrete_actions,
|
||||||
with_images=with_images,
|
with_images=with_images,
|
||||||
)
|
)
|
||||||
policy = GaussianActorPolicy(config=sac_cfg)
|
policy = GaussianActorPolicy(config=sac_cfg)
|
||||||
policy.train()
|
policy.train()
|
||||||
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||||
|
algo_config.utd_ratio = utd_ratio
|
||||||
|
algo_config.policy_update_freq = policy_update_freq
|
||||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||||
algorithm.make_optimizers_and_scheduler()
|
algorithm.make_optimizers_and_scheduler()
|
||||||
return algorithm, policy
|
return algorithm, policy
|
||||||
@@ -136,13 +131,16 @@ def test_sac_algorithm_config_registered():
|
|||||||
|
|
||||||
|
|
||||||
def test_sac_algorithm_config_from_policy_config():
|
def test_sac_algorithm_config_from_policy_config():
|
||||||
"""from_policy_config should copy algorithm hyperparameters from the policy config."""
|
"""from_policy_config embeds the policy config and uses SAC defaults."""
|
||||||
sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2)
|
sac_cfg = _make_sac_config()
|
||||||
algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||||
assert algo_cfg.sac_config is sac_cfg
|
assert algo_cfg.policy_config is sac_cfg
|
||||||
assert algo_cfg.utd_ratio == 4
|
assert algo_cfg.discrete_critic_network_kwargs is sac_cfg.discrete_critic_network_kwargs
|
||||||
assert algo_cfg.policy_update_freq == 2
|
# Defaults come from SACAlgorithmConfig, not from the policy config.
|
||||||
assert algo_cfg.grad_clip_norm == sac_cfg.grad_clip_norm
|
assert algo_cfg.utd_ratio == 1
|
||||||
|
assert algo_cfg.policy_update_freq == 1
|
||||||
|
assert algo_cfg.grad_clip_norm == 40.0
|
||||||
|
assert algo_cfg.actor_lr == 3e-4
|
||||||
|
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
@@ -377,12 +375,14 @@ def test_actor_side_no_optimizers():
|
|||||||
assert algorithm.optimizers == {}
|
assert algorithm.optimizers == {}
|
||||||
|
|
||||||
|
|
||||||
def test_make_algorithm_copies_config_fields():
|
def test_make_algorithm_uses_sac_algorithm_defaults():
|
||||||
sac_cfg = _make_sac_config(utd_ratio=5, policy_update_freq=3)
|
"""make_algorithm populates SACAlgorithmConfig with its own defaults."""
|
||||||
|
sac_cfg = _make_sac_config()
|
||||||
policy = GaussianActorPolicy(config=sac_cfg)
|
policy = GaussianActorPolicy(config=sac_cfg)
|
||||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||||
assert algorithm.config.utd_ratio == 5
|
assert algorithm.config.utd_ratio == 1
|
||||||
assert algorithm.config.policy_update_freq == 3
|
assert algorithm.config.policy_update_freq == 1
|
||||||
|
assert algorithm.config.grad_clip_norm == 40.0
|
||||||
|
|
||||||
|
|
||||||
def test_make_algorithm_raises_for_unknown_type():
|
def test_make_algorithm_raises_for_unknown_type():
|
||||||
@@ -431,10 +431,10 @@ def test_load_weights_round_trip_with_discrete_critic():
|
|||||||
|
|
||||||
assert "discrete_critic" in weights
|
assert "discrete_critic" in weights
|
||||||
assert len(weights["discrete_critic"]) > 0
|
assert len(weights["discrete_critic"]) > 0
|
||||||
dst_dc_state_dict = algo_dst.discrete_critic.state_dict()
|
dst_discrete_critic_state_dict = algo_dst.policy.discrete_critic.state_dict()
|
||||||
for key, tensor in weights["discrete_critic"].items():
|
for key, tensor in weights["discrete_critic"].items():
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
dst_dc_state_dict[key].cpu(),
|
dst_discrete_critic_state_dict[key].cpu(),
|
||||||
tensor.cpu(),
|
tensor.cpu(),
|
||||||
), f"Discrete critic param '{key}' mismatch after load_weights"
|
), f"Discrete critic param '{key}' mismatch after load_weights"
|
||||||
|
|
||||||
@@ -446,6 +446,47 @@ def test_load_weights_ignores_missing_discrete_critic():
|
|||||||
algorithm.load_weights(weights, device="cpu")
|
algorithm.load_weights(weights, device="cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def test_actor_side_weight_sync_with_discrete_critic():
|
||||||
|
"""End-to-end: learner ``algorithm.get_weights()`` -> actor ``policy.load_actor_weights()``."""
|
||||||
|
# Learner side: train the algorithm so its weights diverge from init.
|
||||||
|
algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||||
|
algo_src.update(_batch_iterator(action_dim=7))
|
||||||
|
weights = algo_src.get_weights()
|
||||||
|
|
||||||
|
# Actor side: fresh policy, no algorithm/optimizer.
|
||||||
|
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
|
||||||
|
policy_actor = GaussianActorPolicy(config=sac_cfg)
|
||||||
|
|
||||||
|
# Snapshot initial actor state for the "did it change?" assertion below.
|
||||||
|
initial_discrete_critic_state_dict = {
|
||||||
|
k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items()
|
||||||
|
}
|
||||||
|
|
||||||
|
policy_actor.load_actor_weights(weights, device="cpu")
|
||||||
|
|
||||||
|
# Actor weights match the learner's exported actor state dict.
|
||||||
|
actor_state_dict = policy_actor.actor.state_dict()
|
||||||
|
for key, tensor in weights["policy"].items():
|
||||||
|
assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), (
|
||||||
|
f"Actor param '{key}' not synced by load_actor_weights"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Discrete critic weights match the learner's exported discrete critic.
|
||||||
|
discrete_critic_state_dict = policy_actor.discrete_critic.state_dict()
|
||||||
|
for key, tensor in weights["discrete_critic"].items():
|
||||||
|
assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), (
|
||||||
|
f"Discrete critic param '{key}' not synced by load_actor_weights"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sanity: the discrete critic actually changed (otherwise the sync is trivial).
|
||||||
|
changed = any(
|
||||||
|
not torch.equal(initial_discrete_critic_state_dict[key], discrete_critic_state_dict[key])
|
||||||
|
for key in initial_discrete_critic_state_dict
|
||||||
|
if key in discrete_critic_state_dict
|
||||||
|
)
|
||||||
|
assert changed, "Discrete critic weights did not change between init and after sync"
|
||||||
|
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
# TrainingStats generic losses dict
|
# TrainingStats generic losses dict
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
@@ -468,8 +509,9 @@ def test_training_stats_generic_losses():
|
|||||||
|
|
||||||
def test_build_algorithm_via_config():
|
def test_build_algorithm_via_config():
|
||||||
"""SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm."""
|
"""SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm."""
|
||||||
sac_cfg = _make_sac_config(utd_ratio=2)
|
sac_cfg = _make_sac_config()
|
||||||
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||||
|
algo_config.utd_ratio = 2
|
||||||
policy = GaussianActorPolicy(config=sac_cfg)
|
policy = GaussianActorPolicy(config=sac_cfg)
|
||||||
|
|
||||||
algorithm = algo_config.build_algorithm(policy)
|
algorithm = algo_config.build_algorithm(policy)
|
||||||
|
|||||||
Reference in New Issue
Block a user