diff --git a/src/lerobot/policies/gaussian_actor/modeling_gaussian_actor.py b/src/lerobot/policies/gaussian_actor/modeling_gaussian_actor.py index 9a7bcf1bc..a833d01cc 100644 --- a/src/lerobot/policies/gaussian_actor/modeling_gaussian_actor.py +++ b/src/lerobot/policies/gaussian_actor/modeling_gaussian_actor.py @@ -17,7 +17,6 @@ from collections.abc import Callable from dataclasses import asdict -from typing import Any import torch import torch.nn as nn @@ -25,7 +24,6 @@ from torch import Tensor from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE -from lerobot.utils.transition import move_state_dict_to_device from ..pretrained import PreTrainedPolicy from ..utils import get_device_from_parameters @@ -113,16 +111,6 @@ class GaussianActorPolicy( actions, log_probs, means = self.actor(observations, observation_features) return {"action": actions, "log_prob": log_probs, "action_mean": means} - def load_actor_weights(self, state_dicts: dict[str, Any], device: str | torch.device = "cpu") -> None: - actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device) - self.actor.load_state_dict(actor_state_dict) - - if "discrete_critic" in state_dicts and self.discrete_critic is not None: - discrete_critic_state_dict = move_state_dict_to_device( - state_dicts["discrete_critic"], device=device - ) - self.discrete_critic.load_state_dict(discrete_critic_state_dict) - def _init_encoders(self): """Initialize shared or separate encoders for actor and critic.""" self.shared_encoder = self.config.shared_encoder diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index e7820d14f..c553abd12 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -61,7 +61,7 @@ from torch.multiprocessing import Queue from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser -from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors +from lerobot.policies import make_policy, make_pre_post_processors from lerobot.processor import TransitionKey from lerobot.robots import so_follower # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401 @@ -80,6 +80,9 @@ from lerobot.utils.utils import ( init_logging, ) +from .algorithms.base import RLAlgorithm +from .algorithms.factory import make_algorithm + if TYPE_CHECKING or _grpc_available: import grpc @@ -277,6 +280,9 @@ def act_with_policy( policy = policy.to(device).eval() assert isinstance(policy, nn.Module) + # Build the algorithm + algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy) + preprocessor, postprocessor = make_pre_post_processors( policy_cfg=cfg.policy, dataset_stats=cfg.policy.dataset_stats, @@ -380,7 +386,7 @@ def act_with_policy( if done or truncated: logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") - update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device) + update_policy_parameters(algorithm=algorithm, parameters_queue=parameters_queue, device=device) if len(list_transition_to_send_to_learner) > 0: push_transitions_to_transport_queue( @@ -675,7 +681,8 @@ def interactions_stream( # Policy functions -def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, device): +def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, device): + """Drain the latest learner-pushed weights into ``algorithm.policy``.""" bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False) if bytes_state_dict is not None: logging.info("[ACTOR] Load new parameters from Learner.") @@ -690,7 +697,7 @@ def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, # - Send critic's encoder state when shared_encoder=True # - Skip encoder params entirely when freeze_vision_encoder=True # - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic) - policy.load_actor_weights(state_dicts, device=device) + algorithm.load_weights(state_dicts, device=device) # Utilities functions diff --git a/src/lerobot/rl/algorithms/sac/sac_algorithm.py b/src/lerobot/rl/algorithms/sac/sac_algorithm.py index eeb9f1fc5..81c44068f 100644 --- a/src/lerobot/rl/algorithms/sac/sac_algorithm.py +++ b/src/lerobot/rl/algorithms/sac/sac_algorithm.py @@ -511,7 +511,11 @@ class SACAlgorithm(RLAlgorithm): def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: """Load actor + discrete-critic weights into the policy.""" - self.policy.load_actor_weights(weights, device=device) + actor_sd = move_state_dict_to_device(weights["policy"], device=device) + self.policy.actor.load_state_dict(actor_sd) + if "discrete_critic" in weights and self.policy.discrete_critic is not None: + discrete_sd = move_state_dict_to_device(weights["discrete_critic"], device=device) + self.policy.discrete_critic.load_state_dict(discrete_sd) def state_dict(self) -> dict[str, torch.Tensor]: """Algorithm-owned trainable tensors. diff --git a/tests/rl/test_sac_algorithm.py b/tests/rl/test_sac_algorithm.py index 990d63164..2d77ae9ba 100644 --- a/tests/rl/test_sac_algorithm.py +++ b/tests/rl/test_sac_algorithm.py @@ -445,35 +445,39 @@ def test_load_weights_ignores_missing_discrete_critic(): def test_actor_side_weight_sync_with_discrete_critic(): - """End-to-end: learner ``algorithm.get_weights()`` -> actor ``policy.load_actor_weights()``.""" - # Learner side: train the algorithm so its weights diverge from init. + """End-to-end: learner ``algorithm.get_weights()`` -> actor ``algorithm.load_weights()``.""" + # Learner side: train the source algorithm so its weights diverge from init. algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) algo_src.update(_batch_iterator(action_dim=7)) weights = algo_src.get_weights() - # Actor side: fresh policy, no algorithm/optimizer. + # Actor side: fresh policy + fresh algorithm holding it. sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6) policy_actor = GaussianActorPolicy(config=sac_cfg) + algo_actor = SACAlgorithm( + policy=policy_actor, + config=SACAlgorithmConfig.from_policy_config(sac_cfg), + ) # Snapshot initial actor state for the "did it change?" assertion below. initial_discrete_critic_state_dict = { k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items() } - policy_actor.load_actor_weights(weights, device="cpu") + algo_actor.load_weights(weights, device="cpu") # Actor weights match the learner's exported actor state dict. actor_state_dict = policy_actor.actor.state_dict() for key, tensor in weights["policy"].items(): assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), ( - f"Actor param '{key}' not synced by load_actor_weights" + f"Actor param '{key}' not synced by algorithm.load_weights" ) # Discrete critic weights match the learner's exported discrete critic. discrete_critic_state_dict = policy_actor.discrete_critic.state_dict() for key, tensor in weights["discrete_critic"].items(): assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), ( - f"Discrete critic param '{key}' not synced by load_actor_weights" + f"Discrete critic param '{key}' not synced by algorithm.load_weights" ) # Sanity: the discrete critic actually changed (otherwise the sync is trivial).