Compare commits

...

2 Commits

Author SHA1 Message Date
Khalil Meftah b2f3fd746f refactor(rl): move actor weight-sync wire format from policy to algorithm 2026-05-08 21:57:21 +02:00
Khalil Meftah 0944b84279 feat(rl): consolidate HIL-SERL checkpoint into HF-style components
Make  and  s, add abstract
 /  for algorithm-owned tensors (critics,
target nets, ), and persist them as a sibling
component next to . Replace the pickled
 side-file with an enriched
carrying both  and , so resume restores actor +
critics + target nets + temperature + optimizers + RNG + counters from
plain HF-standard files.
2026-05-08 21:24:23 +02:00
10 changed files with 408 additions and 47 deletions
@@ -17,7 +17,6 @@
from collections.abc import Callable from collections.abc import Callable
from dataclasses import asdict from dataclasses import asdict
from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -25,7 +24,6 @@ from torch import Tensor
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
from lerobot.utils.transition import move_state_dict_to_device
from ..pretrained import PreTrainedPolicy from ..pretrained import PreTrainedPolicy
from ..utils import get_device_from_parameters from ..utils import get_device_from_parameters
@@ -113,16 +111,6 @@ class GaussianActorPolicy(
actions, log_probs, means = self.actor(observations, observation_features) actions, log_probs, means = self.actor(observations, observation_features)
return {"action": actions, "log_prob": log_probs, "action_mean": means} return {"action": actions, "log_prob": log_probs, "action_mean": means}
def load_actor_weights(self, state_dicts: dict[str, Any], device: str | torch.device = "cpu") -> None:
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
self.actor.load_state_dict(actor_state_dict)
if "discrete_critic" in state_dicts and self.discrete_critic is not None:
discrete_critic_state_dict = move_state_dict_to_device(
state_dicts["discrete_critic"], device=device
)
self.discrete_critic.load_state_dict(discrete_critic_state_dict)
def _init_encoders(self): def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic.""" """Initialize shared or separate encoders for actor and critic."""
self.shared_encoder = self.config.shared_encoder self.shared_encoder = self.config.shared_encoder
+11 -4
View File
@@ -61,7 +61,7 @@ from torch.multiprocessing import Queue
from lerobot.cameras import opencv # noqa: F401 from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.processor import TransitionKey from lerobot.processor import TransitionKey
from lerobot.robots import so_follower # noqa: F401 from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401
@@ -80,6 +80,9 @@ from lerobot.utils.utils import (
init_logging, init_logging,
) )
from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm
if TYPE_CHECKING or _grpc_available: if TYPE_CHECKING or _grpc_available:
import grpc import grpc
@@ -277,6 +280,9 @@ def act_with_policy(
policy = policy.to(device).eval() policy = policy.to(device).eval()
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
# Build the algorithm
algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy)
preprocessor, postprocessor = make_pre_post_processors( preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, policy_cfg=cfg.policy,
dataset_stats=cfg.policy.dataset_stats, dataset_stats=cfg.policy.dataset_stats,
@@ -380,7 +386,7 @@ def act_with_policy(
if done or truncated: if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device) update_policy_parameters(algorithm=algorithm, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0: if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue( push_transitions_to_transport_queue(
@@ -675,7 +681,8 @@ def interactions_stream(
# Policy functions # Policy functions
def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, device): def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, device):
"""Drain the latest learner-pushed weights into ``algorithm.policy``."""
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False) bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
if bytes_state_dict is not None: if bytes_state_dict is not None:
logging.info("[ACTOR] Load new parameters from Learner.") logging.info("[ACTOR] Load new parameters from Learner.")
@@ -690,7 +697,7 @@ def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue,
# - Send critic's encoder state when shared_encoder=True # - Send critic's encoder state when shared_encoder=True
# - Skip encoder params entirely when freeze_vision_encoder=True # - Skip encoder params entirely when freeze_vision_encoder=True
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic) # - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
policy.load_actor_weights(state_dicts, device=device) algorithm.load_weights(state_dicts, device=device)
# Utilities functions # Utilities functions
+109 -2
View File
@@ -15,25 +15,38 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import builtins
import os
from collections.abc import Iterator from collections.abc import Iterator
from typing import TYPE_CHECKING, Any from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar
import torch import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_file as load_safetensors, save_file as save_safetensors
from torch.optim import Optimizer from torch.optim import Optimizer
from lerobot.types import BatchType from lerobot.types import BatchType
from lerobot.utils.hub import HubMixin
from .configs import RLAlgorithmConfig, TrainingStats from .configs import RLAlgorithmConfig, TrainingStats
if TYPE_CHECKING: if TYPE_CHECKING:
from torch import nn
from ..data_sources.data_mixer import DataMixer from ..data_sources.data_mixer import DataMixer
T = TypeVar("T", bound="RLAlgorithm")
class RLAlgorithm(abc.ABC):
class RLAlgorithm(HubMixin, abc.ABC):
"""Base for all RL algorithms.""" """Base for all RL algorithms."""
config_class: type[RLAlgorithmConfig] config_class: type[RLAlgorithmConfig]
name: str name: str
config: RLAlgorithmConfig
@abc.abstractmethod @abc.abstractmethod
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
@@ -98,3 +111,97 @@ class RLAlgorithm(abc.ABC):
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
"""Load policy state-dict received from the learner.""" """Load policy state-dict received from the learner."""
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
def state_dict(self) -> dict[str, torch.Tensor]:
"""Algorithm-owned trainable tensors.
Must return a flat tensor mapping for everything the algorithm owns
that is not part of the policy (e.g. critic ensembles, target networks,
temperature parameters). Algorithms with no training-only tensors
should explicitly return an empty dict.
"""
raise NotImplementedError
@abc.abstractmethod
def load_state_dict(
self,
state_dict: dict[str, torch.Tensor],
device: str | torch.device = "cpu",
) -> None:
"""In-place load of algorithm-owned tensors.
Implementations MUST keep the identity of any ``nn.Parameter`` that an
optimizer references (e.g. SAC's ``log_alpha``) by using ``.copy_()``
rather than rebinding the attribute.
"""
raise NotImplementedError
def _save_pretrained(self, save_directory: Path) -> None:
"""Persist the algorithm's tensors and config to ``save_directory``.
Writes ``model.safetensors`` (algorithm tensors via :meth:`state_dict`)
and ``config.json`` (via :meth:`RLAlgorithmConfig.save_pretrained`).
"""
tensors = {k: v.detach().cpu().contiguous() for k, v in self.state_dict().items()}
save_safetensors(tensors, str(save_directory / SAFETENSORS_SINGLE_FILE))
self.config._save_pretrained(save_directory)
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
policy: nn.Module,
config: RLAlgorithmConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
device: str | torch.device = "cpu",
**algo_kwargs: Any,
) -> T:
"""Build an algorithm and load its weights from ``pretrained_name_or_path``."""
if config is None:
config = cls.config_class.from_pretrained(
pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
if hasattr(config, "policy_config"):
config.policy_config = policy.config
instance = cls(policy=policy, config=config, **algo_kwargs)
model_id = str(pretrained_name_or_path)
if os.path.isdir(model_id):
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
else:
try:
model_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
) from e
tensors = load_safetensors(model_file)
instance.load_state_dict(tensors, device=device)
return instance
+76 -2
View File
@@ -15,10 +15,23 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import builtins
import logging
import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from pathlib import Path
from typing import Any, TypeVar
import draccus import draccus
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.utils.hub import HubMixin
T = TypeVar("T", bound="RLAlgorithmConfig")
logger = logging.getLogger(__name__)
@dataclass @dataclass
@@ -43,7 +56,7 @@ class TrainingStats:
@dataclass @dataclass
class RLAlgorithmConfig(draccus.ChoiceRegistry, abc.ABC): class RLAlgorithmConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""Registry for algorithm configs.""" """Registry for algorithm configs."""
@property @property
@@ -62,3 +75,64 @@ class RLAlgorithmConfig(draccus.ChoiceRegistry, abc.ABC):
Must be overridden by every registered config subclass. Must be overridden by every registered config subclass.
""" """
raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()") raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()")
def _save_pretrained(self, save_directory: Path) -> None:
"""Serialize this config as ``config.json`` inside ``save_directory``."""
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict[Any, Any] | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**algo_kwargs: Any,
) -> T:
model_id = str(pretrained_name_or_path)
config_file: str | None = None
if Path(model_id).is_dir():
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
if config_file is None:
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
with draccus.config_type("json"):
instance = draccus.parse(RLAlgorithmConfig, config_file, args=[])
if cls is not RLAlgorithmConfig and not isinstance(instance, cls):
raise TypeError(
f"Config at {model_id} has type '{instance.type}' but was loaded via "
f"{cls.__name__}; use the matching subclass or RLAlgorithmConfig.from_pretrained()."
)
for key, value in algo_kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
return instance
@@ -16,6 +16,7 @@ from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import ( from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
CriticNetworkConfig, CriticNetworkConfig,
GaussianActorConfig, GaussianActorConfig,
@@ -87,7 +88,7 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
use_torch_compile: bool = False use_torch_compile: bool = False
# Policy config # Policy config
policy_config: GaussianActorConfig | None = None policy_config: PreTrainedConfig | None = None
@classmethod @classmethod
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig: def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
+55 -1
View File
@@ -511,7 +511,51 @@ class SACAlgorithm(RLAlgorithm):
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
"""Load actor + discrete-critic weights into the policy.""" """Load actor + discrete-critic weights into the policy."""
self.policy.load_actor_weights(weights, device=device) actor_sd = move_state_dict_to_device(weights["policy"], device=device)
self.policy.actor.load_state_dict(actor_sd)
if "discrete_critic" in weights and self.policy.discrete_critic is not None:
discrete_sd = move_state_dict_to_device(weights["discrete_critic"], device=device)
self.policy.discrete_critic.load_state_dict(discrete_sd)
def state_dict(self) -> dict[str, torch.Tensor]:
"""Algorithm-owned trainable tensors.
Encoder weights are stripped because they are owned by the policy
(``policy.encoder_critic``) and already saved via ``policy.save_pretrained``.
"""
bundle: dict[str, torch.Tensor] = {}
for k, v in _strip_encoder_keys(self.critic_ensemble.state_dict()).items():
bundle[f"critic_ensemble.{k}"] = v
for k, v in _strip_encoder_keys(self.critic_target.state_dict()).items():
bundle[f"critic_target.{k}"] = v
if self.discrete_critic_target is not None:
for k, v in _strip_encoder_keys(self.discrete_critic_target.state_dict()).items():
bundle[f"discrete_critic_target.{k}"] = v
bundle["log_alpha"] = self.log_alpha.detach()
return bundle
def load_state_dict(
self,
state_dict: dict[str, torch.Tensor],
device: str | torch.device = "cpu",
) -> None:
"""In-place load of algorithm-owned tensors.
``log_alpha`` is restored via ``Parameter.data.copy_`` so the
``temperature`` optimizer's reference to the parameter object stays
valid after resume.
"""
critic_ensemble_state = _split_prefix(state_dict, "critic_ensemble.")
critic_target_state = _split_prefix(state_dict, "critic_target.")
self.critic_ensemble.load_state_dict(critic_ensemble_state, strict=False)
self.critic_target.load_state_dict(critic_target_state, strict=False)
if self.discrete_critic_target is not None:
discrete_target_state = _split_prefix(state_dict, "discrete_critic_target.")
self.discrete_critic_target.load_state_dict(discrete_target_state, strict=False)
if "log_alpha" in state_dict:
self.log_alpha.data.copy_(state_dict["log_alpha"].to(self.log_alpha.device))
def get_observation_features( def get_observation_features(
self, observations: Tensor, next_observations: Tensor self, observations: Tensor, next_observations: Tensor
@@ -540,6 +584,16 @@ class SACAlgorithm(RLAlgorithm):
return observation_features, next_observation_features return observation_features, next_observation_features
def _strip_encoder_keys(state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Drop ``encoder.*`` keys from a critic-module state dict."""
return {k: v for k, v in state.items() if not k.startswith("encoder.")}
def _split_prefix(state: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]:
"""Return the subset of ``state`` whose keys start with ``prefix``, prefix-stripped."""
return {k.removeprefix(prefix): v for k, v in state.items() if k.startswith(prefix)}
class CriticHead(nn.Module): class CriticHead(nn.Module):
def __init__( def __init__(
self, self,
+53 -19
View File
@@ -56,6 +56,8 @@ from pprint import pformat
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import torch import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from safetensors.torch import load_file as load_safetensors
from termcolor import colored from termcolor import colored
from torch import nn from torch import nn
from torch.multiprocessing import Queue from torch.multiprocessing import Queue
@@ -77,13 +79,16 @@ from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.constants import ( from lerobot.utils.constants import (
ACTION, ACTION,
ALGORITHM_DIR,
CHECKPOINTS_DIR, CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK, LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR, PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR, TRAINING_STATE_DIR,
TRAINING_STEP,
) )
from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.import_utils import _grpc_available, require_package from lerobot.utils.import_utils import _grpc_available, require_package
from lerobot.utils.io_utils import load_json, write_json
from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import ( from lerobot.utils.utils import (
@@ -370,7 +375,9 @@ def add_actor_information_and_train(
# If we are resuming, we need to load the training state # If we are resuming, we need to load the training state
optimizers = algorithm.get_optimizers() optimizers = algorithm.get_optimizers()
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) resume_optimization_step, resume_interaction_step = load_training_state(
cfg=cfg, optimizers=optimizers, algorithm=algorithm, device=device
)
logging.info("Starting learner thread") logging.info("Starting learner thread")
interaction_message = None interaction_message = None
@@ -464,6 +471,7 @@ def add_actor_information_and_train(
policy=policy, policy=policy,
optimizers=optimizers, optimizers=optimizers,
replay_buffer=replay_buffer, replay_buffer=replay_buffer,
algorithm=algorithm,
offline_replay_buffer=offline_replay_buffer, offline_replay_buffer=offline_replay_buffer,
dataset_repo_id=dataset_repo_id, dataset_repo_id=dataset_repo_id,
fps=fps, fps=fps,
@@ -550,6 +558,7 @@ def save_training_checkpoint(
policy: nn.Module, policy: nn.Module,
optimizers: dict[str, Optimizer], optimizers: dict[str, Optimizer],
replay_buffer: ReplayBuffer, replay_buffer: ReplayBuffer,
algorithm: RLAlgorithm | None = None,
offline_replay_buffer: ReplayBuffer | None = None, offline_replay_buffer: ReplayBuffer | None = None,
dataset_repo_id: str | None = None, dataset_repo_id: str | None = None,
fps: int = 30, fps: int = 30,
@@ -588,7 +597,7 @@ def save_training_checkpoint(
# Create checkpoint directory # Create checkpoint directory
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step) checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
# Save checkpoint # Save policy artifacts (pretrained_model/) + Trainer scaffolding (training_state/).
save_checkpoint( save_checkpoint(
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
step=optimization_step, step=optimization_step,
@@ -600,11 +609,18 @@ def save_training_checkpoint(
postprocessor=postprocessor, postprocessor=postprocessor,
) )
# Save interaction step manually # Algorithm-owned tensors live in their own component subfolder
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR) # so they can be `push_to_hub`'d independently and don't bloat the inference artifact.
os.makedirs(training_state_dir, exist_ok=True) if algorithm is not None:
training_state = {"step": optimization_step, "interaction_step": interaction_step} algorithm.save_pretrained(checkpoint_dir / ALGORITHM_DIR)
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
# Enrich training_step.json with the RL-specific interaction_step counter so
# both can be restored from a single file.
training_state_dir = checkpoint_dir / TRAINING_STATE_DIR
write_json(
{"step": optimization_step, "interaction_step": interaction_step},
training_state_dir / TRAINING_STEP,
)
# Update the "last" symlink # Update the "last" symlink
update_last_checkpoint(checkpoint_dir) update_last_checkpoint(checkpoint_dir)
@@ -698,13 +714,20 @@ def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipeli
def load_training_state( def load_training_state(
cfg: TrainRLServerPipelineConfig, cfg: TrainRLServerPipelineConfig,
optimizers: Optimizer | dict[str, Optimizer], optimizers: Optimizer | dict[str, Optimizer],
algorithm: RLAlgorithm | None = None,
device: str | torch.device = "cpu",
): ):
""" """
Loads the training state (optimizers, step count, etc.) from a checkpoint. Loads the training state (optimizers, RNG, step + interaction step, and
algorithm-owned tensors) from the most recent checkpoint.
Args: Args:
cfg (TrainRLServerPipelineConfig): Training configuration cfg: Training configuration.
optimizers (Optimizer | dict): Optimizers to load state into optimizers: Optimizers to load state into.
algorithm: Algorithm whose state dict should be restored.
Required for full main-equivalent resume;
the policy itself is restored separately via ``make_policy``.
device: Device on which to place loaded algorithm tensors.
Returns: Returns:
tuple: (optimization_step, interaction_step) or (None, None) if not resuming tuple: (optimization_step, interaction_step) or (None, None) if not resuming
@@ -713,20 +736,31 @@ def load_training_state(
return None, None return None, None
# Construct path to the last checkpoint directory # Construct path to the last checkpoint directory
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) checkpoint_dir = Path(cfg.output_dir) / CHECKPOINTS_DIR / LAST_CHECKPOINT_LINK
logging.info(f"Loading training state from {checkpoint_dir}") logging.info(f"Loading training state from {checkpoint_dir}")
try: try:
# Use the utility function from train_utils which loads the optimizer state # Restore optimizers + RNG + step from the standard `training_state/` folder
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None) step, optimizers, _ = utils_load_training_state(checkpoint_dir, optimizers, None)
# Load interaction step separately from training_state.pt # Restore algorithm-owned tensors
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt") if algorithm is not None:
interaction_step = 0 algo_dir = checkpoint_dir / ALGORITHM_DIR
if os.path.exists(training_state_path): if algo_dir.is_dir():
training_state = torch.load(training_state_path, weights_only=False) # nosec B614: Safe usage of torch.load tensors = load_safetensors(str(algo_dir / SAFETENSORS_SINGLE_FILE))
interaction_step = training_state.get("interaction_step", 0) algorithm.load_state_dict(tensors, device=device)
logging.info(f"Loaded algorithm state from {algo_dir}")
else:
logging.warning(
f"No algorithm state found at {algo_dir}; "
"will keep their freshly-initialised values. Adam moments restored from the "
"old optimizer state may not match these reset parameters."
)
# Read interaction_step from the enriched training_step.json
training_step_path = checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP
interaction_step = int(load_json(training_step_path).get("interaction_step", 0))
logging.info(f"Resuming from step {step}, interaction step {interaction_step}") logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
return step, interaction_step return step, interaction_step
+1
View File
@@ -47,6 +47,7 @@ CHECKPOINTS_DIR = "checkpoints"
LAST_CHECKPOINT_LINK = "last" LAST_CHECKPOINT_LINK = "last"
PRETRAINED_MODEL_DIR = "pretrained_model" PRETRAINED_MODEL_DIR = "pretrained_model"
TRAINING_STATE_DIR = "training_state" TRAINING_STATE_DIR = "training_state"
ALGORITHM_DIR = "algorithm"
RNG_STATE = "rng_state.safetensors" RNG_STATE = "rng_state.safetensors"
TRAINING_STEP = "training_step.json" TRAINING_STEP = "training_step.json"
OPTIMIZER_STATE = "optimizer_state.safetensors" OPTIMIZER_STATE = "optimizer_state.safetensors"
+95 -6
View File
@@ -445,35 +445,39 @@ def test_load_weights_ignores_missing_discrete_critic():
def test_actor_side_weight_sync_with_discrete_critic(): def test_actor_side_weight_sync_with_discrete_critic():
"""End-to-end: learner ``algorithm.get_weights()`` -> actor ``policy.load_actor_weights()``.""" """End-to-end: learner ``algorithm.get_weights()`` -> actor ``algorithm.load_weights()``."""
# Learner side: train the algorithm so its weights diverge from init. # Learner side: train the source algorithm so its weights diverge from init.
algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
algo_src.update(_batch_iterator(action_dim=7)) algo_src.update(_batch_iterator(action_dim=7))
weights = algo_src.get_weights() weights = algo_src.get_weights()
# Actor side: fresh policy, no algorithm/optimizer. # Actor side: fresh policy + fresh algorithm holding it.
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6) sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
policy_actor = GaussianActorPolicy(config=sac_cfg) policy_actor = GaussianActorPolicy(config=sac_cfg)
algo_actor = SACAlgorithm(
policy=policy_actor,
config=SACAlgorithmConfig.from_policy_config(sac_cfg),
)
# Snapshot initial actor state for the "did it change?" assertion below. # Snapshot initial actor state for the "did it change?" assertion below.
initial_discrete_critic_state_dict = { initial_discrete_critic_state_dict = {
k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items() k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items()
} }
policy_actor.load_actor_weights(weights, device="cpu") algo_actor.load_weights(weights, device="cpu")
# Actor weights match the learner's exported actor state dict. # Actor weights match the learner's exported actor state dict.
actor_state_dict = policy_actor.actor.state_dict() actor_state_dict = policy_actor.actor.state_dict()
for key, tensor in weights["policy"].items(): for key, tensor in weights["policy"].items():
assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), ( assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), (
f"Actor param '{key}' not synced by load_actor_weights" f"Actor param '{key}' not synced by algorithm.load_weights"
) )
# Discrete critic weights match the learner's exported discrete critic. # Discrete critic weights match the learner's exported discrete critic.
discrete_critic_state_dict = policy_actor.discrete_critic.state_dict() discrete_critic_state_dict = policy_actor.discrete_critic.state_dict()
for key, tensor in weights["discrete_critic"].items(): for key, tensor in weights["discrete_critic"].items():
assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), ( assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), (
f"Discrete critic param '{key}' not synced by load_actor_weights" f"Discrete critic param '{key}' not synced by algorithm.load_weights"
) )
# Sanity: the discrete critic actually changed (otherwise the sync is trivial). # Sanity: the discrete critic actually changed (otherwise the sync is trivial).
@@ -515,3 +519,88 @@ def test_make_algorithm_builds_sac():
algorithm = make_algorithm(cfg=algo_config, policy=policy) algorithm = make_algorithm(cfg=algo_config, policy=policy)
assert isinstance(algorithm, SACAlgorithm) assert isinstance(algorithm, SACAlgorithm)
assert algorithm.config.utd_ratio == 2 assert algorithm.config.utd_ratio == 2
# ===========================================================================
# state_dict / load_state_dict (algorithm-side resume)
# ===========================================================================
def test_state_dict_contains_algorithm_owned_tensors():
"""state_dict should pack critics, target networks, and log_alpha (no encoder bloat)."""
algorithm, _ = _make_algorithm()
sd = algorithm.state_dict()
assert "log_alpha" in sd
assert any(k.startswith("critic_ensemble.") for k in sd)
assert any(k.startswith("critic_target.") for k in sd)
# encoder weights live on the policy and must not be duplicated here.
assert not any(".encoder." in k for k in sd)
def test_state_dict_includes_discrete_critic_target_when_present():
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
sd = algorithm.state_dict()
assert any(k.startswith("discrete_critic_target.") for k in sd)
def test_load_state_dict_round_trip_restores_critics_and_log_alpha():
"""state_dict -> load_state_dict on a fresh algorithm restores all bytes exactly."""
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
src_policy = GaussianActorPolicy(config=sac_cfg)
src = SACAlgorithm(policy=src_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg))
src.make_optimizers_and_scheduler()
# Train a few steps so weights diverge from init (action_dim=7 = 6 continuous + 1 discrete).
src.update(_batch_iterator(action_dim=7))
src.update(_batch_iterator(action_dim=7))
dst_policy = GaussianActorPolicy(config=sac_cfg)
dst = SACAlgorithm(policy=dst_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg))
dst.make_optimizers_and_scheduler()
src_sd = src.state_dict()
dst.load_state_dict(src_sd)
dst_sd = dst.state_dict()
assert set(dst_sd) == set(src_sd)
for key in src_sd:
assert torch.allclose(src_sd[key].cpu(), dst_sd[key].cpu()), f"{key} mismatch after round-trip"
def test_load_state_dict_preserves_log_alpha_parameter_identity():
"""The temperature optimizer holds a reference to log_alpha; identity must survive load."""
algorithm, _ = _make_algorithm()
log_alpha_id_before = id(algorithm.log_alpha)
optimizer_param_id = id(algorithm.optimizers["temperature"].param_groups[0]["params"][0])
assert log_alpha_id_before == optimizer_param_id
new_state = algorithm.state_dict()
new_state["log_alpha"] = torch.tensor([0.42])
algorithm.load_state_dict(new_state)
assert id(algorithm.log_alpha) == log_alpha_id_before
assert id(algorithm.optimizers["temperature"].param_groups[0]["params"][0]) == log_alpha_id_before
assert torch.allclose(algorithm.log_alpha.detach().cpu(), torch.tensor([0.42]))
def test_save_pretrained_round_trip_via_disk(tmp_path):
"""End-to-end: save_pretrained -> from_pretrained restores tensors and config."""
sac_cfg = _make_sac_config()
src_policy = GaussianActorPolicy(config=sac_cfg)
src = SACAlgorithm(policy=src_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg))
src.make_optimizers_and_scheduler()
src.update(_batch_iterator())
save_dir = tmp_path / "algorithm"
src.save_pretrained(save_dir)
assert (save_dir / "model.safetensors").is_file()
assert (save_dir / "config.json").is_file()
dst_policy = GaussianActorPolicy(config=sac_cfg)
dst = SACAlgorithm.from_pretrained(save_dir, policy=dst_policy)
src_sd = src.state_dict()
dst_sd = dst.state_dict()
assert set(src_sd) == set(dst_sd)
for key in src_sd:
assert torch.allclose(src_sd[key].cpu(), dst_sd[key].cpu()), f"{key} mismatch after disk round-trip"
+6
View File
@@ -68,6 +68,12 @@ class _DummyRLAlgorithm(RLAlgorithm):
def load_weights(self, weights, device="cpu") -> None: def load_weights(self, weights, device="cpu") -> None:
_ = (weights, device) _ = (weights, device)
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state_dict, device="cpu") -> None:
_ = (state_dict, device)
class _SimpleMixer: class _SimpleMixer:
def get_iterator(self, batch_size: int, async_prefetch: bool = True, queue_size: int = 2): def get_iterator(self, batch_size: int, async_prefetch: bool = True, queue_size: int = 2):