refactor(rl): move actor weight-sync wire format from policy to algorithm

This commit is contained in:
Khalil Meftah
2026-05-08 21:57:21 +02:00
parent 0944b84279
commit b2f3fd746f
4 changed files with 26 additions and 23 deletions
@@ -17,7 +17,6 @@
from collections.abc import Callable
from dataclasses import asdict
from typing import Any
import torch
import torch.nn as nn
@@ -25,7 +24,6 @@ from torch import Tensor
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
from lerobot.utils.transition import move_state_dict_to_device
from ..pretrained import PreTrainedPolicy
from ..utils import get_device_from_parameters
@@ -113,16 +111,6 @@ class GaussianActorPolicy(
actions, log_probs, means = self.actor(observations, observation_features)
return {"action": actions, "log_prob": log_probs, "action_mean": means}
def load_actor_weights(self, state_dicts: dict[str, Any], device: str | torch.device = "cpu") -> None:
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
self.actor.load_state_dict(actor_state_dict)
if "discrete_critic" in state_dicts and self.discrete_critic is not None:
discrete_critic_state_dict = move_state_dict_to_device(
state_dicts["discrete_critic"], device=device
)
self.discrete_critic.load_state_dict(discrete_critic_state_dict)
def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic."""
self.shared_encoder = self.config.shared_encoder
+11 -4
View File
@@ -61,7 +61,7 @@ from torch.multiprocessing import Queue
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.processor import TransitionKey
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
@@ -80,6 +80,9 @@ from lerobot.utils.utils import (
init_logging,
)
from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm
if TYPE_CHECKING or _grpc_available:
import grpc
@@ -277,6 +280,9 @@ def act_with_policy(
policy = policy.to(device).eval()
assert isinstance(policy, nn.Module)
# Build the algorithm
algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
dataset_stats=cfg.policy.dataset_stats,
@@ -380,7 +386,7 @@ def act_with_policy(
if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
update_policy_parameters(algorithm=algorithm, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue(
@@ -675,7 +681,8 @@ def interactions_stream(
# Policy functions
def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, device):
def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, device):
"""Drain the latest learner-pushed weights into ``algorithm.policy``."""
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
if bytes_state_dict is not None:
logging.info("[ACTOR] Load new parameters from Learner.")
@@ -690,7 +697,7 @@ def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue,
# - Send critic's encoder state when shared_encoder=True
# - Skip encoder params entirely when freeze_vision_encoder=True
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
policy.load_actor_weights(state_dicts, device=device)
algorithm.load_weights(state_dicts, device=device)
# Utilities functions
@@ -511,7 +511,11 @@ class SACAlgorithm(RLAlgorithm):
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
"""Load actor + discrete-critic weights into the policy."""
self.policy.load_actor_weights(weights, device=device)
actor_sd = move_state_dict_to_device(weights["policy"], device=device)
self.policy.actor.load_state_dict(actor_sd)
if "discrete_critic" in weights and self.policy.discrete_critic is not None:
discrete_sd = move_state_dict_to_device(weights["discrete_critic"], device=device)
self.policy.discrete_critic.load_state_dict(discrete_sd)
def state_dict(self) -> dict[str, torch.Tensor]:
"""Algorithm-owned trainable tensors.
+10 -6
View File
@@ -445,35 +445,39 @@ def test_load_weights_ignores_missing_discrete_critic():
def test_actor_side_weight_sync_with_discrete_critic():
"""End-to-end: learner ``algorithm.get_weights()`` -> actor ``policy.load_actor_weights()``."""
# Learner side: train the algorithm so its weights diverge from init.
"""End-to-end: learner ``algorithm.get_weights()`` -> actor ``algorithm.load_weights()``."""
# Learner side: train the source algorithm so its weights diverge from init.
algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
algo_src.update(_batch_iterator(action_dim=7))
weights = algo_src.get_weights()
# Actor side: fresh policy, no algorithm/optimizer.
# Actor side: fresh policy + fresh algorithm holding it.
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
policy_actor = GaussianActorPolicy(config=sac_cfg)
algo_actor = SACAlgorithm(
policy=policy_actor,
config=SACAlgorithmConfig.from_policy_config(sac_cfg),
)
# Snapshot initial actor state for the "did it change?" assertion below.
initial_discrete_critic_state_dict = {
k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items()
}
policy_actor.load_actor_weights(weights, device="cpu")
algo_actor.load_weights(weights, device="cpu")
# Actor weights match the learner's exported actor state dict.
actor_state_dict = policy_actor.actor.state_dict()
for key, tensor in weights["policy"].items():
assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), (
f"Actor param '{key}' not synced by load_actor_weights"
f"Actor param '{key}' not synced by algorithm.load_weights"
)
# Discrete critic weights match the learner's exported discrete critic.
discrete_critic_state_dict = policy_actor.discrete_critic.state_dict()
for key, tensor in weights["discrete_critic"].items():
assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), (
f"Discrete critic param '{key}' not synced by load_actor_weights"
f"Discrete critic param '{key}' not synced by algorithm.load_weights"
)
# Sanity: the discrete critic actually changed (otherwise the sync is trivial).