mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
feat(rl): consolidate HIL-SERL checkpoint into HF-style components
Make and s, add abstract / for algorithm-owned tensors (critics, target nets, ), and persist them as a sibling component next to . Replace the pickled side-file with an enriched carrying both and , so resume restores actor + critics + target nets + temperature + optimizers + RNG + counters from plain HF-standard files.
This commit is contained in:
@@ -15,25 +15,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import builtins
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_file as load_safetensors, save_file as save_safetensors
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.types import BatchType
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
from .configs import RLAlgorithmConfig, TrainingStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import nn
|
||||
|
||||
from ..data_sources.data_mixer import DataMixer
|
||||
|
||||
T = TypeVar("T", bound="RLAlgorithm")
|
||||
|
||||
class RLAlgorithm(abc.ABC):
|
||||
|
||||
class RLAlgorithm(HubMixin, abc.ABC):
|
||||
"""Base for all RL algorithms."""
|
||||
|
||||
config_class: type[RLAlgorithmConfig]
|
||||
name: str
|
||||
config: RLAlgorithmConfig
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
@@ -98,3 +111,97 @@ class RLAlgorithm(abc.ABC):
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
"""Load policy state-dict received from the learner."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Algorithm-owned trainable tensors.
|
||||
|
||||
Must return a flat tensor mapping for everything the algorithm owns
|
||||
that is not part of the policy (e.g. critic ensembles, target networks,
|
||||
temperature parameters). Algorithms with no training-only tensors
|
||||
should explicitly return an empty dict.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
device: str | torch.device = "cpu",
|
||||
) -> None:
|
||||
"""In-place load of algorithm-owned tensors.
|
||||
|
||||
Implementations MUST keep the identity of any ``nn.Parameter`` that an
|
||||
optimizer references (e.g. SAC's ``log_alpha``) by using ``.copy_()``
|
||||
rather than rebinding the attribute.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
"""Persist the algorithm's tensors and config to ``save_directory``.
|
||||
|
||||
Writes ``model.safetensors`` (algorithm tensors via :meth:`state_dict`)
|
||||
and ``config.json`` (via :meth:`RLAlgorithmConfig.save_pretrained`).
|
||||
"""
|
||||
tensors = {k: v.detach().cpu().contiguous() for k, v in self.state_dict().items()}
|
||||
save_safetensors(tensors, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
self.config._save_pretrained(save_directory)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
policy: nn.Module,
|
||||
config: RLAlgorithmConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
device: str | torch.device = "cpu",
|
||||
**algo_kwargs: Any,
|
||||
) -> T:
|
||||
"""Build an algorithm and load its weights from ``pretrained_name_or_path``."""
|
||||
if config is None:
|
||||
config = cls.config_class.from_pretrained(
|
||||
pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
if hasattr(config, "policy_config"):
|
||||
config.policy_config = policy.config
|
||||
|
||||
instance = cls(policy=policy, config=config, **algo_kwargs)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
if os.path.isdir(model_id):
|
||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_SINGLE_FILE,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
tensors = load_safetensors(model_file)
|
||||
instance.load_state_dict(tensors, device=device)
|
||||
return instance
|
||||
|
||||
@@ -15,10 +15,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
T = TypeVar("T", bound="RLAlgorithmConfig")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -43,7 +56,7 @@ class TrainingStats:
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLAlgorithmConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
class RLAlgorithmConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
"""Registry for algorithm configs."""
|
||||
|
||||
@property
|
||||
@@ -62,3 +75,64 @@ class RLAlgorithmConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
Must be overridden by every registered config subclass.
|
||||
"""
|
||||
raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()")
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
"""Serialize this config as ``config.json`` inside ``save_directory``."""
|
||||
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[Any, Any] | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**algo_kwargs: Any,
|
||||
) -> T:
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
if Path(model_id).is_dir():
|
||||
if CONFIG_NAME in os.listdir(model_id):
|
||||
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||
else:
|
||||
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
if config_file is None:
|
||||
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
|
||||
|
||||
with draccus.config_type("json"):
|
||||
instance = draccus.parse(RLAlgorithmConfig, config_file, args=[])
|
||||
|
||||
if cls is not RLAlgorithmConfig and not isinstance(instance, cls):
|
||||
raise TypeError(
|
||||
f"Config at {model_id} has type '{instance.type}' but was loaded via "
|
||||
f"{cls.__name__}; use the matching subclass or RLAlgorithmConfig.from_pretrained()."
|
||||
)
|
||||
|
||||
for key, value in algo_kwargs.items():
|
||||
if hasattr(instance, key):
|
||||
setattr(instance, key, value)
|
||||
return instance
|
||||
|
||||
@@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
|
||||
CriticNetworkConfig,
|
||||
GaussianActorConfig,
|
||||
@@ -87,7 +88,7 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
|
||||
use_torch_compile: bool = False
|
||||
|
||||
# Policy config
|
||||
policy_config: GaussianActorConfig | None = None
|
||||
policy_config: PreTrainedConfig | None = None
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
|
||||
|
||||
@@ -513,6 +513,46 @@ class SACAlgorithm(RLAlgorithm):
|
||||
"""Load actor + discrete-critic weights into the policy."""
|
||||
self.policy.load_actor_weights(weights, device=device)
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Algorithm-owned trainable tensors.
|
||||
|
||||
Encoder weights are stripped because they are owned by the policy
|
||||
(``policy.encoder_critic``) and already saved via ``policy.save_pretrained``.
|
||||
"""
|
||||
bundle: dict[str, torch.Tensor] = {}
|
||||
for k, v in _strip_encoder_keys(self.critic_ensemble.state_dict()).items():
|
||||
bundle[f"critic_ensemble.{k}"] = v
|
||||
for k, v in _strip_encoder_keys(self.critic_target.state_dict()).items():
|
||||
bundle[f"critic_target.{k}"] = v
|
||||
if self.discrete_critic_target is not None:
|
||||
for k, v in _strip_encoder_keys(self.discrete_critic_target.state_dict()).items():
|
||||
bundle[f"discrete_critic_target.{k}"] = v
|
||||
bundle["log_alpha"] = self.log_alpha.detach()
|
||||
return bundle
|
||||
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
device: str | torch.device = "cpu",
|
||||
) -> None:
|
||||
"""In-place load of algorithm-owned tensors.
|
||||
|
||||
``log_alpha`` is restored via ``Parameter.data.copy_`` so the
|
||||
``temperature`` optimizer's reference to the parameter object stays
|
||||
valid after resume.
|
||||
"""
|
||||
critic_ensemble_state = _split_prefix(state_dict, "critic_ensemble.")
|
||||
critic_target_state = _split_prefix(state_dict, "critic_target.")
|
||||
self.critic_ensemble.load_state_dict(critic_ensemble_state, strict=False)
|
||||
self.critic_target.load_state_dict(critic_target_state, strict=False)
|
||||
|
||||
if self.discrete_critic_target is not None:
|
||||
discrete_target_state = _split_prefix(state_dict, "discrete_critic_target.")
|
||||
self.discrete_critic_target.load_state_dict(discrete_target_state, strict=False)
|
||||
|
||||
if "log_alpha" in state_dict:
|
||||
self.log_alpha.data.copy_(state_dict["log_alpha"].to(self.log_alpha.device))
|
||||
|
||||
def get_observation_features(
|
||||
self, observations: Tensor, next_observations: Tensor
|
||||
) -> tuple[Tensor | None, Tensor | None]:
|
||||
@@ -540,6 +580,16 @@ class SACAlgorithm(RLAlgorithm):
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
def _strip_encoder_keys(state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""Drop ``encoder.*`` keys from a critic-module state dict."""
|
||||
return {k: v for k, v in state.items() if not k.startswith("encoder.")}
|
||||
|
||||
|
||||
def _split_prefix(state: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]:
|
||||
"""Return the subset of ``state`` whose keys start with ``prefix``, prefix-stripped."""
|
||||
return {k.removeprefix(prefix): v for k, v in state.items() if k.startswith(prefix)}
|
||||
|
||||
|
||||
class CriticHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
+53
-19
@@ -56,6 +56,8 @@ from pprint import pformat
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
from torch.multiprocessing import Queue
|
||||
@@ -77,13 +79,16 @@ from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
ALGORITHM_DIR,
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
PRETRAINED_MODEL_DIR,
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.import_utils import _grpc_available, require_package
|
||||
from lerobot.utils.io_utils import load_json, write_json
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
@@ -370,7 +375,9 @@ def add_actor_information_and_train(
|
||||
|
||||
# If we are resuming, we need to load the training state
|
||||
optimizers = algorithm.get_optimizers()
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(
|
||||
cfg=cfg, optimizers=optimizers, algorithm=algorithm, device=device
|
||||
)
|
||||
|
||||
logging.info("Starting learner thread")
|
||||
interaction_message = None
|
||||
@@ -464,6 +471,7 @@ def add_actor_information_and_train(
|
||||
policy=policy,
|
||||
optimizers=optimizers,
|
||||
replay_buffer=replay_buffer,
|
||||
algorithm=algorithm,
|
||||
offline_replay_buffer=offline_replay_buffer,
|
||||
dataset_repo_id=dataset_repo_id,
|
||||
fps=fps,
|
||||
@@ -550,6 +558,7 @@ def save_training_checkpoint(
|
||||
policy: nn.Module,
|
||||
optimizers: dict[str, Optimizer],
|
||||
replay_buffer: ReplayBuffer,
|
||||
algorithm: RLAlgorithm | None = None,
|
||||
offline_replay_buffer: ReplayBuffer | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
fps: int = 30,
|
||||
@@ -588,7 +597,7 @@ def save_training_checkpoint(
|
||||
# Create checkpoint directory
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
|
||||
|
||||
# Save checkpoint
|
||||
# Save policy artifacts (pretrained_model/) + Trainer scaffolding (training_state/).
|
||||
save_checkpoint(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
step=optimization_step,
|
||||
@@ -600,11 +609,18 @@ def save_training_checkpoint(
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
# Save interaction step manually
|
||||
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
|
||||
os.makedirs(training_state_dir, exist_ok=True)
|
||||
training_state = {"step": optimization_step, "interaction_step": interaction_step}
|
||||
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
|
||||
# Algorithm-owned tensors live in their own component subfolder
|
||||
# so they can be `push_to_hub`'d independently and don't bloat the inference artifact.
|
||||
if algorithm is not None:
|
||||
algorithm.save_pretrained(checkpoint_dir / ALGORITHM_DIR)
|
||||
|
||||
# Enrich training_step.json with the RL-specific interaction_step counter so
|
||||
# both can be restored from a single file.
|
||||
training_state_dir = checkpoint_dir / TRAINING_STATE_DIR
|
||||
write_json(
|
||||
{"step": optimization_step, "interaction_step": interaction_step},
|
||||
training_state_dir / TRAINING_STEP,
|
||||
)
|
||||
|
||||
# Update the "last" symlink
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
@@ -698,13 +714,20 @@ def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipeli
|
||||
def load_training_state(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
optimizers: Optimizer | dict[str, Optimizer],
|
||||
algorithm: RLAlgorithm | None = None,
|
||||
device: str | torch.device = "cpu",
|
||||
):
|
||||
"""
|
||||
Loads the training state (optimizers, step count, etc.) from a checkpoint.
|
||||
Loads the training state (optimizers, RNG, step + interaction step, and
|
||||
algorithm-owned tensors) from the most recent checkpoint.
|
||||
|
||||
Args:
|
||||
cfg (TrainRLServerPipelineConfig): Training configuration
|
||||
optimizers (Optimizer | dict): Optimizers to load state into
|
||||
cfg: Training configuration.
|
||||
optimizers: Optimizers to load state into.
|
||||
algorithm: Algorithm whose state dict should be restored.
|
||||
Required for full main-equivalent resume;
|
||||
the policy itself is restored separately via ``make_policy``.
|
||||
device: Device on which to place loaded algorithm tensors.
|
||||
|
||||
Returns:
|
||||
tuple: (optimization_step, interaction_step) or (None, None) if not resuming
|
||||
@@ -713,20 +736,31 @@ def load_training_state(
|
||||
return None, None
|
||||
|
||||
# Construct path to the last checkpoint directory
|
||||
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
||||
checkpoint_dir = Path(cfg.output_dir) / CHECKPOINTS_DIR / LAST_CHECKPOINT_LINK
|
||||
|
||||
logging.info(f"Loading training state from {checkpoint_dir}")
|
||||
|
||||
try:
|
||||
# Use the utility function from train_utils which loads the optimizer state
|
||||
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
|
||||
# Restore optimizers + RNG + step from the standard `training_state/` folder
|
||||
step, optimizers, _ = utils_load_training_state(checkpoint_dir, optimizers, None)
|
||||
|
||||
# Load interaction step separately from training_state.pt
|
||||
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
|
||||
interaction_step = 0
|
||||
if os.path.exists(training_state_path):
|
||||
training_state = torch.load(training_state_path, weights_only=False) # nosec B614: Safe usage of torch.load
|
||||
interaction_step = training_state.get("interaction_step", 0)
|
||||
# Restore algorithm-owned tensors
|
||||
if algorithm is not None:
|
||||
algo_dir = checkpoint_dir / ALGORITHM_DIR
|
||||
if algo_dir.is_dir():
|
||||
tensors = load_safetensors(str(algo_dir / SAFETENSORS_SINGLE_FILE))
|
||||
algorithm.load_state_dict(tensors, device=device)
|
||||
logging.info(f"Loaded algorithm state from {algo_dir}")
|
||||
else:
|
||||
logging.warning(
|
||||
f"No algorithm state found at {algo_dir}; "
|
||||
"will keep their freshly-initialised values. Adam moments restored from the "
|
||||
"old optimizer state may not match these reset parameters."
|
||||
)
|
||||
|
||||
# Read interaction_step from the enriched training_step.json
|
||||
training_step_path = checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP
|
||||
interaction_step = int(load_json(training_step_path).get("interaction_step", 0))
|
||||
|
||||
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
|
||||
return step, interaction_step
|
||||
|
||||
@@ -47,6 +47,7 @@ CHECKPOINTS_DIR = "checkpoints"
|
||||
LAST_CHECKPOINT_LINK = "last"
|
||||
PRETRAINED_MODEL_DIR = "pretrained_model"
|
||||
TRAINING_STATE_DIR = "training_state"
|
||||
ALGORITHM_DIR = "algorithm"
|
||||
RNG_STATE = "rng_state.safetensors"
|
||||
TRAINING_STEP = "training_step.json"
|
||||
OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||
|
||||
@@ -515,3 +515,88 @@ def test_make_algorithm_builds_sac():
|
||||
algorithm = make_algorithm(cfg=algo_config, policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.config.utd_ratio == 2
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# state_dict / load_state_dict (algorithm-side resume)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_state_dict_contains_algorithm_owned_tensors():
|
||||
"""state_dict should pack critics, target networks, and log_alpha (no encoder bloat)."""
|
||||
algorithm, _ = _make_algorithm()
|
||||
sd = algorithm.state_dict()
|
||||
|
||||
assert "log_alpha" in sd
|
||||
assert any(k.startswith("critic_ensemble.") for k in sd)
|
||||
assert any(k.startswith("critic_target.") for k in sd)
|
||||
# encoder weights live on the policy and must not be duplicated here.
|
||||
assert not any(".encoder." in k for k in sd)
|
||||
|
||||
|
||||
def test_state_dict_includes_discrete_critic_target_when_present():
|
||||
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||
sd = algorithm.state_dict()
|
||||
assert any(k.startswith("discrete_critic_target.") for k in sd)
|
||||
|
||||
|
||||
def test_load_state_dict_round_trip_restores_critics_and_log_alpha():
|
||||
"""state_dict -> load_state_dict on a fresh algorithm restores all bytes exactly."""
|
||||
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
|
||||
src_policy = GaussianActorPolicy(config=sac_cfg)
|
||||
src = SACAlgorithm(policy=src_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg))
|
||||
src.make_optimizers_and_scheduler()
|
||||
# Train a few steps so weights diverge from init (action_dim=7 = 6 continuous + 1 discrete).
|
||||
src.update(_batch_iterator(action_dim=7))
|
||||
src.update(_batch_iterator(action_dim=7))
|
||||
|
||||
dst_policy = GaussianActorPolicy(config=sac_cfg)
|
||||
dst = SACAlgorithm(policy=dst_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg))
|
||||
dst.make_optimizers_and_scheduler()
|
||||
|
||||
src_sd = src.state_dict()
|
||||
dst.load_state_dict(src_sd)
|
||||
dst_sd = dst.state_dict()
|
||||
|
||||
assert set(dst_sd) == set(src_sd)
|
||||
for key in src_sd:
|
||||
assert torch.allclose(src_sd[key].cpu(), dst_sd[key].cpu()), f"{key} mismatch after round-trip"
|
||||
|
||||
|
||||
def test_load_state_dict_preserves_log_alpha_parameter_identity():
|
||||
"""The temperature optimizer holds a reference to log_alpha; identity must survive load."""
|
||||
algorithm, _ = _make_algorithm()
|
||||
log_alpha_id_before = id(algorithm.log_alpha)
|
||||
optimizer_param_id = id(algorithm.optimizers["temperature"].param_groups[0]["params"][0])
|
||||
assert log_alpha_id_before == optimizer_param_id
|
||||
|
||||
new_state = algorithm.state_dict()
|
||||
new_state["log_alpha"] = torch.tensor([0.42])
|
||||
algorithm.load_state_dict(new_state)
|
||||
|
||||
assert id(algorithm.log_alpha) == log_alpha_id_before
|
||||
assert id(algorithm.optimizers["temperature"].param_groups[0]["params"][0]) == log_alpha_id_before
|
||||
assert torch.allclose(algorithm.log_alpha.detach().cpu(), torch.tensor([0.42]))
|
||||
|
||||
|
||||
def test_save_pretrained_round_trip_via_disk(tmp_path):
|
||||
"""End-to-end: save_pretrained -> from_pretrained restores tensors and config."""
|
||||
sac_cfg = _make_sac_config()
|
||||
src_policy = GaussianActorPolicy(config=sac_cfg)
|
||||
src = SACAlgorithm(policy=src_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg))
|
||||
src.make_optimizers_and_scheduler()
|
||||
src.update(_batch_iterator())
|
||||
|
||||
save_dir = tmp_path / "algorithm"
|
||||
src.save_pretrained(save_dir)
|
||||
assert (save_dir / "model.safetensors").is_file()
|
||||
assert (save_dir / "config.json").is_file()
|
||||
|
||||
dst_policy = GaussianActorPolicy(config=sac_cfg)
|
||||
dst = SACAlgorithm.from_pretrained(save_dir, policy=dst_policy)
|
||||
|
||||
src_sd = src.state_dict()
|
||||
dst_sd = dst.state_dict()
|
||||
assert set(src_sd) == set(dst_sd)
|
||||
for key in src_sd:
|
||||
assert torch.allclose(src_sd[key].cpu(), dst_sd[key].cpu()), f"{key} mismatch after disk round-trip"
|
||||
|
||||
@@ -68,6 +68,12 @@ class _DummyRLAlgorithm(RLAlgorithm):
|
||||
def load_weights(self, weights, device="cpu") -> None:
|
||||
_ = (weights, device)
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict, device="cpu") -> None:
|
||||
_ = (state_dict, device)
|
||||
|
||||
|
||||
class _SimpleMixer:
|
||||
def get_iterator(self, batch_size: int, async_prefetch: bool = True, queue_size: int = 2):
|
||||
|
||||
Reference in New Issue
Block a user