mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
refactor(rl): move actor weight-sync wire format from policy to algorithm
This commit is contained in:
@@ -17,7 +17,6 @@
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -25,7 +24,6 @@ from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
|
||||
from lerobot.utils.transition import move_state_dict_to_device
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..utils import get_device_from_parameters
|
||||
@@ -113,16 +111,6 @@ class GaussianActorPolicy(
|
||||
actions, log_probs, means = self.actor(observations, observation_features)
|
||||
return {"action": actions, "log_prob": log_probs, "action_mean": means}
|
||||
|
||||
def load_actor_weights(self, state_dicts: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
|
||||
self.actor.load_state_dict(actor_state_dict)
|
||||
|
||||
if "discrete_critic" in state_dicts and self.discrete_critic is not None:
|
||||
discrete_critic_state_dict = move_state_dict_to_device(
|
||||
state_dicts["discrete_critic"], device=device
|
||||
)
|
||||
self.discrete_critic.load_state_dict(discrete_critic_state_dict)
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
|
||||
+11
-4
@@ -61,7 +61,7 @@ from torch.multiprocessing import Queue
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.policies import make_policy, make_pre_post_processors
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||
@@ -80,6 +80,9 @@ from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
)
|
||||
|
||||
from .algorithms.base import RLAlgorithm
|
||||
from .algorithms.factory import make_algorithm
|
||||
|
||||
if TYPE_CHECKING or _grpc_available:
|
||||
import grpc
|
||||
|
||||
@@ -277,6 +280,9 @@ def act_with_policy(
|
||||
policy = policy.to(device).eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# Build the algorithm
|
||||
algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
dataset_stats=cfg.policy.dataset_stats,
|
||||
@@ -380,7 +386,7 @@ def act_with_policy(
|
||||
if done or truncated:
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
|
||||
update_policy_parameters(algorithm=algorithm, parameters_queue=parameters_queue, device=device)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
push_transitions_to_transport_queue(
|
||||
@@ -675,7 +681,8 @@ def interactions_stream(
|
||||
# Policy functions
|
||||
|
||||
|
||||
def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, device):
|
||||
def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, device):
|
||||
"""Drain the latest learner-pushed weights into ``algorithm.policy``."""
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
|
||||
if bytes_state_dict is not None:
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
@@ -690,7 +697,7 @@ def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue,
|
||||
# - Send critic's encoder state when shared_encoder=True
|
||||
# - Skip encoder params entirely when freeze_vision_encoder=True
|
||||
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
|
||||
policy.load_actor_weights(state_dicts, device=device)
|
||||
algorithm.load_weights(state_dicts, device=device)
|
||||
|
||||
|
||||
# Utilities functions
|
||||
|
||||
@@ -511,7 +511,11 @@ class SACAlgorithm(RLAlgorithm):
|
||||
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
"""Load actor + discrete-critic weights into the policy."""
|
||||
self.policy.load_actor_weights(weights, device=device)
|
||||
actor_sd = move_state_dict_to_device(weights["policy"], device=device)
|
||||
self.policy.actor.load_state_dict(actor_sd)
|
||||
if "discrete_critic" in weights and self.policy.discrete_critic is not None:
|
||||
discrete_sd = move_state_dict_to_device(weights["discrete_critic"], device=device)
|
||||
self.policy.discrete_critic.load_state_dict(discrete_sd)
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Algorithm-owned trainable tensors.
|
||||
|
||||
@@ -445,35 +445,39 @@ def test_load_weights_ignores_missing_discrete_critic():
|
||||
|
||||
|
||||
def test_actor_side_weight_sync_with_discrete_critic():
|
||||
"""End-to-end: learner ``algorithm.get_weights()`` -> actor ``policy.load_actor_weights()``."""
|
||||
# Learner side: train the algorithm so its weights diverge from init.
|
||||
"""End-to-end: learner ``algorithm.get_weights()`` -> actor ``algorithm.load_weights()``."""
|
||||
# Learner side: train the source algorithm so its weights diverge from init.
|
||||
algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||
algo_src.update(_batch_iterator(action_dim=7))
|
||||
weights = algo_src.get_weights()
|
||||
|
||||
# Actor side: fresh policy, no algorithm/optimizer.
|
||||
# Actor side: fresh policy + fresh algorithm holding it.
|
||||
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
|
||||
policy_actor = GaussianActorPolicy(config=sac_cfg)
|
||||
algo_actor = SACAlgorithm(
|
||||
policy=policy_actor,
|
||||
config=SACAlgorithmConfig.from_policy_config(sac_cfg),
|
||||
)
|
||||
|
||||
# Snapshot initial actor state for the "did it change?" assertion below.
|
||||
initial_discrete_critic_state_dict = {
|
||||
k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items()
|
||||
}
|
||||
|
||||
policy_actor.load_actor_weights(weights, device="cpu")
|
||||
algo_actor.load_weights(weights, device="cpu")
|
||||
|
||||
# Actor weights match the learner's exported actor state dict.
|
||||
actor_state_dict = policy_actor.actor.state_dict()
|
||||
for key, tensor in weights["policy"].items():
|
||||
assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), (
|
||||
f"Actor param '{key}' not synced by load_actor_weights"
|
||||
f"Actor param '{key}' not synced by algorithm.load_weights"
|
||||
)
|
||||
|
||||
# Discrete critic weights match the learner's exported discrete critic.
|
||||
discrete_critic_state_dict = policy_actor.discrete_critic.state_dict()
|
||||
for key, tensor in weights["discrete_critic"].items():
|
||||
assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), (
|
||||
f"Discrete critic param '{key}' not synced by load_actor_weights"
|
||||
f"Discrete critic param '{key}' not synced by algorithm.load_weights"
|
||||
)
|
||||
|
||||
# Sanity: the discrete critic actually changed (otherwise the sync is trivial).
|
||||
|
||||
Reference in New Issue
Block a user