From 1f5487eea8ddeaae2f791ee7735fbd2be8d1c63b Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Wed, 11 Mar 2026 16:49:14 +0100 Subject: [PATCH] refactor: decouple policy from algorithm --- examples/tutorial/rl/hilserl_example.py | 31 +- src/lerobot/policies/sac/modeling_sac.py | 398 ++--------------- src/lerobot/rl/actor.py | 19 +- src/lerobot/rl/algorithms/base.py | 13 +- src/lerobot/rl/algorithms/sac.py | 262 ------------ src/lerobot/rl/algorithms/sac/__init__.py | 18 + .../rl/algorithms/sac/configuration_sac.py | 81 ++++ .../rl/algorithms/sac/sac_algorithm.py | 400 ++++++++++++++++++ src/lerobot/utils/train_utils.py | 1 + tests/policies/test_sac_policy.py | 395 ++++++++--------- tests/rl/test_actor_learner.py | 6 +- tests/rl/test_sac_algorithm.py | 53 +-- 12 files changed, 769 insertions(+), 908 deletions(-) delete mode 100644 src/lerobot/rl/algorithms/sac.py create mode 100644 src/lerobot/rl/algorithms/sac/__init__.py create mode 100644 src/lerobot/rl/algorithms/sac/configuration_sac.py create mode 100644 src/lerobot/rl/algorithms/sac/sac_algorithm.py diff --git a/examples/tutorial/rl/hilserl_example.py b/examples/tutorial/rl/hilserl_example.py index 980ac7985..58b84722b 100644 --- a/examples/tutorial/rl/hilserl_example.py +++ b/examples/tutorial/rl/hilserl_example.py @@ -4,7 +4,6 @@ from pathlib import Path from queue import Empty, Full import torch -import torch.optim as optim from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import hw_to_dataset_features @@ -12,6 +11,7 @@ from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.policies.sac.reward_model.modeling_classifier import Classifier +from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig from lerobot.rl.buffer import ReplayBuffer from lerobot.rl.gym_manipulator import make_robot_env from lerobot.robots.so_follower import SO100FollowerConfig @@ -40,8 +40,9 @@ def run_learner( policy_learner.train() policy_learner.to(device) - # Create Adam optimizer from scratch - simple and clean - optimizer = optim.Adam(policy_learner.parameters(), lr=lr) + algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config) + algorithm = SACAlgorithm(policy=policy_learner, config=algo_config) + algorithm.make_optimizers() print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}") print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}") @@ -83,24 +84,26 @@ def run_learner( else: batch[key] = online_batch[key] - loss, _ = policy_learner.forward(batch) + def batch_iter(b=batch): + while True: + yield b - optimizer.zero_grad() - loss.backward() - optimizer.step() + stats = algorithm.update(batch_iter()) training_step += 1 if training_step % LOG_EVERY == 0: + log_dict = stats.to_log_dict() print( - f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, " + f"[LEARNER] Training step {training_step}, " + f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, " f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}" ) # Send updated parameters to actor every 10 training steps if training_step % SEND_EVERY == 0: try: - state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()} - parameters_queue.put_nowait(state_dict) + weights = algorithm.get_weights() + parameters_queue.put_nowait(weights) print("[LEARNER] Sent updated parameters to actor") except Full: # Missing write due to queue not being consumed (should happen rarely) @@ -144,15 +147,15 @@ def run_actor( while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set(): try: - new_params = parameters_queue.get_nowait() - policy_actor.load_state_dict(new_params) + new_weights = parameters_queue.get_nowait() + policy_actor.load_state_dict(new_weights) print("[ACTOR] Updated policy parameters from learner") except Empty: # No new updated parameters available from learner, waiting pass - # Get action from policy + # Get action from policy (returns full action: continuous + discrete) policy_obs = make_policy_obs(obs, device=device) - action_tensor = policy_actor.select_action(policy_obs) # predicts a single action + action_tensor = policy_actor.select_action(policy_obs) action = action_tensor.squeeze(0).cpu().numpy() # Step environment diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index d5dd71a48..67da85b31 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -15,16 +15,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from collections.abc import Callable from dataclasses import asdict -from typing import Literal -import einops -import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F # noqa: N812 from torch import Tensor from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution @@ -52,20 +47,13 @@ class SACPolicy( # Determine action dimension and initialize all components continuous_action_dim = config.output_features[ACTION].shape[0] - self._init_encoders() - self._init_critics(continuous_action_dim) + self.encoder = SACObservationEncoder(config) self._init_actor(continuous_action_dim) - self._init_temperature() + self._init_discrete_critic() def get_optim_params(self) -> dict: optim_params = { - "actor": [ - p - for n, p in self.actor.named_parameters() - if not n.startswith("encoder") or not self.shared_encoder - ], - "critic": self.critic_ensemble.parameters(), - "temperature": self.log_alpha, + "actor": [self.actor.parameters()], } if self.config.num_discrete_actions is not None: optim_params["discrete_critic"] = self.discrete_critic.parameters() @@ -83,10 +71,9 @@ class SACPolicy( @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select action for inference/evaluation""" - observations_features = None - if self.shared_encoder and self.actor.encoder.has_images: - observations_features = self.actor.encoder.get_cached_image_features(batch) + if self.encoder.has_images: + observations_features = self.encoder.get_cached_image_features(batch) actions, _, _ = self.actor(batch, observations_features) @@ -97,372 +84,35 @@ class SACPolicy( return actions - def critic_forward( - self, - observations: dict[str, Tensor], - actions: Tensor, - use_target: bool = False, - observation_features: Tensor | None = None, - ) -> Tensor: - """Forward pass through a critic network ensemble - - Args: - observations: Dictionary of observations - actions: Action tensor - use_target: If True, use target critics, otherwise use ensemble critics - - Returns: - Tensor of Q-values from all critics - """ - - critics = self.critic_target if use_target else self.critic_ensemble - q_values = critics(observations, actions, observation_features) - return q_values - - def discrete_critic_forward( - self, observations, use_target=False, observation_features=None - ) -> torch.Tensor: - """Forward pass through a discrete critic network - - Args: - observations: Dictionary of observations - use_target: If True, use target critics, otherwise use ensemble critics - observation_features: Optional pre-computed observation features to avoid recomputing encoder output - - Returns: - Tensor of Q-values from the discrete critic network - """ - discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic - q_values = discrete_critic(observations, observation_features) - return q_values - def forward( self, batch: dict[str, Tensor | dict[str, Tensor]], - model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic", ) -> dict[str, Tensor]: - """Compute the loss for the given model + """Actor forward pass.""" + observations = batch.get("state", batch) + observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None + actions, log_probs, means = self.actor(observations, observation_features) + return {"action": actions, "log_prob": log_probs, "action_mean": means} - Args: - batch: Dictionary containing: - - action: Action tensor - - reward: Reward tensor - - state: Observations tensor dict - - next_state: Next observations tensor dict - - done: Done mask tensor - - observation_feature: Optional pre-computed observation features - - next_observation_feature: Optional pre-computed next observation features - model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature") - - Returns: - The computed loss tensor - """ - # Extract common components from batch - actions: Tensor = batch[ACTION] - observations: dict[str, Tensor] = batch["state"] - observation_features: Tensor = batch.get("observation_feature") - - if model == "critic": - # Extract critic-specific components - rewards: Tensor = batch["reward"] - next_observations: dict[str, Tensor] = batch["next_state"] - done: Tensor = batch["done"] - next_observation_features: Tensor = batch.get("next_observation_feature") - - loss_critic = self.compute_loss_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - ) - - return {"loss_critic": loss_critic} - - if model == "discrete_critic" and self.config.num_discrete_actions is not None: - # Extract critic-specific components - rewards: Tensor = batch["reward"] - next_observations: dict[str, Tensor] = batch["next_state"] - done: Tensor = batch["done"] - next_observation_features: Tensor = batch.get("next_observation_feature") - complementary_info = batch.get("complementary_info") - loss_discrete_critic = self.compute_loss_discrete_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - complementary_info=complementary_info, - ) - return {"loss_discrete_critic": loss_discrete_critic} - if model == "actor": - return { - "loss_actor": self.compute_loss_actor( - observations=observations, - observation_features=observation_features, - ) - } - - if model == "temperature": - return { - "loss_temperature": self.compute_loss_temperature( - observations=observations, - observation_features=observation_features, - ) - } - - raise ValueError(f"Unknown model type: {model}") - - def update_target_networks(self): - """Update target networks with exponential moving average""" - for target_param, param in zip( - self.critic_target.parameters(), - self.critic_ensemble.parameters(), - strict=True, - ): - target_param.data.copy_( - param.data * self.config.critic_target_update_weight - + target_param.data * (1.0 - self.config.critic_target_update_weight) - ) - if self.config.num_discrete_actions is not None: - for target_param, param in zip( - self.discrete_critic_target.parameters(), - self.discrete_critic.parameters(), - strict=True, - ): - target_param.data.copy_( - param.data * self.config.critic_target_update_weight - + target_param.data * (1.0 - self.config.critic_target_update_weight) - ) - - @property - def temperature(self) -> float: - """Return the current temperature value, always in sync with log_alpha.""" - return self.log_alpha.exp().item() - - def compute_loss_critic( - self, - observations, - actions, - rewards, - next_observations, - done, - observation_features: Tensor | None = None, - next_observation_features: Tensor | None = None, - ) -> Tensor: - with torch.no_grad(): - next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features) - - # 2- compute q targets - q_targets = self.critic_forward( - observations=next_observations, - actions=next_action_preds, - use_target=True, - observation_features=next_observation_features, - ) - - # subsample critics to prevent overfitting if use high UTD (update to date) - # TODO: Get indices before forward pass to avoid unnecessary computation - if self.config.num_subsample_critics is not None: - indices = torch.randperm(self.config.num_critics) - indices = indices[: self.config.num_subsample_critics] - q_targets = q_targets[indices] - - # critics subsample size - min_q, _ = q_targets.min(dim=0) # Get values from min operation - if self.config.use_backup_entropy: - min_q = min_q - (self.temperature * next_log_probs) - - td_target = rewards + (1 - done) * self.config.discount * min_q - - # 3- compute predicted qs - if self.config.num_discrete_actions is not None: - # NOTE: We only want to keep the continuous action part - # In the buffer we have the full action space (continuous + discrete) - # We need to split them before concatenating them in the critic forward - actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] - q_preds = self.critic_forward( - observations=observations, - actions=actions, - use_target=False, - observation_features=observation_features, - ) - - # 4- Calculate loss - # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. - td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]) - # You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up - critics_loss = ( - F.mse_loss( - input=q_preds, - target=td_target_duplicate, - reduction="none", - ).mean(dim=1) - ).sum() - return critics_loss - - def compute_loss_discrete_critic( - self, - observations, - actions, - rewards, - next_observations, - done, - observation_features=None, - next_observation_features=None, - complementary_info=None, - ): - # NOTE: We only want to keep the discrete action part - # In the buffer we have the full action space (continuous + discrete) - # We need to split them before concatenating them in the critic forward - actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() - actions_discrete = torch.round(actions_discrete) - actions_discrete = actions_discrete.long() - - discrete_penalties: Tensor | None = None - if complementary_info is not None: - discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty") - - with torch.no_grad(): - # For DQN, select actions using online network, evaluate with target network - next_discrete_qs = self.discrete_critic_forward( - next_observations, use_target=False, observation_features=next_observation_features - ) - best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True) - - # Get target Q-values from target network - target_next_discrete_qs = self.discrete_critic_forward( - observations=next_observations, - use_target=True, - observation_features=next_observation_features, - ) - - # Use gather to select Q-values for best actions - target_next_discrete_q = torch.gather( - target_next_discrete_qs, dim=1, index=best_next_discrete_action - ).squeeze(-1) - - # Compute target Q-value with Bellman equation - rewards_discrete = rewards - if discrete_penalties is not None: - rewards_discrete = rewards + discrete_penalties - target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q - - # Get predicted Q-values for current observations - predicted_discrete_qs = self.discrete_critic_forward( - observations=observations, use_target=False, observation_features=observation_features - ) - - # Use gather to select Q-values for taken actions - predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1) - - # Compute MSE loss between predicted and target Q-values - discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q) - return discrete_critic_loss - - def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: - """Compute the temperature loss""" - # calculate temperature loss - with torch.no_grad(): - _, log_probs, _ = self.actor(observations, observation_features) - temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean() - return temperature_loss - - def compute_loss_actor( - self, - observations, - observation_features: Tensor | None = None, - ) -> Tensor: - actions_pi, log_probs, _ = self.actor(observations, observation_features) - - q_preds = self.critic_forward( - observations=observations, - actions=actions_pi, - use_target=False, - observation_features=observation_features, - ) - min_q_preds = q_preds.min(dim=0)[0] - - actor_loss = ((self.temperature * log_probs) - min_q_preds).mean() - return actor_loss - - def _init_encoders(self): - """Initialize shared or separate encoders for actor and critic.""" - self.shared_encoder = self.config.shared_encoder - self.encoder_critic = SACObservationEncoder(self.config) - self.encoder_actor = ( - self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config) - ) - - def _init_critics(self, continuous_action_dim): - """Build critic ensemble, targets, and optional discrete critic.""" - heads = [ - CriticHead( - input_dim=self.encoder_critic.output_dim + continuous_action_dim, - **asdict(self.config.critic_network_kwargs), - ) - for _ in range(self.config.num_critics) - ] - self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads) - target_heads = [ - CriticHead( - input_dim=self.encoder_critic.output_dim + continuous_action_dim, - **asdict(self.config.critic_network_kwargs), - ) - for _ in range(self.config.num_critics) - ] - self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads) - self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) - - if self.config.use_torch_compile: - self.critic_ensemble = torch.compile(self.critic_ensemble) - self.critic_target = torch.compile(self.critic_target) - - if self.config.num_discrete_actions is not None: - self._init_discrete_critics() - - def _init_discrete_critics(self): - """Build discrete discrete critic ensemble and target networks.""" - self.discrete_critic = DiscreteCritic( - encoder=self.encoder_critic, - input_dim=self.encoder_critic.output_dim, - output_dim=self.config.num_discrete_actions, - **asdict(self.config.discrete_critic_network_kwargs), - ) - self.discrete_critic_target = DiscreteCritic( - encoder=self.encoder_critic, - input_dim=self.encoder_critic.output_dim, - output_dim=self.config.num_discrete_actions, - **asdict(self.config.discrete_critic_network_kwargs), - ) - - # TODO: (maractingi, azouitine) Compile the discrete critic - self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict()) - - def _init_actor(self, continuous_action_dim): - """Initialize policy actor network and default target entropy.""" - # NOTE: The actor select only the continuous action part + def _init_actor(self, continuous_action_dim: int) -> None: self.actor = Policy( - encoder=self.encoder_actor, - network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)), + encoder=self.encoder, + network=MLP(input_dim=self.encoder.output_dim, **asdict(self.config.actor_network_kwargs)), action_dim=continuous_action_dim, - encoder_is_shared=self.shared_encoder, + encoder_is_shared=False, **asdict(self.config.policy_kwargs), ) - self.target_entropy = self.config.target_entropy - if self.target_entropy is None: - dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0) - self.target_entropy = -np.prod(dim) / 2 - - def _init_temperature(self) -> None: - """Set up temperature parameter (log_alpha).""" - temp_init = self.config.temperature_init - self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) + def _init_discrete_critic(self) -> None: + if self.config.num_discrete_actions is None: + self.discrete_critic = None + return + self.discrete_critic = DiscreteCritic( + encoder=self.encoder, + input_dim=self.encoder.output_dim, + output_dim=self.config.num_discrete_actions, + **asdict(self.config.discrete_critic_network_kwargs), + ) class SACObservationEncoder(nn.Module): diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 03d48e775..df01108cb 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -61,8 +61,8 @@ from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.policies.factory import make_policy, make_pre_post_processors +from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.processor import TransitionKey -from lerobot.rl.algorithms import RLAlgorithm, make_algorithm from lerobot.rl.process import ProcessSignalHandler from lerobot.rl.queue import get_last_item_from_queue from lerobot.robots import so_follower # noqa: F401 @@ -81,6 +81,7 @@ from lerobot.utils.random_utils import set_seed from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.transition import ( Transition, + move_state_dict_to_device, move_transition_to_device, ) from lerobot.utils.utils import ( @@ -247,9 +248,6 @@ def act_with_policy( logging.info("make_policy") - ### Instantiate the policy in both the actor and learner processes - ### To avoid sending a SACPolicy object through the port, we create a policy instance - ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters policy = make_policy( cfg=cfg.policy, env_cfg=cfg.env, @@ -257,8 +255,6 @@ def act_with_policy( policy = policy.eval() assert isinstance(policy, nn.Module) - algorithm = make_algorithm(policy=policy, policy_cfg=cfg.policy, algorithm_name=cfg.algorithm) - # Build policy pre/post processors for observation normalization and action unnormalization processor_kwargs = {} postprocessor_kwargs = {} @@ -324,7 +320,7 @@ def act_with_policy( # Time policy inference and check if it meets FPS requirement with policy_timer: - action = algorithm.select_action(observation_for_inference) + action = policy.select_action(observation_for_inference) policy_fps = policy_timer.fps_last # Postprocess action (unnormalization, move to cpu). @@ -397,7 +393,7 @@ def act_with_policy( if done or truncated: logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") - update_policy_parameters(algorithm=algorithm, parameters_queue=parameters_queue, device=device) + update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device) if len(list_transition_to_send_to_learner) > 0: push_transitions_to_transport_queue( @@ -695,8 +691,8 @@ def interactions_stream( # Policy functions -def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, device): - """Load the latest weights from the learner via the algorithm's ``load_weights`` API.""" +def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, device): + """Load the latest policy weights from the learner.""" bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False) if bytes_state_dict is not None: logging.info("[ACTOR] Load new parameters from Learner.") @@ -711,7 +707,8 @@ def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, de # - Skip encoder params entirely when freeze_vision_encoder=True # - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic) # Load actor state dict - algorithm.load_weights(state_dicts, device=device) + state_dicts = move_state_dict_to_device(state_dicts, device=device) + policy.load_state_dict(state_dicts) # Utilities functions diff --git a/src/lerobot/rl/algorithms/base.py b/src/lerobot/rl/algorithms/base.py index 839d2288f..36aacde7b 100644 --- a/src/lerobot/rl/algorithms/base.py +++ b/src/lerobot/rl/algorithms/base.py @@ -80,15 +80,6 @@ class RLAlgorithmConfig(draccus.ChoiceRegistry): class RLAlgorithm(abc.ABC): """Base for all RL algorithms.""" - @abc.abstractmethod - def select_action(self, observation: dict[str, Tensor]) -> Tensor: - """Select action(s) for rollout. - - Single-step policies (e.g. SAC) return shape ``(action_dim,)``; - chunking policies (e.g. VLA) return ``(chunk_size, action_dim)``. - """ - ... - @abc.abstractmethod def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: """One complete training step. @@ -145,12 +136,12 @@ class RLAlgorithm(abc.ABC): self._optimization_step = int(value) def get_weights(self) -> dict[str, Any]: - """State-dict(s) to push to actors.""" + """Policy state-dict to push to actors.""" return {} @abc.abstractmethod def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: - """Load state-dict(s) received from the learner (inverse of ``get_weights``).""" + """Load policy state-dict received from the learner (inverse of ``get_weights``).""" @torch.no_grad() def get_observation_features( diff --git a/src/lerobot/rl/algorithms/sac.py b/src/lerobot/rl/algorithms/sac.py deleted file mode 100644 index c16ae48f6..000000000 --- a/src/lerobot/rl/algorithms/sac.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""SAC (Soft Actor-Critic) algorithm. - -This module encapsulates all SAC-specific training logic (critic, actor, -temperature, and discrete-critic updates) behind the ``RLAlgorithm`` interface. -""" - -from __future__ import annotations - -from collections.abc import Iterator -from dataclasses import dataclass -from typing import Any - -import torch -from torch import Tensor -from torch.optim import Optimizer - -from lerobot.policies.sac.modeling_sac import SACPolicy -from lerobot.rl.algorithms.base import ( - BatchType, - RLAlgorithm, - RLAlgorithmConfig, - TrainingStats, -) -from lerobot.utils.constants import ACTION -from lerobot.utils.transition import move_state_dict_to_device - - -@RLAlgorithmConfig.register_subclass("sac") -@dataclass -class SACAlgorithmConfig(RLAlgorithmConfig): - """SAC-specific hyper-parameters that control the update loop.""" - - utd_ratio: int = 1 - policy_update_freq: int = 1 - clip_grad_norm: float = 40.0 - actor_lr: float = 3e-4 - critic_lr: float = 3e-4 - - @classmethod - def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig: - """Build from an existing ``SACConfig`` (cfg.policy) for backwards compat.""" - return cls( - utd_ratio=policy_cfg.utd_ratio, - policy_update_freq=policy_cfg.policy_update_freq, - clip_grad_norm=policy_cfg.grad_clip_norm, - actor_lr=policy_cfg.actor_lr, - critic_lr=policy_cfg.critic_lr, - ) - - def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm: - return SACAlgorithm(policy=policy, config=self) - - -class SACAlgorithm(RLAlgorithm): - """Soft Actor-Critic with optional discrete-critic head. - - Owns the ``SACPolicy`` and its optimizers. - """ - - def __init__( - self, - policy: SACPolicy, - config: SACAlgorithmConfig, - ): - self.policy = policy - self.config = config - self.optimizers: dict[str, Optimizer] = {} - self._optimization_step: int = 0 - - @torch.no_grad() - def select_action(self, observation: dict[str, Tensor]) -> Tensor: - return self.policy.select_action(observation).squeeze(0) - - def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: - """Run one full SAC update with UTD critic warm-up. - - Pulls ``utd_ratio`` batches from ``batch_iterator``. The first - ``utd_ratio - 1`` batches are used for critic-only warm-up steps; - the last batch drives the full update (critic + actor + temperature). - """ - for _ in range(self.config.utd_ratio - 1): - batch = next(batch_iterator) - forward_batch = self._prepare_forward_batch(batch) - - critic_output = self.policy.forward(forward_batch, model="critic") - loss_critic = critic_output["loss_critic"] - self.optimizers["critic"].zero_grad() - loss_critic.backward() - torch.nn.utils.clip_grad_norm_( - self.policy.critic_ensemble.parameters(), - max_norm=self.config.clip_grad_norm, - ).item() - self.optimizers["critic"].step() - - if self.policy.config.num_discrete_actions is not None: - discrete_critic_output = self.policy.forward(forward_batch, model="discrete_critic") - loss_discrete = discrete_critic_output["loss_discrete_critic"] - self.optimizers["discrete_critic"].zero_grad() - loss_discrete.backward() - torch.nn.utils.clip_grad_norm_( - self.policy.discrete_critic.parameters(), - max_norm=self.config.clip_grad_norm, - ).item() - self.optimizers["discrete_critic"].step() - self.policy.update_target_networks() - - batch = next(batch_iterator) - forward_batch = self._prepare_forward_batch(batch) - - critic_output = self.policy.forward(forward_batch, model="critic") - loss_critic = critic_output["loss_critic"] - self.optimizers["critic"].zero_grad() - loss_critic.backward() - critic_grad_norm = torch.nn.utils.clip_grad_norm_( - self.policy.critic_ensemble.parameters(), - max_norm=self.config.clip_grad_norm, - ).item() - self.optimizers["critic"].step() - - critic_loss_val = loss_critic.item() - stats = TrainingStats( - losses={"critic": critic_loss_val}, - grad_norms={"critic": critic_grad_norm}, - ) - - if self.policy.config.num_discrete_actions is not None: - discrete_critic_output = self.policy.forward(forward_batch, model="discrete_critic") - loss_discrete = discrete_critic_output["loss_discrete_critic"] - self.optimizers["discrete_critic"].zero_grad() - loss_discrete.backward() - dc_grad = torch.nn.utils.clip_grad_norm_( - self.policy.discrete_critic.parameters(), - max_norm=self.config.clip_grad_norm, - ).item() - self.optimizers["discrete_critic"].step() - dc_loss_val = loss_discrete.item() - stats.losses["discrete_critic"] = dc_loss_val - stats.grad_norms["discrete_critic"] = dc_grad - - if self._optimization_step % self.config.policy_update_freq == 0: - for _ in range(self.config.policy_update_freq): - actor_output = self.policy.forward(forward_batch, model="actor") - actor_loss = actor_output["loss_actor"] - self.optimizers["actor"].zero_grad() - actor_loss.backward() - actor_grad = torch.nn.utils.clip_grad_norm_( - self.policy.actor.parameters(), - max_norm=self.config.clip_grad_norm, - ).item() - self.optimizers["actor"].step() - - temperature_output = self.policy.forward(forward_batch, model="temperature") - temp_loss = temperature_output["loss_temperature"] - self.optimizers["temperature"].zero_grad() - temp_loss.backward() - temp_grad = torch.nn.utils.clip_grad_norm_( - [self.policy.log_alpha], - max_norm=self.config.clip_grad_norm, - ).item() - self.optimizers["temperature"].step() - - actor_loss_val = actor_loss.item() - temp_loss_val = temp_loss.item() - stats.losses["actor"] = actor_loss_val - stats.losses["temperature"] = temp_loss_val - stats.grad_norms["actor"] = actor_grad - stats.grad_norms["temperature"] = temp_grad - stats.extra["temperature"] = self.policy.temperature - - self.policy.update_target_networks() - - self._optimization_step += 1 - return stats - - def _prepare_forward_batch(self, batch: BatchType) -> dict[str, Any]: - """Build the dict expected by ``SACPolicy.forward()`` from a batch.""" - observations = batch["state"] - next_observations = batch["next_state"] - - observation_features, next_observation_features = self.get_observation_features( - observations, next_observations - ) - forward_batch: dict[str, Any] = { - ACTION: batch[ACTION], - "reward": batch["reward"], - "state": observations, - "next_state": next_observations, - "done": batch["done"], - "observation_feature": observation_features, - "next_observation_feature": next_observation_features, - } - if "complementary_info" in batch: - forward_batch["complementary_info"] = batch["complementary_info"] - return forward_batch - - def make_optimizers(self) -> dict[str, Optimizer]: - """Create Adam optimizers for the SAC components and store them.""" - actor_params = [ - p - for n, p in self.policy.actor.named_parameters() - if not self.policy.config.shared_encoder or not n.startswith("encoder") - ] - self.optimizers = { - "actor": torch.optim.Adam(actor_params, lr=self.config.actor_lr), - "critic": torch.optim.Adam(self.policy.critic_ensemble.parameters(), lr=self.config.critic_lr), - "temperature": torch.optim.Adam([self.policy.log_alpha], lr=self.config.critic_lr), - } - if self.policy.config.num_discrete_actions is not None: - self.optimizers["discrete_critic"] = torch.optim.Adam( - self.policy.discrete_critic.parameters(), lr=self.config.critic_lr - ) - return self.optimizers - - def get_optimizers(self) -> dict[str, Optimizer]: - return self.optimizers - - def get_weights(self) -> dict[str, Any]: - """State-dicts to push to the actor process.""" - out: dict[str, Any] = { - "policy": move_state_dict_to_device(self.policy.actor.state_dict(), device="cpu"), - } - if hasattr(self.policy, "discrete_critic") and self.policy.discrete_critic is not None: - out["discrete_critic"] = move_state_dict_to_device( - self.policy.discrete_critic.state_dict(), device="cpu" - ) - return out - - def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: - """Load state-dict(s) received from the learner (inverse of ``get_weights``).""" - if "policy" in weights: - actor_state = move_state_dict_to_device(weights["policy"], device=device) - self.policy.actor.load_state_dict(actor_state) - if ( - "discrete_critic" in weights - and hasattr(self.policy, "discrete_critic") - and self.policy.discrete_critic is not None - ): - dc_state = move_state_dict_to_device(weights["discrete_critic"], device=device) - self.policy.discrete_critic.load_state_dict(dc_state) - - @torch.no_grad() - def get_observation_features( - self, observations: Tensor, next_observations: Tensor - ) -> tuple[Tensor | None, Tensor | None]: - if self.policy.config.vision_encoder_name is None or not self.policy.config.freeze_vision_encoder: - return None, None - observation_features = self.policy.actor.encoder.get_cached_image_features(observations) - next_observation_features = self.policy.actor.encoder.get_cached_image_features(next_observations) - return observation_features, next_observation_features diff --git a/src/lerobot/rl/algorithms/sac/__init__.py b/src/lerobot/rl/algorithms/sac/__init__.py new file mode 100644 index 000000000..6e0e76a66 --- /dev/null +++ b/src/lerobot/rl/algorithms/sac/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig +from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm + +__all__ = ["SACAlgorithm", "SACAlgorithmConfig"] diff --git a/src/lerobot/rl/algorithms/sac/configuration_sac.py b/src/lerobot/rl/algorithms/sac/configuration_sac.py new file mode 100644 index 000000000..812ec6686 --- /dev/null +++ b/src/lerobot/rl/algorithms/sac/configuration_sac.py @@ -0,0 +1,81 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SAC algorithm configuration.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import torch + +from lerobot.policies.sac.configuration_sac import CriticNetworkConfig +from lerobot.rl.algorithms.base import RLAlgorithmConfig + +if TYPE_CHECKING: + from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm + + +@RLAlgorithmConfig.register_subclass("sac") +@dataclass +class SACAlgorithmConfig(RLAlgorithmConfig): + """SAC-specific hyper-parameters that control the update loop.""" + + utd_ratio: int = 1 + policy_update_freq: int = 1 + clip_grad_norm: float = 40.0 + actor_lr: float = 3e-4 + critic_lr: float = 3e-4 + temperature_lr: float = 3e-4 + discount: float = 0.99 + temperature_init: float = 1.0 + target_entropy: float | None = None + use_backup_entropy: bool = True + critic_target_update_weight: float = 0.005 + num_critics: int = 2 + num_subsample_critics: int | None = None + num_discrete_actions: int | None = None + shared_encoder: bool = True + critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + use_torch_compile: bool = True + + @classmethod + def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig: + """Build from an existing ``SACConfig`` (cfg.policy) for backwards compat.""" + return cls( + utd_ratio=policy_cfg.utd_ratio, + policy_update_freq=policy_cfg.policy_update_freq, + clip_grad_norm=policy_cfg.grad_clip_norm, + actor_lr=policy_cfg.actor_lr, + critic_lr=policy_cfg.critic_lr, + temperature_lr=policy_cfg.temperature_lr, + discount=policy_cfg.discount, + temperature_init=policy_cfg.temperature_init, + target_entropy=policy_cfg.target_entropy, + use_backup_entropy=policy_cfg.use_backup_entropy, + critic_target_update_weight=policy_cfg.critic_target_update_weight, + num_critics=policy_cfg.num_critics, + num_subsample_critics=policy_cfg.num_subsample_critics, + num_discrete_actions=policy_cfg.num_discrete_actions, + shared_encoder=policy_cfg.shared_encoder, + critic_network_kwargs=policy_cfg.critic_network_kwargs, + discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs, + use_torch_compile=policy_cfg.use_torch_compile, + ) + + def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm: + from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm + + return SACAlgorithm(policy=policy, config=self) diff --git a/src/lerobot/rl/algorithms/sac/sac_algorithm.py b/src/lerobot/rl/algorithms/sac/sac_algorithm.py new file mode 100644 index 000000000..61886289a --- /dev/null +++ b/src/lerobot/rl/algorithms/sac/sac_algorithm.py @@ -0,0 +1,400 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SAC (Soft Actor-Critic) algorithm. + +This module encapsulates all SAC-specific training logic (critic, actor, +temperature, and discrete-critic updates) behind the ``RLAlgorithm`` interface. +""" + +from __future__ import annotations + +import math +from collections.abc import Iterator +from dataclasses import asdict +from typing import Any + +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +from torch import Tensor +from torch.optim import Optimizer + +from lerobot.policies.sac.modeling_sac import ( + DISCRETE_DIMENSION_INDEX, + CriticEnsemble, + CriticHead, + DiscreteCritic, + SACObservationEncoder, + SACPolicy, +) +from lerobot.policies.utils import get_device_from_parameters +from lerobot.rl.algorithms.base import ( + BatchType, + RLAlgorithm, + TrainingStats, +) +from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig +from lerobot.utils.constants import ACTION +from lerobot.utils.transition import move_state_dict_to_device + + +class SACAlgorithm(RLAlgorithm): + """Soft Actor-Critic with optional discrete-critic head. + + Owns the ``SACPolicy`` and its optimizers. All loss methods call + ``self.policy(batch_dict)`` rather than reaching into ``self.policy.actor`` + directly, so any policy that returns ``{"action", "log_prob"}`` from its + ``forward()`` is compatible. + """ + + def __init__( + self, + policy: SACPolicy, + config: SACAlgorithmConfig, + ): + self.policy = policy + self.config = config + self.optimizers: dict[str, Optimizer] = {} + self._optimization_step: int = 0 + + self._device = get_device_from_parameters(self.policy) + self._init_critic_encoder() + self._init_critics() + self._init_temperature() + + def _init_critic_encoder(self) -> None: + """Build or share the encoder used by critics.""" + if self.config.shared_encoder: + self.critic_encoder = self.policy.encoder + self.policy.actor.encoder_is_shared = True + else: + self.critic_encoder = SACObservationEncoder(self.policy.config) + + def _init_critics(self) -> None: + """Build critic ensemble, targets, and optional discrete critic.""" + action_dim = self.policy.config.output_features[ACTION].shape[0] + input_dim = self.critic_encoder.output_dim + action_dim + + heads = [ + CriticHead(input_dim=input_dim, **asdict(self.config.critic_network_kwargs)) + for _ in range(self.config.num_critics) + ] + self.critic_ensemble = CriticEnsemble(encoder=self.critic_encoder, ensemble=heads) + + target_heads = [ + CriticHead(input_dim=input_dim, **asdict(self.config.critic_network_kwargs)) + for _ in range(self.config.num_critics) + ] + self.critic_target = CriticEnsemble(encoder=self.critic_encoder, ensemble=target_heads) + self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) + + if self.config.use_torch_compile: + self.critic_ensemble = torch.compile(self.critic_ensemble) + self.critic_target = torch.compile(self.critic_target) + + if self.config.num_discrete_actions is not None: + self._init_discrete_critic_target() + + def _init_discrete_critic_target(self) -> None: + """Build only the target discrete critic.""" + input_dim = self.critic_encoder.output_dim + self.discrete_critic_target = DiscreteCritic( + encoder=self.critic_encoder, + input_dim=input_dim, + output_dim=self.config.num_discrete_actions, + **asdict(self.config.discrete_critic_network_kwargs), + ) + # TODO: (kmeftah) Compile the discrete critic + self.discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict()) + + def _init_temperature(self) -> None: + """Set up temperature parameter (log_alpha) and default target entropy.""" + temp_init = self.config.temperature_init + self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) + + action_dim = self.policy.config.output_features[ACTION].shape[0] + self.target_entropy = self.config.target_entropy + if self.target_entropy is None: + dim = action_dim + (1 if self.config.num_discrete_actions is not None else 0) + self.target_entropy = -np.prod(dim) / 2 + + @property + def temperature(self) -> float: + return self.log_alpha.exp().item() + + def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: + """Run one full SAC update with UTD critic warm-up. + + Pulls ``utd_ratio`` batches from ``batch_iterator``. The first + ``utd_ratio - 1`` batches are used for critic-only warm-up steps; + the last batch drives the full update (critic + actor + temperature). + """ + for _ in range(self.config.utd_ratio - 1): + batch = next(batch_iterator) + forward_batch = self._prepare_forward_batch(batch) + + loss_critic = self._compute_loss_critic(forward_batch) + self.optimizers["critic"].zero_grad() + loss_critic.backward() + torch.nn.utils.clip_grad_norm_( + self.critic_ensemble.parameters(), + max_norm=self.config.clip_grad_norm, + ).item() + self.optimizers["critic"].step() + + if self.config.num_discrete_actions is not None: + loss_discrete = self._compute_loss_discrete_critic(forward_batch) + self.optimizers["discrete_critic"].zero_grad() + loss_discrete.backward() + torch.nn.utils.clip_grad_norm_( + self.policy.discrete_critic.parameters(), + max_norm=self.config.clip_grad_norm, + ).item() + self.optimizers["discrete_critic"].step() + self._update_target_networks() + + batch = next(batch_iterator) + forward_batch = self._prepare_forward_batch(batch) + + loss_critic = self._compute_loss_critic(forward_batch) + self.optimizers["critic"].zero_grad() + loss_critic.backward() + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + self.critic_ensemble.parameters(), + max_norm=self.config.clip_grad_norm, + ).item() + self.optimizers["critic"].step() + + critic_loss_val = loss_critic.item() + stats = TrainingStats( + losses={"critic": critic_loss_val}, + grad_norms={"critic": critic_grad_norm}, + ) + + if self.config.num_discrete_actions is not None: + loss_discrete = self._compute_loss_discrete_critic(forward_batch) + self.optimizers["discrete_critic"].zero_grad() + loss_discrete.backward() + dc_grad = torch.nn.utils.clip_grad_norm_( + self.policy.discrete_critic.parameters(), + max_norm=self.config.clip_grad_norm, + ).item() + self.optimizers["discrete_critic"].step() + stats.losses["discrete_critic"] = loss_discrete.item() + stats.grad_norms["discrete_critic"] = dc_grad + + if self._optimization_step % self.config.policy_update_freq == 0: + for _ in range(self.config.policy_update_freq): + actor_loss = self._compute_loss_actor(forward_batch) + self.optimizers["actor"].zero_grad() + actor_loss.backward() + actor_grad = torch.nn.utils.clip_grad_norm_( + self.policy.actor.parameters(), + max_norm=self.config.clip_grad_norm, + ).item() + self.optimizers["actor"].step() + + temp_loss = self._compute_loss_temperature(forward_batch) + self.optimizers["temperature"].zero_grad() + temp_loss.backward() + temp_grad = torch.nn.utils.clip_grad_norm_( + [self.log_alpha], + max_norm=self.config.clip_grad_norm, + ).item() + self.optimizers["temperature"].step() + + stats.losses["actor"] = actor_loss.item() + stats.losses["temperature"] = temp_loss.item() + stats.grad_norms["actor"] = actor_grad + stats.grad_norms["temperature"] = temp_grad + stats.extra["temperature"] = self.temperature + + self._update_target_networks() + + self._optimization_step += 1 + return stats + + def _compute_loss_critic(self, batch: dict[str, Any]) -> Tensor: + observations = batch["state"] + actions = batch[ACTION] + rewards = batch["reward"] + next_observations = batch["next_state"] + done = batch["done"] + obs_features = batch.get("observation_feature") + next_obs_features = batch.get("next_observation_feature") + + with torch.no_grad(): + next_output = self.policy({"state": next_observations, "observation_feature": next_obs_features}) + next_actions = next_output["action"] + next_log_probs = next_output["log_prob"] + + q_targets = self.critic_target(next_observations, next_actions, next_obs_features) + + if self.config.num_subsample_critics is not None: + indices = torch.randperm(self.config.num_critics) + indices = indices[: self.config.num_subsample_critics] + q_targets = q_targets[indices] + + min_q, _ = q_targets.min(dim=0) + if self.config.use_backup_entropy: + min_q = min_q - (self.temperature * next_log_probs) + + td_target = rewards + (1 - done) * self.config.discount * min_q + + if self.config.num_discrete_actions is not None: + actions = actions[:, :DISCRETE_DIMENSION_INDEX] + + q_preds = self.critic_ensemble(observations, actions, obs_features) + + td_target_dup = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]) + critics_loss = (F.mse_loss(input=q_preds, target=td_target_dup, reduction="none").mean(dim=1)).sum() + return critics_loss + + def _compute_loss_discrete_critic(self, batch: dict[str, Any]) -> Tensor: + observations = batch["state"] + actions = batch[ACTION] + rewards = batch["reward"] + next_observations = batch["next_state"] + done = batch["done"] + obs_features = batch.get("observation_feature") + next_obs_features = batch.get("next_observation_feature") + complementary_info = batch.get("complementary_info") + + actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() + actions_discrete = torch.round(actions_discrete).long() + + discrete_penalties: Tensor | None = None + if complementary_info is not None: + discrete_penalties = complementary_info.get("discrete_penalty") + + with torch.no_grad(): + next_discrete_qs = self.policy.discrete_critic(next_observations, next_obs_features) + best_next_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True) + + target_next_qs = self.discrete_critic_target(next_observations, next_obs_features) + target_next_q = torch.gather(target_next_qs, dim=1, index=best_next_action).squeeze(-1) + + rewards_disc = rewards + if discrete_penalties is not None: + rewards_disc = rewards + discrete_penalties + target_q = rewards_disc + (1 - done) * self.config.discount * target_next_q + + predicted_qs = self.policy.discrete_critic(observations, obs_features) + predicted_q = torch.gather(predicted_qs, dim=1, index=actions_discrete).squeeze(-1) + + return F.mse_loss(input=predicted_q, target=target_q) + + def _compute_loss_actor(self, batch: dict[str, Any]) -> Tensor: + observations = batch["state"] + obs_features = batch.get("observation_feature") + + output = self.policy({"state": observations, "observation_feature": obs_features}) + actions_pi = output["action"] + log_probs = output["log_prob"] + + q_preds = self.critic_ensemble(observations, actions_pi, obs_features) + min_q = q_preds.min(dim=0)[0] + + return ((self.temperature * log_probs) - min_q).mean() + + def _compute_loss_temperature(self, batch: dict[str, Any]) -> Tensor: + observations = batch["state"] + obs_features = batch.get("observation_feature") + + with torch.no_grad(): + output = self.policy({"state": observations, "observation_feature": obs_features}) + log_probs = output["log_prob"] + + return (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean() + + def _update_target_networks(self) -> None: + tau = self.config.critic_target_update_weight + for target_p, p in zip( + self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=True + ): + target_p.data.copy_(p.data * tau + target_p.data * (1.0 - tau)) + if self.config.num_discrete_actions is not None: + for target_p, p in zip( + self.discrete_critic_target.parameters(), + self.policy.discrete_critic.parameters(), + strict=True, + ): + target_p.data.copy_(p.data * tau + target_p.data * (1.0 - tau)) + + def _prepare_forward_batch(self, batch: BatchType) -> dict[str, Any]: + """Build the dict expected by loss computation from a sampled batch.""" + observations = batch["state"] + next_observations = batch["next_state"] + + observation_features, next_observation_features = self.get_observation_features( + observations, next_observations + ) + forward_batch: dict[str, Any] = { + ACTION: batch[ACTION], + "reward": batch["reward"], + "state": observations, + "next_state": next_observations, + "done": batch["done"], + "observation_feature": observation_features, + "next_observation_feature": next_observation_features, + } + if "complementary_info" in batch: + forward_batch["complementary_info"] = batch["complementary_info"] + return forward_batch + + def make_optimizers(self) -> dict[str, Optimizer]: + """Create Adam optimizers for the SAC components and store them.""" + actor_params = [ + p + for n, p in self.policy.actor.named_parameters() + if not self.config.shared_encoder or not n.startswith("encoder") + ] + self.optimizers = { + "actor": torch.optim.Adam(actor_params, lr=self.config.actor_lr), + "critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr), + "temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr), + } + if self.config.num_discrete_actions is not None: + self.optimizers["discrete_critic"] = torch.optim.Adam( + self.policy.discrete_critic.parameters(), lr=self.config.critic_lr + ) + return self.optimizers + + def get_optimizers(self) -> dict[str, Optimizer]: + return self.optimizers + + def get_weights(self) -> dict[str, Any]: + """Policy state-dict to push to actors (includes actor + discrete critic).""" + return move_state_dict_to_device(self.policy.state_dict(), device="cpu") + + def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: + """Load policy state-dict received from the learner.""" + state = move_state_dict_to_device(weights, device=device) + self.policy.load_state_dict(state) + + @torch.no_grad() + def get_observation_features( + self, observations: Tensor, next_observations: Tensor + ) -> tuple[Tensor | None, Tensor | None]: + if not self.config.shared_encoder: + return None, None + if self.policy.config.vision_encoder_name is None or not self.policy.config.freeze_vision_encoder: + return None, None + if not self.policy.encoder.has_images: + return None, None + observation_features = self.policy.encoder.get_cached_image_features(observations) + next_observation_features = self.policy.encoder.get_cached_image_features(next_observations) + return observation_features, next_observation_features diff --git a/src/lerobot/utils/train_utils.py b/src/lerobot/utils/train_utils.py index d8481f4b9..b91f28f67 100644 --- a/src/lerobot/utils/train_utils.py +++ b/src/lerobot/utils/train_utils.py @@ -95,6 +95,7 @@ def save_checkpoint( optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None. scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None. preprocessor: The preprocessor/pipeline to save. Defaults to None. + postprocessor: The postprocessor/pipeline to save. Defaults to None. """ pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR policy.save_pretrained(pretrained_dir) diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py index 6fad2979e..f2374d9a8 100644 --- a/tests/policies/test_sac_policy.py +++ b/tests/policies/test_sac_policy.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - import pytest import torch from torch import Tensor, nn @@ -23,6 +21,7 @@ from torch import Tensor, nn from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.modeling_sac import MLP, SACPolicy +from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE from lerobot.utils.random_utils import seeded_context, set_seed @@ -138,41 +137,6 @@ def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: i } -def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]: - """Create optimizers for the SAC policy.""" - optimizer_actor = torch.optim.Adam( - # Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient - params=[ - p - for n, p in policy.actor.named_parameters() - if not policy.config.shared_encoder or not n.startswith("encoder") - ], - lr=policy.config.actor_lr, - ) - optimizer_critic = torch.optim.Adam( - params=policy.critic_ensemble.parameters(), - lr=policy.config.critic_lr, - ) - optimizer_temperature = torch.optim.Adam( - params=[policy.log_alpha], - lr=policy.config.critic_lr, - ) - - optimizers = { - "actor": optimizer_actor, - "critic": optimizer_critic, - "temperature": optimizer_temperature, - } - - if has_discrete_action: - optimizers["discrete_critic"] = torch.optim.Adam( - params=policy.discrete_critic.parameters(), - lr=policy.config.critic_lr, - ) - - return optimizers - - def create_default_config( state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False ) -> SACConfig: @@ -212,7 +176,6 @@ def create_config_with_visual_input( "std": torch.randn(3, 1, 1), } - # Let make tests a little bit faster config.state_encoder_hidden_dim = 32 config.latent_dim = 32 @@ -220,75 +183,112 @@ def create_config_with_visual_input( return config -@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) -def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int): - batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim) - config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) - +def _make_algorithm(config: SACConfig) -> tuple[SACAlgorithm, SACPolicy]: + """Helper to create policy + algorithm pair for tests that need critics.""" policy = SACPolicy(config=config) policy.train() + algo_config = SACAlgorithmConfig.from_policy_config(config) + algorithm = SACAlgorithm(policy=policy, config=algo_config) + algorithm.make_optimizers() + return algorithm, policy - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - assert actor_loss.item() is not None - assert actor_loss.shape == () - - actor_loss.backward() - optimizers["actor"].step() - - temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] - assert temperature_loss.item() is not None - assert temperature_loss.shape == () - - temperature_loss.backward() - optimizers["temperature"].step() +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_sac_policy_select_action(batch_size: int, state_dim: int, action_dim: int): + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + policy = SACPolicy(config=config) policy.eval() + with torch.no_grad(): observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) selected_action = policy.select_action(observation_batch) - assert selected_action.shape == (batch_size, action_dim) + # squeeze(0) removes batch dim when batch_size==1 + assert selected_action.shape[-1] == action_dim + + +def test_sac_policy_select_action_with_discrete(): + """select_action should return continuous + discrete actions.""" + config = create_default_config(state_dim=10, continuous_action_dim=6) + config.num_discrete_actions = 3 + policy = SACPolicy(config=config) + policy.eval() + + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=1, state_dim=10) + # Squeeze to unbatched (single observation) + observation_batch = {k: v.squeeze(0) for k, v in observation_batch.items()} + selected_action = policy.select_action(observation_batch) + assert selected_action.shape[-1] == 7 # 6 continuous + 1 discrete @pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) -def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int): - config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) +def test_sac_policy_forward(batch_size: int, state_dim: int, action_dim: int): + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) policy = SACPolicy(config=config) + policy.eval() + + batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim) + with torch.no_grad(): + output = policy.forward(batch) + assert "action" in output + assert "log_prob" in output + assert "action_mean" in output + assert output["action"].shape == (batch_size, action_dim) + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_sac_training_through_algorithm(batch_size: int, state_dim: int, action_dim: int): + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + algorithm, policy = _make_algorithm(config) + + batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim) + forward_batch = algorithm._prepare_forward_batch(batch) + + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.item() is not None + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() + + actor_loss = algorithm._compute_loss_actor(forward_batch) + assert actor_loss.item() is not None + assert actor_loss.shape == () + algorithm.optimizers["actor"].zero_grad() + actor_loss.backward() + algorithm.optimizers["actor"].step() + + temp_loss = algorithm._compute_loss_temperature(forward_batch) + assert temp_loss.item() is not None + assert temp_loss.shape == () + algorithm.optimizers["temperature"].zero_grad() + temp_loss.backward() + algorithm.optimizers["temperature"].step() + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int): + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + algorithm, policy = _make_algorithm(config) batch = create_train_batch_with_visual_input( batch_size=batch_size, state_dim=state_dim, action_dim=action_dim ) + forward_batch = algorithm._prepare_forward_batch(batch) - policy.train() + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.item() is not None + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] + actor_loss = algorithm._compute_loss_actor(forward_batch) assert actor_loss.item() is not None assert actor_loss.shape == () - + algorithm.optimizers["actor"].zero_grad() actor_loss.backward() - optimizers["actor"].step() - - temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] - assert temperature_loss.item() is not None - assert temperature_loss.shape == () - - temperature_loss.backward() - optimizers["temperature"].step() + algorithm.optimizers["actor"].step() policy.eval() with torch.no_grad(): @@ -296,207 +296,181 @@ def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_di batch_size=batch_size, state_dim=state_dim ) selected_action = policy.select_action(observation_batch) - assert selected_action.shape == (batch_size, action_dim) + assert selected_action.shape[-1] == action_dim -# Let's check best candidates for pretrained encoders @pytest.mark.parametrize( "batch_size,state_dim,action_dim,vision_encoder_name", [(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")], ) @pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed") -def test_sac_policy_with_pretrained_encoder( +def test_sac_training_with_pretrained_encoder( batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str ): config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) config.vision_encoder_name = vision_encoder_name - policy = SACPolicy(config=config) - policy.train() + algorithm, policy = _make_algorithm(config) batch = create_train_batch_with_visual_input( batch_size=batch_size, state_dim=state_dim, action_dim=action_dim ) + forward_batch = algorithm._prepare_forward_batch(batch) - optimizers = make_optimizers(policy) + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.item() is not None + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] + actor_loss = algorithm._compute_loss_actor(forward_batch) assert actor_loss.item() is not None assert actor_loss.shape == () -def test_sac_policy_with_shared_encoder(): +def test_sac_training_with_shared_encoder(): batch_size = 2 action_dim = 10 state_dim = 10 config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) config.shared_encoder = True - policy = SACPolicy(config=config) - policy.train() + algorithm, policy = _make_algorithm(config) batch = create_train_batch_with_visual_input( batch_size=batch_size, state_dim=state_dim, action_dim=action_dim ) + forward_batch = algorithm._prepare_forward_batch(batch) - policy.train() + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - assert actor_loss.item() is not None + actor_loss = algorithm._compute_loss_actor(forward_batch) assert actor_loss.shape == () - + algorithm.optimizers["actor"].zero_grad() actor_loss.backward() - optimizers["actor"].step() + algorithm.optimizers["actor"].step() -def test_sac_policy_with_discrete_critic(): +def test_sac_training_with_discrete_critic(): batch_size = 2 continuous_action_dim = 9 - full_action_dim = continuous_action_dim + 1 # the last action is discrete + full_action_dim = continuous_action_dim + 1 state_dim = 10 config = create_config_with_visual_input( state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True ) + config.num_discrete_actions = 5 - num_discrete_actions = 5 - config.num_discrete_actions = num_discrete_actions - - policy = SACPolicy(config=config) - policy.train() + algorithm, policy = _make_algorithm(config) batch = create_train_batch_with_visual_input( batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim ) + forward_batch = algorithm._prepare_forward_batch(batch) - policy.train() + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() - optimizers = make_optimizers(policy, has_discrete_action=True) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"] - assert discrete_critic_loss.item() is not None + discrete_critic_loss = algorithm._compute_loss_discrete_critic(forward_batch) assert discrete_critic_loss.shape == () + algorithm.optimizers["discrete_critic"].zero_grad() discrete_critic_loss.backward() - optimizers["discrete_critic"].step() + algorithm.optimizers["discrete_critic"].step() - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - assert actor_loss.item() is not None + actor_loss = algorithm._compute_loss_actor(forward_batch) assert actor_loss.shape == () - + algorithm.optimizers["actor"].zero_grad() actor_loss.backward() - optimizers["actor"].step() + algorithm.optimizers["actor"].step() policy.eval() with torch.no_grad(): observation_batch = create_observation_batch_with_visual_input( batch_size=batch_size, state_dim=state_dim ) - selected_action = policy.select_action(observation_batch) - assert selected_action.shape == (batch_size, full_action_dim) - - discrete_actions = selected_action[:, -1].long() - discrete_action_values = set(discrete_actions.tolist()) - - assert all(action in range(num_discrete_actions) for action in discrete_action_values), ( - f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})" - ) + # Policy.select_action now handles both continuous + discrete + selected_action = policy.select_action({k: v.squeeze(0) for k, v in observation_batch.items()}) + assert selected_action.shape[-1] == continuous_action_dim + 1 -def test_sac_policy_with_default_entropy(): +def test_sac_algorithm_target_entropy(): config = create_default_config(continuous_action_dim=10, state_dim=10) - policy = SACPolicy(config=config) - assert policy.target_entropy == -5.0 + _, policy = _make_algorithm(config) + algo_config = SACAlgorithmConfig.from_policy_config(config) + algorithm = SACAlgorithm(policy=policy, config=algo_config) + assert algorithm.target_entropy == -5.0 -def test_sac_policy_default_target_entropy_with_discrete_action(): +def test_sac_algorithm_target_entropy_with_discrete_action(): config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True) + config.num_discrete_actions = 5 + algo_config = SACAlgorithmConfig.from_policy_config(config) policy = SACPolicy(config=config) - assert policy.target_entropy == -3.0 + algorithm = SACAlgorithm(policy=policy, config=algo_config) + assert algorithm.target_entropy == -3.5 -def test_sac_policy_with_predefined_entropy(): - config = create_default_config(state_dim=10, continuous_action_dim=6) - config.target_entropy = -3.5 +def test_sac_algorithm_temperature(): + import math - policy = SACPolicy(config=config) - assert policy.target_entropy == pytest.approx(-3.5) - - -def test_sac_policy_update_temperature(): - """Test that temperature property is always in sync with log_alpha.""" config = create_default_config(continuous_action_dim=10, state_dim=10) + algo_config = SACAlgorithmConfig.from_policy_config(config) policy = SACPolicy(config=config) + algorithm = SACAlgorithm(policy=policy, config=algo_config) - assert policy.temperature == pytest.approx(1.0) - policy.log_alpha.data = torch.tensor([math.log(0.1)]) - # Temperature property automatically reflects log_alpha changes - assert policy.temperature == pytest.approx(0.1) + assert algorithm.temperature == pytest.approx(1.0) + algorithm.log_alpha.data = torch.tensor([math.log(0.1)]) + assert algorithm.temperature == pytest.approx(0.1) -def test_sac_policy_update_target_network(): +def test_sac_algorithm_update_target_network(): config = create_default_config(state_dim=10, continuous_action_dim=6) config.critic_target_update_weight = 1.0 - + algo_config = SACAlgorithmConfig.from_policy_config(config) policy = SACPolicy(config=config) - policy.train() + algorithm = SACAlgorithm(policy=policy, config=algo_config) - for p in policy.critic_ensemble.parameters(): + for p in algorithm.critic_ensemble.parameters(): p.data = torch.ones_like(p.data) - policy.update_target_networks() - for p in policy.critic_target.parameters(): - assert torch.allclose(p.data, torch.ones_like(p.data)), ( - f"Target network {p.data} is not equal to {torch.ones_like(p.data)}" - ) + algorithm._update_target_networks() + for p in algorithm.critic_target.parameters(): + assert torch.allclose(p.data, torch.ones_like(p.data)) @pytest.mark.parametrize("num_critics", [1, 3]) -def test_sac_policy_with_critics_number_of_heads(num_critics: int): +def test_sac_algorithm_with_critics_number_of_heads(num_critics: int): batch_size = 2 action_dim = 10 state_dim = 10 config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) config.num_critics = num_critics - policy = SACPolicy(config=config) - policy.train() + algorithm, policy = _make_algorithm(config) - assert len(policy.critic_ensemble.critics) == num_critics + assert len(algorithm.critic_ensemble.critics) == num_critics batch = create_train_batch_with_visual_input( batch_size=batch_size, state_dim=state_dim, action_dim=action_dim ) + forward_batch = algorithm._prepare_forward_batch(batch) - policy.train() - - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() + critic_loss = algorithm._compute_loss_critic(forward_batch) + assert critic_loss.shape == () + algorithm.optimizers["critic"].zero_grad() + critic_loss.backward() + algorithm.optimizers["critic"].step() def test_sac_policy_save_and_load(tmp_path): + """Test that the policy can be saved and loaded from pretrained.""" root = tmp_path / "test_sac_save_and_load" state_dim = 10 @@ -510,34 +484,41 @@ def test_sac_policy_save_and_load(tmp_path): loaded_policy = SACPolicy.from_pretrained(root, config=config) loaded_policy.eval() - batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10) + assert policy.state_dict().keys() == loaded_policy.state_dict().keys() + for k in policy.state_dict(): + assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) with torch.no_grad(): with seeded_context(12): - # Collect policy values before saving - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] - observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) actions = policy.select_action(observation_batch) with seeded_context(12): - # Collect policy values after loading - loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"] - loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"] - loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"] - loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) loaded_actions = loaded_policy.select_action(loaded_observation_batch) - assert policy.state_dict().keys() == loaded_policy.state_dict().keys() - for k in policy.state_dict(): - assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) - - # Compare values before and after saving and loading - # They should be the same - assert torch.allclose(cirtic_loss, loaded_cirtic_loss) - assert torch.allclose(actor_loss, loaded_actor_loss) - assert torch.allclose(temperature_loss, loaded_temperature_loss) assert torch.allclose(actions, loaded_actions) + + +def test_sac_policy_save_and_load_with_discrete_critic(tmp_path): + """Discrete critic should be saved/loaded as part of the policy.""" + root = tmp_path / "test_sac_save_and_load_discrete" + + state_dim = 10 + action_dim = 6 + + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + config.num_discrete_actions = 3 + policy = SACPolicy(config=config) + policy.eval() + policy.save_pretrained(root) + + loaded_policy = SACPolicy.from_pretrained(root, config=config) + loaded_policy.eval() + + assert loaded_policy.discrete_critic is not None + dc_keys = [k for k in loaded_policy.state_dict() if k.startswith("discrete_critic.")] + assert len(dc_keys) > 0 + + for k in policy.state_dict(): + assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py index 7c4dd25e7..6f07f1cd2 100644 --- a/tests/rl/test_actor_learner.py +++ b/tests/rl/test_actor_learner.py @@ -356,7 +356,7 @@ def test_learner_algorithm_wiring(): # get_weights -> state_to_bytes round-trip weights = algorithm.get_weights() - assert "policy" in weights + assert len(weights) > 0 serialized = state_to_bytes(weights) assert isinstance(serialized, bytes) assert len(serialized) > 0 @@ -430,8 +430,6 @@ def test_initial_and_periodic_weight_push_consistency(): periodic_decoded = bytes_to_state_dict(periodic_bytes) assert initial_decoded.keys() == periodic_decoded.keys() - for key in initial_decoded: - assert initial_decoded[key].keys() == periodic_decoded[key].keys() def test_actor_side_algorithm_select_action_and_load_weights(): @@ -462,7 +460,7 @@ def test_actor_side_algorithm_select_action_and_load_weights(): # select_action should work obs = {OBS_STATE: torch.randn(state_dim)} - action = algorithm.select_action(obs) + action = policy.select_action(obs) assert action.shape == (action_dim,) # Simulate receiving weights from learner diff --git a/tests/rl/test_sac_algorithm.py b/tests/rl/test_sac_algorithm.py index 63a5eb520..4485b1031 100644 --- a/tests/rl/test_sac_algorithm.py +++ b/tests/rl/test_sac_algorithm.py @@ -158,54 +158,55 @@ def test_training_stats_defaults(): # =========================================================================== -def test_get_weights_returns_actor_state_dict(): +def test_get_weights_returns_policy_state_dict(): algorithm, policy = _make_algorithm() weights = algorithm.get_weights() - assert "policy" in weights - for key in policy.actor.state_dict(): - assert key in weights["policy"] - assert torch.equal(weights["policy"][key].cpu(), policy.actor.state_dict()[key].cpu()) + for key in policy.state_dict(): + assert key in weights + assert torch.equal(weights[key].cpu(), policy.state_dict()[key].cpu()) def test_get_weights_includes_discrete_critic_when_present(): algorithm, policy = _make_algorithm(num_discrete_actions=3, action_dim=6) weights = algorithm.get_weights() - assert "discrete_critic" in weights - for key in policy.discrete_critic.state_dict(): - assert key in weights["discrete_critic"] + dc_keys = [k for k in weights if k.startswith("discrete_critic.")] + assert len(dc_keys) > 0 def test_get_weights_excludes_discrete_critic_when_absent(): algorithm, _ = _make_algorithm() weights = algorithm.get_weights() - assert "discrete_critic" not in weights + dc_keys = [k for k in weights if k.startswith("discrete_critic.")] + assert len(dc_keys) == 0 def test_get_weights_are_on_cpu(): algorithm, _ = _make_algorithm() weights = algorithm.get_weights() - for key, tensor in weights["policy"].items(): + for key, tensor in weights.items(): assert tensor.device == torch.device("cpu"), f"{key} is not on CPU" # =========================================================================== -# select_action +# select_action (lives on the policy, not the algorithm) # =========================================================================== def test_select_action_returns_correct_shape(): action_dim = 6 - algorithm, _ = _make_algorithm(state_dim=10, action_dim=action_dim) + _, policy = _make_algorithm(state_dim=10, action_dim=action_dim) + policy.eval() obs = {OBS_STATE: torch.randn(10)} - action = algorithm.select_action(obs) + action = policy.select_action(obs) assert action.shape == (action_dim,) def test_select_action_with_discrete_critic(): continuous_dim = 5 - algorithm, _ = _make_algorithm(state_dim=10, action_dim=continuous_dim, num_discrete_actions=3) + _, policy = _make_algorithm(state_dim=10, action_dim=continuous_dim, num_discrete_actions=3) + policy.eval() obs = {OBS_STATE: torch.randn(10)} - action = algorithm.select_action(obs) + action = policy.select_action(obs) assert action.shape == (continuous_dim + 1,) @@ -298,12 +299,12 @@ def test_update_utd_ratio_3_critic_warmup_changes_weights(): """With utd_ratio=3, critic weights should change after update (3 critic steps).""" algorithm, policy = _make_algorithm(utd_ratio=3) - critic_params_before = {n: p.clone() for n, p in policy.critic_ensemble.named_parameters()} + critic_params_before = {n: p.clone() for n, p in algorithm.critic_ensemble.named_parameters()} algorithm.update(_batch_iterator()) changed = False - for n, p in policy.critic_ensemble.named_parameters(): + for n, p in algorithm.critic_ensemble.named_parameters(): if not torch.equal(p, critic_params_before[n]): changed = True break @@ -403,11 +404,11 @@ def test_load_weights_round_trip(): weights = algo_src.get_weights() algo_dst.load_weights(weights, device="cpu") - for key in weights["policy"]: + for key in weights: assert torch.equal( - algo_dst.policy.actor.state_dict()[key].cpu(), - weights["policy"][key].cpu(), - ), f"Actor param '{key}' mismatch after load_weights" + algo_dst.policy.state_dict()[key].cpu(), + weights[key].cpu(), + ), f"Policy param '{key}' mismatch after load_weights" def test_load_weights_round_trip_with_discrete_critic(): @@ -421,17 +422,19 @@ def test_load_weights_round_trip_with_discrete_critic(): weights = algo_src.get_weights() algo_dst.load_weights(weights, device="cpu") - for key in weights["discrete_critic"]: + dc_keys = [k for k in weights if k.startswith("discrete_critic.")] + assert len(dc_keys) > 0 + for key in dc_keys: assert torch.equal( - algo_dst.policy.discrete_critic.state_dict()[key].cpu(), - weights["discrete_critic"][key].cpu(), + algo_dst.policy.state_dict()[key].cpu(), + weights[key].cpu(), ), f"Discrete critic param '{key}' mismatch after load_weights" def test_load_weights_ignores_missing_discrete_critic(): """load_weights should not fail when weights lack discrete_critic on a non-discrete policy.""" algorithm, _ = _make_algorithm() - weights = {"policy": algorithm.get_weights()["policy"]} + weights = algorithm.get_weights() algorithm.load_weights(weights, device="cpu")