refactor(rl): move actor weight-sync wire format from policy to algorithm

This commit is contained in:
Khalil Meftah
2026-05-08 21:57:21 +02:00
parent 0944b84279
commit b2f3fd746f
4 changed files with 26 additions and 23 deletions
@@ -17,7 +17,6 @@
from collections.abc import Callable from collections.abc import Callable
from dataclasses import asdict from dataclasses import asdict
from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -25,7 +24,6 @@ from torch import Tensor
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
from lerobot.utils.transition import move_state_dict_to_device
from ..pretrained import PreTrainedPolicy from ..pretrained import PreTrainedPolicy
from ..utils import get_device_from_parameters from ..utils import get_device_from_parameters
@@ -113,16 +111,6 @@ class GaussianActorPolicy(
actions, log_probs, means = self.actor(observations, observation_features) actions, log_probs, means = self.actor(observations, observation_features)
return {"action": actions, "log_prob": log_probs, "action_mean": means} return {"action": actions, "log_prob": log_probs, "action_mean": means}
def load_actor_weights(self, state_dicts: dict[str, Any], device: str | torch.device = "cpu") -> None:
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
self.actor.load_state_dict(actor_state_dict)
if "discrete_critic" in state_dicts and self.discrete_critic is not None:
discrete_critic_state_dict = move_state_dict_to_device(
state_dicts["discrete_critic"], device=device
)
self.discrete_critic.load_state_dict(discrete_critic_state_dict)
def _init_encoders(self): def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic.""" """Initialize shared or separate encoders for actor and critic."""
self.shared_encoder = self.config.shared_encoder self.shared_encoder = self.config.shared_encoder
+11 -4
View File
@@ -61,7 +61,7 @@ from torch.multiprocessing import Queue
from lerobot.cameras import opencv # noqa: F401 from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.processor import TransitionKey from lerobot.processor import TransitionKey
from lerobot.robots import so_follower # noqa: F401 from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401
@@ -80,6 +80,9 @@ from lerobot.utils.utils import (
init_logging, init_logging,
) )
from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm
if TYPE_CHECKING or _grpc_available: if TYPE_CHECKING or _grpc_available:
import grpc import grpc
@@ -277,6 +280,9 @@ def act_with_policy(
policy = policy.to(device).eval() policy = policy.to(device).eval()
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
# Build the algorithm
algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy)
preprocessor, postprocessor = make_pre_post_processors( preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, policy_cfg=cfg.policy,
dataset_stats=cfg.policy.dataset_stats, dataset_stats=cfg.policy.dataset_stats,
@@ -380,7 +386,7 @@ def act_with_policy(
if done or truncated: if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device) update_policy_parameters(algorithm=algorithm, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0: if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue( push_transitions_to_transport_queue(
@@ -675,7 +681,8 @@ def interactions_stream(
# Policy functions # Policy functions
def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, device): def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, device):
"""Drain the latest learner-pushed weights into ``algorithm.policy``."""
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False) bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
if bytes_state_dict is not None: if bytes_state_dict is not None:
logging.info("[ACTOR] Load new parameters from Learner.") logging.info("[ACTOR] Load new parameters from Learner.")
@@ -690,7 +697,7 @@ def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue,
# - Send critic's encoder state when shared_encoder=True # - Send critic's encoder state when shared_encoder=True
# - Skip encoder params entirely when freeze_vision_encoder=True # - Skip encoder params entirely when freeze_vision_encoder=True
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic) # - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
policy.load_actor_weights(state_dicts, device=device) algorithm.load_weights(state_dicts, device=device)
# Utilities functions # Utilities functions
@@ -511,7 +511,11 @@ class SACAlgorithm(RLAlgorithm):
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
"""Load actor + discrete-critic weights into the policy.""" """Load actor + discrete-critic weights into the policy."""
self.policy.load_actor_weights(weights, device=device) actor_sd = move_state_dict_to_device(weights["policy"], device=device)
self.policy.actor.load_state_dict(actor_sd)
if "discrete_critic" in weights and self.policy.discrete_critic is not None:
discrete_sd = move_state_dict_to_device(weights["discrete_critic"], device=device)
self.policy.discrete_critic.load_state_dict(discrete_sd)
def state_dict(self) -> dict[str, torch.Tensor]: def state_dict(self) -> dict[str, torch.Tensor]:
"""Algorithm-owned trainable tensors. """Algorithm-owned trainable tensors.
+10 -6
View File
@@ -445,35 +445,39 @@ def test_load_weights_ignores_missing_discrete_critic():
def test_actor_side_weight_sync_with_discrete_critic(): def test_actor_side_weight_sync_with_discrete_critic():
"""End-to-end: learner ``algorithm.get_weights()`` -> actor ``policy.load_actor_weights()``.""" """End-to-end: learner ``algorithm.get_weights()`` -> actor ``algorithm.load_weights()``."""
# Learner side: train the algorithm so its weights diverge from init. # Learner side: train the source algorithm so its weights diverge from init.
algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
algo_src.update(_batch_iterator(action_dim=7)) algo_src.update(_batch_iterator(action_dim=7))
weights = algo_src.get_weights() weights = algo_src.get_weights()
# Actor side: fresh policy, no algorithm/optimizer. # Actor side: fresh policy + fresh algorithm holding it.
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6) sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
policy_actor = GaussianActorPolicy(config=sac_cfg) policy_actor = GaussianActorPolicy(config=sac_cfg)
algo_actor = SACAlgorithm(
policy=policy_actor,
config=SACAlgorithmConfig.from_policy_config(sac_cfg),
)
# Snapshot initial actor state for the "did it change?" assertion below. # Snapshot initial actor state for the "did it change?" assertion below.
initial_discrete_critic_state_dict = { initial_discrete_critic_state_dict = {
k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items() k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items()
} }
policy_actor.load_actor_weights(weights, device="cpu") algo_actor.load_weights(weights, device="cpu")
# Actor weights match the learner's exported actor state dict. # Actor weights match the learner's exported actor state dict.
actor_state_dict = policy_actor.actor.state_dict() actor_state_dict = policy_actor.actor.state_dict()
for key, tensor in weights["policy"].items(): for key, tensor in weights["policy"].items():
assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), ( assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), (
f"Actor param '{key}' not synced by load_actor_weights" f"Actor param '{key}' not synced by algorithm.load_weights"
) )
# Discrete critic weights match the learner's exported discrete critic. # Discrete critic weights match the learner's exported discrete critic.
discrete_critic_state_dict = policy_actor.discrete_critic.state_dict() discrete_critic_state_dict = policy_actor.discrete_critic.state_dict()
for key, tensor in weights["discrete_critic"].items(): for key, tensor in weights["discrete_critic"].items():
assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), ( assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), (
f"Discrete critic param '{key}' not synced by load_actor_weights" f"Discrete critic param '{key}' not synced by algorithm.load_weights"
) )
# Sanity: the discrete critic actually changed (otherwise the sync is trivial). # Sanity: the discrete critic actually changed (otherwise the sync is trivial).